python深度学习之多标签分类器及pytorch实现源码

 更新时间:2022年01月30日 09:14:08   作者:鬼道2022  
这篇文章主要为大家介绍了python深度学习之多标签分类器的使用说明及pytorch的实现源码,有需要的朋友可以借鉴参考下,希望能够有所帮助

多标签分类器

多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:

  • 类标数量不确定,有些样本可能只有一个类标,有些样本的类标可能高达几十甚至上百个
  • 类标之间相互依赖,例如包含蓝天类标的样本很大概率上包含白云

如下图所示,即为一个多标签分类学习的一个例子,一张图片里有多个类别,房子,树,云等,深度学习模型需要将其一一分类识别出来。

多标签分类器损失函数

代码实现

针对图像的多标签分类器pytorch的简化代码实现如下所示。因为图像的多标签分类器的数据集比较难获取,所以可以通过对mnist数据集中的每个图片打上特定的多标签,例如类别1的多标签可以为[1,1,0,1,0,1,0,0,1],然后再利用重新打标后的数据集训练出一个mnist的多标签分类器。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.Sq1 = nn.Sequential(         
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),   # (16, 28, 28)                           #  output: (16, 28, 28)
            nn.ReLU(),                    
            nn.MaxPool2d(kernel_size=2),    # (16, 14, 14)
        )
        self.Sq2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),  # (32, 14, 14)
            nn.ReLU(),                      
            nn.MaxPool2d(2),                # (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 100)  
    def forward(self, x):
        x = self.Sq1(x)
        x = self.Sq2(x)
        x = x.view(x.size(0), -1)    
        x = self.out(x)
        ## Sigmoid activation   
        output = F.sigmoid(x)  # 1/(1+e**(-x))
        return output
def loss_fn(pred, target):
    return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum()
def multilabel_generate(label):
    Y1 = F.one_hot(label, num_classes = 100)
    Y2 = F.one_hot(label+10, num_classes = 100)
    Y3 = F.one_hot(label+50, num_classes = 100) 	
    multilabel = Y1+Y2+Y3
    return multilabel
        
# def multilabel_generate(label):
# 	multilabel_dict = {}
# 	multi_list = []
# 	for i in range(label.shape[0]):
# 		multi_list.append(multilabel_dict[label[i].item()])
# 	multilabel_tensor = torch.tensor(multi_list)
#     return multilabel
def train():
    epoches = 10
    mnist_net = CNN()
    mnist_net.train()
    opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)
    mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True)
    for epoch in range(epoches):
    	loss = 0 
    	for batch_X, batch_Y in train_loader:
    		opitimizer.zero_grad()
    		outputs = mnist_net(batch_X)
    		loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0]
    		loss.backward()
    		opitimizer.step()
    		print(loss)
if __name__ == '__main__':
	train()

以上就是python深度学习之多标签分类器及pytorch源码的详细内容,更多关于多标签分类器pytorch源码的资料请关注脚本之家其它相关文章!

相关文章

  • Pandas根据条件实现替换列中的值

    Pandas根据条件实现替换列中的值

    在使用Pandas的Python中,DataFrame列中的值可以通过使用各种内置函数根据条件进行替换,本文主要来和大家讨论在Pandas中用条件替换数据集列中的值的各种方法,希望对大家有所帮助
    2024-01-01
  • python 经纬度求两点距离、三点面积操作

    python 经纬度求两点距离、三点面积操作

    这篇文章主要介绍了python 经纬度求两点距离、三点面积操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-06-06
  • python实例化对象的具体方法

    python实例化对象的具体方法

    在本篇文章里小编给大家整理的是关于python实例化对象的具体方法,有兴趣的朋友们可以学习下。
    2020-06-06
  • Python字典生成式、集合生成式、生成器用法实例分析

    Python字典生成式、集合生成式、生成器用法实例分析

    这篇文章主要介绍了Python字典生成式、集合生成式、生成器用法,结合实例形式分析了Python字典生成式、集合生成式、生成器相关原理、使用技巧与操作注意事项,需要的朋友可以参考下
    2020-01-01
  • python使用OS模块操作系统接口及常用功能详解

    python使用OS模块操作系统接口及常用功能详解

    os是 Python 标准库中的一个模块,提供了与操作系统交互的功能,在本节中,我们将介绍os模块的一些常用功能,并通过实例代码详细讲解每个知识点
    2023-06-06
  • Python实现破解猜数游戏算法示例

    Python实现破解猜数游戏算法示例

    这篇文章主要介绍了Python实现破解猜数游戏算法,简单描述了猜数游戏的原理,并结合具体实例形式分析了Python破解猜数游戏的相关实现技巧,需要的朋友可以参考下
    2017-09-09
  • Python爬虫入门案例之爬取去哪儿旅游景点攻略以及可视化分析

    Python爬虫入门案例之爬取去哪儿旅游景点攻略以及可视化分析

    读万卷书不如行万里路,学的扎不扎实要通过实战才能看出来,本篇文章手把手带你爬取去哪儿平台的旅游景点攻略并进行可视化分析,大家可以在过程中查缺补漏,看看自己掌握程度怎么样
    2021-10-10
  • Django-Rest-Framework 权限管理源码浅析(小结)

    Django-Rest-Framework 权限管理源码浅析(小结)

    这篇文章主要介绍了Django-Rest-Framework 权限管理源码浅析(小结),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-11-11
  • Python可视化函数plt.scatter详解

    Python可视化函数plt.scatter详解

    这篇文章主要介绍了Python可视化函数plt.scatter详解, 关于matplotlib的scatter函数有许多活动参数,如果不专门注解,是无法掌握精髓的,本文专门针对scatter的参数和调用说起,并配有若干案例,需要的朋友可以参考下
    2023-04-04
  • python的pip安装以及使用教程

    python的pip安装以及使用教程

    这篇文章主要为大家详细介绍了python的pip安装以及使用教程,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-09-09

最新评论