pytorch实现线性回归

 更新时间:2021年04月09日 14:35:06   作者:逝去〃年华  
这篇文章主要为大家详细介绍了pytorch实现线性回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

pytorch实现线性回归代码练习实例,供大家参考,具体内容如下

欢迎大家指正,希望可以通过小的练习提升对于pytorch的掌握

# 随机初始化一个二维数据集,使用朋友torch训练一个回归模型
import numpy as np
import random
import matplotlib.pyplot as plt

x = np.arange(20)
y = np.array([5*x[i] + random.randint(1,20) for i in range(len(x))])    # random.randint(参数1,参数2)函数返回参数1和参数2之间的任意整数
print('-'*50)
# 打印数据集
print(x)
print(y)

import torch
x_train = torch.from_numpy(x).float()
y_train = torch.from_numpy(y).float()

# model
class LinearRegression(torch.nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        # 输入与输出都是一维的
        self.linear = torch.nn.Linear(1,1)
    def forward(self,x):
        return self.linear(x)

# 新建模型,误差函数,优化器
model = LinearRegression()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),0.001)
# 开始训练
num_epoch = 20
for i in range(num_epoch):
    input_data = x_train.unsqueeze(1)
    target = y_train.unsqueeze(1)           # unsqueeze(1)在第二维增加一个维度
    out = model(input_data)
    loss = criterion(out,target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("Eopch:[{}/{},loss:[{:.4f}]".format(i+1,num_epoch,loss.item()))
    if ((i+1)%2 == 0):
        predict = model(input_data)
        plt.plot(x_train.data.numpy(),predict.squeeze(1).data.numpy(),"r")
        loss = criterion(predict,target)
        plt.title("Loss:{:.4f}".format(loss.item()))
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.scatter(x_train,y_train)
        plt.show()

实验结果:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • python选择排序算法的实现代码

    python选择排序算法的实现代码

    这篇文章主要介绍了python选择排序算法的实现代码,大家参考
    2013-11-11
  • Python实战之生成有关联单选问卷

    Python实战之生成有关联单选问卷

    这篇文章主要为大家分享了一个Python实战小案例——生成有关联单选问卷,并且能根据问卷总分数生成对应判断文案结果,感兴趣的可以了解一下
    2023-04-04
  • Python Spyder 调出缩进对齐线的操作

    Python Spyder 调出缩进对齐线的操作

    这篇文章主要介绍了Python Spyder 调出缩进对齐线的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-02-02
  • python sort、sorted高级排序技巧分享(key的使用)

    python sort、sorted高级排序技巧分享(key的使用)

    这篇文章主要介绍了python sort、sorted高级排序技巧(key的使用),具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-03-03
  • Python编码爬坑指南(必看)

    Python编码爬坑指南(必看)

    下面小编就为大家带来一篇Python编码爬坑指南(必看)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2016-06-06
  • 使用Python实现Office文档(Word/Excel/PowerPoint)批量转换为PDF

    使用Python实现Office文档(Word/Excel/PowerPoint)批量转换为PDF

    在处理不同格式的Office文档(如Word、Excel和PowerPoint)时,将其转换为PDF格式是常见的需求,本文就跟随小编来看看如何使用Python将Word/Excel/PowerPoint批量转换为PDF吧
    2024-10-10
  • Python 忽略warning的输出方法

    Python 忽略warning的输出方法

    今天小编就为大家分享一篇Python 忽略warning的输出方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Flask搭建虚拟环境并运行第一个flask程序

    Flask搭建虚拟环境并运行第一个flask程序

    这篇文章主要介绍了Flask搭建虚拟环境并运行第一个flask程序,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • Python装饰器的两种使用心得

    Python装饰器的两种使用心得

    装饰器(Decorators)是 Python 的一个重要部分。简单地说:他们是修改其他函数的功能的函数。他们有助于让我们的代码更简短,也更Pythonic(Python范儿),今天通过本文给大家分享Python装饰器使用小结,感兴趣的朋友一起看看吧
    2021-09-09
  • Python调用C语言开发的共享库方法实例

    Python调用C语言开发的共享库方法实例

    这篇文章主要介绍了Python调用C语言开发的共享库方法实例,本文同时给出了C语言和Python调用简单实例,需要的朋友可以参考下
    2015-03-03

最新评论