7.1 模型保存与加载

保存与加载的概念(序列化与反序列化)

模型训练完毕之后,肯定想要把它保存下来,供以后使用,不需要再次去训练。

那么在pytorch中如何把训练好的模型,保存,保存之后又如何加载呢?

这就用需要序列化与反序列化,序列化与反序列化的概念如下图所示:

因为在内存中的数据,运行结束会进行释放,所以我们需要将数据保存到硬盘中,以二进制序列的形式进行长久存储,便于日后使用。

序列化即把对象转换为字节序列的过程,反序列化则把字节序列恢复为对象。

在pytorch中,对象就是模型,所以我们常常听到序列化和反序列化,就是将训练好的模型从内存中保存到硬盘里,当要使用的时候,再从硬盘中加载。

torch.save / torch.load

pytorch提供的序列化与反序列化函数分别是

1.
torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

功能:保存对象到硬盘中

主要参数:

  • obj- 对象
  • f - 文件路径
2.
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

功能:加载硬盘中对象

主要参数:

  • f - 文件路径
  • map_location - 指定存储位置,如map_location='cpu', map_location={'cuda:1':'cuda:0'}

这里的map_location大有文章,经常需要手动设置,否者会报错。具体可参考以下形式:

GPU->CPU:torch.load(model_path, map_location='cpu')

CPU->GPU:torch.load(model_path, map_location=lambda storage, loc: storage)

两种保存方式

pytorch保存模型有两种方式

  1. 保存整个模型

  2. 保存模型参数

我们通过示意图来区分两者之间的差异

从上图左边知道法1保存整个nn.Module, 而法2只保存模型的参数信息。

我们知道一个module当中包含了很多信息,不仅仅是模型的参数 parameters,还包含了buffers, hooks和modules等一系列信息。

对于模型应用,最重要的是模型的parameters,其余的信息是可以通过model 类再去构建的,所以模型保存就有两种方式

  1. 所有内容都保存;

  2. 仅保存模型的parameters。

通常,我们只需要保存模型的参数,在使用的时候再通过load_state_dict方法加载参数。

由于第一种方法不常用,并且在加载过程中还需要指定的类方法,因此不做演示也不推荐。

对于第二种方法的代码十分简单,请看示例:

net_state_dict = net.state_dict()
torch.save(net_state_dict, "my_model.pth")

常用的代码段

在模型开发过程中,往往不是一次就能训练好模型,经常需要反复训练,因此需要保存训练的“状态信息”,以便于基于某个状态继续训练,这就是常说的resume,可以理解为断点续训练。

在整个训练阶段,除了模型参数需要保存,还有优化器的参数、学习率调整器的参数和迭代次数等信息也需要保存,因此推荐在训练时,采用以下代码段进行模型保存。以下代码来自torchvision的训练脚本

checkpoint = {
    "model": model_without_ddp.state_dict(),
    "optimizer": optimizer.state_dict(),
    "lr_scheduler": lr_scheduler.state_dict(),
    "epoch": epoch,
}
path_save = "model_{}.pth".format(epoch)
torch.save(checkpoint, path_save

# =================== resume ===============
# resume
checkpoint = torch.load(path_save, map_location="cpu")
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
start_epoch = checkpoint["epoch"] + 1

小结

模型保存与加载比较简单,需要注意的有两点:

  1. torch.load的时候注意map_location的设置;
  2. 理解checkpoint resume的概念,以及训练过程是需要模型、优化器、学习率调整器和已迭代次数的共同配合。
Copyright © TingsongYu 2021 all right reserved,powered by Gitbook文件修订时间: 2024年04月26日21:48:10

results matching ""

    No results matching ""