Sklearn: ROC for multiclass classification

匿名 (未验证) 提交于 2019-12-03 08:28:06

问题:

I'm doing different text classification experiments. Now I need to calculate the AUC-ROC for each task. For the binary classifications, I already made it work with this code:

scaler = StandardScaler(with_mean=False)  enc = LabelEncoder() y = enc.fit_transform(labels)  feat_sel = SelectKBest(mutual_info_classif, k=200)  clf = linear_model.LogisticRegression()  pipe = Pipeline([('vectorizer', DictVectorizer()),                  ('scaler', StandardScaler(with_mean=False)),                  ('mutual_info', feat_sel),                  ('logistregress', clf)]) y_pred = model_selection.cross_val_predict(pipe, instances, y, cv=10) # instances is a list of dictionaries  #visualisation ROC-AUC  fpr, tpr, thresholds = roc_curve(y, y_pred) auc = auc(fpr, tpr) print('auc =', auc)  plt.figure() plt.title('Receiver Operating Characteristic') plt.plot(fpr, tpr, 'b', label='AUC = %0.2f'% auc) plt.legend(loc='lower right') plt.plot([0,1],[0,1],'r--') plt.xlim([-0.1,1.2]) plt.ylim([-0.1,1.2]) plt.ylabel('True Positive Rate') plt.xlabel('False Positive Rate') plt.show() 

But now I need to do it for the multiclass classification task. I read somewhere that I need to binarize the labels, but I really don't get how to calculate ROC for multiclass classification. Tips?

回答1:

As people mentioned in comments you have to convert your problem into binary by using OneVsAll approach, so you'll have n_class number of ROC curves. A simple example:

from sklearn.metrics import roc_curve, auc from sklearn import datasets from sklearn.multiclass import OneVsRestClassifier from sklearn.svm import LinearSVC from sklearn.preprocessing import label_binarize from sklearn.cross_validation import train_test_split import matplotlib.pyplot as plt  iris = datasets.load_iris() X, y = iris.data, iris.target  y = label_binarize(y, classes=[0,1,2]) n_classes = 3  # shuffle and split training and test sets X_train, X_test, y_train, y_test =\     train_test_split(X, y, test_size=0.33, random_state=0)  # classifier clf = OneVsRestClassifier(LinearSVC(random_state=0)) y_score = clf.fit(X_train, y_train).decision_function(X_test)  # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes):     fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])     roc_auc[i] = auc(fpr[i], tpr[i])  # Plot of a ROC curve for a specific class for i in range(n_classes):     plt.figure()     plt.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f)' % roc_auc[i])     plt.plot([0, 1], [0, 1], 'k--')     plt.xlim([0.0, 1.0])     plt.ylim([0.0, 1.05])     plt.xlabel('False Positive Rate')     plt.ylabel('True Positive Rate')     plt.title('Receiver operating characteristic example')     plt.legend(loc="lower right")     plt.show() 



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!