Multilabel-indicator is not supported for confusion matrix

前端 未结 3 1373
野性不改
野性不改 2020-12-05 18:35

multilabel-indicator is not supported is the error message I get, when trying to run:

confusion_matrix(y_test, predictions)

y

相关标签:
3条回答
  • 2020-12-05 19:02
    from sklearn.metrics import confusion_matrix
    
    predictions_one_hot = model.predict(test_data)
    cm = confusion_matrix(labels_one_hot.argmax(axis=1), predictions_one_hot.argmax(axis=1))
    print(cm)
    

    Output would be something like this:

    [[298   2  47  15  77   3  49]
     [ 14  31   2   0   5   1   2]
     [ 64   5 262  22  94  38  43]
     [ 16   1  20 779  15  14  34]
     [ 49   0  71  33 316   7 118]
     [ 14   0  42  23   5 323   9]
     [ 20   1  27  32  97  13 436]]
    
    0 讨论(0)
  • 2020-12-05 19:12

    The confusion matrix takes a vector of labels (not the one-hot encoding). You should run

    confusion_matrix(y_test.values.argmax(axis=1), predictions.argmax(axis=1))
    
    0 讨论(0)
  • 2020-12-05 19:13

    No, your input to confusion_matrix must be a list of predictions, not OHEs (one hot encodings). Call argmax on your y_test and y_pred, and you should get what you expect.

    confusion_matrix(
        y_test.values.argmax(axis=1), predictions.argmax(axis=1))
    
    array([[1, 0],
           [0, 2]])
    
    0 讨论(0)
提交回复
热议问题