理解 PyTorch 中维度的概念

王 茂南 2020年4月11日07:51:51
评论
36 2933字阅读9分46秒
摘要一篇非常好的, 对于PyTorch中维度来进行介绍. 使用动图直观说明各个维度(dim)下求和和topk操作.

简介

今天在使用 Torch 中的 topk 的时候,对于 dim 产生了一些疑问。后面也是找到了规律,但是还是很困惑他为什么是这么设计的,即 dimtensor 本身的行列是不一致的。然后就查了一下,真的找到了一篇很好的文章,解决了我的困惑,就想在这里记录一下。

我这一篇文章里的所有的动图,都是来自与下面这篇文章,写的非常直观。

原文链接(十分棒的文章), Understanding dimensions in PyTorch

关于这里设计的代码, 有一个完整的notebook的文档, 具体链接见GithubPytorch维度介绍.ipynb

 

理解 PyTorch 维度概念

首先我们从最基础的开始, 当我们在 Pytorch 中定义一个二维的 tensor 的时候, 他包含行和列. 例如下面我们创建一个 2✖3tensor

  1. x = torch.tensor([
  2.         [1,2,3],
  3.         [4,5,6]
  4.     ])
  5. # 我们可以看到"行"是dim=0, "列"是dim=1
  6. print(x.shape)
  7. >> torch.Size([2, 3])

我们可以看到打印的结果显示:

  • first dimension (dim=0) stays for rows, 第一个维度代表行, 因为是2, 实际x就是2行
  • the second one (dim=1) for columns, 第二个维度代表列, 因为是3

于是, 我们会认为, torch.sum(x, dim=0)就是(1+2+3, 4+5+6)=tensor([6, 15]), 但是实际情况却不是这个样子的.

  1. torch.sum(x, dim=0)
  2. >> tensor([5, 7, 9])

我们可以看到按照 dim=0 求和, 其实是在按列相加, 也就是 (1+4, 2+5, 3+6) =tensor([5, 7, 9]), 和我们想象的完全不一样. 我们再看一下按照 dim=1 进行求和.

  1. torch.sum(x, dim=1)
  2. >> tensor([ 6, 15])

可以看到, 在按照 dim=1 的时候求和的时候, 其实在按照按行进行求和,  (1+2+3, 4+5+6)=tensor([6, 15]), 这就让人很困惑, 明明上面说的是 dim=0 代表是行。

于是, 原文作者在一篇介绍 numpy 维度的文章中, 找到了问题的关键所在. 也就是下面的这段话( numpy 中的 axis 也就是这里的 dim).

The way to understand the "axis" of numpy sum is that it collapses the specified axis. So when it collapses the axis 0 (the row), it becomes just one row (it sums column-wise).

上面的话简单翻译就是, 当按照axis=0进行求和的时候, 其实可以想象为对axis=0这个维度进行挤压, 最后只剩下一行, 那一行就是结果, 也就是按列在相加.

是不是还是会有一些困惑, 我们还是对于上面的例子(tensor([[1,2,3], [4,5,6]])), 看一下在dim=0的时候, 为什么是列相加, 以及上面的collapse the specific axis(dim)的含义.

理解 PyTorch 中维度的概念

如上面的动图所示, 当dim=0的时候, 按每一行的元素进行相加, 最后的结果就是和按列求和.

 

对于三维向量

下面我们更进一步, 来看一下对于三维的tensor, 在各个维度进行sum操作的结果. 首先我们看一下每一个dim代表的含义.

  1. # 看一下三维的
  2. x = torch.tensor([
  3.         [
  4.          [1,2,3],
  5.          [4,5,6]
  6.         ],
  7.         [
  8.          [1,2,3],
  9.          [4,5,6]
  10.         ],
  11.         [
  12.          [1,2,3],
  13.          [4,5,6]
  14.         ]
  15.     ])
  16. # 我们可以看到第三维是dim=0, "行"是dim=1, 列是dim=2
  17. print(x.shape)
  18. >> torch.Size([3, 2, 3])

可以看到此时dim=0是第三个维度, dim=1是行, dim=2是列.

  1. torch.sum(x, dim=0)
  2. >>
  3. tensor([[ 3,  6,  9],
  4.         [12, 15, 18]])

我们可以将其看成是各个二维平面对应元素求和, 还是有点绕, 还是直接看下面的动图.

理解 PyTorch 中维度的概念

接着是对dim=1进行求和.

  1. torch.sum(x, dim=1)
  2. >>
  3. tensor([[5, 7, 9],
  4.         [5, 7, 9],
  5.         [5, 7, 9]])

还是直接看下面的动图, 来进行理解.

理解 PyTorch 中维度的概念

最后按照dim=2来进行求和.

  1. torch.sum(x, dim=2)
  2. >>
  3. tensor([[ 6, 15],
  4.         [ 6, 15],
  5.         [ 6, 15]])

还是使用动图来进行解释.

理解 PyTorch 中维度的概念

 

关于TopK的问题

最后说一下topk与dim之间的问题, 这个也是我产生这次疑惑的根源. 关于topk, 我们会在计算准确率的时候用到这个.

比如说一个4分类的问题, 有如下的三个输出. 我们希望找到第一个是预测的结果是第几个分类, 这个时候就需要使用到topk, 返回最大的前K个值和他的位置, 这个位置就可以用来知道label.

  1. # 看一下二维的
  2. x = torch.tensor([
  3.     [0.1, 0.2 ,0.5, 0.2],
  4.     [0.4, 0.3, 0.2, 0.1],
  5.     [0.1, 0.2, 0.5, 0.1],
  6. ])
  7. # -------
  8. print(x.shape)
  9. >> torch.Size([3, 4])

比如上面的数据x, 我们就是要在一行进行比较, 最后结果是一个3*1的tensor, 所以根据上面讲的, 就是对列进行collapse, 也就是dim=1, 我们测试一下看一下结果.

  1. a, b = x.topk(1, dim=1)
  2. print(a)
  3. print('-'*10)
  4. print(b)
  5. print(b.shape)
  6. >>
  7. tensor([[0.5000],
  8.         [0.4000],
  9.         [0.5000]])
  10. ----------
  11. tensor([[2],
  12.         [0],
  13.         [2]])
  14. torch.Size([3, 1])

可以看到和我们预料的是一样的. 如果dim=0, 那么返回的就是一个1*4的向量, 应该就会返回tensor([0.4,0.3,0.5,0.2]), 我们实际做一下看一下是不是这个.

  1. a, b = x.topk(1, dim=0)
  2. print(a)
  3. print('-'*10)
  4. print(b)
  5. print(b.shape)
  6. >>
  7. tensor([[0.4000, 0.3000, 0.5000, 0.2000]])
  8. ----------
  9. tensor([[1, 1, 0, 0]])
  10. torch.Size([1, 4])

可以看到输出的结果是和我们上面判断的是一样的. 到这里就把PyTorch中dim的用法说明白了. 我觉得那一篇文章真的是很不错的.

关于完整的代码, 可以查看链接Pytorch维度介绍.ipynb

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南
  • 本文由 发表于 2020年4月11日07:51:51
  • 转载请务必保留本文链接:https://mathpretty.com/12065.html
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: