原理
registry
在了解了 create_model
函数的基本使用之后,我们来深入探索一下 create_model
函数的源码,看一下究竟是怎样实现从模型到特征提取器的转换的
https://blog.csdn.net/weixin_44966641/article/details/121364784
在 timm 内部,有一个字典称为 _model_entrypoints
包含了所有的模型名称和他们各自的函数。比如说,我们可以通过 model_entrypoint
函数从 _model_entrypoints
内部得到 xception71
模型的构造函数。
如我们所见,在 timm.models.xception_aligned
模块中有一个函数称为 xception71
。类似的,timm 中的每一个模型都有着一个这样的构造函数。事实上,内部的 _model_entrypoints
字典大概长这个样子:
view code
不适合放代码,都锁在一起了在 timm 对应的模块中,每个模型都有一个构造器。比如说 ResNets 系列模型被定义在 timm.models.resnet
模块中。因此,实际上我们有两种方式来创建一个 resnet34
模型
import timm from timm.models.resnet import resnet34 # 使用 create_model m = timm.create_model(‘resnet34’) # 直接调用构造函数 m = resnet34()
但使用上,我们无须调用构造函数。所用模型都可以通过 create_model
函数来将创建。
画图,话一个相互之间构成联系的图
whiteboard
boardmix
0521
create_dataset
Dataset factory method 创建数据集,返回一个数据集对象
支持加载的数据集
1.folder类型:基于文件夹结构组织的数据集(比如说imagenet)
2.torch类型:基于torchvision的数据集(cifar10/100/MNIST/places365)
3.TFDS类型:通过IterableImageDataset在IterabeDataset接口中的Tensorflow数据集包装器
源码
1 | def create_dataset( |
看的头大?没事,下面带你一个一个看,给出解释和建议的参数,一般也用不到这么多参数啦,按照参数重要性讲解
参数
name 数据集名称(重要)
数据集名称
对于基于文件夹的数据集(folder)可以为空
有坑,如果是想加载cifar10,则name需要写为torch/CIFAR10
这也是我为什么不太喜欢封装的太深的
root:数据集的根文件夹(必需)
数据集所在位置根目录
如果没有该数据集,启用download参数会自动下载文件夹在该路径
split 划分数据集(必需-默认为val)
train 和 val的数据集都由这个create_dataset函数确立,那么如何区分呢,重点就是这个参数。
不同数据集有不同的划分,这里不用管,他有一个字典。不管你输入是train还是training,都会自动识别,并根据数据集给出正确的划分
val/valid/validation/eval/evaluation也一样
download 自动下载数据集(重要-torch/TFDS类型)
默认为Fasle,第一次的话可以指定为True自动下载数据集
如果之前指定的root路径中没有数据集
则自动下载数据集到root目录,仅支持torch类型和TFDS类型
imagenet需要自己下载
cifar10我试了一下,大概20min下好,之后就可以快了的使用了
_search_split 自主找到划分(folder)
默认为True不用改
如果是folder类型,则调用_search_split()
函数
从根目录中搜索拆分特定的子文件夹,以便指定imagenet/
而不是/imagenet/val
等
class_map 索引映射(folder)
把输出的tensor映射成实际的类别
一般不用,不用管,默认为None
load_bytes(folder)
加载数据,以未编码字节的形式返回图像
一般不用,不用管
batch_size/is_training/repeats(TFDS类型)
一般用不到,不管了
0521-create_loader
源码
def create_loader(
dataset,
input_size,
batch_size,
is_training=False,
use_prefetcher=True,
no_aug=False,
re_prob=0.,
re_mode=’const’,
re_count=1,
re_split=False,(re_num_splits = 0 if re_split: re_num_splits = num_aug_splits or 2)
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
**num_aug_repeats=0,**(re_num_splits = 0 if re_split: re_num_splits = num_aug_splits or 2)
**num_aug_splits=0,**(separate=num_aug_splits > 0,)
interpolation=’bilinear’,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
crop_pct=None,
collate_fn=None,
pin_memory=False,
fp16=False,
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
worker_seeding=’all’,
):
0522-create_transform
def create_transform(
input_size,
is_training=False,
use_prefetcher=False,
no_aug=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
interpolation=’bilinear’,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode=’const’,
re_count=1,
re_num_splits=0,
crop_pct=None,
tf_preprocessing=False,
separate=False):
dataloader内部代码中,实际还是基于torch的加载
loader_class = torch.utils.data.DataLoader
loader = loader_class(dataset, **loader_args)
(26条消息) timm.data.create_transform_alien丿明天的博客-CSDN博客
(26条消息) 【transformer】【pytorch】DeiT的数据增强_create_transform_剑宇2022的博客-CSDN博客
0522-
#这里加载cifar数据集
#3套方法timm/ torch/加载自己数据集
#基于torch在这个时候就transformer了(getitem)/timm的在loader的时候才(dataset.transform = create_transform)
#和分布式 pin memory
训练
pytorch训练步骤总结 - 知乎 (zhihu.com)
cs230-code-examples/pytorch/nlp at master · cs230-stanford/cs230-code-examples · GitHub
cs230-code-examples/train.py at master · cs230-stanford/cs230-code-examples · GitHub
深度学习pytorch训练代码模板(个人习惯) - 知乎 (zhihu.com)
深度学习pytorch训练代码模板(个人习惯) - 知乎 (zhihu.com)
(6 封私信 / 80 条消息) 如何提高自己的代码能力以达到熟练使用pytorch? - 知乎 (zhihu.com)
(27条消息) 【Pytorch】学习笔记(训练代码框架)_Chaossll的博客-CSDN博客
(27条消息) Pytorch模型训练&保存/加载(搭建完整流程)_pytorch训练模型_Huterox的博客-CSDN博客