Pytorch实现逻辑回归分类

 更新时间:2022年07月30日 11:10:06   作者:远方与你  
这篇文章主要为大家详细介绍了Pytorch实现逻辑回归分类,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了Pytorch实现逻辑回归分类的具体代码,供大家参考,具体内容如下

1、代码实现

步骤:

1.获得数据
2.建立逻辑回归模型
3.定义损失函数
4.计算损失函数
5.求解梯度
6.梯度更新
7.预测测试集

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

input_size = 784  # 输入到逻辑回归模型中的输入大小
num_classes = 10  # 分类的类别个数
num_epochs = 10  # 迭代次数
batch_size = 50  # 批量训练个数
learning_rate = 0.01  # 学习率


# 下载训练数据和测试数据
train_dataset = dataset.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dataset.MNIST(root='./data',train=False, transform=transforms.ToTensor)

# 使用DataLoader形成批处理文件
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 创建逻辑回归类模型  (sigmoid(wx+b))
class LogisticRegression(nn.Module):
    def __init__(self,input_size,num_classes):
        super(LogisticRegression,self).__init__()
        self.linear = nn.Linear(input_size,num_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

# 设定模型参数
model = LogisticRegression(input_size, num_classes)
# 定义损失函数,分类任务,使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化算法,随机梯度下降,lr为学习率,获得模型需要更新的参数值
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)


# 使用训练数据训练模型
for epoch in range(num_epochs):
    # 批量数据进行模型训练
    for i, (images, labels) in enumerate(train_loader):
        # 需要将数据转换为张量Variable
        images = Variable(images.view(-1, 28*28))
        labels = Variable(labels)
        
        # 梯度更新前需要进行梯度清零
        optimizer.zero_grad()

        # 获得模型的训练数据结果
        outputs = model(images)
        
        # 计算损失函数用于计算梯度
        loss = criterion(outputs, labels)

        # 计算梯度
        loss.backward()
    
        # 进行梯度更新
        optimizer.step()

        # 每隔一段时间输出一个训练结果
        if (i+1) % 100 == 0:
            print('Epoch:[%d %d], Step:[%d/%d], Loss: %.4f' % (epoch+1,num_epochs,i+1,len(train_dataset)//batch_size,loss.item()))

# 训练好的模型预测测试数据集
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images.view(-1, 28*28))  # 形式为(batch_size,28*28)
    outputs = model(images)
    _,predicts = torch.max(outputs.data,1)  # _输出的是最大概率的值,predicts输出的是最大概率值所在位置,max()函数中的1表示维度,意思是计算某一行的最大值
    total += labels.size(0)
    correct += (predicts==labels).sum()

print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))

2、踩过的坑

1.在代码中下载训练数据和测试数据的时候,两段代码是有区别的:

train_dataset = dataset.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dataset.MNIST(root='./data',train=False, transform=transforms.ToTensor)

第一段代码中多了一个download=True,这个的作用是,如果为True,则从Internet下载数据集并将其存放在根目录中。如果数据已经下载,则不会再次下载。

在第二段代码中没有加download=True,加了的话在使用测试数据进行预测的时候会报错。

代码中transform=transforms.ToTensor()的作用是将PIL图像转换为Tensor,同时已经进行归一化处理。

2.代码中设置损失函数:

criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, labels)

一开始的时候直接使用:

loss = nn.CrossEntropyLoss()
loss = loss(outputs, labels)

这样也会报错,因此需要将loss改为criterion。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • 详解Python中os.path与pathlib的用法和性能对比

    详解Python中os.path与pathlib的用法和性能对比

    pathlib 模块是在Python3.4版本中首次被引入到标准库中的,这篇文章主要来和大家介绍一下Python中os.path与pathlib再用法和性能上的区别,感兴趣的可以了解下
    2024-03-03
  • Python探索之自定义实现线程池

    Python探索之自定义实现线程池

    这篇文章主要介绍了Python探索之自定义实现线程池,使用queue实现线程池的方法,具有一定参考价值,需要的朋友可以了解下。
    2017-10-10
  • Python实战之异步获取中国天气信息

    Python实战之异步获取中国天气信息

    这篇文章主要介绍了如何利用Python爬虫异步获取天气信息,用的API是中国天气网。文中的示例代码讲解详细,感兴趣的小伙伴可以动手试一试
    2022-03-03
  • Python实现的微信红包提醒功能示例

    Python实现的微信红包提醒功能示例

    这篇文章主要介绍了Python实现的微信红包提醒功能,结合实例形式分析了Python使用微信模块itchat实现微信红包提醒操作的相关实现技巧,需要的朋友可以参考下
    2019-08-08
  • Python把png转成jpg的项目实践

    Python把png转成jpg的项目实践

    本文主要介绍了Python把png转成jpg的项目实践,可以使用PIL库来将PNG图片转换为JPG格式,具有一定的参考价值,感兴趣的可以了解一下
    2024-02-02
  • Python模拟FTP文件服务器的操作方法

    Python模拟FTP文件服务器的操作方法

    这篇文章主要介绍了Python_模拟FTP文件服务器的操作方法,分为服务端和客户端,要求可以有多个客户端同时操作。本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友参考下吧
    2020-02-02
  • python将音频进行变速的操作方法

    python将音频进行变速的操作方法

    这篇文章主要介绍了python将音频进行变速的操作方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04
  • opencv python 图像轮廓/检测轮廓/绘制轮廓的方法

    opencv python 图像轮廓/检测轮廓/绘制轮廓的方法

    这篇文章主要介绍了opencv python 图像轮廓/检测轮廓/绘制轮廓的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • 使用Python实现首页通知功能

    使用Python实现首页通知功能

    这篇文章主要为大家详细介绍了如何使用Python实现首页通知功能,文中的示例代码讲解详细,具有一定的借鉴价值,有需要的小伙伴可以跟随小编一起学习一下
    2024-02-02
  • pycharm远程调试openstack的图文教程

    pycharm远程调试openstack的图文教程

    这篇文章主要为大家详细介绍了pycharm远程调试openstack的图文教程,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-11-11

最新评论