How to write a confusion matrix in Python?

前端 未结 14 1940
太阳男子
太阳男子 2020-12-04 06:48

I wrote a confusion matrix calculation code in Python:

def conf_mat(prob_arr, input_arr):
        # confusion matrix
        conf_arr = [[0, 0], [0, 0]]

            


        
14条回答
  •  粉色の甜心
    2020-12-04 07:19

    Update

    Since writing this post, I've updated my library implementation to include a few other nice features. As with the code below, no third-party dependencies are required. The class can also output a nice tabulation table, similar to many commonly used statistical packages. See this Gist.

    Example usage of the above Gist

    # Example Usage
    actual      = ["A", "B", "C", "C", "B", "C", "C", "B", "A", "A", "B", "A", "B", "C", "A", "B", "C"]
    predicted   = ["A", "B", "B", "C", "A", "C", "A", "B", "C", "A", "B", "B", "B", "C", "A", "A", "C"]
    
    # Initialize Performance Class
    performance = Performance(actual, predicted)
    
    # Print Confusion Matrix
    performance.tabulate()
    

    Here's an example of the output:

    ===================================
            Aᴬ      Bᴬ      Cᴬ
    
    Aᴾ      3       2       1
    
    Bᴾ      1       4       1
    
    Cᴾ      1       0       4
    
    Note: classᴾ = Predicted, classᴬ = Actual
    ===================================
    

    In addition to raw counts, we can output a normalized confusion matrix (i.e. with proportions)

    # Print Normalized Confusion Matrix
    performance.tabulate(normalized = True)
    
    ===================================
            Aᴬ      Bᴬ      Cᴬ
    
    Aᴾ      17.65%  11.76%  5.88%
    
    Bᴾ      5.88%   23.53%  5.88%
    
    Cᴾ      5.88%   0.00%   23.53%
    
    Note: classᴾ = Predicted, classᴬ = Actual
    ===================================
    

    A Simple Multiclass Implementation

    A multi-class confusion matrix can be computed incredibly simply with vanilla Python in roughly O(N) time. All we need to do is pair up the unique classes found in the actual vector into a 2-dimensional list. From there, we simply iterate through the zipped actual and predicted vectors and populate the counts.

    # A Simple Confusion Matrix Implementation
    def confusionmatrix(actual, predicted, normalize = False):
        """
        Generate a confusion matrix for multiple classification
        @params:
            actual      - a list of integers or strings for known classes
            predicted   - a list of integers or strings for predicted classes
            normalize   - optional boolean for matrix normalization
        @return:
            matrix      - a 2-dimensional list of pairwise counts
        """
        unique = sorted(set(actual))
        matrix = [[0 for _ in unique] for _ in unique]
        imap   = {key: i for i, key in enumerate(unique)}
        # Generate Confusion Matrix
        for p, a in zip(predicted, actual):
            matrix[imap[p]][imap[a]] += 1
        # Matrix Normalization
        if normalize:
            sigma = sum([sum(matrix[imap[i]]) for i in unique])
            matrix = [row for row in map(lambda i: list(map(lambda j: j / sigma, i)), matrix)]
        return matrix
    

    Usage

    # Input Below Should Return: [[2, 1, 0], [0, 2, 1], [1, 2, 1]]
    cm = confusionmatrix(
        [1, 1, 2, 0, 1, 1, 2, 0, 0, 1], # actual
        [0, 1, 1, 0, 2, 1, 2, 2, 0, 2]  # predicted
    )
    
    # And The Output
    print(cm)
    [[2, 1, 0], [0, 2, 1], [1, 2, 1]]
    

    Note: the actual classes are along the columns and the predicted classes are along the rows.

    # Actual
    # 0  1  2
      #  #  #   
    [[2, 1, 0], # 0
     [0, 2, 1], # 1  Predicted
     [1, 2, 1]] # 2
    

    Class Names Can be Strings or Integers

    # Input Below Should Return: [[2, 1, 0], [0, 2, 1], [1, 2, 1]]
    cm = confusionmatrix(
        ["B", "B", "C", "A", "B", "B", "C", "A", "A", "B"], # actual
        ["A", "B", "B", "A", "C", "B", "C", "C", "A", "C"]  # predicted
    )
    
    # And The Output
    print(cm)
    [[2, 1, 0], [0, 2, 1], [1, 2, 1]]
    

    You Can Also Return The Matrix With Proportions (Normalization)

    # Input Below Should Return: [[0.2, 0.1, 0.0], [0.0, 0.2, 0.1], [0.1, 0.2, 0.1]]
    cm = confusionmatrix(
        ["B", "B", "C", "A", "B", "B", "C", "A", "A", "B"], # actual
        ["A", "B", "B", "A", "C", "B", "C", "C", "A", "C"], # predicted
        normalize = True
    )
    
    # And The Output
    print(cm)
    [[0.2, 0.1, 0.0], [0.0, 0.2, 0.1], [0.1, 0.2, 0.1]]
    

    Extracting Statistics From a Multiple Classification Confusion Matrix

    Once you have the matrix, you can compute a bunch of statistics to assess your classifier. That said, extracting the values out of a confusion matrix setup for multiple classification can be a bit of a headache. Here's a function that returns both the confusion matrix and statistics by class:

    # Not Required, But Nice For Legibility
    from collections import OrderedDict
    
    # A Simple Confusion Matrix Implementation
    def confusionmatrix(actual, predicted, normalize = False):
        """
        Generate a confusion matrix for multiple classification
        @params:
            actual      - a list of integers or strings for known classes
            predicted   - a list of integers or strings for predicted classes
        @return:
            matrix      - a 2-dimensional list of pairwise counts
            statistics  - a dictionary of statistics for each class
        """
        unique = sorted(set(actual))
        matrix = [[0 for _ in unique] for _ in unique]
        imap   = {key: i for i, key in enumerate(unique)}
        # Generate Confusion Matrix
        for p, a in zip(predicted, actual):
            matrix[imap[p]][imap[a]] += 1
        # Get Confusion Matrix Sum
        sigma = sum([sum(matrix[imap[i]]) for i in unique])
        # Scaffold Statistics Data Structure
        statistics = OrderedDict(((i, {"counts" : OrderedDict(), "stats" : OrderedDict()}) for i in unique))
        # Iterate Through Classes & Compute Statistics
        for i in unique:
            loc = matrix[imap[i]][imap[i]]
            row = sum(matrix[imap[i]][:])
            col = sum([row[imap[i]] for row in matrix])
            # Get TP/TN/FP/FN
            tp  = loc
            fp  = row - loc
            fn  = col - loc
            tn  = sigma - row - col + loc
            # Populate Counts Dictionary
            statistics[i]["counts"]["tp"]   = tp
            statistics[i]["counts"]["fp"]   = fp
            statistics[i]["counts"]["tn"]   = tn
            statistics[i]["counts"]["fn"]   = fn
            statistics[i]["counts"]["pos"]  = tp + fn
            statistics[i]["counts"]["neg"]  = tn + fp
            statistics[i]["counts"]["n"]    = tp + tn + fp + fn
            # Populate Statistics Dictionary
            statistics[i]["stats"]["sensitivity"]   = tp / (tp + fn) if tp > 0 else 0.0
            statistics[i]["stats"]["specificity"]   = tn / (tn + fp) if tn > 0 else 0.0
            statistics[i]["stats"]["precision"]     = tp / (tp + fp) if tp > 0 else 0.0
            statistics[i]["stats"]["recall"]        = tp / (tp + fn) if tp > 0 else 0.0
            statistics[i]["stats"]["tpr"]           = tp / (tp + fn) if tp > 0 else 0.0
            statistics[i]["stats"]["tnr"]           = tn / (tn + fp) if tn > 0 else 0.0
            statistics[i]["stats"]["fpr"]           = fp / (fp + tn) if fp > 0 else 0.0
            statistics[i]["stats"]["fnr"]           = fn / (fn + tp) if fn > 0 else 0.0
            statistics[i]["stats"]["accuracy"]      = (tp + tn) / (tp + tn + fp + fn) if (tp + tn) > 0 else 0.0
            statistics[i]["stats"]["f1score"]       = (2 * tp) / ((2 * tp) + (fp + fn)) if tp > 0 else 0.0
            statistics[i]["stats"]["fdr"]           = fp / (fp + tp) if fp > 0 else 0.0
            statistics[i]["stats"]["for"]           = fn / (fn + tn) if fn > 0 else 0.0
            statistics[i]["stats"]["ppv"]           = tp / (tp + fp) if tp > 0 else 0.0
            statistics[i]["stats"]["npv"]           = tn / (tn + fn) if tn > 0 else 0.0
        # Matrix Normalization
        if normalize:
            matrix = [row for row in map(lambda i: list(map(lambda j: j / sigma, i)), matrix)]
        return matrix, statistics
    

    Computed Statistics

    Above, the confusion matrix is used to tabulate statistics for each class, which are returned in an OrderedDict with the following structure:

    OrderedDict(
        [
            ('A', {
                'stats' : OrderedDict([
                    ('sensitivity', 0.6666666666666666), 
                    ('specificity', 0.8571428571428571), 
                    ('precision', 0.6666666666666666), 
                    ('recall', 0.6666666666666666), 
                    ('tpr', 0.6666666666666666), 
                    ('tnr', 0.8571428571428571), 
                    ('fpr', 0.14285714285714285), 
                    ('fnr', 0.3333333333333333), 
                    ('accuracy', 0.8), 
                    ('f1score', 0.6666666666666666), 
                    ('fdr', 0.3333333333333333), 
                    ('for', 0.14285714285714285), 
                    ('ppv', 0.6666666666666666), 
                    ('npv', 0.8571428571428571)
                ]), 
                'counts': OrderedDict([
                    ('tp', 2), 
                    ('fp', 1), 
                    ('tn', 6), 
                    ('fn', 1), 
                    ('pos', 3), 
                    ('neg', 7), 
                    ('n', 10)
                ])
            }), 
            ('B', {
                'stats': OrderedDict([
                    ('sensitivity', 0.4), 
                    ('specificity', 0.8), 
                    ('precision', 0.6666666666666666), 
                    ('recall', 0.4), 
                    ('tpr', 0.4), 
                    ('tnr', 0.8), 
                    ('fpr', 0.2), 
                    ('fnr', 0.6), 
                    ('accuracy', 0.6), 
                    ('f1score', 0.5), 
                    ('fdr', 0.3333333333333333), 
                    ('for', 0.42857142857142855), 
                    ('ppv', 0.6666666666666666), 
                    ('npv', 0.5714285714285714)
                ]), 
                'counts': OrderedDict([
                    ('tp', 2), 
                    ('fp', 1), 
                    ('tn', 4), 
                    ('fn', 3), 
                    ('pos', 5), 
                    ('neg', 5), 
                    ('n', 10)
                ])
            }), 
            ('C', {
                'stats': OrderedDict([
                    ('sensitivity', 0.5), 
                    ('specificity', 0.625), 
                    ('precision', 0.25), 
                    ('recall', 0.5), 
                    ('tpr', 0.5), 
                    ('tnr', 0.625), (
                    'fpr', 0.375), (
                    'fnr', 0.5), 
                    ('accuracy', 0.6), 
                    ('f1score', 0.3333333333333333), 
                    ('fdr', 0.75), 
                    ('for', 0.16666666666666666), 
                    ('ppv', 0.25), 
                    ('npv', 0.8333333333333334)
                ]), 
                'counts': OrderedDict([
                    ('tp', 1), 
                    ('fp', 3), 
                    ('tn', 5), 
                    ('fn', 1), 
                    ('pos', 2), 
                    ('neg', 8), 
                    ('n', 10)
                ])
            })
        ]
    )
    

提交回复
热议问题