Subplots and Grids¶
Create multiple plots in a single figure using plt.subplots() for organized, comparative visualizations.
Basic Subplots¶
Create a grid of axes with plt.subplots().
1. Single Row¶
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].plot(x, np.sin(x))
axes[0].set_title('Sine')
axes[1].plot(x, np.cos(x))
axes[1].set_title('Cosine')
axes[2].plot(x, np.tan(x))
axes[2].set_ylim(-5, 5)
axes[2].set_title('Tangent')
plt.show()
2. Single Column¶
fig, axes = plt.subplots(3, 1, figsize=(6, 10))
axes[0].plot(x, np.sin(x))
axes[1].plot(x, np.cos(x))
axes[2].plot(x, x**2)
plt.show()
3. Grid Layout¶
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
axes[0, 0].plot(x, np.sin(x))
axes[0, 1].plot(x, np.cos(x))
axes[1, 0].plot(x, np.exp(x/10))
axes[1, 1].plot(x, np.log(x + 1))
plt.show()
Axes Indexing¶
Access individual axes in different grid configurations.
1. 1D Array (Single Row or Column)¶
fig, axes = plt.subplots(1, 3)
# axes is 1D: axes[0], axes[1], axes[2]
fig, axes = plt.subplots(3, 1)
# axes is 1D: axes[0], axes[1], axes[2]
2. 2D Array (Grid)¶
fig, axes = plt.subplots(2, 3)
# axes is 2D: axes[row, col]
# axes[0, 0], axes[0, 1], axes[0, 2]
# axes[1, 0], axes[1, 1], axes[1, 2]
3. Flatten for Iteration¶
fig, axes = plt.subplots(2, 3)
for ax in axes.flat:
ax.plot(np.random.randn(50))
plt.show()
Figure Size¶
Control overall figure dimensions.
1. figsize Parameter¶
fig, axes = plt.subplots(2, 2, figsize=(10, 8)) # Width, Height in inches
2. Aspect Ratio¶
# Square figure
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
# Wide figure
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
# Tall figure
fig, axes = plt.subplots(4, 1, figsize=(6, 12))
3. DPI Setting¶
fig, axes = plt.subplots(2, 2, figsize=(8, 8), dpi=100)
# Total pixels: 800 x 800
Shared Axes¶
Link axes across subplots for consistent scales.
1. Share X-Axis¶
fig, axes = plt.subplots(3, 1, figsize=(8, 8), sharex=True)
axes[0].plot(x, np.sin(x))
axes[1].plot(x, np.cos(x))
axes[2].plot(x, np.sin(x) * np.cos(x))
# Only bottom subplot shows x-tick labels
plt.show()
2. Share Y-Axis¶
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
axes[0].plot(x, np.sin(x))
axes[1].plot(x, np.cos(x))
axes[2].plot(x, np.sin(2*x))
# Only left subplot shows y-tick labels
plt.show()
3. Share Both¶
fig, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True, sharey=True)
for ax in axes.flat:
ax.plot(np.random.randn(50).cumsum())
plt.show()
Spacing Control¶
Adjust space between subplots.
1. Default Spacing¶
fig, axes = plt.subplots(2, 2)
# Default spacing applied
2. Tight Layout¶
fig, axes = plt.subplots(2, 2)
for ax in axes.flat:
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
fig.tight_layout()
plt.show()
3. Constrained Layout¶
fig, axes = plt.subplots(2, 2, constrained_layout=True)
for ax in axes.flat:
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
plt.show()
Adding Titles¶
Add titles to figure and subplots.
1. Subplot Titles¶
fig, axes = plt.subplots(2, 2)
axes[0, 0].set_title('Plot A')
axes[0, 1].set_title('Plot B')
axes[1, 0].set_title('Plot C')
axes[1, 1].set_title('Plot D')
plt.show()
2. Figure Super Title¶
fig, axes = plt.subplots(2, 2)
fig.suptitle('Main Title', fontsize=16)
fig.tight_layout(rect=[0, 0, 1, 0.95]) # Leave space for suptitle
plt.show()
3. Combined Titles¶
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
fig.suptitle('Comparison of Functions', fontsize=16, fontweight='bold')
titles = ['Sine', 'Cosine', 'Exponential', 'Logarithm']
for ax, title in zip(axes.flat, titles):
ax.set_title(title)
fig.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
Removing Empty Subplots¶
Handle grids with fewer plots than cells.
1. Turn Off Unused Axes¶
fig, axes = plt.subplots(2, 3)
# Only use 5 subplots
for i, ax in enumerate(axes.flat[:5]):
ax.plot(np.random.randn(50))
# Turn off the 6th
axes.flat[5].axis('off')
plt.show()
2. Remove Completely¶
fig, axes = plt.subplots(2, 3)
for i, ax in enumerate(axes.flat[:5]):
ax.plot(np.random.randn(50))
fig.delaxes(axes.flat[5])
fig.tight_layout()
plt.show()
3. Set Visibility¶
axes.flat[5].set_visible(False)
Practical Example¶
Create a complete multi-panel figure.
1. Setup Figure¶
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
fig.suptitle('Data Analysis Dashboard', fontsize=16)
x = np.linspace(0, 10, 100)
2. Populate Subplots¶
# Row 1
axes[0, 0].plot(x, np.sin(x), 'b-')
axes[0, 0].set_title('Sine Wave')
axes[0, 0].set_xlabel('Time')
axes[0, 0].set_ylabel('Amplitude')
axes[0, 1].hist(np.random.randn(1000), bins=30, color='green', alpha=0.7)
axes[0, 1].set_title('Distribution')
axes[0, 2].scatter(np.random.rand(50), np.random.rand(50), c='red')
axes[0, 2].set_title('Scatter')
# Row 2
axes[1, 0].bar(['A', 'B', 'C', 'D'], [23, 45, 56, 78])
axes[1, 0].set_title('Categories')
axes[1, 1].plot(x, np.cumsum(np.random.randn(100)), 'purple')
axes[1, 1].set_title('Random Walk')
axes[1, 2].imshow(np.random.rand(10, 10), cmap='viridis')
axes[1, 2].set_title('Heatmap')
3. Finalize Layout¶
fig.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
Runnable Example: subplots_tutorial.py¶
"""
Matplotlib Tutorial - Beginner Level
=====================================
Topic: Subplots and Understanding Axes as NumPy Array
Author: Educational Python Course
Level: Beginner
Learning Objectives:
-------------------
1. Create multiple subplots using plt.subplots()
2. Understand that axes is a NumPy array
3. Learn the counter-intuitive shape of the axes array
4. Access individual axes for plotting
5. Master the indexing of axes arrays
Prerequisites:
-------------
- Completion of 01_introduction_and_first_plot.py
- Completion of 02_two_plotting_styles.py
- Basic NumPy array indexing knowledge
CRITICAL CONCEPT:
----------------
When using fig, axes = plt.subplots(nrows, ncols), the 'axes' object is
a NumPy array. Its shape and indexing can be counter-intuitive at first,
but once you understand it, it becomes very handy!
"""
import matplotlib.pyplot as plt
import numpy as np
# ============================================================================
# SECTION 1: Creating a Single Subplot (Review)
# ============================================================================
if __name__ == "__main__":
"""
When you create a single plot with plt.subplots(), you get:
- fig: a Figure object (the container)
- ax: a single Axes object (the plot area)
"""
fig, ax = plt.subplots()
print(f"Type of ax (single plot): {type(ax)}")
print(f"Is ax a numpy array? {isinstance(ax, np.ndarray)}")
# Output: False - with single subplot, ax is NOT an array
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x))
ax.set_title('Single Subplot')
plt.show()
# ============================================================================
# SECTION 2: Creating Multiple Subplots (1 Row, Multiple Columns)
# ============================================================================
"""
When you create multiple subplots, plt.subplots() returns:
- fig: a Figure object
- axes: a NumPy array of Axes objects (NOT a single Axes!)
"""
# Create 1 row, 3 columns of subplots
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
print("\n" + "=" * 70)
print("1 ROW, 3 COLUMNS")
print("=" * 70)
print(f"Type of axes: {type(axes)}")
print(f"Is axes a numpy array? {isinstance(axes, np.ndarray)}")
print(f"Shape of axes: {axes.shape}") # (3,) - a 1D array with 3 elements
print(f"Number of axes: {len(axes)}")
# axes is a 1D NumPy array: [ax0, ax1, ax2]
# Access each subplot using index: axes[0], axes[1], axes[2]
x = np.linspace(0, 10, 100)
# Plot on first subplot (index 0)
axes[0].plot(x, np.sin(x))
axes[0].set_title('sin(x)')
# Plot on second subplot (index 1)
axes[1].plot(x, np.cos(x))
axes[1].set_title('cos(x)')
# Plot on third subplot (index 2)
axes[2].plot(x, np.tan(x))
axes[2].set_ylim(-5, 5) # Limit y-axis for tan
axes[2].set_title('tan(x)')
plt.tight_layout() # Automatically adjust spacing to prevent overlap
plt.show()
# ============================================================================
# SECTION 3: Creating Multiple Subplots (Multiple Rows, 1 Column)
# ============================================================================
# Create 3 rows, 1 column of subplots
fig, axes = plt.subplots(3, 1, figsize=(6, 8))
print("\n" + "=" * 70)
print("3 ROWS, 1 COLUMN")
print("=" * 70)
print(f"Shape of axes: {axes.shape}") # (3,) - still a 1D array!
print(f"Number of axes: {len(axes)}")
# axes is still a 1D array: [ax0, ax1, ax2]
# Even though the visual layout is vertical!
x = np.linspace(0, 10, 100)
# Plot on first subplot (top)
axes[0].plot(x, x)
axes[0].set_title('Linear: y = x')
# Plot on second subplot (middle)
axes[1].plot(x, x**2)
axes[1].set_title('Quadratic: y = x²')
# Plot on third subplot (bottom)
axes[2].plot(x, x**3)
axes[2].set_title('Cubic: y = x³')
plt.tight_layout()
plt.show()
# ============================================================================
# SECTION 4: The Counter-Intuitive Part - 2D Grid of Subplots
# ============================================================================
"""
CRITICAL: Understanding the shape of axes in a 2D grid
When you create a 2D grid: plt.subplots(nrows, ncols)
The axes array has shape (nrows, ncols)
COUNTER-INTUITIVE PART:
- First index = row number (vertical position)
- Second index = column number (horizontal position)
- This is like matrix indexing: axes[row, col]
Visual Layout: Array Indexing:
------------- ----------------
[plot1] [plot2] ==> axes[0,0] axes[0,1]
[plot3] [plot4] ==> axes[1,0] axes[1,1]
The first index (row) changes the VERTICAL position
The second index (col) changes the HORIZONTAL position
This matches NumPy array conventions but can feel backwards at first!
"""
# Create 2 rows, 3 columns
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
print("\n" + "=" * 70)
print("2 ROWS, 3 COLUMNS - THE IMPORTANT CASE")
print("=" * 70)
print(f"Type of axes: {type(axes)}")
print(f"Shape of axes: {axes.shape}") # (2, 3) - a 2D array
print(f"axes is a 2D array with shape (rows, cols)")
print()
print("Visual Layout:")
print(" Col 0 Col 1 Col 2")
print("Row 0: [0,0] [0,1] [0,2]")
print("Row 1: [1,0] [1,1] [1,2]")
# Let's number each subplot to see the indexing clearly
for i in range(2): # rows
for j in range(3): # cols
axes[i, j].text(0.5, 0.5, f'axes[{i},{j}]',
ha='center', va='center',
fontsize=20, transform=axes[i, j].transAxes)
axes[i, j].set_title(f'Row {i}, Col {j}')
plt.tight_layout()
plt.show()
# ============================================================================
# SECTION 5: Practical Example - Plotting Data in 2D Grid
# ============================================================================
# Create a 2x2 grid of different functions
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
print("\n" + "=" * 70)
print("2x2 GRID EXAMPLE")
print("=" * 70)
print(f"axes.shape = {axes.shape}")
x = np.linspace(0, 10, 100)
# Top-left: axes[0, 0] (row 0, col 0)
axes[0, 0].plot(x, np.sin(x), 'r-')
axes[0, 0].set_title('Top-Left: sin(x)')
axes[0, 0].set_ylabel('Row 0')
# Top-right: axes[0, 1] (row 0, col 1)
axes[0, 1].plot(x, np.cos(x), 'b-')
axes[0, 1].set_title('Top-Right: cos(x)')
# Bottom-left: axes[1, 0] (row 1, col 0)
axes[1, 0].plot(x, np.exp(-x/5) * np.sin(x), 'g-')
axes[1, 0].set_title('Bottom-Left: Damped sine')
axes[1, 0].set_ylabel('Row 1')
axes[1, 0].set_xlabel('Col 0')
# Bottom-right: axes[1, 1] (row 1, col 1)
axes[1, 1].plot(x, np.log(x + 1), 'm-')
axes[1, 1].set_title('Bottom-Right: log(x+1)')
axes[1, 1].set_xlabel('Col 1')
plt.tight_layout()
plt.show()
# ============================================================================
# SECTION 6: Why This Shape Convention Makes Sense
# ============================================================================
"""
Why axes[row, col] instead of axes[col, row]?
Reason: It matches NumPy and mathematical matrix conventions!
In NumPy arrays and matrices:
- First index = row (vertical position)
- Second index = column (horizontal position)
- This is standard in linear algebra
Example with a NumPy array:
"""
# Create a 3x4 array
arr = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
])
print("\n" + "=" * 70)
print("NUMPY ARRAY INDEXING ANALOGY")
print("=" * 70)
print("Array:")
print(arr)
print()
print(f"arr[0, 0] = {arr[0, 0]} (top-left)")
print(f"arr[0, 3] = {arr[0, 3]} (top-right)")
print(f"arr[2, 0] = {arr[2, 0]} (bottom-left)")
print(f"arr[2, 3] = {arr[2, 3]} (bottom-right)")
print()
print("Same logic applies to axes[row, col]!")
# ============================================================================
# SECTION 7: Flattening Axes Array for Easy Iteration
# ============================================================================
"""
Sometimes you want to iterate over all subplots without worrying about
row/column indexing. You can flatten the axes array!
"""
# Create a 2x3 grid
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
print("\n" + "=" * 70)
print("FLATTENING AXES ARRAY")
print("=" * 70)
print(f"Original shape: {axes.shape}") # (2, 3)
# Flatten to 1D array
axes_flat = axes.flatten()
print(f"Flattened shape: {axes_flat.shape}") # (6,)
# Now you can iterate easily
x = np.linspace(0, 10, 100)
functions = [
('sin(x)', np.sin(x)),
('cos(x)', np.cos(x)),
('sin(2x)', np.sin(2*x)),
('cos(2x)', np.cos(2*x)),
('sin(x/2)', np.sin(x/2)),
('cos(x/2)', np.cos(x/2))
]
for ax, (name, y) in zip(axes_flat, functions):
ax.plot(x, y)
ax.set_title(name)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Alternative: Use ravel() instead of flatten()
# ravel() is similar but returns a view when possible (more efficient)
axes_flat = axes.ravel()
# ============================================================================
# SECTION 8: Common Patterns and Best Practices
# ============================================================================
"""
Common Patterns:
---------------
"""
# Pattern 1: Single subplot (no array)
fig, ax = plt.subplots()
# ax is a single Axes object, NOT an array
# Pattern 2: One row or one column (1D array)
fig, axes = plt.subplots(1, 3) # shape (3,)
fig, axes = plt.subplots(3, 1) # shape (3,)
# axes is a 1D array, use axes[i]
# Pattern 3: Grid (2D array)
fig, axes = plt.subplots(2, 3) # shape (2, 3)
# axes is a 2D array, use axes[i, j]
# Pattern 4: Force axes to always be 2D (even for single row/column)
fig, axes = plt.subplots(1, 3, squeeze=False) # shape (1, 3)
fig, axes = plt.subplots(3, 1, squeeze=False) # shape (3, 1)
# squeeze=False prevents reduction to 1D
print("\n" + "=" * 70)
print("EFFECT OF squeeze PARAMETER")
print("=" * 70)
fig, axes1 = plt.subplots(1, 3) # Default: squeeze=True
fig, axes2 = plt.subplots(1, 3, squeeze=False)
print(f"With squeeze=True (default): axes.shape = {axes1.shape}") # (3,)
print(f"With squeeze=False: axes.shape = {axes2.shape}") # (1, 3)
plt.close('all') # Close the figures we just created
# ============================================================================
# SECTION 9: Handling Different Cases with Robust Code
# ============================================================================
"""
Problem: You might not know in advance if axes will be:
- A single Axes object (1 subplot)
- A 1D array (one row or column)
- A 2D array (grid)
Solution: Always use np.atleast_2d() or ravel() to standardize
"""
def plot_on_grid(nrows, ncols):
"""
Demonstrates robust handling of axes regardless of shape
"""
fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 3*nrows))
# Convert to 1D array for easy iteration
# This works regardless of whether axes is single object, 1D, or 2D
if nrows == 1 and ncols == 1:
axes_list = [axes] # Single axes, wrap in list
else:
axes_list = axes.flatten() # Array of axes, flatten
# Now we can safely iterate
for i, ax in enumerate(axes_list):
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin((i+1)*x))
ax.set_title(f'Subplot {i+1}')
plt.tight_layout()
return fig, axes
# Test with different configurations
print("\n" + "=" * 70)
print("TESTING ROBUST CODE")
print("=" * 70)
fig1, ax1 = plot_on_grid(1, 1) # Single plot
print(f"1x1: type(axes) = {type(ax1)}")
fig2, ax2 = plot_on_grid(2, 2) # 2x2 grid
print(f"2x2: axes.shape = {ax2.shape}")
plt.show()
# ============================================================================
# KEY TAKEAWAYS
# ============================================================================
"""
1. When using plt.subplots(nrows, ncols), axes is a NumPy array
2. Shape of axes array: (nrows, ncols)
3. Indexing: axes[row, col] where row is vertical, col is horizontal
4. This matches NumPy/matrix conventions (row first, column second)
5. For single row/column, axes is 1D: use axes[i]
6. For grids, axes is 2D: use axes[i, j]
7. Use flatten() or ravel() to convert to 1D for easy iteration
8. Use squeeze=False to keep axes as 2D even for single row/column
9. Once you understand the convention, it's very handy!
Common Gotchas:
--------------
✗ axes[col, row] # WRONG! Don't think horizontally first
✓ axes[row, col] # CORRECT! Think vertically first (like matrices)
Quick Reference:
---------------
plt.subplots(1, 1) → ax (single Axes object)
plt.subplots(1, n) → axes (1D array, shape (n,))
plt.subplots(n, 1) → axes (1D array, shape (n,))
plt.subplots(m, n) → axes (2D array, shape (m, n))
"""
print("\n" + "=" * 70)
print("VISUAL SUMMARY: axes[row, col]")
print("=" * 70)
print(" Col 0 Col 1 Col 2")
print("Row 0: [0,0] [0,1] [0,2]")
print("Row 1: [1,0] [1,1] [1,2]")
print()
print("First index increases DOWN (rows)")
print("Second index increases RIGHT (columns)")
print("=" * 70)