【白菜】timm篇五:创建模型 | cabbage

cabbage

菜鸟写给小白的教程

0%

【白菜】timm篇五:创建模型

create model

https://blog.csdn.net/MengYa_Dream/article/details/126690336-主要参考

The model architectures included come from a wide variety of sources. Sources, including papers, original impl (“reference code”) that Ross rewrote / adapted, and PyTorch impl that he leveraged directly (“code”) are listed below.

Most included models have pretrained weights. The weights are either:

from their original sources
ported by myself from their original impl in a different framework (e.g. Tensorflow models)
trained from scratch using the included training script
The validation results for the pretrained weights can be found here.

最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm。各位炼丹师应该已经想必已经对其无比熟悉了,本文将介绍其中最关键的函数之一:create_model 函数

create_model 函数是用来创建一个网络模型(如 ResNet、ViT 等),timm 库本身可供直接调用的模型已有接近400个,用户也可以自己实现一些模型并注册进 timm (这一部分内容将在下一小节着重介绍),供自己调用

源码

create_model

create_model函数是用来在里面创建数百个模型的timm。它还期望一堆**kwargs诸如features_onlyout_indices并将这两个传递**kwargscreate_model函数来创建一个特征提取器。让我们看看如何?

create_model函数本身只有大约 50 行代码。所以所有的神奇的事情都必须在其他地方发生。您可能已经知道,其中的每个模型名称timm.list_models()实际上都是一个函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#model_name:模型的名字
#pretrained:是否加载预训练模型
#checkpoint_path:加载预训练模型的路径
#scriptable:是否使用脚本化
#exportable:是否使用导出
#no_jit:是否使用jit
#**kwargs:其他参数

#首先通过split_model_name函数将模型名字分割成两部分,一部分是模型的来源,一部分是模型的名字
#然后通过is_model函数判断模型是否存在
#如果存在,通过model_entrypoint函数获取模型的入口函数
#然后通过set_layer_config函数设置层的配置
#最后通过load_checkpoint函数加载预训练模型
def create_model(
model_name,
pretrained=False,
checkpoint_path='',
scriptable=None,
exportable=None,
no_jit=None,
**kwargs):
"""Create a model

Args:
model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
checkpoint_path (str): path of checkpoint to load after model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)

Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific
"""
source_name, model_name = split_model_name(model_name)

# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
" Setting drop_path to %f." % drop_connect_rate)
kwargs['drop_path_rate'] = drop_connect_rate

# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if source_name == 'hf_hub':
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
# load model weights + default_cfg from Hugging Face hub.
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday

if is_model(model_name):
create_fn = model_entrypoint(model_name)
else:
raise RuntimeError('Unknown model (%s)' % model_name)

with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, **kwargs)

if checkpoint_path:
load_checkpoint(model, checkpoint_path)

return model

timm.models.registry模块

代码中有这样一行

create_fn = model_entrypoint(model_name)

函数获取模型的入口函数

那么他是如何实现的呢,他的运行主要依赖于timm.models.registry模块

有一个字典_model_entrypoints = {} # mapping of model names to entrypoint fns

包含了所有的模型名称和他们各自的函数

1
2
3
4
def is_model(model_name):
""" Check if a model name exists
"""
return model_name in _model_entrypoints

model_entrypoint 函数从 _model_entrypoints 内部得到模型的构造函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def model_entrypoint(model_name):
"""Fetch a model entrypoint for specified model name
"""
return _model_entrypoints[model_name]

print(timm.models.registry._model_entrypoints)
print(timm.models.registry._model_entrypoints.keys())
'''
输出是下面这样子的,太多了,只贴出来一点
{'vit_tiny_patch16_224': <function vit_tiny_patch16_224 at 0x000002450498A048>, 'vit_tiny_patch16_384': <function vit_tiny_patch16_384 at 0x000002450498A0D0>, 'vit_small_patch32_224': <function vit_small_patch32_224 at 0x000002450498A158>, 'vit_small_patch32_384': <function vit_small_patch32_384 at 0x000002450498A1E0>}
可以看到是个字典,key是模型名称,value是模型的构造函数,
以vit为例,vit_base_patch16_224模型的构造函数,他存在于timm.models.vision_trasnformer.py
'''


'''
create_fn = model_entrypoint(model_name)
#以model_name = 'xception71' 为例
#<function xception71 at 0x7fc0cba0eca0>
#<function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>
'''

timm.models.vision_trasnformer模块

1
2
3
4
5
6
7
8
@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model

注释写了出处,模型的原论文

model_kwargs把我这个模型的独特参数和传进来的**kwargs一起打包成参数字典,传给_create_vision_transformer函数

这个函数写在文件的开头,vit的模型都是他生成的,区别就是架构参数不一样,比如层数,注意力头数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')

if 'flexi' in variant:
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
_filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
else:
_filter_fn = checkpoint_filter_fn

return build_model_with_cfg(
VisionTransformer, variant, pretrained,
pretrained_filter_fn=_filter_fn,
**kwargs,
)

挖坑提高,星星是什么

variant 变种,就是模型的名字,比如在vit类上生成的vit-base

函数之中,会再调用 build_model_with_cfg 函数并将一个构造器类 VisionTransformer、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。

**kwargs就是之前打包的模型参数和 **kwargs

build_model_with_cfg

这个 build_model_with_cfg 函数负责:

  • 真正地实例化一个模型类来创建一个模型
  • 若 pruned=True,对模型进行剪枝
  • 若 pretrained=True,加载预训练模型参数
  • 若 features_only=True,将模型转换为特征提取器

register_model装饰器

以上就是读者在使用 timm 库时的基本方法,其实到这里你应该已经能够使用它训练自己的分类模型了。但是如果还想进一步搞清楚它的框架原理,并在它的基础上做修改,本节可能会帮到你。

https://zhuanlan.zhihu.com/p/361837010

但是问题来了,这个字典怎么来的,我写好代码之后一个一个添加进去的吗?(也不是不行hh

但是太麻烦,这时候就用到我们的 register_model 装饰器,它可以不断地像其中添加模型名称和它对应的构造函数,一开始的字典是空的,是装饰器自动添加的,其源码如下

2.5 sys.modules
该属性是一个字典,包含的是各种已加载的模块的模块名到模块具体位置的映射。

通过手动修改这个字典,可以重新加载某些模块;但要注意,切记不要大意删除了一些基本的项,否则可能会导致 Python 整个儿无法运行。

mod = sys.modules[fn.__module__]

这里我print一下fn.__module__ 应该是xxx.xxx.vit_base

模型名称就是vit_base

module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''

还是打印一下,我有点分不清module_name和model_name

_model_to_module[model_name] = module_name

_module_to_models[module_name].add(model_name)

展开查看

def register_model(fn):
    # lookup containing module
    mod = sys.modules[fn.__module__]
    module_name_split = fn.__module__.split('.')
    module_name = module_name_split[-1] if len(module_name_split) else ''
    # add model to __all__ in module
    model_name = fn.__name__
    if hasattr(mod, '__all__'):
        mod.__all__.append(model_name)
    else:
        mod.__all__ = [model_name]
    # add entries to registry dict/sets
    _model_entrypoints[model_name] = fn
    _model_to_module[model_name] = module_name
    _module_to_models[module_name].add(model_name)
    has_pretrained = False  # check if model has a pretrained url to allow filtering on this
    if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
        # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
        # entrypoints or non-matching combos
        has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
        _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
    if has_pretrained:
        _model_has_pretrained.add(model_name)
    return fn
  

重点是这一句

1
2
3
model_name = fn.__name__

_model_entrypoints[model_name] = fn

它将给定的 fn 添加到 _model_entrypoints 其键名为 fn.__name__

查看模型定义文件

我们会发现 timm 中的每个模型都有一个 @register_model

例如vit-base

代码很简单明了,但是我还是不懂他怎么发生作用的

他是什么时候开始把构造函数加进那个空字典的?

我要自己注册函数应该如何做?

挖坑拓展,import

modelconfig

timm 中所有的模型都有一个默认的配置,包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。

此默认配置与其他参数(如构造函数类和一些模型参数)一起传递给 build_model_with_cfg 函数