4.4 Module常用函数

本小节汇总介绍Module常用的方法,由于文档中是按首字母排序展示所有方法,未按用途进行归类,不便于理解各函数之间的关系。在这里,特别将具有相似功能的相关函数归纳整理,供大家参考学习。

常用方法包括:

  • 设置模型训练、评估模式

    1. eval

    2. train

  • 设置模型存放在cpu/gpu/xpu

    1. cpu
    2. cuda
    3. to
    4. xpu
  • 获取模型参数、加载权重参数

    1. load_state_dict

    2. state_dict

  • 管理模型的modules, parameters, sub_module

    1. parameters
    2. children
    3. modules
    4. named_children
    5. named_modules
    6. named_parameters
    7. get_parameter
    8. get_submodule
    9. add_module
  • 设置模型的参数精度,可选半精度、单精度、双精度等

    1. bfloat16
    2. half
    3. float
    4. double
  • 对子模块执行特定功能

    1. apply
    2. zero_grad

以上是不完全的列举,有些非高频使用的函数请到文档中查阅。下面通过简介和配套代码的形式学习上述函数的使用。

设置模型训练、评估模式

eval:设置模型为评估模式,这一点与上一小节介绍的BN,Dropout息息相关,评估模式下模型的某些层执行的操作与训练状态下是不同的。

train:设置模型为训练模式,如BN层需要统计running_var这些统计数据,Dropout层需要执行随机失活等。

使用方法过于简单,无需代码展示。

设置模型存放在cpu/gpu

对于gpu的使用会在后面设置单独小节详细介绍,由于这里是基础学习,暂时可不考虑运算速度问题。这里既然遇到了相关的概念,就简单说一下。

pytorch可以利用gpu进行加速运算,早期只支持NVIDIA公司的GPU,现在也逐步开始支持AMD的GPU。使用gpu进行运算的方法很简单,就是把需要运算的数据放到gpu即可。方法就是 xxx.cuda(),若想回到cpu运算,那就需要xxx.cpu()即可。但有一个更好的方法是to(),to方法可将对象放到指定的设备中去,如to.("cpu") 、 to.("cuda)、to("cuda:0")等。

cpu:将Module放到cpu上。

cuda:将Module放到cuda上。为什么是cuda不是gpu呢?因为CUDA(Compute Unified Device Architecture)是NVIDIA推出的运算平台,数据是放到那上面进行运算,而gpu可以有很多个品牌,因此用cuda更合理一些。

to:将Module放到指定的设备上。

关于to通常会配备torch.cuda.is_available()使用,请看配套代码学习。

获取模型参数、加载权重参数

模型训练完毕后,我们需要保存的核心内容是模型参数,这样可以供下次使用,或者是给别人进行finetune。相信大家都用ImageNet上的预训练模型,而使用方法就是官方训练完毕后保存模型的参数,供我们下载,然后加载到自己的模型中。在这里就涉及两个重要操作:保存模型参数与加载模型参数,分别要用到以下两个函数。

state_dict:返回参数字典。key是告诉你这个权重参数是放到哪个网络层。

load_state_dict:将参数字典中的参数复制到当前模型中。这里的复制要求key要一一对应,若key对不上,自然模型不知道要把这个参数放到哪里去。绝大多数开发者都会在load_state_dict这里遇到过报错,如

RuntimeError: Error(s) in loading state_dict for ResNet:
   Missing key(s) in state_dict: xxxxxxxx
  Unexpected key(s) in state_dict: xxxxxxxxxx

这通常是拿到的参数字典与模型当前的结构不匹配。

对于load_state_dict函数,还有两个参数可以设置,请看原型:

参数:

  • state_dict (dict) – a dict containing parameters and persistent buffers.
  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

返回项

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

上述两个方法具体的使用请看配套代码

管理模型的modules, parameters, sub_module

模型中需要管理的主要是parameter与module,每个对象都有两种方式读取,分别是带名字和不带名字的。针对module还有一个称为children的方法,它与modules方法最大的不同在于modules会返回module本身。具体差异通过配套代码一看便明了。

parameters:返回一个迭代器,迭代器可抛出Module的所有parameter对象

named_parameters:作用同上,不仅可得到parameter对象,还会给出它的名称

modules:返回一个迭代器,迭代器可以抛出Module的所有Module对象,注意:模型本身也是module,所以也会获得自己。

named_modules:作用同上,不仅可得到Module对象,还会给出它的名称

children:作用同modules,但不会返回Module自己。

named_children:作用同named_modules,但不会返回Module自己。

获取某个参数或submodule

当想查看某个部分数据时,可以通过get_xxx方法获取模型特定位置的数据,可获取parameter、submodule,使用方法也很简单,只需要传入对应的name即可。

get_parameter

get_submodule

设置模型的参数精度,可选半精度、单精度、双精度等

为了调整模型占存储空间的大小,可以设置参数的数据类型,默认情况是float32位(单精度),在一些场景可采用半精度、双精度等,以此改变模型的大小或精度。Module提供了几个转换权重参数精度的方法,分别如下:

  • half:半精度
  • float:单精度
  • double:双精度
  • bfloat16:Brain Floating Point 是Google开发的一种数据格式,详细参见wikipedia

对子模块执行特定功能

zero_grad:将所有参数的梯度设置为0,或者None

apply:对所有子Module执行指定fn(函数),常见于参数初始化。这个可以参见配套代码。

小结

本节对Module的常用API函数进行了介绍,包括模型两种状态,模型存储于何种设备,模型获取参数,加载参数,管理模型的modules,设置模型参数的精度,对模型子模块执行特定功能。

由于Module是核心模块,其涉及的API非常多,短时间不好消化,建议大家结合代码用例,把这些方法都过一遍,留个印象,待日后项目开发需要的时候知道有这些函数可以使用即可。

下一小节将介绍Module中的Hook函数。

Copyright © TingsongYu 2021 all right reserved,powered by Gitbook文件修订时间: 2024年04月26日21:48:10

results matching ""

    No results matching ""