文章目录(Table of Contents)
简介
这一篇介绍Pytorch中模型的加载和保存. 关于模型的保存和加载, 主要分为下面的几种方法:
- 整个模型的保存和加载;
- 模型的参数的保存和加载;
- 模型的上下文保存, 除了保存模型的权重外, 还会保持此时的学习率, 优化器的参数, 方便恢复.
同时会在最后介绍CPU和GPU情况下的一些特殊情况.
Pytorch中变量的保存
在讲模型的保存之前, 我们首先来简单说明一下Pytorch中tensor的保存.
基础的tensor保存
首先是最基础的对tensor的保存, 可以直接使用torch.save进行保存, torch.load来进行读取. 下面看一个例子.
- x = torch.arange(4)
- # 进行保存
- torch.save(x, 'x-file')
- # 进行读取
- x2 = torch.load("x-file")
- # 比较是否一样
- x == x2
- """
- tensor([True, True, True, True])
- """
使用List的方式保存
除此之后, 我们还可以对tensor按照List的方式进行保存和读取.
- x = torch.arange(4)
- y = torch.zeros(4)
- torch.save([x, y],'x-files')
- x2, y2 = torch.load('x-files')
- (x2, y2)
- """
- (tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
- """
按照字典类型存储
最后, 我们可以按照dict的方式进行存储, 这个就很类似下面模型参数的存储了.
- x = torch.arange(4)
- y = torch.zeros(4)
- mydict = {'x': x, 'y': y}
- torch.save(mydict, 'mydict')
- mydict2 = torch.load('mydict')
- mydict2
- """
- {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
- """
模型的保存与加载
整个模型的保存与加载
我们可以将整个模型进保存, 使用torch.save即可.
- # model是模型, file是要保存的文件名
- torch.save(model, file)
接着我们可以进行模型的加载(load), 导入之前我们需要保证网络定义的类是存在的.
- # 注意load之前, 我们需要先定义模型
- loaded_model = torch.load(file)
我们可以比较一下导入前后两个模型的参数是否有改变.
- # 保存前
- print("保存前:")
- for param in model.parameters():
- print(param)
- print("=====================================")
- # 加载后
- print("保存后:")
- for param in loaded_model.parameters():
- print(param)
模型参数的保存和加载
有的时候, 保存整个模型会显得比较麻烦, 存储的文件会比较大. 所以在实际使用的时候, 我们会通常在训练的过程中, 只保存模型的参数.
我们可以使用model.state_dict()
将模型参数转为字典对象, 于是模型的参数保持可以使用下面的方式来进行保存.
- torch.save(model.state_dict(), file)
在加载参数的时候, 我们需要提前初始化模型, 接着传入保存的参数.
- dicts = torch.load('params.pkl')
- model_object.load_state_dict(dicts)
模型的上下文保存
上面我们介绍了模型的参数的保存. 但是有的时候, 我们还需要保存此时的学习率, 保存优化器的系数等. 这样一旦模型的训练终止, 我们可以从终止的地方继续开始训练我们的模型.
例如, 现在不仅有模型, 还有优化器. 那么我们保存的时候就会将模型参数, 优化器参数, 迭代次数都封装进入一个字典类型的数据.
- checkpoint = {
- "epoch": 90,
- "model_state": model.state_dict(),
- "optim_state": optimizer.state_dict()
- }
接着还是使用torch.save来保存上面字典类型的数据.
- FILE = "checkpoint.pth"
- torch.save(checkpoint, FILE)
关于这些系数的加载, 我们也是使用torch.load来进行参数的加载. 下面看一个简单的例子.
- checkpoint = torch.load(FILE)
- # 加载的文件是一个字典,根据key值,将其加载到模型、优化器、迭次次数中
- model.load_state_dict(checkpoint['model_state'])
- optimizer.load_state_dict(checkpoint['optim_state'])
- epoch = checkpoint['epoch']
GPU与CPU
由于GPU和CPU的训练模型方式不同, 因此保存下来的模型也存在不同. 为此, 面对不同环境下训练出来的模型, 我们的加载方式也存在细微的差别.
保存模型在GPU, 运行在CPU
此时在load_state_dict的时候, 需要指定map_location.
- device = torch.device('cpu')
- model.load_state_dict(torch.load(PATH, map_location=device))
保存模型在CPU, 运行在GPU
同样, 如果运行在GPU上面, 我们也是要通过map_location来进行指定的. 例如下面这个简单的例子.
- model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
- model.to(device)
- 微信公众号
- 关注微信公众号
- QQ群
- 我们的QQ群号
评论