关于pytorch处理类别不平衡的问题

 更新时间:2019年12月31日 09:09:22   作者:NAAE  
今天小编就为大家分享一篇关于pytorch处理类别不平衡的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python实现LRU热点缓存及原理

    python实现LRU热点缓存及原理

    LRU算法根据数据的历史访问记录来进行淘汰数据,其核心思想是“如果数据最近被访问过,那么将来被访问的几率也更高”。 。这篇文章主要介绍了python实现LRU热点缓存,需要的朋友可以参考下
    2019-10-10
  • 基于Python绘制一个会动的3D立体粽子

    基于Python绘制一个会动的3D立体粽子

    下周就要到端午节了,所以本文小编就来和大家分享一个有趣的Python项目——绘制会动的3D立体粽子,文中的示例代码讲解详细,感兴趣的可以了解一下
    2023-06-06
  • python实现吃苹果小游戏

    python实现吃苹果小游戏

    这篇文章主要为大家详细介绍了python实现吃苹果小游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-03-03
  • pycharm设置虚拟环境与更换镜像教程

    pycharm设置虚拟环境与更换镜像教程

    这篇文章主要介绍了pycharm设置虚拟环境与更换镜像教程,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-09-09
  • python实现小球弹跳效果

    python实现小球弹跳效果

    这篇文章主要为大家详细介绍了python实现小球弹跳效果,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-05-05
  • python 接口实现 供第三方调用的例子

    python 接口实现 供第三方调用的例子

    今天小编就为大家分享一篇python 接口实现 供第三方调用的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • python 在mysql中插入null空值的操作

    python 在mysql中插入null空值的操作

    这篇文章主要介绍了python 在mysql中插入null空值的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • python2.7 安装pip的方法步骤(管用)

    python2.7 安装pip的方法步骤(管用)

    这篇文章主要介绍了python2.7 安装pip的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-05-05
  • Django修改端口号与地址的三种方式

    Django修改端口号与地址的三种方式

    Django是一个开放源代码的Web应用框架,由Python写成,下面这篇文章主要给大家介绍了关于Django修改端口号与地址的三种方式,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2023-02-02
  • Python进行密码学反向密码教程

    Python进行密码学反向密码教程

    这篇文章主要为大家介绍了Python进行密码学反向密码的教程详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05

最新评论