PyTorch加载数据集梯度下降优化

 更新时间:2022年03月09日 17:16:49   作者:心️升明月  
这篇文章主要介绍了PyTorch加载数据集梯度下降优化,使用DataLoader方法,并继承DataSet抽象类,可实现对数据集进行mini_batch梯度下降优化,需要的小伙伴可以参考一下

一、实现过程

1、准备数据

PyTorch实现多维度特征输入的逻辑回归的方法不同的是:本文使用DataLoader方法,并继承DataSet抽象类,可实现对数据集进行mini_batch梯度下降优化。

代码如下:

import torch
import numpy as np
from torch.utils.data import Dataset,DataLoader

class DiabetesDataSet(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:,:-1])
        self.y_data = torch.from_numpy(xy[:,[-1]])
        
    def __getitem__(self, index):
        return self.x_data[index],self.y_data[index]
    
    def __len__(self):
        return self.len

dataset = DiabetesDataSet('G:/datasets/diabetes/diabetes.csv')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)

2、设计模型

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1 = torch.nn.Linear(8,6)
        self.linear2 = torch.nn.Linear(6,4)
        self.linear3 = torch.nn.Linear(4,1)
        self.activate = torch.nn.Sigmoid()
    
    def forward(self, x):
        x = self.activate(self.linear1(x))
        x = self.activate(self.linear2(x))
        x = self.activate(self.linear3(x))
        return x
model = Model()

3、构造损失函数和优化器

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)

4、训练过程

每次拿出mini_batch个样本进行训练,代码如下:

epoch_list = []
loss_list = []
for epoch in range(100):
    count = 0
    loss1 = 0
    for i, data in enumerate(train_loader,0):
        # 1.Prepare data
        inputs, labels = data
        # 2.Forward
        y_pred = model(inputs)
        loss = criterion(y_pred,labels)
        print(epoch,i,loss.item())
        count += 1
        loss1 += loss.item()
        # 3.Backward
        optimizer.zero_grad()
        loss.backward()
        # 4.Update
        optimizer.step()
        
    epoch_list.append(epoch)
    loss_list.append(loss1/count)

5、结果展示

plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid()
plt.show()

二、参考文献

 到此这篇关于PyTorch加载数据集梯度下降优化的文章就介绍到这了,更多相关PyTorch加载数据集内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • pytorch 实现冻结部分参数训练另一部分

    pytorch 实现冻结部分参数训练另一部分

    这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • python爬虫之利用selenium模块自动登录CSDN

    python爬虫之利用selenium模块自动登录CSDN

    这篇文章主要介绍了python爬虫之利用selenium模块自动登录CSDN,文中有非常详细的代码示例,对正在学习python的小伙伴们有很好地帮助,需要的朋友可以参考下
    2021-04-04
  • Python字符串本身作为bytes进行解码的问题

    Python字符串本身作为bytes进行解码的问题

    这篇文章主要介绍了解决Python字符串本身作为bytes进行解码的问题,文末给大家补充介绍了,Python字符串如何转为bytes对象?Python字符串和bytes类型怎么互转,需要的朋友可以参考下
    2022-11-11
  • python使用yield压平嵌套字典的超简单方法

    python使用yield压平嵌套字典的超简单方法

    这篇文章主要给大家介绍了关于python使用yield压平嵌套字典的超简单方法,文中通过示例代码介绍的非常详细,对大家的学习或者使用python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-11-11
  • Python做文本按行去重的实现方法

    Python做文本按行去重的实现方法

    每行在promotion后面包含一些数字,如果这些数字是相同的,则认为是相同的行,对于相同的行,只保留一行。接下来通过本文给大家介绍Python做文本按行去重的实现方法,感兴趣的朋友一起看看吧
    2016-10-10
  • 在PyCharm下使用 ipython 交互式编程的方法

    在PyCharm下使用 ipython 交互式编程的方法

    今天小编就为大家分享一篇在PyCharm下使用 ipython 交互式编程的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • Python实现识别图像中人物的示例代码

    Python实现识别图像中人物的示例代码

    这篇文章主要介绍了通过face_recognition提供的demo代码,简单调整了一下,从而实现识别图像中人物的功能,感兴趣的可以跟随小编一起试试
    2022-01-01
  • python制作英语翻译小工具代码实例

    python制作英语翻译小工具代码实例

    这篇文章主要介绍了python制作英语翻译小工具代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • Python3.4实现远程控制电脑开关机

    Python3.4实现远程控制电脑开关机

    这篇文章主要为大家详细介绍了Python3.4实现远程控制电脑开关机的方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-02-02
  • Python 求数组局部最大值的实例

    Python 求数组局部最大值的实例

    今天小编就为大家分享一篇Python 求数组局部最大值的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11

最新评论