多项式回归与过拟合
import numpy as np
import matplotlib.pyplot as plt
X_simple = np.arange(1,11).reshape(-1,2)
X_simple
array([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10]])
# 导入多项式处理的类
from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=2)
poly.fit(X_simple)
X2_simple = poly.transform(X_simple)
X2_simple
array([[ 1., 1., 2., 1., 2., 4.],
[ 1., 3., 4., 9., 12., 16.],
[ 1., 5., 6., 25., 30., 36.],
[ 1., 7., 8., 49., 56., 64.],
[ 1., 9., 10., 81., 90., 100.]])
可以看到,在sklearn中,对于有两个特征,,多项式回归预处理将特征变为:
- 第一列全部为1;
- 第二列和第三列是原始特征;
- 第四列和第六列分别是第二列原始特征和第三列原始特征的平方;
- 第五列是第二列原始特征和第三列原始特征的乘积。
poly = PolynomialFeatures(degree=3)
poly.fit(X_simple)
X3_simple = poly.transform(X_simple)
X3_simple
array([[ 1., 1., 2., 1., 2., 4., 1., 2., 4.,
8.],
[ 1., 3., 4., 9., 12., 16., 27., 36., 48.,
64.],
[ 1., 5., 6., 25., 30., 36., 125., 150., 180.,
216.],
[ 1., 7., 8., 49., 56., 64., 343., 392., 448.,
512.],
[ 1., 9., 10., 81., 90., 100., 729., 810., 900.,
1000.]])
如果原始特征为和,那么多项式回归特征变为:.可见,当改变PolynomialFeatures
的degree参数后,转换后样本数据的特征会成指数级的增长,它会尽可能地列出所有的多项式来丰富样本数据。
下面分别用一般线性回归和多项式回归进行训练,输出均方误差值大小。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
# 构建样本数据
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x ** 2 + x + 2 + np.random.normal(0, 1, size=100)
# 先用线性回归进行拟合
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(X, y)
y_predict = lr.predict(X) #得到的是一个预测值数组
plt.scatter(x,y) #画出散点图
plt.plot(np.sort(x),y_predict[np.argsort(x)],color='r')
plt.show()
# 导入均方误差
from sklearn.metrics import mean_squared_error
print(mean_squared_error(y, y_predict))
# 不经过归一化处理的多项式回归,指定最高次数为2
poly = PolynomialFeatures(degree=2)
poly.fit(X)
X2 = poly.transform(X)
lr.fit(X2,y)
y_predict_poly = lr.predict(X2)
plt.scatter(x,y)
plt.plot(np.sort(x),y_predict_poly[np.argsort(x)],color='r')
plt.show()
print(mean_squared_error(y, y_predict_poly))
3.0685100431552086
0.8039389169869919
可以看到,多项式回归相对于线性回归能够更好地拟合数据,在训练集上具有更好地效果(远远好于线性回归)。下面逐渐提高多项式回归最高项的次数,可以看到随着次数的提高,模型在训练集合上的误差越来越小,当时训练误差降到了0.62,可以看到这个曲线的复杂度变得越来越复杂,当次数足够大时,必然找到一条曲线经过所有的训练样本,此时的均方误差接近于0.
def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scalar", StandardScaler()),
("lr", LinearRegression())
])
poly2_reg = PolynomialRegression(degree=5)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
plt.scatter(x,y)
plt.plot(np.sort(x),y2_predict[np.argsort(x)],color='y')
plt.show()
print(mean_squared_error(y, y2_predict))
0.799071260369073
def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scalar", StandardScaler()),
("lr", LinearRegression())
])
poly2_reg = PolynomialRegression(degree=100)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
plt.scatter(x,y)
plt.plot(np.sort(x),y2_predict[np.argsort(x)],color='y')
plt.show()
print(mean_squared_error(y, y2_predict))
# 取100个新的样本点作为测试集,使用模型预测其y值并计算其均方误差
x_test = np.random.uniform(-3, 3, size=100)
X_test = x_test.reshape(-1, 1)
y_test = 0.5 * x_test ** 2 + x_test + 2 + np.random.normal(0, 1, size=100)
y2_predict_test = poly2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict_test))
0.6223122856236626
1055551209.8359903
可以看到,此时尽管训练误差足够小,但测试误差却很大,出现了过拟合的情况,当前训练得到的模型的泛化能力很差。下面使用degrees=2重新训练,可以看到泛化性能是足够好的。
def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scalar", StandardScaler()),
("lr", LinearRegression())
])
poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
plt.scatter(x,y)
plt.plot(np.sort(x),y2_predict[np.argsort(x)],color='y')
plt.show()
print(mean_squared_error(y, y2_predict))
# 取100个新的样本点作为测试集,使用模型预测其y值并计算其均方误差
x_test = np.random.uniform(-3, 3, size=100)
X_test = x_test.reshape(-1, 1)
y_test = 0.5 * x_test ** 2 + x_test + 2 + np.random.normal(0, 1, size=100)
y2_predict_test = poly2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict_test))
0.8039389169869918
0.9126747770972047
来源:CSDN
作者:hust_oluo
链接:https://blog.csdn.net/hbhgyu/article/details/103586179