scikit-learn .predict() default threshold

后端 未结 5 1520
长发绾君心
长发绾君心 2020-12-02 05:00

I\'m working on a classification problem with unbalanced classes (5% 1\'s). I want to predict the class, not the probability.

In a binary classification problem, is

5条回答
  •  我在风中等你
    2020-12-02 05:41

    The threshold in scikit learn is 0.5 for binary classification and whichever class has the greatest probability for multiclass classification. In many problems a much better result may be obtained by adjusting the threshold. However, this must be done with care and NOT on the holdout test data but by cross validation on the training data. If you do any adjustment of the threshold on your test data you are just overfitting the test data.

    Most methods of adjusting the threshold is based on the receiver operating characteristics (ROC) and Youden's J statistic but it can also be done by other methods such as a search with a genetic algorithm.

    Here is a peer review journal article describing doing this in medicine:

    http://www.ncbi.nlm.nih.gov/pmc/articles/PMC2515362/

    So far as I know there is no package for doing it in Python but it is relatively simple (but inefficient) to find it with a brute force search in Python.

    This is some R code that does it.

    ## load data
    DD73OP <- read.table("/my_probabilites.txt", header=T, quote="\"")
    
    library("pROC")
    # No smoothing
    roc_OP <- roc(DD73OP$tc, DD73OP$prob)
    auc_OP <- auc(roc_OP)
    auc_OP
    Area under the curve: 0.8909
    plot(roc_OP)
    
    # Best threshold
    # Method: Youden
    #Youden's J statistic (Youden, 1950) is employed. The optimal cut-off is the threshold that maximizes the distance to the identity (diagonal) line. Can be shortened to "y".
    #The optimality criterion is:
    #max(sensitivities + specificities)
    coords(roc_OP, "best", ret=c("threshold", "specificity", "sensitivity"), best.method="youden")
    #threshold specificity sensitivity 
    #0.7276835   0.9092466   0.7559022
    

提交回复
热议问题