pytorch使用nn.Moudle实现逻辑回归

 更新时间:2022年07月30日 15:42:35   作者:ALEN.Z  
这篇文章主要为大家详细介绍了pytorch使用nn.Moudle实现逻辑回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了pytorch使用nn.Moudle实现逻辑回归的具体代码,供大家参考,具体内容如下

内容

pytorch使用nn.Moudle实现逻辑回归

问题

loss下降不明显

解决方法

#源代码 out的数据接收方式
     if torch.cuda.is_available():
         x_data=Variable(x).cuda()
         y_data=Variable(y).cuda()
     else:
         x_data=Variable(x)
         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
#源代码 out的数据有拼装数据直接输入
#     if torch.cuda.is_available():
#         x_data=Variable(x).cuda()
#         y_data=Variable(y).cuda()
#     else:
#         x_data=Variable(x)
#         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
    print_loss=loss.data.item()  #得出损失函数值

源代码

import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np

#生成数据
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums, 2)
x0 = torch.normal(mean_value * n_data, 1) + bias      # 类别0 数据 shape=(100, 2)
y0 = torch.zeros(sample_nums)                         # 类别0 标签 shape=(100, 1)
x1 = torch.normal(-mean_value * n_data, 1) + bias     # 类别1 数据 shape=(100, 2)
y1 = torch.ones(sample_nums)                          # 类别1 标签 shape=(100, 1)
x_data = torch.cat((x0, x1), 0)  #按维数0行拼接
y_data = torch.cat((y0, y1), 0)

#画图
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.show()

# 利用torch.nn实现逻辑回归
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.lr = nn.Linear(2, 1)
        self.sm = nn.Sigmoid()

    def forward(self, x):
        x = self.lr(x)
        x = self.sm(x)
        return x
    
logistic_model = LogisticRegression()
# if torch.cuda.is_available():
#     logistic_model.cuda()

#loss函数和优化
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(logistic_model.parameters(), lr=0.01, momentum=0.9)
#开始训练
#训练10000次
for epoch in range(10000):
#     if torch.cuda.is_available():
#         x_data=Variable(x).cuda()
#         y_data=Variable(y).cuda()
#     else:
#         x_data=Variable(x)
#         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
    print_loss=loss.data.item()  #得出损失函数值
    #反向传播
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    mask=out.ge(0.5).float()  #以0.5为阈值进行分类
    correct=(mask==y_data).sum().squeeze()  #计算正确预测的样本个数
    acc=correct.item()/x_data.size(0)  #计算精度
    #每隔20轮打印一下当前的误差和精度
    if (epoch+1)%100==0:
        print('*'*10)
        print('epoch {}'.format(epoch+1))  #误差
        print('loss is {:.4f}'.format(print_loss))
        print('acc is {:.4f}'.format(acc))  #精度
        
        
w0, w1 = logistic_model.lr.weight[0]
w0 = float(w0.item())
w1 = float(w1.item())
b = float(logistic_model.lr.bias.item())
plot_x = np.arange(-7, 7, 0.1)
plot_y = (-w0 * plot_x - b) / w1
plt.xlim(-5, 7)
plt.ylim(-7, 7)
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=logistic_model(x_data)[:,0].cpu().data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.plot(plot_x, plot_y)
plt.show()

输出结果

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • Python中的TCP socket写法示例

    Python中的TCP socket写法示例

    最近在学习脚本语言python,所以下面这篇文章主要给大家介绍了关于Python中TCP socket写法的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或工作具有一定的参考学习价值,需要的朋友们一起来看看吧
    2018-05-05
  • Python 命令行非阻塞输入的小例子

    Python 命令行非阻塞输入的小例子

    很久很久以前,系windows平台下,用C语言写过一款贪食蛇游戏,cmd界面,用kbhit()函数实现非阻塞输入。系windows平台下用python依然可以调用msvcrt.khbit实现非阻塞监听。但系喺linux下面就冇呢支歌仔唱
    2013-09-09
  • PyCharm-错误-找不到指定文件python.exe的解决方法

    PyCharm-错误-找不到指定文件python.exe的解决方法

    今天小编就为大家分享一篇PyCharm-错误-找不到指定文件python.exe的解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python随机函数random随机获取数字、字符串、列表等使用详解

    Python随机函数random随机获取数字、字符串、列表等使用详解

    这篇文章主要介绍了Python随机函数random使用详解包含了Python随机数字,Python随机字符串,Python随机列表等,需要的朋友可以参考下
    2021-04-04
  • python学习入门细节知识点

    python学习入门细节知识点

    我们整理了关于python入门学习的一些细节知识点,对于学习python的初学者很有用,一起学习下。
    2018-03-03
  • Python3实现飞机大战游戏

    Python3实现飞机大战游戏

    这篇文章主要为大家详细介绍了Python3实现飞机大战游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-04-04
  • Python numpy.transpose使用详解

    Python numpy.transpose使用详解

    本文主要介绍了Python numpy.transpose使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-08-08
  • python字符串验证的几种实现方法

    python字符串验证的几种实现方法

    字符串的验证是确保数据符合特定要求的关键步骤之一,本文主要介绍了python字符串验证的几种实现方法,具有一定的参考价值,感兴趣的可以了解一下
    2024-07-07
  • Python实现的数据结构与算法之队列详解

    Python实现的数据结构与算法之队列详解

    这篇文章主要介绍了Python实现的数据结构与算法之队列,详细分析了队列的定义、功能与Python实现队列的相关技巧,以及具体的用法,需要的朋友可以参考下
    2015-04-04
  • 详解 Python中LEGB和闭包及装饰器

    详解 Python中LEGB和闭包及装饰器

    这篇文章主要介绍了详解 Python中LEGB和闭包及装饰器的相关资料,主要介绍了函数作用域和闭包的理解和使用方法及Python中的装饰器,需要的朋友可以参考下
    2017-08-08

最新评论