pytorch 自定义数据集加载方法

 更新时间:2019年08月18日 08:51:08   作者:xholes  
今天小编就为大家分享一篇pytorch 自定义数据集加载方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) – 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size – batch的大小(可选项,默认值为1)。

shuffle – 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler – 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers – 表示读取样本的线程数, 0表示只有主线程。

collate_fn – 合并一个样本列表称为一个batch。

pin_memory – 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) – 设置是否丢弃最后一个不完整的batch,默认为False。

timeout – 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python适合做数据挖掘吗

    python适合做数据挖掘吗

    在本篇文章里小编给各位分享的是一篇关于python做数据挖掘的相关知识点内容,有兴趣的朋友们可以学习下。
    2020-06-06
  • python中的属性管理机制详解

    python中的属性管理机制详解

    这篇文章主要介绍了python中的属性管理机制,主要包括私有属性和属性限制-__slots__方法,文中详细介绍了python中如何去声明变量的相关知识,需要的朋友可以参考下
    2022-06-06
  • python 脚本生成随机 字母 + 数字密码功能

    python 脚本生成随机 字母 + 数字密码功能

    本文通过一小段简单的代码给大家分享基于python 脚本生成随机 字母 + 数字密码功能,感兴趣的朋友跟随脚本之家小编一起学习吧
    2018-05-05
  • python使用列表的最佳方案

    python使用列表的最佳方案

    这篇文章主要介绍了python使用列表的最佳方式,帮助大家更好的理解和学习python,感兴趣的朋友可以了解下
    2020-08-08
  • Python3搜索及替换文件中文本的方法

    Python3搜索及替换文件中文本的方法

    这篇文章主要介绍了Python3搜索及替换文件中文本的方法,涉及Python操作文件及字符串的相关技巧,需要的朋友可以参考下
    2015-05-05
  • 完美解决python中ndarray 默认用科学计数法显示的问题

    完美解决python中ndarray 默认用科学计数法显示的问题

    今天小编就为大家分享一篇完美解决python中ndarray 默认用科学计数法显示的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • python爬虫爬取微博评论案例详解

    python爬虫爬取微博评论案例详解

    这篇文章主要介绍了python爬虫爬取微博评论,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • 用python编写第一个IDA插件的实例

    用python编写第一个IDA插件的实例

    今天小编就为大家分享一篇用python编写第一个IDA插件的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • PyCharm利用pydevd-pycharm实现Python远程调试的详细过程

    PyCharm利用pydevd-pycharm实现Python远程调试的详细过程

    这篇文章主要介绍了PyCharm利用pydevd-pycharm实现Python远程调试,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-09-09
  • Pyspark读取parquet数据过程解析

    Pyspark读取parquet数据过程解析

    这篇文章主要介绍了pyspark读取parquet数据过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-03-03

最新评论