GridSearchCV - access to predicted values across tests?

ぐ巨炮叔叔 提交于 2019-12-11 00:24:07

问题


Is there a way to get access to the predicted values calculated within a GridSearchCV process?

I'd like to be able to plot the predicted y values against their actual values (from the test/validation set).

Once the grid search is complete, I can fit it against some other data using

 ypred = grid.predict(xv)

but I'd like to be able to plot the values calculated during the grid search. Maybe there's a way of saving the points as a pandas dataframe?

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV, KFold, 
cross_val_score, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import SVR

scaler = StandardScaler()
svr_rbf = SVR(kernel='rbf')
pipe = Pipeline(steps=[('scaler', scaler), ('svr_rbf', svr_rbf)])
grid = GridSearchCV(pipe, param_grid=parameters, cv=splits, refit=True, verbose=3, scoring=msescorer, n_jobs=4)
grid.fit(xt, yt)

回答1:


One solution is to make a custom scorer and save an argument it receives into a global variable:

from sklearn.grid_search import GridSearchCV
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error,make_scorer

X, y = np.random.rand(2,200)
clf = SVR()

ys = []

def MSE(y_true,y_pred):
    global ys
    ys.append(y_pred)
    mse = mean_squared_error(y_true, y_pred)
    return mse

def scorer():
    return make_scorer(MSE, greater_is_better=False)

n_splits = 3 
cv = GridSearchCV(clf, {'degree':[1,2,3]}, scoring=scorer(), cv=n_splits)
cv.fit(X.reshape(-1, 1), y)

Then we need to collect every split into a full array:

idxs = range(0, len(ys)+1, n_splits)
#e.g. [0, 3, 6, 9]
#collect every n_split elements into a single list
new = [ys[j[0]+1:j[1]] for j in zip(idxs,idxs[1:])]
#summing every such list
ys = [reduce(lambda x,y:np.concatenate((x,y), axis=0), i) for i in new]


来源:https://stackoverflow.com/questions/49633465/gridsearchcv-access-to-predicted-values-across-tests

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!