Pytorch训练时候导入大量数据(What’s the best way to load large data?)

  • A+
所属分类:深度学习
摘要这一篇文章主要讲一下在Pytorch中,如何处理数据量较大,无法全部导入memory的情况。同时,也会说明一下如何使用Pytorch中的Dataset。

前言

有的时候,我们会在训练的时候训练数据集较大,无法全部导入到内存中去,于是就有了这篇文章。这里会讲几种我自己在实际使用过程中使用的方法。

不过不管使用什么样的方法,都是需要自己来重写torch.utils.data.Dataset的。在这里也是进行记录。

方法一--使用HDF5文件

首先说一下总体的做法。

  1. 首先将csv文件转换为HDF5文件
  2. 定义MyDataset类, 继承Dataset, 重写抽象方法: __len()__, __getitem()__
    • __len()__ : 此方法应该提供数据集的大小(容量)
    • __getitem()__ : 此方法应该提供支持下标索方式引访问数据集

将csv文件转为HDF5文件

首先第一个步骤,原始文件为有一个较大的csv文件,我们无法全部读入内存中去,于是我们先转换为HDF5文件。

  1. # csv文件的路径
  2. csv_path = './data/train_data_sample.txt'
  3. # 获得文件总的行数
  4. num_lines = 0
  5. with open(csv_path,'r') as f:
  6.     for line in f:
  7.         num_lines = num_lines + 1
  8.         line_data = line.split(',')
  9. print('num_lines : ',num_lines)
  10. num_features = 4
  11. class_dict = {'query_id': 0,
  12.               'query': 1,
  13.               'query_title_id': 2,
  14.               'title' : 3}
  15. # 每次读取的行数
  16. chunksize = 20
  17. dt = h5py.special_dtype(vlen=str# 数据类型为string类型
  18. # 创建HDF5 数据集
  19. with h5py.File('./train_data.h5', 'w') as h5f:
  20.     dset1 = h5f.create_dataset('features',
  21.                               shape=(num_lines, num_features),
  22.                               compression=None,
  23.                               dtype=dt)
  24.     dset2 = h5f.create_dataset('labels',
  25.                               shape=(num_lines,),
  26.                                compression=None,
  27.                                dtype='int32')
  28.     for i in range(0, num_lines, chunksize):
  29.         df = pd.read_csv(csv_path,
  30.                          header=None,
  31.                          nrows = chunksize,
  32.                         skiprows=i) # 跳过读取的行数
  33.         features = df.values[:,:4]
  34.         labels = df.values[:,4]
  35.         dset1[i:i+chunksize, :] = features.astype(dt)
  36.         dset2[i:i+chunksize] = labels.astype('int32') # 这里一定要做数据类型转换(默认是string)
  37.         print("i/Num_Line : {}/{}".format(i,num_lines))

我们打印一下数据的大小,查看一下是否和我们想要的大小是一样的。

  1. # 打印一下测试数据
  2. with h5py.File('./train_data.h5', 'r') as h5f:
  3.     print(h5f['features'].shape)
  4.     print(h5f['labels'].shape)
  5. """
  6. (20000, 4)
  7. (20000,)
  8. """

接下来看一下打印一下具体的数据,来查看一下是否正确。

  1. # 打印一下具体数据
  2. with h5py.File('./train_data.h5', 'r') as h5f:
  3.     print('First feature entry', h5f['features'][5000:5006])
  4.     print('First label entry', h5f['labels'][5000:5006])
  5. """
  6. First feature entry [['733' '1451 64 7903 59 13 13904' '2'
  7.   '5877 7 31 56 1947 15 1451 27 543 38038 31']
  8.  ['733' '1451 64 7903 59 13 13904' '3'
  9.   '22 110 14 497 193 37 86 36 10245 115 29278 5511 33 7903 136']
  10.  ['734' '4107 9646 10397 8469 8469 7073' '1'
  11.   '12 6625 24455 16 89117 27 702 296 2091 14164 3922 11']
  12.  ['734' '4107 9646 10397 8469 8469 7073' '2'
  13.   '10471 292 1526 580 15 27 4844 38616 25752 27 757 54 23229 220223']
  14.  ['734' '4107 9646 10397 8469 8469 7073' '3'
  15.   '4107 12 348 5019 27 16 26606 122 21373 27 5888 25746 27 93241']
  16.  ['734' '4107 9646 10397 8469 8469 7073' '4'
  17.   '1451 64 119 27 6278 22 15 4497 53598 31']]
  18. First label entry [1 0 0 0 0 1]
  19. """

 重写Dataset类

在这里,我们需要继承Dataset, 重写抽象方法: len(), getitem()。其中这两个方法的作用分别如下:

Pytorch训练时候导入大量数据(What's the best way to load large data?)

下面是针对HDF5数据的读取方式。

  1. class MyDataset(torch.utils.data.Dataset):
  2.     def __init__(self, fileName, features='features', labels='labels', transform=None):
  3.         self.h5f = h5py.File(fileName, 'r')
  4.         self.data_X = self.h5f[features] # 返回特征数据
  5.         self.data_Y = self.h5f[labels] # 返回label数据
  6.         self.size = self.data_X.shape[0] # 数据集的整个大小
  7.         self.transform = transform # 对原始数据进行变换
  8.     def __getitem__(self, idx):
  9.         # self.data_X = transform(self.data_X)
  10.         # self.data_Y = transform(self.data_Y)
  11.         query = torch.tensor([int(i) for i in self.data_X[idx,1].split(' ')]).long()
  12.         query_len = query.size(0) # 获取query填充前的真实长度
  13.         # query = self.pad_sequences(query, 50)
  14.         title = torch.tensor([int(i) for i in self.data_X[idx,3].split(' ')]).long()
  15.         title_len = title.size(0) # 获取title填充前的真实长度
  16.         # title = self.pad_sequences(title, 20)
  17.         content = torch.cat([query, title],dim=0)
  18.         content_len = query_len + title_len
  19.         if content_len > 100: # 大于100的len也要修改为100
  20.             content_len = 100
  21.         content = self.pad_sequences(content, 100)
  22.         labels = torch.tensor(self.data_Y[idx]).long()
  23.         return content, content_len, labels
  24.     def __len__(self):
  25.         return self.size
  26.     def pad_sequences(self, x, max_len):
  27.         """定义自动填充的函数
  28.         """
  29.         padded = np.zeros((max_len), dtype=np.int64)
  30.         if len(x) > max_len:
  31.             padded[:] = x[:max_len]
  32.         else:
  33.             padded[:len(x)] = x
  34.         return padded
  35. train_dataset = MyDataset(fileName='./train_data.h5', features='features', labels='labels')

我自己实验了一下,这种方法在数据量很大的时候,数据集进行转换需要消耗较长的时间。所以后来没有使用这种方式,使用了下面的方式。

图片的例子

这里提供一个图片的例子 : pytorch学习(四)—自定义数据集

方法二--将大文件切分

其实这个方法和Pytorch没什么关系,就是把大文件切分成小文件,之后使用小文件来进行训练。

  1. # 分别生成三组对应的数据
  2. csv_path = '/home/kesci/input/bytedance/first-round/train.csv'
  3. base = 1000000 # 每个文件都存储100万条记录
  4. names_list = ['train', 'stacking', 'test']
  5. # train, test, stacking中数据各占的比例
  6. iterations_dict = {'train':90, 'stacking':5, 'test':5}
  7. skiprow = 0
  8. for name in names_list:
  9.     iterations = iterations_dict[name]
  10.     for itera in range(1, iterations+1):
  11.         df = pd.read_csv(csv_path,
  12.                          header=None,
  13.                          nrows = base,
  14.                         skiprows=skiprow) # 跳过读取的行数
  15.         skiprow = skiprow + base
  16.         df.to_csv('/home/kesci/work/data/{}/{}_{}.csv'.format(name,name,itera), header=Falseindex=False)
  17.         print('文件{}_{}.csv导出成功. skiprow={}'.format(name, itera,skiprow))
Pytorch训练时候导入大量数据(What's the best way to load large data?)

参考链接

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: