理解PyTorch中维度的概念

  • A+
所属分类:深度学习
摘要一篇非常好的, 对于PyTorch中维度来进行介绍. 使用动图直观说明各个维度(dim)下求和和topk操作.

简介

今天在使用torch中的topk的时候, 对于dim产生了一些疑问. 后面也是找到了规律, 但是还是很困惑他为什么是这么设计的, 即dim与tensor本身的行列是不一致的. 然后就查了一下, 真的找到了一篇很好的文章, 解决了我的困惑, 就想在这里记录一下.

我这一篇文章里的所有的动图, 都是来自与下面这篇文章, 写的非常直观.

原文链接(十分棒的文章), Understanding dimensions in PyTorch

关于这里设计的代码, 有一个完整的notebook的文档, 具体链接见GithubPytorch维度介绍.ipynb

 

理解PyTorch维度概念

首先我们从最基础的开始, 当我们在Pytorch中定义一个二维的tensor的时候, 他包含行和列. 例如下面我们创建一个2*3的tensor.

  1. x = torch.tensor([
  2.         [1,2,3],
  3.         [4,5,6]
  4.     ])
  5. # 我们可以看到"行"是dim=0, "列"是dim=1
  6. print(x.shape)
  7. >> torch.Size([2, 3])

我们可以看到打印的结果显示:

  • first dimension (dim=0) stays for rows, 第一个维度代表行, 因为是2, 实际x就是2行
  • the second one (dim=1) for columns, 第二个维度代表列, 因为是3

于是, 我们会认为, torch.sum(x, dim=0)就是(1+2+3, 4+5+6)=tensor([6, 15]), 但是实际情况却不是这个样子的.

  1. torch.sum(x, dim=0)
  2. >> tensor([5, 7, 9])

我们可以看到按照dim=0求和, 其实是在按列相加, 也就是(1+4, 2+5, 3+6) =tensor([5, 7, 9]), 和我们想象的完全不一样. 我们再看一下按照dim=1进行求和.

  1. torch.sum(x, dim=1)
  2. >> tensor([ 6, 15])

可以看到, 在按照dim=1的时候求和的时候, 其实在按照按行进行求和,  (1+2+3, 4+5+6)=tensor([6, 15]), 这就让人很困惑, 明明上面说的是dim=0代表是行.

于是, 原文作者在一篇介绍numpy维度的文章中, 找到了问题的关键所在. 也就是下面的这段话(numpy中的axis也就是这里的dim).

The way to understand the "axis" of numpy sum is that it collapses the specified axis. So when it collapses the axis 0 (the row), it becomes just one row (it sums column-wise).

上面的话简单翻译就是, 当按照axis=0进行求和的时候, 其实可以想象为对axis=0这个维度进行挤压, 最后只剩下一行, 那一行就是结果, 也就是按列在相加.

是不是还是会有一些困惑, 我们还是对于上面的例子(tensor([[1,2,3], [4,5,6]])), 看一下在dim=0的时候, 为什么是列相加, 以及上面的collapse the specific axis(dim)的含义.

理解PyTorch中维度的概念

如上面的动图所示, 当dim=0的时候, 按每一行的元素进行相加, 最后的结果就是和按列求和.

 

对于三维向量

下面我们更进一步, 来看一下对于三维的tensor, 在各个维度进行sum操作的结果. 首先我们看一下每一个dim代表的含义.

  1. # 看一下三维的
  2. x = torch.tensor([
  3.         [
  4.          [1,2,3],
  5.          [4,5,6]
  6.         ],
  7.         [
  8.          [1,2,3],
  9.          [4,5,6]
  10.         ],
  11.         [
  12.          [1,2,3],
  13.          [4,5,6]
  14.         ]
  15.     ])
  16. # 我们可以看到第三维是dim=0, "行"是dim=1, 列是dim=2
  17. print(x.shape)
  18. >> torch.Size([3, 2, 3])

可以看到此时dim=0是第三个维度, dim=1是行, dim=2是列.

  1. torch.sum(x, dim=0)
  2. >>
  3. tensor([[ 3,  6,  9],
  4.         [12, 15, 18]])

我们可以将其看成是各个二维平面对应元素求和, 还是有点绕, 还是直接看下面的动图.

理解PyTorch中维度的概念

接着是对dim=1进行求和.

  1. torch.sum(x, dim=1)
  2. >>
  3. tensor([[5, 7, 9],
  4.         [5, 7, 9],
  5.         [5, 7, 9]])

还是直接看下面的动图, 来进行理解.

理解PyTorch中维度的概念

最后按照dim=2来进行求和.

  1. torch.sum(x, dim=2)
  2. >>
  3. tensor([[ 6, 15],
  4.         [ 6, 15],
  5.         [ 6, 15]])

还是使用动图来进行解释.

理解PyTorch中维度的概念

 

关于TopK的问题

最后说一下topk与dim之间的问题, 这个也是我产生这次疑惑的根源. 关于topk, 我们会在计算准确率的时候用到这个.

比如说一个4分类的问题, 有如下的三个输出. 我们希望找到第一个是预测的结果是第几个分类, 这个时候就需要使用到topk, 返回最大的前K个值和他的位置, 这个位置就可以用来知道label.

  1. # 看一下二维的
  2. x = torch.tensor([
  3.     [0.1, 0.2 ,0.5, 0.2],
  4.     [0.4, 0.3, 0.2, 0.1],
  5.     [0.1, 0.2, 0.5, 0.1],
  6. ])
  7. # -------
  8. print(x.shape)
  9. >> torch.Size([3, 4])

比如上面的数据x, 我们就是要在一行进行比较, 最后结果是一个3*1的tensor, 所以根据上面讲的, 就是对列进行collapse, 也就是dim=1, 我们测试一下看一下结果.

  1. a, b = x.topk(1, dim=1)
  2. print(a)
  3. print('-'*10)
  4. print(b)
  5. print(b.shape)
  6. >>
  7. tensor([[0.5000],
  8.         [0.4000],
  9.         [0.5000]])
  10. ----------
  11. tensor([[2],
  12.         [0],
  13.         [2]])
  14. torch.Size([3, 1])

可以看到和我们预料的是一样的. 如果dim=0, 那么返回的就是一个1*4的向量, 应该就会返回tensor([0.4,0.3,0.5,0.2]), 我们实际做一下看一下是不是这个.

  1. a, b = x.topk(1, dim=0)
  2. print(a)
  3. print('-'*10)
  4. print(b)
  5. print(b.shape)
  6. >>
  7. tensor([[0.4000, 0.3000, 0.5000, 0.2000]])
  8. ----------
  9. tensor([[1, 1, 0, 0]])
  10. torch.Size([1, 4])

可以看到输出的结果是和我们上面判断的是一样的. 到这里就把PyTorch中dim的用法说明白了. 我觉得那一篇文章真的是很不错的.

关于完整的代码, 可以查看链接Pytorch维度介绍.ipynb

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: