【机器学习 模型调参】GridSearchCV模型调参利器

生来就可爱ヽ(ⅴ<●) 提交于 2019-12-26 14:44:46

导入模块sklearn.model_selection

from sklearn.model_selection import GridSearchCV

GridSearchCV 称为网格搜索交叉验证调参,它通过遍历传入的参数的所有排列组合,通过交叉验证的方式,返回所有参数组合下的评价指标得分,GridSearchCV 函数的参数详细解释如下:

class sklearn.model_selection.GridSearchCV(estimator,param_grid,scoring = None,n_jobs = None,iid ='deprecated',refit = True,cv = None,verbose = 0,pre_dispatch ='2 * n_jobs',error_score = nan,return_train_score = False )


GridSearchCV官方说明

参数:

estimator:scikit-learn 库里的算法模型;
param_grid:需要搜索调参的参数字典;
scoring:评价指标,可以是 auc, rmse,logloss等;
n_jobs:并行计算线程个数,可以设置为 -1,这样可以充分使用机器的所有处理器,并行数量越多,有利于缩短调参时间;
iid:如果设置为True,则默认假设数据在每折中具有相同地分布,并且最小化的损失是每个样本的总损失,而不是每折的平均损失。简单点说,就是如果你可以确定 cv 中每折数据分布一致就设置为 True,否则设置为 False;
cv:交叉验证的折数,默认为3折;

常用属性:
cv_results_:用来输出cv结果的,可以是字典形式也可以是numpy形式,还可以转换成DataFrame格式
best_estimator_:通过搜索参数得到的最好的估计器,当参数refit=False时该对象不可用
best_score_:float类型,输出最好的成绩
best_params_:通过网格搜索得到的score最好对应的参数
best_index_:对应于最佳候选参数设置的索引(cv_results_数组)。cv_results _ [‘params’] [search.best_index_]中的dict给出了最佳模型的参数设置,给出了最高的平均分数(search.best_score_)。
scorer_:评分函数
n_splits_:交叉验证的数量
refit_time_:refit所用的时间,当参数refit=False时该对象不可用


常用函数:

decision_function(X):返回决策函数值(比如svm中的决策距离)
fit(X,y=None,groups=None,fit_params):在数据集上运行所有的参数组合
get_params(deep=True):返回估计器的参数
inverse_transform(Xt):Call inverse_transform on the estimator with the best found params.
predict(X):返回预测结果值(0/1predict_log_proba(X): Call predict_log_proba on the estimator with the best found parameters.
predict_proba(X):返回每个类别的概率值(有几类就返回几列值)
score(X, y=None):返回函数
set_params(**params):Set the parameters of this estimator.
transform(X):在X上使用训练好的参数

属性grid_scores_已经被删除,改用:

means = grid_search.cv_results_['mean_test_score']
params = grid_search.cv_results_['params']

例子:

from sklearn.model_selection import GridSearchCV
param_gbdt3 = {'learning_rate':[0.06,0.07,0.08,0.09,0.1],
               'n_estimators':[100,150,200,250,300]}


gbdt_search2 = GridSearchCV(estimator=GradientBoostingRegressor(loss='ls',max_depth=9,max_features=9,
                                                                subsample=0.8,min_samples_leaf=4, min_samples_split=6),n_jobs=-1,
                            param_grid=param_gbdt3,scoring='neg_mean_squared_error',iid=False,cv=5)
gbdt_search2.fit(X_train,y_train)
print(gbdt_search2.best_params_)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!