Saliency Maps的原理与简单实现(使用Pytorch实现)

  • A+
所属分类:深度学习
摘要这一篇主要介绍一种模型解释的方式,Saliency Maps来进行对模型的解释。我们会结合论文说明Saliency Maps的基本原理和实验Pytorch来完成Saliency Maps的实验。

介绍

这一篇会介绍一下Saliency Maps的简单原理。Saliency Maps简单来说可以理解为是用来做模型的解释,可以用来知道哪些变量对于模型来说是重要的。我们也可以理解为Saliency map即特征图,可以告诉我们图像中的像素点对图像分类结果的影响。

原始论文 : Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps

这一篇文章介绍了两种可视化的方式,一种在我之前也有简单的涉及,就是通过指定某个class的概率最大,来通过反向传播来修改input image,不过在文章中他有一些改进来保证最后输出结果的可视化。文章链接如下:使用CNN在MNIST上实现简单的攻击样本

这里会主要介绍关于Saliency Maps的内容。

Saliency Maps的原理

下面我就实验论文里的进行解释(具体的可以查看原始的论文).

整体的目标 : 给一张图片I_0, 对应的分类是c, 有一个模型给出图片I_0的概率值是S_c(I),我们想要衡量I_0的某个像素点对分类器S_c(I)的影响.

简单解释

下面给一个比较直观的解释,假如我们的模型S_c(I)是一个线性模型,如下所示:

Saliency Maps的原理与简单实现(使用Pytorch实现)

那么我们就可以通过观察W的大小来看出每个像素点对应的重要度的信息。

对于复杂的网络

对于复杂的网络来说,模型S_c(I)是一个复杂的非线性模型。但是对于给的图像I_0, 我们可以在I_0的周围对模型S_c(I)进行一阶泰勒展开,如下所示

Saliency Maps的原理与简单实现(使用Pytorch实现)

其中的W就是模型S_c(I)I的导数:

Saliency Maps的原理与简单实现(使用Pytorch实现)

所以,最终我们要做的就是计算W的值。到这里W的计算方式也给出来了,其实就是整个网络进行方向传播,计算输入X的梯度,就是这里的W.(也可以看一下后面具体代码的实现)

Saliency Maps的另外一种解释

对于上面计算梯度W可以用来表示每个特征的重要度,我们可以这么来进行理解。导数的大小表示某个像素点改变一点,对最后结果的影响。

如下图所示,我们可以认为是x_n的改变对y_k的改变

Saliency Maps的原理与简单实现(使用Pytorch实现)

于是,这样计算下来相当于是导数的计算

Saliency Maps的原理与简单实现(使用Pytorch实现)

Saliency Maps实验结果

下面一些是原论文的实验的结果。后面会有具体的实现方式讲解。

Saliency Maps的原理与简单实现(使用Pytorch实现)

Saliency Maps的具体操作

下面部分内容参考自 : 利用pytorch实现Visualising Image Classification Models and Saliency Maps

  • 计算Saliency Map的时候首先要计算与图像像素对应的正确分类中的标准化分数的梯度(这是一个标量)。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);
  • 对于图像中的每个像素点,这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。
  • 之后,我们计算出梯度的绝对值,然后再取三个颜色通道的最大值;因此最后的saliency map的形状是(H, W)为一个通道的灰度图。

Saliency Maps简单实现(Pytorch版本)

对于Saliency Maps的实验,这里只给出核心代码,全部的代码见链接(github仓库)。

代码链接 : Saliency Maps的简单实现(使用Pytorch)

计算输入X的导数

  1. def compute_saliency_maps(X, y, model):
  2.     """
  3.     X表示图片, y表示分类结果, model表示使用的分类模型
  4.     
  5.     Input : 
  6.     - X : Input images : Tensor of shape (N, 3, H, W)
  7.     - y : Label for X : LongTensor of shape (N,)
  8.     - model : A pretrained CNN that will be used to computer the saliency map
  9.     
  10.     Return :
  11.     - saliency : A Tensor of shape (N, H, W) giving the saliency maps for the input images
  12.     """
  13.     # 确保model是test模式
  14.     model.eval()
  15.     # 确保X是需要gradient
  16.     X.requires_grad_()
  17.     saliency = None
  18.     logits = model.forward(X)
  19.     logits = logits.gather(1, y.view(-1, 1)).squeeze() # 得到正确分类
  20.     logits.backward(torch.FloatTensor([1., 1., 1., 1., 1.])) # 只计算正确分类部分的loss
  21.     saliency = abs(X.grad.data) # 返回X的梯度绝对值大小
  22.     saliency, _ = torch.max(saliency, dim=1)
  23.     return saliency.squeeze()

 显示实验结果

这一部分就是结果的显示了,不是很重要,核心都在上面那部分,来计算输入X的梯度。

  1. def show_saliency_maps(X, y):
  2.     # Convert X and y from numpy arrays to Torch Tensors
  3.     X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0)
  4.     y_tensor = torch.LongTensor(y)
  5.     # Compute saliency maps for images in X
  6.     saliency = compute_saliency_maps(X_tensor, y_tensor, model)
  7.     # Convert the saliency map from Torch Tensor to numpy array and show images
  8.     # and saliency maps together.
  9.     saliency = saliency.numpy()
  10.     N = X.shape[0]
  11.     for i in range(N):
  12.         plt.subplot(2, N, i + 1)
  13.         plt.imshow(X[i])
  14.         plt.axis('off')
  15.         plt.title(class_names[y[i]])
  16.         plt.subplot(2, N, N + i + 1)
  17.         plt.imshow(saliency[i], cmap=plt.cm.hot)
  18.         plt.axis('off')
  19.         plt.gcf().set_size_inches(12, 5)
  20.     plt.show()
  21. show_saliency_maps(X, y)

最终的实验的结果如下所示:

Saliency Maps的原理与简单实现(使用Pytorch实现)

关于gather的一些说明

Saliency Maps的原理与简单实现(使用Pytorch实现)

Saliency Maps的应用

在原论文中,作者给出了一个Saliency Maps的应用,可以用来进行图像中物体的识别和分割。如下图所示。

Saliency Maps的原理与简单实现(使用Pytorch实现)
  • 最左侧为原图像;
  • 第二列表示计算得到的Saliency Maps;
  • 第三列表示作者设定了一个阈值,超过认为是重要的(图中蓝色的部分是比较重要的);
  • 将上一层的mask与原始图像结合,得到物体的分割;
  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: