pytorch实现线性回归

 更新时间:2021年04月09日 14:35:06   作者:逝去〃年华  
这篇文章主要为大家详细介绍了pytorch实现线性回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

pytorch实现线性回归代码练习实例,供大家参考,具体内容如下

欢迎大家指正,希望可以通过小的练习提升对于pytorch的掌握

# 随机初始化一个二维数据集,使用朋友torch训练一个回归模型
import numpy as np
import random
import matplotlib.pyplot as plt

x = np.arange(20)
y = np.array([5*x[i] + random.randint(1,20) for i in range(len(x))])    # random.randint(参数1,参数2)函数返回参数1和参数2之间的任意整数
print('-'*50)
# 打印数据集
print(x)
print(y)

import torch
x_train = torch.from_numpy(x).float()
y_train = torch.from_numpy(y).float()

# model
class LinearRegression(torch.nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        # 输入与输出都是一维的
        self.linear = torch.nn.Linear(1,1)
    def forward(self,x):
        return self.linear(x)

# 新建模型,误差函数,优化器
model = LinearRegression()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),0.001)
# 开始训练
num_epoch = 20
for i in range(num_epoch):
    input_data = x_train.unsqueeze(1)
    target = y_train.unsqueeze(1)           # unsqueeze(1)在第二维增加一个维度
    out = model(input_data)
    loss = criterion(out,target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("Eopch:[{}/{},loss:[{:.4f}]".format(i+1,num_epoch,loss.item()))
    if ((i+1)%2 == 0):
        predict = model(input_data)
        plt.plot(x_train.data.numpy(),predict.squeeze(1).data.numpy(),"r")
        loss = criterion(predict,target)
        plt.title("Loss:{:.4f}".format(loss.item()))
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.scatter(x_train,y_train)
        plt.show()

实验结果:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • python实现提取COCO,VOC数据集中特定的类

    python实现提取COCO,VOC数据集中特定的类

    这篇文章主要介绍了python实现提取COCO,VOC数据集中特定的类,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • 基于Keras的扩展性使用

    基于Keras的扩展性使用

    这篇文章主要介绍了Keras的扩展性使用操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • Python基于pandas爬取网页表格数据

    Python基于pandas爬取网页表格数据

    这篇文章主要介绍了Python基于pandas获取网页表格数据,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • pip install urllib2不能安装的解决方法

    pip install urllib2不能安装的解决方法

    今天小编就为大家分享一篇pip install urllib2不能安装的解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • Python使用爬虫猜密码

    Python使用爬虫猜密码

    我们可以通过python 来实现这样一个简单的爬虫猜密码功能。下面就看看如何使用python来实现这样一个功能,对python爬虫猜密码相关知识感兴趣的朋友参考下吧
    2016-02-02
  • python argparse命令行参数解析(推荐)

    python argparse命令行参数解析(推荐)

    Python argparse模块是解析命令行参数的首选方法。解析命令行参数是一个非常常见的任务,Python脚本根据传递的值来执行和操作
    2021-06-06
  • pandas数值计算与排序方法

    pandas数值计算与排序方法

    下面小编就为大家分享一篇pandas数值计算与排序方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • python list是否包含另一个list所有元素的实例

    python list是否包含另一个list所有元素的实例

    今天小编就为大家分享一篇python list是否包含另一个list所有元素的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • 使用python 对验证码图片进行降噪处理

    使用python 对验证码图片进行降噪处理

    今天小编就为大家分享一篇使用python 对验证码图片进行降噪处理,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python实现动态条形图的示例详解

    Python实现动态条形图的示例详解

    这篇文章主要为大家详细介绍了如何利用Python中的pynimate模块实现动态条形图的绘制,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下
    2023-03-03

最新评论