Skip to content

Heatmaps with imshow

The ax.imshow() method displays 2D data as a color-coded image, ideal for matrices, correlation tables, and gridded data.

Mental Model

imshow() treats a 2D array as a grid of pixels, coloring each cell by its value. Row 0 is at the top (image convention), and cells are equally sized. It is the simplest way to visualize matrices and correlation tables. For non-uniform grids or coordinate-based data, use pcolormesh() instead.

A heatmap is a visualization of a scalar field: each cell represents a value at a location, and color shows how that value changes across space. In this framing, color is a third axis — x and y give position, color gives value. Heatmaps are discrete versions of contour and density plots.

Reading Heatmaps

Visual feature Interpretation
Bright/dark region High/low values
Color gradient Transition between values
Block of uniform color Plateau or cluster
Diagonal pattern Correlation or symmetry
Isolated bright cell Outlier or peak

Coordinate Convention

imshow() uses image coordinates by default: row 0 is at the top (origin='upper'). This matches image/pixel convention but is upside-down relative to mathematical plots. Set origin='lower' when your y-axis represents a quantity that increases upward (e.g., frequency, temperature). Forgetting this is one of the most common heatmap bugs.

For data on non-uniform or real-valued grids, use pcolormesh() instead — it accepts explicit coordinate arrays. For colormap guidance, see Colormap Selection.

Basic Heatmap

Create a simple heatmap from a 2D array.

1. Import and Setup

python import matplotlib.pyplot as plt import numpy as np

2. Create 2D Data

python np.random.seed(42) data = np.random.rand(10, 10)

3. Display with imshow

python fig, ax = plt.subplots() im = ax.imshow(data) plt.colorbar(im) plt.show()

Colormap Selection

The cmap keyword controls the color scheme.

1. Sequential Colormaps

```python fig, axes = plt.subplots(1, 3, figsize=(12, 4))

for ax, cmap in zip(axes, ['viridis', 'plasma', 'Blues']): im = ax.imshow(data, cmap=cmap) ax.set_title(cmap) plt.colorbar(im, ax=ax)

plt.tight_layout() plt.show() ```

2. Diverging Colormaps

```python data_centered = np.random.randn(10, 10)

fig, ax = plt.subplots() im = ax.imshow(data_centered, cmap='RdBu', vmin=-2, vmax=2) plt.colorbar(im) plt.show() ```

3. Common Colormaps

```python

Sequential: 'viridis', 'plasma', 'inferno', 'magma', 'Blues', 'Greens'

Diverging: 'RdBu', 'coolwarm', 'seismic', 'PiYG'

Qualitative: 'Set1', 'Set2', 'tab10', 'tab20'

```

Value Range

Control the mapping between data values and colors.

1. Auto Range (Default)

python ax.imshow(data) # Maps min to bottom, max to top of colormap

2. Fixed Range

python ax.imshow(data, vmin=0, vmax=1)

3. Centered at Zero

python max_abs = np.abs(data_centered).max() ax.imshow(data_centered, cmap='RdBu', vmin=-max_abs, vmax=max_abs)

Aspect Ratio

The aspect keyword controls pixel shape.

1. Equal Aspect (Default)

python ax.imshow(data, aspect='equal') # Square pixels

2. Auto Aspect

python ax.imshow(data, aspect='auto') # Fills axes, may stretch

3. Numeric Aspect

python ax.imshow(data, aspect=2) # Height = 2 × width per pixel

Axis Labels and Ticks

Customize tick positions and labels for matrix visualization.

1. Set Tick Positions

```python fig, ax = plt.subplots() im = ax.imshow(data)

ax.set_xticks(np.arange(10)) ax.set_yticks(np.arange(10)) plt.show() ```

2. Custom Labels

```python row_labels = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'] col_labels = [f'Col {i}' for i in range(10)]

ax.set_xticks(np.arange(10)) ax.set_yticks(np.arange(10)) ax.set_xticklabels(col_labels, rotation=45, ha='right') ax.set_yticklabels(row_labels) ```

3. Move X Labels to Top

python ax.xaxis.set_ticks_position('top') ax.xaxis.set_label_position('top')

Annotating Cells

Add text values to each cell.

Scalability

The nested-loop annotation pattern below works well for small matrices (up to ~20×20). For larger matrices, annotations become unreadable and the loop becomes slow. For large data, rely on the colorbar and skip cell labels.

1. Basic Annotation

```python fig, ax = plt.subplots() im = ax.imshow(data, cmap='Blues')

for i in range(data.shape[0]): for j in range(data.shape[1]): ax.text(j, i, f'{data[i, j]:.2f}', ha='center', va='center', fontsize=8)

plt.show() ```

2. Contrast Text Color

```python threshold = data.max() / 2

for i in range(data.shape[0]): for j in range(data.shape[1]): color = 'white' if data[i, j] > threshold else 'black' ax.text(j, i, f'{data[i, j]:.2f}', ha='center', va='center', color=color, fontsize=8) ```

3. Integer Annotation

```python int_data = np.random.randint(0, 100, (5, 5))

for i in range(int_data.shape[0]): for j in range(int_data.shape[1]): ax.text(j, i, int_data[i, j], ha='center', va='center') ```

Correlation Matrix

A common use case for imshow heatmaps.

1. Compute Correlation

python np.random.seed(42) df_data = np.random.randn(100, 5) corr_matrix = np.corrcoef(df_data.T)

2. Display Correlation Heatmap

```python fig, ax = plt.subplots(figsize=(6, 5))

im = ax.imshow(corr_matrix, cmap='RdBu', vmin=-1, vmax=1)

labels = ['Var A', 'Var B', 'Var C', 'Var D', 'Var E'] ax.set_xticks(np.arange(5)) ax.set_yticks(np.arange(5)) ax.set_xticklabels(labels, rotation=45, ha='right') ax.set_yticklabels(labels)

for i in range(5): for j in range(5): color = 'white' if abs(corr_matrix[i, j]) > 0.5 else 'black' ax.text(j, i, f'{corr_matrix[i, j]:.2f}', ha='center', va='center', color=color)

plt.colorbar(im, label='Correlation') plt.tight_layout() plt.show() ```

3. Mask Upper Triangle

```python mask = np.triu(np.ones_like(corr_matrix, dtype=bool), k=1) masked_corr = np.ma.masked_array(corr_matrix, mask)

ax.imshow(masked_corr, cmap='RdBu', vmin=-1, vmax=1) ```

Interpolation

Control how pixel boundaries are rendered.

1. No Interpolation (Default for Small Data)

python ax.imshow(data, interpolation='nearest')

2. Smooth Interpolation

python ax.imshow(data, interpolation='bilinear')

3. Common Options

```python

'nearest': Sharp pixel boundaries

'bilinear': Smooth linear interpolation

'bicubic': Smoother cubic interpolation

'gaussian': Gaussian smoothing

```


Runnable Example: seaborn_matrix_plots.py

```python """ Tutorial 06: Matrix Plots Heatmaps, cluster maps, correlation matrices Level: Intermediate """ import seaborn as sns import matplotlib.pyplot as plt import pandas as pd import numpy as np

=============================================================================

Main

=============================================================================

if name == "main":

sns.set_style("white")
tips = sns.load_dataset('tips')

# Correlation heatmap
plt.figure(figsize=(8, 6))
numeric_cols = tips.select_dtypes(include=[np.number])
corr = numeric_cols.corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', center=0, square=True, linewidths=1)
plt.title('Correlation Heatmap', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Pivot table heatmap
pivot_data = tips.pivot_table(values='tip', index='day', columns='time', aggfunc='mean')
plt.figure(figsize=(8, 6))
sns.heatmap(pivot_data, annot=True, fmt='.2f', cmap='YlOrRd')
plt.title('Average Tip: Day vs Time', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Clustermap - with hierarchical clustering
plt.figure(figsize=(10, 8))
sns.clustermap(corr, cmap='coolwarm', center=0, linewidths=1, annot=True)
plt.show()

print("Tutorial 06 demonstrates matrix visualizations")
print("Key functions: heatmap(), clustermap()")

```

In Machine Learning

Heatmaps are used for confusion matrices (predicted vs actual class, with cell values = counts) and correlation matrices (feature-to-feature Pearson correlations). Annotating cells with ax.text() and choosing a diverging colormap ('RdBu' centered at 0 for correlations, sequential for counts) makes these instantly readable.


Exercises

Exercise 1. Write code that creates a 5x5 NumPy array of random values and displays it as a heatmap using ax.imshow(). Add a colorbar and set the colormap to 'viridis'.

Solution to Exercise 1

```python import matplotlib.pyplot as plt import numpy as np

np.random.seed(42) data = np.random.rand(5, 5)

fig, ax = plt.subplots() im = ax.imshow(data, cmap='viridis') fig.colorbar(im, ax=ax) ax.set_title('5x5 Random Heatmap') plt.show() ```


Exercise 2. Explain the role of origin='lower' vs origin='upper' in ax.imshow(). When should you use each?

Solution to Exercise 2

origin='upper' (default) places row 0 at the top of the image, matching how matrices and digital images are conventionally displayed. origin='lower' places row 0 at the bottom, matching mathematical coordinate systems where the y-axis increases upward.

Use 'upper' for actual image data and matrices. Use 'lower' for scientific data like KDE, contour-like heatmaps, or any data where the y-axis should increase upward.


Exercise 3. Write code that creates a heatmap of a 10x10 array and adds text annotations showing the value in each cell using a loop with ax.text().

Solution to Exercise 3

```python import matplotlib.pyplot as plt import numpy as np

np.random.seed(42) data = np.random.rand(10, 10)

fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(data, cmap='Blues') fig.colorbar(im, ax=ax)

for i in range(10): for j in range(10): ax.text(j, i, f'{data[i, j]:.2f}', ha='center', va='center', fontsize=7)

ax.set_title('Annotated Heatmap') plt.tight_layout() plt.show() ```


Exercise 4. Create a heatmap with a custom extent parameter to set meaningful x and y axis ranges instead of pixel indices.

Solution to Exercise 4

```python import matplotlib.pyplot as plt import numpy as np

np.random.seed(42) data = np.random.rand(20, 30)

fig, ax = plt.subplots(figsize=(10, 6)) im = ax.imshow(data, cmap='hot', extent=[0, 6, 0, 4], aspect='auto') ax.set_xlabel('X (meters)') ax.set_ylabel('Y (meters)') ax.set_title('Heatmap with Custom Extent') fig.colorbar(im, ax=ax, label='Intensity') plt.show() ```