Pytorch入门教程04-非叶子节点梯度保存

王 茂南 2020年10月3日08:38:25
评论
1 2505字阅读8分21秒
摘要这里主要介绍在Pytorch中Hook的使用, 我们可以用其来保存非叶子节点的梯度.

文章目录(Table of Contents)

简介

在上一篇中, 我们介绍了关于正向传播与反向传播的例子. 但是其中我们只能求叶子节点的梯度. 也就是在下图中, Pytorch只会保存d和e的梯度, 张量b和c的梯度在计算完毕之后会释放.

Pytorch入门教程04-非叶子节点梯度保存

但是在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是:

  • retain_grad()
  • hook

 

参考资料

 

retain_grad()

retain_grad()显式地保存非叶节点的梯度, 当然代价就是会增加显存的消耗(对比hook函数的方法则是在反向计算时直接打印, 因此不会增加显存消耗.)

但是使用起来retain_grad()要比hook函数方便一些, 我们还是使用之前的例子, 有如下的计算图.

Pytorch入门教程04-非叶子节点梯度保存

正常情况下, 我们在反向传播之后, 只有w1和w2的导数, 但是我们可以在z1等变量后面加上retain_grad, 使其梯度保持, 如下所示:

  1. def forwrad(x, y, w1, w2):
  2.     # 其中 x,y 为输入数据,w为该函数所需要的参数
  3.     z_1 = torch.mm(w1, x)
  4.     z_1.retain_grad()
  5.     y_1 = torch.sigmoid(z_1)
  6.     y_1.retain_grad()
  7.     z_2 = torch.mm(w2, y_1)
  8.     z_2.retain_grad()
  9.     y_2 = torch.sigmoid(z_2)
  10.     y_2.retain_grad()
  11.     loss = 1/2*(((y_2 - y)**2).sum())
  12.     loss.retain_grad()
  13.     return loss, z_1, y_1, z_2, y_2

接着我们进行正向传播和反向传播.

  1. # 测试代码
  2. x = torch.tensor([[1.0]])
  3. y = torch.tensor([[1.0], [0.0]])
  4. w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
  5. w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
  6. # w2 = torch.tensor([[3.0, 1.0], [1.0, 6.0]], requires_grad=True)
  7. # 正向
  8. loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
  9. # 反向
  10. loss.backward()  # 反向传播,计算梯度

接下来就可以查看中间变量的梯度了. 我们在下面举几个例子, 完整的可以在Github仓库中进行查看, 正向传播,反向传播与非叶子节点梯度保存. (下面图片可以在新窗口打开查看大图)

Pytorch入门教程04-非叶子节点梯度保存

 

hook的使用

使用retain_grad会消耗额外的显存, 我们可以使用hook在反向计算的时候进行保存. 还是上面的例子, 我们使用hook来完成.

  1. # 我们可以定义一个hook来保存中间的变量
  2. grads = {} # 存储节点名称与节点的grad
  3. def save_grad(name):
  4.     def hook(grad):
  5.         grads[name] = grad
  6.     return hook

对于forward函数, 我们不需要进行修改, 保持最基本的样子即可.

  1. def forwrad(x, y, w1, w2):
  2.     # 其中 x,y 为输入数据,w为该函数所需要的参数
  3.     z_1 = torch.mm(w1, x)
  4.     y_1 = torch.sigmoid(z_1)
  5.     z_2 = torch.mm(w2, y_1)
  6.     y_2 = torch.sigmoid(z_2)
  7.     loss = 1/2*(((y_2 - y)**2).sum())
  8.     return loss, z_1, y_1, z_2, y_2

接着进行正向传播和反向传播, register_hook是发生在backward之前的动作.

  1. # 测试代码
  2. x = torch.tensor([[1.0]])
  3. y = torch.tensor([[1.0], [0.0]])
  4. w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
  5. w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
  6. # 正向传播
  7. loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
  8. # hook中间节点
  9. z_1.register_hook(save_grad('z_1'))
  10. y_1.register_hook(save_grad('y_1'))
  11. z_2.register_hook(save_grad('z_2'))
  12. y_2.register_hook(save_grad('y_2'))
  13. loss.register_hook(save_grad('loss'))
  14. # 反向传播
  15. loss.backward()

最后我们可以打印出中间变量的梯度.

  1. print(grads['z_1'])
  2. print(grads['y_1'])
  3. print(grads['z_2'])
  4. print(grads['y_2'])
  5. print(grads['loss'])
  6. """
  7. tensor([[1.2243e-04],
  8.         [7.8005e-05]])
  9. tensor([[0.0006],
  10.         [0.0007]])
  11. tensor([[-1.0728e-05],
  12.         [ 1.3098e-04]])
  13. tensor([[-0.0033],
  14.         [ 0.9999]])
  15. tensor(1.)
  16. """

 

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南
  • 本文由 发表于 2020年10月3日08:38:25
  • 转载请务必保留本文链接:https://mathpretty.com/12509.html
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: