GAN实验–生成手写数字(Pytorch)

王 茂南 2019年7月22日07:21:25
评论
6115字阅读20分23秒
摘要这一篇我们会讲一下使用GAN来生成手写数字,也是一个简单的实验,实验Pytorch来进行完成。

简介

这一篇还是GAN的实验. 之前一篇介绍了使用GAN生成服从高斯分布的数据,简单GAN的实验--生成高斯分布数据(Pytorch),这一篇使用图片来做一下实验,获得更加直观的感受。

实验内容简介

这一篇会使用MNIST数据集作为训练集,我们希望我们的生成手写数字。这一篇的内容主要参考自Pytorch的指南, generative_adversarial_network--Pytorch

存在的一个问题

需要注意的是,原始代码中是存在问题的,因为MNIST是灰度图,他在数据处理的时候使用了3通道的处理方式,所以直接运行会出现如下的报错:

output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]。

解决方法将transform修改为如下即可。参考链接: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28

  1. transform = transforms.Compose([
  2.                 transforms.ToTensor(),
  3.                 transforms.Normalize(mean=[0.5], std=[0.5])])

 实验步骤

因为这里训练GAN的方式和简单GAN的实验--生成高斯分布数据(Pytorch)是差不多的,所以就不详细说步骤了,我就把代码拆解了进行简单说明一下。

其实主要还是训练的时候,对生成器(G)和分类器(D)的训练. 我会把源代码上传github, 下面就贴出关键的部分的代码。

代码链接: GAN实验代码

数据准备与预处理

首先我们准备训练需要的数据。注意我们要对数据做标准化,使得其范围从[0,1]转换为[-1,1]

  1. # Image processing
  2. transform = transforms.Compose([
  3.                 transforms.ToTensor(),
  4.                 transforms.Normalize(mean=[0.5], std=[0.5])])
  5.                 # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) # 3 for RGB channels
  6. # MNIST dataset
  7. mnist = torchvision.datasets.MNIST(root='./',
  8.                                    train=True,
  9.                                    transform=transform,
  10.                                    download=True)
  11. # Data loader
  12. data_loader = torch.utils.data.DataLoader(dataset=mnist,
  13.                                           batch_size=batch_size,
  14.                                           shuffle=True)

 网络的定义

接下来定义判别器(D)和生成器(G). 首先是判别器的部分.

  1. class Discriminator(nn.Module):
  2.     def __init__(self, input_size, hidden_size, output_size):
  3.         super(Discriminator, self).__init__()
  4.         self.map1 = nn.Linear(input_size, hidden_size)
  5.         self.map2 = nn.Linear(hidden_size, hidden_size)
  6.         self.map3 = nn.Linear(hidden_size, output_size)
  7.         self.leakyrelu = nn.LeakyReLU(0.2)
  8.         self.sigmoid = nn.Sigmoid()
  9.     def forward(self, x):
  10.         x = self.leakyrelu(self.map1(x))
  11.         x = self.leakyrelu(self.map2(x))
  12.         x = self.sigmoid(self.map3(x))# 最后生成的是概率
  13.         return x

接着是生成器的部分生成器最后过tanh可以使得输出的范围被要压缩到(-1,1)

  1. class Generator(nn.Module):
  2.     def __init__(self, input_size, hidden_size, output_size):
  3.         super(Generator, self).__init__()
  4.         self.map1 = nn.Linear(input_size, hidden_size)
  5.         self.map2 = nn.Linear(hidden_size, hidden_size)
  6.         self.map3 = nn.Linear(hidden_size, output_size)
  7.         self.relu = nn.ReLU()
  8.         self.tanh = nn.Tanh() # 激活函数
  9.     def forward(self, x):
  10.         x = self.relu(self.map1(x))
  11.         x = self.relu(self.map2(x))
  12.         x = self.tanh(self.map3(x))
  13.         return x

初始化网络与定义优化器和损失函数

  1. # ----------
  2. # 初始化网络
  3. # ----------
  4. D = Discriminator(input_size=image_size,
  5.                   hidden_size=hidden_size,
  6.                   output_size=1).to(device)
  7. G = Generator(input_size=latent_size,
  8.               hidden_size=hidden_size,
  9.               output_size=image_size).to(device)
  10. # -----------------------
  11. # 定义损失函数和优化器
  12. # -----------------------
  13. learning_rate = 0.0003
  14. criterion = nn.BCELoss()
  15. d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate)
  16. g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate)
  17. d_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(d_optimizer, step_size=50, gamma=0.9)
  18. g_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size=50, gamma=0.9)

在这里我们使得learning rate每50轮进行递减,每次变为之前的90%.

定义辅助函数

这里定义两个辅助函数,用来做图片的还原,为了最后的显示。

  1. # 定义辅助函数
  2. def denorm(x):
  3.     """
  4.     用来还原图片, 之前做过标准化
  5.     """
  6.     out = (x + 1) / 2
  7.     return out.clamp(0, 1)
  8. def reset_grad():
  9.     d_optimizer.zero_grad()
  10.     g_optimizer.zero_grad()

 开始训练

  1. total_step = len(data_loader)
  2. # ------------------
  3. # 一开始学习率快一些
  4. # ------------------
  5. for epoch in range(250):
  6.     d_exp_lr_scheduler.step()
  7.     g_exp_lr_scheduler.step()
  8.     for i, (images, _) in enumerate(data_loader):
  9.         images = images.reshape(batch_size, -1).to(device)
  10.         # 创造real label和fake label
  11.         real_labels = torch.ones(batch_size, 1).to(device)
  12.         fake_labels = torch.zeros(batch_size, 1).to(device)
  13.         # ---------------------
  14.         # 开始训练discriminator
  15.         # ---------------------
  16.         # 首先计算真实的图片
  17.         outputs = D(images)
  18.         d_loss_real = criterion(outputs, real_labels)
  19.         real_score = outputs # 真实图片的分类结果, 越接近1越好
  20.         # 接着使用生成器训练得到图片, 放入判别器
  21.         z = torch.randn(batch_size, latent_size).to(device)
  22.         fake_images = G(z)
  23.         outputs = D(fake_images)
  24.         d_loss_fake = criterion(outputs, fake_labels)
  25.         fake_score = outputs # 错误图片的分类结果, 越接近0越好, 最后会趋于1, 生成器生成的判别器判断不了
  26.         # 两个loss相加, 反向传播进行优化
  27.         d_loss = d_loss_real + d_loss_fake
  28.         g_optimizer.zero_grad() # 两个优化器梯度都要清0
  29.         d_optimizer.zero_grad()
  30.         d_loss.backward()
  31.         d_optimizer.step()
  32.         # -----------------
  33.         # 开始训练generator
  34.         # -----------------
  35.         z = torch.randn(batch_size, latent_size).to(device)
  36.         fake_images = G(z)
  37.         outputs = D(fake_images)
  38.         g_loss = criterion(outputs, real_labels) # 希望生成器生成的图片判别器可以判别为真
  39.         d_optimizer.zero_grad()
  40.         g_optimizer.zero_grad()
  41.         g_loss.backward()
  42.         g_optimizer.step()
  43.         if (i+1) % 200 == 0:
  44.             print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}, d_lr={:.6f},g_lr={:.6f}'
  45.                   .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
  46.                           real_score.mean().item(), fake_score.mean().item(),
  47.                          d_optimizer.param_groups[0]['lr'], g_optimizer.param_groups[0]['lr']))
  48.         # Save real images
  49.         if (epoch+1) == 1:
  50.             images = images.reshape(images.size(0), 1, 28, 28)
  51.             save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
  52.         # Save sampled images
  53.         fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
  54.         save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
  55. # Save the model checkpoints 
  56. torch.save(G.state_dict(), './models/G.ckpt')
  57. torch.save(D.state_dict(), './models/D.ckpt')

我这里只贴一部分的训练代码,实际上后面还会对学习率有一些优化的措施。

最终结果

最后,就来看一下最终的效果。下面是训练不同次数的效果,到最后效果还是不错的。

GAN实验–生成手写数字(Pytorch)

下面是训练了2000轮后的效果。

GAN实验–生成手写数字(Pytorch)

但其实仔细看我们会发现一个问题,就是会有很多1出现。这是因为使用这种方式来进行训练的时候,我们不能控制输出,之后我们会将Conditional GAN来解决这个问题.

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南
  • 本文由 发表于 2019年7月22日07:21:25
  • 转载请务必保留本文链接:https://mathpretty.com/10811.html
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: