Source code for sppcax.distributions.exponential_family

"""Base class for exponential family distributions."""

from typing import ClassVar

import jax.numpy as jnp

from ..types import Array, Shape
from .base import Distribution


[docs]class ExponentialFamily(Distribution): """Base class for exponential family distributions in natural parameterization. The exponential family has the form: p(x|η) = h(x)exp(η^T T(x) - A(η)) where: - η: natural parameters - T(x): sufficient statistics - A(η): log normalizer - h(x): base measure Attributes: natural_param_shape: Shape of natural parameters (class variable). """ natural_param_shape: ClassVar[Shape] = () # Override in subclasses def __init__(self, batch_shape: Shape = (), event_shape: Shape = ()): """Initialize exponential family distribution. Args: batch_shape: Shape of batch dimensions. event_shape: Shape of event dimensions. """ super().__init__(batch_shape, event_shape) @property def natural_parameters(self) -> Array: """Get natural parameters of the distribution. Returns: Natural parameters η with shape: batch_shape + natural_param_shape """ raise NotImplementedError
[docs] def sufficient_statistics(self, x: Array) -> Array: """Compute sufficient statistics T(x). Args: x: Data to compute sufficient statistics for. Shape: batch_shape + event_shape Returns: Sufficient statistics T(x) with shape: batch_shape + natural_param_shape """ raise NotImplementedError
@property def expected_sufficient_statistics(self) -> Array: """Compute expected sufficient statistics E[T(x)]. Returns: Expected sufficient statistics E[T(x)] with shape: batch_shape + natural_param_shape """ raise NotImplementedError @property def log_normalizer(self) -> Array: """Compute log normalizer A(η). Returns: Log normalizer A(η) with shape: batch_shape """ raise NotImplementedError
[docs] def log_base_measure(self, x: Array = None) -> Array: """Compute log of base measure h(x). Args: x: Data to compute base measure for. Shape: batch_shape + event_shape Returns: Log base measure log(h(x)) with shape: batch_shape """ return jnp.zeros(()) # Default to h(x) = 1
def _check_support(self, x: Array) -> Array: """Check if values are within distribution support. Args: x: Values to check. Shape: batch_shape + event_shape Returns: Boolean mask of valid values with shape matching x's batch dimensions. """ return jnp.ones(x.shape[: -len(self.event_shape)], dtype=bool) # Default: all values valid
[docs] def log_prob(self, x: Array) -> Array: """Compute log probability. Args: x: Data to compute log probability for. Shape: batch_shape + event_shape Returns: Log probability log p(x|η) with shape: batch_shape Returns -inf for values outside the support. """ valid = self._check_support(x) eta = self.natural_parameters # batch_shape + natural_param_shape T_x = self.sufficient_statistics(x) # batch_shape + natural_param_shape # Sum over natural parameter dimensions inner_product = jnp.sum(eta * T_x, axis=tuple(range(-len(self.natural_param_shape), 0))) log_prob = inner_product - self.log_normalizer + self.log_base_measure(x) return jnp.where(valid, log_prob, -jnp.inf)
@property def expected_log_base_measure(self) -> Array: """Compute the expectation of the log base measure E_{p(x)}[log(h(x))] Returns: Expectation E_{p(x)}[log(h(x))] with shape: batch_shape """ # We assume that by default base measure is not a function of x. # This has to be modified for distributions where this is not the case anymore. return self.log_base_measure() @property def entropy(self) -> Array: """Compute entropy of the distribution. Returns: Entropy with shape: batch_shape """ eta = self.natural_parameters expected_T = self.expected_sufficient_statistics inner_product = jnp.sum(eta * expected_T, axis=tuple(range(-len(self.natural_param_shape), 0))) return -self.expected_log_base_measure - inner_product + self.log_normalizer
[docs] def kl_divergence(self, other: "ExponentialFamily") -> Array: """Compute KL divergence KL(self||other). Args: other: Other distribution to compute KL divergence with. Returns: KL divergence KL(self||other) with shape: batch_shape """ eta_self = self.natural_parameters eta_other = other.natural_parameters expected_T = self.expected_sufficient_statistics # Sum over natural parameter dimensions inner_product = jnp.sum( (eta_self - eta_other) * expected_T, axis=tuple(range(-len(self.natural_param_shape), 0)) ) return -self.log_normalizer + other.log_normalizer + inner_product
[docs] @classmethod def from_natural_parameters(cls, eta: Array) -> "ExponentialFamily": """Create distribution from natural parameters. Args: eta: Natural parameters with shape: batch_shape + natural_param_shape Returns: Distribution instance. """ raise NotImplementedError