6.5 模型参数可视化

随着神经网络越来越深,越来越复杂,手动计算模型中间的数据的shape变得困难。

本节将介绍torchinfo,可用一键实现模型参数量计算、各层特征图形状计算和计算量计算等功能。

torchinfo的功能最早来自于TensorFlow和Kearas的summary()函数,torchinfo是学习借鉴而来。而在torchinfo之前还有torchsummary工具,不过torchsummary已经停止更新,并且推荐使用torchinfo。

torchsummay:https://github.com/sksq96/pytorch-summary

torchinfo:https://github.com/TylerYep/torchinfo

torchinfo 主要提供了一个函数,即

def summary(
    model: nn.Module,
    input_size: Optional[INPUT_SIZE_TYPE] = None,
    input_data: Optional[INPUT_DATA_TYPE] = None,
    batch_dim: Optional[int] = None,
    cache_forward_pass: Optional[bool] = None,
    col_names: Optional[Iterable[str]] = None,
    col_width: int = 25,
    depth: int = 3,
    device: Optional[torch.device] = None,
    dtypes: Optional[List[torch.dtype]] = None,
    mode: str | None = None,
    row_settings: Optional[Iterable[str]] = None,
    verbose: int = 1,
    **kwargs: Any,
) -> ModelStatistics:

torchinfo 演示

运行代码

    resnet_50 = models.resnet50(pretrained=False)
    batch_size = 1
    summary(resnet_50, input_size=(batch_size, 3, 224, 224))

可看到resnet50的以下信息

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 256, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-8             [1, 256, 56, 56]          512
│    │    └─Sequential: 3-9              [1, 256, 56, 56]          16,896
│    │    └─ReLU: 3-10                   [1, 256, 56, 56]          --

......

│    └─Bottleneck: 2-16                  [1, 2048, 7, 7]           --
│    │    └─Conv2d: 3-140                [1, 512, 7, 7]            1,048,576
│    │    └─BatchNorm2d: 3-141           [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-142                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-143                [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-144           [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-145                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-146                [1, 2048, 7, 7]           1,048,576
│    │    └─BatchNorm2d: 3-147           [1, 2048, 7, 7]           4,096
│    │    └─ReLU: 3-148                  [1, 2048, 7, 7]           --
├─AdaptiveAvgPool2d: 1-9                 [1, 2048, 1, 1]           --
├─Linear: 1-10                           [1, 1000]                 2,049,000
==========================================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
Total mult-adds (G): 4.09
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 177.83
Params size (MB): 102.23
Estimated Total Size (MB): 280.66
==========================================================================================

其中包括各网络层名称,以及层级关系,各网络层输出形状以及参数量。在最后还有模型的总结,包括总的参数量有25,557,032个,总的乘加(Mult-Adds)操作有4.09G(4.09*10^9次方 浮点运算),输入大小为0.60MB,参数占102.23MB。

计算量:1G表示10^9 次浮点运算 (Giga Floating-point Operations Per Second),关于乘加运算,可参考知乎问题

存储量:这里的Input size (MB): 0.60,是通过数据精度计算得到,默认情况下采用float32位存储一个数,因此输入为:3*224*224*32b = 4816896b = 602112B = 602.112 KB = 0.6 MB

同理,Params size (MB): 25557032 * 32b = 817,825,024 b = 102,228,128 B = 102.23 MB

接口详解

summary提供了很多参数可以配置打印信息,这里介绍几个常用参数。

col_names:可选择打印的信息内容,如 ("input_size","output_size","num_params","kernel_size","mult_adds","trainable",)

dtypes:可以设置数据类型,默认的为float32,单精度。

mode:可设置模型在训练还是测试状态。

verbose: 可设置打印信息的详细程度。0是不打印,1是默认,2是将weight和bias也打出来。

小结

本节介绍torchinfo的使用,并分析其参数的计算过程,这里需要了解训练参数数量、特征图参数数量和计算量。其中计算量还有一个好用的工具库进行计算,这里作为额外资料供大家学习——PyTorch-OpCounter

Copyright © TingsongYu 2021 all right reserved,powered by Gitbook文件修订时间: 2024年04月26日21:48:10

results matching ""

    No results matching ""