4.4 Module常用函数
本小节汇总介绍Module常用的方法,由于文档中是按首字母排序展示所有方法,未按用途进行归类,不便于理解各函数之间的关系。在这里,特别将具有相似功能的相关函数归纳整理,供大家参考学习。
常用方法包括:
设置模型训练、评估模式
eval
train
设置模型存放在cpu/gpu/xpu
- cpu
- cuda
- to
- xpu
获取模型参数、加载权重参数
load_state_dict
state_dict
管理模型的modules, parameters, sub_module
- parameters
- children
- modules
- named_children
- named_modules
- named_parameters
- get_parameter
- get_submodule
- add_module
设置模型的参数精度,可选半精度、单精度、双精度等
- bfloat16
- half
- float
- double
对子模块执行特定功能
- apply
- zero_grad
以上是不完全的列举,有些非高频使用的函数请到文档中查阅。下面通过简介和配套代码的形式学习上述函数的使用。
设置模型训练、评估模式
eval:设置模型为评估模式,这一点与上一小节介绍的BN,Dropout息息相关,评估模式下模型的某些层执行的操作与训练状态下是不同的。
train:设置模型为训练模式,如BN层需要统计running_var这些统计数据,Dropout层需要执行随机失活等。
使用方法过于简单,无需代码展示。
设置模型存放在cpu/gpu
对于gpu的使用会在后面设置单独小节详细介绍,由于这里是基础学习,暂时可不考虑运算速度问题。这里既然遇到了相关的概念,就简单说一下。
pytorch可以利用gpu进行加速运算,早期只支持NVIDIA公司的GPU,现在也逐步开始支持AMD的GPU。使用gpu进行运算的方法很简单,就是把需要运算的数据放到gpu即可。方法就是 xxx.cuda()
,若想回到cpu运算,那就需要xxx.cpu()
即可。但有一个更好的方法是to(),to方法可将对象放到指定的设备中去,如to.("cpu") 、 to.("cuda)、to("cuda:0")
等。
cpu:将Module放到cpu上。
cuda:将Module放到cuda上。为什么是cuda不是gpu呢?因为CUDA(Compute Unified Device Architecture)是NVIDIA推出的运算平台,数据是放到那上面进行运算,而gpu可以有很多个品牌,因此用cuda更合理一些。
to:将Module放到指定的设备上。
关于to通常会配备torch.cuda.is_available()使用,请看配套代码学习。
获取模型参数、加载权重参数
模型训练完毕后,我们需要保存的核心内容是模型参数,这样可以供下次使用,或者是给别人进行finetune。相信大家都用ImageNet上的预训练模型,而使用方法就是官方训练完毕后保存模型的参数,供我们下载,然后加载到自己的模型中。在这里就涉及两个重要操作:保存模型参数与加载模型参数,分别要用到以下两个函数。
state_dict:返回参数字典。key是告诉你这个权重参数是放到哪个网络层。
load_state_dict:将参数字典中的参数复制到当前模型中。这里的复制要求key要一一对应,若key对不上,自然模型不知道要把这个参数放到哪里去。绝大多数开发者都会在load_state_dict这里遇到过报错,如
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: xxxxxxxx
Unexpected key(s) in state_dict: xxxxxxxxxx
这通常是拿到的参数字典与模型当前的结构不匹配。
对于load_state_dict函数,还有两个参数可以设置,请看原型:
参数:
- state_dict (dict) – a dict containing parameters and persistent buffers.
- strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True
返回项
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
上述两个方法具体的使用请看配套代码。
管理模型的modules, parameters, sub_module
模型中需要管理的主要是parameter与module,每个对象都有两种方式读取,分别是带名字和不带名字的。针对module还有一个称为children的方法,它与modules方法最大的不同在于modules会返回module本身。具体差异通过配套代码一看便明了。
parameters:返回一个迭代器,迭代器可抛出Module的所有parameter对象
named_parameters:作用同上,不仅可得到parameter对象,还会给出它的名称
modules:返回一个迭代器,迭代器可以抛出Module的所有Module对象,注意:模型本身也是module,所以也会获得自己。
named_modules:作用同上,不仅可得到Module对象,还会给出它的名称
children:作用同modules,但不会返回Module自己。
named_children:作用同named_modules,但不会返回Module自己。
获取某个参数或submodule
当想查看某个部分数据时,可以通过get_xxx方法获取模型特定位置的数据,可获取parameter、submodule,使用方法也很简单,只需要传入对应的name即可。
get_parameter
get_submodule
设置模型的参数精度,可选半精度、单精度、双精度等
为了调整模型占存储空间的大小,可以设置参数的数据类型,默认情况是float32位(单精度),在一些场景可采用半精度、双精度等,以此改变模型的大小或精度。Module提供了几个转换权重参数精度的方法,分别如下:
- half:半精度
- float:单精度
- double:双精度
- bfloat16:Brain Floating Point 是Google开发的一种数据格式,详细参见wikipedia
对子模块执行特定功能
zero_grad:将所有参数的梯度设置为0,或者None
apply:对所有子Module执行指定fn(函数),常见于参数初始化。这个可以参见配套代码。
小结
本节对Module的常用API函数进行了介绍,包括模型两种状态,模型存储于何种设备,模型获取参数,加载参数,管理模型的modules,设置模型参数的精度,对模型子模块执行特定功能。
由于Module是核心模块,其涉及的API非常多,短时间不好消化,建议大家结合代码用例,把这些方法都过一遍,留个印象,待日后项目开发需要的时候知道有这些函数可以使用即可。
下一小节将介绍Module中的Hook函数。