"""Multivariate normal distribution implementation."""
from typing import ClassVar, Optional
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
from jax.scipy.linalg import solve, solve_triangular
from ..types import Array, Matrix, PRNGKey, Shape, Vector
from .exponential_family import ExponentialFamily
from .utils import safe_cholesky, safe_cholesky_and_logdet, cho_inv
[docs]class MultivariateNormal(ExponentialFamily):
"""Multivariate normal distribution in natural parameters."""
nat1_0: Vector # Prior: precision_0 * mean_0 (unmasked)
nat2_0: Matrix # Prior: -0.5 * precision_0 (unmasked)
dnat1: Vector # Data contribution to nat1
dnat2: Matrix # Data contribution to nat2
mask: Vector # Mask indicating active dimensions.
natural_param_shape: ClassVar[Shape] = (1,) # [nat1, nat2]
def __init__(
self,
loc: Array,
scale_tril: Optional[Array] = None,
covariance: Optional[Array] = None,
precision: Optional[Array] = None,
mask: Optional[Array] = None,
):
"""Initialize multivariate normal with standard parameters.
Args:
loc: Mean vector with shape (..., d)
scale_tril: Optional lower triangular scale matrix with shape (..., d, d)
covariance: Optional covariance matrix with shape (..., d, d)
precision: Optional precision matrix with shape (..., d, d)
mask: Optional boolean mask with shape (..., d) where True indicates active dimensions
Note:
Only one of scale_tril, covariance, or precision should be provided.
If none are provided, identity matrix is used as the scale.
"""
# Get shapes from loc
*batch_shape, dim = loc.shape
super().__init__(batch_shape=tuple(batch_shape), event_shape=(dim,))
# Validate inputs
scale_params = sum(x is not None for x in [scale_tril, covariance, precision])
if scale_params > 1:
raise ValueError("Only one of scale_tril, covariance, or precision should be provided")
# Process mask
if mask is not None:
if mask.shape != loc.shape:
raise ValueError(f"Mask shape {mask.shape} must match loc shape {loc.shape}")
self.mask = mask
else:
self.mask = jnp.ones_like(loc, dtype=bool)
# Compute precision matrix
if precision is not None:
precision = jnp.broadcast_to(precision, (*batch_shape, dim, dim))
if precision.shape != (*batch_shape, dim, dim):
raise ValueError(f"Precision shape {precision.shape} must match loc batch shape")
P = precision
elif covariance is not None:
covariance = jnp.broadcast_to(covariance, (*batch_shape, dim, dim))
if covariance.shape != (*batch_shape, dim, dim):
raise ValueError(f"Covariance shape {covariance.shape} must match loc batch shape")
P = cho_inv(covariance)
elif scale_tril is not None:
scale_tril = jnp.broadcast_to(scale_tril, (*batch_shape, dim, dim))
if scale_tril.shape != (*batch_shape, dim, dim):
raise ValueError(f"Scale_tril shape {scale_tril.shape} must match loc batch shape")
P = cho_inv(scale_tril @ scale_tril.mT)
else:
# Default to identity matrix with proper broadcasting
P = jnp.broadcast_to(jnp.eye(dim), (*batch_shape, dim, dim))
# Set prior natural parameters (unmasked)
self.nat1_0 = jnp.squeeze(P @ loc[..., None], -1)
self.nat2_0 = -0.5 * P
self.dnat1 = jnp.zeros_like(self.nat1_0)
self.dnat2 = jnp.zeros_like(self.nat2_0)
[docs] def apply_mask_vector(self, x: Array) -> Array:
"""Apply mask to a vector, zeroing out masked dimensions.
Args:
x: Vector with shape (..., d)
Returns:
Masked vector with same shape
"""
return jnp.where(self.mask, x, 0.0)
[docs] def apply_mask_matrix(self, x: Array, zeromask: bool = False) -> Array:
"""Apply mask to a matrix, zeroing out masked rows and columns.
Args:
x: Matrix with shape (..., d, d)
zeromask: If True, set masked entries to 0. If False (default),
set masked entries to identity matrix values.
Returns:
Masked matrix with same shape.
"""
mask_mat = self.mask[..., None] * self.mask[..., None, :]
if zeromask:
return jnp.where(mask_mat, x, 0.0)
else:
return jnp.where(mask_mat, x, jnp.eye(x.shape[-1]))
@property
def nat1(self) -> Array:
"""First natural parameter (precision * mean), masked."""
return self.apply_mask_vector(self.nat1_0 + self.dnat1)
@property
def nat2(self) -> Array:
"""Second natural parameter (-0.5 * precision)."""
return self.nat2_0 + self.dnat2
[docs] @classmethod
def from_natural_parameters(cls, nat1: Array, nat2: Array, mask: Optional[Array] = None) -> "MultivariateNormal":
"""Create MVN from natural parameters.
Args:
nat1: First natural parameter (precision * mean).
nat2: Second natural parameter (-0.5 * precision).
mask: Optional boolean mask with shape matching nat1
Returns:
MultivariateNormal instance.
"""
precision = -2.0 * nat2
loc = solve(precision, nat1, assume_a="pos")
return cls(loc=loc, precision=precision, mask=mask)
@property
def mean(self) -> Array:
"""Get mean parameter."""
mean = solve(self.precision, self.nat1, assume_a="pos")
return self.apply_mask_vector(mean)
@property
def precision(self) -> Array:
"""Get precision parameter."""
return self.apply_mask_matrix(-2.0 * self.nat2)
@property
def covariance(self) -> Array:
return self.apply_mask_matrix(cho_inv(self.precision), zeromask=True)
[docs] def sufficient_statistics(self, x: Array) -> Array:
"""Compute sufficient statistics T(x) = [x, xx^T].
Args:
x: Value to compute sufficient statistics for.
Returns:
Sufficient statistics [x, vec(xx^T)].
"""
x = self.apply_mask_vector(x)
xx = x[..., None] * x[..., None, :] # Outer product
return jnp.concatenate([x, xx.reshape(*x.shape[:-1], -1)], axis=-1)
@property
def expected_sufficient_statistics(self) -> Array:
"""Compute E[T(x)] = [μ, vec(μμ^T + Σ)].
Returns:
Expected sufficient statistics [E[x], vec(E[xx^T])].
"""
E_x = self.mean
cov = self.covariance
mean_outer = E_x[..., None] * E_x[..., None, :]
E_xx = mean_outer + cov
return jnp.concatenate([E_x, E_xx.reshape(*E_x.shape[:-1], -1)], axis=-1)
@property
def natural_parameters(self) -> Array:
"""Get natural parameters η = [precision*mean, -0.5*precision].
Returns:
Natural parameters [η₁, vec(η₂)].
"""
return jnp.concatenate([self.nat1, -0.5 * self.precision.reshape(*self.batch_shape, -1)], axis=-1)
@property
def log_normalizer(self) -> Array:
"""Compute log normalizer A(η).
Returns:
Log normalizer A(η) with shape: batch_shape
"""
precision = self.precision
L, logdet = safe_cholesky_and_logdet(precision)
m = self.apply_mask_vector(solve_triangular(L, self.nat1[..., None], lower=True)[..., 0])
return 0.5 * jnp.sum(jnp.square(m), -1) - 0.5 * logdet
[docs] def log_base_measure(self, x: Array = None) -> Array:
"""Compute log of base measure h(x).
Args:
x: Data to compute base measure for.
Shape: batch_shape + event_shape
Returns:
Log base measure log(h(x)) with shape: batch_shape
"""
d = jnp.sum(self.mask, axis=-1) # Count active dimensions
return self.broadcast_to_shape(-0.5 * d * jnp.log(2.0 * jnp.pi), ignore_event=True)
[docs] def sample(self, key: PRNGKey, sample_shape: Shape = ()) -> Array:
"""Sample from the distribution.
Args:
key: PRNG key for random sampling.
sample_shape: Shape of samples to draw.
Returns:
Samples from the distribution.
"""
precision = self.precision
# Use Cholesky for sampling
L = safe_cholesky(precision)
_eta = solve_triangular(L, self.nat1, lower=True)
L = jnp.broadcast_to(L, sample_shape + L.shape)
# Generate standard normal samples and transform
shape = sample_shape + self.shape
z = jr.normal(key, shape)
smpl = _eta + z
smpl = solve_triangular(L.mT, smpl[..., None], lower=False)[..., 0]
return self.apply_mask_vector(smpl)
@property
def expected_second_moment(self) -> Array:
"""Expected second moment E[xx^T] = Cov + mean @ mean^T, per row."""
m = self.mean
return self.covariance + m[..., :, None] * m[..., None, :]
[docs] def mf_expectations(self) -> dict:
"""Return expectations for mean-field coordinate ascent partner."""
return {
"mean": self.mean,
"second_moment": self.expected_second_moment,
}
[docs] def mf_update(self, stats: tuple, partner_expectations: dict) -> "MultivariateNormal":
"""Mean-field coordinate ascent update for MVN weights.
Scales data sufficient statistics by partner's expected precision E[1/sigma^2],
then performs standard natural parameter update.
Args:
stats: Sufficient statistics tuple (SxxT, SxyT, SyyT, N).
partner_expectations: Dict with 'expected_precision' from noise component.
E[precision] can be scalar (isotropic), (D,) vector (per-feature IG),
or (D, D) matrix (InverseWishart). For matrix precision, the diagonal
is used for per-row weight updates.
Returns:
Updated MVN distribution (posterior).
"""
E_psi = partner_expectations["expected_precision"] # scalar, (D,), or (D, D)
SxxT, SxyT, *_ = stats
# For full matrix precision (IW noise), extract diagonal for row-wise updates
if E_psi.ndim >= 2:
E_psi = jnp.diag(E_psi) # (D,)
# Data contributions only — prior is preserved in nat1_0/nat2_0
dnat2 = -0.5 * E_psi[..., None, None] * SxxT # (D, dim, dim) broadcast
dnat1 = E_psi[..., None] * SxyT.mT # (D, dim) broadcast
return eqx.tree_at(lambda d: (d.dnat1, d.dnat2), self, (dnat1, dnat2))