Source code for sppcax.distributions.delta

"""Delta distribution implementation."""

from typing import Callable, Optional

import jax.numpy as jnp

from ..types import Array, PRNGKey, Shape
from .base import Distribution
from .utils import cho_inv


[docs]def default_ss(x: Array) -> Array: """Compute default sufficient statistics [x, vec(xx^T)] for MVN. Args: x: Input vector with shape (..., d). Returns: Concatenation of x and vectorized outer product with shape (..., d + d*d). """ return jnp.concatenate([x, (x[..., None] * x[..., None, :]).reshape(*x.shape[:-1], -1)], axis=-1)
[docs]class Delta(Distribution): """Delta distribution (Dirac delta) concentrated at a single point.""" mean: Array sufficient_statistics: Callable def __init__( self, location: Array, sufficient_statistics_fn: Optional[Callable] = None, ): """Initialize delta distribution. Args: location: Point where probability mass is concentrated. Shape: batch_shape + event_shape sufficient_statistics_fn: Optional function to compute sufficient statistics. If None, uses MVN sufficient statistics. """ *batch_shape, event_dim = location.shape super().__init__(batch_shape=tuple(batch_shape), event_shape=(event_dim,)) self.mean = location self.sufficient_statistics = sufficient_statistics_fn if sufficient_statistics_fn is not None else default_ss
[docs] def log_prob(self, x: Array) -> Array: """Compute log probability. Args: x: Value to compute log probability for. Shape: batch_shape + event_shape Returns: Log probability with shape: batch_shape Returns 0 at location, -inf elsewhere. """ # Check if x equals location across event dimensions equal = jnp.all(x == self.mean, axis=tuple(range(-len(self.event_shape), 0))) return jnp.where(equal, 0.0, -jnp.inf)
[docs] def mode(self) -> Array: return self.mean
@property def covariance(self) -> Array: """Covariance matrix (always zero for delta distribution).""" zeros = jnp.zeros(self.shape) return zeros[..., None] * jnp.eye(zeros.shape[-1])
[docs] def sample(self, key: PRNGKey, sample_shape: Shape = ()) -> Array: """Sample from the distribution (always returns location). Args: key: PRNG key (unused). sample_shape: Shape of samples to draw. Returns: Samples with shape: sample_shape + batch_shape + event_shape All samples are equal to location. """ return jnp.broadcast_to(self.mean, sample_shape + self.shape)
[docs] def entropy(self) -> Array: """Compute entropy (always 0 for delta distribution). Returns: Entropy with shape: batch_shape """ return jnp.zeros(self.batch_shape)
@property def precision(self) -> Array: """Precision (infinite for delta, but return large finite value for compatibility).""" return jnp.inf * jnp.ones_like(self.covariance) @property def expected_psi(self) -> Array: """Expected precision for Delta: matrix inverse for square matrices, element-wise otherwise.""" if self.mean.ndim >= 2 and self.mean.shape[-1] == self.mean.shape[-2]: return cho_inv(self.mean) return 1.0 / self.mean @property def expected_second_moment(self) -> Array: """Expected second moment E[XX^T] = location @ location^T (no variance).""" return self.mean[..., :, None] * self.mean[..., None, :]
[docs] def mf_expectations(self) -> dict: """Return expectations for mean-field coordinate ascent partner.""" return { "mean": self.mean, "second_moment": self.expected_second_moment, "expected_precision": self.expected_psi, }
[docs] def mf_update(self, *args) -> "Delta": """Mean-field update for Delta is a no-op (fixed component).""" return self
@property def expected_sufficient_statistics(self) -> Array: """Compute expected sufficient statistics. For delta distribution, this is just sufficient_statistics(location) since all probability mass is concentrated at location. Returns: Expected sufficient statistics with shape determined by sufficient_statistics_fn. """ return self.sufficient_statistics(self.mean)