scikit-learn .predict() default threshold

后端 未结 5 1509
长发绾君心
长发绾君心 2020-12-02 05:00

I\'m working on a classification problem with unbalanced classes (5% 1\'s). I want to predict the class, not the probability.

In a binary classification problem, is

5条回答
  •  不思量自难忘°
    2020-12-02 05:33

    The threshold can be set using clf.predict_proba()

    for example:

    from sklearn.tree import DecisionTreeClassifier
    clf = DecisionTreeClassifier(random_state = 2)
    clf.fit(X_train,y_train)
    # y_pred = clf.predict(X_test)  # default threshold is 0.5
    y_pred = (clf.predict_proba(X_test)[:,1] >= 0.3).astype(bool) # set threshold as 0.3
    

提交回复
热议问题