How to Make the Legend of the Scatter Plot in Matplotlib

Suraj Joshi Feb 02, 2024

Legend is simply the description of various elements in a figure. We can generate a legend of scatter plot using the `matplotlib.pyplot.legend` function.

Add a Legend to the 2D Scatter Plot in Matplotlib

``````import numpy as np
import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]

y1 = [i ** 2 for i in x]
y2 = [2 * i + 1 for i in x]

plt.scatter(x, y1, marker="x", color="r", label="x**2")
plt.scatter(x, y2, marker="o", color="b", label="2*x+1")
plt.legend()
plt.show()
``````

Output:

We have two separate scatter plots in the figure: one represented by `x` and another by the `o` mark. We assign the `label` to each scatter plot used as a tag while generating the legend. Then, we create the legend in the figure using the `legend()` function and finally display the entire figure using the `show()` method.

We can also directly pass a tuple of labels as an argument to the `legend()` method and change the legend’s position using the `loc` parameter of the `legend()` method.

``````import numpy as np
import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]

y1 = [i ** 2 for i in x]
y2 = [2 * i + 1 for i in x]

plt.scatter(x, y1, marker="x", color="r")
plt.scatter(x, y2, marker="o", color="b")
plt.legend(("x**2", "2*x+1"), loc="center left")
plt.show()
``````

Output:

This process creates a figure with two scatter plots and a legend placed at the `center left` of the axes’ border-box.

Add a Legend to the 3D Scatter Plot in Matplotlib

``````import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]
y = [2, 1, 4, 5, 6]

z1 = [i + j for (i, j) in zip(x, y)]
z2 = [3 * i - j for (i, j) in zip(x, y)]

axes = plt.subplot(111, projection="3d")
axes.plot(x, y, z1, "x", label="x+y")
axes.plot(x, y, z2, "o", label="3*x-y")

plt.legend(loc="upper left")
plt.show()
``````

Output:

To create a legend for `3D` scatter plot, we use the `plot()` method instead of the `scatter()` method; it’s because the `legend()` method does not support `Patch3DCollection` returned by the `scatter()` method of the `Axes3D` instance.

Author: Suraj Joshi

Suraj Joshi is a backend software engineer at Matrice.ai.