文章目录(Table of Contents)
前言
有的时候,我们会在训练的时候训练数据集较大,无法全部导入到内存中去,于是就有了这篇文章。这里会讲几种我自己在实际使用过程中使用的方法。
不过不管使用什么样的方法,都是需要自己来重写torch.utils.data.Dataset的。在这里也是进行记录。
方法一--使用HDF5文件
首先说一下总体的做法。
- 首先将csv文件转换为HDF5文件
- 定义MyDataset类, 继承Dataset, 重写抽象方法: __len()__, __getitem()__
-
- __len()__ : 此方法应该提供数据集的大小(容量)
- __getitem()__ : 此方法应该提供支持下标索方式引访问数据集
将csv文件转为HDF5文件
首先第一个步骤,原始文件为有一个较大的csv文件,我们无法全部读入内存中去,于是我们先转换为HDF5文件。
- # csv文件的路径
- csv_path = './data/train_data_sample.txt'
- # 获得文件总的行数
- num_lines = 0
- with open(csv_path,'r') as f:
- for line in f:
- num_lines = num_lines + 1
- line_data = line.split(',')
- print('num_lines : ',num_lines)
- num_features = 4
- class_dict = {'query_id': 0,
- 'query': 1,
- 'query_title_id': 2,
- 'title' : 3}
- # 每次读取的行数
- chunksize = 20
- dt = h5py.special_dtype(vlen=str) # 数据类型为string类型
- # 创建HDF5 数据集
- with h5py.File('./train_data.h5', 'w') as h5f:
- dset1 = h5f.create_dataset('features',
- shape=(num_lines, num_features),
- compression=None,
- dtype=dt)
- dset2 = h5f.create_dataset('labels',
- shape=(num_lines,),
- compression=None,
- dtype='int32')
- for i in range(0, num_lines, chunksize):
- df = pd.read_csv(csv_path,
- header=None,
- nrows = chunksize,
- skiprows=i) # 跳过读取的行数
- features = df.values[:,:4]
- labels = df.values[:,4]
- dset1[i:i+chunksize, :] = features.astype(dt)
- dset2[i:i+chunksize] = labels.astype('int32') # 这里一定要做数据类型转换(默认是string)
- print("i/Num_Line : {}/{}".format(i,num_lines))
我们打印一下数据的大小,查看一下是否和我们想要的大小是一样的。
- # 打印一下测试数据
- with h5py.File('./train_data.h5', 'r') as h5f:
- print(h5f['features'].shape)
- print(h5f['labels'].shape)
- """
- (20000, 4)
- (20000,)
- """
接下来看一下打印一下具体的数据,来查看一下是否正确。
- # 打印一下具体数据
- with h5py.File('./train_data.h5', 'r') as h5f:
- print('First feature entry', h5f['features'][5000:5006])
- print('First label entry', h5f['labels'][5000:5006])
- """
- First feature entry [['733' '1451 64 7903 59 13 13904' '2'
- '5877 7 31 56 1947 15 1451 27 543 38038 31']
- ['733' '1451 64 7903 59 13 13904' '3'
- '22 110 14 497 193 37 86 36 10245 115 29278 5511 33 7903 136']
- ['734' '4107 9646 10397 8469 8469 7073' '1'
- '12 6625 24455 16 89117 27 702 296 2091 14164 3922 11']
- ['734' '4107 9646 10397 8469 8469 7073' '2'
- '10471 292 1526 580 15 27 4844 38616 25752 27 757 54 23229 220223']
- ['734' '4107 9646 10397 8469 8469 7073' '3'
- '4107 12 348 5019 27 16 26606 122 21373 27 5888 25746 27 93241']
- ['734' '4107 9646 10397 8469 8469 7073' '4'
- '1451 64 119 27 6278 22 15 4497 53598 31']]
- First label entry [1 0 0 0 0 1]
- """
重写Dataset类
在这里,我们需要继承Dataset, 重写抽象方法: len(), getitem()。其中这两个方法的作用分别如下:
下面是针对HDF5数据的读取方式。
- class MyDataset(torch.utils.data.Dataset):
- def __init__(self, fileName, features='features', labels='labels', transform=None):
- self.h5f = h5py.File(fileName, 'r')
- self.data_X = self.h5f[features] # 返回特征数据
- self.data_Y = self.h5f[labels] # 返回label数据
- self.size = self.data_X.shape[0] # 数据集的整个大小
- self.transform = transform # 对原始数据进行变换
- def __getitem__(self, idx):
- # self.data_X = transform(self.data_X)
- # self.data_Y = transform(self.data_Y)
- query = torch.tensor([int(i) for i in self.data_X[idx,1].split(' ')]).long()
- query_len = query.size(0) # 获取query填充前的真实长度
- # query = self.pad_sequences(query, 50)
- title = torch.tensor([int(i) for i in self.data_X[idx,3].split(' ')]).long()
- title_len = title.size(0) # 获取title填充前的真实长度
- # title = self.pad_sequences(title, 20)
- content = torch.cat([query, title],dim=0)
- content_len = query_len + title_len
- if content_len > 100: # 大于100的len也要修改为100
- content_len = 100
- content = self.pad_sequences(content, 100)
- labels = torch.tensor(self.data_Y[idx]).long()
- return content, content_len, labels
- def __len__(self):
- return self.size
- def pad_sequences(self, x, max_len):
- """定义自动填充的函数
- """
- padded = np.zeros((max_len), dtype=np.int64)
- if len(x) > max_len:
- padded[:] = x[:max_len]
- else:
- padded[:len(x)] = x
- return padded
- train_dataset = MyDataset(fileName='./train_data.h5', features='features', labels='labels')
我自己实验了一下,这种方法在数据量很大的时候,数据集进行转换需要消耗较长的时间。所以后来没有使用这种方式,使用了下面的方式。
图片的例子
这里提供一个图片的例子 : pytorch学习(四)—自定义数据集
对于灰度图, 我们首先需要进行转换, 因为datasets.ImageFolder默认是会按照RGB进行导入, 所以我们需要首先转换为grayscale, 我们可以实验transforms.Grayscale来进行转换, 如下面的例子所展示.
同时, 下面的例子展示了如何对灰度图进行Normalize.
- trans = transforms.Compose([
- # transforms.Resize(64),
- transforms.Grayscale(num_output_channels=1), # 转换为灰度图
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.5], std=[0.5])
- ])
方法二--将大文件切分
其实这个方法和Pytorch没什么关系,就是把大文件切分成小文件,之后使用小文件来进行训练。
- # 分别生成三组对应的数据
- csv_path = '/home/kesci/input/bytedance/first-round/train.csv'
- base = 1000000 # 每个文件都存储100万条记录
- names_list = ['train', 'stacking', 'test']
- # train, test, stacking中数据各占的比例
- iterations_dict = {'train':90, 'stacking':5, 'test':5}
- skiprow = 0
- for name in names_list:
- iterations = iterations_dict[name]
- for itera in range(1, iterations+1):
- df = pd.read_csv(csv_path,
- header=None,
- nrows = base,
- skiprows=skiprow) # 跳过读取的行数
- skiprow = skiprow + base
- df.to_csv('/home/kesci/work/data/{}/{}_{}.csv'.format(name,name,itera), header=False, index=False)
- print('文件{}_{}.csv导出成功. skiprow={}'.format(name, itera,skiprow))
参考链接
- 微信公众号
- 关注微信公众号
- QQ群
- 我们的QQ群号
2021年3月2日 上午11:31 1F
楼主您好,想咨询一下,采用第二种方式,如何加载多个小文件在一个模型上训练,我也在您的公众号上咨询了
2021年3月4日 下午6:02 B1
@ 葛一飞 你好,不知道我理解的是否正确,多个小文件,就类似于多张图片,训练时候同时加载多个图片。可以有下面两种(我觉得):
1. 如果使用 Pytroch 需要重写 Dataset 类
2. 将多个小文件先保存为 npy 格式再加载即可。
2021年3月4日 下午6:02 B1
@ 葛一飞 另外,你就直接在网站留言就行,公众号现在看的比较少。