Pytorch入门教程12-全连接网络的手写数字识别(MNIST)

  • A+
所属分类:Pytorch快速入门
摘要前面我们介绍了正向传播, 反向传播, 梯度下降法. 也介绍了Pytorch中的损失函数和优化器, 数据加载器, 数据预处理, 和交叉熵. 这一篇, 我们使用之前学习到的所有知识, 建立一个全连接的神经网络, 来完成手写字符的识别.

简介

前面我们介绍了正向传播, 反向传播, 梯度下降法. 也介绍了Pytorch中的损失函数和优化器, 数据加载器, 数据预处理, 和交叉熵. 这一篇, 我们使用之前学习到的所有知识, 建立一个全连接的神经网络, 来完成手写字符的识别.

之前我曾经写过一个版本的使用Pytorch实现手写字符的识别, 当时主要目的是为了实现如何动态修改网络结构, 所以总的结构不是很完整, 这一篇会写的比较完整. 上一个版本的链接: Pytorch模型实例-MNIST dataset

这一部分的代码, 见github链接全连接网络的手写数字识别(MNIST).ipynb

 

全连接网络完成手写数字识别

准备工作

在准备工作阶段, 我们需要导入需要的库, 并且判断实验环境是否支持GPU, 还是只能使用CPU. 

  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. import numpy as np
  6. import pandas as pd
  7. import matplotlib.pyplot as plt
  8. %matplotlib inline
  9. # Device configuration
  10. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  11. device

 

数据加载与数据预处理

我们使用MNIST数据集, 该数据集可以使用torchvision.datasets.MNIST获得. 这一阶段的任务如下所示:

  • 创建dataset
    • 加载MNIST数据
    • 进行数据预处理, 转换为tensor
  • 创建dataloader
    • 将dataset传入dataloader, 设置batchsize

首先我们创建dataset, 同时设置数据预处理.

  1. # 将数据集合下载到指定目录下,这里的transform表示,数据加载时所需要做的预处理操作
  2. # 加载训练集合(Train)
  3. train_dataset = torchvision.datasets.MNIST(root='./data',
  4.                                            train=True,
  5.                                            transform=torchvision.transforms.ToTensor(),
  6.                                            download=True)
  7. # 加载测试集合(Test)
  8. test_dataset = torchvision.datasets.MNIST(root='./data',
  9.                                           train=False,
  10.                                           transform=transforms.ToTensor())
  11. print(train_dataset) # 训练集
  12. """
  13. Dataset MNIST
  14.     Number of datapoints: 60000
  15.     Root location: ./data
  16.     Split: Train
  17.     StandardTransform
  18. Transform: ToTensor()
  19. """
  20. print(test_dataset) # 测试集
  21. """
  22. Dataset MNIST
  23.     Number of datapoints: 10000
  24.     Root location: ./data
  25.     Split: Test
  26.     StandardTransform
  27. Transform: ToTensor()
  28. """

接着设置dataloader, 设置batchsize的大小. 这里的dataloader就是训练的时候会用到的.

  1. batch_size = 100
  2. # 根据数据集定义数据加载器
  3. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  4.                                            batch_size=batch_size,
  5.                                            shuffle=True)
  6. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  7.                                           batch_size=batch_size,
  8.                                           shuffle=False)

最后查看一下样例数据(样例图像), 注意如何查看dataloader中的数据:

  1. # 查看数据
  2. examples = iter(test_loader)
  3. example_data, example_target = examples.next() # 100*1*28*28
  4. for i in range(9):
  5.     plt.subplot(3,3,i+1).set_title(example_target[i])
  6.     plt.imshow(example_data[i][0], 'gray')
  7. plt.tight_layout()
  8. plt.show()
Pytorch入门教程12-全连接网络的手写数字识别(MNIST)

 

网络的构建

我们定义一个三层的网络, 网络大小分别是(784, 500, 10), 网络定义如下所示:

  1. # 输入节点数就为图片的大小:28×28×1
  2. input_size = 784
  3. # 由于数字为 0-9,因此是10分类问题,因此输出节点数为 10
  4. num_classes = 10
  5. # 网络的建立
  6. class NeuralNet(nn.Module):
  7.     # 输入数据的维度,中间层的节点数,输出数据的维度
  8.     def __init__(self, input_size, hidden_size, num_classes):
  9.         super(NeuralNet, self).__init__()
  10.         self.input_size = input_size
  11.         self.l1 = nn.Linear(input_size, hidden_size)
  12.         self.relu = nn.ReLU()
  13.         self.l2 = nn.Linear(hidden_size, num_classes)
  14.     def forward(self, x):
  15.         out = self.relu(self.l1(x))
  16.         out = self.l2(out)
  17.         return out
  18. model = NeuralNet(input_size, 500, num_classes).to(device)
  19. model
  20. """
  21. NeuralNet(
  22.   (l1): Linear(in_features=784, out_features=500, bias=True)
  23.   (relu): ReLU()
  24.   (l2): Linear(in_features=500, out_features=10, bias=True)
  25. )
  26. """

我们用dataloader中的数据测试一下网络, 是否可以正常运算.

  1. # 简单测试模型的输出
  2. examples = iter(test_loader)
  3. example_data, _ = examples.next() # 100*1*28*28
  4. model(example_data.reshape(example_data.size(0),-1)).shape
  5. """
  6. torch.Size([100, 10])
  7. """

为了更加详细的了解网络的构造, 我们可以构造一个测试样例, 并打印出每一层的输出大小.

  1. X = torch.randn(1, 1, 224, 224)
  2. for layer in net:
  3.     X=layer(X)
  4.     print(layer.__class__.__name__,'Output shape:\t',X.shape)

 

定义损失函数和优化器

接下来, 我们定义损失函数和优化器.

  1. # 定义学习率
  2. learning_rate = 0.001
  3. criterion = nn.CrossEntropyLoss()
  4. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

 

模型的训练与测试

有了上面的铺垫之后, 我们就可以开始模型的训练和测试了. 我们定义需要的epoch的数量, 同时保存每一个epoch的loss和在测试集上的准确率.

下面的代码有几个注意的点:

  • loss不是每一次的loss, 而是会计算整个epoch的平均loss
  • 每一个epoch结束之后在测试集上进行测试, 保存测试的准确率
  1. num_epochs = 10
  2. n_total_steps = len(train_loader)
  3. LossList = [] # 记录每一个epoch的loss
  4. AccuryList = [] # 每一个epoch的accury
  5. for epoch in range(num_epochs):
  6.     # -------
  7.     # 开始训练
  8.     # -------
  9.     model.train() # 切换为训练模型
  10.     totalLoss = 0
  11.     for i, (images, labels) in enumerate(train_loader):
  12.         images = images.reshape(-1, 28*28).to(device) # 图片大小转换
  13.         labels = labels.to(device)
  14.         # 正向传播以及损失的求取
  15.         outputs = model(images)
  16.         loss = criterion(outputs, labels)
  17.         totalLoss = totalLoss + loss.item()
  18.         # 反向传播
  19.         optimizer.zero_grad() # 梯度清空
  20.         loss.backward() # 反向传播
  21.         optimizer.step() # 权重更新
  22.         if (i+1) % 300 == 0:
  23.             print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, n_total_steps, totalLoss/(i+1)))
  24.     LossList.append(totalLoss/(i+1))
  25.     # ---------
  26.     # 开始测试
  27.     # ---------
  28.     model.eval()
  29.     with torch.no_grad():
  30.         correct = 0
  31.         total = 0
  32.         for images, labels in test_loader:
  33.             images = images.reshape(-1, 28*28).to(device)
  34.             labels = labels.to(device)
  35.             outputs = model(images)
  36.             _, predicted = torch.max(outputs.data, 1) # 预测的结果
  37.             total += labels.size(0)
  38.             correct += (predicted == labels).sum().item()
  39.         acc = 100.0 * correct / total # 在测试集上总的准确率
  40.         AccuryList.append(acc)
  41.         print('Accuracy of the network on the {} test images: {} %'.format(total, acc))
  42. print("模型训练完成")
  43. """
  44. Epoch [9/10], Step [300/600], Loss: 0.0112
  45. Epoch [9/10], Step [600/600], Loss: 0.0133
  46. Accuracy of the network on the 10000 test images: 98.03 %
  47. Epoch [10/10], Step [300/600], Loss: 0.0092
  48. Epoch [10/10], Step [600/600], Loss: 0.0114
  49. Accuracy of the network on the 10000 test images: 98.08 %
  50. 模型训练完成
  51. """

接着我们绘制一下loss和accuracy的变化趋势. 首先查看loss的变化趋势.

  1. # 绘制loss的变化
  2. fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(13,7))
  3. axes.plot(LossList, 'k--')
Pytorch入门教程12-全连接网络的手写数字识别(MNIST)

接着查看accuracy的变化趋势:

  1. # 绘制loss的变化
  2. fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(13,7))
  3. axes.plot(AccuryList, 'k--')
Pytorch入门教程12-全连接网络的手写数字识别(MNIST)

 

查看测试集上预测结果

最后, 我们从测试集中找几个例子来看一下实际的测试结果.

  1. # 测试样例
  2. examples = iter(test_loader)
  3. example_data, example_targets = examples.next()
  4. # 实际图片
  5. for i in range(9):
  6.     plt.subplot(3, 3, i+1)
  7.     plt.imshow(example_data[i][0], cmap='gray')
  8. plt.show()
  9. # 结果的预测
  10. images = example_data.reshape(-1, 28*28).to(device)
  11. labels = example_targets.to(device)
  12. # 正向传播以及损失的求取
  13. outputs = model(images)
  14. # 将 Tensor 类型的变量 example_targets 转为 numpy 类型的,方便展示
  15. print("上面三张图片的真实结果:", example_targets[0:9].detach().numpy())
  16. # 将得到预测结果
  17. # 由于预测结果是 N×10 的矩阵,因此利用 np.argmax 函数取每行最大的那个值,最为预测值
  18. print("上面三张图片的预测结果:", np.argmax(outputs[0:9].detach().numpy(), axis=1))

最终的结果如下, 可以看到前9个数字的预测都是准确的.

Pytorch入门教程12-全连接网络的手写数字识别(MNIST)
  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: