Pytorch入门教程06-损失函数和优化器

  • A+
所属分类:Pytorch快速入门
摘要在之前的内容中, 我们通过自己定义的损失函数和系数更新的方法来更新系数. 但其实在Pytorch中, 已经包含了相应的函数, 可以直接来计算损失, 和完成梯度下降. 所以在这一部分, 我们会介绍优化器和损失函数.

简介

在上一篇文章中, 我们介绍了梯度下降算法, 在求出系数的偏导之后, 如何来进行系数的优化. 但是在上一次文章中, 我们自己定义了损失函数和系数更新的方法.

但其实在Pytorch中, 包含了相应的函数, 可以直接来计算损失, 和完成梯度下降. 所以在这一篇中, 我们还是使用求解一元线性回归的例子, 来看一下如何使用Pytorch默认提供的函数求解.

Github仓库链接: 损失函数和优化器的介绍

 

损失函数与优化器

定义模型

我们首先定义我们需要用到的线性模型, 这里只需要一个参数w.

  1. w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)
  2. def forward(x):
  3.     return w * x

 

损失函数

在线性回归的例子中, 我们应该使用均方差损失函数, 使用torch.nn.MSELoss()表示. 于是我们可以将代码写成下面的形式.

  1. loss = torch.nn.MSELoss() # 定义均方损失函数

 

优化器

之前我们使用了最基础的梯度下降来更新参数w. 但其实, 在torch.optim中存在着各种梯度下降的改进算法, 比如 SGD, Momentum, RMSProp和Adam等. (这些算法都是以传统梯度下降算法为基础改进得到的, 这些算法可以更快更准确地求解最佳模型参数.)

我们在这里定义一个SGD优化器,

  1. optimizer = torch.optim.SGD([w], lr=learning_rate)

其中:

  • w表示我们要更新的参数(网络的权重)
  • lr表示学习率

在Pytorch中, 还optimizer还提供可以一次更新全部的参数, 和参数梯度清零两个功能.

  • optimizer.step(): 对神经网络(复合函数)的相关变量进行更新, 即所有参数值向梯度相反方向走一步;
  • optimizer.zero_grad(): 对神经网络(复合函数)的相关系数进行梯度的清空;

我们还可以在优化器中加入weight decay的内容. 在Pytorch中, 默认的是L2的正则项. 我们可以使用下面的方法来实现. 下面是只对weight进行正则化, 不对bias进行正则化.

  1. trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)

 

更新权重

有了上面定义的损失函数和优化器之后, 我们就可以开始更新权重w了.

  1. X_tensor = torch.from_numpy(X)
  2. Y_tensor = torch.from_numpy(Y)
  3. n_iters = 100
  4. for epoch in range(n_iters):
  5.     y_pred = forward(X_tensor)
  6.     l = loss(Y_tensor, y_pred) # 求误差(注意这里的顺序)
  7.     l.backward() # 求梯度
  8.     optimizer.step()  # 更新权重,即向梯度方向走一步
  9.     optimizer.zero_grad() # 清空梯度
  10.     if epoch % 20 == 0:
  11.         print(f'epoch {epoch+1}: w = {w.item():.3f}, loss = {l.item():.3f}')
  12. print(f'根据训练模型预测, 当x=5时, y的值为: {forward(5):.3f}')
  13. """
  14. epoch 1: w = 0.142, loss = 153.912
  15. epoch 21: w = 1.642, loss = 10.674
  16. epoch 41: w = 2.028, loss = 1.210
  17. epoch 61: w = 2.127, loss = 0.585
  18. epoch 81: w = 2.152, loss = 0.544
  19. 根据训练模型预测, 当x=5时, y的值为: 10.794
  20. """

可以看到在经过100次迭代之后, 最终的系数w接近于2.

 

 

 

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: