文章目录(Table of Contents)
简介
前面我们介绍了Pytorch中数据加载器的简单使用. 在完成数据加载之后, 我们还需要进行数据的预处理. 例如:
- 特征数据需要进行标准化, 将不同的特征都转换为0-1之间的数.
- 图像数据也需要从0-255转换为0-1之间, 同时图像数据还需要对大小进行剪裁等操作. 同时也可以通过图像变换来增加数据量, 提高模型的准确率和鲁棒性.
所以这一篇, 我们主要介绍一下Pytorch中数据预处理的部分.
参考资料
- 这一部分的所有代码, 见Github仓库: Pytorch数据预处理部分
- 数据预处理一个总的说明: Python科学运算-数据预处理
- 数据预处理的一个例子: 数据预处理说明(操作分析)
- 关于使用matplotlib显示颜色的问题, 图像处理-matplotlib显示opencv图像
- 关于Pytorch中图像的函数, make_grid的使用, Pytorch图像处理,显示与保存
torchvision.transforms.Compose
首先我们介绍一下Pytorch中的Compose操作. 该函数的主要作用是将所有的预处理操作进行打包. 当有数据来的时候, 该函数可以对数据进行所有预处理的操作. 下面看一个简单的例子.
现在有一些来自numpy的数据, 我们想要, 将其转换为tensor的格式; 同时, 对原始数据进行2*x+1的操作. 我们首先定义这两个变换.
- class ToTensor:
- # 将Numpy转换为tensor格式
- def __call__(self, X):
- return torch.from_numpy(X)
- class AddMulTransform:
- def __call__(self, X):
- X = torch.tensor([1]) + 2*X
- return X
接着使用Compose, 合并上面两个变换的操作.
- # 定义预处理集合器
- composed = torchvision.transforms.Compose([ToTensor(), AddMulTransform()])
接着我们放入数据进行测试:
- # 进行数据的测试
- data = np.array([[1, 2, 3], [3,4,5]])
- composed(data)
- """
- tensor([[ 3, 5, 7],
- [ 7, 9, 11]])
- """
可以看到原始数据, 例如2, 通过2*2+1的操作变为了5. 同时数据类型也是转换为tensor.
特征数据预处理
对于数值的特征, 我们会在传入网络之前进行标准化, 使得各个特征之间不会有很大的数值上的差异. 上面我们是直接对原始数据进行变换.
- 除了上面的方法, 我们可以使用sklearn中的数据预处理的方式对数据进行处理
- 我们也可以将其与dataset进行结合, 这样可以把数据导入和数据预处理结合在一起.
特征数据标准化
首先我们看一下如何使用sklearn中的数据预处理的方式来对数据进行处理.
- from sklearn.preprocessing import StandardScaler
- sc = StandardScaler()
- # 对特征进行标准化,标签不要标准化,因为标签只有 0 和 1
- X_train = sc.fit_transform(X_train)
- X_test = sc.transform(X_test)
- X_train
与dataset结合使用
我们定义一个dataset, 基本内容是和之前一样的. 唯一的不同是我们在getitem中加入了transform, 来对数据进行预处理. 下面是一个简单的例子.
- class TestDataset(Dataset):
- # 需要继承dataset数据集
- def __init__(self, transform):
- # I初始化数据
- xy = np.random.random((100, 13))*100
- self.n_samples = xy.shape[0] # 样本的个数
- # 这里我们就不做Tensor的转换了,将其全部放入 transform 中
- self.x_data = xy[:, 1:]
- self.y_data = xy[:, 0].reshape(-1,1)
- # 数据预处理集合
- self.transform = transform
- # 返回 dataset[index]
- def __getitem__(self, index):
- sample = self.x_data[index], self.y_data[index]
- if self.transform:
- # 数据预处理在这里(只对x进行预处理)
- sample = self.transform(sample)
- return sample
- def __len__(self):
- # 返回数据长度
- return self.n_samples
我们定义两种变换, 分别是将数据转换为0-1之间, 和将数据转换为tensor格式. 需要注意的是, 这里归一化操作是只对feature进行的归一化, 可以看下面的代码.
- # 定义归一化操作
- class Normalization:
- """这里是只对features进行归一化处理
- """
- def __call__(self, sample):
- inputs, targets = sample
- amin, amax = inputs.min(), inputs.max() # 求最大最小值
- inputs = (inputs-amin)/(amax-amin) # (矩阵元素-最小值)/(最大值-最小值)
- return inputs, targets
- # 定义numpy转tensor
- class ToTensor:
- def __call__(self, sample):
- inputs, targets = sample
- return torch.from_numpy(inputs), torch.from_numpy(targets)
接着我们使用compose合并上面的两种变换, 同时初始化dataset并传入transform.
- # 定义 composed
- composed = torchvision.transforms.Compose([Normalization(), ToTensor()])
- # 初始化dataset的时候传入transform即可
- dataset = TestDataset(transform=composed)
我们来看一下dataset中的数据, 看一下是否完成了预处理.
- features_0, labels_0 = dataset[0]
- # 查看数据类型
- print(type(features_0), type(labels_0))
- """
- <class 'torch.Tensor'> <class 'torch.Tensor'>
- """
- # 查看是否进行归一化
- print(features_0) # 这里是进行归一化的结果
- print(labels_0) # 这里是没有进行归一化的结果
- """
- tensor([0.0935, 0.2159, 0.9332, 0.0000, 0.0356, 0.3848, 0.4804, 0.0839, 0.1280,
- 0.7919, 0.6031, 1.0000], dtype=torch.float64)
- tensor([44.6151], dtype=torch.float64)
- """
可以看到数据类型转换为tensor, 同时所有的feature都转换为0-1之间的数字, 但是label是没有变换的. 可以看到在初始化dataset的时候, 我们就同时完成了数据的预处理.
图像数据预处理
除了有特征数据之外, 有的时候我们的数据集会是图像的数据集. 我们使用下面的图像作为例子, 来说明图像数据集如何进行预处理. 首先我们显示原始图像:
- import torchvision.transforms as transforms
- from PIL import Image
- import matplotlib.pyplot as plt
- from matplotlib.pyplot import imshow
- %matplotlib inline
- img = Image.open("20200804_141727_tnzpt18.jpg")
- imshow(img)
剪裁图像大小
有的时候, 我们需要从中心开始对原始图像进行剪裁, 剪裁为固定的大小.
- # torchvision.transforms.CenterCrop(size):从中心开始,裁剪给定大小的 PIL 图像
- transform = transforms.CenterCrop((300, 400))
- new_img = transform(img)
- imshow(new_img)
图像的随机剪裁
- #torchvision.transforms.RandomResizedCrop(size, scale,ratio,interpolation)
- new_img = transforms.RandomResizedCrop((300, 400), scale=(0.08, 1.0), ratio=(0.75, 1.333333333), interpolation=2)(img)
- imshow(new_img)
改变图像亮度, 对比对和饱和度
使用transforms.ColorJitter
来完成改变图像亮度, 对比对和饱和度. 下面是一个简单的例子.
- # transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):
- # 改变图片的亮度、对比度和饱和度
- plt.subplot(221)
- imshow(img) # 显示原始图片
- # r随机改变亮度
- my_img1 = transforms.ColorJitter((0.5, 0.6))(img)
- plt.subplot(222)
- imshow(my_img1)
- # 随机改变对比度
- my_img2 = transforms.ColorJitter(0, (0.5, 0.6))(img)
- plt.subplot(223)
- imshow(my_img2)
- # 随机改变饱和度
- my_img3 = transforms.ColorJitter(0, 0, (0.5, 0.6))(img)
- plt.subplot(224)
- imshow(my_img3)
- plt.show()
转换为灰度图
使用torchvision.transforms.Grayscale(num_output_channels)将图像转换为灰度图.
- 如果返回的图像是单通道num_output_channels=1
- 如果返回的图像是3通道,其中num_output_channels=3, 且r == g == b
也就是说三通道也是灰度图, 不过三个通道的值是一样的.
- plt.subplot(131)
- imshow(img) # 显示原始图片
- my_img1 = transforms.Grayscale(1)(img)
- plt.subplot(132)
- imshow(my_img1, 'gray')
- my_img2 = transforms.Grayscale(3)(img)
- plt.subplot(133)
- imshow(my_img2)
图像的填充
使用pad对图像进行填充. 下面的实验中我们在图像的四周加上黑边.
- # transforms.Pad(padding,fill = 0,padding_mode ='constant' ):
- # 使用给定的 pad 值将给定的 PIL 图像四处填充
- plt.subplot(121)
- imshow(img)
- # 四周加黑色的边界
- my_img = transforms.Pad(padding=200, fill=(0, 0, 0), padding_mode='constant')(img)
- plt.subplot(122).set_title("Pad")
- imshow(my_img)
中心仿射变换
- # transforms.RandomAffine(degrees, translate=None, scale=None,
- # shear=None, resample=False, fillcolor=0):
- # 保持图像中心不变的中心仿射变换(可以理解为不同程度的旋转,再在空余位置补 0)
- my_img1 = transforms.RandomAffine(60)(img)
- plt.subplot(221).set_title("rotate_only")
- imshow(my_img1)
- my_img2 = transforms.RandomAffine(60, translate=(0.3, 0.3))(img)
- plt.subplot(222).set_title("rotate_translate")
- imshow(my_img2)
- my_img3 = transforms.RandomAffine(60, scale=(2.0, 2.1))(img)
- plt.subplot(223).set_title("rotate_scale")
- imshow(my_img3)
- my_img4 = transforms.RandomAffine(0, shear=60)(img)
- plt.subplot(224).set_title("shear_only")
- imshow(my_img4)
- plt.tight_layout()
翻转-RandomHorizontalFlip
除了上面的操作之外, 我们还可以对图片进行翻转, 例如下面的例子.
- plt.subplot(121)
- img = Image.open("hymenoptera_data/train/ants/178538489_bec7649292.jpg")
- imshow(img)
- plt.subplot(122)
- transform = torchvision.transforms.RandomHorizontalFlip()
- new_img = transform(img)
- print(new_img.size)
- imshow(new_img)
最终的效果如下图所示, 下面两个图的区别就是进行了翻转:
- 微信公众号
- 关注微信公众号
- QQ群
- 我们的QQ群号
评论