Sensitivity specificity plot python

允我心安 提交于 2020-01-16 03:39:50

问题


I'm trying to reproduce a Sensitivity specificity plot similar to this one: where the X axis is the threshold

But I have not found how to do it, some skalern metrics like ROC curve return the true positive and false positive, but I have not found any option to make this plot there.

I'm trying to compare the prob with the actal label to keep the count, the the plot I get is like this:

Hence the X label has to be some how normalized so the curves can actually go up and down.


回答1:


Building upon @ApproachingDarknessFish's answer, you can fit a variety of distributions to the resulting histograms, not all of which fall outside of [0,1]. For example the beta distribution will do a decent job of capturing most unimodal distributions on [0,1], at least for visualization's sake:

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats

test_y = np.array([0]*100 + [1]*100)
predicted_y_probs = np.concatenate((np.random.beta(2,5,100), np.random.beta(8,3,100)))

def estimate_beta(X):
    xbar = np.mean(X)
    vbar = np.var(X,ddof=1)
    alphahat = xbar*(xbar*(1-xbar)/vbar - 1)
    betahat = (1-xbar)*(xbar*(1-xbar)/vbar - 1)
    return alphahat, betahat

positive_beta_estimates = estimate_beta(predicted_y_probs[test_y == 1])
negative_beta_estimates = estimate_beta(predicted_y_probs[test_y == 0])

unit_interval = np.linspace(0,1,100)
plt.plot(unit_interval, scipy.stats.beta.pdf(unit_interval, *positive_beta_estimates), c='r', label="positive")
plt.plot(unit_interval, scipy.stats.beta.pdf(unit_interval, *negative_beta_estimates), c='g', label="negative")

# Show the threshold.
plt.axvline(0.5, c='black', ls='dashed')
plt.xlim(0,1)

# Add labels
plt.legend()




回答2:


I don't think that plot is showing what you think it's showing. As the threshold drops to zero, the sensitivity will approach 1, since 100% of the observations will be categorized as positive and the false negative rate will drop to zero. Likewise, the selectivity will approach 1 as the threshold approaches 1, since every observation will be categorized as negative and the false positive rate will be zero. So this plot is not showing sensitivity or selectivity.

To plot selectivity and sensitivity on the x-axis as a function of threshold, we can use the builtin ROC functionality and extract the values from it to plot them in our own way. Given a vector of binary labels test_y, a matrix of associated predictors test_x, and a fit RandomForestClassifier object rfc:

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_score, recall_score

# Get the estimated probabilities of each observation being categorized as positive
# [:,1] for probabilities of negative
predicted_y_probs = rfc.predict_proba(test_x)[:,0]

thresholds = np.linspace(0,1,20) # or however many points you want

sensitivities = [recall_score(test_y, predicted_y_probs >= t) for t in thresholds]
selectivities = [precision_score(test_y, predicted_y_probs >= t) for t in thresholds]
plt.plot(thresholds, sensitivies, label='sensitivity')
plt.plot(thresholds, selectivities, label='selectivity')
plt.legend()

However, this will not recreate the plot you've provided as a reference, which seems to be showing the distribution of estimated probabilities of each observation being categorized as positive. In other words, the threshold in that plot is a constant, and the x-axis shows us where each prediction falls relative to that (stationary) threshold. It does not directly tell us either sensitivity or selectivity. If you really want a plot that looks like that, keep reading.

I can't think of way to reconstruct those smooth curves, since a density plot will extend below zero and above 1, but we can show the information using histograms. Using the same variables as before:

# Specify range to ensure both groups show up the same width.
bins = np.linspace(0,1,10)

# Show distributions of estimated probabilities for the two classes.
plt.hist(predicted_y_probs[test_y == 1], alpha=0.5, color='red', label='positive', bins=bins)
plt.hist(predicted_y_probs[test_y == 0], alpha=0.5, color='green', label='negative', bins=bins)

# Show the threshold.
plt.axvline(0.5, c='black', ls='dashed')

# Add labels
plt.legend()

I ran this code for the classic Iris dataset using only two species of the three species, and got the following output. Versicolor is "positive", viriginica is "negative", and setosa was ignored to produce a binary classification. Note that my model had perfect recall so all the probabilities for versicolor are all very close to 1.0. It's rather blocky due to only having 100 samples, most of which were correctly categorized, but hopefully it gets the idea across.



来源:https://stackoverflow.com/questions/57280577/sensitivity-specificity-plot-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!