Axes Shape Behavior¶
Understanding how plt.subplots returns different shaped arrays is crucial for writing robust plotting code.
Mental Model
plt.subplots(r, c) returns axes shaped like the grid: a single object for 1x1, a 1D array for a single row or column, and a 2D array for multi-row, multi-column grids. Use squeeze=False to always get a 2D array, which makes indexing uniform and avoids shape-related bugs.
Why This Matters
Shape inconsistency is the #1 source of subplot bugs. Code that works
for a 2x2 grid (axes[0, 1]) crashes on a 1x2 grid (axes is 1D, so
axes[0, 1] raises IndexError). The fix: always pass squeeze=False
in functions that accept variable grid sizes.
text
Rule of thumb:
1 subplot → ax (scalar)
1 row/column → axes[i] (1D array)
Grid → axes[i, j] (2D array)
Need consistency → squeeze=False (always 2D)
Single Axes¶
With no arguments or (1, 1), a single Axes object is returned (not an array):
```python import matplotlib.pyplot as plt
fig, ax = plt.subplots()
print(type(ax)) #
fig, ax = plt.subplots(1, 1)
print(type(ax)) #
1D Arrays (1×N or N×1)¶
When creating a single row or column, the result is a 1D array:
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(-2np.pi, 2np.pi, 100)
1 row, 2 columns -> shape (2,), not (1, 2)¶
fig, axes = plt.subplots(1, 2) axes[0].plot(x, np.sin(x)) axes[1].plot(x, np.cos(x)) plt.show()
print(type(axes)) # numpy.ndarray print(axes.shape) # (2,) print(axes.dtype) # object ```
```python
2 rows, 1 column -> shape (2,), not (2, 1)¶
fig, axes = plt.subplots(2, 1) axes[0].plot(x, np.sin(x)) axes[1].plot(x, np.cos(x)) plt.show()
print(axes.shape) # (2,) ```
2D Arrays (N×M)¶
Only when both dimensions are greater than 1:
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(-2np.pi, 2np.pi, 100)
fig, axes = plt.subplots(2, 2)
axes[0, 0].plot(x, np.sin(x)) axes[0, 1].plot(x, np.cos(x)) axes[1, 0].plot(x, np.sinh(x)) axes[1, 1].plot(x, np.cosh(x))
plt.show()
print(type(axes)) # numpy.ndarray print(axes.shape) # (2, 2) print(axes.dtype) # object ```
The squeeze Parameter¶
Use squeeze=False to always get a 2D array:
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(0.001, 1, 100)
Without squeeze=False¶
fig, axs = plt.subplots(1, 3, figsize=(12, 3)) print(axs.shape) # (3,) axs[0].plot(x, x**2) axs[1].plot(x, np.sin(x)) axs[2].plot(x, np.exp(x))
With squeeze=False¶
fig, axs = plt.subplots(1, 3, figsize=(12, 3), squeeze=False) print(axs.shape) # (1, 3) axs[0, 0].plot(x, x**2) axs[0, 1].plot(x, np.sin(x)) axs[0, 2].plot(x, np.exp(x))
plt.tight_layout() plt.show() ```
Why Use squeeze=False?¶
Consistent array access is useful when:
-
Looping over subplots:
python fig, axes = plt.subplots(1, 3, squeeze=False) for i in range(1): for j in range(3): axes[i, j].plot([1, 2, 3]) -
Writing generic functions:
python def setup_grid(nrows, ncols): fig, axes = plt.subplots(nrows, ncols, squeeze=False) # Always use axes[i, j] regardless of dimensions return fig, axes
Shape Summary Table¶
| Subplots | Returns | Shape | Access |
|---|---|---|---|
plt.subplots() |
Axes | N/A | ax |
plt.subplots(1, 3) |
1D array | (3,) |
axes[j] |
plt.subplots(3, 1) |
1D array | (3,) |
axes[i] |
plt.subplots(2, 3) |
2D array | (2, 3) |
axes[i, j] |
plt.subplots(1, 3, squeeze=False) |
2D array | (1, 3) |
axes[0, j] |
Practical Implications¶
Handle different cases in code:
```python import matplotlib.pyplot as plt import numpy as np
def plot_functions(funcs, figsize=(12, 3)): n = len(funcs) fig, axes = plt.subplots(1, n, figsize=figsize, squeeze=False)
x = np.linspace(0, 2*np.pi, 100)
for j, func in enumerate(funcs):
axes[0, j].plot(x, func(x))
plt.tight_layout()
return fig, axes
Works for any number of functions¶
plot_functions([np.sin, np.cos, np.tan]) plt.show() ```
Key Takeaways¶
- Single axes returns an Axes object, not an array
- 1×N or N×1 returns a 1D array with shape
(N,) - N×M (both > 1) returns a 2D array with shape
(N, M) - Use
squeeze=Falsefor consistent 2D array access - Handle shape variations when writing generic plotting functions
Exercises¶
Exercise 1.
Call plt.subplots() with no arguments and print the type and shape of the returned axes object. Then call plt.subplots(1, 3) and print the type and shape. Finally call plt.subplots(2, 3) and print the type and shape. Explain the differences.
Solution to Exercise 1
import matplotlib.pyplot as plt
import numpy as np
fig1, ax1 = plt.subplots()
print(f"No args: type={type(ax1)}, not an array")
plt.close(fig1)
fig2, ax2 = plt.subplots(1, 3)
print(f"(1, 3): type={type(ax2)}, shape={ax2.shape}")
plt.close(fig2)
fig3, ax3 = plt.subplots(2, 3)
print(f"(2, 3): type={type(ax3)}, shape={ax3.shape}")
plt.close(fig3)
# No args -> single Axes object
# (1, 3) -> 1D array of shape (3,)
# (2, 3) -> 2D array of shape (2, 3)
Exercise 2.
Create a 3x3 grid using plt.subplots(3, 3). Use axes.flat to iterate over all 9 axes and plot y = sin(n*x) where n goes from 1 to 9. Title each subplot with the value of n.
Solution to Exercise 2
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2 * np.pi, 200)
fig, axes = plt.subplots(3, 3, figsize=(10, 10))
for n, ax in enumerate(axes.flat, 1):
ax.plot(x, np.sin(n * x))
ax.set_title(f'n = {n}')
plt.tight_layout()
plt.show()
Exercise 3.
Create a 2x3 subplot grid with squeeze=False and verify the return shape is always 2D by printing axes.shape. Then access axes using 2D indexing axes[row, col] to plot different functions. Compare this with the default squeeze=True behavior for a 1x3 grid.
Solution to Exercise 3
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2 * np.pi, 200)
# With squeeze=False, always 2D
fig, axes = plt.subplots(2, 3, squeeze=False, figsize=(12, 6))
print(f"squeeze=False, (2,3): shape={axes.shape}")
funcs = [np.sin, np.cos, np.tan, np.exp, np.log1p, np.sqrt]
names = ['sin', 'cos', 'tan', 'exp', 'log1p', 'sqrt']
for i in range(2):
for j in range(3):
idx = i * 3 + j
axes[i, j].plot(x, funcs[idx](x))
axes[i, j].set_title(names[idx])
plt.tight_layout()
plt.show()
# Compare with squeeze=True (default) for 1x3
fig2, axes2 = plt.subplots(1, 3, squeeze=True)
print(f"squeeze=True, (1,3): shape={axes2.shape}") # (3,) not (1,3)
fig3, axes3 = plt.subplots(1, 3, squeeze=False)
print(f"squeeze=False, (1,3): shape={axes3.shape}") # (1,3)
plt.close('all')