Axes Shape Behavior¶
Understanding how plt.subplots returns different shaped arrays is crucial for writing robust plotting code.
Single Axes¶
With no arguments or (1, 1), a single Axes object is returned (not an array):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
print(type(ax)) # <class 'matplotlib.axes._subplots.AxesSubplot'>
fig, ax = plt.subplots(1, 1)
print(type(ax)) # <class 'matplotlib.axes._subplots.AxesSubplot'>
1D Arrays (1×N or N×1)¶
When creating a single row or column, the result is a 1D array:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-2*np.pi, 2*np.pi, 100)
# 1 row, 2 columns -> shape (2,), not (1, 2)
fig, axes = plt.subplots(1, 2)
axes[0].plot(x, np.sin(x))
axes[1].plot(x, np.cos(x))
plt.show()
print(type(axes)) # numpy.ndarray
print(axes.shape) # (2,)
print(axes.dtype) # object
# 2 rows, 1 column -> shape (2,), not (2, 1)
fig, axes = plt.subplots(2, 1)
axes[0].plot(x, np.sin(x))
axes[1].plot(x, np.cos(x))
plt.show()
print(axes.shape) # (2,)
2D Arrays (N×M)¶
Only when both dimensions are greater than 1:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-2*np.pi, 2*np.pi, 100)
fig, axes = plt.subplots(2, 2)
axes[0, 0].plot(x, np.sin(x))
axes[0, 1].plot(x, np.cos(x))
axes[1, 0].plot(x, np.sinh(x))
axes[1, 1].plot(x, np.cosh(x))
plt.show()
print(type(axes)) # numpy.ndarray
print(axes.shape) # (2, 2)
print(axes.dtype) # object
The squeeze Parameter¶
Use squeeze=False to always get a 2D array:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0.001, 1, 100)
# Without squeeze=False
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
print(axs.shape) # (3,)
axs[0].plot(x, x**2)
axs[1].plot(x, np.sin(x))
axs[2].plot(x, np.exp(x))
# With squeeze=False
fig, axs = plt.subplots(1, 3, figsize=(12, 3), squeeze=False)
print(axs.shape) # (1, 3)
axs[0, 0].plot(x, x**2)
axs[0, 1].plot(x, np.sin(x))
axs[0, 2].plot(x, np.exp(x))
plt.tight_layout()
plt.show()
Why Use squeeze=False?¶
Consistent array access is useful when:
-
Looping over subplots:
fig, axes = plt.subplots(1, 3, squeeze=False) for i in range(1): for j in range(3): axes[i, j].plot([1, 2, 3]) -
Writing generic functions:
def setup_grid(nrows, ncols): fig, axes = plt.subplots(nrows, ncols, squeeze=False) # Always use axes[i, j] regardless of dimensions return fig, axes
Shape Summary Table¶
| Subplots | Returns | Shape | Access |
|---|---|---|---|
plt.subplots() |
Axes | N/A | ax |
plt.subplots(1, 3) |
1D array | (3,) |
axes[j] |
plt.subplots(3, 1) |
1D array | (3,) |
axes[i] |
plt.subplots(2, 3) |
2D array | (2, 3) |
axes[i, j] |
plt.subplots(1, 3, squeeze=False) |
2D array | (1, 3) |
axes[0, j] |
Practical Implications¶
Handle different cases in code:
import matplotlib.pyplot as plt
import numpy as np
def plot_functions(funcs, figsize=(12, 3)):
n = len(funcs)
fig, axes = plt.subplots(1, n, figsize=figsize, squeeze=False)
x = np.linspace(0, 2*np.pi, 100)
for j, func in enumerate(funcs):
axes[0, j].plot(x, func(x))
plt.tight_layout()
return fig, axes
# Works for any number of functions
plot_functions([np.sin, np.cos, np.tan])
plt.show()
Key Takeaways¶
- Single axes returns an Axes object, not an array
- 1×N or N×1 returns a 1D array with shape
(N,) - N×M (both > 1) returns a 2D array with shape
(N, M) - Use
squeeze=Falsefor consistent 2D array access - Handle shape variations when writing generic plotting functions