Remove Legend From Seaborn Plots in Python
-
Use the
legend
Parameter to Remove the Legend From a Seaborn Plot in Python -
Use the
legend()
Function to Remove the Legend From a Seaborn Plot in Python -
Use the
remove()
Function to Remove the Legend From a Seaborn Plot in Python

In this tutorial, we will learn how to remove the legend from a seaborn plot in Python.
Use the legend
Parameter to Remove the Legend From a Seaborn Plot in Python
Most of the plot functions in seaborn accept the legend
parameter. We can set it to False and hide the legend from the final plot.
For example,
import random
import seaborn as sns
import matplotlib.pyplot as plt
s_x = random.sample(range(0,100),20)
s_y = random.sample(range(0,100),20)
cat = [i for i in range(2)]*10
sns.scatterplot(y = s_y, x = s_x, hue = cat, legend = False)
Use the legend()
Function to Remove the Legend From a Seaborn Plot in Python
The matplotlib.pyplot.legend()
function can be used to add a custom legend to seaborn plots. We can use this function because the seaborn module is built on top of the matplotlib module. We can add an empty legend to the plot and remove its frame. This way, we hide the legend from the final figure.
The following code snippet implements this.
import random
import seaborn as sns
import matplotlib.pyplot as plt
s_x = random.sample(range(0,100),20)
s_y = random.sample(range(0,100),20)
cat = [i for i in range(2)]*10
sns.scatterplot(y = s_y, x = s_x, hue = cat)
plt.legend([],[], frameon=False)
If we are dealing with a figure containing subplots and wish to remove the legend from each subplot, we can iterate through the axes object and add the empty legend using the above function to every axis.
Use the remove()
Function to Remove the Legend From a Seaborn Plot in Python
This method works with the objects belonging to different classes like the PairGrid class from the seaborn module. We can call the legend using the _legend()
function and remove it using the remove()
method.
See the code below.
import random
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
s_x = random.sample(range(0,100),20)
s_y = random.sample(range(0,100),20)
cat = [i for i in range(2)]*10
df = pd.DataFrame({'s_x':s_x,'s_y':s_y,'cat':cat})
g = sns.pairplot(data = df,x_vars='s_x', y_vars='s_y', hue = 'cat')
g._legend.remove()
The pairplot()
function returns an object of PairGrid class. This method works for the FacetGrid
objects of the seaborn module also.
Manav is a IT Professional who has a lot of experience as a core developer in many live projects. He is an avid learner who enjoys learning new things and sharing his findings whenever possible.
LinkedIn