基于Pytorch的神经网络之Regression的实现

 更新时间:2022年03月15日 10:15:49   作者:ZDDWLIG  
本文主要介绍了基于Pytorch的神经网络之Regression的实现,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

1.引言

我们之前已经介绍了神经网络的基本知识,神经网络的主要作用就是预测与分类,现在让我们来搭建第一个用于拟合回归的神经网络吧。

2.神经网络搭建

2.1 准备工作

要搭建拟合神经网络并绘图我们需要使用python的几个库。

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
 
x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)
y = x.pow(3) + 0.2 * torch.rand(x.size())

 既然是拟合,我们当然需要一些数据啦,我选取了在区间 [-5,5] 内的100个等间距点,并将它们排列成三次函数的图像。

2.2 搭建网络

我们定义一个类,继承了封装在torch中的一个模块,我们先分别确定输入层、隐藏层、输出层的神经元数目,继承父类后再使用torch中的.nn.Linear()函数进行输入层到隐藏层的线性变换,隐藏层也进行线性变换后传入输出层predict,接下来定义前向传播的函数forward(),使用relu()作为激活函数,最后输出predict()结果即可。

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)
    def forward(self, x):
        x = F.relu(self.hidden(x))
        return self.predict(x)
net = Net(1, 20, 1)
print(net)
optimizer = torch.optim.Adam(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()

网络的框架搭建完了,然后我们传入三层对应的神经元数目再定义优化器,这里我选取了Adam而随机梯度下降(SGD),因为它是SGD的优化版本,效果在大部分情况下比SGD好,我们要传入这个神经网络的参数(parameters),并定义学习率(learning rate),学习率通常选取小于1的数,需要凭借经验并不断调试。最后我们选取均方差法(MSE)来计算损失(loss)。

2.3 训练网络

接下来我们要对我们搭建好的神经网络进行训练,我训练了2000轮(epoch),先更新结果prediction再计算损失,接着清零梯度,然后根据loss反向传播(backward),最后进行优化,找出最优的拟合曲线。

for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

3.效果

使用如下绘图的代码展示效果。

for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy(), s=10)
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
        plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

最后的结果: 

4. 完整代码

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
 
x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)
y = x.pow(3) + 0.2 * torch.rand(x.size())
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)
    def forward(self, x):
        x = F.relu(self.hidden(x))
        return self.predict(x)
net = Net(1, 20, 1)
print(net)
optimizer = torch.optim.Adam(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()
plt.ion()
for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy(), s=10)
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
        plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

到此这篇关于基于Pytorch的神经网络之Regression的实现的文章就介绍到这了,更多相关 Pytorch Regression内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 浅析NumPy 切片和索引

    浅析NumPy 切片和索引

    这篇文章主要介绍了NumPy 切片和索引的相关资料,帮助大家更好的理解和学习NumPy的相关知识,感兴趣的朋友可以了解下。
    2020-09-09
  • Python通过cron或schedule实现爬虫的自动定时运行

    Python通过cron或schedule实现爬虫的自动定时运行

    自动定时运行爬虫是很多数据采集项目的基本需求,通过 Python 实现定时任务,可以保证数据采集的高效和持续性,本文将带大家了解如何在 Python 中使用 cron 和 schedule 来实现爬虫的自动定时运行,需要的朋友可以参考下
    2024-12-12
  • Pandas提高数据分析效率的13个技巧汇总

    Pandas提高数据分析效率的13个技巧汇总

    这篇文章主要是为大家归纳整理了13个工作中常用到的pandas使用技巧,方便更高效地实现数据分析,感兴趣的小伙伴可以跟随小编一起学习一下
    2022-05-05
  • Numpy实现矩阵运算及线性代数应用

    Numpy实现矩阵运算及线性代数应用

    这篇文章主要介绍了Numpy实现矩阵运算及线性代数应用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-03-03
  • 利用Python来控制终端打印字体的颜色和格式

    利用Python来控制终端打印字体的颜色和格式

    使用python编程时,改变控制台或终端中输出字体的颜色和格式,会显著提升代码质量,快速帮助我们定位问题和锁定重要输出,但是一般情况下,python控制台输出的字体默认为白色,所以这篇文章给大家介绍了如何利用Python控制终端打印字体的颜色和格式,需要的朋友可以参考下
    2024-06-06
  • python实现批量监控网站

    python实现批量监控网站

    本文给大家分享的是一个非常实用的,python实现多网站的可用性监控的脚本,并附上核心点解释,有相同需求的小伙伴可以参考下
    2016-09-09
  • Python操作word文档的示例详解

    Python操作word文档的示例详解

    本文为大家介绍了Python操作docx文档相关知识点。主要涉及的内容为python-docx ,一款可以操作Word文档(仅支持docx)的第三方库。快跟随小编一起学习一下吧
    2022-01-01
  • Pycharm连接远程服务器并实现远程调试的实现

    Pycharm连接远程服务器并实现远程调试的实现

    这篇文章主要介绍了Pycharm连接远程服务器并实现远程调试的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-08-08
  • 深入解析Python中的集合类型操作符

    深入解析Python中的集合类型操作符

    这篇文章主要介绍了深入解析Python中的集合类型操作符,是Python入门学习中的基础知识,需要的朋友可以参考下
    2015-08-08
  • python实现web方式logview的方法

    python实现web方式logview的方法

    这篇文章主要介绍了python实现web方式logview的方法,涉及Python基于web模块操作Linux命令的技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-08-08

最新评论