Pytorch DataLoader shuffle验证方式
shuffle = False时,不打乱数据顺序
shuffle = True,随机打乱
import numpy as np import h5py import torch from torch.utils.data import DataLoader, Dataset h5f = h5py.File('train.h5', 'w'); data1 = np.array([[1,2,3], [2,5,6], [3,5,6], [4,5,6]]) data2 = np.array([[1,1,1], [1,2,6], [1,3,6], [1,4,6]]) h5f.create_dataset(str('data'), data=data1) h5f.create_dataset(str('label'), data=data2) class Dataset(Dataset): def __init__(self): h5f = h5py.File('train.h5', 'r') self.data = h5f['data'] self.label = h5f['label'] def __getitem__(self, index): data = torch.from_numpy(self.data[index]) label = torch.from_numpy(self.label[index]) return data, label def __len__(self): assert self.data.shape[0] == self.label.shape[0], "wrong data length" return self.data.shape[0] dataset_train = Dataset() loader_train = DataLoader(dataset=dataset_train, batch_size=2, shuffle = True) for i, data in enumerate(loader_train): train_data, label = data print(train_data)
pytorch DataLoader使用细节
背景:
我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,
数据变换共有以下内容
composed = transforms.Compose([transforms.Resize((448, 448)), # resize transforms.RandomCrop(300), # random crop transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], # normalize std=[0.5, 0.5, 0.5])])
简单的数据读取类, 进返回PIL格式的image:
class MyDataset(data.Dataset): def __init__(self, labels_file, root_dir, transform=None): with open(labels_file) as csvfile: self.labels_file = list(csv.reader(csvfile)) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.labels_file) def __getitem__(self, idx): im_name = os.path.join(root_dir, self.labels_file[idx][0]) im = Image.open(im_name) if self.transform: im = self.transform(im) return im
下面是主程序
labels_file = "F:/test_temp/labels.csv" root_dir = "F:/test_temp" dataset_transform = MyDataset(labels_file, root_dir, transform=composed) dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False) """原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张) """ for eopch in range(2): plt.figure(figsize=(6, 6)) for ind, i in enumerate(dataloader): a = i[0, :, :, :].numpy().transpose((1, 2, 0)) plt.subplot(1, 3, ind+1) plt.imshow(a)
从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
用Python自动清理电脑内重复文件,只要10行代码(自动脚本)
这篇文章主要介绍了用Python自动清理电脑内重复文件,只要10行代码,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下2021-01-01python3.6+selenium实现操作Frame中的页面元素
这篇文章主要为大家详细介绍了python3.6+selenium实现操作Frame中的页面元素,具有一定的参考价值,感兴趣的小伙伴们可以参考一下2019-07-07python3实现语音转文字(语音识别)和文字转语音(语音合成)
这篇文章主要介绍了python3实现语音转文字(语音识别)和文字转语音(语音合成),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧2020-10-10Python tkinter 多选按钮控件 Checkbutton方法
这篇文章主要介绍了Python tkinter 多选按钮控件 Checkbutton方法,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的朋友可以参考一下2022-07-07在PyCharm中遇到pip安装 失败问题及解决方案(pip失效时的解决方案)
这篇文章主要介绍了在PyCharm中遇到pip安装失败问题及解决方案(pip失效时的解决方案),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下2020-03-03
最新评论