Pytorch入门教程11-Softmax函数和交叉熵

  • A+
所属分类:Pytorch快速入门
摘要这一篇我们介绍一下在多分类问题中经常会使用到的Softmax和交叉熵的概念. 特别的, 在Pytorch中, 因为损失函数CorssEntropyLoss中同时包含了Softmax和交叉熵, 所以我们在构建网络的时候, 最后一层不需要再使用Softmax, 只需要直接使用nn.Linear即可.

简介

这一篇我们会讲解二分类问题和多分类问题. 在二分类问题中, 我们可以使用Sigmoid作为输出, 看成是一类的概率, 另一类的概率就是1-Sogmoid的输出. 但是在多分类问题中, 我们需要每一类的概率, 这个是偶就需要使用Softmax.

 

参考资料

 

关于Softmax和多分类问题

Softmax函数介绍

在多分类问题中, 我们需要对每一个类别输出概率, 同时要保证这些概率和是1. 这个时候就需要使用Sotmax函数了.

假设现在的输出是vi, 那么每一个对应softmax之后的输出如下:

Pytorch入门教程11-Softmax函数和交叉熵

在这里对输入进行指数化, 这样可以使得两个输入之间的差距可以扩大.

下面我们看一个softmax的例子, 并使用Pytorch中自带的函数进行测试.

Pytorch入门教程11-Softmax函数和交叉熵

下面看一下使用Pytorch的测试结果.

  1. x = torch.tensor([2, 1, 0.1])
  2. s = torch.softmax(x, dim=0)
  3. print('Sotfmax的输出:{}'.format(s))
  4. print('Sotfmax的输出总和:{}'.format(s.sum().item()))
  5. """
  6. Sotfmax的输出:tensor([0.6590, 0.2424, 0.0986])
  7. Sotfmax的输出总和:1.0000001192092896
  8. """

对于多组数据来说, 其实softmax做的就是将一个矩阵的值压缩到0到1之间. 例如下面的例子, 测试数据有三类.

Pytorch入门教程11-Softmax函数和交叉熵

The output of the softmax describes the probability (or if you may, the confidence) of the neural network that a particular sample belongs to a certain class.

Thus, for the first example above, the neural network assigns a confidence of 0.71 that it is a cat, 0.26 that it is a dog, and 0.04 that it is a horse. The same goes for each of the samples above.

 

交叉熵损失函数

在介绍交叉熵损失函数之前, 我们首先看一个例子. 现在有一个三分类的问题, 一共有三个测试数据, 两个测试的模型, 他们的结果分别如下所示:

模型一的结果:

Pytorch入门教程11-Softmax函数和交叉熵

模型二的结果:

Pytorch入门教程11-Softmax函数和交叉熵
  • 对于两个模型的准确率来说, 是一样的, 准确率都是2/3;
  • 对于模型一来说, 他在数据1和数据2上, 对结果的确定性都不是很确定, 都是0.4, 只比0.3大一点.
  • 对于模型二来说, 他的预测结果就表现很肯定, 对结果的可能性都是0.7, 会远远比其他类别要大.

所以可以看到, 虽说上面两个模型在准确率上的表现是一样的, 但是实际上, 模型二会比模型一要更好一些. 于是, 我们就需要使用一个loss函数, 能够反映出上面模型的好坏, 这个时候就需要使用交叉熵损失. (关于交叉熵的损失, 可以查看熵, 交叉熵, 和KL散度)

首先关于交叉熵的计算如下所示:

Pytorch入门教程11-Softmax函数和交叉熵

其中:

  • p(真实的分布), 在实际中也就是0和1, 是正确的label就是1, 否则就是0;
  • q(预测的概率), 模型给出的概率(经过softmax), 都是介于0-1之间的数字;

因为q是介于0-1之间的数字, 所以经过log之后为负数, 然后前面又加了一个负号. 如果对正确标签给出的概率越大, 那个这个值就会越小. 于是整个的loss可以定义为下面的样子.

Pytorch入门教程11-Softmax函数和交叉熵

其中:

  • M为总的数据的数量;
  • y是实际标签, 为one-hot编码.
  • p为模型输出的概率;

在实际中, 因为y只有0和1两种不同的取值, 所以negative log-likelihood可以化简为下面的式子.

Pytorch入门教程11-Softmax函数和交叉熵

我们画出上面的L关于不同输入的图像, 可以看到当概率很小的时候, L会很大; 当概率大的时候, L会很小. 也就是说, 当网络给正确的结果一个很低的置信度(confidence)的时候, 此时L很大.

Pytorch入门教程11-Softmax函数和交叉熵

还是上面的例子, 在计算完softmax之后, 我们计算NLLLoss. 下图中红色的表示正确的是哪一类.

Pytorch入门教程11-Softmax函数和交叉熵

于是, 我们使用上面的交叉熵来衡量上面的模型一和模型二. 需要注意的是, 在Pytorch中, 有一个函数是NLLLoss (negative log likelihood loss), 但是他并没有计算log, 这个名字起的很奇怪.

所以, 在Pytorch中, 多分类的loss我们会使用nn.CrossEntropyLoss(), 这里面包含了Logsoftmax+NLLLoss, 他把Log的计算和sotfmax和在了一起.

我们在这里就自己手动计算log, 再使用NLLLoss来测试一下上面两个模型的好坏. (我们认为是模型二比模型一要好), 注意下面的代码, 我们对概率值手动计算了log.

  1. loss = nn.NLLLoss()
  2. Y = torch.tensor([2, 1, 0])
  3. # 模型1对每条数据的预测,每条数据对应三个概率,表示该条数据属于第 i 类的概率值
  4. model_one_pred = torch.tensor(
  5.     [[0.3, 0.3, 0.4],  # predict class 2
  6.      [0.3, 0.4, 0.3],  # predict class 1
  7.      [0.1, 0.2, 0.7]])  # predict class 2
  8. # 模型2对每条数据的预测,每条数据对应三个概率,表示该条数据属于第 i 类的概率值
  9. model_two_pred = torch.tensor(
  10.     [[0.1, 0.2, 0.7],  # predict class 2
  11.      [0.1, 0.7, 0.2],  # predict class 1
  12.      [0.4, 0.3, 0.3]])  # predict class 2l1 = loss(Y_pred_good, Y)
  13. l1 = loss(torch.log(model_one_pred), Y)
  14. l2 = loss(torch.log(model_two_pred), Y)
  15. l1, l2
  16. # > (tensor(1.3784), tensor(0.5432))

可以看到模型一的loss比模型二的要大, 说明模型二比较好.

 

关于Sotfmax和NLLLoss的数学推导

数学推导

上面我们直观的了解了一下Softamx和NLLLoss的相关内容, 这里我们来计算一下他们的导数, 更进一步说明这样做的好处. (这一部分的数学推导来自, Understanding softmax and the negative log-likelihood)

首先重写一下关于softmax的函数:

Pytorch入门教程11-Softmax函数和交叉熵

接着是negative log-likelihood的函数:

Pytorch入门教程11-Softmax函数和交叉熵

我们是要计算L关于f的导数, 我们可以使用链式法则进行如下的展开.

Pytorch入门教程11-Softmax函数和交叉熵

这样就分为了两个部分, 我们首先看一下前面一部分.

Pytorch入门教程11-Softmax函数和交叉熵

接着我们来计算第二部分. 在计算之前, 复习一下除法的求导公式.

Pytorch入门教程11-Softmax函数和交叉熵

按照上面的方式, 对第二部分的式子进行求导.

Pytorch入门教程11-Softmax函数和交叉熵

继续对上面的式子进行化简

Pytorch入门教程11-Softmax函数和交叉熵

最后我们将第一部分和第二部分结合起来, 总的导数为:

Pytorch入门教程11-Softmax函数和交叉熵

可以看到最终的导数很简单, 这也就是为什么通常情况下, 我们都是将softmax和negative log likelihood合起来一起用的, 也就是Pytorch中的CrossEntropyLoss.

下面我们使用CrossEntropyLoss来做一下实验, 看一下导数是不是和我们推导的是一样的.

使用CorssEntropyLoss进行实验

我们还是使用之前的实验数据来进行模拟. 下面我们使用了sum, 也就是将每一条数据的loss进行求和.

  1. loss = nn.CrossEntropyLoss(reduction='sum')
  2. # 实际的分类
  3. Y = torch.tensor([0, 2, 1])
  4. # 模型一对每条数据的预测,每条数据对应三个概率,表示该条数据属于第 i 类的概率值
  5. X = torch.tensor(
  6.     [[5, 4, 2],  # predict class 2
  7.      [4, 2, 8],  # predict class 1
  8.      [4, 4, 1]], dtype=torch.float32, requires_grad=True)  # predict class 2
  9. l = loss(X, Y)
  10. print('loss:{}'.format(l.item()))
  11. # >> loss:1.0873291492462158

我们看一下X进行softmax之后的结果:

  1. p = nn.Softmax(dim=1)
  2. print(p(X))
  3. """
  4. tensor([[0.7054, 0.2595, 0.0351],
  5.         [0.0179, 0.0024, 0.9796],
  6.         [0.4879, 0.4879, 0.0243]], grad_fn=<SoftmaxBackward>)
  7. """

可以看到和上面图上是一样的. 下面我们再来确认一下上面loss的计算. 为了自己手动计算loss, 我们需要对Softmax的结果求log.

  1. -torch.log(p(X))
  2. """
  3. tensor([[0.3490, 1.3490, 3.3490],
  4.         [4.0206, 6.0206, 0.0206],
  5.         [0.7177, 0.7177, 3.7177]], grad_fn=<NegBackward>)
  6. """

于是loss的计算也就变成了, loss=0.3490+0.0206+0.7177=1.0873.

上面我们再次确认了CrossEntropyLoss是如何进行计算的. 下面我们确认上面的求导公式. 我们首先进行反向传播, 计算X的梯度.

  1. l.backward()  # 反向传播,计算梯度
  2. print('X的梯度, \n{}'.format(X.grad))
  3. """
  4. X的梯度, 
  5. tensor([[-0.2946,  0.2595,  0.0351],
  6.         [ 0.0179,  0.0024, -0.0204],
  7.         [ 0.4879, -0.5121,  0.0243]])
  8. """

三个负数的值(对应第一条数据的第一个, 第二条数据的第三个, 第三条数据的第二个),  也就是对应label=1的数据, 都是通过(p-1)计算得到的, 其他的都是p (因为y=0, 所以其实在计算NLL Loss的时候没有参与运算).

下面是按照上面的公式推导出来的导数, 看到和上面解释的是一样的.

  1. p(X)-1
  2. """
  3. tensor([[-0.2946, -0.7405, -0.9649],
  4.         [-0.9821, -0.9976, -0.0204],
  5.         [-0.5121, -0.5121, -0.9757]], grad_fn=<SubBackward0>)
  6. """

于是, 使用CrossEntropyLoss的总的一个流程如下图所示:

Pytorch入门教程11-Softmax函数和交叉熵
  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: