How does the class_weight parameter in scikit-learn work?

后端 未结 2 1801
[愿得一人]
[愿得一人] 2020-11-29 15:23

I am having a lot of trouble understanding how the class_weight parameter in scikit-learn\'s Logistic Regression operates.

The Situation

2条回答
  •  情歌与酒
    2020-11-29 15:48

    The first answer is good for understanding how it works. But I wanted to understand how I should be using it in practice.

    SUMMARY

    • for moderately imbalanced data WITHOUT noise, there is not much of a difference in applying class weights
    • for moderately imbalanced data WITH noise and strongly imbalanced, it is better to apply class weights
    • param class_weight="balanced" works decent in the absence of you wanting to optimize manually
    • with class_weight="balanced" you capture more true events (higher TRUE recall) but also you are more likely to get false alerts (lower TRUE precision)
      • as a result, the total % TRUE might be higher than actual because of all the false positives
      • AUC might misguide you here if the false alarms are an issue
    • no need to change decision threshold to the imbalance %, even for strong imbalance, ok to keep 0.5 (or somewhere around that depending on what you need)

    NB

    The result might differ when using RF or GBM. sklearn does not have class_weight="balanced" for GBM but lightgbm has LGBMClassifier(is_unbalance=False)

    CODE

    # scikit-learn==0.21.3
    from sklearn import datasets
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import roc_auc_score, classification_report
    import numpy as np
    import pandas as pd
    
    # case: moderate imbalance
    X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
    np.mean(y) # 0.2
    
    LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
    (LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
    LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
    LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
    LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?
    
    roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
    roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
    roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same
    
    # case: strong imbalance
    X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
    np.mean(y) # 0.06
    
    LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
    (LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
    LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
    LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
    LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
    (LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last
    
    roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
    roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
    roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
    roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last
    
    print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
    pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
    pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%
    
    print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
    pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
    pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
    

提交回复
热议问题