Pytorch入门教程15-Pytorch中模型的保存和加载

王 茂南 2020年10月14日07:38:46
评论
2447字阅读8分9秒
摘要当我们完成了模型的训练之后, 我们会希望将其保存下来, 之后可以进行使用. 或是在训练过程中, 我们需要定时对模型进行保存. 所以这一篇, 我们会介绍Pytorch中模型的加载和保存.

简介

这一篇介绍Pytorch中模型的加载和保存. 关于模型的保存和加载, 主要分为下面的几种方法:

  • 整个模型的保存和加载;
  • 模型的参数的保存和加载;
  • 模型的上下文保存, 除了保存模型的权重外, 还会保持此时的学习率, 优化器的参数, 方便恢复.

同时会在最后介绍CPU和GPU情况下的一些特殊情况.

 

Pytorch中变量的保存

在讲模型的保存之前, 我们首先来简单说明一下Pytorch中tensor的保存.

基础的tensor保存

首先是最基础的对tensor的保存, 可以直接使用torch.save进行保存, torch.load来进行读取. 下面看一个例子.

  1. x = torch.arange(4)
  2. # 进行保存
  3. torch.save(x, 'x-file')
  4. # 进行读取
  5. x2 = torch.load("x-file")
  6. # 比较是否一样
  7. x == x2
  8. """
  9. tensor([True, True, True, True])
  10. """

 

使用List的方式保存

除此之后, 我们还可以对tensor按照List的方式进行保存和读取.

  1. x = torch.arange(4)
  2. y = torch.zeros(4)
  3. torch.save([x, y],'x-files')
  4. x2, y2 = torch.load('x-files')
  5. (x2, y2)
  6. """
  7. (tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
  8. """

 

按照字典类型存储

最后, 我们可以按照dict的方式进行存储, 这个就很类似下面模型参数的存储了.

  1. x = torch.arange(4)
  2. y = torch.zeros(4)
  3. mydict = {'x': x, 'y': y}
  4. torch.save(mydict, 'mydict')
  5. mydict2 = torch.load('mydict')
  6. mydict2
  7. """
  8. {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
  9. """

 

 

模型的保存与加载

整个模型的保存与加载

我们可以将整个模型进保存, 使用torch.save即可.

  1. # model是模型, file是要保存的文件名
  2. torch.save(model, file)

接着我们可以进行模型的加载(load), 导入之前我们需要保证网络定义的类是存在的.

  1. # 注意load之前, 我们需要先定义模型
  2. loaded_model = torch.load(file)

我们可以比较一下导入前后两个模型的参数是否有改变.

  1. # 保存前
  2. print("保存前:")
  3. for param in model.parameters():
  4.     print(param)
  5. print("=====================================")
  6. # 加载后
  7. print("保存后:")
  8. for param in loaded_model.parameters():
  9.     print(param)

 

模型参数的保存和加载

有的时候, 保存整个模型会显得比较麻烦, 存储的文件会比较大. 所以在实际使用的时候, 我们会通常在训练的过程中, 只保存模型的参数.

我们可以使用model.state_dict()将模型参数转为字典对象, 于是模型的参数保持可以使用下面的方式来进行保存.

  1. torch.save(model.state_dict(), file)

在加载参数的时候, 我们需要提前初始化模型, 接着传入保存的参数.

  1. dicts = torch.load('params.pkl')
  2. model_object.load_state_dict(dicts)

 

模型的上下文保存

上面我们介绍了模型的参数的保存. 但是有的时候, 我们还需要保存此时的学习率, 保存优化器的系数等. 这样一旦模型的训练终止, 我们可以从终止的地方继续开始训练我们的模型.

例如, 现在不仅有模型, 还有优化器. 那么我们保存的时候就会将模型参数, 优化器参数, 迭代次数都封装进入一个字典类型的数据.

  1. checkpoint = {
  2.     "epoch": 90,
  3.     "model_state": model.state_dict(),
  4.     "optim_state": optimizer.state_dict()
  5. }

接着还是使用torch.save来保存上面字典类型的数据.

  1. FILE = "checkpoint.pth"
  2. torch.save(checkpoint, FILE)

关于这些系数的加载, 我们也是使用torch.load来进行参数的加载. 下面看一个简单的例子.

  1. checkpoint = torch.load(FILE)
  2. # 加载的文件是一个字典,根据key值,将其加载到模型、优化器、迭次次数中
  3. model.load_state_dict(checkpoint['model_state'])
  4. optimizer.load_state_dict(checkpoint['optim_state'])
  5. epoch = checkpoint['epoch']

 

GPU与CPU

由于GPU和CPU的训练模型方式不同, 因此保存下来的模型也存在不同. 为此, 面对不同环境下训练出来的模型, 我们的加载方式也存在细微的差别.

保存模型在GPU, 运行在CPU

此时在load_state_dict的时候, 需要指定map_location.

  1. device = torch.device('cpu')
  2. model.load_state_dict(torch.load(PATH, map_location=device))

 

保存模型在CPU, 运行在GPU

同样, 如果运行在GPU上面, 我们也是要通过map_location来进行指定的. 例如下面这个简单的例子.

  1. model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
  2. model.to(device)

 

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南
  • 本文由 发表于 2020年10月14日07:38:46
  • 转载请务必保留本文链接:https://mathpretty.com/12577.html
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: