cabbage

cabbage

菜鸟写给小白的教程

0%

create model

https://blog.csdn.net/MengYa_Dream/article/details/126690336-主要参考

The model architectures included come from a wide variety of sources. Sources, including papers, original impl (“reference code”) that Ross rewrote / adapted, and PyTorch impl that he leveraged directly (“code”) are listed below.

Most included models have pretrained weights. The weights are either:

from their original sources
ported by myself from their original impl in a different framework (e.g. Tensorflow models)
trained from scratch using the included training script
The validation results for the pretrained weights can be found here.

最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm。各位炼丹师应该已经想必已经对其无比熟悉了,本文将介绍其中最关键的函数之一:create_model 函数

create_model 函数是用来创建一个网络模型(如 ResNet、ViT 等),timm 库本身可供直接调用的模型已有接近400个,用户也可以自己实现一些模型并注册进 timm (这一部分内容将在下一小节着重介绍),供自己调用

源码

create_model

create_model函数是用来在里面创建数百个模型的timm。它还期望一堆**kwargs诸如features_onlyout_indices并将这两个传递**kwargscreate_model函数来创建一个特征提取器。让我们看看如何?

create_model函数本身只有大约 50 行代码。所以所有的神奇的事情都必须在其他地方发生。您可能已经知道,其中的每个模型名称timm.list_models()实际上都是一个函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#model_name:模型的名字
#pretrained:是否加载预训练模型
#checkpoint_path:加载预训练模型的路径
#scriptable:是否使用脚本化
#exportable:是否使用导出
#no_jit:是否使用jit
#**kwargs:其他参数

#首先通过split_model_name函数将模型名字分割成两部分,一部分是模型的来源,一部分是模型的名字
#然后通过is_model函数判断模型是否存在
#如果存在,通过model_entrypoint函数获取模型的入口函数
#然后通过set_layer_config函数设置层的配置
#最后通过load_checkpoint函数加载预训练模型
def create_model(
model_name,
pretrained=False,
checkpoint_path='',
scriptable=None,
exportable=None,
no_jit=None,
**kwargs):
"""Create a model

Args:
model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
checkpoint_path (str): path of checkpoint to load after model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)

Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific
"""
source_name, model_name = split_model_name(model_name)

# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
" Setting drop_path to %f." % drop_connect_rate)
kwargs['drop_path_rate'] = drop_connect_rate

# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if source_name == 'hf_hub':
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
# load model weights + default_cfg from Hugging Face hub.
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday

if is_model(model_name):
create_fn = model_entrypoint(model_name)
else:
raise RuntimeError('Unknown model (%s)' % model_name)

with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, **kwargs)

if checkpoint_path:
load_checkpoint(model, checkpoint_path)

return model

timm.models.registry模块

代码中有这样一行

create_fn = model_entrypoint(model_name)

函数获取模型的入口函数

那么他是如何实现的呢,他的运行主要依赖于timm.models.registry模块

有一个字典_model_entrypoints = {} # mapping of model names to entrypoint fns

包含了所有的模型名称和他们各自的函数

1
2
3
4
def is_model(model_name):
""" Check if a model name exists
"""
return model_name in _model_entrypoints

model_entrypoint 函数从 _model_entrypoints 内部得到模型的构造函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def model_entrypoint(model_name):
"""Fetch a model entrypoint for specified model name
"""
return _model_entrypoints[model_name]

print(timm.models.registry._model_entrypoints)
print(timm.models.registry._model_entrypoints.keys())
'''
输出是下面这样子的,太多了,只贴出来一点
{'vit_tiny_patch16_224': <function vit_tiny_patch16_224 at 0x000002450498A048>, 'vit_tiny_patch16_384': <function vit_tiny_patch16_384 at 0x000002450498A0D0>, 'vit_small_patch32_224': <function vit_small_patch32_224 at 0x000002450498A158>, 'vit_small_patch32_384': <function vit_small_patch32_384 at 0x000002450498A1E0>}
可以看到是个字典,key是模型名称,value是模型的构造函数,
以vit为例,vit_base_patch16_224模型的构造函数,他存在于timm.models.vision_trasnformer.py
'''


'''
create_fn = model_entrypoint(model_name)
#以model_name = 'xception71' 为例
#<function xception71 at 0x7fc0cba0eca0>
#<function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>
'''

timm.models.vision_trasnformer模块

1
2
3
4
5
6
7
8
@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model

注释写了出处,模型的原论文

model_kwargs把我这个模型的独特参数和传进来的**kwargs一起打包成参数字典,传给_create_vision_transformer函数

这个函数写在文件的开头,vit的模型都是他生成的,区别就是架构参数不一样,比如层数,注意力头数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')

if 'flexi' in variant:
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
_filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
else:
_filter_fn = checkpoint_filter_fn

return build_model_with_cfg(
VisionTransformer, variant, pretrained,
pretrained_filter_fn=_filter_fn,
**kwargs,
)

挖坑提高,星星是什么

variant 变种,就是模型的名字,比如在vit类上生成的vit-base

函数之中,会再调用 build_model_with_cfg 函数并将一个构造器类 VisionTransformer、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。

**kwargs就是之前打包的模型参数和 **kwargs

build_model_with_cfg

这个 build_model_with_cfg 函数负责:

  • 真正地实例化一个模型类来创建一个模型
  • 若 pruned=True,对模型进行剪枝
  • 若 pretrained=True,加载预训练模型参数
  • 若 features_only=True,将模型转换为特征提取器

register_model装饰器

以上就是读者在使用 timm 库时的基本方法,其实到这里你应该已经能够使用它训练自己的分类模型了。但是如果还想进一步搞清楚它的框架原理,并在它的基础上做修改,本节可能会帮到你。

https://zhuanlan.zhihu.com/p/361837010

但是问题来了,这个字典怎么来的,我写好代码之后一个一个添加进去的吗?(也不是不行hh

但是太麻烦,这时候就用到我们的 register_model 装饰器,它可以不断地像其中添加模型名称和它对应的构造函数,一开始的字典是空的,是装饰器自动添加的,其源码如下

2.5 sys.modules
该属性是一个字典,包含的是各种已加载的模块的模块名到模块具体位置的映射。

通过手动修改这个字典,可以重新加载某些模块;但要注意,切记不要大意删除了一些基本的项,否则可能会导致 Python 整个儿无法运行。

mod = sys.modules[fn.__module__]

这里我print一下fn.__module__ 应该是xxx.xxx.vit_base

模型名称就是vit_base

module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''

还是打印一下,我有点分不清module_name和model_name

_model_to_module[model_name] = module_name

_module_to_models[module_name].add(model_name)

展开查看

def register_model(fn):
    # lookup containing module
    mod = sys.modules[fn.__module__]
    module_name_split = fn.__module__.split('.')
    module_name = module_name_split[-1] if len(module_name_split) else ''
    # add model to __all__ in module
    model_name = fn.__name__
    if hasattr(mod, '__all__'):
        mod.__all__.append(model_name)
    else:
        mod.__all__ = [model_name]
    # add entries to registry dict/sets
    _model_entrypoints[model_name] = fn
    _model_to_module[model_name] = module_name
    _module_to_models[module_name].add(model_name)
    has_pretrained = False  # check if model has a pretrained url to allow filtering on this
    if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
        # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
        # entrypoints or non-matching combos
        has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
        _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
    if has_pretrained:
        _model_has_pretrained.add(model_name)
    return fn
  

重点是这一句

1
2
3
model_name = fn.__name__

_model_entrypoints[model_name] = fn

它将给定的 fn 添加到 _model_entrypoints 其键名为 fn.__name__

查看模型定义文件

我们会发现 timm 中的每个模型都有一个 @register_model

例如vit-base

代码很简单明了,但是我还是不懂他怎么发生作用的

他是什么时候开始把构造函数加进那个空字典的?

我要自己注册函数应该如何做?

挖坑拓展,import

modelconfig

timm 中所有的模型都有一个默认的配置,包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。

此默认配置与其他参数(如构造函数类和一些模型参数)一起传递给 build_model_with_cfg 函数

写在前面

读写方式 可否读写 文件不存在
w 覆盖写入 创建
w+ 覆盖写入、可读 创建
r 可读 报错
r+ 覆盖写入、可读 报错
a 附加写入 创建
a+ 附加写入、可读 创建

csv

csv文件叫逗号分隔值文件

每一行内容是通过逗号来区分出不同的列

csv文件可以直接通过excel打开,以行列的形式保存和显示数据,但是相对excel文件,它只能存储数据,不能保存公式和函数

使用时要import csv

读操作

1)创建打开csv文件

1
2
3
4
5
6
7
"""
文件对象 = open()
文件对象.close()

with open() as 文件对象:
操作文件
"""

举例:

1
f = open('file/电影.csv', 'r', encoding='utf-8')

2)创建reader获取文件内容

1
2
3
4
5
6
7
8
9
10
11
"""
csv.reader(文件对象) - 获取文件内容,并且以列表为单位返回每一行内容
csv.DictReader(文件对象) - 获取文件内容,并且以字典为单位返回第2行开始的每一行内容(字典的键是第一行内容)
"""
import csv

f = open('file/电影.csv', 'r', encoding='utf-8')
reader1 = csv.reader(f)
print(list(reader1))
reader2 = csv.DictReader(f)
print(list(reader2))

二、csv文件写操作

(一)打开文件

1
f = open('./file/data.csv', 'w', encoding='utf-8')

如果没有文件,则会自动创建

(二)创建writer对象

1
2
3
4
5
6
7
"""
csv.writer(文件对象,delimiter) - 创建writer对象,这个对象在写入数据的时候一行对应一个列表

delimiter 指定同一行每个字段的分隔字符。若不指定,默认以英文逗号(,)分隔,在csv文件中显示的是不同单元格,若以其他符号分隔,则显示在csv同一单元格中

csv.DictWriter(文件对象,键列表,delimiter) - 创建writer对象,以字典为单位写入数据
"""

1)以列表为单位写入一行内容

1
writer = csv.writer(f)

一次写入一行内容

1
2
writer.writerow(['姓名', '出生日期', '性别', '电话'])
writer.writerow(['小明', '2001', '男', '110'])

一次写入多行内容

1
2
3
4
writer.writerows([
['小花', '2001', '女', '110'],
['小华', '2001', '女', '110']
])

2)以字典为单位写入一行内容

1
writer = csv.DictWriter(f, ['姓名', '出生日期', '性别', '电话'])

写入文件头(将字典的键写入到文件开头)

1
writer.writeheader()

一次写入一行内容

1
writer.writerow({'姓名': '小明', '出生日期': '2000-2-3', '性别': '男', '电话': '112'})

一次写入多行内容

1
2
3
4
writer.writerows([
{'姓名': '小花', '出生日期': '2000-2-3', '性别': '男', '电话': '112'},
{'姓名': '小华', '出生日期': '2000-2-3', '性别': '男', '电话': '112'}
])

d待写

dddd

获取模型运行过程中的特征

+hook技术

获取特征 (一篇)

可以通过多种方式获得倒数第二个模型层的特征,而无需进行模型手术(尽管可以随意进行手术)。人们必须首先决定他们是想要池化的还是非池化的特征。

https://blog.csdn.net/qq_41917697/article/details/115026308
获取分类层前(倒数第二层)的特征

直接创建没有池化和分类层的模型,对于基于CNN的模型可以这样做
https://blog.csdn.net/qq_42003943/article/details/118382823?login=from_csdn
检查分类器,我们可以看到 timm 已经用一个新的、未经训练的、具有所需类别数量的线性层替换了最后一层;准备微调我们的数据集!
如果我们想避免完全创建最后一层,我们可以将类的数量设置为 0,这将创建一个具有恒等函数的模型作为最后一层;这对于检查倒数第二层的输出很有用。
identity()

直接调用forward_features()函数
x = torch.randn([1, 3, 224, 224])
Backbone1 = timm.create_model(‘vit_base_patch16_224’)
Backbone2 = timm.create_model(‘resnet50’)
feature1 = Backbone1.forward_features(x)
feature2 = Backbone2.forward_features(x)
print(‘vit_feature:’, feature1.shape, ‘resnet_feature:’, feature2.shape)# Results: vit_feature: torch.Size([1, 768]) resnet_feature: torch.Size([1, 2048, 7, 7])
————————————————
版权声明:本文为CSDN博主「ZhangChen@BJTU」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_42003943/article/details/118382823

通过移除层来获得
https://blog.csdn.net/qq_42003943/article/details/118382823?login=from_csdn

获取中间层特征-注意并非所有model都有此选项
https://blog.csdn.net/qq_42003943/article/details/118382823?login=from_csdn
可通过out_indices参数指定从哪个level获取feature

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
x = torch.randn([1, 3, 224, 224])
feature_extractor = timm.create_model('resnet50', features_only=True) # 并非所有model都有此选项
feature_list = feature_extractor(x)
for a in feature_list:
print(a.shape)
# Results:
# torch.Size([1, 64, 112, 112])
# torch.Size([1, 256, 56, 56])
# torch.Size([1, 512, 28, 28])
# torch.Size([1, 1024, 14, 14])
# torch.Size([1, 2048, 7, 7])

##可通过out_indices参数指定从哪个level获取feature
feature_extractor = timm.create_model('resnet50', features_only=True, out_indices=[1, 3, 4])
feature_list = feature_extractor(x)
for a in feature_list:
print(a.shape)
# Results:
# torch.Size([1, 256, 56, 56])
# torch.Size([1, 1024, 14, 14])
# torch.Size([1, 2048, 7, 7])
————————————————
版权声明:本文为CSDN博主「ZhangChen@BJTU」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_42003943/article/details/118382823

接下来我们来看一下这些特征提取器究竟是什么类型:

1
2
3
4
5
6
7
8
9
10
import timm
feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=[3])

print('type:', type(feature_extractor))
print('len: ', len(feature_extractor))
for item in feature_extractor:
print(item)
————————————————
版权声明:本文为CSDN博主「Adenialzz」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_44966641/article/details/121364784

输出:

type: <class ‘timm.models.features.FeatureListNet’>
len: 7
conv1
bn1
act1
maxpool
layer1
layer2
layer3

可以看到,feature_extractor 其实也是一个神经网络,在 timm 中称为 FeatureListNet,而我们通过 out_indices 参数来指定截取到哪一层特征。

需要注意的是,ViT 模型并不支持 features_only 选项(0.4.12版本)。

extractor = timm.create_model(‘vit_base_patch16_224’, features_only=True)
1
输出:

RuntimeError: features_only not implemented for Vision Transformer models.
1

————————————————
版权声明:本文为CSDN博主「Adenialzz」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_44966641/article/details/121364784

模型架构

部分的PyTorch模型及其对应arxiv链接如下:

所有的来源可以查看timm官方repo

result

各种timm库模型在 ImageNet 数据集训练结果

https://github.com/huggingface/pytorch-image-models/blob/main/results/README.md

下载pretrained model

pretrained

在创建模型时create_model如果我们传入 pretrained=True

那么 timm 会从对应的 URL 下载模型权重参数并载入模型,只有当第一次(即本地还没有对应模型参数时)会去下载,之后会直接从本地加载模型权重参数。

1
model = timm.create_model('vit_small_patch16_224', pretrained=True)

Downloading: “https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz" to /home/xxx/.cache/torch/hub/checkpoints/xxx.pth

下载的时候从model.default_cfg.url中下载权重

模型下载到当前用户root下的.cache/torch/hub/checkpoints/中

如果服务器网不好,可以手动下载再上传到相应位置

数据集

词2077分析
拼图 预测拼图完成时间

数据清洗

数据格式是2k分辨率图像的url,以及拼好所用的时间(64块,100块)

初步筛掉公司提供的有缺失值的数据

剩下1w张图像,时间对

特征提取(模型训练

模型主体

resnet34,效果不好,

用vit

损失函数

一开始选用mseloss

但是训练效果很差

发现是异常值有点多

mseloss均方误差(又称MSE、L2损失)

回归问题中最常见的损失函数。如果对所有样本点只给出一个预测值,那么这个值就是所有目标值的平均值。

优点:

  • 计算方便,逻辑清晰,衡量误差较准确。
  • MSE 曲线的特点是光滑连续、可导,便于使用梯度下降算法。梯度随着误差增大或减小,收敛效果好

缺点:

  • 对异常点会赋予较大的权重,如果异常点不属于考虑范围,是由于某种错误导致的,则此函数指导方向将出现偏差。模型会更加偏向于惩罚较大的点,赋予其更大的权重,忽略掉较小的点的作用,无法避免离群点可能导致的梯度爆炸问题。如果样本中存在离群点,MSE 会给离群点赋予更高的权重,但是却是以牺牲其他正常数据点的预测效果为代价,因此会导致降低模型的整体性能。

平均绝对值误差(又称MAE,L1损失)

优点:

  • 对异常值具有较好鲁棒性

缺点:

  • 梯度不变是个严重问题,MAE训练中梯度始终很大,且在0点连续但不可导,这意味着即使对于小的损失值,其梯度也是大的。这不利于函数的收敛和模型的学习,模型学习速度慢,同时也会导致使用梯度下降训练模型时,在结束时可能会遗漏全局最小值。

解决办法,可变lr与warmup

容易受离群点影响,也可能是过拟合了,加正则化

lr

https://zhuanlan.zhihu.com/p/524650878

可变学习率

https://blog.csdn.net/qq_37085158/article/details/122536201

pytorch中优化器optimizer.param_groups[0]是什么

https://blog.csdn.net/weixin_36670529/article/details/107531773

参考

https://zhuanlan.zhihu.com/p/378822530

https://mp.weixin.qq.com/s?__biz=MzI1MjQ2OTQ3Ng==&mid=2247582694&idx=1&sn=d8ec63b98581155f90963d4414892b37&chksm=e9e08a6dde97037bdec0781a4481959d94fe8bc961ea57573c22edd41afd1f2ce963e1a53cc8&scene=27

https://blog.csdn.net/weixin_45085051/article/details/127308481

image-20230329222008475

省流:训练自己的

如果不想用预训练好的权重,想自己搞怎么办呢

几个方法

微调权重

迁移学习到时候,可以不改前面backbone而只修改后面的分类头

脚本

官方提供了一些示例脚本,可以在github中下载

然后按照实列来运行

Scripts (huggingface.co)

比如

1
./distributed_train.sh 2 /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce

resnet50在imagenet上训练200epochs 余弦学习率 初始lr0.05 使用2张GPU分布式训练

之后会讲到train.py的代码解析

自己写train

省流:使用timm加载CNN进行图像分类,调整CNN使之更适合你的任务

问:使用timm搭建一个可以使用的CNN或ViT拢共需要几步?

答:4步

0.安装 timm

1.import timm

2.创建model

3.运行model

这一节很基础,会的兄弟们可以跳过看后面的

接下来具体讲一下如何使用,代码codebook会之后给出

安装、导入

准备自己python环境,建议用anconda来管理,然后安装上torch 和cuda

对于小白,可以先不安装cuda,只在cpu上跑代码

配置环境网上有很多教程,环境准备好了之后就可以在cmd安装timm包了,很简单:

1
2
3
4
5
6
conda imstall timm
#or
pip install timm
#or
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .

新建文件

在文件开头导入必要的包

1
2
import torch
import timm

创建、使用模型

创建模型的最简单方法是使用create_model;

这一个可用于在 timm 库中创建任何模型的工厂函数

这个函数各个参数有什么用,内部具体怎么实现的,怎么玩出花来的之后再讲,先只用它来创建一个CNN用来做分类任务

1
model_resnet34 = timm.create_model('resnet34', pretrained=True)

‘resnet34’是模型架构的名字

pretrained=True则会自动从网上下载训练好的模型权重加载到resnet34上

然后模型就创建好了,可以直接使用了

这里我们使用随机张量表示图像

torch.randn:用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1)

1
2
3
x = torch.randn([1, 1, 224, 224])#创建一个tensor 代表 一张3x224x224的图片
out = model_resnet50(x)#out就是x所对应的表示类别的一个tensor
print(out.shape)# Results: torch.Size([1, 1000])代表1000个类别

我们可以看到模型已经处理了图像并返回了预期的输出形状。

查看模型信息

那么怎么知道timm都可以导入哪些模型来使用呢?

1
2
3
4
5
6
7
model_list = timm.list_models()#返回一个包含所有模型名称的list
print(len(model_list))#964
pretrain_model_list = timm.list_models(pretrained = True)#筛选出带预训练模型的
print(len(pretrain_model_list))#770
##使用通配符字符串来列出可用的不同 ResNet 变体
resnet_model_list = timm.list_models('*resnet*')
pretrain_resnet_model_list = timm.list_models('*resnet*' , pretrained = True)

调整模型-创建适合自己的模型

直接导入训练好的模型并不是万能的,经常会有维度不匹配的情况
比如说我的resnet34模型在cifar10和imagenet两个数据集上进行训练,分类类别不一样,输入的图片大小不一样,那我应该怎么创建合适的模型呢?

改变输出类别数目

分类类别数量:num_classes

model的主体提取特征,之后往往会接一个mlp层用作分类

如果设置num_classes,表示重设全连接层,num_classes设置为你需要分类的类别数量即可

1
2
3
4
5
6
import torch
x = torch.randn([1, 3, 224, 224])

model_resnet34_out10 = timm.create_model('resnet34', pretrained=True, num_classes=10)
out = model_resnet34_out10 (x)
print(out.shape)# Results: torch.Size([1, 10])

改变输入通道数

输入通道数:in_chans

对图片的大小,可以在输入model之前进行resize处理到统一大小

但是如果输入的图片不是传统rgb图片,通道不是3怎么办

当然,我们可以复制单通道像素来创建3通道图像,从而将其单通道输入图像转换为3通道图像。但是对于timm,他又一套申请的参数加载模式,我们可以直接改变in_chans 来指定输入图像的通道数

通道数改变后,对应的权重参数会进行相应的处理,此处不作详细说明
可参照:https://fastai.github.io/timmdocs/models或直接查看源代码

1
2
x = torch.randn([1, 1, 224, 224])
model_resnet34_in1 = timm.create_model('resnet50',pretrained=True, in_chans=1)

特性

用timm有什么好处吗,下面是一些功能特性

所有model都有一个通用的默认配置接口和API

所有模型都支持通过create_model提取中间特征(vit除外)

所有型号都有一个预训练重量加载器,可调整最后一个线性层,也可调整3通道输入为1个通道输入

并且我们还可以直接复用timm的功能模块或者一些训练tricks(Learning rate schedulers/Optimizers/Augment),简直是百宝箱