Parameter “stratify” from method “train_test_split” (scikit Learn)

前端 未结 5 1689
我寻月下人不归
我寻月下人不归 2020-12-22 21:37

I am trying to use train_test_split from package scikit Learn, but I am having trouble with parameter stratify. Hereafter is the code:



        
5条回答
  •  无人及你
    2020-12-22 22:04

    For my future self who comes here via Google:

    train_test_split is now in model_selection, hence:

    from sklearn.model_selection import train_test_split
    
    # given:
    # features: xs
    # ground truth: ys
    
    x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                        test_size=0.33,
                                                        random_state=0,
                                                        stratify=ys)
    

    is the way to use it. Setting the random_state is desirable for reproducibility.

提交回复
热议问题