Skip to content

Einsum Operations

np.einsum provides a powerful notation for tensor contractions and array operations.

Basic Syntax

1. Subscript Notation

import numpy as np

def main():
    A = np.array([[1, 2], [3, 4]])
    B = np.array([[5, 6], [7, 8]])

    # Matrix multiplication: C_ik = sum_j A_ij * B_jk
    C = np.einsum('ij,jk->ik', A, B)

    print("A @ B via einsum:")
    print(C)
    print()
    print("Verify with @:")
    print(A @ B)

if __name__ == "__main__":
    main()

2. Index Convention

  • Repeated indices are summed (contracted)
  • Output indices appear after ->
  • Free indices remain in result

3. Implicit vs Explicit

import numpy as np

def main():
    A = np.array([[1, 2], [3, 4]])

    # Explicit output
    trace1 = np.einsum('ii->', A)

    # Implicit (same result)
    trace2 = np.einsum('ii', A)

    print(f"Trace (explicit): {trace1}")
    print(f"Trace (implicit): {trace2}")

if __name__ == "__main__":
    main()

Common Operations

1. Matrix Transpose

import numpy as np

def main():
    A = np.array([[1, 2, 3],
                  [4, 5, 6]])

    # Transpose: swap indices
    AT = np.einsum('ij->ji', A)

    print("Original:")
    print(A)
    print()
    print("Transposed:")
    print(AT)

if __name__ == "__main__":
    main()

2. Matrix Trace

import numpy as np

def main():
    A = np.array([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

    # Trace: sum of diagonal
    trace = np.einsum('ii->', A)

    print(f"einsum trace: {trace}")
    print(f"np.trace:     {np.trace(A)}")

if __name__ == "__main__":
    main()

3. Matrix Diagonal

import numpy as np

def main():
    A = np.array([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

    # Extract diagonal
    diag = np.einsum('ii->i', A)

    print(f"einsum diag: {diag}")
    print(f"np.diag:     {np.diag(A)}")

if __name__ == "__main__":
    main()

Vector Operations

1. Dot Product

import numpy as np

def main():
    a = np.array([1, 2, 3])
    b = np.array([4, 5, 6])

    # Dot product: sum_i a_i * b_i
    dot = np.einsum('i,i->', a, b)

    print(f"einsum dot: {dot}")
    print(f"np.dot:     {np.dot(a, b)}")

if __name__ == "__main__":
    main()

2. Outer Product

import numpy as np

def main():
    a = np.array([1, 2, 3])
    b = np.array([4, 5])

    # Outer product: C_ij = a_i * b_j
    outer = np.einsum('i,j->ij', a, b)

    print("einsum outer:")
    print(outer)
    print()
    print("np.outer:")
    print(np.outer(a, b))

if __name__ == "__main__":
    main()

3. Element-wise Product

import numpy as np

def main():
    a = np.array([1, 2, 3])
    b = np.array([4, 5, 6])

    # Hadamard product: c_i = a_i * b_i
    hadamard = np.einsum('i,i->i', a, b)

    print(f"einsum:  {hadamard}")
    print(f"a * b:   {a * b}")

if __name__ == "__main__":
    main()

Matrix Operations

1. Matrix Multiply

import numpy as np

def main():
    A = np.random.randn(3, 4)
    B = np.random.randn(4, 5)

    # C_ik = sum_j A_ij * B_jk
    C = np.einsum('ij,jk->ik', A, B)

    print(f"Shape: {C.shape}")
    print(f"Matches A @ B: {np.allclose(C, A @ B)}")

if __name__ == "__main__":
    main()

2. Batch Matrix Multiply

import numpy as np

def main():
    # Batch of matrices: (batch, rows, cols)
    A = np.random.randn(10, 3, 4)
    B = np.random.randn(10, 4, 5)

    # Batch matmul: C_bij = sum_k A_bik * B_bkj
    C = np.einsum('bik,bkj->bij', A, B)

    print(f"A shape: {A.shape}")
    print(f"B shape: {B.shape}")
    print(f"C shape: {C.shape}")

if __name__ == "__main__":
    main()

3. Matrix-Vector Product

import numpy as np

def main():
    A = np.array([[1, 2, 3],
                  [4, 5, 6]])
    x = np.array([1, 2, 3])

    # y_i = sum_j A_ij * x_j
    y = np.einsum('ij,j->i', A, x)

    print(f"einsum: {y}")
    print(f"A @ x:  {A @ x}")

if __name__ == "__main__":
    main()

Tensor Contractions

1. 3D Tensor Sum

import numpy as np

def main():
    T = np.random.randn(3, 4, 5)

    # Sum over all indices
    total = np.einsum('ijk->', T)

    print(f"einsum sum: {total:.4f}")
    print(f"np.sum:     {np.sum(T):.4f}")

if __name__ == "__main__":
    main()

2. Partial Contraction

import numpy as np

def main():
    T = np.random.randn(3, 4, 5)

    # Sum over middle index: R_ik = sum_j T_ijk
    R = np.einsum('ijk->ik', T)

    print(f"Original shape: {T.shape}")
    print(f"Result shape:   {R.shape}")
    print(f"Matches sum:    {np.allclose(R, T.sum(axis=1))}")

if __name__ == "__main__":
    main()

3. Tensor Product

import numpy as np

def main():
    A = np.random.randn(2, 3)
    B = np.random.randn(4, 5)

    # Tensor (Kronecker-like) product
    C = np.einsum('ij,kl->ijkl', A, B)

    print(f"A shape: {A.shape}")
    print(f"B shape: {B.shape}")
    print(f"C shape: {C.shape}")

if __name__ == "__main__":
    main()

Performance

1. Optimize Flag

import numpy as np
import time

def main():
    A = np.random.randn(100, 100)
    B = np.random.randn(100, 100)
    C = np.random.randn(100, 100)

    # Without optimization
    start = time.perf_counter()
    for _ in range(100):
        D = np.einsum('ij,jk,kl->il', A, B, C)
    time1 = time.perf_counter() - start

    # With optimization
    start = time.perf_counter()
    for _ in range(100):
        D = np.einsum('ij,jk,kl->il', A, B, C, optimize=True)
    time2 = time.perf_counter() - start

    print(f"Without optimize: {time1:.4f} sec")
    print(f"With optimize:    {time2:.4f} sec")
    print(f"Speedup:          {time1/time2:.1f}x")

if __name__ == "__main__":
    main()

2. Path Optimization

import numpy as np

def main():
    A = np.random.randn(10, 20)
    B = np.random.randn(20, 30)
    C = np.random.randn(30, 40)

    # Get optimized contraction path
    path, info = np.einsum_path('ij,jk,kl->il', A, B, C, optimize='optimal')

    print("Contraction path:")
    print(path)
    print()
    print(info)

if __name__ == "__main__":
    main()

3. vs Native Operations

import numpy as np
import time

def main():
    A = np.random.randn(1000, 1000)
    B = np.random.randn(1000, 1000)

    # einsum
    start = time.perf_counter()
    C1 = np.einsum('ij,jk->ik', A, B, optimize=True)
    einsum_time = time.perf_counter() - start

    # Native matmul
    start = time.perf_counter()
    C2 = A @ B
    matmul_time = time.perf_counter() - start

    print(f"einsum time: {einsum_time:.4f} sec")
    print(f"matmul time: {matmul_time:.4f} sec")
    print(f"Results match: {np.allclose(C1, C2)}")

if __name__ == "__main__":
    main()

Practical Examples

1. Attention Scores

import numpy as np

def main():
    # Query, Key, Value matrices
    batch, seq_len, d_model = 2, 10, 64

    Q = np.random.randn(batch, seq_len, d_model)
    K = np.random.randn(batch, seq_len, d_model)

    # Attention scores: S_bij = sum_k Q_bik * K_bjk
    scores = np.einsum('bik,bjk->bij', Q, K)

    print(f"Q shape:      {Q.shape}")
    print(f"K shape:      {K.shape}")
    print(f"Scores shape: {scores.shape}")

if __name__ == "__main__":
    main()

2. Bilinear Form

import numpy as np

def main():
    # x^T A y
    x = np.array([1, 2, 3])
    A = np.array([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])
    y = np.array([1, 1, 1])

    # Bilinear: sum_ij x_i * A_ij * y_j
    result = np.einsum('i,ij,j->', x, A, y)

    print(f"einsum:     {result}")
    print(f"x @ A @ y:  {x @ A @ y}")

if __name__ == "__main__":
    main()

3. Covariance Matrix

import numpy as np

def main():
    # Data matrix: (samples, features)
    X = np.random.randn(100, 5)
    X_centered = X - X.mean(axis=0)

    # Covariance: C_ij = (1/n) sum_k X_ki * X_kj
    n = X.shape[0]
    cov_einsum = np.einsum('ki,kj->ij', X_centered, X_centered) / n

    print("einsum covariance:")
    print(cov_einsum)
    print()
    print("np.cov (transposed input):")
    print(np.cov(X.T, bias=True))

if __name__ == "__main__":
    main()

Summary Table

1. Quick Reference

Operation einsum Equivalent
Transpose 'ij->ji' A.T
Trace 'ii->' np.trace(A)
Diagonal 'ii->i' np.diag(A)
Dot product 'i,i->' np.dot(a, b)
Outer product 'i,j->ij' np.outer(a, b)
Matrix multiply 'ij,jk->ik' A @ B
Batch matmul 'bij,bjk->bik' np.matmul(A, B)

2. When to Use einsum

  • Complex tensor contractions
  • Multiple simultaneous operations
  • Non-standard axis combinations

3. When to Avoid

  • Simple operations with native equivalents
  • When readability is paramount