线性回归曲线和过拟合判断

匿名 (未验证) 提交于 2019-12-03 00:15:02
import matplotlib.pyplot as plt import mglearn from scipy import sparse import numpy as np import matplotlib as mt import pandas as pd from IPython.display import display from sklearn.linear_model import LinearRegression from sklearn.model_selection import train_test_split  #wave数据集 #wave数据集只有一个特征 #公式为y=w[0]x[0]+b #w为斜率,b为轴偏移或截距,分别在sklearn中使用 coef_[0],  intercept_表示 mglearn.plots.plot_linear_regression_wave() plt.show()  #boston数据集 #boston数据集有506个样本,105个特征 X, y = mglearn.datasets.load_extended_boston() X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) lr = LinearRegression().fit(X_train, y_train) print("Training set score: {:.2f}".format(lr.score(X_train, y_train))) print("Test set score: {:.2f}".format(lr.score(X_test, y_test)))

结果:

w[0]: 0.393906 b: -0.031804

plot_linear_regression_wave源码
import numpy as np import matplotlib.pyplot as plt  from sklearn.linear_model import LinearRegression from sklearn.model_selection import train_test_split from .datasets import make_wave from .plot_helpers import cm2   def plot_linear_regression_wave():     X, y = make_wave(n_samples=60)     X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)      line = np.linspace(-3, 3, 100).reshape(-1, 1)      lr = LinearRegression().fit(X_train, y_train)     print("w[0]: %f  b: %f" % (lr.coef_[0], lr.intercept_))      plt.figure(figsize=(8, 8))     plt.plot(line, lr.predict(line))     plt.plot(X, y, 'o', c=cm2(0))     ax = plt.gca()     ax.spines['left'].set_position('center')     ax.spines['right'].set_color('none')     ax.spines['bottom'].set_position('center')     ax.spines['top'].set_color('none')     ax.set_ylim(-3, 3)     #ax.set_xlabel("Feature")     #ax.set_ylabel("Target")     ax.legend(["model", "training data"], loc="best")     ax.grid(True)     ax.set_aspect('equal')

 

结果2:


Training set score: 0.95
Test set score: 0.61

可以看出出现了过拟合,这是因为波士顿房价的各个特征的差距非常大,不适合使用最小二乘法,需要使用“正则化”来做显式约束,使用岭回归避免过拟合。

Ridge岭回归用到L2正则化。

Lasso回归用到L1正则,还可以使用ElasticNet。

 
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!