Einsum Operations¶
np.einsum provides a powerful notation for tensor contractions and array operations.
Mental Model
einsum uses Einstein summation convention: label each axis with a letter, and any letter that appears in the inputs but not the output is summed over. 'ij,jk->ik' is matrix multiplication, 'ii->' is a trace, 'ij->ji' is a transpose. Once you learn the notation, einsum can express almost any array operation in a single, readable call.
Basic Syntax¶
1. Subscript Notation¶
```python import numpy as np
def main(): A = np.array([[1, 2], [3, 4]]) B = np.array([[5, 6], [7, 8]])
# Matrix multiplication: C_ik = sum_j A_ij * B_jk
C = np.einsum('ij,jk->ik', A, B)
print("A @ B via einsum:")
print(C)
print()
print("Verify with @:")
print(A @ B)
if name == "main": main() ```
2. Index Convention¶
- Repeated indices are summed (contracted)
- Output indices appear after
-> - Free indices remain in result
3. Implicit vs Explicit¶
```python import numpy as np
def main(): A = np.array([[1, 2], [3, 4]])
# Explicit output
trace1 = np.einsum('ii->', A)
# Implicit (same result)
trace2 = np.einsum('ii', A)
print(f"Trace (explicit): {trace1}")
print(f"Trace (implicit): {trace2}")
if name == "main": main() ```
Common Operations¶
1. Matrix Transpose¶
```python import numpy as np
def main(): A = np.array([[1, 2, 3], [4, 5, 6]])
# Transpose: swap indices
AT = np.einsum('ij->ji', A)
print("Original:")
print(A)
print()
print("Transposed:")
print(AT)
if name == "main": main() ```
2. Matrix Trace¶
```python import numpy as np
def main(): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# Trace: sum of diagonal
trace = np.einsum('ii->', A)
print(f"einsum trace: {trace}")
print(f"np.trace: {np.trace(A)}")
if name == "main": main() ```
3. Matrix Diagonal¶
```python import numpy as np
def main(): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# Extract diagonal
diag = np.einsum('ii->i', A)
print(f"einsum diag: {diag}")
print(f"np.diag: {np.diag(A)}")
if name == "main": main() ```
Vector Operations¶
1. Dot Product¶
```python import numpy as np
def main(): a = np.array([1, 2, 3]) b = np.array([4, 5, 6])
# Dot product: sum_i a_i * b_i
dot = np.einsum('i,i->', a, b)
print(f"einsum dot: {dot}")
print(f"np.dot: {np.dot(a, b)}")
if name == "main": main() ```
2. Outer Product¶
```python import numpy as np
def main(): a = np.array([1, 2, 3]) b = np.array([4, 5])
# Outer product: C_ij = a_i * b_j
outer = np.einsum('i,j->ij', a, b)
print("einsum outer:")
print(outer)
print()
print("np.outer:")
print(np.outer(a, b))
if name == "main": main() ```
3. Element-wise Product¶
```python import numpy as np
def main(): a = np.array([1, 2, 3]) b = np.array([4, 5, 6])
# Hadamard product: c_i = a_i * b_i
hadamard = np.einsum('i,i->i', a, b)
print(f"einsum: {hadamard}")
print(f"a * b: {a * b}")
if name == "main": main() ```
Matrix Operations¶
1. Matrix Multiply¶
```python import numpy as np
def main(): A = np.random.randn(3, 4) B = np.random.randn(4, 5)
# C_ik = sum_j A_ij * B_jk
C = np.einsum('ij,jk->ik', A, B)
print(f"Shape: {C.shape}")
print(f"Matches A @ B: {np.allclose(C, A @ B)}")
if name == "main": main() ```
2. Batch Matrix Multiply¶
```python import numpy as np
def main(): # Batch of matrices: (batch, rows, cols) A = np.random.randn(10, 3, 4) B = np.random.randn(10, 4, 5)
# Batch matmul: C_bij = sum_k A_bik * B_bkj
C = np.einsum('bik,bkj->bij', A, B)
print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")
print(f"C shape: {C.shape}")
if name == "main": main() ```
3. Matrix-Vector Product¶
```python import numpy as np
def main(): A = np.array([[1, 2, 3], [4, 5, 6]]) x = np.array([1, 2, 3])
# y_i = sum_j A_ij * x_j
y = np.einsum('ij,j->i', A, x)
print(f"einsum: {y}")
print(f"A @ x: {A @ x}")
if name == "main": main() ```
Tensor Contractions¶
1. 3D Tensor Sum¶
```python import numpy as np
def main(): T = np.random.randn(3, 4, 5)
# Sum over all indices
total = np.einsum('ijk->', T)
print(f"einsum sum: {total:.4f}")
print(f"np.sum: {np.sum(T):.4f}")
if name == "main": main() ```
2. Partial Contraction¶
```python import numpy as np
def main(): T = np.random.randn(3, 4, 5)
# Sum over middle index: R_ik = sum_j T_ijk
R = np.einsum('ijk->ik', T)
print(f"Original shape: {T.shape}")
print(f"Result shape: {R.shape}")
print(f"Matches sum: {np.allclose(R, T.sum(axis=1))}")
if name == "main": main() ```
3. Tensor Product¶
```python import numpy as np
def main(): A = np.random.randn(2, 3) B = np.random.randn(4, 5)
# Tensor (Kronecker-like) product
C = np.einsum('ij,kl->ijkl', A, B)
print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")
print(f"C shape: {C.shape}")
if name == "main": main() ```
Performance¶
1. Optimize Flag¶
```python import numpy as np import time
def main(): A = np.random.randn(100, 100) B = np.random.randn(100, 100) C = np.random.randn(100, 100)
# Without optimization
start = time.perf_counter()
for _ in range(100):
D = np.einsum('ij,jk,kl->il', A, B, C)
time1 = time.perf_counter() - start
# With optimization
start = time.perf_counter()
for _ in range(100):
D = np.einsum('ij,jk,kl->il', A, B, C, optimize=True)
time2 = time.perf_counter() - start
print(f"Without optimize: {time1:.4f} sec")
print(f"With optimize: {time2:.4f} sec")
print(f"Speedup: {time1/time2:.1f}x")
if name == "main": main() ```
2. Path Optimization¶
```python import numpy as np
def main(): A = np.random.randn(10, 20) B = np.random.randn(20, 30) C = np.random.randn(30, 40)
# Get optimized contraction path
path, info = np.einsum_path('ij,jk,kl->il', A, B, C, optimize='optimal')
print("Contraction path:")
print(path)
print()
print(info)
if name == "main": main() ```
3. vs Native Operations¶
```python import numpy as np import time
def main(): A = np.random.randn(1000, 1000) B = np.random.randn(1000, 1000)
# einsum
start = time.perf_counter()
C1 = np.einsum('ij,jk->ik', A, B, optimize=True)
einsum_time = time.perf_counter() - start
# Native matmul
start = time.perf_counter()
C2 = A @ B
matmul_time = time.perf_counter() - start
print(f"einsum time: {einsum_time:.4f} sec")
print(f"matmul time: {matmul_time:.4f} sec")
print(f"Results match: {np.allclose(C1, C2)}")
if name == "main": main() ```
Practical Examples¶
1. Attention Scores¶
```python import numpy as np
def main(): # Query, Key, Value matrices batch, seq_len, d_model = 2, 10, 64
Q = np.random.randn(batch, seq_len, d_model)
K = np.random.randn(batch, seq_len, d_model)
# Attention scores: S_bij = sum_k Q_bik * K_bjk
scores = np.einsum('bik,bjk->bij', Q, K)
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"Scores shape: {scores.shape}")
if name == "main": main() ```
2. Bilinear Form¶
```python import numpy as np
def main(): # x^T A y x = np.array([1, 2, 3]) A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) y = np.array([1, 1, 1])
# Bilinear: sum_ij x_i * A_ij * y_j
result = np.einsum('i,ij,j->', x, A, y)
print(f"einsum: {result}")
print(f"x @ A @ y: {x @ A @ y}")
if name == "main": main() ```
3. Covariance Matrix¶
```python import numpy as np
def main(): # Data matrix: (samples, features) X = np.random.randn(100, 5) X_centered = X - X.mean(axis=0)
# Covariance: C_ij = (1/n) sum_k X_ki * X_kj
n = X.shape[0]
cov_einsum = np.einsum('ki,kj->ij', X_centered, X_centered) / n
print("einsum covariance:")
print(cov_einsum)
print()
print("np.cov (transposed input):")
print(np.cov(X.T, bias=True))
if name == "main": main() ```
Summary Table¶
1. Quick Reference¶
| Operation | einsum | Equivalent |
|---|---|---|
| Transpose | 'ij->ji' |
A.T |
| Trace | 'ii->' |
np.trace(A) |
| Diagonal | 'ii->i' |
np.diag(A) |
| Dot product | 'i,i->' |
np.dot(a, b) |
| Outer product | 'i,j->ij' |
np.outer(a, b) |
| Matrix multiply | 'ij,jk->ik' |
A @ B |
| Batch matmul | 'bij,bjk->bik' |
np.matmul(A, B) |
2. When to Use einsum¶
- Complex tensor contractions
- Multiple simultaneous operations
- Non-standard axis combinations
3. When to Avoid¶
- Simple operations with native equivalents
- When readability is paramount
Exercises¶
Exercise 1. Write a vectorized NumPy solution and a pure Python loop solution for the same computation. Measure and compare their performance using time.perf_counter().
Solution to Exercise 1
```python import numpy as np import time
n = 1_000_000 data = np.random.default_rng(42).random(n)
Python loop¶
start = time.perf_counter() result_py = [x ** 2 for x in data] py_time = time.perf_counter() - start
NumPy vectorized¶
start = time.perf_counter() result_np = data ** 2 np_time = time.perf_counter() - start
print(f"Python: {py_time:.4f}s, NumPy: {np_time:.6f}s") print(f"Speedup: {py_time / np_time:.0f}x") ```
Exercise 2. Identify a potential performance pitfall in the following code and rewrite it using NumPy vectorization:
python
result = []
for i in range(len(data)):
result.append(data[i] ** 2 + 2 * data[i] + 1)
Solution to Exercise 2
```python import numpy as np
data = np.random.default_rng(42).random(100000)
Vectorized (fast)¶
result = data ** 2 + 2 * data + 1 ```
The loop version creates Python objects for each element and calls append repeatedly. The vectorized version computes everything in compiled C code on contiguous memory.
Exercise 3. Explain why NumPy vectorized operations are faster than Python loops. Reference memory layout, type checking overhead, and SIMD instructions in your answer.
Solution to Exercise 3
NumPy vectorized operations are faster because:
- Contiguous memory: NumPy arrays store elements in a contiguous block, enabling efficient CPU cache usage.
- No type checking: Python loops check types at each iteration; NumPy knows the dtype in advance.
- Compiled C loops: The actual computation runs in compiled C/Fortran code, not interpreted Python.
- SIMD instructions: Modern CPUs can process multiple array elements simultaneously using SIMD (Single Instruction, Multiple Data).
Exercise 4. Apply the concepts from this page to a practical problem: given a large array of temperatures in Celsius, convert them all to Fahrenheit and find the maximum. Compare vectorized and loop approaches.
Solution to Exercise 4
```python import numpy as np import time
rng = np.random.default_rng(42) celsius = rng.uniform(-40, 50, 1_000_000)
Vectorized¶
start = time.perf_counter() fahrenheit = celsius * 9/5 + 32 max_f = fahrenheit.max() vec_time = time.perf_counter() - start
Loop¶
start = time.perf_counter() max_f_loop = max(c * 9/5 + 32 for c in celsius) loop_time = time.perf_counter() - start
print(f"Vectorized: {vec_time:.6f}s, max={max_f:.1f}F") print(f"Loop: {loop_time:.4f}s, max={max_f_loop:.1f}F") ```