cabbage

cabbage

菜鸟写给小白的教程

0%

timm库(CV利器)的入门教程

你是否想过用深度学习来解决一些视觉问题,比如图像分类、目标检测、人脸识别等?

你是否觉得深度学习的模型太多太复杂,不知道该如何选择和使用?

你是否曾经为了复现一篇论文的结果而苦苦寻找代码和权重,却发现官方没有提供或者提供的不完整?

如果你有以上的困惑,那么我要向你推荐一个神奇的库:timm。

啥是timm?

timm 是 PyTorch Image Models 的缩写

is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.

timm 库实现了最新的几乎所有的具有影响力的视觉模型,它不仅提供了模型的权重,还提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。

timm 库是由 Ross Wightman 开发和维护的

项目地址:https://github.com/huggingface/pytorch-image-models

写文原因:

我之所以写这篇文章,主要是出于以下几点:

  1. 深度学习在各个领域都取得了令人惊叹的成就,比如自然语言处理、计算机视觉、AIGC等。作为一个对人工智能感兴趣并且想要从事相关工作或研究的人,我认为有必要掌握深度学习相关知识,并且跟进最新进展。

  2. 互联网是一个开放、自由、分享、进步 的平台。我非常欣赏并支持开源社区中那种无私奉献、相互帮助、共同成长的氛围。我也希望通过我的文章来分享我所学到或者感兴趣的东西,并且与读者们交流心得。

  3. 当我刚开始接触深度学习时,我遇到了很多困难和挫折。比如找不到合适的教程,网上信息良莠不齐,解决一个问题要搞好久。虽然我现在也还是菜鸡hh,但是我想把自己的学习经历和心得分享给那些想要入门CV的小白们,希望能够对你们有所帮助。同时,也欢迎各位大佬或者有兴趣的同学给我提出宝贵的意见和建议,让我们一起进步。

  4. 本来是想写完发,但是可能又会拖延,就先发出来一点,欢迎大家监督。可以加我的QQ:7914675,一起交流入门心得

  5. 还有一个原因,不过先容我卖个关子,等到这个timm的系列完结再说。

本系列的目录结构和跳转链接

如下:

0.引言

1.简单的使用

2.训练自己的

3.pth下载到哪里了

4.修改模型 获取特征

5.create _model代码解读以及创建自己的模型

5.create_transform以及resolve_data_config代码解读

后续更新计划: 1.简单的使用model 2.训练自己的model 3.pth下载到哪里了 4.修改model 获取中间特征 5.create _model代码解读以及创建自己的模型 6.create_transform以及resolve_data_config代码解读 7.及以后 更多的tricks

帖子代码基于 timm0.54

之后放出codebook

wx

https://github.com/huggingface/pytorch-image-models

https://huggingface.co/docs/hub/timm

https://huggingface.co/docs/timm/index

https://timm.fast.ai/training

https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055

https://zhuanlan.zhihu.com/p/404107277

https://pytorch.org/tutorials/

optimzer

传统的训练函数,一个batch是这么训练的:

  1. 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  2. optimizer.zero_grad() 清空过往梯度;
  3. loss.backward() 反向传播,计算当前梯度;
  4. optimizer.step() 根据梯度更新网络参数

一、optimizer的基本参数
defaults:优化器超参数
state:参数的缓存,如momentum的缓存
param_groups:管理的参数组
_step_count:记录更新次数,学习率调整中使用(比如要在第n次迭代后降低学习率)

二、optimizer的基本使用方法
zero_grad():清空所有参数的梯度,因为张量不会自动清零会自动累加,在梯度求导之前需要进行梯度清零
step():进行一步更新
add_param_groups():添加一组参数
state_dict():获取优化器当前状态信息字典
load_state_dict():加载状态信息字典
————————————————
https://blog.csdn.net/m0_55769743/article/details/122717287

https://blog.csdn.net/m0_37628604/article/details/121630836
四个函数作用是先将梯度值归零(optimizer.zero_grad()),然后反向传播计算每个参数的梯度值(loss.backward()),通过梯度下降进行参数更新(optimizer.step()),最后根据opeoch训练轮数更新学习率(lr_scheduler.step())。

接下来通过源码对四个函数进行分析。在此之前说明函数中常见的参数变量。
param_groups:Optimizer类在实例化时会创建一个param_groups列表,列表中有num_groups(num_groups取决于你定义optimizer时传入了几组参数)个长度为6的param_group字典,每个param_group包含了[‘param’,‘lr’,‘momentum’,‘dampening’,‘weight_decay’,‘nesterov’]这6组键值对。
params(iterable)—待优化参数w、b 或者定义了参数组的dict
lr(float,可选)—学习率
momentum(float,可选,默认0)—动量因子
weight_decay(float,可选,默认0)—权重衰减
dampening (float, 可选) – 动量的抑制因子(默认:0)
nesterov (bool, 可选) – 使用Nesterov动量(默认:False)
param_group[‘param’]:由传入的模型参数组成的列表,即实例化Optimizer类时传入该group的参数,如果参数没有分组,则为整个模型的参数model.parameters(),每个参数是一个torch.nn.parameter.Parameter对象。

https://blog.csdn.net/qq_40379132/article/details/124573071 -仔细看看

https://blog.csdn.net/weixin_43863869/article/details/128120719

梯度累计

简单的说就是进来一个batch的数据,计算一次梯度,更新一次网络
使用梯度累加是这么写的
获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
loss.backward() 反向传播,计算当前梯度;
多次循环步骤1-2,不清空梯度,使梯度累加在已有梯度上;
梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;

一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize ‘变相’ 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。

对模型参数的梯度置0时通常使用两种方式:model.zero_grad()optimizer.zero_grad()。二者在训练代码都很常见,那么二者的区别在哪里呢?

https://www.cnblogs.com/chaofengya/p/16925125.html
转自 https://cloud.tencent.com/developer/article/1710864

使用

https://blog.csdn.net/lishanlu136/article/details/121284421

auto.grad

https://blog.csdn.net/qq_39208832/article/details/117415229

torch.Tensor 是这个包的核心类。如果设置它的属性 .requires_gradTrue,那么它将会追踪对于该张量的所有操作。当完成计算后可以通过调用 .backward(),来自动计算所有的梯度。这个张量的所有梯度将会自动累加到.grad属性.

要阻止一个张量被跟踪历史,可以调用 .detach() 方法将其与计算历史分离,并阻止它未来的计算记录被跟踪。为了防止跟踪历史记录(和使用内存),可以将代码块包装在 with torch.no_grad(): 中。在评估模型时特别有用,因为模型可能具有 requires_grad = True 的可训练的参数,但是我们不需要在此过程中对他们进行梯度计算。

https://blog.csdn.net/qq_53345829/article/details/124308515

pytorch加载图片数据集有两种方法。

orch能处理的数据只能是torch.Tensor,所以有必要将其他数据转换为torch.Tensor。

常见的有几种数据:

np.ndarray
PIL.Image

https://blog.csdn.net/a19990412/article/details/105402341/

创建自己的数据集需要继承父类torch.utils.data.Dataset,同时需要重载两个私有成员函数:def len(self)和def getitem(self, index) 。 def len(self)应该返回数据集的大小;def getitem(self, index)接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。
https://blog.csdn.net/Vertira/article/details/127482001

https://blog.csdn.net/qq_38683460/article/details/123306447

https://blog.csdn.net/zwy_697198/article/details/123561769

https://blog.csdn.net/qq_41140138/article/details/127084076

Image.open(ImgPath)

https://blog.csdn.net/weixin_43723625/article/details/108158375

3.DataLoader

提供对Dataset的操作,操作如下:

1
2
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
1

参数含义如下:

  • d a t a s e t \color{HotPink}{dataset}datase**t: 加载torch.utils.data.Dataset对象数据
  • b a t c h _ s i z e \color{HotPink}{batch_size}batc**h_size: 每个batch的大小
  • s h u f f l e \color{HotPink}{shuffle}shuffl**e:是否对数据进行打乱
  • d r o p _ l a s t \color{HotPink}{drop_last}drop_las**t:是否对无法整除的最后一个datasize进行丢弃
  • n u m _ w o r k e r s \color{HotPink}{num_workers}num_work**ers:表示加载的时候子进程数

因此,在实现过程中我们测试如下(紧跟上述用例):

1
2
3
4
5
from torch.utils.data import DataLoader

# 读取数据
datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)
1234

此时,我们的数据已经加载完毕了,只需要在训练过程中使用即可。

4.查看数据

我们可以通过迭代器(enumerate)进行输出数据,测试如下:

1
2
3
for i, data in enumerate(datas):
# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
print("第 {} 个Batch \n{}".format(i, data))

相互转换

使用pandas,可以写入到csv或者xlsx格式文件

import pandas as pd

result_list = [[‘1’, 1, 1], [‘2’, 2, 2], [‘3’, 3, 3]]

columns = [“URL”, “predict”, “score”]

dt = pd.Dataframe(result_list, columns=columns)

dt.to_excel(“result_xlsx.xlsx”, index=0)

dt.to_csv(“result_csv.csv”, index=0)

excel文件转csv

https://blog.csdn.net/weixin_44416114/article/details/123894189

平时开发时文件读写都是csv比较简单方便,不过有时需要给别人提供excel,或者别人提供excel给自己,那么csv和excel的互转工具就十分有必要写一个了

下面代码保存一个py文件,比如 csvtool.py,放到csv或者excel所在的文件夹,运行即可

https://blog.csdn.net/weixin_40476348/article/details/120345498

https://blog.csdn.net/weixin_42171682/article/details/120431054

使用python读取和保存为excel、csv、txt文件以及对DataFrame文件的基本操作

https://blog.csdn.net/weixin_45928096/article/details/124034946

pandas

通用啊,前面的了解了解就行了

真正格式转换,读取数据,还是这个好用

https://blog.csdn.net/npm_run_dev__/article/details/125881177DataFrame对象的结构

https://blog.csdn.net/Xiaoyuteacher/article/details/117993785

1、os获取文件夹下所有文件

方法一:读取文件夹下所有文件名(不读取子文件夹下文件名,顺序为乱序)

1
2
3
4
5
6
7
8
9
#采用os.listdir()函数
def listdir(path): #path为文件夹的存储路径
list_name=[] #保存所有文件名至列表
list_dir = os.listdir(path)
# list_dir.sort(key=lambda x:int(x[:-5])) #若文件名无中文,可对列表进行排序
for file in list_dir: #由于子文件名后缀为.xlsx,因此为-5
file_path = os.path.join(path, file)
list_name.append(file_path)
return list_name

方法二:读取文件夹下指定后缀文件名(包括子文件下文件名,顺序为乱序)

1
2
3
4
5
6
7
8
9
10
#采用os.walk()方法,读取文件夹中后缀为.csv的文件名
def listdir2(path): #path为文件夹的存储路径
type = ('.csv') #可在此处更改文件类型
list_name = [] #保存所有文件名至列表
for root, dirs, files in os.walk(path):
for f in files:
fname = os.path.join(root, f)
if fname.endswith(type):
list_name.append(fname)
return list_name

2、读取多个excel内容存到list中

1
2
3
4
5
6
7
8
9
#函数的输入为列表,存储着多张excel的文件位置
#函数的输出为二维列表,保存着多张excel的内容
def textList(list_name): #list_name为多张表的位置信息列表
text_excel = [] #text_excel为存储多张表的二维列表
for num in list_name:
excel = pd.read_excel(num, header=None, skiprows=0)
excel = excel.values.tolist()
text_excel.extend(excel)
return text_excel

3、将二维列表的值写入excel中

1
2
3
4
5
6
7
8
9
10
11
12
#利用openpyxl模块实现,还有一种常用的是利用pandas模块,后续进行介绍
def writeToExcel(file_path, new_list): #file_path为excel存储路径,new_list为二维列表
wb = openpyxl.Workbook()
ws = wb.active
ws.title = 'feature'
for r in range(len(new_list)):
for c in range(len(new_list[0])):
ws.cell(r + 1, c + 1).value = new_list[r][c]
# excel中的行和列是从1开始计数的,所以需要+1
wb.save(file_path) # 注意,写入后一定要保存
print("成功写入文件: " + file_path + " !")
return 1

4、批量更改excel文件名

1
2
3
4
5
6
7
8
#批量修改文件中excel文件名,使用到1中listdir()函数
def rename_excel(path, text): #path为存储文件夹,text为替换的新内容
list_name = listdir(path)
for excel_name in list_name:
new_name = excel_name[:-10] + text + excel_name[-8:]
os.rename(excel_name, new_name)
return list_name
rename_excel(r'D:\test', '27')

执行效果,文件夹中excel命名如下:
在这里插入图片描述
批量将“26”改为“27”,程序执行结果为:
在这里插入图片描述

5、给csv添加第一列,第一列的值为日期

1
2
3
4
5
6
7
#给csv添加一列,该列表示时间,全部设为2021-04-26
data = pd.read_csv(r'test.csv',sep=',', header=None,skiprows=1,encoding="gbk")
col_name = data.columns.tolist() #读取列名,并放入list中
col_name.insert(0,'DATE') #给list中插入第一列
new_data = data.reindex(columns=col_name) #列名重命名
new_data['DATE']='2021-04-26' #给第一列赋值
new_data.to_csv('new_test.csv',index=0) #写入新的csv文件

6、批量修改csv中的值,并写入新的csv中

1
2
3
4
5
6
7
8
9
10
11
12
#读取csv的值,第一列为时间类型,大小不变,对其余列生成随机数字
data = pd.read_csv(path, sep=',',header=None, skiprows=1, encoding="gbk") #原始数据
col_name = list(data) #返回列名
new_col = col_name[1:] #取除第一列之外其余列的名称
col = data.shape[1] #返回列数
index = data.shape[0] #返回行数
#给其余列随机生成[0-100]之间的数字
new_data_randam = pd.DataFrame(np.random.randint(0,100,size=(index, col-1)), columns = new_col)
new_col_name = new_data_randam.columns.tolist() #生成新的列名
new_col_name.insert(0, 'DATE') #插入日期为第一列
new_test_data = new_data_randam.reindex(columns=new_col_name) #列名重命名
new_test_data['DATE'] = '2021-04-26' #给第一列进行赋值

7、csv每隔10行取1行,并写入新的csv中

1
2
3
4
5
6
7
8
#读取csv的内容,每10条进行抽样,path为csv文件位置,new_path为新的存储位置
data = pd.read_csv(path, sep=',', header=None, skiprows=1, encoding="gbk")
sample = [] #存储抽样的行数
for i in range(0, len(data), 10): #每隔10行取数据
sample.append(i)
new_data = data.iloc[sample] #根据行数对数据进行抽样
new_data = pd.DataFrame(new_data)
new_data.to_csv(new_path, index=0, header=0) # 写入新的csv文件

8、csv按列求均值、最大值、最小值,并写入新的csv中

1
2
3
4
5
6
7
8
9
10
11
12
#利用dataframe的.mean().max().min()函数分别求各列的均值、最大值和最小值
#path为csv的存储路径,new_path_mean为均值存储路径,new_path_max为最大值存储路径,new_path_min为最小值存储路径
data = pd.read_csv(path, sep=',', header=None, skiprows=1, encoding="gbk")
data_mean = data.mean(axis=0)) #均值,axis=1时,为对行求均值
data_max = data.max(axis=0)) #最大值
data_min = data.min(axis=0))) #最小值
# 均值写入新的csv文件
data_mean.to_csv(new_path_mean, index=0, header=0, mode='a') # mode='a'为追加模式
# 最大值写入新的csv文件
data_max.to_csv(new_path_max, index=0, header=0, mode='a') #mode='w'为写模式
# 最小值写入新的csv文件
data_min.to_csv(new_path_min, index=0, header=0, mode='a')

excel文件

工作簿:一个excel文件就是一个工作簿

工作表:一个工作簿中可以有多个工作表(至少一个)

单元格:单元格是excel文件保存数据的基本单位

行号和列号:可以确定单元格位置

import openpyxl

excel读操作

1)打开excel文件创建工作簿对象

openpyxl.open(excel文件路径)

openpyxl.load_workbook(excel文件路径)

1
2
3
4
5
import openpyxl

workbook = openpyxl.open('excel_file/三国人物数据.xlsx')
# 或者
workbook = openpyxl.load_workbook('excel_file/三国人物数据.xlsx')

2)获取工作表

工作簿对象.active - 获取活跃表(选中的表)

工作簿对象[工作表名称] - 获取指定名字对应的工作表

1
2
3
4
5
sheet1 = workbook.active
print(sheet1)

sheet2 = workbook['工作表名称']
print(sheet2)

如果要获取工作簿中所有的工作表表名

workbook.sheetnames

1
2
result = workbook.sheetnames
print(result)

3)获取单元格

工作表对象.cell(行号, 列号)

1
2
3
cell1 = sheet2.cell(8, 1)
cell2 = sheet2.cell(12, 1)
print(cell1, cell2)

4)获取单元格内容

单元格对象.value

1
2
print(cell1.value)
print(cell2.value)

5)获取最大行号和最大列号(保存了数据的有效行和有效列)

工作表对象.max_row

工作表对象.max_column

1
2
print(sheet2.max_row)
print(sheet2.max_column)

excel文件写操作

1
注意:不管是以什么样的方式对excel进行写操作,操作完成之后必须保存

(一)新建工作簿

1)新建工作簿对象

openpyxl.Workbook()

例如:

1
workbook = openpyxl.Workbook()

2)保存

工作簿对象.save(文件路径)

实际中新建工作簿的时候需要先判断工作簿对应的文件是否已经存在,存在就不需要新建,不存在才新建

法一

1
2
3
4
5
try:
workbook = openpyxl.open('excel_file/student2.xlsx')
except FileNotFoundError:
workbook = openpyxl.Workbook()
workbook.save('excel_file/student2.xlsx')

法二

os.path.exists(文件路径) - 判断指定文件是否存在,存在返回True,不存在返回False

1
2
3
4
5
6
7
import os

if os.path.exists('excel_file/student2.xlsx'):
workbook = openpyxl.open('excel_file/student2.xlsx')
else:
workbook = openpyxl.Workbook()
workbook.save('excel_file/student2.xlsx')

(二)工作表的写操作

1)新建工作表

工作簿对象.create_sheet(表名, 下标)

实际中的新建表:没有的时候才新建,有的时候直接打开

1
2
3
4
5
if 'Python' in workbook.sheetnames:
sheet = workbook['Python']
else:
sheet = workbook.create_sheet('Python')
workbook.save('excel_file/student2.xlsx')

2)删除工作表

工作簿对象.remove(工作表对象)

1
2
workbook.remove(workbook['Sheet1'])
workbook.save('excel_file/student2.xlsx')

实际中删除表:存在的时候才能删

1
2
3
if 'Sheet1' in workbook.sheetnames:
workbook.remove(workbook['Sheet1'])
workbook.save('excel_file/student2.xlsx')

(三)单元格的写操作

单元格对象.value = 数据

1
2
3
4
5
6
java_sheet = workbook['Java']
java_sheet.cell(1, 3).value = '电话'
java_sheet.cell(2, 1).value = None
java_sheet.cell(4, 2).value = 'stu003'

workbook.save('excel_file/student2.xlsx')

原理

registry

在了解了 create_model 函数的基本使用之后,我们来深入探索一下 create_model 函数的源码,看一下究竟是怎样实现从模型到特征提取器的转换的

https://blog.csdn.net/weixin_44966641/article/details/121364784

在 timm 内部,有一个字典称为 _model_entrypoints 包含了所有的模型名称和他们各自的函数。比如说,我们可以通过 model_entrypoint 函数从 _model_entrypoints 内部得到 xception71 模型的构造函数。

如我们所见,在 timm.models.xception_aligned 模块中有一个函数称为 xception71 。类似的,timm 中的每一个模型都有着一个这样的构造函数。事实上,内部的 _model_entrypoints 字典大概长这个样子:

view code 不适合放代码,都锁在一起了

在 timm 对应的模块中,每个模型都有一个构造器。比如说 ResNets 系列模型被定义在 timm.models.resnet 模块中。因此,实际上我们有两种方式来创建一个 resnet34 模型

import timm from timm.models.resnet import resnet34 # 使用 create_model m = timm.create_model(‘resnet34’) # 直接调用构造函数 m = resnet34()

但使用上,我们无须调用构造函数。所用模型都可以通过 create_model 函数来将创建。

画图,话一个相互之间构成联系的图

whiteboard

boardmix

0521

create_dataset

Dataset factory method 创建数据集,返回一个数据集对象

支持加载的数据集

1.folder类型:基于文件夹结构组织的数据集(比如说imagenet)

2.torch类型:基于torchvision的数据集(cifar10/100/MNIST/places365)

3.TFDS类型:通过IterableImageDataset在IterabeDataset接口中的Tensorflow数据集包装器

源码

1
2
3
4
5
6
7
8
9
10
11
12
13
def create_dataset(
name,
root,
split='validation',
search_split=True,
class_map=None,
load_bytes=False,
is_training=False,
download=False,
batch_size=None,
repeats=0,
**kwargs
):

看的头大?没事,下面带你一个一个看,给出解释和建议的参数,一般也用不到这么多参数啦,按照参数重要性讲解

参数

name 数据集名称(重要)

数据集名称

对于基于文件夹的数据集(folder)可以为空

有坑,如果是想加载cifar10,则name需要写为torch/CIFAR10 这也是我为什么不太喜欢封装的太深的

root:数据集的根文件夹(必需)

数据集所在位置根目录

如果没有该数据集,启用download参数会自动下载文件夹在该路径

split 划分数据集(必需-默认为val)

train 和 val的数据集都由这个create_dataset函数确立,那么如何区分呢,重点就是这个参数。

不同数据集有不同的划分,这里不用管,他有一个字典。不管你输入是train还是training,都会自动识别,并根据数据集给出正确的划分

val/valid/validation/eval/evaluation也一样

download 自动下载数据集(重要-torch/TFDS类型)

默认为Fasle,第一次的话可以指定为True自动下载数据集

如果之前指定的root路径中没有数据集

则自动下载数据集到root目录,仅支持torch类型和TFDS类型

imagenet需要自己下载

cifar10我试了一下,大概20min下好,之后就可以快了的使用了

_search_split 自主找到划分(folder)

默认为True不用改

如果是folder类型,则调用_search_split() 函数

从根目录中搜索拆分特定的子文件夹,以便指定imagenet/而不是/imagenet/val

class_map 索引映射(folder)

把输出的tensor映射成实际的类别

一般不用,不用管,默认为None

load_bytes(folder)

加载数据,以未编码字节的形式返回图像

一般不用,不用管

batch_size/is_training/repeats(TFDS类型)

一般用不到,不管了

0521-create_loader

源码

def create_loader(

​ dataset,

input_size,

​ batch_size,

is_training=False,

use_prefetcher=True,

no_aug=False,

re_prob=0.,

re_mode=’const’,

re_count=1,

​ re_split=False,(re_num_splits = 0 if re_split: re_num_splits = num_aug_splits or 2)

scale=None,

ratio=None,

hflip=0.5,

vflip=0.,

color_jitter=0.4,

auto_augment=None,

​ **num_aug_repeats=0,**(re_num_splits = 0 if re_split: re_num_splits = num_aug_splits or 2)

​ **num_aug_splits=0,**(separate=num_aug_splits > 0,)

interpolation=’bilinear’,

mean=IMAGENET_DEFAULT_MEAN,

std=IMAGENET_DEFAULT_STD,

​ num_workers=1,

​ distributed=False,

crop_pct=None,

​ collate_fn=None,

​ pin_memory=False,

​ fp16=False,

tf_preprocessing=False,

​ use_multi_epochs_loader=False,

​ persistent_workers=True,

​ worker_seeding=’all’,

):

0522-create_transform

def create_transform(

​ input_size,

​ is_training=False,

​ use_prefetcher=False,

​ no_aug=False,

​ scale=None,

​ ratio=None,

​ hflip=0.5,

​ vflip=0.,

​ color_jitter=0.4,

​ auto_augment=None,

​ interpolation=’bilinear’,

​ mean=IMAGENET_DEFAULT_MEAN,

​ std=IMAGENET_DEFAULT_STD,

​ re_prob=0.,

​ re_mode=’const’,

​ re_count=1,

​ re_num_splits=0,

​ crop_pct=None,

​ tf_preprocessing=False,

​ separate=False):

dataloader内部代码中,实际还是基于torch的加载

loader_class = torch.utils.data.DataLoader

loader = loader_class(dataset, **loader_args)

(26条消息) timm.data.create_transform_alien丿明天的博客-CSDN博客

(26条消息) 【transformer】【pytorch】DeiT的数据增强_create_transform_剑宇2022的博客-CSDN博客

0522-

#这里加载cifar数据集
#3套方法timm/ torch/加载自己数据集
#基于torch在这个时候就transformer了(getitem)/timm的在loader的时候才(dataset.transform = create_transform)
#和分布式 pin memory

训练

pytorch训练步骤总结 - 知乎 (zhihu.com)

cs230-code-examples/pytorch/nlp at master · cs230-stanford/cs230-code-examples · GitHub

cs230-code-examples/train.py at master · cs230-stanford/cs230-code-examples · GitHub

深度学习pytorch训练代码模板(个人习惯) - 知乎 (zhihu.com)

深度学习pytorch训练代码模板(个人习惯) - 知乎 (zhihu.com)

(6 封私信 / 80 条消息) 如何提高自己的代码能力以达到熟练使用pytorch? - 知乎 (zhihu.com)

(27条消息) 【Pytorch】学习笔记(训练代码框架)_Chaossll的博客-CSDN博客

(27条消息) Pytorch模型训练&保存/加载(搭建完整流程)_pytorch训练模型_Huterox的博客-CSDN博客

os

listdir()

os.listdir() 函数用于返回指定目录下的文件和文件夹的名称列表。

语法: os.listdir(path) 参数: path : 指定的目录。 返回值: 返回指定目录下的文件和文件夹的名称列表

tqdm

Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。

安装与导入

pip install tqdm

我们能发现使用的核心是tqdm和trange这两个函数,导入这两个

from tqdm import tqdm

from tqdm import trange

其实:trange=tqdm(range())

4种使用方式

tqdm是非常通用的,并且可以以多种方式使用。下面给出三个主要部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 方法1
for i in tqdm(可迭代对象):
pass

# 方法2
for idx, i in enumerate(tqdm(可迭代对象)):
pass
for idx, i in tqdm(enumerate(可迭代对象)):
pass

# 方法3 而如果想要在迭代过程中变更说明文字,还可以预先实例化进度条对象,在需要刷新说明文字的时候执行相应的程序
p = tqdm(可迭代对象)
for idx, i in enumerate(p):
pass
# 方法4 手动控制更新
with tqdm.tqdm(total=10) as bar: # total为进度条总的迭代次数
# 操作1
time.sleep(1)
# 更新进度条
bar.update(1) # bar.update()里面的数表示更新的次数,和optimizer.step方法类似

# 操作2
time.sleep(2)
# 更新进度条
bar.update(3)

# 操作3
time.sleep(1)
# 更新进度条
bar.update(6) # 建议不要超过total

with tqdm(total=100) as pbar:
for i in range(10):
pbar.update(10)
# 也可以这样
pbar = tqdm(total=100)
for i in range(10):
pbar.update(10)
pbar.close()

3

而如果想要在迭代过程中变更说明文字,还可以预先实例化进度条对象,在需要刷新说明文字的时候执行相应的程序:

tqdm提供了两个个方法:

  1. set_description()
  2. set_postfix()

这两个方法就类似于print,可以在进度条中显示一些变动的信息,在使用set_description()和set_postfix()时一般会创建一个tqdm.tqdm()对象

1
2
3
4
5
6
pbar = tqdm.tqdm(range(epochs), ncols=100)  # ncols设置进度条显示的字符长度,小了就显示不全了

for idx, element in enumerate(pbar):
time.sleep(0.01)
pbar.set_description(f"Epoch {idx}/{epochs}")
pbar.set_postfix({"class": element}, loss=random.random(), cost_time = random.randrange(0, 100))
image-20230328122422753

4

tqdm()的返回值是一个可迭代对象,迭代的每一个元素就是iterable的每一个参数。该返回值可以修改进度条信息。示例

1
2
3
4
5
with tqdm(range(100), desc='Test') as tbar:
for i in tbar:
tbar.set_postfix(loss=i/100, x=i)
tbar.update() # 默认参数n=1,每update一次,进度+n
time.sleep(0.2)

参数

1
2
3
4
@staticmethod
def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False, unit='it',
unit_scale=False, rate=None, bar_format=None, postfix=None,
unit_divisor=1000, initial=0, colour=None, **extra_kwargs):

iterable: 可迭代的对象, 在手动更新时不需要进行设置
desc: 字符串,左边进度条描述,作为进度条标题
total: 预期的迭代次数。一般不填,默认为iterable的长度。
leave: bool值, 迭代完成后是否保留进度条。默认保留。
file: 输出指向位置, 默认是终端, ⼀般不需要设置
ncols: 调整进度条宽度, 默认是根据环境⾃动调节长度, 如果设置为0, 就没有进度条, 只有输出的信息
unit: 描述处理项⽬的⽂字, 默认是it, 例如: 100 it/s, 处理照⽚的话设置为img ,则为 100 img/s
unit_scale: ⾃动根据国际标准进⾏项⽬处理速度单位的换算, 例如 100000 it/s >> 100k it/s

postfix进度条后缀信息以字典形式传入详细信息

colour: 进度条颜色

还有一个复杂的格式参数bar_formmat,见下一节

进度条含义解释与自定义

其中xxxit/s表示每秒迭代的次数,it=iteration(一次迭代)

进度百分比|进度条| 当前迭代数/总迭代个数,[消耗时间<剩余时间,迭代的速度]

其会在每一轮从可迭代对象中取得一个值之后,打印遍历进度条,然后再执行循环中的程序。最后面的速度表示执行一个循环所耗费的时间。

复杂参数bar_formmat

默认格式:'{l_bar}{bar}{r_bar}'

l_bar:进度条左边的文字
bar: 进度条图形部分
r_bar: 进度条右边的文字

bar_format  : str, optional
    Specify a custom bar string formatting. May impact performance.
    [default: '{l_bar}{bar}{r_bar}'], where
    l_bar='{desc}: {percentage:3.0f}%|' and
    r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt,
    percentage, elapsed, elapsed_s, ncols, desc, unit,
    rate, rate_fmt, rate_noinv, rate_noinv_fmt,
    rate_inv, rate_inv_fmt, postfix, unit_divisor,
    remaining, remaining_s.
    Note that a trailing ": " is automatically removed after {desc}
    if the latter is empty.

其具体为:{desc}: 100%|███████████| 10/10 [00:05<00:00, 1.98it/s,{postfix}]

desc、postfix默认为空;

1
2
for i in tqdm(range(10), bar_format='{desc}|{bar}|{percentage:3.0f}%'):
time.sleep(.5)

|█████████████████████|100%

others

嵌套

1
2
3
4
5
tqdm还支持在循环中嵌套使用进度条,例如在for循环中嵌套while循环:
from tqdm import tqdm
for i in tqdm(range(10)):
for j in tqdm(range(100), leave=False):
# do something

这样就会在外层循环中显示一个进度条,在内层循环中显示另一个进度条。参数leave用于控制内层进度条是否在外层进度条完成后消失。

在多线程和多进程中使用

可以通过设置参数desc来给每个进程或线程命名

1
2
3
4
5
6
7
8
9
10
11
from tqdm import tqdm
import multiprocessing

def worker(num):
for i in tqdm(range(1000000), desc=f'Worker {num}'):
pass

if __name__ == '__main__':
with multiprocessing.Pool(4) as p:
p.map(worker, [1, 2, 3, 4])

这样就会在每个进程的进度条上显示进度,并显示进程的名称。

结合pandas的使用

1
2
3
4
5
6
import  pandas as pd
import numpy as np

df = pd.DataFrame(np.random.randint(0, 100, (10000000, 6)))
tqdm.pandas(desc="my bar!")
df.progress_apply(lambda x: x**2)
1
2
3
4
5
6
7
8
9
tqdm对pandas中的apply也提供了支持,用法是:

import pandas as pd
from tqdm.notebook import tqdm
# 每个单独的porgress_apply运行之前一定要先执行tqdm.pandas()
tqdm.pandas()

df=pd.DataFrame({'a', range(10)})
x = df.progress_apply(lambda x: time.sleep(0.2))

减少tqdm进度条在日志文件中的打印频率?

可以使用嵌套循环,其中包含tqdm的循环和内部循环只是一个简单的for循环

1
2
3
for outer in tqdm(range(0,1e5,1e4)):
for inner in range(1e4):
print(outer+inner)

cankao

https://blog.csdn.net/weixin_44878336/article/details/124894210

https://www.cnblogs.com/KuanLez/p/15974970.html

https://blog.csdn.net/weixin_42475060/article/details/121661840

https://blog.csdn.net/qq_41554005/article/details/117297861

https://blog.csdn.net/winter2121/article/details/111356587

https://zhuanlan.zhihu.com/p/378474516

hh