文章目录(Table of Contents)
简介
这一篇还是GAN的实验. 之前一篇介绍了使用GAN生成服从高斯分布的数据,简单GAN的实验--生成高斯分布数据(Pytorch),这一篇使用图片来做一下实验,获得更加直观的感受。
实验内容简介
这一篇会使用MNIST数据集作为训练集,我们希望我们的生成手写数字。这一篇的内容主要参考自Pytorch的指南, generative_adversarial_network--Pytorch
存在的一个问题
需要注意的是,原始代码中是存在问题的,因为MNIST是灰度图,他在数据处理的时候使用了3通道的处理方式,所以直接运行会出现如下的报错:
output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]。
解决方法将transform修改为如下即可。参考链接: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.5], std=[0.5])])
实验步骤
因为这里训练GAN的方式和简单GAN的实验--生成高斯分布数据(Pytorch)是差不多的,所以就不详细说步骤了,我就把代码拆解了进行简单说明一下。
其实主要还是训练的时候,对生成器(G)和分类器(D)的训练. 我会把源代码上传github, 下面就贴出关键的部分的代码。
代码链接: GAN实验代码
数据准备与预处理
首先我们准备训练需要的数据。注意我们要对数据做标准化,使得其范围从[0,1]转换为[-1,1]
- # Image processing
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.5], std=[0.5])])
- # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) # 3 for RGB channels
- # MNIST dataset
- mnist = torchvision.datasets.MNIST(root='./',
- train=True,
- transform=transform,
- download=True)
- # Data loader
- data_loader = torch.utils.data.DataLoader(dataset=mnist,
- batch_size=batch_size,
- shuffle=True)
网络的定义
接下来定义判别器(D)和生成器(G). 首先是判别器的部分.
- class Discriminator(nn.Module):
- def __init__(self, input_size, hidden_size, output_size):
- super(Discriminator, self).__init__()
- self.map1 = nn.Linear(input_size, hidden_size)
- self.map2 = nn.Linear(hidden_size, hidden_size)
- self.map3 = nn.Linear(hidden_size, output_size)
- self.leakyrelu = nn.LeakyReLU(0.2)
- self.sigmoid = nn.Sigmoid()
- def forward(self, x):
- x = self.leakyrelu(self.map1(x))
- x = self.leakyrelu(self.map2(x))
- x = self.sigmoid(self.map3(x))# 最后生成的是概率
- return x
接着是生成器的部分,生成器最后过tanh可以使得输出的范围被要压缩到(-1,1)
- class Generator(nn.Module):
- def __init__(self, input_size, hidden_size, output_size):
- super(Generator, self).__init__()
- self.map1 = nn.Linear(input_size, hidden_size)
- self.map2 = nn.Linear(hidden_size, hidden_size)
- self.map3 = nn.Linear(hidden_size, output_size)
- self.relu = nn.ReLU()
- self.tanh = nn.Tanh() # 激活函数
- def forward(self, x):
- x = self.relu(self.map1(x))
- x = self.relu(self.map2(x))
- x = self.tanh(self.map3(x))
- return x
初始化网络与定义优化器和损失函数
- # ----------
- # 初始化网络
- # ----------
- D = Discriminator(input_size=image_size,
- hidden_size=hidden_size,
- output_size=1).to(device)
- G = Generator(input_size=latent_size,
- hidden_size=hidden_size,
- output_size=image_size).to(device)
- # -----------------------
- # 定义损失函数和优化器
- # -----------------------
- learning_rate = 0.0003
- criterion = nn.BCELoss()
- d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate)
- g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate)
- d_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(d_optimizer, step_size=50, gamma=0.9)
- g_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size=50, gamma=0.9)
在这里我们使得learning rate每50轮进行递减,每次变为之前的90%.
定义辅助函数
这里定义两个辅助函数,用来做图片的还原,为了最后的显示。
- # 定义辅助函数
- def denorm(x):
- """
- 用来还原图片, 之前做过标准化
- """
- out = (x + 1) / 2
- return out.clamp(0, 1)
- def reset_grad():
- d_optimizer.zero_grad()
- g_optimizer.zero_grad()
开始训练
- total_step = len(data_loader)
- # ------------------
- # 一开始学习率快一些
- # ------------------
- for epoch in range(250):
- d_exp_lr_scheduler.step()
- g_exp_lr_scheduler.step()
- for i, (images, _) in enumerate(data_loader):
- images = images.reshape(batch_size, -1).to(device)
- # 创造real label和fake label
- real_labels = torch.ones(batch_size, 1).to(device)
- fake_labels = torch.zeros(batch_size, 1).to(device)
- # ---------------------
- # 开始训练discriminator
- # ---------------------
- # 首先计算真实的图片
- outputs = D(images)
- d_loss_real = criterion(outputs, real_labels)
- real_score = outputs # 真实图片的分类结果, 越接近1越好
- # 接着使用生成器训练得到图片, 放入判别器
- z = torch.randn(batch_size, latent_size).to(device)
- fake_images = G(z)
- outputs = D(fake_images)
- d_loss_fake = criterion(outputs, fake_labels)
- fake_score = outputs # 错误图片的分类结果, 越接近0越好, 最后会趋于1, 生成器生成的判别器判断不了
- # 两个loss相加, 反向传播进行优化
- d_loss = d_loss_real + d_loss_fake
- g_optimizer.zero_grad() # 两个优化器梯度都要清0
- d_optimizer.zero_grad()
- d_loss.backward()
- d_optimizer.step()
- # -----------------
- # 开始训练generator
- # -----------------
- z = torch.randn(batch_size, latent_size).to(device)
- fake_images = G(z)
- outputs = D(fake_images)
- g_loss = criterion(outputs, real_labels) # 希望生成器生成的图片判别器可以判别为真
- d_optimizer.zero_grad()
- g_optimizer.zero_grad()
- g_loss.backward()
- g_optimizer.step()
- if (i+1) % 200 == 0:
- print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}, d_lr={:.6f},g_lr={:.6f}'
- .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
- real_score.mean().item(), fake_score.mean().item(),
- d_optimizer.param_groups[0]['lr'], g_optimizer.param_groups[0]['lr']))
- # Save real images
- if (epoch+1) == 1:
- images = images.reshape(images.size(0), 1, 28, 28)
- save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
- # Save sampled images
- fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
- save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
- # Save the model checkpoints
- torch.save(G.state_dict(), './models/G.ckpt')
- torch.save(D.state_dict(), './models/D.ckpt')
我这里只贴一部分的训练代码,实际上后面还会对学习率有一些优化的措施。
最终结果
最后,就来看一下最终的效果。下面是训练不同次数的效果,到最后效果还是不错的。
下面是训练了2000轮后的效果。
但其实仔细看我们会发现一个问题,就是会有很多1出现。这是因为使用这种方式来进行训练的时候,我们不能控制输出,之后我们会将Conditional GAN来解决这个问题.
- 微信公众号
- 关注微信公众号
- QQ群
- 我们的QQ群号
评论