plt.subplot vs plt.subplots¶
Understanding the difference between plt.subplot (singular) and plt.subplots (plural) is essential for effective Matplotlib usage.
Mental Model
plt.subplot (singular, no "s") is the old MATLAB-style way -- you specify one subplot at a time using a 3-digit code. plt.subplots (plural, with "s") is the modern way -- it creates the entire grid at once and hands you back all the Axes in an array. Prefer plt.subplots for cleaner, more Pythonic code.
plt.subplots() is a factory function — it creates a Figure and a
structured collection of Axes objects in one call. This defines the
structure of the visualization before any data is plotted, which is
why it belongs to the setup layer rather than the plotting layer.
plt.subplot (MATLAB Style)¶
plt.subplot creates a single subplot in a grid layout:
```python import matplotlib.pyplot as plt
plt.subplot(2, 2, 1) # Create the first subplot in a 2x2 grid plt.plot([1, 2, 3, 4])
plt.subplot(2, 2, 2) # Create the second subplot plt.plot([4, 3, 2, 1])
plt.subplot(2, 2, 3) # Create the third subplot plt.plot([1, 1, 1, 1])
plt.subplot(2, 2, 4) # Create the fourth subplot plt.plot([1, 2, 1, 2])
plt.show() ```
Syntax: plt.subplot(nrows, ncols, index)
- Index starts at 1 (not 0)
- Counts left-to-right, top-to-bottom
plt.subplots (OOP Style)¶
plt.subplots creates a Figure and all Axes objects at once:
```python import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(4, 4))
axes[0, 0].plot([1, 2, 3, 4]) axes[0, 1].plot([4, 3, 2, 1]) axes[1, 0].plot([1, 1, 1, 1]) axes[1, 1].plot([1, 2, 1, 2])
plt.show()
print(type(axes)) # numpy.ndarray print(axes.shape) # (2, 2) print(axes.dtype) # object ```
Returns: A Figure object and a NumPy array of Axes objects.
Unpacking Axes¶
You can unpack the axes array for cleaner code:
```python import matplotlib.pyplot as plt
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2, figsize=(4, 4))
ax0.plot([1, 2, 3, 4]) ax1.plot([4, 3, 2, 1]) ax2.plot([1, 1, 1, 1]) ax3.plot([1, 2, 1, 2])
plt.show() ```
Axes Shape Behavior¶
The returned axes array shape depends on the grid dimensions:
```python
Single axes returns an Axes object, not an array¶
fig, ax = plt.subplots() print(type(ax)) # AxesSubplot
1xN or Nx1 returns a 1D array¶
fig, axes = plt.subplots(1, 3) print(axes.shape) # (3,)
NxM returns a 2D array¶
fig, axes = plt.subplots(2, 3) print(axes.shape) # (2, 3) ```
The squeeze Keyword¶
Use squeeze=False to always get a 2D array:
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(0.001, 1, 100)
Without squeeze=False: shape is (3,)¶
fig, axs = plt.subplots(1, 3, figsize=(12, 3)) axs[0].plot(x, x**2) axs[1].plot(x, np.sin(x)) axs[2].plot(x, np.exp(x))
With squeeze=False: shape is (1, 3)¶
fig, axs = plt.subplots(1, 3, figsize=(12, 3), squeeze=False) axs[0, 0].plot(x, x**2) axs[0, 1].plot(x, np.sin(x)) axs[0, 2].plot(x, np.exp(x))
plt.tight_layout() plt.show() ```
Comparison Summary¶
| Feature | plt.subplot | plt.subplots |
|---|---|---|
| Style | MATLAB | OOP |
| Returns | Axes | (Figure, Axes array) |
| Index starts at | 1 | 0 |
| Creates | One axes at a time | All axes at once |
| Flexibility | Limited | High |
Decision Rule
Always use plt.subplots() unless you have a very specific reason not to. It returns both the Figure and all Axes at once, supports the OOP interface, and integrates cleanly with tight_layout(), savefig(), and every other Figure method. The only time plt.subplot() makes sense is quick interactive one-liners in the REPL.
Key Takeaways¶
plt.subplot: singular, MATLAB style, 1-based indexingplt.subplots: plural, OOP style, returns Figure and Axes array- Use
plt.subplotsfor new code - Use
squeeze=Falsefor consistent array shapes
Exercises¶
Exercise 1. Write the same 2x2 subplot figure using both plt.subplot() (singular) and plt.subplots() (plural). Plot sin, cos, tan, and exp respectively. Which approach requires fewer lines of code?
Solution to Exercise 1
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2 * np.pi, 100)
funcs = [np.sin, np.cos, np.tan, np.exp]
titles = ['sin', 'cos', 'tan', 'exp']
# --- plt.subplot (singular) ---
plt.figure(figsize=(8, 6))
for i, (f, t) in enumerate(zip(funcs, titles), start=1):
plt.subplot(2, 2, i)
plt.plot(x, f(x))
plt.title(t)
plt.tight_layout()
plt.show()
# --- plt.subplots (plural) ---
fig, axes = plt.subplots(2, 2, figsize=(8, 6))
for ax, f, t in zip(axes.flat, funcs, titles):
ax.plot(x, f(x))
ax.set_title(t)
plt.tight_layout()
plt.show()
# plt.subplots is more concise: one creation call + flat iteration.
Exercise 2. Explain the difference between plt.subplot(nrows, ncols, index) and plt.subplots(nrows, ncols). What does each return?
Solution to Exercise 2
plt.subplot(nrows, ncols, index) creates one Axes at the given position in the grid and returns that single Axes object. You call it repeatedly to build up a figure.
plt.subplots(nrows, ncols) creates the entire grid at once and returns a tuple (fig, axes) where fig is the Figure object and axes is a NumPy array of all Axes objects. For a 1x1 grid, axes is a single Axes (not an array); for 1xN or Nx1, it is a 1D array; for NxM, it is a 2D array.
Exercise 3. Write code that creates a 2x2 figure with plt.subplots(2, 2) and iterates over axes.flat to plot y = x**n for n = 1, 2, 3, 4.
Solution to Exercise 3
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 3, 50)
fig, axes = plt.subplots(2, 2, figsize=(8, 6))
for n, ax in enumerate(axes.flat, start=1):
ax.plot(x, x**n)
ax.set_title(f'y = x^{n}')
ax.grid(True)
plt.tight_layout()
plt.show()
Exercise 4. Demonstrate that plt.subplot(2, 2, 1) and plt.subplot(221) are equivalent calls by creating two figures and verifying both produce the same grid position.
Solution to Exercise 4
import matplotlib.pyplot as plt
# Three-argument form
fig1 = plt.figure()
ax1 = plt.subplot(2, 2, 1)
ax1.set_title('subplot(2, 2, 1)')
print(f"Position (3-arg): {ax1.get_geometry()}")
# Compact integer form
fig2 = plt.figure()
ax2 = plt.subplot(221)
ax2.set_title('subplot(221)')
print(f"Position (int): {ax2.get_geometry()}")
# Both produce (2, 2, 1) — same grid position
plt.show()