Scikit Learn SVC decision_function and predict

后端 未结 6 1562
攒了一身酷
攒了一身酷 2020-11-30 17:15

I\'m trying to understand the relationship between decision_function and predict, which are instance methods of SVC (http://scikit-learn.org/stable/modules/generated/sklear

6条回答
  •  失恋的感觉
    2020-11-30 17:57

    For those interested, I'll post a quick example of the predict function translated from C++ (here) to python:

    # I've only implemented the linear and rbf kernels
    def kernel(params, sv, X):
        if params.kernel == 'linear':
            return [np.dot(vi, X) for vi in sv]
        elif params.kernel == 'rbf':
            return [math.exp(-params.gamma * np.dot(vi - X, vi - X)) for vi in sv]
    
    # This replicates clf.decision_function(X)
    def decision_function(params, sv, nv, a, b, X):
        # calculate the kernels
        k = kernel(params, sv, X)
    
        # define the start and end index for support vectors for each class
        start = [sum(nv[:i]) for i in range(len(nv))]
        end = [start[i] + nv[i] for i in range(len(nv))]
    
        # calculate: sum(a_p * k(x_p, x)) between every 2 classes
        c = [ sum(a[ i ][p] * k[p] for p in range(start[j], end[j])) +
              sum(a[j-1][p] * k[p] for p in range(start[i], end[i]))
                    for i in range(len(nv)) for j in range(i+1,len(nv))]
    
        # add the intercept
        return [sum(x) for x in zip(c, b)]
    
    # This replicates clf.predict(X)
    def predict(params, sv, nv, a, b, cs, X):
        ''' params = model parameters
            sv = support vectors
            nv = # of support vectors per class
            a  = dual coefficients
            b  = intercepts 
            cs = list of class names
            X  = feature to predict       
        '''
        decision = decision_function(params, sv, nv, a, b, X)
        votes = [(i if decision[p] > 0 else j) for p,(i,j) in enumerate((i,j) 
                                               for i in range(len(cs))
                                               for j in range(i+1,len(cs)))]
    
        return cs[max(set(votes), key=votes.count)]
    

    There are a lot of input arguments for predict and decision_function, but note that these are all used internally in by the model when calling predict(X). In fact, all of the arguments are accessible to you inside the model after fitting:

    # Create model
    clf = svm.SVC(gamma=0.001, C=100.)
    
    # Fit model using features, X, and labels, Y.
    clf.fit(X, y)
    
    # Get parameters from model
    params = clf.get_params()
    sv = clf.support_vectors
    nv = clf.n_support_
    a  = clf.dual_coef_
    b  = clf._intercept_
    cs = clf.classes_
    
    # Use the functions to predict
    print(predict(params, sv, nv, a, b, cs, X))
    
    # Compare with the builtin predict
    print(clf.predict(X))
    

提交回复
热议问题