2D FFT for Image Processing¶
Why 2D FFT for Images?¶
A digital image is a 2D signal: a matrix of pixel intensities. Just as 1D FFT decomposes signals into frequency components, 2D FFT decomposes images into frequency components in both spatial dimensions.
Key insight: - Low frequencies in an image represent smooth regions (large color/brightness areas) - High frequencies represent edges, textures, and fine details - Periodic patterns (like striped noise) show up as concentrated peaks in frequency space
This opens possibilities: - Denoising: Suppress high-frequency noise while preserving edges - Blur detection: Analyze frequency content - Remove periodic noise: Identify and eliminate repeating patterns - Image restoration: Reverse certain degradations - Efficient convolution: For large kernels, FFT-based convolution is faster than spatial convolution
2D FFT Fundamentals¶
Mathematical Definition¶
The 2D DFT is:
where: - \(x(m, n)\) is the pixel intensity at row \(m\), column \(n\) - \(X(u, v)\) is the frequency component at frequency \((u, v)\) - \(M, N\) are image dimensions
In practice: Use np.fft.fft2() or scipy.fftpack.fftn() instead of implementing manually.
Computing 2D FFT¶
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
# Load or create an image
# For this example, create a simple geometric image
image = np.zeros((256, 256))
image[50:150, 50:150] = 1 # white square
# Compute 2D FFT
X = np.fft.fft2(image)
# Magnitude spectrum
X_mag = np.abs(X)
# Log scale for better visualization
X_mag_log = np.log1p(X_mag)
print(f"Image shape: {image.shape}")
print(f"FFT shape: {X.shape}")
print(f"FFT dtype: {X.dtype}") # complex128
Interpreting Frequency Components¶
The output X is a 2D array of complex numbers:
- Real part: Cosine components
- Imaginary part: Sine components
- Magnitude \(|X(u, v)|\): Amplitude at frequency \((u, v)\)
- Phase \(\angle X(u, v)\): Phase shift
Frequency layout (without fftshift): - Origin \((0, 0)\) is at top-left - Positive frequencies in top-right and bottom-left - Negative frequencies in bottom-right (due to complex conjugate symmetry)
Using fftshift() to Center Zero Frequency¶
By default, the zero frequency (DC component) is at the top-left corner. To visualize conveniently, center it:
# Shift zero frequency to center
X_shifted = np.fft.fftshift(X)
# Now:
# - Center of image: low frequencies (DC)
# - Edges of image: high frequencies
Complete example:
import numpy as np
import matplotlib.pyplot as plt
# Simple test image: white square on black background
image = np.zeros((128, 128))
image[40:88, 40:88] = 255
# Compute FFT
X = np.fft.fft2(image)
# Magnitude and log scale
X_mag = np.abs(X)
X_mag_log = np.log1p(X_mag)
# Shift for visualization
X_shifted = np.fft.fftshift(X_mag_log)
# Plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Original image
ax = axes[0]
ax.imshow(image, cmap='gray')
ax.set_title('Original Image')
ax.axis('off')
# FFT magnitude (unshifted)
ax = axes[1]
ax.imshow(X_mag_log, cmap='hot')
ax.set_title('FFT Magnitude (log) - Unshifted')
ax.axis('off')
# FFT magnitude (shifted)
ax = axes[2]
ax.imshow(X_shifted, cmap='hot')
ax.set_title('FFT Magnitude (log) - Shifted')
ax.axis('off')
plt.tight_layout()
plt.show()
Symmetry in Real Images
For real-valued images (typical case), the FFT has Hermitian symmetry:
This means the negative frequencies contain redundant information (they're complex conjugates of positive frequencies). In some applications, you only need half the FFT output.
Application: Removing Periodic Noise¶
Periodic patterns (like scan lines or striped artifacts) create concentrated peaks in frequency space. You can remove them by suppressing those peaks.
Workflow¶
- Compute FFT of the noisy image
- Identify noise peaks visually or algorithmically
- Create a notch filter (suppress specific frequencies)
- Apply the filter in frequency domain
- Inverse FFT to get the cleaned image
Example: Removing Periodic Noise¶
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy.ndimage import gaussian_filter
# Create a clean image (sample)
image_clean = np.zeros((256, 256))
image_clean = ndimage.binary_dilation(
image_clean,
iterations=0
).astype(float)
# Add realistic content: a circle
y, x = np.ogrid[:256, :256]
mask = (x - 128)**2 + (y - 128)**2 <= 60**2
image_clean[mask] = 200
# Add periodic stripe noise (horizontal)
stripe_period = 16
image_noisy = image_clean.copy()
for i in range(0, 256, stripe_period):
image_noisy[i:i+2, :] += 50
image_noisy = np.clip(image_noisy, 0, 255)
# Compute FFT
X = np.fft.fft2(image_noisy)
X_shifted = np.fft.fftshift(X)
# Visualize FFT to identify noise peaks
X_mag = np.log1p(np.abs(X_shifted))
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Row 1: Original and FFT
ax = axes[0, 0]
ax.imshow(image_noisy, cmap='gray')
ax.set_title('Noisy Image (stripe pattern)')
ax = axes[0, 1]
ax.imshow(X_mag, cmap='hot')
ax.set_title('FFT Magnitude (log)')
# Create notch filter to suppress stripe noise
# Stripes repeat every 16 pixels → peak at frequency 256/16 = 16
notch_filter = np.ones_like(X_shifted)
# Suppress the vertical stripe peaks (horizontal frequency)
# These appear at (center_y, center_x ± stripe_freq)
freq_stripe = 256 // stripe_period
center = 128
# Create Gaussian notch at the noise frequency
for dy in [-freq_stripe, freq_stripe]:
y_notch = center + dy
for x_idx in range(256):
dist = (x_idx - center)**2 + (y_notch - center)**2
notch_filter[y_notch, x_idx] *= np.exp(-dist / 1000)
ax = axes[0, 2]
ax.imshow(notch_filter, cmap='gray')
ax.set_title('Notch Filter (suppresses peaks)')
# Apply filter in frequency domain
X_filtered = X_shifted * notch_filter
# Inverse FFT
X_filtered_unshifted = np.fft.ifftshift(X_filtered)
image_restored = np.fft.ifft2(X_filtered_unshifted).real
# Clip to valid range
image_restored = np.clip(image_restored, 0, 255)
# Visualize results
ax = axes[1, 0]
ax.imshow(image_restored, cmap='gray')
ax.set_title('Restored Image')
ax = axes[1, 1]
ax.imshow(np.abs(X_shifted * notch_filter), cmap='hot')
ax.set_title('Filtered FFT')
# Difference
ax = axes[1, 2]
diff = np.abs(image_noisy - image_restored)
ax.imshow(diff, cmap='hot')
ax.set_title('Removed Noise (difference)')
for ax in axes.flat:
ax.axis('off')
plt.suptitle('Periodic Noise Removal via Frequency Domain')
plt.tight_layout()
plt.show()
# Quantify improvement
noise_original = np.mean((image_noisy - image_clean)**2)
noise_restored = np.mean((image_restored - image_clean)**2)
print(f"Original MSE: {noise_original:.1f}")
print(f"Restored MSE: {noise_restored:.1f}")
print(f"Improvement: {noise_original / noise_restored:.1f}x")
More Sophisticated Filters
- Butterworth filter: Smooth roll-off (avoids sharp artifacts)
- Morphological operations: Enhance specific frequency bands
- Wiener filter: Optimal restoration for known noise statistics
- For real applications: Use scikit-image (restoration module)
FFT-Based Convolution¶
For large kernels, FFT-based convolution is faster than spatial domain convolution:
When to Use FFT Convolution¶
- Kernel size > ~10×10 pixels: FFT is faster
- Small kernels (3×3, 5×5): Spatial convolution is faster
- Repetitive convolution: Amortize FFT computation cost
Example: Blur with FFT¶
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
# Create test image (checkerboard)
image = np.zeros((256, 256))
image[::8, ::8] = 1
image[1::8, 1::8] = 1
image[::8, 1::8] = 1
image[1::8, ::8] = 1
# Define blur kernel (Gaussian)
kernel_size = 31
kernel = ndimage.gaussian_filter(
np.ones((kernel_size, kernel_size)),
sigma=5
)
kernel /= kernel.sum()
print(f"Image shape: {image.shape}")
print(f"Kernel shape: {kernel.shape}")
# Method 1: Spatial convolution
result_spatial = ndimage.convolve(image, kernel)
# Method 2: FFT convolution
# Need to zero-pad to avoid circular convolution
M, N = image.shape
Mk, Nk = kernel.shape
# Pad to M + Mk - 1, N + Nk - 1
output_shape = (M + Mk - 1, N + Nk - 1)
# Pad image and kernel
image_padded = np.pad(image, ((0, Mk-1), (0, Nk-1)), mode='constant')
kernel_padded = np.pad(kernel, ((0, M-1), (0, N-1)), mode='constant')
# FFT-based convolution
X = np.fft.fft2(image_padded)
H = np.fft.fft2(kernel_padded)
Y = X * H
result_fft = np.fft.ifft2(Y).real
# Crop to original size
result_fft = result_fft[:M, :N]
# Compare
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
ax = axes[0, 0]
ax.imshow(image, cmap='gray')
ax.set_title('Original Image')
ax.axis('off')
ax = axes[0, 1]
ax.imshow(kernel, cmap='gray')
ax.set_title('Blur Kernel')
ax.axis('off')
ax = axes[1, 0]
ax.imshow(result_spatial, cmap='gray')
ax.set_title('Spatial Convolution')
ax.axis('off')
ax = axes[1, 1]
ax.imshow(result_fft, cmap='gray')
ax.set_title('FFT Convolution')
ax.axis('off')
plt.tight_layout()
plt.show()
# Verify they produce the same result
difference = np.max(np.abs(result_spatial - result_fft))
print(f"Max difference: {difference:.2e} (should be ~0)")
print(f"Results match: {np.allclose(result_spatial, result_fft, atol=1e-10)}")
Performance Comparison¶
import time
import numpy as np
from scipy import ndimage
# Benchmark
image = np.random.rand(512, 512)
kernel = np.random.rand(64, 64)
kernel /= kernel.sum()
# Time spatial convolution
start = time.perf_counter()
for _ in range(10):
_ = ndimage.convolve(image, kernel)
time_spatial = time.perf_counter() - start
# Time FFT convolution
start = time.perf_counter()
for _ in range(10):
M, N = image.shape
Mk, Nk = kernel.shape
image_padded = np.pad(image, ((0, Mk-1), (0, Nk-1)))
kernel_padded = np.pad(kernel, ((0, M-1), (0, N-1)))
X = np.fft.fft2(image_padded)
H = np.fft.fft2(kernel_padded)
Y = X * H
_ = np.fft.ifft2(Y).real[:M, :N]
time_fft = time.perf_counter() - start
print(f"Spatial convolution: {time_spatial:.3f} sec")
print(f"FFT convolution: {time_fft:.3f} sec")
print(f"Speedup: {time_spatial / time_fft:.1f}x")
Typical speedup: 5-10x for 64×64 kernels, even more for larger kernels.
Real-World FFT Convolution
SciPy provides optimized versions:
from scipy.signal import fftconvolve
result = fftconvolve(image, kernel, mode='same')
Zero-Padding Considerations¶
Zero-padding affects FFT behavior in important ways:
1. Avoiding Circular Convolution¶
By default, FFT assumes circular convolution (the signal wraps around). To get linear convolution, pad with zeros:
# Without padding: circular convolution
X = np.fft.fft2(image)
H = np.fft.fft2(kernel)
Y = X * H
result_circular = np.fft.ifft2(Y).real
# With padding: linear convolution
M, N = image.shape
Mk, Nk = kernel.shape
image_padded = np.pad(image, ((0, Mk-1), (0, Nk-1)))
kernel_padded = np.pad(kernel, ((0, M-1), (0, N-1)))
X = np.fft.fft2(image_padded)
H = np.fft.fft2(kernel_padded)
Y = X * H
result_linear = np.fft.ifft2(Y).real[:M, :N]
2. Frequency Resolution¶
More zero-padding increases frequency resolution but doesn't improve the actual information content:
# Image: 128×128 pixels
image = np.random.rand(128, 128)
# No padding
X1 = np.fft.fft2(image)
freqs1 = np.fft.fftfreq(128, 1.0)
# 2x padding (256×256)
image_padded = np.pad(image, ((0, 128), (0, 128)))
X2 = np.fft.fft2(image_padded)
freqs2 = np.fft.fftfreq(256, 1.0)
print(f"Frequency resolution without padding: {1.0/128:.4f}")
print(f"Frequency resolution with 2x padding: {1.0/256:.4f}")
print(f"More padding = finer frequency grid (interpolation, not more info)")
3. Edge Effects¶
Abrupt image boundaries create high-frequency artifacts. Options: - Zero-padding: Simple but introduces discontinuity - Mirroring: Reflects image edges - Periodic extension: Assumes the image repeats (FFT default)
image = np.random.rand(64, 64)
# Pad with different modes
padded_zeros = np.pad(image, 32, mode='constant', constant_values=0)
padded_reflect = np.pad(image, 32, mode='reflect')
padded_wrap = np.pad(image, 32, mode='wrap')
# Compare frequency content
X_zeros = np.log1p(np.abs(np.fft.fft2(padded_zeros)))
X_reflect = np.log1p(np.abs(np.fft.fft2(padded_reflect)))
# Reflection typically gives lower high-frequency artifacts
Complete Image Denoising Pipeline¶
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
def denoise_fft(image, threshold_percentile=90):
"""Denoise image by suppressing low-magnitude frequency components."""
# Compute FFT
X = np.fft.fft2(image)
# Magnitude spectrum
X_mag = np.abs(X)
# Threshold: keep top X% of frequencies
threshold = np.percentile(X_mag, threshold_percentile)
# Create binary mask
mask = X_mag > threshold
# Apply mask
X_filtered = X * mask
# Inverse FFT
image_denoised = np.fft.ifft2(X_filtered).real
return image_denoised, X, mask
# Test
image = ndimage.gaussian_filter(np.random.rand(128, 128), sigma=2)
image_noisy = image + 0.3 * np.random.randn(128, 128)
image_denoised, X, mask = denoise_fft(image_noisy, threshold_percentile=85)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].imshow(image_noisy, cmap='gray')
axes[0].set_title('Noisy Image')
axes[1].imshow(np.log1p(np.abs(X)), cmap='hot')
axes[1].set_title('FFT Magnitude')
axes[2].imshow(image_denoised, cmap='gray')
axes[2].set_title('Denoised (FFT threshold)')
for ax in axes:
ax.axis('off')
plt.tight_layout()
plt.show()
Summary¶
- 2D FFT decomposes images into frequency components
- Low frequencies = smooth regions; high frequencies = edges and noise
fftshift()centers zero frequency for intuitive visualization- Frequency domain filtering can remove periodic noise effectively
- FFT convolution is faster for large kernels (>10×10)
- Zero-padding avoids circular convolution and can reduce edge artifacts
- Real-world applications use SciPy (
fftconvolve,ndimage) for reliability
Next steps: - Explore convolution theorem for deeper frequency domain theory - Study image restoration techniques (Wiener filtering, etc.) - Apply to computer vision tasks: edge detection, blur analysis, feature extraction