How to Display Multiple Images in One Figure Correctly in Matplotlib

Suraj Joshi Feb 02, 2024
  1. Use Matplotlib add_subplot() in for Loop
  2. Define a Function Based on the Subplots in Matplotlib
How to Display Multiple Images in One Figure Correctly in Matplotlib

The core idea for displaying multiple images in a figure is to iterate over the list of axes to plot individual images. We use the imshow() method to display individual images.

Use Matplotlib add_subplot() in for Loop

The simplest approach to display multiple images in a figure might be displaying every image using add_subplot() to initiate subplot and imshow() method to display an image inside a for loop.

Syntax for add_subplot() method:

add_subplot(rows, columns, i)

where rows and columns represent the total number of rows and columns in composite figure and i represents the index of the image to be displayed.

import numpy as np
import matplotlib.pyplot as plt

width = 5
height = 5
rows = 2
cols = 2
axes = []
fig = plt.figure()

for a in range(rows * cols):
    b = np.random.randint(7, size=(height, width))
    axes.append(fig.add_subplot(rows, cols, a + 1))
    subplot_title = "Subplot" + str(a)
    axes[-1].set_title(subplot_title)
    plt.imshow(b)
fig.tight_layout()
plt.show()

Output:

display multiple images in a figure simper approach

We can add more flexibility in the above code to to plot more features on selected axes and provide access to each axes of subplots.

import numpy as np
import matplotlib.pyplot as plt

width = 5
height = 5
rows = 2
cols = 2
fig = plt.figure()

x = np.linspace(-3, 3, 100)
y1 = np.sin(x)
y2 = 1 / (1 + np.exp(-x))

axes = []

for i in range(cols * rows):
    b = np.random.randint(10, size=(height, width))
    axes.append(fig.add_subplot(rows, cols, i + 1))
    subplot_title = "Subplot" + str(i)
    axes[-1].set_title(subplot_title)
    plt.imshow(b)

axes[1].plot(x, y1)
axes[3].plot(x, y2)
fig.tight_layout()

plt.show()

Output:

display multiple images in a figure simper with flexibility

Here, axes enables access to manipulate each of subplots.

Alternatively, we can also provide access to each of the sub-plots with [row_index][column_index], which is more helpful when we have array of multiple images.

import numpy as np
import matplotlib.pyplot as plt

width = 5
height = 5
rows = 2
cols = 2

x = np.linspace(0, 3, 100)
y1 = np.sin(x)
y2 = 1 / (1 + np.exp(-x))

figure, axes = plt.subplots(nrows=rows, ncols=cols)

for a, b in enumerate(axes.flat):
    image = np.random.randint(7, size=(height, width))
    b.imshow(image, alpha=0.25)
    r = a // cols
    c = a % cols
    subtitle = "Row:" + str(r) + ", Col:" + str(c)
    b.set_title(subtitle)

axes[0][1].plot(x, y1)
axes[1][1].plot(x, y2)

figure.tight_layout()
plt.show()

Output:

access to each of the sub-plots with row index column index

Define a Function Based on the Subplots in Matplotlib

We can define a function based on the subplots command which creates a number of axes in the figures, according to the number of rows and columns, and then iterates over the list of the axis to plot images adding a title for each of them.

import numpy as np
import matplotlib.pyplot as plt


def display_multiple_img(images, rows=1, cols=1):
    figure, ax = plt.subplots(nrows=rows, ncols=cols)
    for ind, title in enumerate(images):
        ax.ravel()[ind].imshow(images[title])
        ax.ravel()[ind].set_title(title)
        ax.ravel()[ind].set_axis_off()
    plt.tight_layout()
    plt.show()


total_images = 4
images = {"Image" + str(i): np.random.rand(100, 100) for i in range(total_images)}

display_multiple_img(images, 2, 2)

Output:

display multiple images in a figure by defining a function

Author: Suraj Joshi
Suraj Joshi avatar Suraj Joshi avatar

Suraj Joshi is a backend software engineer at Matrice.ai.

LinkedIn