Skip to content

2D FFT for Image Processing

Why 2D FFT for Images?

A digital image is a 2D signal: a matrix of pixel intensities. Just as 1D FFT decomposes signals into frequency components, 2D FFT decomposes images into frequency components in both spatial dimensions.

Key insight: - Low frequencies in an image represent smooth regions (large color/brightness areas) - High frequencies represent edges, textures, and fine details - Periodic patterns (like striped noise) show up as concentrated peaks in frequency space

This opens possibilities: - Denoising: Suppress high-frequency noise while preserving edges - Blur detection: Analyze frequency content - Remove periodic noise: Identify and eliminate repeating patterns - Image restoration: Reverse certain degradations - Efficient convolution: For large kernels, FFT-based convolution is faster than spatial convolution

2D FFT Fundamentals

Mathematical Definition

The 2D DFT is:

\[X(u, v) = \sum_{m=0}^{M-1} \sum_{n=0}^{N-1} x(m, n) \cdot e^{-j2\pi(um/M + vn/N)}\]

where: - \(x(m, n)\) is the pixel intensity at row \(m\), column \(n\) - \(X(u, v)\) is the frequency component at frequency \((u, v)\) - \(M, N\) are image dimensions

In practice: Use np.fft.fft2() or scipy.fftpack.fftn() instead of implementing manually.

Computing 2D FFT

import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage

# Load or create an image
# For this example, create a simple geometric image
image = np.zeros((256, 256))
image[50:150, 50:150] = 1  # white square

# Compute 2D FFT
X = np.fft.fft2(image)

# Magnitude spectrum
X_mag = np.abs(X)

# Log scale for better visualization
X_mag_log = np.log1p(X_mag)

print(f"Image shape: {image.shape}")
print(f"FFT shape: {X.shape}")
print(f"FFT dtype: {X.dtype}")  # complex128

Interpreting Frequency Components

The output X is a 2D array of complex numbers: - Real part: Cosine components - Imaginary part: Sine components - Magnitude \(|X(u, v)|\): Amplitude at frequency \((u, v)\) - Phase \(\angle X(u, v)\): Phase shift

Frequency layout (without fftshift): - Origin \((0, 0)\) is at top-left - Positive frequencies in top-right and bottom-left - Negative frequencies in bottom-right (due to complex conjugate symmetry)

Using fftshift() to Center Zero Frequency

By default, the zero frequency (DC component) is at the top-left corner. To visualize conveniently, center it:

# Shift zero frequency to center
X_shifted = np.fft.fftshift(X)

# Now:
# - Center of image: low frequencies (DC)
# - Edges of image: high frequencies

Complete example:

import numpy as np
import matplotlib.pyplot as plt

# Simple test image: white square on black background
image = np.zeros((128, 128))
image[40:88, 40:88] = 255

# Compute FFT
X = np.fft.fft2(image)

# Magnitude and log scale
X_mag = np.abs(X)
X_mag_log = np.log1p(X_mag)

# Shift for visualization
X_shifted = np.fft.fftshift(X_mag_log)

# Plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original image
ax = axes[0]
ax.imshow(image, cmap='gray')
ax.set_title('Original Image')
ax.axis('off')

# FFT magnitude (unshifted)
ax = axes[1]
ax.imshow(X_mag_log, cmap='hot')
ax.set_title('FFT Magnitude (log) - Unshifted')
ax.axis('off')

# FFT magnitude (shifted)
ax = axes[2]
ax.imshow(X_shifted, cmap='hot')
ax.set_title('FFT Magnitude (log) - Shifted')
ax.axis('off')

plt.tight_layout()
plt.show()

Symmetry in Real Images

For real-valued images (typical case), the FFT has Hermitian symmetry:

\[X(-u, -v) = X^*(u, v)\]

This means the negative frequencies contain redundant information (they're complex conjugates of positive frequencies). In some applications, you only need half the FFT output.

Application: Removing Periodic Noise

Periodic patterns (like scan lines or striped artifacts) create concentrated peaks in frequency space. You can remove them by suppressing those peaks.

Workflow

  1. Compute FFT of the noisy image
  2. Identify noise peaks visually or algorithmically
  3. Create a notch filter (suppress specific frequencies)
  4. Apply the filter in frequency domain
  5. Inverse FFT to get the cleaned image

Example: Removing Periodic Noise

import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy.ndimage import gaussian_filter

# Create a clean image (sample)
image_clean = np.zeros((256, 256))
image_clean = ndimage.binary_dilation(
    image_clean,
    iterations=0
).astype(float)

# Add realistic content: a circle
y, x = np.ogrid[:256, :256]
mask = (x - 128)**2 + (y - 128)**2 <= 60**2
image_clean[mask] = 200

# Add periodic stripe noise (horizontal)
stripe_period = 16
image_noisy = image_clean.copy()
for i in range(0, 256, stripe_period):
    image_noisy[i:i+2, :] += 50

image_noisy = np.clip(image_noisy, 0, 255)

# Compute FFT
X = np.fft.fft2(image_noisy)
X_shifted = np.fft.fftshift(X)

# Visualize FFT to identify noise peaks
X_mag = np.log1p(np.abs(X_shifted))

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Row 1: Original and FFT
ax = axes[0, 0]
ax.imshow(image_noisy, cmap='gray')
ax.set_title('Noisy Image (stripe pattern)')

ax = axes[0, 1]
ax.imshow(X_mag, cmap='hot')
ax.set_title('FFT Magnitude (log)')

# Create notch filter to suppress stripe noise
# Stripes repeat every 16 pixels → peak at frequency 256/16 = 16
notch_filter = np.ones_like(X_shifted)

# Suppress the vertical stripe peaks (horizontal frequency)
# These appear at (center_y, center_x ± stripe_freq)
freq_stripe = 256 // stripe_period
center = 128

# Create Gaussian notch at the noise frequency
for dy in [-freq_stripe, freq_stripe]:
    y_notch = center + dy
    for x_idx in range(256):
        dist = (x_idx - center)**2 + (y_notch - center)**2
        notch_filter[y_notch, x_idx] *= np.exp(-dist / 1000)

ax = axes[0, 2]
ax.imshow(notch_filter, cmap='gray')
ax.set_title('Notch Filter (suppresses peaks)')

# Apply filter in frequency domain
X_filtered = X_shifted * notch_filter

# Inverse FFT
X_filtered_unshifted = np.fft.ifftshift(X_filtered)
image_restored = np.fft.ifft2(X_filtered_unshifted).real

# Clip to valid range
image_restored = np.clip(image_restored, 0, 255)

# Visualize results
ax = axes[1, 0]
ax.imshow(image_restored, cmap='gray')
ax.set_title('Restored Image')

ax = axes[1, 1]
ax.imshow(np.abs(X_shifted * notch_filter), cmap='hot')
ax.set_title('Filtered FFT')

# Difference
ax = axes[1, 2]
diff = np.abs(image_noisy - image_restored)
ax.imshow(diff, cmap='hot')
ax.set_title('Removed Noise (difference)')

for ax in axes.flat:
    ax.axis('off')

plt.suptitle('Periodic Noise Removal via Frequency Domain')
plt.tight_layout()
plt.show()

# Quantify improvement
noise_original = np.mean((image_noisy - image_clean)**2)
noise_restored = np.mean((image_restored - image_clean)**2)

print(f"Original MSE: {noise_original:.1f}")
print(f"Restored MSE: {noise_restored:.1f}")
print(f"Improvement: {noise_original / noise_restored:.1f}x")

More Sophisticated Filters

  • Butterworth filter: Smooth roll-off (avoids sharp artifacts)
  • Morphological operations: Enhance specific frequency bands
  • Wiener filter: Optimal restoration for known noise statistics
  • For real applications: Use scikit-image (restoration module)

FFT-Based Convolution

For large kernels, FFT-based convolution is faster than spatial domain convolution:

\[y = \text{IFFT}(\text{FFT}(x) \cdot \text{FFT}(h))\]

When to Use FFT Convolution

  • Kernel size > ~10×10 pixels: FFT is faster
  • Small kernels (3×3, 5×5): Spatial convolution is faster
  • Repetitive convolution: Amortize FFT computation cost

Example: Blur with FFT

import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt

# Create test image (checkerboard)
image = np.zeros((256, 256))
image[::8, ::8] = 1
image[1::8, 1::8] = 1
image[::8, 1::8] = 1
image[1::8, ::8] = 1

# Define blur kernel (Gaussian)
kernel_size = 31
kernel = ndimage.gaussian_filter(
    np.ones((kernel_size, kernel_size)),
    sigma=5
)
kernel /= kernel.sum()

print(f"Image shape: {image.shape}")
print(f"Kernel shape: {kernel.shape}")

# Method 1: Spatial convolution
result_spatial = ndimage.convolve(image, kernel)

# Method 2: FFT convolution
# Need to zero-pad to avoid circular convolution
M, N = image.shape
Mk, Nk = kernel.shape

# Pad to M + Mk - 1, N + Nk - 1
output_shape = (M + Mk - 1, N + Nk - 1)

# Pad image and kernel
image_padded = np.pad(image, ((0, Mk-1), (0, Nk-1)), mode='constant')
kernel_padded = np.pad(kernel, ((0, M-1), (0, N-1)), mode='constant')

# FFT-based convolution
X = np.fft.fft2(image_padded)
H = np.fft.fft2(kernel_padded)
Y = X * H
result_fft = np.fft.ifft2(Y).real

# Crop to original size
result_fft = result_fft[:M, :N]

# Compare
fig, axes = plt.subplots(2, 2, figsize=(12, 12))

ax = axes[0, 0]
ax.imshow(image, cmap='gray')
ax.set_title('Original Image')
ax.axis('off')

ax = axes[0, 1]
ax.imshow(kernel, cmap='gray')
ax.set_title('Blur Kernel')
ax.axis('off')

ax = axes[1, 0]
ax.imshow(result_spatial, cmap='gray')
ax.set_title('Spatial Convolution')
ax.axis('off')

ax = axes[1, 1]
ax.imshow(result_fft, cmap='gray')
ax.set_title('FFT Convolution')
ax.axis('off')

plt.tight_layout()
plt.show()

# Verify they produce the same result
difference = np.max(np.abs(result_spatial - result_fft))
print(f"Max difference: {difference:.2e} (should be ~0)")
print(f"Results match: {np.allclose(result_spatial, result_fft, atol=1e-10)}")

Performance Comparison

import time
import numpy as np
from scipy import ndimage

# Benchmark
image = np.random.rand(512, 512)
kernel = np.random.rand(64, 64)
kernel /= kernel.sum()

# Time spatial convolution
start = time.perf_counter()
for _ in range(10):
    _ = ndimage.convolve(image, kernel)
time_spatial = time.perf_counter() - start

# Time FFT convolution
start = time.perf_counter()
for _ in range(10):
    M, N = image.shape
    Mk, Nk = kernel.shape
    image_padded = np.pad(image, ((0, Mk-1), (0, Nk-1)))
    kernel_padded = np.pad(kernel, ((0, M-1), (0, N-1)))
    X = np.fft.fft2(image_padded)
    H = np.fft.fft2(kernel_padded)
    Y = X * H
    _ = np.fft.ifft2(Y).real[:M, :N]
time_fft = time.perf_counter() - start

print(f"Spatial convolution: {time_spatial:.3f} sec")
print(f"FFT convolution: {time_fft:.3f} sec")
print(f"Speedup: {time_spatial / time_fft:.1f}x")

Typical speedup: 5-10x for 64×64 kernels, even more for larger kernels.

Real-World FFT Convolution

SciPy provides optimized versions:

from scipy.signal import fftconvolve
result = fftconvolve(image, kernel, mode='same')
This handles padding and edge modes automatically.

Zero-Padding Considerations

Zero-padding affects FFT behavior in important ways:

1. Avoiding Circular Convolution

By default, FFT assumes circular convolution (the signal wraps around). To get linear convolution, pad with zeros:

# Without padding: circular convolution
X = np.fft.fft2(image)
H = np.fft.fft2(kernel)
Y = X * H
result_circular = np.fft.ifft2(Y).real

# With padding: linear convolution
M, N = image.shape
Mk, Nk = kernel.shape
image_padded = np.pad(image, ((0, Mk-1), (0, Nk-1)))
kernel_padded = np.pad(kernel, ((0, M-1), (0, N-1)))
X = np.fft.fft2(image_padded)
H = np.fft.fft2(kernel_padded)
Y = X * H
result_linear = np.fft.ifft2(Y).real[:M, :N]

2. Frequency Resolution

More zero-padding increases frequency resolution but doesn't improve the actual information content:

# Image: 128×128 pixels
image = np.random.rand(128, 128)

# No padding
X1 = np.fft.fft2(image)
freqs1 = np.fft.fftfreq(128, 1.0)

# 2x padding (256×256)
image_padded = np.pad(image, ((0, 128), (0, 128)))
X2 = np.fft.fft2(image_padded)
freqs2 = np.fft.fftfreq(256, 1.0)

print(f"Frequency resolution without padding: {1.0/128:.4f}")
print(f"Frequency resolution with 2x padding: {1.0/256:.4f}")
print(f"More padding = finer frequency grid (interpolation, not more info)")

3. Edge Effects

Abrupt image boundaries create high-frequency artifacts. Options: - Zero-padding: Simple but introduces discontinuity - Mirroring: Reflects image edges - Periodic extension: Assumes the image repeats (FFT default)

image = np.random.rand(64, 64)

# Pad with different modes
padded_zeros = np.pad(image, 32, mode='constant', constant_values=0)
padded_reflect = np.pad(image, 32, mode='reflect')
padded_wrap = np.pad(image, 32, mode='wrap')

# Compare frequency content
X_zeros = np.log1p(np.abs(np.fft.fft2(padded_zeros)))
X_reflect = np.log1p(np.abs(np.fft.fft2(padded_reflect)))

# Reflection typically gives lower high-frequency artifacts

Complete Image Denoising Pipeline

import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt

def denoise_fft(image, threshold_percentile=90):
    """Denoise image by suppressing low-magnitude frequency components."""

    # Compute FFT
    X = np.fft.fft2(image)

    # Magnitude spectrum
    X_mag = np.abs(X)

    # Threshold: keep top X% of frequencies
    threshold = np.percentile(X_mag, threshold_percentile)

    # Create binary mask
    mask = X_mag > threshold

    # Apply mask
    X_filtered = X * mask

    # Inverse FFT
    image_denoised = np.fft.ifft2(X_filtered).real

    return image_denoised, X, mask

# Test
image = ndimage.gaussian_filter(np.random.rand(128, 128), sigma=2)
image_noisy = image + 0.3 * np.random.randn(128, 128)

image_denoised, X, mask = denoise_fft(image_noisy, threshold_percentile=85)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].imshow(image_noisy, cmap='gray')
axes[0].set_title('Noisy Image')

axes[1].imshow(np.log1p(np.abs(X)), cmap='hot')
axes[1].set_title('FFT Magnitude')

axes[2].imshow(image_denoised, cmap='gray')
axes[2].set_title('Denoised (FFT threshold)')

for ax in axes:
    ax.axis('off')

plt.tight_layout()
plt.show()

Summary

  • 2D FFT decomposes images into frequency components
  • Low frequencies = smooth regions; high frequencies = edges and noise
  • fftshift() centers zero frequency for intuitive visualization
  • Frequency domain filtering can remove periodic noise effectively
  • FFT convolution is faster for large kernels (>10×10)
  • Zero-padding avoids circular convolution and can reduce edge artifacts
  • Real-world applications use SciPy (fftconvolve, ndimage) for reliability

Next steps: - Explore convolution theorem for deeper frequency domain theory - Study image restoration techniques (Wiener filtering, etc.) - Apply to computer vision tasks: edge detection, blur analysis, feature extraction