How to Create a Single Legend for All Subplots in Matplotlib

Jinku Hu Feb 02, 2024
  1. Make a Single Legend for All Subplots With figure.legend Method in Matplotlib
  2. Make a Single Legend for All Subplots With figure.legend Method When Line Handles and Lines Are Different in Matplotlib
How to Create a Single Legend for All Subplots in Matplotlib

Matplotlib figure class has a legend method to place the legend on the figure level but not the subplot level. It is especially convenient if the lines’ pattern and labels are the same among all the subplots.

Make a Single Legend for All Subplots With figure.legend Method in Matplotlib

import matplotlib.pyplot as plt

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)


for ax in fig.axes:
    ax.plot([0, 10], [0, 10], label="linear")

lines, labels = fig.axes[-1].get_legend_handles_labels()

fig.legend(lines, labels, loc="upper center")

plt.show()

Matplotlib Figure Legend get legend handles labels

lines, labels = fig.axes[-1].get_legend_handles_labels()

Because we have the presumption that all the subplots have the same lines and labels, therefore, that handles and labels of the last Axes could be used for the whole figure.

Make a Single Legend for All Subplots With figure.legend Method When Line Handles and Lines Are Different in Matplotlib

If the line pattern and labels are different among subplots but a single legend is required for all subplots, we need to get all the line handles and labels from all the subplots.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 501)

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

axes[0, 0].plot(x, np.sin(x), color="k", label="sin(x)")
axes[0, 1].plot(x, np.cos(x), color="b", label="cos(x)")
axes[1, 0].plot(x, np.sin(x) + np.cos(x), color="r", label="sin(x)+cos(x)")
axes[1, 1].plot(x, np.sin(x) - np.cos(x), color="m", label="sin(x)-cos(x)")

lines = []
labels = []

for ax in fig.axes:
    axLine, axLabel = ax.get_legend_handles_labels()
    lines.extend(axLine)
    labels.extend(axLabel)


fig.legend(lines, labels, loc="upper right")

plt.show()

Matplotlib Figure Legend_all labels from get_legend_handles_labels

for ax in fig.axes:
    axLine, axLabel = ax.get_legend_handles_labels()
    lines.extend(axLine)
    labels.extend(axLabel)

All the lines handles and labels are added to lines and labels list with list extend method in case more lines and labels exist in single one subplot.

Author: Jinku Hu
Jinku Hu avatar Jinku Hu avatar

Founder of DelftStack.com. Jinku has worked in the robotics and automotive industries for over 8 years. He sharpened his coding skills when he needed to do the automatic testing, data collection from remote servers and report creation from the endurance test. He is from an electrical/electronics engineering background but has expanded his interest to embedded electronics, embedded programming and front-/back-end programming.

LinkedIn Facebook

Related Article - Matplotlib Legend