I wrote a confusion matrix calculation code in Python:
def conf_mat(prob_arr, input_arr):
# confusion matrix
conf_arr = [[0, 0], [0, 0]]
Update
Since writing this post, I've updated my library implementation to include a few other nice features. As with the code below, no third-party dependencies are required. The class can also output a nice tabulation table, similar to many commonly used statistical packages. See this Gist.
Example usage of the above Gist
# Example Usage
actual = ["A", "B", "C", "C", "B", "C", "C", "B", "A", "A", "B", "A", "B", "C", "A", "B", "C"]
predicted = ["A", "B", "B", "C", "A", "C", "A", "B", "C", "A", "B", "B", "B", "C", "A", "A", "C"]
# Initialize Performance Class
performance = Performance(actual, predicted)
# Print Confusion Matrix
performance.tabulate()
Here's an example of the output:
===================================
Aᴬ Bᴬ Cᴬ
Aᴾ 3 2 1
Bᴾ 1 4 1
Cᴾ 1 0 4
Note: classᴾ = Predicted, classᴬ = Actual
===================================
In addition to raw counts, we can output a normalized confusion matrix (i.e. with proportions)
# Print Normalized Confusion Matrix
performance.tabulate(normalized = True)
===================================
Aᴬ Bᴬ Cᴬ
Aᴾ 17.65% 11.76% 5.88%
Bᴾ 5.88% 23.53% 5.88%
Cᴾ 5.88% 0.00% 23.53%
Note: classᴾ = Predicted, classᴬ = Actual
===================================
A Simple Multiclass Implementation
A multi-class confusion matrix can be computed incredibly simply with vanilla Python in roughly O(N) time. All we need to do is pair up the unique classes found in the actual
vector into a 2-dimensional list. From there, we simply iterate through the zipped actual
and predicted
vectors and populate the counts.
# A Simple Confusion Matrix Implementation
def confusionmatrix(actual, predicted, normalize = False):
"""
Generate a confusion matrix for multiple classification
@params:
actual - a list of integers or strings for known classes
predicted - a list of integers or strings for predicted classes
normalize - optional boolean for matrix normalization
@return:
matrix - a 2-dimensional list of pairwise counts
"""
unique = sorted(set(actual))
matrix = [[0 for _ in unique] for _ in unique]
imap = {key: i for i, key in enumerate(unique)}
# Generate Confusion Matrix
for p, a in zip(predicted, actual):
matrix[imap[p]][imap[a]] += 1
# Matrix Normalization
if normalize:
sigma = sum([sum(matrix[imap[i]]) for i in unique])
matrix = [row for row in map(lambda i: list(map(lambda j: j / sigma, i)), matrix)]
return matrix
Usage
# Input Below Should Return: [[2, 1, 0], [0, 2, 1], [1, 2, 1]]
cm = confusionmatrix(
[1, 1, 2, 0, 1, 1, 2, 0, 0, 1], # actual
[0, 1, 1, 0, 2, 1, 2, 2, 0, 2] # predicted
)
# And The Output
print(cm)
[[2, 1, 0], [0, 2, 1], [1, 2, 1]]
Note: the actual
classes are along the columns and the predicted
classes are along the rows.
# Actual
# 0 1 2
# # #
[[2, 1, 0], # 0
[0, 2, 1], # 1 Predicted
[1, 2, 1]] # 2
Class Names Can be Strings or Integers
# Input Below Should Return: [[2, 1, 0], [0, 2, 1], [1, 2, 1]]
cm = confusionmatrix(
["B", "B", "C", "A", "B", "B", "C", "A", "A", "B"], # actual
["A", "B", "B", "A", "C", "B", "C", "C", "A", "C"] # predicted
)
# And The Output
print(cm)
[[2, 1, 0], [0, 2, 1], [1, 2, 1]]
You Can Also Return The Matrix With Proportions (Normalization)
# Input Below Should Return: [[0.2, 0.1, 0.0], [0.0, 0.2, 0.1], [0.1, 0.2, 0.1]]
cm = confusionmatrix(
["B", "B", "C", "A", "B", "B", "C", "A", "A", "B"], # actual
["A", "B", "B", "A", "C", "B", "C", "C", "A", "C"], # predicted
normalize = True
)
# And The Output
print(cm)
[[0.2, 0.1, 0.0], [0.0, 0.2, 0.1], [0.1, 0.2, 0.1]]
Extracting Statistics From a Multiple Classification Confusion Matrix
Once you have the matrix, you can compute a bunch of statistics to assess your classifier. That said, extracting the values out of a confusion matrix setup for multiple classification can be a bit of a headache. Here's a function that returns both the confusion matrix and statistics by class:
# Not Required, But Nice For Legibility
from collections import OrderedDict
# A Simple Confusion Matrix Implementation
def confusionmatrix(actual, predicted, normalize = False):
"""
Generate a confusion matrix for multiple classification
@params:
actual - a list of integers or strings for known classes
predicted - a list of integers or strings for predicted classes
@return:
matrix - a 2-dimensional list of pairwise counts
statistics - a dictionary of statistics for each class
"""
unique = sorted(set(actual))
matrix = [[0 for _ in unique] for _ in unique]
imap = {key: i for i, key in enumerate(unique)}
# Generate Confusion Matrix
for p, a in zip(predicted, actual):
matrix[imap[p]][imap[a]] += 1
# Get Confusion Matrix Sum
sigma = sum([sum(matrix[imap[i]]) for i in unique])
# Scaffold Statistics Data Structure
statistics = OrderedDict(((i, {"counts" : OrderedDict(), "stats" : OrderedDict()}) for i in unique))
# Iterate Through Classes & Compute Statistics
for i in unique:
loc = matrix[imap[i]][imap[i]]
row = sum(matrix[imap[i]][:])
col = sum([row[imap[i]] for row in matrix])
# Get TP/TN/FP/FN
tp = loc
fp = row - loc
fn = col - loc
tn = sigma - row - col + loc
# Populate Counts Dictionary
statistics[i]["counts"]["tp"] = tp
statistics[i]["counts"]["fp"] = fp
statistics[i]["counts"]["tn"] = tn
statistics[i]["counts"]["fn"] = fn
statistics[i]["counts"]["pos"] = tp + fn
statistics[i]["counts"]["neg"] = tn + fp
statistics[i]["counts"]["n"] = tp + tn + fp + fn
# Populate Statistics Dictionary
statistics[i]["stats"]["sensitivity"] = tp / (tp + fn) if tp > 0 else 0.0
statistics[i]["stats"]["specificity"] = tn / (tn + fp) if tn > 0 else 0.0
statistics[i]["stats"]["precision"] = tp / (tp + fp) if tp > 0 else 0.0
statistics[i]["stats"]["recall"] = tp / (tp + fn) if tp > 0 else 0.0
statistics[i]["stats"]["tpr"] = tp / (tp + fn) if tp > 0 else 0.0
statistics[i]["stats"]["tnr"] = tn / (tn + fp) if tn > 0 else 0.0
statistics[i]["stats"]["fpr"] = fp / (fp + tn) if fp > 0 else 0.0
statistics[i]["stats"]["fnr"] = fn / (fn + tp) if fn > 0 else 0.0
statistics[i]["stats"]["accuracy"] = (tp + tn) / (tp + tn + fp + fn) if (tp + tn) > 0 else 0.0
statistics[i]["stats"]["f1score"] = (2 * tp) / ((2 * tp) + (fp + fn)) if tp > 0 else 0.0
statistics[i]["stats"]["fdr"] = fp / (fp + tp) if fp > 0 else 0.0
statistics[i]["stats"]["for"] = fn / (fn + tn) if fn > 0 else 0.0
statistics[i]["stats"]["ppv"] = tp / (tp + fp) if tp > 0 else 0.0
statistics[i]["stats"]["npv"] = tn / (tn + fn) if tn > 0 else 0.0
# Matrix Normalization
if normalize:
matrix = [row for row in map(lambda i: list(map(lambda j: j / sigma, i)), matrix)]
return matrix, statistics
Computed Statistics
Above, the confusion matrix is used to tabulate statistics for each class, which are returned in an OrderedDict
with the following structure:
OrderedDict(
[
('A', {
'stats' : OrderedDict([
('sensitivity', 0.6666666666666666),
('specificity', 0.8571428571428571),
('precision', 0.6666666666666666),
('recall', 0.6666666666666666),
('tpr', 0.6666666666666666),
('tnr', 0.8571428571428571),
('fpr', 0.14285714285714285),
('fnr', 0.3333333333333333),
('accuracy', 0.8),
('f1score', 0.6666666666666666),
('fdr', 0.3333333333333333),
('for', 0.14285714285714285),
('ppv', 0.6666666666666666),
('npv', 0.8571428571428571)
]),
'counts': OrderedDict([
('tp', 2),
('fp', 1),
('tn', 6),
('fn', 1),
('pos', 3),
('neg', 7),
('n', 10)
])
}),
('B', {
'stats': OrderedDict([
('sensitivity', 0.4),
('specificity', 0.8),
('precision', 0.6666666666666666),
('recall', 0.4),
('tpr', 0.4),
('tnr', 0.8),
('fpr', 0.2),
('fnr', 0.6),
('accuracy', 0.6),
('f1score', 0.5),
('fdr', 0.3333333333333333),
('for', 0.42857142857142855),
('ppv', 0.6666666666666666),
('npv', 0.5714285714285714)
]),
'counts': OrderedDict([
('tp', 2),
('fp', 1),
('tn', 4),
('fn', 3),
('pos', 5),
('neg', 5),
('n', 10)
])
}),
('C', {
'stats': OrderedDict([
('sensitivity', 0.5),
('specificity', 0.625),
('precision', 0.25),
('recall', 0.5),
('tpr', 0.5),
('tnr', 0.625), (
'fpr', 0.375), (
'fnr', 0.5),
('accuracy', 0.6),
('f1score', 0.3333333333333333),
('fdr', 0.75),
('for', 0.16666666666666666),
('ppv', 0.25),
('npv', 0.8333333333333334)
]),
'counts': OrderedDict([
('tp', 1),
('fp', 3),
('tn', 5),
('fn', 1),
('pos', 2),
('neg', 8),
('n', 10)
])
})
]
)