【Pytorch】5. Pytorch搭建多项式回归模型

mac2026-05-14  6

一、理论介绍

对于一般的线性回归模型,由于该函数拟合出来的是一条直线,所以精度欠佳,我们可以考虑多项式回归,构造非线性特征,用的是高次多项式而不是简单的一 次线性多项式。所谓多项式回归,其本质也是线性回归。也就是说,我们采取的方法是,提高每个特征的次数来增加维度数。

1、需要拟合的方程:

y = 2.4 x 3 + 3 x 2 + 0.2 x + 0.9 y = 2.4x^3+3x^2+0.2x+0.9 y=2.4x3+3x2+0.2x+0.9

2、构建多项式回归方程:

对于输入变量x和输出值y,我们只需要增加其平方项、三次方项系数即可。目标是将每一个参数都能够学习到和真实参数很接近的结果 y = w 3 x 3 + w 2 x 2 + w 1 x + b y = w_{3}x^3+w_{2}x^2+w_{1}x+b y=w3x3+w2x2+w1x+b

3、数据预处理

多项式回归方程与线性回归方程并没有本质区别。可以采用线性回归的方式来进行多项式的拟合。所以需要先对数据进行预处理,将数据变为如下矩阵形式:

二、代码实现

import torch import matplotlib.pyplot as plt import torch.optim as optim from torch.autograd import Variable import torch.nn as nn import numpy as np #构造数据x def make_features(x): x = x.unsqueeze(1) #给x加一个维度 return torch.cat([x**i for i in range(1,4)],1) #定义真实函数 w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1) b_target = torch.FloatTensor([0.9]) def f(x): return x.mm(w_target) + b_target #获取每批输入模型的x,和通过函数计算的实际标签y def get_batch(batch_size=32): random = torch.randn(batch_size) #随机产生32个,从标准正太分布中抽取的一组随机数。 #注意这里一定要对x进行排序 random = np.sort(random) random = torch.Tensor(random) x = make_features(random) y = f(x) if torch.cuda.is_available(): return Variable(x).cuda(),Variable(y).cuda() else: return Variable(x),Variable(y) #定义模型 class poly_model(nn.Module): def __init__(self): super(poly_model,self).__init__() self.poly = nn.Linear(3,1) #输入3维,输出1个 def forward(self,x): out = self.poly(x) return out if torch.cuda.is_available(): model = poly_model().cuda() else: model = poly_model() #定义损失和优化函数 criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(),lr=1e-3) #模型训练 epoch = 0 for i in range(20000): #获取数据 batch_x,batch_y = get_batch() #前向传播 output = model(batch_x) loss = criterion(output,batch_y) print_loss = loss.item() print(f'epoch:{epoch},print_loss:{print_loss}') #梯度清零 optimizer.zero_grad() #反向运算 loss.backward() #参数更新 optimizer.step() epoch += 1 if print_loss < 1e-3: break #预测 model.eval() predict = model(batch_x) predict = predict.data.numpy() plt.plot(batch_x.numpy()[:,0],batch_y.numpy(),'ro',label='Original data') plt.plot(batch_x.numpy()[:,0],predict,label='Fitting Line') plt.show() #打印拟合的回归模型 print( f'====> Learned function:y = {model.poly.bias[0]:.2f} + {model.poly.weight[0][0]:.2f}*x + {model.poly.weight[0][1]:.2f}*x^2 + {model.poly.weight[0][2]:.2f}*x^3') print( f'====> Actual function:y = {b_target[0]:.2f} + {w_target[0][0]:.2f}*x + {w_target[1][0]:.2f}*x^2 + {w_target[2][0]:.2f}*x^3')

拟合结果如下图: 拟合的多项式回归方程:

最新回复(0)