pytorch实现多项式回归

 更新时间:2021年04月09日 17:33:03   作者:逝去〃年华  
这篇文章主要为大家详细介绍了pytorch实现多项式回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

pytorch实现多项式回归,供大家参考,具体内容如下

一元线性回归模型虽然能拟合出一条直线,但精度依然欠佳,拟合的直线并不能穿过每个点,对于复杂的拟合任务需要多项式回归拟合,提高精度。多项式回归拟合就是将特征的次数提高,线性回归的次数使一次的,实际我们可以使用二次、三次、四次甚至更高的次数进行拟合。由于模型的复杂度增加会带来过拟合的风险,因此需要采取正则化损失的方式减少过拟合,提高模型泛化能力。希望大家可以自己动手,通过一些小的训练掌握pytorch(案例中有些观察数据格式的代码,大家可以自己注释掉)

# 相较于一元线性回归模型,多项式回归可以很好的提高拟合精度,但要注意过拟合风险
# 多项式回归方程 f(x) = -1.13x-2.14x^2+3.12x^3-0.01x^4+0.512
import torch
import matplotlib.pyplot as plt
import numpy as np
# 数据准备(测试数据)
x = torch.linspace(-2,2,50)
print(x.shape)
y = -1.13*x - 2.14*torch.pow(x,2) + 3.15*torch.pow(x,3) - 0.01*torch.pow(x,4) + 0.512
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

# 此时输入维度为4维
# 为了拼接输入数据,需要编写辅助数据,输入标量x,使其变为矩阵,使用torch.cat拼接
def features(x): # 生成矩阵
    # [x,x^2,x^3,x^4]
    x = x.unsqueeze(1)
    print(x.shape)
    return torch.cat([x ** i for i in range(1,5)], 1)
result = features(x)
print(result.shape)
# 目标公式用于计算输入特征对应的标准输出
# 目标公式的权重如下
x_weight = torch.Tensor([-1.13,-2.14,3.15,-0.01]).unsqueeze(1)
b = torch.Tensor([0.512])
# 得到x数据对应的标准输出
def target(x):
    return x.mm(x_weight) + b.item()

# 新建一个随机生成输入数据和输出数据的函数,用于生成训练数据

def get_batch_data(batch_size):
    # 生成batch_size个随机的x
    batch_x = torch.randn(batch_size)
    # 对于每个x要生成一个矩阵
    features_x = features(batch_x)
    target_y = target(features_x)
    return features_x,target_y

# 创建模型
class PolynomialRegression(torch.nn.Module):
    def __init__(self):
        super(PolynomialRegression, self).__init__()
        # 输入四维度 输出一维度
        self.poly = torch.nn.Linear(4,1)

    def forward(self, x):
        return self.poly(x)

# 开始训练模型
epochs = 10000
batch_size = 32
model = PolynomialRegression()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),0.001)

for epoch in range(epochs):
    print("{}/{}".format(epoch+1,epochs))
    batch_x,batch_y = get_batch_data(batch_size)
    out = model(batch_x)
    loss = criterion(out,batch_y)
    optimizer.zero_grad()
    loss.backward()
    # 更新梯度
    optimizer.step()
    if (epoch % 100 == 0):
        print("Epoch:[{}/{}],loss:{:.6f}".format(epoch,epochs,loss.item()))
    if (epoch % 1000 == 0):
        predict = model(features(x))
        print(x.shape)
        print(predict.shape)
        print(predict.squeeze(1).shape)
        plt.plot(x.data.numpy(),predict.squeeze(1).data.numpy(),"r")
        loss = criterion(predict,y)
        plt.title("Loss:{:.4f}".format(loss.item()))
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.scatter(x,y)
        plt.show()

拟合结果:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • python树的双亲存储结构的实现示例

    python树的双亲存储结构的实现示例

    本文主要介绍了python树的双亲存储结构,这种存储结构是一种顺序存储结构,采用元素形如“[结点值,双亲结点索引]”的列表表示,感兴趣的可以了解一下
    2023-11-11
  • python使用lxml xpath模块解析XML遇到的坑及解决

    python使用lxml xpath模块解析XML遇到的坑及解决

    这篇文章主要介绍了python使用lxml xpath模块解析XML遇到的坑及解决,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-05-05
  • Python httplib模块使用实例

    Python httplib模块使用实例

    这篇文章主要介绍了Python httplib模块使用实例,httplib模块是一个底层基础模块,本文讲解了httplib模块的常用方法及使用实例,需要的朋友可以参考下
    2015-04-04
  • 快速排序的算法思想及Python版快速排序的实现示例

    快速排序的算法思想及Python版快速排序的实现示例

    快速排序算法来源于分治法的思想策略,这里我们将来为大家简单解析一下快速排序的算法思想及Python版快速排序的实现示例:
    2016-07-07
  • 使用keras实现densenet和Xception的模型融合

    使用keras实现densenet和Xception的模型融合

    这篇文章主要介绍了使用keras实现densenet和Xception的模型融合,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • python判断变量是否为列表的方法

    python判断变量是否为列表的方法

    在本篇文章里小编给大家整理了关于python判断变量是否为列表的方法,有需要的朋友们可以学习下。
    2020-09-09
  • Python猜解网站数据库管理员密码的脚本

    Python猜解网站数据库管理员密码的脚本

    这篇文章主要和大家分享一个Python脚本,可以实现猜解网站数据库管理员的密码。文中的示例代码讲解详细,需要的小伙伴可以参考一下
    2022-02-02
  • 详解Python中高阶函数(map,filter,reduce,sorted)的使用

    详解Python中高阶函数(map,filter,reduce,sorted)的使用

    高阶函数就是能够把函数当成参数传递的函数就是高阶函数,换句话说如果一个函数的参数是函数,那么这个函数就是一个高阶函数。本文为大家详细讲解了Python中常用的四个高阶函数,感兴趣的可以了解一下
    2022-04-04
  • 通过Python脚本+Jenkins实现项目重启

    通过Python脚本+Jenkins实现项目重启

    Jenkins是一个流行的开源自动化服务器,用于快速构建、测试和部署软件,本文主要介绍了通过Python脚本+Jenkins实现项目重启,具有一定的参考价值,感兴趣的可以了解一下
    2023-10-10
  • 7个Python中的隐藏小技巧分享

    7个Python中的隐藏小技巧分享

    Python 是每个程序员都喜欢的语言,因为它易于编码和易于阅读的语法。但是,你知道 python 有一些很酷的技巧可以用来让事情变得更简单吗?在今天的内容中,我将与你分享7 个你可能从未使用过的Python 技巧
    2023-03-03

最新评论