"""Utility functions for distributions."""
import jax.numpy as jnp
from jax.scipy.linalg import cholesky, solve, cho_factor, cho_solve
from jax import jit
from ..types import Array, Scalar, Matrix
[docs]def symmetrize(matrix: Matrix) -> Matrix:
"""Symmetrize one or more matrices."""
return 0.5 * (matrix + jnp.swapaxes(matrix, -1, -2))
[docs]@jit
def cho_inv(matrix: Matrix, diagonal_boost: float = 1e-12) -> Matrix:
"""Invert a positive-definite matrix via Cholesky decomposition.
Args:
matrix: Positive-definite matrix with shape (..., d, d).
diagonal_boost: Small value added to diagonal for numerical stability.
Returns:
Inverse matrix with same shape.
"""
identity = jnp.eye(matrix.shape[-1], dtype=matrix.dtype)
chol = cho_factor(symmetrize(matrix) + diagonal_boost * identity, lower=True)
return cho_solve(chol, jnp.broadcast_to(identity, matrix.shape))
[docs]@jit
def safe_cholesky(X: Array, jitter: float = 1e-12) -> Array:
"""Compute Cholesky decomposition with added diagonal jitter for numerical stability.
Args:
X: Symmetric positive definite matrix.
jitter: Small positive value to add to diagonal for stability.
Returns:
Lower triangular Cholesky factor L such that X ≈ L @ L^T.
"""
n = X.shape[-1]
X = X + jitter * jnp.eye(n)
L = cholesky(X, lower=True)
return L
[docs]@jit
def safe_cholesky_and_logdet(X: Array, jitter: float = 1e-12) -> tuple[Array, Scalar]:
"""Compute Cholesky decomposition and log determinant with added diagonal jitter.
Args:
X: Symmetric positive definite matrix.
jitter: Small positive value to add to diagonal for stability.
Returns:
Tuple of (L, logdet) where L is the lower triangular Cholesky factor
and logdet is the log determinant of X.
"""
L = safe_cholesky(X, jitter=jitter)
logdet = 2.0 * jnp.sum(jnp.log(jnp.diagonal(L, axis1=-1, axis2=-2)), -1)
return L, logdet
[docs]@jit
def natural_to_moment(nat1: Array, nat2: Array) -> tuple[Array, Array]:
"""Convert natural parameters to moment parameters for multivariate normal.
Args:
nat1: First natural parameter (precision * mean).
nat2: Second natural parameter (-0.5 * precision).
Returns:
Tuple of (mean, covariance).
"""
precision = -2.0 * nat2
mean = solve(precision, nat1)
covariance = solve(precision, jnp.eye(precision.shape[0]))
return mean, covariance
[docs]@jit
def moment_to_natural(mean: Array, covariance: Array) -> tuple[Array, Array]:
"""Convert moment parameters to natural parameters for multivariate normal.
Args:
mean: Mean vector.
covariance: Covariance matrix.
Returns:
Tuple of (nat1, nat2).
"""
precision = solve(covariance, jnp.eye(covariance.shape[0]))
nat1 = precision @ mean
nat2 = -0.5 * precision
return nat1, nat2