PyTorch如何创建自己的数据集

 更新时间:2022年11月28日 15:06:09   作者:ZQ_ZHU  
这篇文章主要介绍了PyTorch如何创建自己的数据集,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

PyTorch创建自己的数据集

图片文件在同一的文件夹下

思路是继承 torch.utils.data.Dataset,并重点重写其 __getitem__方法,示例代码如下:

class ImageFolder(Dataset):
    def __init__(self, folder_path):
        self.files = sorted(glob.glob('%s/*.*' % folder_path))

    def __getitem__(self, index):
        path = self.files[index % len(self.files)]
        img = np.array(Image.open(path))
        h, w, c = img.shape
        pad = ((40, 40), (4, 4), (0, 0))

        # img = np.pad(img, pad, 'constant', constant_values=0) / 255
        img = np.pad(img, pad, mode='edge') / 255.0
        img = torch.from_numpy(img).float()
        patches = np.reshape(img, (3, 10, 128, 11, 128))
        patches = np.transpose(patches, (0, 1, 3, 2, 4))

        return img, patches, path

    def __len__(self):
        return len(self.files)

图片文件在不同的文件夹下

比如我们有数据如下:

─── data
├── train
│ ├── 0.jpg
│ └── 1.jpg
├── test
│ ├── 0.jpg
│ └── 1.jpg
└── val
├── 1.jpg
└── 2.jpg

此时我们只需要将以上代码稍作修改即可,修改的代码如下:

self.files = sorted(glob.glob('%s/**/*.*' % folder_path, recursive=True))

其他代码不变。

pytorch常用数据集的使用

对于pytorch数据集的使用,示例代码如下:

from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from torchvision import transforms
import torchvision
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

dataset_transform = Compose([transforms.ToTensor()])


# 关于官方数据集的使用还是关键要看pytorch的官方文档
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=dataset_transform,download=True)

# 查看测试数据集中的第一个数据
# print(test_set[0])
# 查看测试数据集中的分类情况
# print(test_set.classes)
#
# 取出第一个数据中的图片(img)和分类结果(target)
# img,target = test_set[0]
# 查看图片数据的类型
# print(img)
# print(target)
# 输出类别
# print(test_set.classes[target])
# 查看图片
# img.show()

# 使用tensorboard显示tensor数据类型的图片
writer = SummaryWriter("logs")
for i in range(10):
	# 取出数据中的图片(img)和分类结果(target)
    img,target = test_set[i]
    writer.add_image("test_set",img,i)

writer.close()

上述代码运行结果在tensorboard可视化:

代码

train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)

常用参数讲解

  • root:根目录,存放数据集的位置
  • train:若为True,则划分为训练数据集,若为False,则划分为测试数据集
  • transform:指定输入数据集处理方式
  • download:若为True,则会将数据集下载到root指定的目录下,否则不会下载

官方文档对参数的解释:

root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.

train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.

transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

注意:

  • 关于官方数据集的使用还是关键要看pytorch的官方文档
  • 下载数据集的细节之处:知道下载链接(下载链接可以在源码中查看)之后可以不用使用代码下载了,使用迅雷来下载可能会更快。
  • 要学会使用Pycharm中的ctrl+p和ctrl+alt这两个快捷键
  • pytorch官网
  • pytorch官方数据集(下载数据集方法)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python光学仿真wxpython透镜演示系统计算与绘图

    Python光学仿真wxpython透镜演示系统计算与绘图

    这篇文章主要为大家介绍了Python光学仿真wxpython透镜演示系统计算与绘图的实现示例。有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-10-10
  • python中isdigit() isalpha()用于判断字符串的类型问题

    python中isdigit() isalpha()用于判断字符串的类型问题

    这篇文章主要介绍了python中isdigit() isalpha()用于判断字符串的类型问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • 浅谈python中str字符串和unicode对象字符串的拼接问题

    浅谈python中str字符串和unicode对象字符串的拼接问题

    今天小编就为大家分享一篇浅谈python中str字符串和unicode对象字符串的拼接问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • Python面向对象程序设计类的多态用法详解

    Python面向对象程序设计类的多态用法详解

    这篇文章主要介绍了Python面向对象程序设计类的多态用法,结合实例形式详细分析了Python面向对象程序设计中类的多态概念、原理、用法及相关操作注意事项,需要的朋友可以参考下
    2019-04-04
  • PyTorch权值初始化原理解析

    PyTorch权值初始化原理解析

    这篇文章主要为大家介绍了PyTorch权值初始化原理示例解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-07-07
  • python3编写ThinkPHP命令执行Getshell的方法

    python3编写ThinkPHP命令执行Getshell的方法

    这篇文章主要介绍了python3编写ThinkPHP命令执行Getshell的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-02-02
  • 深入理解python中sort()与sorted()的区别

    深入理解python中sort()与sorted()的区别

    Python list内置sort()方法用来排序,也可以用python内置的全局sorted()方法来对可迭代的序列排序生成新的序列。这篇文章主要介绍了python中sort()与sorted()的区别,需要的朋友可以参考下
    2018-08-08
  • OpenCv实现绘图功能

    OpenCv实现绘图功能

    这篇文章主要为大家详细介绍了OpenCv实现绘图功能,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-05-05
  • python中gevent库的用法详情

    python中gevent库的用法详情

    这篇文章主要介绍了python中gevent库的用法详情,Greenlet全部运行在主程序操作系统的过程中,但是它们是协作调度的,文章围绕主题展开详细的内容介绍,具有一定的参考价值
    2022-07-07
  • 使用Python机器学习降低静态日志噪声

    使用Python机器学习降低静态日志噪声

    今天小编就为大家分享一篇关于使用Python和机器学习的静态日志噪声的文章,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2018-09-09

最新评论