How to use warm_start

后端 未结 4 1338
栀梦
栀梦 2020-12-25 15:51

I\'d like to use the warm_start parameter to add training data to my random forest classifier. I expected it to be used like this:

clf = RandomF         


        
4条回答
  •  借酒劲吻你
    2020-12-25 15:57

    All warm_start does boils down to preserving the state of the previous train.


    It differs from a partial_fit in that the idea is not to incrementally learn on small batches of data, but rather to re-use a trained model in its previous state. Namely the difference between a regular call to fit and a fit having set warm_start=True is that the estimator state is not cleared, see _clear_state

    if not self.warm_start:
        self._clear_state()
    

    Which, among other parameters, would initialize all estimators:

    if hasattr(self, 'estimators_'):
        self.estimators_ = np.empty((0, 0), dtype=np.object)
    

    So having set warm_start=True in each subsequent call to fit will not initialize the trainable parameters, instead it will start from their previous state and add new estimators to the model.


    Which means that one could do:

    grid1={'bootstrap': [True, False],
     'max_depth': [10, 20, 30, 40, 50, 60],
     'max_features': ['auto', 'sqrt'],
     'min_samples_leaf': [1, 2, 4],
     'min_samples_split': [2, 5, 10]}
    
    rf_grid_search1 = GridSearchCV(estimator = RandomForestClassifier(), 
                                   param_distributions = grid1,
                                   cv = 3,
                                   random_state=12)
    rf_grid_search1.fit(X_train, y_train)
    

    Then fit a model on the best parameters and set warm_start=True:

    rf = RandomForestClassifier(**rf_grid_search1.best_params_, warm_start=True)
    rf.fit(X_train, y_train)
    

    Then we could perform GridSearch only on say n_estimators:

    grid2 = {'n_estimators': [200, 400, 600, 800, 1000]}
    rf_grid_search2 = GridSearchCV(estimator = rf,
                                   param_distributions = grid2,
                                   cv = 3, 
                                   random_state=12,
                                   n_iter=4)
    rf_grid_search2.fit(X_train, y_train)
    

    The advantage here is that the estimators would already be fit with the previous parameter setting, and with each subsequent call to fit, the model will be starting from the previous parameters, and we're just analyzing if adding new estimators would benefit the model.

提交回复
热议问题