WGAN-GP训练流程

  • A+
所属分类:深度学习
摘要这一篇主要讲关于使用pytorch来实现WGAN-GP, 我们也是来看一下训练GAN的一个主要的流程, 里面包含训练的每一步以及最后结果的展示.

简介

这一篇文章主要会介绍一下WGAN-GP的训练方式. 也是使用这一个例子, 来说明一下在训练GAN的时候的一些常用的步骤, 包括定义网络, 网络的测试(设置测试集), 训练分类器和生成器的步骤, 模型的保存和使用, 结果的展示.

下面会就每一个部分分别进行解释. 同时原始的notebook也是已经上传了github, 链接地址如下所示.

代码仓库链接GAN的代码仓库

实践步骤

准备工作

在这里我们导入我们需要使用的库, 同时我们需要定义device, 也就是训练的时候使用cpu还是gpu.

  1. import numpy as np
  2. import pandas as pd
  3. import os
  4. import matplotlib.pyplot as plt
  5. from datetime import date,datetime
  6. import logging
  7. import torch
  8. from torch import nn
  9. from torchvision import datasets, transforms
  10. from torch import optim
  11. from torch.autograd import Variable
  12. from torchvision.utils import make_grid
  13. from torchvision.utils import save_image
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. device

 

定义网络

接着我们来定义网络, 我们首先定义分类器(discriminator), 这里我们是用来做动漫头像的分类.

  1. class Discriminator(nn.Module):
  2.     def __init__(self):
  3.         super(Discriminator, self).__init__()
  4.         self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)
  5.         self.batchN1 = nn.BatchNorm2d(64)
  6.         self.LeakyReLU1 = nn.LeakyReLU(0.2, inplace=True)
  7.         self.conv2 = nn.Conv2d(in_channels=64, out_channels=64*2, kernel_size=4, stride=2, padding=1, bias=False)
  8.         self.batchN2 = nn.BatchNorm2d(64*2)
  9.         self.LeakyReLU2 = nn.LeakyReLU(0.2, inplace=True)
  10.         self.conv3 = nn.Conv2d(in_channels=64*2, out_channels=64*4, kernel_size=4, stride=2, padding=1, bias=False)
  11.         self.batchN3 = nn.BatchNorm2d(64*4)
  12.         self.LeakyReLU3 = nn.LeakyReLU(0.2, inplace=True)
  13.         self.conv4 = nn.Conv2d(in_channels=64*4, out_channels=64*8, kernel_size=4, stride=2, padding=1, bias=False)
  14.         self.batchN4 = nn.BatchNorm2d(64*8)
  15.         self.LeakyReLU4 = nn.LeakyReLU(0.2, inplace=True)
  16.         self.conv5 = nn.Conv2d(in_channels=64*8, out_channels=1, kernel_size=4, bias=False)
  17.         self.sigmoid = nn.Sigmoid()
  18.     def forward(self, x):
  19.         x = self.LeakyReLU1(self.batchN1(self.conv1(x)))
  20.         x = self.LeakyReLU2(self.batchN2(self.conv2(x)))
  21.         x = self.LeakyReLU3(self.batchN3(self.conv3(x)))
  22.         x = self.LeakyReLU4(self.batchN4(self.conv4(x)))
  23.         x = self.conv5(x)
  24.         return x

我们有的时候会测试一下我们的D是否是正确的, 于是我们可以从训练样本中抽取出一些来进行测试.

  1. # 真实的图片
  2. images = torch.stack(([dataset[i][0] for i in range(batch_size)]))
  3. # 测试D是否与想象的是一样的
  4. outputs = D(images)

 

接着我们定义生成器(generator), 生成器是输入随机数, 生成我们要模仿的动漫头像(Anime-Face)

  1. class Generator(nn.Module):
  2.     def __init__(self):
  3.         super(Generator, self).__init__()
  4.         self.ConvT1 = nn.ConvTranspose2d(in_channels=100, out_channels=64*8, kernel_size=4, bias=False# 这里的in_channels是和初始的随机数有关
  5.         self.batchN1 = nn.BatchNorm2d(64*8)
  6.         self.relu1 = nn.ReLU()
  7.         self.ConvT2 = nn.ConvTranspose2d(in_channels=64*8, out_channels=64*4, kernel_size=4, stride=2, padding=1, bias=False# 这里的in_channels是和初始的随机数有关
  8.         self.batchN2 = nn.BatchNorm2d(64*4)
  9.         self.relu2 = nn.ReLU()
  10.         self.ConvT3= nn.ConvTranspose2d(in_channels=64*4, out_channels=64*2, kernel_size=4, stride=2, padding=1, bias=False# 这里的in_channels是和初始的随机数有关
  11.         self.batchN3 = nn.BatchNorm2d(64*2)
  12.         self.relu3 = nn.ReLU()
  13.         self.ConvT4 = nn.ConvTranspose2d(in_channels=64*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False# 这里的in_channels是和初始的随机数有关
  14.         self.batchN4 = nn.BatchNorm2d(64)
  15.         self.relu4 = nn.ReLU()
  16.         self.ConvT5 = nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False)
  17.         self.tanh = nn.Tanh() # 激活函数
  18.     def forward(self, x):
  19.         x = self.relu1(self.batchN1(self.ConvT1(x)))
  20.         x = self.relu2(self.batchN2(self.ConvT2(x)))
  21.         x = self.relu3(self.batchN3(self.ConvT3(x)))
  22.         x = self.relu4(self.batchN4(self.ConvT4(x)))
  23.         x = self.ConvT5(x)
  24.         x = self.tanh(x)
  25.         return x

同样的, 我们可以测试一下G是否是和我们想象中是一样进行工作的. 我们使用下面的方式进行测试.

  1. noise = Variable(torch.randn(batch_size, 100, 1, 1)).to(device) # 随机噪声,生成器输入
  2. # 测试G
  3. fake_images = G(noise)

 

 

加载数据集&定义辅助函数

在这一部分我们进行数据集的加载. 因为我数据集已经提前下载好了, 直接使用pytorch中的dataset即可. 这一部分的使用可以参考这个链接: Pytorch图像处理,显示与保存

我们在图像导入的时候, 首先将其变为64*64的大小, 同时进行归一化, 使其像素值的范围变为(-1,1). 这样于generator最后的tanh也是可以对应起来.

  1. trans = transforms.Compose([
  2.     transforms.Resize(64),
  3.     transforms.ToTensor(),
  4.     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  5. ])
  6. dataset = datasets.ImageFolder('./data', transform=trans) # 数据路径
  7. dataloader = torch.utils.data.DataLoader(dataset,
  8.                                     batch_size=128, # 批量大小
  9.                                     shuffle=True# 乱序
  10.                                     num_workers=2 # 多进程
  11.                                     )

因为我们进行了归一化, 所以在图像最后进行保存的时候, 我们需要进行还原, 所以我们定义一个辅助函数来帮助进行还原.

  1. # 图像像素还原
  2. def denorm(x):
  3.     out = (x + 1) / 2
  4.     return out.clamp(0, 1)

 

开始训练

接下来就可以开始进行训练了. 这里训练主要分为三个步骤, 首先是初始化网络, 定义损失函数和优化器. 接着我们就是分别训练分类器和优化器.

初始化网络和优化器

这里一部分没有什么特别要说明的, 就是和其他的网络训练是一样的.

  1. # ----------
  2. # 初始化网络
  3. # ----------
  4. D = Discriminator().to(device) # 定义分类器
  5. G = Generator().to(device) # 定义生成器
  6. # -----------------------
  7. # 定义损失函数和优化器
  8. # -----------------------
  9. learning_rate = 0.0002
  10. d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate)
  11. g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate)

 

训练Discriminator

接着我们训练分类器(discriminator), 在训练WGAN-GP的discriminator的时候, 他是由三个部分的loss来组成的. 下面我们来每一步进行分解了进行查看.

首先我们定义好要使用的real_label=1和fake_label=0, 和G需要使用的noise.

  1. # 测试时的batch大小
  2. batch_size = 36
  3. # 创造real label和fake label
  4. real_labels = torch.ones(batch_size, 1).to(device) # real的pic的label都是1
  5. fake_labels = torch.zeros(batch_size, 1).to(device) # fake的pic的label都是0
  6. noise = Variable(torch.randn(batch_size, 100, 1, 1)).to(device) # 随机噪声,生成器输入
  7. # 真实的图片
  8. images = torch.stack(([dataset[i][0] for i in range(batch_size)]))

接着我们计算loss的第一个组成部分(这里参考WGAN-GP的loss的计算公式).

  1. # 首先计算真实的图片的loss, d_loss_real
  2. outputs = D(images)
  3. d_loss_real = -torch.mean(outputs)

接着我们计算loss的第二个组成部分.

  1. # 接着计算假的图片的loss, d_loss_fake
  2. fake_images = G(noise)
  3. outputs = D(fake_images)
  4. d_loss_fake = torch.mean(outputs)

接着我们计算penalty region的loss, 也就是我们希望在penalty region中的梯度是越接近1越好.

我们首先生成penalty region, 这一部分是在P_G和P_data之间的, 如下图所示.

WGAN-GP训练流程
  1. # 接着计算penalty region 的loss, d_loss_penalty
  2. # 生成penalty region
  3. alpha = torch.rand((batch_size, 1, 1, 1)).to(device)
  4. x_hat = alpha * images.data + (1 - alpha) * fake_images.data
  5. x_hat.requires_grad = True

接着我们来计算他们的梯度, 我们希望梯度是越接近1越好.

  1. # 将中间的值进行分类
  2. pred_hat = D(x_hat)
  3. # 计算梯度
  4. gradient = torch.autograd.grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).to(device),
  5.                    create_graph=False, retain_graph=False)

这里的梯度计算完毕之后是在每一个像素点处都是有梯度的值的.

  1. # 计算出每一张图, 每一个像素点处的梯度
  2. gradient[0].shape
  3. """
  4. torch.Size([36, 3, 64, 64])
  5. """

接着我们计算L2范数.

  1. penalty_lambda = 10 # 梯度惩罚系数
  2. gradient_penalty = penalty_lambda * ((gradient[0].view(gradient[0].size()[0], -1).norm(p=2,dim=1)-1)**2).mean()
  3. gradient_penalty

最后只需要把上面的三个部分相加, 进行反向传播来进行优化即可.

  1. # 三个loss相加, 反向传播进行优化
  2. d_loss = d_loss_real + d_loss_fake + gradient_penalty
  3. g_optimizer.zero_grad() # 两个优化器梯度都要清0
  4. d_optimizer.zero_grad()
  5. d_loss.backward()
  6. d_optimizer.step()

 

训练Generator

WGAN-GP优化器部分的训练和其他的没有什么太大的不同, 我们这里就简单说明一下即可.

  1. normal_noise = Variable(torch.randn(batch_size, 100, 1, 1)).normal_(0, 1).to(device)
  2. fake_images = G(normal_noise) # 生成假的图片
  3. outputs = D(fake_images) # 放入辨别器
  4. g_loss = -torch.mean(outputs) # 希望生成器生成的图片判别器可以判别为真
  5. d_optimizer.zero_grad()
  6. g_optimizer.zero_grad()
  7. g_loss.backward()
  8. g_optimizer.step()

到这里基本所有的步骤就完毕了, 后面就是开始训练就可以了. 我们后面直接看一下训练完毕之后的结果.

 

结果展示

我们将上面的步骤重复N次, 反复训练D和G, 并将结果进行保存. 下面我们来看一下最后生成器生成的效果.

首先我们导入已经训练好的模型.

  1. G = Generator().to(device) # 定义生成器
  2. # 读入生成器的模型
  3. G.load_state_dict(torch.load('./models/G.ckpt', map_location='cpu'))

接着我们使用G来进行图像的生成, 并显示出来. 在这之前, 我们首先定义一个函数来帮助我们进行显示.

  1. def show(img):
  2.     """
  3.     用来显示图片的
  4.     """
  5.     plt.figure(figsize=(24, 16))
  6.     npimg = img.detach().numpy()
  7.     plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

最后就是查看显示的结果了.

  1. # 使用生成器来进行生成
  2. test_noise = Variable(torch.FloatTensor(40, 100, 1, 1).normal_(0, 1)).to(device)
  3. fake_image = G(test_noise)
  4. show(make_grid(fake_image, nrow=8, padding=1, normalize=Truerange=(-1, 1), scale_each=False, pad_value=0.5))

最终显示的结果如下所示, 可以看到不仔细看还是可以的:

WGAN-GP训练流程

训练结果的渐变

这里我们随机选择两张图片A, B; 作为对角线两端的图片. 接着我们将A的前50个变量在x轴缓慢变为B的前50个变量; 将A的后50个变量在y轴缓慢变为B的后50个变量, 于是就可以得到一个从A到B变换过程的图片.

我们首先随机取出两个图片.

  1. test_noise = Variable(torch.FloatTensor(2, 100, 1, 1).normal_(0, 1)).to(device)
  2. fake_image = G(test_noise)
  3. show(make_grid(fake_image, nrow=2, padding=1, normalize=Truerange=(-1, 1), scale_each=False, pad_value=0.5))
WGAN-GP训练流程

接着定义左上角和右下角的图片.

  1. leftTop = test_noise[0] # 左上角图片
  2. rightBottom = test_noise[1] # 右下角图片

最后我们进行生成他们的中间值,

  1. # 生成10*10的noise, 用来作为输入进行生成
  2. interval = 10 # 图片的大小
  3. rowAdd = [(leftTop[i]-rightBottom[i])/interval for i in range(0,50)] # 每一行每一格变化的长度
  4. rowAddNums = [[leftTop[i] - rowAdd[i]*k for k in range(interval+1) ] for i in range(0,50)] # 变换后每一格的值
  5. rowAddNums = np.transpose(np.array(rowAddNums))
  6. columnAdd = [(leftTop[j]-rightBottom[j])/interval for j in range(50,100)] # 每一列每一格变化的长度
  7. columnAddNums = [[leftTop[j] - columnAdd[j-50]*k for k in range(interval+1)] for j in range(50,100)] # 变换后每一格的值
  8. columnAddNums = np.transpose(np.array(columnAddNums))
  9. InterpolationNoise = leftTop.unsqueeze(0)
  10. for columnAddNum in columnAddNums:
  11.     for rowAddNum in rowAddNums:
  12.         for i in range(0, 50):
  13.             leftTop[i] = torch.tensor(rowAddNum[i])
  14.         for j in range(50, 100):
  15.             leftTop[j] = torch.tensor(columnAddNum[j-50])
  16.         InterpolationNoise = torch.cat((InterpolationNoise, leftTop.unsqueeze(0)), dim=0)

最后进行可视化即可.

  1. fake_image = G(InterpolationNoise[1:])
  2. show(make_grid(fake_image, nrow=11, padding=1, normalize=Truerange=(-1, 1), scale_each=False, pad_value=0.5))

下面我们来放一些最终比较好的结果. 简单放三张结果图.

WGAN-GP训练流程 WGAN-GP训练流程 WGAN-GP训练流程

训练不同轮数的结果展示

下面我们看一下训练不同epoch后的结果, 看一下逐渐变化的过程.

第一轮

WGAN-GP训练流程

第十轮

WGAN-GP训练流程

第五十轮

WGAN-GP训练流程

第一百轮

WGAN-GP训练流程

第两百轮

WGAN-GP训练流程

第五百轮

WGAN-GP训练流程

第一千轮

WGAN-GP训练流程
  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: