Searching Gradients

Are your models calibrated?

A calibration curve (sometimes called a “reliability diagram”) tells you whether your model’s predicted probabilities accurately reflect the real chance of your model being right.

It’s very common for neural networks to overestimate the confidence in their predictions, and this type of diagram helps us detect when this phenomenon occurs. Here’s an example:

On the x-axis we have our model’s predicted confidence. On the y-axis we plot the model accuracy given its predicted confidence. We can see from this particular diagram that the model is “overconfident” when it makes a prediction in the range between 0.5 to 0.7.

Calibration curves for multiclass classifiers

Scikit learn provides a function to compute calibration curves for binary classification problems. However, in many cases we want to obtain the calibration curve for a model that makes predictions for more than 2 classes.

We can look to Guo et al. to see how they generate their calibration curve plots.

They propose “binning” all predicted confidences into equally wide bins. Where is the bin containing the set of indices of samples that fall into interval .

For each bin we can compute the bin accuracy (which is the y-axis on our graph) using the following formula:

Here, is if the example label belongs to the same class as the prediction , and otherwise.

To sum up, we compute the y-axis of our plot by first segmenting all predicted confidence scores into M bins. Each of these prediction scores are associated to a class . For each bin, we count the number of examples whose labels match the class associated to our predicted score and divide by the total count of items in the bin.


Here is the code to compute and plot the calibration curves for your models in matplotlib.

import numpy as np
import matplotlib.pyplot as plt

def multiclass_calibration_curve(probs, labels, bins=10):
        probs (ndarray):
            NxM predicted probabilities for N examples and M classes.
        labels (ndarray):
            Vector of size N where each entry is an integer class label.
        bins (int):
            Number of bins to divide the prediction probabilities into.
        midpoints (ndarray):
            Midpoint value of each bin
        accuracies (ndarray):
            Fraction of examples that are positive in bin
            Average predicted confidences in each bin
    step_size = 1.0 / bins
    n_classes = probs.shape[1]
    labels_ohe = np.eye(n_classes)[labels.astype(np.int64)]

    midpoints = []
    mean_confidences = []
    accuracies = []

    for i in range(bins):
        beg = i * step_size
        end = (i + 1) * step_size

        bin_mask = (probs >= beg) & (probs < end)
        bin_cnt = bin_mask.astype(np.float32).sum()
        bin_confs = probs[bin_mask]
        bin_acc = labels_ohe[bin_mask].sum() / bin_cnt


    return midpoints, accuracies, mean_confidences

def plot_multiclass_calibration_curve(probs, labels, bins=10, title=None):
    Plot calibration curve
    title = 'Reliability Diagram' if title is None else title
    midpoints, accuracies, mean_confidences = multiclass_calibration_curve(probs, labels, bins=bins), accuracies, width=1.0/float(bins), align='center', lw=1, ec='#000000', fc='#2233aa', alpha=1, label='Model', zorder=0)
    plt.scatter(midpoints, accuracies, lw=2, ec='black', fc="#ffffff", zorder=2)
    plt.plot(np.linspace(0, 1.0, 20), np.linspace(0, 1.0, 20), '--', lw=2, alpha=.7, color='gray', label='Perfectly calibrated', zorder=1)
    plt.xlim(0.0, 1.0)
    plt.ylim(0.0, 1.0)
    plt.xticks(midpoints, rotation=-45)
    plt.legend(loc='upper left')
    return midpoints, accuracies, mean_confidences


  1. Obtaining Well Calibrated Probabilities Using Bayesian Binning
  2. On Calibration of Modern Neural Networks
  3. Scikit Learn Calibration Curve

Hey there! I'm Huy and I do research in computer vision, visual search, and AI. You can get updates on new essays by subscribing to my rss feed. Occassionally, I will send out interesting links on twitter so follow me if you like this kind stuff.