create model
https://blog.csdn.net/MengYa_Dream/article/details/126690336-主要参考
The model architectures included come from a wide variety of sources. Sources, including papers, original impl (“reference code”) that Ross rewrote / adapted, and PyTorch impl that he leveraged directly (“code”) are listed below.
Most included models have pretrained weights. The weights are either:
from their original sources
ported by myself from their original impl in a different framework (e.g. Tensorflow models)
trained from scratch using the included training script
The validation results for the pretrained weights can be found here.
最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm。各位炼丹师应该已经想必已经对其无比熟悉了,本文将介绍其中最关键的函数之一:create_model 函数
create_model 函数是用来创建一个网络模型(如 ResNet、ViT 等),timm 库本身可供直接调用的模型已有接近400个,用户也可以自己实现一些模型并注册进 timm (这一部分内容将在下一小节着重介绍),供自己调用
源码
create_model
该create_model
函数是用来在里面创建数百个模型的timm
。它还期望一堆**kwargs
诸如features_only
和out_indices
并将这两个传递**kwargs
给create_model
函数来创建一个特征提取器。让我们看看如何?
该create_model
函数本身只有大约 50 行代码。所以所有的神奇的事情都必须在其他地方发生。您可能已经知道,其中的每个模型名称timm.list_models()
实际上都是一个函数。
1 | #model_name:模型的名字 |
timm.models.registry模块
代码中有这样一行
create_fn = model_entrypoint(model_name)
函数获取模型的入口函数
那么他是如何实现的呢,他的运行主要依赖于timm.models.registry模块
有一个字典_model_entrypoints = {} # mapping of model names to entrypoint fns
包含了所有的模型名称和他们各自的函数
1 | def is_model(model_name): |
model_entrypoint 函数从 _model_entrypoints
内部得到模型的构造函数。
1 | def model_entrypoint(model_name): |
timm.models.vision_trasnformer模块
1 | @register_model |
注释写了出处,模型的原论文
model_kwargs把我这个模型的独特参数和传进来的**kwargs一起打包成参数字典,传给_create_vision_transformer
函数
这个函数写在文件的开头,vit的模型都是他生成的,区别就是架构参数不一样,比如层数,注意力头数
1 | def _create_vision_transformer(variant, pretrained=False, **kwargs): |
挖坑提高,星星是什么
variant 变种,就是模型的名字,比如在vit类上生成的vit-base
函数之中,会再调用 build_model_with_cfg 函数并将一个构造器类 VisionTransformer、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。
**kwargs就是之前打包的模型参数和 **kwargs
build_model_with_cfg
这个 build_model_with_cfg 函数负责:
- 真正地实例化一个模型类来创建一个模型
- 若 pruned=True,对模型进行剪枝
- 若 pretrained=True,加载预训练模型参数
- 若 features_only=True,将模型转换为特征提取器
register_model装饰器
以上就是读者在使用 timm 库时的基本方法,其实到这里你应该已经能够使用它训练自己的分类模型了。但是如果还想进一步搞清楚它的框架原理,并在它的基础上做修改,本节可能会帮到你。
https://zhuanlan.zhihu.com/p/361837010
但是问题来了,这个字典怎么来的,我写好代码之后一个一个添加进去的吗?(也不是不行hh
但是太麻烦,这时候就用到我们的 register_model
装饰器,它可以不断地像其中添加模型名称和它对应的构造函数,一开始的字典是空的,是装饰器自动添加的,其源码如下
2.5 sys.modules
该属性是一个字典,包含的是各种已加载的模块的模块名到模块具体位置的映射。
通过手动修改这个字典,可以重新加载某些模块;但要注意,切记不要大意删除了一些基本的项,否则可能会导致 Python 整个儿无法运行。
mod = sys.modules[fn.__module__]
这里我print一下fn.__module__ 应该是xxx.xxx.vit_base
模型名称就是vit_base
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
还是打印一下,我有点分不清module_name和model_name
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
展开查看
def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
重点是这一句
1 | model_name = fn.__name__ |
它将给定的 fn
添加到 _model_entrypoints
其键名为 fn.__name__
查看模型定义文件
我们会发现 timm 中的每个模型都有一个 @register_model
例如vit-base
代码很简单明了,但是我还是不懂他怎么发生作用的
他是什么时候开始把构造函数加进那个空字典的?
我要自己注册函数应该如何做?
挖坑拓展,import
modelconfig
timm 中所有的模型都有一个默认的配置,包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。
此默认配置与其他参数(如构造函数类和一些模型参数)一起传递给 build_model_with_cfg
函数