Source code for sppcax.distributions.categorical

"""Categorical distribution in natural parameterization."""

from typing import ClassVar, Optional

import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr

from ..types import Array, PRNGKey, Shape
from .exponential_family import ExponentialFamily

MAXF = jnp.finfo(jnp.dtype(jnp.array(0.0).dtype)).max


[docs]class Categorical(ExponentialFamily): """Categorical distribution parameterized by logits. The Categorical distribution has K-1 natural parameters η_k = log(p_k/p_K) for k=1,...,K-1 where p_K is the probability of the last category. """ nat1: Array # Shape (..., K-1) natural_param_shape: ClassVar[Shape] = (1,) def __init__(self, logits: Array): """Initialize Categorical distribution. Args: logits: Log-probs Shape (..., K) where K is number of categories. """ super().__init__(batch_shape=logits.shape[:-1], event_shape=()) self.nat1 = logits[..., :-1] - logits[..., -1:]
[docs] @classmethod def from_natural_parameters(cls, eta: Array) -> "Categorical": """Create Categorical from natural parameters. Args: eta: Log-odds relative to last category. Shape (..., K-1) where K is number of categories. Returns: Categorical instance. """ pad_width = [(0, 0)] * eta.ndim pad_width[-1] = (0, 1) return cls(logits=jnp.pad(eta, pad_width))
@property def full_logits(self) -> Array: pad_width = [(0, 0)] * self.nat1.ndim pad_width[-1] = (0, 1) return jnp.pad(self.nat1, pad_width) @property def probs(self) -> Array: """Get probability parameters.""" return jnn.softmax(self.full_logits) @property def natural_parameters(self) -> Array: """Get natural parameters (logits). Returns: Natural parameters η. """ return self.nat1 def _check_support(self, x: Array) -> Array: """Check if values are within distribution support.""" n_categories = self.nat1.shape[-1] + 1 return (x >= 0) & (x < n_categories)
[docs] def sufficient_statistics(self, x: Array) -> Array: """Compute sufficient statistics T(x). For categorical data, T(x) is a one-hot vector with the last category omitted (since probabilities sum to 1). Args: x: Category indices. Returns: One-hot encoded data (excluding last category). """ n_categories = self.nat1.shape[-1] + 1 one_hot = jnn.one_hot(x, n_categories) return one_hot[..., :-1] # Omit last category
@property def expected_sufficient_statistics(self) -> Array: """Compute E[T(x)] = p_1,...,p_{K-1}. Returns: Expected probabilities for first K-1 categories. """ probs = jnn.softmax(self.full_logits) return probs[..., :-1] # Return first K-1 probabilities @property def log_normalizer(self) -> Array: """Compute log normalizer A(η). For categorical, A(η) = log(1 + sum(exp(η_k))). Returns: Log normalizer A(η). """ return jnn.logsumexp(self.full_logits)
[docs] def sample(self, key: PRNGKey, sample_shape: Optional[Shape] = None) -> Array: """Sample from Categorical distribution. Args: key: PRNG key. sample_shape: Shape of samples to draw. Returns: Category indices sampled from distribution. """ return jr.categorical(key, self.full_logits, shape=sample_shape + self.shape)