A calibration curve (sometimes called a “reliability diagram”) tells you whether your model’s predicted probabilities accurately reflect the real chance of your model being right.
It’s very common for neural networks to overestimate the confidence in their predictions, and this type of diagram helps us detect when this phenomenon occurs. Here’s an example:
On the x-axis we have our model’s predicted confidence. On the y-axis we plot the model accuracy given its predicted confidence. We can see from this particular diagram that the model is “overconfident” when it makes a prediction in the range between 0.5
to 0.7
.
Calibration curves for multiclass classifiers
Scikit learn provides a function to compute calibration curves for binary classification problems. However, in many cases we want to obtain the calibration curve for a model that makes predictions for more than 2 classes.
We can look to Guo et al. to see how they generate their calibration curve plots.
They propose “binning” all predicted confidences into equally wide bins. Where is the bin containing the set of indices of samples that fall into interval .
For each bin we can compute the bin accuracy (which is the y-axis on our graph) using the following formula:
Here, is if the example label belongs to the same class as the prediction , and otherwise.
To sum up, we compute the y-axis of our plot by first segmenting all predicted confidence scores into M bins. Each of these prediction scores are associated to a class . For each bin, we count the number of examples whose labels match the class associated to our predicted score and divide by the total count of items in the bin.
Code
Here is the code to compute and plot the calibration curves for your models in matplotlib.
import numpy as np
import matplotlib.pyplot as plt
def multiclass_calibration_curve(probs, labels, bins=10):
'''
Args:
probs (ndarray):
NxM predicted probabilities for N examples and M classes.
labels (ndarray):
Vector of size N where each entry is an integer class label.
bins (int):
Number of bins to divide the prediction probabilities into.
Returns:
midpoints (ndarray):
Midpoint value of each bin
accuracies (ndarray):
Fraction of examples that are positive in bin
mean_confidences:
Average predicted confidences in each bin
'''
step_size = 1.0 / bins
n_classes = probs.shape[1]
labels_ohe = np.eye(n_classes)[labels.astype(np.int64)]
midpoints = []
mean_confidences = []
accuracies = []
for i in range(bins):
beg = i * step_size
end = (i + 1) * step_size
bin_mask = (probs >= beg) & (probs < end)
bin_cnt = bin_mask.astype(np.float32).sum()
bin_confs = probs[bin_mask]
bin_acc = labels_ohe[bin_mask].sum() / bin_cnt
midpoints.append((beg+end)/2.)
mean_confidences.append(np.mean(bin_confs))
accuracies.append(bin_acc)
return midpoints, accuracies, mean_confidences
def plot_multiclass_calibration_curve(probs, labels, bins=10, title=None):
'''
Plot calibration curve
'''
title = 'Reliability Diagram' if title is None else title
midpoints, accuracies, mean_confidences = multiclass_calibration_curve(probs, labels, bins=bins)
plt.bar(midpoints, accuracies, width=1.0/float(bins), align='center', lw=1, ec='#000000', fc='#2233aa', alpha=1, label='Model', zorder=0)
plt.scatter(midpoints, accuracies, lw=2, ec='black', fc="#ffffff", zorder=2)
plt.plot(np.linspace(0, 1.0, 20), np.linspace(0, 1.0, 20), '--', lw=2, alpha=.7, color='gray', label='Perfectly calibrated', zorder=1)
plt.xlim(0.0, 1.0)
plt.ylim(0.0, 1.0)
plt.xlabel('\nconfidence')
plt.ylabel('accuracy\n')
plt.title(title+'\n')
plt.xticks(midpoints, rotation=-45)
plt.legend(loc='upper left')
plt.tight_layout()
return midpoints, accuracies, mean_confidences
Discussion