Unpacking Axes¶
Python's tuple unpacking provides a clean way to name individual axes when creating subplots.
Mental Model
fig, ax = plt.subplots() is just Python tuple unpacking -- the function returns a pair and you name each element. For grids, fig, (ax1, ax2) = plt.subplots(1, 2) unpacks the axes array in the same way. Giving each Axes a descriptive name makes your code read like a sentence instead of an index lookup.
The deeper principle: naming axes = naming meaning. Compare
axes[0] vs ax_price — the first is a position, the second is a concept.
Descriptive names make code self-documenting and prevent mix-ups when a
layout has many panels:
python
fig, (ax_price, ax_volume, ax_returns) = plt.subplots(3, 1)
Unpacking is Fragile
Unpacking hardcodes the number of axes. If the layout changes (e.g., from 2 to 3 panels), the unpacking line breaks immediately:
```python
This works for 2 panels¶
fig, (ax1, ax2) = plt.subplots(1, 2)
Adding a third panel requires changing the unpack line too¶
fig, (ax1, ax2, ax3) = plt.subplots(1, 3) ```
For functions that accept a variable number of panels, use indexing
(axes[i]) or axes.flat instead. Reserve unpacking for fixed, known
layouts where readability outweighs flexibility.
Unpacking a Single Axes¶
The most common pattern:
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(-2np.pi, 2np.pi, 100) y = np.sin(x)
fig, ax = plt.subplots() ax.plot(x, y) plt.show() ```
Unpacking 1D Arrays¶
For single rows or columns:
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(-2np.pi, 2np.pi, 100)
Unpack 1x2 grid¶
fig, (ax0, ax1) = plt.subplots(1, 2) ax0.plot(x, np.sin(x)) ax1.plot(x, np.cos(x)) plt.show() ```
```python
Unpack 2x1 grid¶
fig, (ax0, ax1) = plt.subplots(2, 1) ax0.plot(x, np.sin(x)) ax1.plot(x, np.cos(x)) plt.show() ```
```python
Unpack 1x3 grid¶
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(12, 3)) ax0.plot(x, x**2) ax1.plot(x, np.sin(x)) ax2.plot(x, np.exp(x)) plt.tight_layout() plt.show() ```
Unpacking 2D Arrays¶
Use nested tuple unpacking for grids:
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(-2np.pi, 2np.pi, 100)
Unpack 2x2 grid¶
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
ax0.plot(x, np.sin(x)) ax0.set_title("sin")
ax1.plot(x, np.cos(x)) ax1.set_title("cos")
ax2.plot(x, np.sinh(x)) ax2.set_title("sinh")
ax3.plot(x, np.cosh(x)) ax3.set_title("cosh")
plt.tight_layout() plt.show() ```
2×3 Grid Unpacking¶
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(0.001, 1, 100)
fig, ((ax0, ax1, ax2), (ax3, ax4, ax5)) = plt.subplots(2, 3, figsize=(12, 6))
ax0.plot(x, x**2) ax1.plot(x, np.sin(x)) ax2.plot(x, np.exp(x)) ax3.plot(x, np.log(x)) ax4.plot(x, np.sin(x) / np.exp(x)) ax5.plot(x, np.log(x) / np.exp(x))
plt.tight_layout() plt.show() ```
When to Use Array Indexing vs Unpacking¶
Use unpacking when:
- Fixed number of subplots
- Each subplot has distinct content
- Want descriptive variable names
python
fig, (ax_price, ax_volume) = plt.subplots(2, 1, sharex=True)
ax_price.plot(dates, prices)
ax_volume.bar(dates, volume)
Use array indexing when:
- Dynamic number of subplots
- Applying the same operation to all
- Looping over subplots
python
fig, axes = plt.subplots(3, 4)
for i, ax in enumerate(axes.flat):
ax.plot(data[i])
Combining Both Approaches¶
Sometimes a hybrid approach is clearest:
```python import matplotlib.pyplot as plt import numpy as np
x = np.linspace(0, 2*np.pi, 100)
Create the grid¶
fig, axes = plt.subplots(2, 2, figsize=(8, 6))
Unpack for clarity¶
(ax_sin, ax_cos), (ax_tan, ax_exp) = axes
ax_sin.plot(x, np.sin(x)) ax_sin.set_title("sin(x)")
ax_cos.plot(x, np.cos(x)) ax_cos.set_title("cos(x)")
ax_tan.plot(x, np.tan(x)) ax_tan.set_ylim(-5, 5) ax_tan.set_title("tan(x)")
ax_exp.plot(x, np.exp(np.sin(x))) ax_exp.set_title("exp(sin(x))")
plt.tight_layout() plt.show() ```
Using axes.flat¶
Iterate over all axes regardless of shape:
```python import matplotlib.pyplot as plt import numpy as np
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
for i, ax in enumerate(axes.flat): x = np.linspace(0, 2*np.pi, 100) ax.plot(x, np.sin((i+1) * x)) ax.set_title(f"sin({i+1}x)")
plt.tight_layout() plt.show() ```
Key Takeaways¶
fig, ax = plt.subplots()for single axesfig, (ax0, ax1) = plt.subplots(1, 2)for 1D arraysfig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)for 2D arrays- Use unpacking for fixed layouts with distinct content
- Use array indexing for dynamic or looped operations
axes.flatflattens any shape for iteration
Exercises¶
Exercise 1.
Use tuple unpacking to create a 1x2 subplot layout: fig, (ax_left, ax_right) = plt.subplots(1, 2). Plot a histogram on the left and a box plot on the right using 1000 samples from a normal distribution.
Solution to Exercise 1
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
data = np.random.randn(1000)
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(10, 4))
ax_left.hist(data, bins=30, color='steelblue', edgecolor='white')
ax_left.set_title('Histogram')
ax_right.boxplot(data)
ax_right.set_title('Box Plot')
plt.tight_layout()
plt.show()
Exercise 2.
Create a 2x2 subplot grid and use nested unpacking: fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2). Plot sin, cos, tan, and exp on the four axes respectively. Add titles and use plt.tight_layout().
Solution to Exercise 2
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2 * np.pi, 200)
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 8))
ax1.plot(x, np.sin(x))
ax1.set_title('sin(x)')
ax2.plot(x, np.cos(x), color='red')
ax2.set_title('cos(x)')
ax3.plot(x, np.tan(x), color='green')
ax3.set_ylim(-5, 5)
ax3.set_title('tan(x)')
ax4.plot(x, np.exp(x / 3), color='orange')
ax4.set_title('exp(x/3)')
plt.tight_layout()
plt.show()
Exercise 3.
Create a 3x1 layout and unpack as fig, (ax_top, ax_mid, ax_bot) = plt.subplots(3, 1, figsize=(8, 10), sharex=True). Plot a stock-like random walk on the top, its daily returns on the middle, and a cumulative return on the bottom. Use descriptive y-labels for each.
Solution to Exercise 3
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
returns = np.random.randn(200) * 0.02
prices = 100 * np.cumprod(1 + returns)
cum_returns = np.cumprod(1 + returns) - 1
fig, (ax_top, ax_mid, ax_bot) = plt.subplots(3, 1, figsize=(8, 10), sharex=True)
ax_top.plot(prices, color='navy')
ax_top.set_ylabel('Price')
ax_top.set_title('Stock Dashboard')
ax_mid.bar(range(len(returns)), returns, color=['green' if r > 0 else 'red' for r in returns], width=1)
ax_mid.set_ylabel('Daily Return')
ax_bot.plot(cum_returns, color='purple')
ax_bot.set_ylabel('Cumulative Return')
ax_bot.set_xlabel('Day')
plt.tight_layout()
plt.show()