Source code for sppcax.distributions.poisson

"""Poisson distribution in natural parameterization."""

from typing import ClassVar, Optional

import jax.numpy as jnp
import jax.random as jr
from jax.scipy.special import gammaln

from ..types import Array, PRNGKey, Shape
from .exponential_family import ExponentialFamily


[docs]class Poisson(ExponentialFamily): """Poisson distribution parameterized by log rate. The Poisson distribution has natural parameter η = log(λ) where λ is the rate parameter, and sufficient statistic T(x) = x. """ nat1: Array natural_param_shape: ClassVar[Shape] = () # [log_rate] def __init__(self, log_rate: Array): """Initialize Poisson distribution. Args: log_rate: Natural parameter η = log(λ). """ super().__init__(batch_shape=jnp.shape(log_rate), event_shape=()) self.nat1 = log_rate
[docs] @classmethod def from_natural_parameters(cls, eta: Array) -> "Poisson": """Create Poisson from natural parameters. Args: log_rate: Natural parameter η = log(λ). Returns: Poisson instance. """ return cls(log_rate=eta)
@property def log_rate(self) -> Array: """Get log(rate) parameter.""" return self.nat1 @property def rate(self) -> Array: """Get rate parameter.""" return jnp.exp(self.log_rate) @property def natural_parameters(self) -> Array: """Get natural parameters (log rate). Returns: Natural parameters η = log(λ). """ return self.nat1
[docs] def sufficient_statistics(self, x: Array) -> Array: """Compute sufficient statistics T(x) = x. Args: x: Count data. Returns: Sufficient statistics T(x) = x. """ return x
@property def expected_sufficient_statistics(self) -> Array: """Compute E[T(x)] = E[x] = λ = exp(η). Returns: Expected sufficient statistics E[x]. """ return jnp.exp(self.nat1) @property def log_normalizer(self) -> Array: """Compute log normalizer A(η) = exp(η). Returns: Log normalizer A(η). """ return jnp.exp(self.nat1) def _check_support(self, x: Array) -> Array: """Check if values are within distribution support. Args: x: Values to check. Shape: batch_shape + event_shape Returns: Boolean mask of valid values. """ return x >= 0
[docs] def log_base_measure(self, x: Array) -> Array: """Compute log base measure log(h(x)) = -log(x!). Args: x: Count data. Returns: Log base measure -log(x!). """ return -gammaln(x + 1)
[docs] def sample(self, key: PRNGKey, sample_shape: Optional[Shape] = ()) -> Array: """Sample from Poisson distribution. Args: key: PRNG key. sample_shape: Shape of samples to draw. Returns: Count samples from distribution. """ rate = jnp.broadcast_to(jnp.exp(self.log_rate), sample_shape + self.shape) return jr.poisson(key, rate)
@property def entropy(self) -> Array: """Compute entropy of Poisson distribution. Returns: Entropy H(λ) = λ(1 - log(λ)) + exp(-λ)sum_{k=0}^∞ λ^k log(k!)/k! """ # Approximate entropy using rate and log rate inv_rate = jnp.exp(-self.log_rate) return 0.5 * (jnp.log(2 * jnp.pi) + 1 + self.log_rate) - inv_rate / 12