(Python - sklearn) How to pass parameters to the customize ModelTransformer class by gridsearchcv

末鹿安然 提交于 2019-11-28 04:00:57

GridSearchCV has a special naming convention for nested objects. In your case ess__rfc__n_estimators stands for ess.rfc.n_estimators, and, according to the definition of the pipeline, it points to the property n_estimators of

ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1,  n_estimators=100)))

Obviously, ModelTransformer instances don't have such property.

The fix is easy: in order to access underlying object of ModelTransformer one needs to use model field. So, grid parameters become

parameters = {
  'ess__rfc__model__n_estimators': (100, 200),
}

P.S. it's not the only problem with your code. In order to use multiple jobs in GridSearchCV, you need to make all objects you're using copy-able. This is achieved by implementing methods get_params and set_params, you can borrow them from BaseEstimator mixin.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!