Skip to content

Image I/O

This document covers reading, loading, and displaying images in matplotlib.

Reading Web Images

Load images from URLs directly into NumPy arrays for visualization and manipulation.

PIL and urllib Approach

import matplotlib.pyplot as plt
import numpy as np
import PIL
import urllib

def main():
    url = "https://upload.wikimedia.org/wikipedia/en/4/43/Pok%C3%A9mon_Mewtwo_art.png"
    img = np.array(PIL.Image.open(urllib.request.urlopen(url)))
    print(f"{type(img) = }, {img.shape = }, {img.dtype = }")

    fig, ax = plt.subplots()
    ax.imshow(img)
    ax.axis('off')
    plt.show()

if __name__ == "__main__":
    main()

Understanding Image Arrays

import numpy as np
import PIL
import urllib

url = "https://upload.wikimedia.org/wikipedia/en/4/43/Pok%C3%A9mon_Mewtwo_art.png"
img = np.array(PIL.Image.open(urllib.request.urlopen(url)))

print(f"Shape: {img.shape}")      # (height, width, channels)
print(f"Dtype: {img.dtype}")      # uint8 (0-255)
print(f"Min value: {img.min()}")  # 0
print(f"Max value: {img.max()}")  # 255

Channel Interpretation

Channels Format Description
1 Grayscale Single intensity value
3 RGB Red, Green, Blue
4 RGBA RGB + Alpha (transparency)

Alternative: requests Library

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import requests
from io import BytesIO

def load_image_requests(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    return np.array(img)

url = "https://upload.wikimedia.org/wikipedia/en/4/43/Pok%C3%A9mon_Mewtwo_art.png"
img = load_image_requests(url)

fig, ax = plt.subplots()
ax.imshow(img)
ax.axis('off')
plt.show()

Error Handling

import matplotlib.pyplot as plt
import numpy as np
import PIL
import urllib
from urllib.error import URLError, HTTPError

def safe_load_image(url):
    try:
        img = np.array(PIL.Image.open(urllib.request.urlopen(url)))
        return img
    except HTTPError as e:
        print(f"HTTP Error: {e.code}")
        return None
    except URLError as e:
        print(f"URL Error: {e.reason}")
        return None
    except Exception as e:
        print(f"Error: {e}")
        return None

url = "https://upload.wikimedia.org/wikipedia/en/4/43/Pok%C3%A9mon_Mewtwo_art.png"
img = safe_load_image(url)

if img is not None:
    fig, ax = plt.subplots()
    ax.imshow(img)
    ax.axis('off')
    plt.show()

imread - Reading Local Images

The plt.imread() function reads image files from disk into NumPy arrays.

Basic Usage

import matplotlib.pyplot as plt

def main():
    img = plt.imread('img/mewtwo.jpg')

    fig, ax = plt.subplots()
    ax.imshow(img)
    plt.show()

if __name__ == "__main__":
    main()

Supported Formats

Format Extension Notes
PNG .png Lossless, supports transparency
JPEG .jpg, .jpeg Lossy compression
GIF .gif Limited colors
TIFF .tiff, .tif High quality
BMP .bmp Uncompressed

Format-Specific Behavior

import matplotlib.pyplot as plt

# PNG: Returns float32 (0.0-1.0) or uint8 (0-255)
img_png = plt.imread('image.png')
print(f"PNG dtype: {img_png.dtype}")

# JPEG: Returns uint8 (0-255)
img_jpg = plt.imread('image.jpg')
print(f"JPEG dtype: {img_jpg.dtype}")

Batch Loading from Directory

import matplotlib.pyplot as plt
import os
import glob

def load_images_from_directory(directory, extension='*.jpg'):
    images = []
    filenames = []
    for filepath in glob.glob(os.path.join(directory, extension)):
        img = plt.imread(filepath)
        images.append(img)
        filenames.append(os.path.basename(filepath))
    return images, filenames

images, names = load_images_from_directory('img/')
print(f"Loaded {len(images)} images")

Comparison: imread vs PIL

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

filepath = 'img/mewtwo.jpg'

# Using plt.imread
img_mpl = plt.imread(filepath)

# Using PIL
img_pil = np.array(Image.open(filepath))

print(f"plt.imread shape: {img_mpl.shape}, dtype: {img_mpl.dtype}")
print(f"PIL shape: {img_pil.shape}, dtype: {img_pil.dtype}")

imshow - Displaying Images

The ax.imshow() method displays image data on an Axes.

Basic Usage

import matplotlib.pyplot as plt
import numpy as np

img = np.random.rand(100, 100, 3)

fig, ax = plt.subplots()
ax.imshow(img)
plt.show()

Image Data Types

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Float array (0-1)
img_float = np.random.rand(50, 50, 3)
axes[0].imshow(img_float)
axes[0].set_title('Float [0, 1]')

# Uint8 array (0-255)
img_uint8 = np.random.randint(0, 256, (50, 50, 3), dtype=np.uint8)
axes[1].imshow(img_uint8)
axes[1].set_title('Uint8 [0, 255]')

# Grayscale (2D array)
img_gray = np.random.rand(50, 50)
axes[2].imshow(img_gray, cmap='gray')
axes[2].set_title('Grayscale')

for ax in axes:
    ax.axis('off')
plt.tight_layout()
plt.show()

Key Parameters

Parameter Description Default
X Image data (array-like) Required
cmap Colormap None
aspect Aspect ratio 'equal'
interpolation Interpolation method 'antialiased'
alpha Transparency None
vmin, vmax Value range Data min/max
origin Origin position 'upper'

Aspect Ratio

import matplotlib.pyplot as plt
import numpy as np

img = np.random.rand(50, 100)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(img, aspect='equal')
axes[0].set_title("aspect='equal'")

axes[1].imshow(img, aspect='auto')
axes[1].set_title("aspect='auto'")

axes[2].imshow(img, aspect=0.5)
axes[2].set_title("aspect=0.5")

plt.tight_layout()
plt.show()

Interpolation Methods

import matplotlib.pyplot as plt
import numpy as np

img = np.random.rand(10, 10)
methods = ['nearest', 'bilinear', 'bicubic', 'spline16']

fig, axes = plt.subplots(1, 4, figsize=(12, 3))

for ax, method in zip(axes, methods):
    ax.imshow(img, interpolation=method, cmap='viridis')
    ax.set_title(method)
    ax.axis('off')

plt.tight_layout()
plt.show()

Origin Position

import matplotlib.pyplot as plt
import numpy as np

img = np.arange(25).reshape(5, 5)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))

axes[0].imshow(img, origin='upper')
axes[0].set_title("origin='upper' (default)")

axes[1].imshow(img, origin='lower')
axes[1].set_title("origin='lower'")

plt.tight_layout()
plt.show()

Adding Colorbar

import matplotlib.pyplot as plt
import numpy as np

img = np.random.rand(50, 50)

fig, ax = plt.subplots()
im = ax.imshow(img, cmap='viridis')
fig.colorbar(im, ax=ax)
plt.show()

Deep Learning Examples

FashionMNIST (PyTorch)

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

def main():
    training_data = datasets.FashionMNIST(
        root="data", train=True, download=True, transform=ToTensor()
    )

    labels_map = {
        0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat",
        5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot",
    }

    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

    fig, axes = plt.subplots(1, 10, figsize=(15, 5))

    for imgs, labels in train_dataloader:
        for i, (img, label) in enumerate(zip(imgs, labels)):
            axes[i].imshow(img.squeeze(), cmap='binary')
            axes[i].set_title(labels_map[label.item()])
            axes[i].axis('off')
            if i == 9:
                break
        break

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()

MNIST (TensorFlow)

import matplotlib.pyplot as plt
import tensorflow as tf

def main():
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    fig, axes = plt.subplots(nrows=2, ncols=10, figsize=(10, 2))
    for i in range(2):
        for j in range(10):
            axes[i, j].imshow(x_train[i*10+j], cmap=plt.cm.gray)
            axes[i, j].set_title(f'Label {y_train[i*10+j]}')
            axes[i, j].axis('off')
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()

PyTorch Tensor Handling

import matplotlib.pyplot as plt
import torch

# PyTorch: (C, H, W) -> (H, W, C)
tensor = torch.rand(3, 64, 64)
img = tensor.permute(1, 2, 0).numpy()

fig, ax = plt.subplots()
ax.imshow(img)
ax.axis('off')
plt.show()

import matplotlib.pyplot as plt
import os

def create_image_gallery(image_dir, ncols=4):
    files = [f for f in os.listdir(image_dir) 
             if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    nrows = (len(files) + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))
    axes = axes.flatten() if nrows > 1 else [axes] if ncols == 1 else axes

    for ax, filename in zip(axes, files):
        img = plt.imread(os.path.join(image_dir, filename))
        ax.imshow(img)
        ax.set_title(filename, fontsize=8)
        ax.axis('off')

    # Hide empty subplots
    for ax in axes[len(files):]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

create_image_gallery('img/')