Source code for sppcax.distributions.mvn_gamma

"""Multivariate Normal-Gamma distribution implementation."""

from typing import Optional, Tuple

import jax.numpy as jnp
import jax.random as jr
import equinox as eqx

from ..types import Array, PRNGKey, Shape
from .exponential_family import ExponentialFamily
from .gamma import InverseGamma, Gamma
from .mvn import MultivariateNormal


[docs]class MultivariateNormalInverseGamma(ExponentialFamily): """Multivariate Normal-Gamma distribution. This distribution combines: p(x|σ²) = N(x; μ, σ²Λ⁻¹) # Multivariate Normal p(σ²) = InverseGamma(σ²; α, β) # Variance scalar where: - x is a vector (can be batched) - μ is the location parameter - Λ is a base precision matrix - σ² is a scalar variance parameter - α, β are Inverse-Gamma distribution parameters """ mvn: MultivariateNormal inv_gamma: InverseGamma def __init__( self, loc: Array, *, isotropic_noise: bool, mask: Optional[Array] = None, alpha0: float = 2.0, beta0: float = 1.0, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None, ): """Initialize MultivariateNormalInverseGamma distribution. Args: loc: Location parameter with shape (batch_dim, event_dim). isotropic_noise: If True, use a single shared InverseGamma (scalar batch). If False, use per-row InverseGamma (one per batch dimension). mask: Optional boolean mask for active dimensions. alpha0: Shape parameter for InverseGamma prior (default: 2.0). beta0: Scale parameter for InverseGamma prior (default: 1.0). scale_tril: Optional lower triangular scale matrix. covariance: Optional covariance matrix. precision: Optional precision matrix. Note: Only one of scale_tril, covariance, or precision should be provided. If none are provided, identity matrix is used. """ # Initialize MVN distribution self.mvn = MultivariateNormal( loc=loc, mask=mask, scale_tril=scale_tril, covariance=covariance, precision=precision ) # Initialize Gamma parameters if isotropic_noise: self.inv_gamma = InverseGamma(alpha0=alpha0, beta0=beta0) else: self.inv_gamma = InverseGamma( alpha0=alpha0 * jnp.ones(self.mvn.batch_shape), beta0=beta0 * jnp.ones(self.mvn.batch_shape) ) # Set shapes from MVN-Gamma super().__init__(batch_shape=self.mvn.batch_shape, event_shape=self.mvn.event_shape)
[docs] def log_prob(self, x: Tuple[Array, Array]) -> Array: """Compute log probability. Args: x: Tuple of (sig_sqr, w) where: w: Value of the sample state sig_sqr: Value of the sample variance Returns: Log probability """ sig_sqr, w = x sig_sqr = jnp.diagonal(sig_sqr, axis1=-1, axis2=-2) if sig_sqr.ndim == w.ndim else sig_sqr # MVN term: p(w|psi) mvn = eqx.tree_at( lambda x: (x.nat1_0, x.dnat1, x.nat2_0, x.dnat2), self.mvn, ( self.mvn.nat1 / sig_sqr[..., None], jnp.zeros_like(self.mvn.dnat1), self.mvn.nat2 / sig_sqr[..., None, None], jnp.zeros_like(self.mvn.dnat2), ), ) mvn_log_prob = mvn.log_prob(w).sum() # Gamma term: p(psi) inv_gamma_log_prob = self.inv_gamma.log_prob(sig_sqr).sum() return mvn_log_prob + inv_gamma_log_prob
[docs] def sample(self, seed: PRNGKey, sample_shape: Shape = ()) -> Tuple[Array, Array]: """Sample from the distribution. Args: seed: PRNG key sample_shape: Shape of samples to draw Returns: Tuple of (sig_sqr, value) samples """ key_g, key_mvn = jr.split(seed) # Sample psi ~ Gamma(α, β) sig_sqr = self.inv_gamma.sample(key_g, sample_shape=sample_shape) sig = jnp.sqrt(sig_sqr) # Sample x|σ² ~ MVN(μ, σ²Λ⁻¹) # We can sample from base MVN and scale by sqrt(σ²) mvn = eqx.tree_at( lambda x: (x.nat1_0, x.dnat1), self.mvn, (self.mvn.nat1 / sig[..., None], jnp.zeros_like(self.mvn.dnat1)), ) value = mvn.sample(key_mvn, sample_shape=sample_shape) * sig[..., None] n = value.shape[-2] return jnp.broadcast_to(sig_sqr, (n,)), value
[docs] def mode(self) -> Tuple[Array, Array]: r"""Solve for the mode. Recall, ..math:: p(\mu, \sigma^2) \propto \mathrm{N}(x \mid \mu, \sigma^2 \Sigma) \times \mathrm{IG}(\sigma^2 \mid \alpha, \beta) The optimal mean is :math:`x^* = \mu_0`. Substituting this in, ..math:: p(\mu^*, \sigma^2) \propto IG(\sigma^2 \mid \alpha + D/2, \beta) where D is dimensionality of x, and the mode of this inverse gamma distribution is at ..math:: (\sigma^2)* = \beta / (\alpha + (D + 2)/2) """ dim = self.event_shape[-1] sigma_sqr_mode = self.beta / (self.alpha + (dim + 2) / 2) n = self.batch_shape[0] if self.batch_shape else 1 return jnp.broadcast_to(sigma_sqr_mode, (n,)), self.mean
@property def alpha(self) -> Array: """Shape parameter α of the InverseGamma component.""" return self.inv_gamma.alpha @property def beta(self) -> Array: """Scale parameter β of the InverseGamma component.""" return self.inv_gamma.beta @property def precision(self) -> Array: """Base precision matrix Λ of the MVN component.""" return self.mvn.precision @property def mean(self) -> Array: """Get mean of the marginal distribution p(x).""" return self.mvn.mean @property def covariance(self) -> Array: """Base covariance Λ⁻¹ of the MVN component (without noise scaling).""" return self.mvn.covariance @property def col_covariance(self) -> Array: """Column covariance Λ⁻¹ (alias for covariance).""" return self.mvn.covariance @property def expected_covariance(self) -> Array: """Expected covariance E[σ²] * Λ⁻¹.""" exp_variance = jnp.broadcast_to(self.inv_gamma.mean, self.batch_shape) return self.mvn.covariance * exp_variance[..., None, None] @property def expected_psi(self) -> Array: """Compute expected precision E[psi].""" return jnp.broadcast_to(self.inv_gamma.expected_psi, self.batch_shape) @property def expected_log_psi(self) -> Array: """Compute expected log precision E[log(psi)].""" return jnp.broadcast_to(self.inv_gamma.expected_log_psi, self.batch_shape) @property def expected_sufficient_statistics_psi(self) -> Array: """Compute expected sufficient statistics of psi.""" gamma = Gamma(self.alpha, self.beta) suff_stats = gamma.expected_sufficient_statistics return jnp.broadcast_to(suff_stats, self.batch_shape + (2,))
[docs]def mvnig_posterior_update( mvnig_prior: MultivariateNormalInverseGamma, sufficient_stats: tuple, props: object ) -> MultivariateNormalInverseGamma: """Update the multivariate normal inverse gamma (MVNIG) posterior from sufficient statistics. Args: mvnig_prior: Prior MVNIG distribution. sufficient_stats: Tuple of (SxxT, SxyT, SyyT, N) sufficient statistics. props: Parameter properties controlling which parameters are trainable. Returns: Posterior MVNIG distribution. """ # extract parameters of the prior distribution mvn_prior = mvnig_prior.mvn # unpack the sufficient statistics SxxT, SxyT, SyyT, N = sufficient_stats # compute MVN data contributions (prior preserved in nat1_0/nat2_0) mvn_dnat2 = -0.5 * SxxT mvn_dnat1 = SxyT.mT # Syy term uses prior nat1 (before data) nat1_prior = mvn_prior.nat1 Syy = (SyyT if SyyT.ndim == 1 else jnp.diag(SyyT)) + jnp.sum(mvn_prior.mean * nat1_prior, -1) mvn_post = eqx.tree_at(lambda m: (m.dnat1, m.dnat2), mvnig_prior.mvn, (mvn_dnat1, mvn_dnat2)) M_pos = mvn_post.mean nat1_post = mvn_post.nat1 dnat1 = -N / 2 dnat2 = -(Syy - jnp.sum(M_pos * nat1_post, -1)) / 2 # For isotropic noise (PCA), sum dnat across features since all share one variance if mvnig_prior.inv_gamma.batch_shape == (): D = dnat2.shape[0] dnat1 = dnat1.sum() if jnp.ndim(dnat1) > 0 else dnat1 * D dnat2 = dnat2.sum() inv_gamma_post = eqx.tree_at(lambda m: (m.dnat1, m.dnat2), mvnig_prior.inv_gamma, (dnat1, dnat2)) mvnig_post = eqx.tree_at(lambda m: (m.mvn, m.inv_gamma), mvnig_prior, (mvn_post, inv_gamma_post)) return mvnig_post