"""Multivariate Normal-Gamma distribution implementation."""
from typing import Optional, Tuple
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
from ..types import Array, PRNGKey, Shape
from .exponential_family import ExponentialFamily
from .gamma import InverseGamma, Gamma
from .mvn import MultivariateNormal
[docs]class MultivariateNormalInverseGamma(ExponentialFamily):
"""Multivariate Normal-Gamma distribution.
This distribution combines:
p(x|σ²) = N(x; μ, σ²Λ⁻¹) # Multivariate Normal
p(σ²) = InverseGamma(σ²; α, β) # Variance scalar
where:
- x is a vector (can be batched)
- μ is the location parameter
- Λ is a base precision matrix
- σ² is a scalar variance parameter
- α, β are Inverse-Gamma distribution parameters
"""
mvn: MultivariateNormal
inv_gamma: InverseGamma
def __init__(
self,
loc: Array,
*,
isotropic_noise: bool,
mask: Optional[Array] = None,
alpha0: float = 2.0,
beta0: float = 1.0,
scale_tril: Optional[Array] = None,
covariance: Optional[Array] = None,
precision: Optional[Array] = None,
):
"""Initialize MultivariateNormalInverseGamma distribution.
Args:
loc: Location parameter with shape (batch_dim, event_dim).
isotropic_noise: If True, use a single shared InverseGamma (scalar batch).
If False, use per-row InverseGamma (one per batch dimension).
mask: Optional boolean mask for active dimensions.
alpha0: Shape parameter for InverseGamma prior (default: 2.0).
beta0: Scale parameter for InverseGamma prior (default: 1.0).
scale_tril: Optional lower triangular scale matrix.
covariance: Optional covariance matrix.
precision: Optional precision matrix.
Note:
Only one of scale_tril, covariance, or precision should be provided.
If none are provided, identity matrix is used.
"""
# Initialize MVN distribution
self.mvn = MultivariateNormal(
loc=loc, mask=mask, scale_tril=scale_tril, covariance=covariance, precision=precision
)
# Initialize Gamma parameters
if isotropic_noise:
self.inv_gamma = InverseGamma(alpha0=alpha0, beta0=beta0)
else:
self.inv_gamma = InverseGamma(
alpha0=alpha0 * jnp.ones(self.mvn.batch_shape), beta0=beta0 * jnp.ones(self.mvn.batch_shape)
)
# Set shapes from MVN-Gamma
super().__init__(batch_shape=self.mvn.batch_shape, event_shape=self.mvn.event_shape)
[docs] def log_prob(self, x: Tuple[Array, Array]) -> Array:
"""Compute log probability.
Args:
x: Tuple of (sig_sqr, w) where:
w: Value of the sample state
sig_sqr: Value of the sample variance
Returns:
Log probability
"""
sig_sqr, w = x
sig_sqr = jnp.diagonal(sig_sqr, axis1=-1, axis2=-2) if sig_sqr.ndim == w.ndim else sig_sqr
# MVN term: p(w|psi)
mvn = eqx.tree_at(
lambda x: (x.nat1_0, x.dnat1, x.nat2_0, x.dnat2),
self.mvn,
(
self.mvn.nat1 / sig_sqr[..., None],
jnp.zeros_like(self.mvn.dnat1),
self.mvn.nat2 / sig_sqr[..., None, None],
jnp.zeros_like(self.mvn.dnat2),
),
)
mvn_log_prob = mvn.log_prob(w).sum()
# Gamma term: p(psi)
inv_gamma_log_prob = self.inv_gamma.log_prob(sig_sqr).sum()
return mvn_log_prob + inv_gamma_log_prob
[docs] def sample(self, seed: PRNGKey, sample_shape: Shape = ()) -> Tuple[Array, Array]:
"""Sample from the distribution.
Args:
seed: PRNG key
sample_shape: Shape of samples to draw
Returns:
Tuple of (sig_sqr, value) samples
"""
key_g, key_mvn = jr.split(seed)
# Sample psi ~ Gamma(α, β)
sig_sqr = self.inv_gamma.sample(key_g, sample_shape=sample_shape)
sig = jnp.sqrt(sig_sqr)
# Sample x|σ² ~ MVN(μ, σ²Λ⁻¹)
# We can sample from base MVN and scale by sqrt(σ²)
mvn = eqx.tree_at(
lambda x: (x.nat1_0, x.dnat1),
self.mvn,
(self.mvn.nat1 / sig[..., None], jnp.zeros_like(self.mvn.dnat1)),
)
value = mvn.sample(key_mvn, sample_shape=sample_shape) * sig[..., None]
n = value.shape[-2]
return jnp.broadcast_to(sig_sqr, (n,)), value
[docs] def mode(self) -> Tuple[Array, Array]:
r"""Solve for the mode. Recall,
..math::
p(\mu, \sigma^2) \propto
\mathrm{N}(x \mid \mu, \sigma^2 \Sigma) \times
\mathrm{IG}(\sigma^2 \mid \alpha, \beta)
The optimal mean is :math:`x^* = \mu_0`. Substituting this in,
..math::
p(\mu^*, \sigma^2) \propto IG(\sigma^2 \mid \alpha + D/2, \beta)
where D is dimensionality of x, and the mode of this inverse gamma
distribution is at
..math::
(\sigma^2)* = \beta / (\alpha + (D + 2)/2)
"""
dim = self.event_shape[-1]
sigma_sqr_mode = self.beta / (self.alpha + (dim + 2) / 2)
n = self.batch_shape[0] if self.batch_shape else 1
return jnp.broadcast_to(sigma_sqr_mode, (n,)), self.mean
@property
def alpha(self) -> Array:
"""Shape parameter α of the InverseGamma component."""
return self.inv_gamma.alpha
@property
def beta(self) -> Array:
"""Scale parameter β of the InverseGamma component."""
return self.inv_gamma.beta
@property
def precision(self) -> Array:
"""Base precision matrix Λ of the MVN component."""
return self.mvn.precision
@property
def mean(self) -> Array:
"""Get mean of the marginal distribution p(x)."""
return self.mvn.mean
@property
def covariance(self) -> Array:
"""Base covariance Λ⁻¹ of the MVN component (without noise scaling)."""
return self.mvn.covariance
@property
def col_covariance(self) -> Array:
"""Column covariance Λ⁻¹ (alias for covariance)."""
return self.mvn.covariance
@property
def expected_covariance(self) -> Array:
"""Expected covariance E[σ²] * Λ⁻¹."""
exp_variance = jnp.broadcast_to(self.inv_gamma.mean, self.batch_shape)
return self.mvn.covariance * exp_variance[..., None, None]
@property
def expected_psi(self) -> Array:
"""Compute expected precision E[psi]."""
return jnp.broadcast_to(self.inv_gamma.expected_psi, self.batch_shape)
@property
def expected_log_psi(self) -> Array:
"""Compute expected log precision E[log(psi)]."""
return jnp.broadcast_to(self.inv_gamma.expected_log_psi, self.batch_shape)
@property
def expected_sufficient_statistics_psi(self) -> Array:
"""Compute expected sufficient statistics of psi."""
gamma = Gamma(self.alpha, self.beta)
suff_stats = gamma.expected_sufficient_statistics
return jnp.broadcast_to(suff_stats, self.batch_shape + (2,))
[docs]def mvnig_posterior_update(
mvnig_prior: MultivariateNormalInverseGamma, sufficient_stats: tuple, props: object
) -> MultivariateNormalInverseGamma:
"""Update the multivariate normal inverse gamma (MVNIG) posterior from sufficient statistics.
Args:
mvnig_prior: Prior MVNIG distribution.
sufficient_stats: Tuple of (SxxT, SxyT, SyyT, N) sufficient statistics.
props: Parameter properties controlling which parameters are trainable.
Returns:
Posterior MVNIG distribution.
"""
# extract parameters of the prior distribution
mvn_prior = mvnig_prior.mvn
# unpack the sufficient statistics
SxxT, SxyT, SyyT, N = sufficient_stats
# compute MVN data contributions (prior preserved in nat1_0/nat2_0)
mvn_dnat2 = -0.5 * SxxT
mvn_dnat1 = SxyT.mT
# Syy term uses prior nat1 (before data)
nat1_prior = mvn_prior.nat1
Syy = (SyyT if SyyT.ndim == 1 else jnp.diag(SyyT)) + jnp.sum(mvn_prior.mean * nat1_prior, -1)
mvn_post = eqx.tree_at(lambda m: (m.dnat1, m.dnat2), mvnig_prior.mvn, (mvn_dnat1, mvn_dnat2))
M_pos = mvn_post.mean
nat1_post = mvn_post.nat1
dnat1 = -N / 2
dnat2 = -(Syy - jnp.sum(M_pos * nat1_post, -1)) / 2
# For isotropic noise (PCA), sum dnat across features since all share one variance
if mvnig_prior.inv_gamma.batch_shape == ():
D = dnat2.shape[0]
dnat1 = dnat1.sum() if jnp.ndim(dnat1) > 0 else dnat1 * D
dnat2 = dnat2.sum()
inv_gamma_post = eqx.tree_at(lambda m: (m.dnat1, m.dnat2), mvnig_prior.inv_gamma, (dnat1, dnat2))
mvnig_post = eqx.tree_at(lambda m: (m.mvn, m.inv_gamma), mvnig_prior, (mvn_post, inv_gamma_post))
return mvnig_post