Skip to content

Matrix Square Root

The matrix square root \(X\) satisfies \(X^2 = A\).

linalg.sqrtm

1. Basic Usage

import numpy as np
from scipy import linalg

def main():
    A = np.array([[4, 0],
                  [0, 9]])

    sqrt_A = linalg.sqrtm(A)

    print("A =")
    print(A)
    print()
    print("sqrt(A) =")
    print(sqrt_A)

if __name__ == "__main__":
    main()

2. Verify X² = A

import numpy as np
from scipy import linalg

def main():
    A = np.array([[4, 2],
                  [2, 3]])

    X = linalg.sqrtm(A)

    print("sqrt(A) =")
    print(X)
    print()
    print("sqrt(A) @ sqrt(A) =")
    print((X @ X).round(10))
    print()
    print("Original A =")
    print(A)

if __name__ == "__main__":
    main()

3. Non-Unique

import numpy as np
from scipy import linalg

def main():
    A = np.array([[4, 0],
                  [0, 4]])

    # Principal square root
    X1 = linalg.sqrtm(A)

    # Another square root
    X2 = -X1  # (-X)² = X² = A

    print("Principal sqrt(A):")
    print(X1)
    print()
    print("Another sqrt(A):")
    print(X2)
    print()
    print("Both satisfy X² = A:")
    print(f"X1² = A: {np.allclose(X1 @ X1, A)}")
    print(f"X2² = A: {np.allclose(X2 @ X2, A)}")

if __name__ == "__main__":
    main()

Properties

1. Positive Definite Input

import numpy as np
from scipy import linalg

def main():
    # Positive definite -> real square root
    A = np.array([[4, 2],
                  [2, 5]])

    X = linalg.sqrtm(A)

    print("A is positive definite")
    print("sqrt(A) is real:")
    print(X.real.round(6))

if __name__ == "__main__":
    main()

2. Negative Eigenvalues

import numpy as np
from scipy import linalg

def main():
    # Negative eigenvalues -> complex square root
    A = np.array([[-1, 0],
                  [0, 4]])

    X = linalg.sqrtm(A)

    print("A has negative eigenvalue:")
    print(A)
    print()
    print("sqrt(A) is complex:")
    print(X)

if __name__ == "__main__":
    main()

Applications

1. Whitening Transform

import numpy as np
from scipy import linalg

def main():
    np.random.seed(42)

    # Covariance matrix
    Sigma = np.array([[2, 1],
                      [1, 2]])

    # Whitening: W = Sigma^{-1/2}
    Sigma_sqrt = linalg.sqrtm(Sigma)
    W = np.linalg.inv(Sigma_sqrt)

    # Generate correlated data
    data = np.random.randn(1000, 2) @ linalg.sqrtm(Sigma)

    # Whiten
    whitened = data @ W

    print("Original covariance:")
    print(np.cov(data.T).round(2))
    print()
    print("Whitened covariance:")
    print(np.cov(whitened.T).round(2))

if __name__ == "__main__":
    main()

2. Geometric Mean

import numpy as np
from scipy import linalg

def main():
    A = np.array([[4, 0],
                  [0, 9]])
    B = np.array([[1, 0],
                  [0, 4]])

    # Matrix geometric mean: A#B = A^{1/2} (A^{-1/2} B A^{-1/2})^{1/2} A^{1/2}
    A_sqrt = linalg.sqrtm(A)
    A_inv_sqrt = linalg.sqrtm(np.linalg.inv(A))

    inner = A_inv_sqrt @ B @ A_inv_sqrt
    geo_mean = A_sqrt @ linalg.sqrtm(inner) @ A_sqrt

    print("Geometric mean A#B:")
    print(geo_mean.real.round(6))

if __name__ == "__main__":
    main()

Summary

Function Description
linalg.sqrtm(A) Principal matrix square root

Key: Returns principal square root. May be complex if A has negative eigenvalues.