返回最大值的index pytorch方式

 更新时间:2022年07月16日 15:37:45   作者:catbird233  
这篇文章主要介绍了返回最大值的index pytorch方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回最大值的index

import torch
a=torch.tensor([[.1,.2,.3],
                [1.1,1.2,1.3],
                [2.1,2.2,2.3],
                [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())

输出:

tensor([ 2,  2,  2,  2])
tensor(11)

pytorch 找最大值

题意:使用神经网络实现,从数组中找出最大值。

提供数据:两个 csv 文件,一个存训练集:n 个 m 维特征自然数数据,另一个存每条数据对应的 label ,就是每条数据中的最大值。

这里将随机构建训练集:

#%%
import numpy as np 
import pandas as pd 
import torch 
import random 
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim
  
def GetData(m, n):
    dataset = []
    for j in range(m):
        max_v = random.randint(0, 9)
        data = [random.randint(0, 9) for i in range(n)]
        dataset.append(data)
    label = [max(dataset[i]) for i in  range(len(dataset))]
    data_list = np.column_stack((dataset, label))
    data_list = data_list.astype(np.float32)
    return data_list
 
#%%
# 数据集封装 重载函数len, getitem
class GetMaxEle(Data.Dataset):
    def __init__(self, trainset):
        self.data = trainset 
 
    def __getitem__(self, index):
        item = self.data[index]
        x = item[:-1]
        y = item[-1]
        return x, y
    
    def __len__(self):
        return len(self.data)
 
# %% 定义网络模型
class SingleNN(nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(SingleNN, self).__init__()
        
        self.hidden = nn.Linear(n_feature, n_hidden)
        self.relu = nn.ReLU()
        self.predict = nn.Linear(n_hidden, n_output)
 
    def forward(self, x):
        x = self.hidden(x)
        x = self.relu(x)
        x = self.predict(x)
        return x
  
def train(m, n, batch_size, PATH):
    # 随机生成 m 个 n 个维度的训练样本
    data_list =GetData(m, n)
    dataset = GetMaxEle(data_list)
    trainset = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                      shuffle=True)
 
    net = SingleNN(n_feature=10, n_hidden=100,
                   n_output=10)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    #
    total_epoch = 100
    for epoch in range(total_epoch):
        for index, data in enumerate(trainset):
            input_x, labels = data
            labels = labels.long()
            optimizer.zero_grad()
 
            output = net(input_x)
            # print(output)
            # print(labels)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
 
        # scheduled_optimizer.step()
        print(f"Epoch {epoch}, loss:{loss.item()}")
 
    # %% 保存参数
    torch.save(net.state_dict(), PATH)
    #测试 
  
def test(m, n, batch_size, PATH):
    data_list = GetData(m, n)
    dataset = GetMaxEle(data_list)
    testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    dataiter = iter(testloader)
    input_x, labels = dataiter.next()
    net = SingleNN(n_feature=10, n_hidden=100,
                   n_output=10)
    net.load_state_dict(torch.load(PATH))
    outputs = net(input_x)
 
    _, predicted = torch.max(outputs, 1)
    print("Ground_truth:",labels.numpy())
    print("predicted:",predicted.numpy())
  
if __name__ == "__main__":
    m = 1000
    n = 10
    batch_size = 64
    PATH = './max_list.pth'
    train(m, n, batch_size, PATH)
    test(m, n, batch_size, PATH)

初始的想法是使用全连接网络+分类来实现, 但是结果不尽人意,主要原因:不同类别之间的样本量差太大,几乎90%都是最大值。

比如代码中随机构建 10 个 0~9 的数字构成一个样本[2, 3, 5, 8, 9, 5, 3, 9, 3, 6], 该样本标签是9。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用python把Excel中的数据在页面中可视化

    使用python把Excel中的数据在页面中可视化

    最近学习数据分析,感觉Python做数据分析真的好用,下面这篇文章主要给大家介绍了关于如何使用python把Excel中的数据在页面中可视化的相关资料,需要的朋友可以参考下
    2022-03-03
  • Python简单删除列表中相同元素的方法示例

    Python简单删除列表中相同元素的方法示例

    这篇文章主要介绍了Python简单删除列表中相同元素的方法,结合具体实例形式分析了Python使用list、set方法针对列表元素的去重与排序操作实现技巧,非常简单实用,需要的朋友可以参考下
    2017-06-06
  • Python如何设置指定窗口为前台活动窗口

    Python如何设置指定窗口为前台活动窗口

    这篇文章主要介绍了Python如何设置指定窗口为前台活动窗口,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-08-08
  • Python使用ThreadPoolExecutor一次开启多个线程

    Python使用ThreadPoolExecutor一次开启多个线程

    通过使用ThreadPoolExecutor,您可以同时开启多个线程,从而提高程序的并发性能,本文就来介绍一下Python使用ThreadPoolExecutor一次开启多个线程,感兴趣的可以了解一下
    2023-11-11
  • Python实现TCP探测目标服务路由轨迹的原理与方法详解

    Python实现TCP探测目标服务路由轨迹的原理与方法详解

    这篇文章主要介绍了Python实现TCP探测目标服务路由轨迹的原理与方法,结合实例形式分析了Python TCP探测目标服务路由轨迹的原理、实现方法及相关操作注意事项,需要的朋友可以参考下
    2019-09-09
  • Python requests HTTP验证登录实现流程

    Python requests HTTP验证登录实现流程

    这篇文章主要介绍了Python requests HTTP验证登录实现流程,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • django和vue实现数据交互的方法

    django和vue实现数据交互的方法

    今天小编就为大家分享一篇django和vue实现数据交互的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • Jinja2实现模板渲染与访问对象属性流程详解

    Jinja2实现模板渲染与访问对象属性流程详解

    要了解jinja2,那么需要先理解模板的概念。模板在Python的web开发中广泛使用,它能够有效的将业务逻辑和页面逻辑分开,使代码可读性增强,并且更加容易理解和维护。模板简单来说就是一个其中包含占位变量表示动态部分的文,模板文件在经过动态赋值后,返回给用户
    2023-03-03
  • python利用rsa库做公钥解密的方法教程

    python利用rsa库做公钥解密的方法教程

    RSA是一种公钥密码算法,RSA的密文是对代码明文的数字的 E 次方求mod N 的结果。下面这篇文章主要给大家介绍了关于python利用rsa库做公钥解密的方法教程,文中通过示例代码介绍的非常详细,需要的朋友可以参考下。
    2017-12-12
  • python二进制文件的转译详解

    python二进制文件的转译详解

    这篇文章主要介绍了python二进制文件的转译详解的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07

最新评论