PyTorch加载数据集梯度下降优化

 更新时间:2022年03月09日 17:16:49   作者:心️升明月  
这篇文章主要介绍了PyTorch加载数据集梯度下降优化,使用DataLoader方法,并继承DataSet抽象类,可实现对数据集进行mini_batch梯度下降优化,需要的小伙伴可以参考一下

一、实现过程

1、准备数据

PyTorch实现多维度特征输入的逻辑回归的方法不同的是:本文使用DataLoader方法,并继承DataSet抽象类,可实现对数据集进行mini_batch梯度下降优化。

代码如下:

import torch
import numpy as np
from torch.utils.data import Dataset,DataLoader

class DiabetesDataSet(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:,:-1])
        self.y_data = torch.from_numpy(xy[:,[-1]])
        
    def __getitem__(self, index):
        return self.x_data[index],self.y_data[index]
    
    def __len__(self):
        return self.len

dataset = DiabetesDataSet('G:/datasets/diabetes/diabetes.csv')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)

2、设计模型

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1 = torch.nn.Linear(8,6)
        self.linear2 = torch.nn.Linear(6,4)
        self.linear3 = torch.nn.Linear(4,1)
        self.activate = torch.nn.Sigmoid()
    
    def forward(self, x):
        x = self.activate(self.linear1(x))
        x = self.activate(self.linear2(x))
        x = self.activate(self.linear3(x))
        return x
model = Model()

3、构造损失函数和优化器

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)

4、训练过程

每次拿出mini_batch个样本进行训练,代码如下:

epoch_list = []
loss_list = []
for epoch in range(100):
    count = 0
    loss1 = 0
    for i, data in enumerate(train_loader,0):
        # 1.Prepare data
        inputs, labels = data
        # 2.Forward
        y_pred = model(inputs)
        loss = criterion(y_pred,labels)
        print(epoch,i,loss.item())
        count += 1
        loss1 += loss.item()
        # 3.Backward
        optimizer.zero_grad()
        loss.backward()
        # 4.Update
        optimizer.step()
        
    epoch_list.append(epoch)
    loss_list.append(loss1/count)

5、结果展示

plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid()
plt.show()

二、参考文献

 到此这篇关于PyTorch加载数据集梯度下降优化的文章就介绍到这了,更多相关PyTorch加载数据集内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 对django 2.x版本中models.ForeignKey()外键说明介绍

    对django 2.x版本中models.ForeignKey()外键说明介绍

    这篇文章主要介绍了对django 2.x版本中models.ForeignKey()外键说明介绍,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • python 命令行界面的用户交互及优化

    python 命令行界面的用户交互及优化

    这篇文章主要为大家介绍了python 命令行界面的用户交互及优化方法详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-11-11
  • Python configparser模块配置文件过程解析

    Python configparser模块配置文件过程解析

    这篇文章主要介绍了Python configparser模块配置文件过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-03-03
  • Pandas如何对Categorical类型字段数据统计实战案例

    Pandas如何对Categorical类型字段数据统计实战案例

    这篇文章主要介绍了Pandas如何对Categorical类型字段数据统计实战案例,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-08-08
  • 关于Series的index的方法和属性使用说明

    关于Series的index的方法和属性使用说明

    这篇文章主要介绍了关于Series的index的方法和属性使用说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-06-06
  • PyQt5实现五子棋游戏(人机对弈)

    PyQt5实现五子棋游戏(人机对弈)

    这篇文章主要为大家详细介绍了PyQt5实现五子棋游戏,人机对弈,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-01-01
  • python基础之函数

    python基础之函数

    这篇文章主要介绍了python的函数,实例分析了Python中返回一个返回值与多个返回值的方法,需要的朋友可以参考下
    2021-10-10
  • 解决Django no such table: django_session的问题

    解决Django no such table: django_session的问题

    这篇文章主要介绍了解决Django no such table: django_session的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • python模拟斗地主发牌

    python模拟斗地主发牌

    这篇文章主要为大家详细介绍了python模拟斗地主发牌,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-04-04
  • 快速进修Python指南之迭代器Iterator与生成器

    快速进修Python指南之迭代器Iterator与生成器

    这篇文章主要为大家介绍了Java开发者快速进修Python指南之迭代器Iterator与生成器示例解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-12-12

最新评论