pytorch使用Variable实现线性回归

 更新时间:2021年04月09日 14:23:09   作者:东城青年  
这篇文章主要为大家详细介绍了pytorch使用Variable实现线性回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

显示的最后一张图如下所示:

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

显示的最后一张图如下所示:

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • Python中拆分字符串的操作方法

    Python中拆分字符串的操作方法

    这篇文章主要介绍了Python中拆分字符串的操作方法,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-07-07
  • 利用Python绘制酷炫的3D地图

    利用Python绘制酷炫的3D地图

    pyecharts是一款将python与echarts结合的强大的数据可视化工具。本文将为大家介绍如何利用pyecharts绘制酷炫的3D地图,感兴趣的小伙伴可以试一试
    2022-03-03
  • Python StringIO及BytesIO包使用方法解析

    Python StringIO及BytesIO包使用方法解析

    这篇文章主要介绍了Python StringIO及BytesIO包使用方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • python使用Pillow将照片转换为1寸报名照片的教程分享

    python使用Pillow将照片转换为1寸报名照片的教程分享

    在现代科技时代,我们经常需要调整和处理照片以适应特定的需求和用途,本文将介绍如何使用wxPython和Pillow库,通过一个简单的图形界面程序,将选择的照片转换为指定尺寸的JPG格式,并保存在桌面上,需要的朋友可以参考下
    2023-09-09
  • pandas 条件搜索返回列表的方法

    pandas 条件搜索返回列表的方法

    今天小编就为大家分享一篇pandas 条件搜索返回列表的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Pycharm插件(Grep Console)自定义规则输出颜色日志的方法

    Pycharm插件(Grep Console)自定义规则输出颜色日志的方法

    这篇文章主要介绍了Pycharm插件(Grep Console)自定义规则输出颜色日志的方法,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-05-05
  • Python math库 ln(x)运算的实现及原理

    Python math库 ln(x)运算的实现及原理

    这篇文章主要介绍了Python math库 ln(x)运算的实现及原理,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • 浅谈django框架集成swagger以及自定义参数问题

    浅谈django框架集成swagger以及自定义参数问题

    这篇文章主要介绍了浅谈django框架集成swagger以及自定义参数问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07
  • Python中字典的setdefault()方法教程

    Python中字典的setdefault()方法教程

    在学习python字典操作方法时,感觉setdefault()方法,比字典的其它基本操作方法更难理解的同学比较多,所以想着总结以下,下面这篇文章主要给大家介绍了Python中字典的setdefault()方法,需要的朋友可以参考借鉴,下面来一起看看吧。
    2017-02-02
  • 基于Python实现ComicReaper漫画自动爬取脚本过程解析

    基于Python实现ComicReaper漫画自动爬取脚本过程解析

    这篇文章主要介绍了基于Python实现ComicReaper漫画自动爬取脚本过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-11-11

最新评论