How to Make the Legend of the Scatter Plot in Matplotlib

Suraj Joshi Feb 02, 2024
  1. Add a Legend to the 2D Scatter Plot in Matplotlib
  2. Add a Legend to the 3D Scatter Plot in Matplotlib
How to Make the Legend of the Scatter Plot in Matplotlib

Legend is simply the description of various elements in a figure. We can generate a legend of scatter plot using the matplotlib.pyplot.legend function.

Add a Legend to the 2D Scatter Plot in Matplotlib

import numpy as np
import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]

y1 = [i ** 2 for i in x]
y2 = [2 * i + 1 for i in x]

plt.scatter(x, y1, marker="x", color="r", label="x**2")
plt.scatter(x, y2, marker="o", color="b", label="2*x+1")
plt.legend()
plt.show()

Output:

Add legend to a 2D scatter plot

We have two separate scatter plots in the figure: one represented by x and another by the o mark. We assign the label to each scatter plot used as a tag while generating the legend. Then, we create the legend in the figure using the legend() function and finally display the entire figure using the show() method.

We can also directly pass a tuple of labels as an argument to the legend() method and change the legend’s position using the loc parameter of the legend() method.

import numpy as np
import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]

y1 = [i ** 2 for i in x]
y2 = [2 * i + 1 for i in x]

plt.scatter(x, y1, marker="x", color="r")
plt.scatter(x, y2, marker="o", color="b")
plt.legend(("x**2", "2*x+1"), loc="center left")
plt.show()

Output:

Add legend to a 2D scatter plot specifying the location in Matplotlib

This process creates a figure with two scatter plots and a legend placed at the center left of the axes’ border-box.

Add a Legend to the 3D Scatter Plot in Matplotlib

import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]
y = [2, 1, 4, 5, 6]

z1 = [i + j for (i, j) in zip(x, y)]
z2 = [3 * i - j for (i, j) in zip(x, y)]

axes = plt.subplot(111, projection="3d")
axes.plot(x, y, z1, "x", label="x+y")
axes.plot(x, y, z2, "o", label="3*x-y")

plt.legend(loc="upper left")
plt.show()

Output:

Add legend to a 3D scatter plot

To create a legend for 3D scatter plot, we use the plot() method instead of the scatter() method; it’s because the legend() method does not support Patch3DCollection returned by the scatter() method of the Axes3D instance.

Author: Suraj Joshi
Suraj Joshi avatar Suraj Joshi avatar

Suraj Joshi is a backend software engineer at Matrice.ai.

LinkedIn

Related Article - Matplotlib Scatter Plot

Related Article - Matplotlib Legend