PyTorch零基础入门之逻辑斯蒂回归

 更新时间:2021年10月19日 10:47:57   作者:山顶夕景  
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序,它是一个可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络

学习总结

(1)和上一讲的模型训练是类似的,只是在线性模型的基础上加个sigmoid,然后loss函数改为交叉熵BCE函数(当然也可以用其他函数),另外一开始的数据y_data也从数值改为类别0和1(本例为二分类,注意x_datay_data这里也是矩阵的形式)。

一、sigmoid函数

logistic function是一种sigmoid函数(还有其他sigmoid函数),但由于使用过于广泛,pytorch默认logistic function叫为sigmoid函数。还有如下的各种sigmoid函数:

在这里插入图片描述

二、和Linear的区别

逻辑斯蒂和线性模型的unit区别如下图:

在这里插入图片描述

sigmoid函数是不需要参数的,所以不用对其初始化(直接调用nn.functional.sigmoid即可)。
另外loss函数从MSE改用交叉熵BCE:尽可能和真实分类贴近。

在这里插入图片描述

如下图右方表格所示,当 y ^ \hat{y} y^​越接近y时则BCE Loss值越小。

在这里插入图片描述

三、逻辑斯蒂回归(分类)PyTorch实现

# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 08:35:00 2021

@author: 86493
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt  
import torch.nn.functional as F
import numpy as np

# 准备数据
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])


losslst = []

class LogisticRegressionModel(nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
        
    def forward(self, x):
    	# 和线性模型的网络的唯一区别在这句,多了F.sigmoid
        y_predict = F.sigmoid(self.linear(x))
        return y_predict
    
model = LogisticRegressionModel()

# 使用交叉熵作损失函数
criterion = torch.nn.BCELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(), 
                            lr = 0.01)

# 训练
for epoch in range(1000):
    y_predict = model(x_data)
    loss = criterion(y_predict, y_data)
    # 打印loss对象会自动调用__str__
    print(epoch, loss.item())
    losslst.append(loss.item())
    # 梯度清零后反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 画图
plt.plot(range(1000), losslst)
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.show()


# test
# 每周学习的时间,200个点
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
# 画 probability of pass = 0.5的红色横线
plt.plot([0, 10], [0.5, 0.5], c = 'r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

在这里插入图片描述

可以看出处于通过和不通过的分界线是Hours=2.5。

在这里插入图片描述

Reference

pytorch官方文档

到此这篇关于PyTorch零基础入门之逻辑斯蒂回归的文章就介绍到这了,更多相关PyTorch 逻辑斯蒂回归内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python之如何将标签转化为one-hot(独热编码)

    python之如何将标签转化为one-hot(独热编码)

    这篇文章主要介绍了python之如何将标签转化为one-hot(独热编码)问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-06-06
  • Python中的复制操作及copy模块中的浅拷贝与深拷贝方法

    Python中的复制操作及copy模块中的浅拷贝与深拷贝方法

    浅拷贝和深拷贝是Python基础学习中必须辨析的知识点,这里我们将为大家解析Python中的复制操作及copy模块中的浅拷贝与深拷贝方法:
    2016-07-07
  • Python使用Selenium、PhantomJS爬取动态渲染页面

    Python使用Selenium、PhantomJS爬取动态渲染页面

    本文主要介绍了Python使用Selenium、PhantomJS爬取动态渲染页面,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-05-05
  • python程序中断然后接着中断代码继续运行问题

    python程序中断然后接着中断代码继续运行问题

    这篇文章主要介绍了python程序中断然后接着中断代码继续运行问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-02-02
  • 对python多线程与global变量详解

    对python多线程与global变量详解

    今天小编就为大家分享一篇对python多线程与global变量详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • python实现跨excel的工作表sheet之间的复制方法

    python实现跨excel的工作表sheet之间的复制方法

    今天小编就为大家分享一篇python实现跨excel的工作表sheet之间的复制方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python数据分析之 Pandas Dataframe应用自定义

    Python数据分析之 Pandas Dataframe应用自定义

    这篇文章主要介绍了Python数据分析之 Pandas Dataframe应用自定义,文章基于python的相关资料展开 Pandas Dataframe应用自定义的详细内容,需要的小伙伴可以参考一下
    2022-05-05
  • pandas如何将DataFrame 转为txt文本去除引号

    pandas如何将DataFrame 转为txt文本去除引号

    这篇文章主要介绍了pandas如何将DataFrame 转为txt文本去除引号,文中补充介绍了DataFrame导CSV txt || 每行有双引号的原因及解决办法,感兴趣的朋友跟随小编一起看看吧
    2024-01-01
  • Python遗传算法Geatpy工具箱使用介绍

    Python遗传算法Geatpy工具箱使用介绍

    这篇文章主要为大家介绍了Python遗传算法Geatpy工具箱使用介绍,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-09-09
  • PyChar学习教程之自定义文件与代码模板详解

    PyChar学习教程之自定义文件与代码模板详解

    pycharm默认的【新建】文件,格式很不友好,那么就需要改一下文件模板。下面这篇文章主要给大家介绍了关于PyChar学习教程之自定义文件与代码模板的相关资料,文中通过示例代码介绍的非常详细,需要的朋友们下面跟着小编来一起看看吧。
    2017-07-07

最新评论