Image Style Transform–关于图像风格迁移的介绍

王 茂南 2019年5月26日07:53:09
评论
10176字阅读33分55秒
摘要这位一篇文章会介绍一下CNN的一个应用,图像风格迁移。会介绍一下大致的方法和其中Style Loss使用的Gram Matrix。

介绍

关于图像的风格迁移,最早是来源与论文 A Neural Algorithm of Artistic Style,在这篇论文中,作者给出了一种模仿图像 contentstyle 的方式,他原文做出的效果如下所示:

Image Style Transform–关于图像风格迁移的介绍

下面简单说明一下原理,最后会给出详细的实现的过程。这里有一篇文章也是介绍风格迁移的原理的,我觉得写的很不错,链接放在这里,大家可以看一下。

 

参考资料

 

原理介绍

首先整个网络的构成大概如下面所示,CNN 的层是可以更深,我这里为了方便就画了一层。输入有三个,Content ImageStyle ImageRandom Image我们希望最后 Random Image 在内容上可以接近 Content Image,在风格上可以接近 Style Image

Image Style Transform–关于图像风格迁移的介绍

于是我们就会即希望 Loss=ContentLoss+StyleLoss 可以尽量小。下面我们看一下如何定义 ContentLossStyleLoss

 

Content Loss

关于 ContentLoss 的计算,是 Content Image Random Image Convolutional Layer 的输出逐像素点相减的平方(MSE)。这个还是比较好理解的。简单看一下下面的式子。(一会可以看下面的代码来进行分析)

Image Style Transform–关于图像风格迁移的介绍

Style Loss

关于 Style Loss 是这样计算的,计算卷积层输出的 Gram Matrix (Style ImageRandom Image 的 Gram Matrix 都会计算);最后计算两者 Gram Matrix 的差值,希望越接近越好。

关于 Gram Matrix 的计算如下所示。

Image Style Transform–关于图像风格迁移的介绍

我们看一个例子来理解一下 Gram Matrix 是如何进行计算的。

Image Style Transform–关于图像风格迁移的介绍
  • 我们看到,Gram Matrix 是没有考虑像素点之间的关系的,最后输出大小只和 filter 的个数有关。(上面的例子是三通道的,所以最后的输出是 3×3 的矩阵)
  • 他只考虑了两个 feature map 的距离的远近。其实仔细看计算,这个就是在计算余弦距离。关于余弦距离的计算看下面图片,可以看到也是两个向量的相乘,再除模长,与这里的计算是一样的。
  • 我们可以认为一个 filter 其实表示一个特征,于是 Gram Matrix 可以表示出特征出现的关系。
  • 所以说,我们可以通过计算 Gram Matrix 的差,来计算两张图片风格上的差距。
  • 有了两个 Loss 之后,就是进行梯度下降即可。下面看一下详细的实现过程。

下面是余弦距离的计算公式。

Image Style Transform–关于图像风格迁移的介绍

 

代码实现

准备工作

首先做好准备的工作,导入相应要使用的库和定义device。

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. from PIL import Image
  6. import matplotlib.pyplot as plt
  7. import torchvision.transforms as transforms
  8. import torchvision.models as models
  9. import numpy as np
  10. import copycopy
  11. import os

定义device, 判断是否使用cuda

  1. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

定义一些工具函数

下面函数主要是为了加载图片来使用的,将图片转为tensor,同时进行放缩。

  1. def image_loader(image_name,imsize):
  2.     """图片load函数
  3.     """
  4.     # 转换图片大小
  5.     loader = transforms.Compose([
  6.         transforms.Resize(imsize),  # scale imported image
  7.         transforms.ToTensor()])  # transform it into a torch tensor
  8.     image = Image.open(image_name)
  9.     # fake batch dimension required to fit network's input dimensions
  10.     image = loader(image).unsqueeze(0)
  11.     return image.to(device, torch.float)
  12. def image_util(img_size=512,style_img="./images/picasso.jpg", content_img="./images/dancing.jpg"):
  13.     """返回style_image和content_image
  14.        需要保证两张图片的大小是一样的
  15.        """
  16.     imsize = img_size if torch.cuda.is_available() else 128  # use small size if no gpu
  17.     # 加载图片
  18.     style_img = image_loader(image_name=style_img, imsize=img_size)
  19.     content_img = image_loader(image_name=content_img, imsize=img_size)
  20.     # 判断是否加载成功
  21.     print("Style Image Size:{}".format(style_img.size()))
  22.     print("Content Image Size:{}".format(content_img.size()))
  23.     assert style_img.size() == content_img.size(), \
  24.         "we need to import style and content images of the same size"
  25.     return style_img, content_img

定义Content Loss

我们上面介绍了,Content Loss是Content Image和Random Image在卷积层的输出计算MSE,即像素两两相减,于是可以很容易写出Content Loss。

  1. class ContentLoss(nn.Module):
  2.     def __init__(self, target,):
  3.         super(ContentLoss, self).__init__()
  4.         # we 'detach' the target content from the tree used
  5.         # to dynamically compute the gradient: this is a stated value,
  6.         # not a variable. Otherwise the forward method of the criterion
  7.         # will throw an error.
  8.         self.target = target.detach()
  9.     def forward(selfinput):
  10.         self.loss = F.mse_loss(inputself.target)
  11.         return input

定义Style Loss

上面见过,计算Style Loss的,其实就是计算Style Image和Random Image两个图片的Gram Matrix的MSE,所以我们先定义如何计算Gram Matrix.

下面对原始数据做view,是为了做一次矩阵乘法就可以得到Gram Matrix,与我上面画的Gram Matrix的计算的例子相同的意思。

  1. # 我们首先定义 Gram Matrix
  2. def gram_matrix(input):
  3.     a, b, c, d = input.size()  # a=batch size(=1)
  4.     # b=number of feature maps
  5.     # (c,d)=dimensions of a f. map (N=c*d)
  6.     features = input.view(a * b, c * d)  # resise F_XL into \hat F_XL
  7.     G = torch.mm(features, features.t())  # compute the gram product
  8.     # print(G)
  9.     # 对Gram Matrix做正规化, 除总的大小
  10.     return G.div(a * b * c * d)

接着我们就可以计算Style Loss了。

  1. # 接着我们就可以定义Style Loss了
  2. class StyleLoss(nn.Module):
  3.     def __init__(self, target_feature):
  4.         super(StyleLoss, self).__init__()
  5.         self.target = gram_matrix(target_feature).detach()
  6.     def forward(selfinput):
  7.         G = gram_matrix(input)
  8.         self.loss = F.mse_loss(G, self.target)
  9.         return input

 修改VGG网络

我们在这里是使用预训练好的VGG16的网路,但是因为我们要获取中间网络的输出,所以我们可以重新写一下。

首先定义一个标准化的类,因为VGG16对所有输入进行了标准化,我们也要进行同样的操作。

  1. # -------------------
  2. # 模型的标准化
  3. # 因为原始的VGG网络对图片做了normalization, 所在要把下面的Normalization放在新的网络的第一层
  4. # -------------------
  5. class Normalization(nn.Module):
  6.     def __init__(self, mean, std):
  7.         super(Normalization, self).__init__()
  8.         # .view the mean and std to make them [C x 1 x 1] so that they can
  9.         # directly work with image Tensor of shape [B x C x H x W].
  10.         # B is batch size. C is number of channels. H is height and W is width.
  11.         self.mean = mean.view(-1, 1, 1)
  12.         self.std = std.view(-1, 1, 1)
  13.     def forward(self, img):
  14.         # normalize img
  15.         return (img - self.mean) / self.std

我们将上面定义的ContentLoss和StyleLoss这两个类,加到网络指定的层后面,为了方便我们获取输出的值,其实我们也可以可以使用hook来完成相同的操作,具体可以看,CNN可视化Convolutional Features

  1. # --------------------------------
  2. # 网络结构的修改, 生成一个style的网络
  3. # --------------------------------
  4. def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
  5.                                style_img, content_img,
  6.                                content_layers,
  7.                                style_layers):
  8.     # 复制cnn的网络部分
  9.     cnn = copy.deepcopy(cnn)
  10.     # normalization module
  11.     normalization = Normalization(normalization_mean, normalization_std).to(device)
  12.     # just in order to have an iterable access to or list of content/syle
  13.     # losses
  14.     content_losses = []
  15.     style_losses = []
  16.     # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
  17.     # to put in modules that are supposed to be activated sequentially
  18.     # 之后逐层向model里面增加内容
  19.     model = nn.Sequential(normalization)
  20.     i = 0  # increment every time we see a conv
  21.     for layer in cnn.children():
  22.         if isinstance(layer, nn.Conv2d):
  23.             i += 1
  24.             name = 'conv_{}'.format(i)
  25.         elif isinstance(layer, nn.ReLU):
  26.             name = 'relu_{}'.format(i)
  27.             # The in-place version doesn't play very nicely with the ContentLoss
  28.             # and StyleLoss we insert below. So we replace with out-of-place
  29.             # ones here.
  30.             layer = nn.ReLU(inplace=False)
  31.         elif isinstance(layer, nn.MaxPool2d):
  32.             name = 'pool_{}'.format(i)
  33.         elif isinstance(layer, nn.BatchNorm2d):
  34.             name = 'bn_{}'.format(i)
  35.         else:
  36.             raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
  37.         model.add_module(name, layer)
  38.         if name in content_layers:
  39.             # add content loss:
  40.             target = model(content_img).detach()
  41.             content_loss = ContentLoss(target)
  42.             model.add_module("content_loss_{}".format(i), content_loss)
  43.             content_losses.append(content_loss)
  44.         if name in style_layers:
  45.             # add style loss:
  46.             target_feature = model(style_img).detach()
  47.             style_loss = StyleLoss(target_feature)
  48.             model.add_module("style_loss_{}".format(i), style_loss)
  49.             style_losses.append(style_loss)
  50.     # now we trim off the layers after the last content and style losses\
  51.     # 只需要算到最后一个style loss或是content loss用到的layer就可以了, 后面的可以去掉
  52.     for i in range(len(model) - 1, -1, -1):
  53.         if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
  54.             break
  55.     model = model[:(i + 1)]
  56.     # 返回的是修改后的Model, style_losses和content_losses的list
  57.     return model, style_losses, content_losses

定义优化函数

  1. def get_input_optimizer(input_img):
  2.     # 这里要对图片做梯度下降
  3.     optimizer = optim.LBFGS([input_img.requires_grad_()])
  4.     return optimizer

定义传播函数

这一步即,我们对输入的图片进行修改,使得ContentLoss+StyleLoss可以变小。

  1. def run_style_transfer(cnn, normalization_mean, normalization_std, content_img, style_img, input_img, content_layers,style_layers, num_steps=300, style_weight=1000000, content_weight=1):
  2.     print('Building the style transfer model..')
  3.     model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img, content_layers, style_layers)
  4.     optimizer = get_input_optimizer(input_img)
  5.     print('Optimizing..')
  6.     run = [0]
  7.     while run[0] <= num_steps:
  8.         def closure():
  9.             # correct the values of updated input image
  10.             input_img.data.clamp_(0, 1)
  11.             optimizer.zero_grad()
  12.             model(input_img) # 前向传播
  13.             style_score = 0
  14.             content_score = 0
  15.             for sl in style_losses:
  16.                 style_score += sl.loss
  17.             for cl in content_losses:
  18.                 content_score += cl.loss
  19.             style_score *= style_weight
  20.             content_score *= content_weight
  21.             # loss为style loss 和 content loss的和
  22.             loss = style_score + content_score
  23.             loss.backward() # 反向传播
  24.             # 打印loss的变化情况
  25.             run[0] += 1
  26.             if run[0] % 50 == 0:
  27.                 print("run {}:".format(run))
  28.                 print('Style Loss : {:4f} Content Loss: {:4f}'.format(
  29.                     style_score.item(), content_score.item()))
  30.                 print()
  31.             return style_score + content_score
  32.         # 进行参数优化
  33.         optimizer.step(closure)
  34.     # a last correction...
  35.     # 数值范围的纠正, 使其范围在0-1之间
  36.     input_img.data.clamp_(0, 1)
  37.     return input_img

 进行训练

做完所有工作之后,我们就可以开始进行训练了。

  1. # 加载content image和style image
  2. style_img,content_img = image_util(img_size=444,style_img="./images/style/rose.jpg", content_img="./images/content/face.jpg")
  3. # input image使用content image
  4. input_img = content_img.clone()
  5. # 加载预训练好的模型
  6. cnn = models.vgg19(pretrained=True).features.to(device).eval()
  7. # 模型标准化的值
  8. cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
  9. cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
  10. # 定义要计算loss的层
  11. content_layers_default = ['conv_4']
  12. style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
  13. # 模型进行计算
  14. output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img, content_layers=content_layers_default, style_layers=style_layers_default, num_steps=300, style_weight=100000, content_weight=1)
Image Style Transform–关于图像风格迁移的介绍

显示结果

最后把结果打印出来即可。

  1. image = output.cpu().clone()
  2. image = image.squeeze(0)
  3. unloader = transforms.ToPILImage()
  4. unloader(image)
Image Style Transform–关于图像风格迁移的介绍

一些其他结果

尝试了转换为其他的一些风格,把一些结果放在下面。

Image Style Transform–关于图像风格迁移的介绍 Image Style Transform–关于图像风格迁移的介绍 Image Style Transform–关于图像风格迁移的介绍

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南
  • 本文由 发表于 2019年5月26日07:53:09
  • 转载请务必保留本文链接:https://mathpretty.com/10491.html
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: