Scikit Learn SVC decision_function and predict

后端 未结 6 1567
攒了一身酷
攒了一身酷 2020-11-30 17:15

I\'m trying to understand the relationship between decision_function and predict, which are instance methods of SVC (http://scikit-learn.org/stable/modules/generated/sklear

6条回答
  •  野性不改
    2020-11-30 17:46

    There's a really nice Q&A for the multi-class one-vs-one scenario at datascience.sx:

    Question

    I have a multiclass SVM classifier with labels 'A', 'B', 'C', 'D'.

    This is the code I'm running:

    >>>print clf.predict([predict_this])
    ['A']
    >>>print clf.decision_function([predict_this])
    [[ 185.23220833   43.62763596  180.83305074  -93.58628288   62.51448055  173.43335293]]
    

    How can I use the output of decision function to predict the class (A/B/C/D) with the highest probability and if possible, it's value? I have visited https://stackoverflow.com/a/20114601/7760998 but it is for binary classifiers and could not find a good resource which explains the output of decision_function for multiclass classifiers with shape ovo (one-vs-one).

    Edit:

    The above example is for class 'A'. For another input the classifier predicted 'C' and gave the following result in decision_function

    [[ 96.42193513 -11.13296606 111.47424538 -88.5356536 44.29272494 141.0069203 ]]
    

    For another different input which the classifier predicted as 'C' gave the following result from decision_function,

    [[ 290.54180354 -133.93467605  116.37068951 -392.32251314 -130.84421412   284.87653043]]
    

    Had it been ovr (one-vs-rest), it would become easier by selecting the one with higher value, but in ovo (one-vs-one) there are (n * (n - 1)) / 2 values in the resulting list.

    How to deduce which class would be selected based on the decision function?

    Answer

    Your link has sufficient resources, so let's go through:

    When you call decision_function(), you get the output from each of the pairwise classifiers (n*(n-1)/2 numbers total). See pages 127 and 128 of "Support Vector Machines for Pattern Classification".

    Click on the "page 127 and 128" link (not shown here, but in the Stackoverflow answer). You should see:

    • Python's SVM implementation uses one-vs-one. That's exactly what the book is talking about.
    • For each pairwise comparison, we measure the decision function
    • The decision function is the just the regular binary SVM decision boundary

    What does that to do with your question?

    • clf.decision_function() will give you the $D$ for each pairwise comparison
    • The class with the most votes win

    For instance,

    [[ 96.42193513 -11.13296606 111.47424538 -88.5356536 44.29272494 141.0069203 ]]

    is comparing:

    [AB, AC, AD, BC, BD, CD]

    We label each of them by the sign. We get:

    [A, C, A, C, B, C]

    For instance, 96.42193513 is positive and thus A is the label for AB.

    Now we have three C, C would be your prediction. If you repeat my procedure for the other two examples, you will get Python's prediction. Try it!

提交回复
热议问题