Source code for sppcax.distributions.beta

"""Beta distribution implementation."""

from typing import ClassVar

import jax.numpy as jnp
import jax.random as jr
import jax.scipy.special as jsp

from ..types import Array, PRNGKey, Shape
from .exponential_family import ExponentialFamily


[docs]class Beta(ExponentialFamily): """Beta distribution in natural parameters. The beta distribution has density: p(x|α,β) = x^(α-1) * (1-x)^(β-1) / B(α,β) for x ∈ [0,1] In exponential family form: h(x) = 1 η = [α-1, β-1] T(x) = [log(x), log(1-x)] A(η) = log(B(η₁+1, η₂+1)) """ nat1_0: Array # prior value of the first natural parameter (α0 - 1) nat2_0: Array # prior value of the second natural parameter (β0 - 1) dnat1: Array # Change in the first natural parameter (α-1) dnat2: Array # Change in the second natural parameter (β-1) natural_param_shape: ClassVar[Shape] = (2,) # [η₁, η₂] def __init__(self, alpha0: float | Array = 1.0, beta0: float | Array = 1.0): """Initialize beta distribution with alpha and beta parameters. Args: alpha0: First shape parameter α (default: 1.0) beta0: Second shape parameter β (default: 1.0) """ # Convert to arrays alpha = jnp.asarray(alpha0) beta = jnp.asarray(beta0) # Set shapes batch_shape = jnp.broadcast_shapes(jnp.shape(alpha), jnp.shape(beta)) super().__init__(batch_shape=batch_shape, event_shape=()) # Convert to natural parameters self.nat1_0 = jnp.broadcast_to(alpha - 1, self.batch_shape) # α-1 self.nat2_0 = jnp.broadcast_to(beta - 1, self.batch_shape) # β-1 # Initialize parameter changes self.dnat1 = jnp.zeros(self.batch_shape) self.dnat2 = jnp.zeros(self.batch_shape) @property def nat1(self) -> Array: """First natural parameter η₁ = α - 1.""" return self.nat1_0 + self.dnat1 @property def nat2(self) -> Array: """Second natural parameter η₂ = β - 1.""" return self.nat2_0 + self.dnat2 @property def alpha(self) -> Array: """Get first shape parameter α.""" return self.nat1 + 1.0 @property def beta(self) -> Array: """Get second shape parameter β.""" return self.nat2 + 1.0 @property def mean(self) -> Array: """Get mean of the distribution.""" alpha = self.alpha beta = self.beta return alpha / (alpha + beta) @property def variance(self) -> Array: """Get variance of the distribution.""" alpha = self.alpha beta = self.beta return (alpha * beta) / ((alpha + beta) ** 2 * (alpha + beta + 1)) @property def natural_parameters(self) -> Array: """Get natural parameters η = [α-1, β-1]. Returns: Natural parameters [η₁, η₂] with shape: batch_shape + (2,) """ return jnp.stack([self.nat1, self.nat2], axis=-1)
[docs] def sufficient_statistics(self, x: Array) -> Array: """Compute sufficient statistics T(x) = [log(x), log(1-x)]. Args: x: Value to compute sufficient statistics for. Shape: batch_shape + event_shape Returns: Sufficient statistics [log(x), log(1-x)] with shape: batch_shape + (2,) """ return jnp.stack([jnp.log(x), jnp.log(1 - x)], axis=-1)
@property def expected_sufficient_statistics(self) -> Array: """Compute E[T(x)] = [ψ(α) - ψ(α+β), ψ(β) - ψ(α+β)]. Returns: Expected sufficient statistics [E[log(x)], E[log(1-x)]] with shape: batch_shape + (2,) """ alpha = self.alpha beta = self.beta digamma_sum = jsp.digamma(alpha + beta) # E[log(x)] expected_log_x = jsp.digamma(alpha) - digamma_sum # E[log(1-x)] expected_log_1_minus_x = jsp.digamma(beta) - digamma_sum return jnp.stack([expected_log_x, expected_log_1_minus_x], axis=-1) @property def log_normalizer(self) -> Array: """Compute log normalizer A(η) = log(B(η₁+1, η₂+1)). Returns: Log normalizer with shape: batch_shape """ alpha = self.alpha beta = self.beta return jsp.betaln(alpha, beta) def _check_support(self, x: Array) -> Array: """Check if values are within distribution support. Args: x: Values to check. Shape: batch_shape + event_shape Returns: Boolean mask of valid values. """ return (x > 0) & (x < 1)
[docs] def log_base_measure(self, x: Array = None) -> Array: """Compute log of base measure h(x). Args: x: Data to compute base measure for. Shape: batch_shape + event_shape Returns: zero """ return jnp.zeros(())
[docs] def sample(self, key: PRNGKey, sample_shape: Shape = ()) -> Array: """Sample from the distribution. Args: key: PRNG key for random sampling. sample_shape: Shape of samples to draw. Returns: Samples with shape: sample_shape + batch_shape + event_shape """ shape = sample_shape + self.batch_shape + self.event_shape return jr.beta(key, self.alpha, self.beta, shape=shape)
[docs] @classmethod def from_natural_parameters(cls, eta: Array) -> "Beta": """Create beta distribution from natural parameters. Args: eta: Natural parameters [η₁, η₂] with shape: batch_shape + (2,) Returns: Beta distribution. """ alpha = eta[..., 0] + 1.0 # α = η₁ + 1 beta = eta[..., 1] + 1.0 # β = η₂ + 1 return cls(alpha0=alpha, beta0=beta)
@property def kl_divergence_from_prior(self) -> Array: """Compute KL divergence KL(post||prior). Returns: KL divergence KL(post||prior) with shape: batch_shape """ eta_self = self.natural_parameters eta_other = jnp.stack([self.nat1_0, self.nat2_0], axis=-1) alpha = eta_other[..., 0] + 1 beta = eta_other[..., 1] + 1 other_log_normalizer = jsp.betaln(alpha, beta) expected_T = self.expected_sufficient_statistics # Sum over natural parameter dimensions inner_product = jnp.sum( (eta_self - eta_other) * expected_T, axis=tuple(range(-len(self.natural_param_shape), 0)) ) return -self.log_normalizer + other_log_normalizer + inner_product