sppcax.distributions package

Submodules

sppcax.distributions.base module

Base distribution class.

class sppcax.distributions.base.Distribution(batch_shape: Tuple[int, ...], event_shape: Tuple[int, ...])[source]

Bases: Module

Base distribution class in natural parameters.

batch_shape

Shape of batch dimensions.

Type

Tuple[int, …]

event_shape

Shape of event dimensions.

Type

Tuple[int, …]

batch_shape: Tuple[int, ...]
broadcast_to_shape(x: Array, ignore_event: bool = False) Array[source]

Broadcast array to match distribution shape.

Parameters
  • x – Array to broadcast.

  • ignore_event – If True, only broadcast batch dimensions.

Returns

Broadcasted array.

entropy() Array[source]

Compute entropy of the distribution.

Returns

batch_shape

Return type

Entropy with shape

event_shape: Tuple[int, ...]
log_prob(x: Array) Array[source]

Compute log probability of x.

Parameters

x – Value to compute log probability for. Should have shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log probability with shape

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Optional[Tuple[int, ...]] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Additional sample dimensions.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

property shape: Tuple[int, ...]

Full shape (batch_shape + event_shape).

sppcax.distributions.beta module

Beta distribution implementation.

class sppcax.distributions.beta.Beta(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]

Bases: ExponentialFamily

Beta distribution in natural parameters.

The beta distribution has density: p(x|α,β) = x^(α-1) * (1-x)^(β-1) / B(α,β) for x ∈ [0,1]

In exponential family form: h(x) = 1 η = [α-1, β-1] T(x) = [log(x), log(1-x)] A(η) = log(B(η₁+1, η₂+1))

property alpha: Array

Get first shape parameter α.

property beta: Array

Get second shape parameter β.

dnat1: Array
dnat2: Array
property expected_sufficient_statistics: Array

Compute E[T(x)] = [ψ(α) - ψ(α+β), ψ(β) - ψ(α+β)].

Returns

batch_shape + (2,)

Return type

Expected sufficient statistics [E[log(x)], E[log(1-x)]] with shape

classmethod from_natural_parameters(eta: Array) Beta[source]

Create beta distribution from natural parameters.

Parameters

eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)

Returns

Beta distribution.

property kl_divergence_from_prior: Array

Compute KL divergence KL(post||prior).

Returns

batch_shape

Return type

KL divergence KL(post||prior) with shape

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

zero

property log_normalizer: Array

Compute log normalizer A(η) = log(B(η₁+1, η₂+1)).

Returns

batch_shape

Return type

Log normalizer with shape

property mean: Array

Get mean of the distribution.

property nat1: Array

First natural parameter η₁ = α - 1.

nat1_0: Array
property nat2: Array

Second natural parameter η₂ = β - 1.

nat2_0: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (2,)
property natural_parameters: Array

Get natural parameters η = [α-1, β-1].

Returns

batch_shape + (2,)

Return type

Natural parameters [η₁, η₂] with shape

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [log(x), log(1-x)].

Parameters

x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + (2,)

Return type

Sufficient statistics [log(x), log(1-x)] with shape

property variance: Array

Get variance of the distribution.

sppcax.distributions.categorical module

Categorical distribution in natural parameterization.

class sppcax.distributions.categorical.Categorical(logits: Array)[source]

Bases: ExponentialFamily

Categorical distribution parameterized by logits.

The Categorical distribution has K-1 natural parameters η_k = log(p_k/p_K) for k=1,…,K-1 where p_K is the probability of the last category.

property expected_sufficient_statistics: Array

Compute E[T(x)] = p_1,…,p_{K-1}.

Returns

Expected probabilities for first K-1 categories.

classmethod from_natural_parameters(eta: Array) Categorical[source]

Create Categorical from natural parameters.

Parameters

eta – Log-odds relative to last category. Shape (…, K-1) where K is number of categories.

Returns

Categorical instance.

property full_logits: Array
property log_normalizer: Array

Compute log normalizer A(η).

For categorical, A(η) = log(1 + sum(exp(η_k))).

Returns

Log normalizer A(η).

nat1: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (1,)
property natural_parameters: Array

Get natural parameters (logits).

Returns

Natural parameters η.

property probs: Array

Get probability parameters.

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Optional[Tuple[int, ...]] = None) Array[source]

Sample from Categorical distribution.

Parameters
  • key – PRNG key.

  • sample_shape – Shape of samples to draw.

Returns

Category indices sampled from distribution.

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x).

For categorical data, T(x) is a one-hot vector with the last category omitted (since probabilities sum to 1).

Parameters

x – Category indices.

Returns

One-hot encoded data (excluding last category).

sppcax.distributions.delta module

Delta distribution implementation.

class sppcax.distributions.delta.Delta(location: Array, sufficient_statistics_fn: Optional[Callable] = None)[source]

Bases: Distribution

Delta distribution (Dirac delta) concentrated at a single point.

property covariance: Array

Covariance matrix (always zero for delta distribution).

entropy() Array[source]

Compute entropy (always 0 for delta distribution).

Returns

batch_shape

Return type

Entropy with shape

property expected_psi: Array

matrix inverse for square matrices, element-wise otherwise.

Type

Expected precision for Delta

property expected_second_moment: Array

Expected second moment E[XX^T] = location @ location^T (no variance).

property expected_sufficient_statistics: Array

Compute expected sufficient statistics.

For delta distribution, this is just sufficient_statistics(location) since all probability mass is concentrated at location.

Returns

Expected sufficient statistics with shape determined by sufficient_statistics_fn.

log_prob(x: Array) Array[source]

Compute log probability.

Parameters

x – Value to compute log probability for. Shape: batch_shape + event_shape

Returns

batch_shape Returns 0 at location, -inf elsewhere.

Return type

Log probability with shape

mean: Array
mf_expectations() dict[source]

Return expectations for mean-field coordinate ascent partner.

mf_update(*args) Delta[source]

Mean-field update for Delta is a no-op (fixed component).

mode() Array[source]
property precision: Array

Precision (infinite for delta, but return large finite value for compatibility).

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution (always returns location).

Parameters
  • key – PRNG key (unused).

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape All samples are equal to location.

Return type

Samples with shape

sufficient_statistics: Callable
sppcax.distributions.delta.default_ss(x: Array) Array[source]

Compute default sufficient statistics [x, vec(xx^T)] for MVN.

Parameters

x – Input vector with shape (…, d).

Returns

Concatenation of x and vectorized outer product with shape (…, d + d*d).

sppcax.distributions.exponential_family module

Base class for exponential family distributions.

class sppcax.distributions.exponential_family.ExponentialFamily(batch_shape: Tuple[int, ...] = (), event_shape: Tuple[int, ...] = ())[source]

Bases: Distribution

Base class for exponential family distributions in natural parameterization.

The exponential family has the form: p(x|η) = h(x)exp(η^T T(x) - A(η)) where: - η: natural parameters - T(x): sufficient statistics - A(η): log normalizer - h(x): base measure

natural_param_shape

Shape of natural parameters (class variable).

Type

ClassVar[Tuple[int, …]]

property entropy: Array

Compute entropy of the distribution.

Returns

batch_shape

Return type

Entropy with shape

property expected_log_base_measure: Array

Compute the expectation of the log base measure E_{p(x)}[log(h(x))]

Returns

batch_shape

Return type

Expectation E_{p(x)}[log(h(x))] with shape

property expected_sufficient_statistics: Array

Compute expected sufficient statistics E[T(x)].

Returns

batch_shape + natural_param_shape

Return type

Expected sufficient statistics E[T(x)] with shape

classmethod from_natural_parameters(eta: Array) ExponentialFamily[source]

Create distribution from natural parameters.

Parameters

eta – Natural parameters with shape: batch_shape + natural_param_shape

Returns

Distribution instance.

kl_divergence(other: ExponentialFamily) Array[source]

Compute KL divergence KL(self||other).

Parameters

other – Other distribution to compute KL divergence with.

Returns

batch_shape

Return type

KL divergence KL(self||other) with shape

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log base measure log(h(x)) with shape

property log_normalizer: Array

Compute log normalizer A(η).

Returns

batch_shape

Return type

Log normalizer A(η) with shape

log_prob(x: Array) Array[source]

Compute log probability.

Parameters

x – Data to compute log probability for. Shape: batch_shape + event_shape

Returns

batch_shape Returns -inf for values outside the support.

Return type

Log probability log p(x|η) with shape

natural_param_shape: ClassVar[Tuple[int, ...]] = ()
property natural_parameters: Array

Get natural parameters of the distribution.

Returns

batch_shape + natural_param_shape

Return type

Natural parameters η with shape

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x).

Parameters

x – Data to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + natural_param_shape

Return type

Sufficient statistics T(x) with shape

sppcax.distributions.gamma module

Gamma distribution implementation.

class sppcax.distributions.gamma.Gamma(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]

Bases: ExponentialFamily

Gamma distribution in natural parameters.

The gamma distribution has density: p(x|α,β) = β^α * x^(α-1) * exp(-βx) / Γ(α)

In exponential family form: h(x) = 1 η = [α-1, -β] T(x) = [log(x), x] A(η) = log(Γ(η₁ + 1)) - (η₁ + 1)*log(-η₂)

property alpha: Array

Get shape parameter α.

property beta: Array

Get rate parameter β.

dnat1: Array
dnat2: Array
property expected_sufficient_statistics: Array

Compute E[T(x)] = [ψ(α) - log(β), α/β].

Returns

batch_shape + (2,)

Return type

Expected sufficient statistics [E[log(x)], E[x]] with shape

classmethod from_natural_parameters(eta: Array) Gamma[source]

Create gamma distribution from natural parameters.

Parameters

eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)

Returns

Gamma distribution.

property kl_divergence_from_prior: Array

Compute KL divergence KL(post||prior).

Returns

batch_shape

Return type

KL divergence KL(post||prior) with shape

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

zero

property log_normalizer: Array

Compute log normalizer A(η) = log(Γ(η₁ + 1)) - (η₁ + 1)*log(-η₂).

Returns

batch_shape

Return type

Log normalizer with shape

property mean: Array

Get mean E[x] = α/β.

property nat1: Array

First natural parameter η₁ = α - 1.

nat1_0: Array
property nat2: Array

Second natural parameter η₂ = -β.

nat2_0: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (2,)
property natural_parameters: Array

Get natural parameters η = [α-1, -β].

Returns

batch_shape + (2,)

Return type

Natural parameters [η₁, η₂] with shape

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [log(x), x].

Parameters

x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + (2,)

Return type

Sufficient statistics [log(x), x] with shape

class sppcax.distributions.gamma.InverseGamma(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]

Bases: ExponentialFamily

Inverse Gamma distribution in natural parameters.

The inverse gamma distribution has density: p(x|α,β) = β^α * x^(-α-1) * exp(-β/x) / Γ(α)

In exponential family form: h(x) = 1 η = [-α-1, -β] T(x) = [log(x), 1/x] A(η) = log(Γ(-η₁ - 1)) + (η₁ + 1)*log(-η₂)

property alpha: Array

Get shape parameter α.

property beta: Array

Get scale parameter β.

dnat1: Array
dnat2: Array
property expected_psi: Array

Expected precision E[1/x] = alpha/beta for InverseGamma.

property expected_sufficient_statistics: Array

Compute E[T(x)] = [log(β) - ψ(α), α/β].

Returns

batch_shape + (2,)

Return type

Expected sufficient statistics [E[log(x)], E[1/x]] with shape

classmethod from_natural_parameters(eta: Array) InverseGamma[source]

Create inverse gamma distribution from natural parameters.

Parameters

eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)

Returns

InverseGamma distribution.

property kl_divergence_from_prior: Array

Compute KL divergence KL(post||prior).

Returns

batch_shape

Return type

KL divergence KL(post||prior) with shape

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

zero

property log_normalizer: Array

Compute log normalizer A(η) = log(Γ(-η₁ - 1)) + (η₁ + 1)*log(-η₂)

Returns

batch_shape

Return type

Log normalizer with shape

property mean: Array

Get mean E[x] = β/(α-1).

mf_expectations() dict[source]

Return expectations for mean-field coordinate ascent partner.

mf_update(stats: tuple, partner_expectations: dict) InverseGamma[source]

Mean-field coordinate ascent update for InverseGamma noise.

Given partner (weights) expectations, compute the residual and update the InverseGamma natural parameters.

Parameters
  • stats – Sufficient statistics tuple (SxxT, SxyT, SyyT, N).

  • partner_expectations – Dict with ‘mean’ and ‘second_moment’ from weights.

Returns

Updated InverseGamma distribution.

mode() Array[source]
property nat1: Array

First natural parameter η₁ = -α - 1.

nat1_0: Array
property nat2: Array

Second natural parameter η₂ = -β.

nat2_0: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (2,)
property natural_parameters: Array

Get natural parameters η = [-α-1, -β].

Returns

batch_shape + (2,)

Return type

Natural parameters [η₁, η₂] with shape

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [log(x), 1/x].

Parameters

x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + (2,)

Return type

Sufficient statistics [log(x), 1/x] with shape

sppcax.distributions.inverse_wishart module

Inverse Wishart distribution with mean-field interface.

Wraps the dynamax InverseWishart with natural parameters and coordinate ascent update methods for use in MeanField composites.

class sppcax.distributions.inverse_wishart.InverseWishart(df0: float | jax.Array, scale0: Array)[source]

Bases: ExponentialFamily

Inverse Wishart distribution in natural parameters.

For Sigma ~ IW(df, Psi):

p(Sigma | df, Psi) propto |Sigma|^{-(df+k+1)/2} exp(-tr(Psi Sigma^{-1})/2)

Exponential family form:

eta1 = -(df + k + 1) / 2 (scalar) eta2 = -Psi / 2 (k x k matrix) T(Sigma) = [log|Sigma|, Sigma^{-1}] A(eta) = -(df/2) log|Psi/2| + multigammaln(df/2, k) + (df*k/2) log(2)

Key moments:

E[Sigma^{-1}] = df * Psi^{-1} E[log|Sigma^{-1}|] = multidigamma(df/2, k) + k*log(2) - log|Psi|

property df: Array

df = -2*nat1 - k - 1.

Type

Degrees of freedom

property dim: int

Dimension k of the k x k matrices.

dnat1: Array
dnat2: Array
property entropy: Array

Entropy of IW(df, Psi).

property expected_psi: Array

E[Sigma^{-1}] = df * Psi^{-1}.

property expected_sufficient_statistics: tuple

E[T(Sigma)] = (E[log|Sigma|], E[Sigma^{-1}]).

property inv_scale: Array

Inverse scale matrix Psi^{-1}.

property kl_divergence_from_prior: Array

KL(posterior || prior) using stored prior natural parameters.

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log base measure log(h(x)) with shape

property log_normalizer: Array

Log normalizer A(df, Psi).

log_prob(x: Array) Array[source]

Log probability of x under IW(df, Psi).

property mean: Array

Mean E[Sigma] = Psi / (df - k - 1), requires df > k + 1.

mf_expectations() dict[source]

Return expectations for mean-field coordinate ascent partner.

mf_update(stats: tuple, partner_expectations: dict) InverseWishart[source]

Mean-field coordinate ascent update for InverseWishart noise.

For IW noise in a linear model y = W x + e, e ~ N(0, Sigma):

df_post = df_prior + N Psi_post = Psi_prior + E[sum_t (y_t - W x_t)(y_t - W x_t)^T]

= Psi_prior + SyyT - SxyT^T E[W]^T - E[W] SxyT + E[WW^T] SxxT

The quadratic term E[W SxxT W^T] decomposes under mean-field q(W) = prod_d N(w_d; m_d, V_d):

off-diag (i!=j): m_i^T SxxT m_j (independent rows) diag (i=i): tr(E[w_i w_i^T] SxxT) = m_i^T SxxT m_i + tr(V_i SxxT)

mode() Array[source]

Mode = Psi / (df + k + 1).

property nat1: Array

-(df + k + 1) / 2.

Type

First natural parameter

nat1_0: Array
property nat2: Array

-Psi / 2.

Type

Second natural parameter

nat2_0: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = ()
property natural_parameters: tuple

Natural parameters (eta1, eta2).

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from IW(df, Psi) via Bartlett decomposition.

property scale: Array

Scale matrix Psi = -2 * nat2.

sufficient_statistics(x: Array) tuple[source]

T(Sigma) = (log|Sigma|, Sigma^{-1}).

sppcax.distributions.inverse_wishart.multidigamma(a: Array, p: int) Array[source]

Multivariate digamma: psi_p(a) = sum_{i=1}^{p} psi(a + (1-i)/2).

sppcax.distributions.mean_field module

Mean-field composite distribution with independent components.

class sppcax.distributions.mean_field.MeanField(weights: Distribution, noise: Distribution)[source]

Bases: Distribution

Mean-field composite distribution: q(W, Sigma) = q(W) x q(Sigma).

Components are independent. Posterior updates use coordinate ascent (alternating updates of weights and noise components).

weights

Distribution over weight parameters (MVN or Delta if frozen).

Type

sppcax.distributions.base.Distribution

noise

Distribution over noise parameters (InverseGamma, Delta, etc.).

Type

sppcax.distributions.base.Distribution

n_iter

Number of coordinate ascent iterations for posterior updates.

property alpha: Array

Shape parameter from InverseGamma noise (MVNIG compat).

property beta: Array

Scale parameter from InverseGamma noise (MVNIG compat).

property col_covariance: Array

Column covariance (base covariance, same as covariance).

property covariance: Array

Covariance of the weights component (base, not scaled by noise).

entropy() Array[source]

Entropy of the mean-field distribution (sum of component entropies).

property expected_covariance: Array

Expected covariance E[sigma^2] * base_covariance.

For InverseGamma noise: scalar E[sigma^2] per row. For InverseWishart noise: full matrix E[Sigma], not factored with weights cov. For Delta noise: fixed value.

property expected_psi: Array

Expected noise precision E[1/sigma^2] from noise component.

property expected_sufficient_statistics_psi: Array

Expected sufficient statistics of noise precision (MVNIG compat).

property inv_gamma

InverseGamma component (for MVNIG compatibility).

log_prob(x: Tuple[Array, Array]) Array[source]

Compute log probability.

Parameters

x – Tuple of (cov, w) where: w: Value of the sample state cov: Value of the sample covariance

Returns

Log probability

property mask: Array

Mask from weights component (if available).

property mean: Array

Mean of the weights component.

mode() Tuple[Array, Array][source]

Compute joint mode (noise_cov_matrix, weights_mean).

Returns

Tuple of (noise covariance as matrix, weights mean).

property mvn: Distribution

Weights component (for ARD compatibility).

noise: Distribution
property precision: Array

Precision of the weights component.

sample(seed: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Tuple[Array, Array][source]

Sample from both components independently.

Parameters
  • seed – PRNG key.

  • sample_shape – Additional sample dimensions.

Returns

Tuple of (noise_sample, weights_sample).

weights: Distribution

sppcax.distributions.mvn module

Multivariate normal distribution implementation.

class sppcax.distributions.mvn.MultivariateNormal(loc: Array, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None, mask: Optional[Array] = None)[source]

Bases: ExponentialFamily

Multivariate normal distribution in natural parameters.

apply_mask_matrix(x: Array, zeromask: bool = False) Array[source]

Apply mask to a matrix, zeroing out masked rows and columns.

Parameters
  • x – Matrix with shape (…, d, d)

  • zeromask – If True, set masked entries to 0. If False (default), set masked entries to identity matrix values.

Returns

Masked matrix with same shape.

apply_mask_vector(x: Array) Array[source]

Apply mask to a vector, zeroing out masked dimensions.

Parameters

x – Vector with shape (…, d)

Returns

Masked vector with same shape

property covariance: Array
dnat1: Float[Array, 'dim']
dnat2: Float[Array, 'rows cols']
property expected_second_moment: Array

Expected second moment E[xx^T] = Cov + mean @ mean^T, per row.

property expected_sufficient_statistics: Array

Compute E[T(x)] = [μ, vec(μμ^T + Σ)].

Returns

Expected sufficient statistics [E[x], vec(E[xx^T])].

classmethod from_natural_parameters(nat1: Array, nat2: Array, mask: Optional[Array] = None) MultivariateNormal[source]

Create MVN from natural parameters.

Parameters
  • nat1 – First natural parameter (precision * mean).

  • nat2 – Second natural parameter (-0.5 * precision).

  • mask – Optional boolean mask with shape matching nat1

Returns

MultivariateNormal instance.

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log base measure log(h(x)) with shape

property log_normalizer: Array

Compute log normalizer A(η).

Returns

batch_shape

Return type

Log normalizer A(η) with shape

mask: Float[Array, 'dim']
property mean: Array

Get mean parameter.

mf_expectations() dict[source]

Return expectations for mean-field coordinate ascent partner.

mf_update(stats: tuple, partner_expectations: dict) MultivariateNormal[source]

Mean-field coordinate ascent update for MVN weights.

Scales data sufficient statistics by partner’s expected precision E[1/sigma^2], then performs standard natural parameter update.

Parameters
  • stats – Sufficient statistics tuple (SxxT, SxyT, SyyT, N).

  • partner_expectations – Dict with ‘expected_precision’ from noise component. E[precision] can be scalar (isotropic), (D,) vector (per-feature IG), or (D, D) matrix (InverseWishart). For matrix precision, the diagonal is used for per-row weight updates.

Returns

Updated MVN distribution (posterior).

property nat1: Array

First natural parameter (precision * mean), masked.

nat1_0: Float[Array, 'dim']
property nat2: Array

Second natural parameter (-0.5 * precision).

nat2_0: Float[Array, 'rows cols']
natural_param_shape: ClassVar[Tuple[int, ...]] = (1,)
property natural_parameters: Array

Get natural parameters η = [precision*mean, -0.5*precision].

Returns

Natural parameters [η₁, vec(η₂)].

property precision: Array

Get precision parameter.

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

Samples from the distribution.

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [x, xx^T].

Parameters

x – Value to compute sufficient statistics for.

Returns

Sufficient statistics [x, vec(xx^T)].

sppcax.distributions.mvn_gamma module

Multivariate Normal-Gamma distribution implementation.

class sppcax.distributions.mvn_gamma.MultivariateNormalInverseGamma(loc: Array, *, isotropic_noise: bool, mask: Optional[Array] = None, alpha0: float = 2.0, beta0: float = 1.0, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None)[source]

Bases: ExponentialFamily

Multivariate Normal-Gamma distribution.

This distribution combines: p(x|σ²) = N(x; μ, σ²Λ⁻¹) # Multivariate Normal p(σ²) = InverseGamma(σ²; α, β) # Variance scalar

where: - x is a vector (can be batched) - μ is the location parameter - Λ is a base precision matrix - σ² is a scalar variance parameter - α, β are Inverse-Gamma distribution parameters

property alpha: Array

Shape parameter α of the InverseGamma component.

property beta: Array

Scale parameter β of the InverseGamma component.

property col_covariance: Array

Column covariance Λ⁻¹ (alias for covariance).

property covariance: Array

Base covariance Λ⁻¹ of the MVN component (without noise scaling).

property expected_covariance: Array

Expected covariance E[σ²] * Λ⁻¹.

property expected_log_psi: Array

Compute expected log precision E[log(psi)].

property expected_psi: Array

Compute expected precision E[psi].

property expected_sufficient_statistics_psi: Array

Compute expected sufficient statistics of psi.

inv_gamma: InverseGamma
log_prob(x: Tuple[Array, Array]) Array[source]

Compute log probability.

Parameters

x – Tuple of (sig_sqr, w) where: w: Value of the sample state sig_sqr: Value of the sample variance

Returns

Log probability

property mean: Array

Get mean of the marginal distribution p(x).

mode() Tuple[Array, Array][source]

Solve for the mode. Recall, ..math:

p(\mu, \sigma^2) \propto
    \mathrm{N}(x \mid \mu, \sigma^2 \Sigma) \times
    \mathrm{IG}(\sigma^2 \mid \alpha, \beta)

The optimal mean is \(x^* = \mu_0\). Substituting this in, ..math:

p(\mu^*, \sigma^2) \propto IG(\sigma^2 \mid \alpha + D/2, \beta)

where D is dimensionality of x, and the mode of this inverse gamma distribution is at ..math:

(\sigma^2)* = \beta / (\alpha + (D + 2)/2)
mvn: MultivariateNormal
property precision: Array

Base precision matrix Λ of the MVN component.

sample(seed: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Tuple[Array, Array][source]

Sample from the distribution.

Parameters
  • seed – PRNG key

  • sample_shape – Shape of samples to draw

Returns

Tuple of (sig_sqr, value) samples

sppcax.distributions.mvn_gamma.mvnig_posterior_update(mvnig_prior: MultivariateNormalInverseGamma, sufficient_stats: tuple, props: object) MultivariateNormalInverseGamma[source]

Update the multivariate normal inverse gamma (MVNIG) posterior from sufficient statistics.

Parameters
  • mvnig_prior – Prior MVNIG distribution.

  • sufficient_stats – Tuple of (SxxT, SxyT, SyyT, N) sufficient statistics.

  • props – Parameter properties controlling which parameters are trainable.

Returns

Posterior MVNIG distribution.

sppcax.distributions.normal module

Normal distribution implementations.

class sppcax.distributions.normal.Normal(loc: Array = 0.0, scale: Array = 1.0)[source]

Bases: ExponentialFamily

Univariate normal distribution in natural parameters.

The normal distribution has density: p(x|μ,σ) = 1/√(2πσ²) * exp(-(x-μ)²/(2σ²))

In exponential family form: η = [μ/σ², -1/(2σ²)] T(x) = [x, x²] A(η) = -η₁²/(4η₂) - (1/2)log(-2η₂) + (1/2)log(2π)

property entropy: Array

Compute entropy of the distribution.

Returns

batch_shape

Return type

Entropy with shape

property expected_sufficient_statistics: Array

Compute E[T(x)] = [μ, μ² + σ²].

Returns

batch_shape + (2,)

Return type

Expected sufficient statistics [E[x], E[x²]] with shape

classmethod from_natural_parameters(eta: Array) Normal[source]

Create normal distribution from natural parameters.

Parameters

eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)

Returns

Normal distribution.

property loc: Array

Get location parameter.

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log base measure log(h(x)) with shape

property log_normalizer: Array

Compute log normalizer A(η).

Returns

batch_shape

Return type

Log normalizer with shape

nat1: Array
nat2: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (2,)
property natural_parameters: Array

Get natural parameters η = [precision*mean, -0.5*precision].

Returns

batch_shape + (2,)

Return type

Natural parameters [η₁, η₂] with shape

property precision: Array

Get precision parameter

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

property scale: Array

Get scale parameter.

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [x, x²].

Parameters

x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + (2,)

Return type

Sufficient statistics [x, x²] with shape

sppcax.distributions.poisson module

Poisson distribution in natural parameterization.

class sppcax.distributions.poisson.Poisson(log_rate: Array)[source]

Bases: ExponentialFamily

Poisson distribution parameterized by log rate.

The Poisson distribution has natural parameter η = log(λ) where λ is the rate parameter, and sufficient statistic T(x) = x.

property entropy: Array

Compute entropy of Poisson distribution.

Returns

Entropy H(λ) = λ(1 - log(λ)) + exp(-λ)sum_{k=0}^∞ λ^k log(k!)/k!

property expected_sufficient_statistics: Array

Compute E[T(x)] = E[x] = λ = exp(η).

Returns

Expected sufficient statistics E[x].

classmethod from_natural_parameters(eta: Array) Poisson[source]

Create Poisson from natural parameters.

Parameters

log_rate – Natural parameter η = log(λ).

Returns

Poisson instance.

log_base_measure(x: Array) Array[source]

Compute log base measure log(h(x)) = -log(x!).

Parameters

x – Count data.

Returns

Log base measure -log(x!).

property log_normalizer: Array

Compute log normalizer A(η) = exp(η).

Returns

Log normalizer A(η).

property log_rate: Array

Get log(rate) parameter.

nat1: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = ()
property natural_parameters: Array

Get natural parameters (log rate).

Returns

Natural parameters η = log(λ).

property rate: Array

Get rate parameter.

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Optional[Tuple[int, ...]] = ()) Array[source]

Sample from Poisson distribution.

Parameters
  • key – PRNG key.

  • sample_shape – Shape of samples to draw.

Returns

Count samples from distribution.

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = x.

Parameters

x – Count data.

Returns

Sufficient statistics T(x) = x.

sppcax.distributions.updates module

sppcax.distributions.utils module

Utility functions for distributions.

sppcax.distributions.utils.cho_inv(matrix: Float[Array, 'rows cols'], diagonal_boost: float = 1e-12) Float[Array, 'rows cols'][source]

Invert a positive-definite matrix via Cholesky decomposition.

Parameters
  • matrix – Positive-definite matrix with shape (…, d, d).

  • diagonal_boost – Small value added to diagonal for numerical stability.

Returns

Inverse matrix with same shape.

sppcax.distributions.utils.moment_to_natural(mean: Array, covariance: Array) tuple[jax.Array, jax.Array][source]

Convert moment parameters to natural parameters for multivariate normal.

Parameters
  • mean – Mean vector.

  • covariance – Covariance matrix.

Returns

Tuple of (nat1, nat2).

sppcax.distributions.utils.natural_to_moment(nat1: Array, nat2: Array) tuple[jax.Array, jax.Array][source]

Convert natural parameters to moment parameters for multivariate normal.

Parameters
  • nat1 – First natural parameter (precision * mean).

  • nat2 – Second natural parameter (-0.5 * precision).

Returns

Tuple of (mean, covariance).

sppcax.distributions.utils.safe_cholesky(X: Array, jitter: float = 1e-12) Array[source]

Compute Cholesky decomposition with added diagonal jitter for numerical stability.

Parameters
  • X – Symmetric positive definite matrix.

  • jitter – Small positive value to add to diagonal for stability.

Returns

Lower triangular Cholesky factor L such that X ≈ L @ L^T.

sppcax.distributions.utils.safe_cholesky_and_logdet(X: Array, jitter: float = 1e-12) tuple[jax.Array, jaxtyping.Float[Array, '']][source]

Compute Cholesky decomposition and log determinant with added diagonal jitter.

Parameters
  • X – Symmetric positive definite matrix.

  • jitter – Small positive value to add to diagonal for stability.

Returns

Tuple of (L, logdet) where L is the lower triangular Cholesky factor and logdet is the log determinant of X.

sppcax.distributions.utils.symmetrize(matrix: Float[Array, 'rows cols']) Float[Array, 'rows cols'][source]

Symmetrize one or more matrices.

Module contents

Distribution classes.

class sppcax.distributions.Beta(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]

Bases: ExponentialFamily

Beta distribution in natural parameters.

The beta distribution has density: p(x|α,β) = x^(α-1) * (1-x)^(β-1) / B(α,β) for x ∈ [0,1]

In exponential family form: h(x) = 1 η = [α-1, β-1] T(x) = [log(x), log(1-x)] A(η) = log(B(η₁+1, η₂+1))

property alpha: Array

Get first shape parameter α.

batch_shape: Tuple[int, ...]
property beta: Array

Get second shape parameter β.

dnat1: Array
dnat2: Array
event_shape: Tuple[int, ...]
property expected_sufficient_statistics: Array

Compute E[T(x)] = [ψ(α) - ψ(α+β), ψ(β) - ψ(α+β)].

Returns

batch_shape + (2,)

Return type

Expected sufficient statistics [E[log(x)], E[log(1-x)]] with shape

classmethod from_natural_parameters(eta: Array) Beta[source]

Create beta distribution from natural parameters.

Parameters

eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)

Returns

Beta distribution.

property kl_divergence_from_prior: Array

Compute KL divergence KL(post||prior).

Returns

batch_shape

Return type

KL divergence KL(post||prior) with shape

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

zero

property log_normalizer: Array

Compute log normalizer A(η) = log(B(η₁+1, η₂+1)).

Returns

batch_shape

Return type

Log normalizer with shape

property mean: Array

Get mean of the distribution.

property nat1: Array

First natural parameter η₁ = α - 1.

nat1_0: Array
property nat2: Array

Second natural parameter η₂ = β - 1.

nat2_0: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (2,)
property natural_parameters: Array

Get natural parameters η = [α-1, β-1].

Returns

batch_shape + (2,)

Return type

Natural parameters [η₁, η₂] with shape

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [log(x), log(1-x)].

Parameters

x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + (2,)

Return type

Sufficient statistics [log(x), log(1-x)] with shape

property variance: Array

Get variance of the distribution.

class sppcax.distributions.Categorical(logits: Array)[source]

Bases: ExponentialFamily

Categorical distribution parameterized by logits.

The Categorical distribution has K-1 natural parameters η_k = log(p_k/p_K) for k=1,…,K-1 where p_K is the probability of the last category.

batch_shape: Tuple[int, ...]
event_shape: Tuple[int, ...]
property expected_sufficient_statistics: Array

Compute E[T(x)] = p_1,…,p_{K-1}.

Returns

Expected probabilities for first K-1 categories.

classmethod from_natural_parameters(eta: Array) Categorical[source]

Create Categorical from natural parameters.

Parameters

eta – Log-odds relative to last category. Shape (…, K-1) where K is number of categories.

Returns

Categorical instance.

property full_logits: Array
property log_normalizer: Array

Compute log normalizer A(η).

For categorical, A(η) = log(1 + sum(exp(η_k))).

Returns

Log normalizer A(η).

nat1: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (1,)
property natural_parameters: Array

Get natural parameters (logits).

Returns

Natural parameters η.

property probs: Array

Get probability parameters.

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Optional[Tuple[int, ...]] = None) Array[source]

Sample from Categorical distribution.

Parameters
  • key – PRNG key.

  • sample_shape – Shape of samples to draw.

Returns

Category indices sampled from distribution.

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x).

For categorical data, T(x) is a one-hot vector with the last category omitted (since probabilities sum to 1).

Parameters

x – Category indices.

Returns

One-hot encoded data (excluding last category).

class sppcax.distributions.Delta(location: Array, sufficient_statistics_fn: Optional[Callable] = None)[source]

Bases: Distribution

Delta distribution (Dirac delta) concentrated at a single point.

batch_shape: Tuple[int, ...]
property covariance: Array

Covariance matrix (always zero for delta distribution).

entropy() Array[source]

Compute entropy (always 0 for delta distribution).

Returns

batch_shape

Return type

Entropy with shape

event_shape: Tuple[int, ...]
property expected_psi: Array

matrix inverse for square matrices, element-wise otherwise.

Type

Expected precision for Delta

property expected_second_moment: Array

Expected second moment E[XX^T] = location @ location^T (no variance).

property expected_sufficient_statistics: Array

Compute expected sufficient statistics.

For delta distribution, this is just sufficient_statistics(location) since all probability mass is concentrated at location.

Returns

Expected sufficient statistics with shape determined by sufficient_statistics_fn.

log_prob(x: Array) Array[source]

Compute log probability.

Parameters

x – Value to compute log probability for. Shape: batch_shape + event_shape

Returns

batch_shape Returns 0 at location, -inf elsewhere.

Return type

Log probability with shape

mean: Array
mf_expectations() dict[source]

Return expectations for mean-field coordinate ascent partner.

mf_update(*args) Delta[source]

Mean-field update for Delta is a no-op (fixed component).

mode() Array[source]
property precision: Array

Precision (infinite for delta, but return large finite value for compatibility).

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution (always returns location).

Parameters
  • key – PRNG key (unused).

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape All samples are equal to location.

Return type

Samples with shape

sufficient_statistics: Callable
class sppcax.distributions.Distribution(batch_shape: Tuple[int, ...], event_shape: Tuple[int, ...])[source]

Bases: Module

Base distribution class in natural parameters.

batch_shape

Shape of batch dimensions.

Type

Tuple[int, …]

event_shape

Shape of event dimensions.

Type

Tuple[int, …]

batch_shape: Tuple[int, ...]
broadcast_to_shape(x: Array, ignore_event: bool = False) Array[source]

Broadcast array to match distribution shape.

Parameters
  • x – Array to broadcast.

  • ignore_event – If True, only broadcast batch dimensions.

Returns

Broadcasted array.

entropy() Array[source]

Compute entropy of the distribution.

Returns

batch_shape

Return type

Entropy with shape

event_shape: Tuple[int, ...]
log_prob(x: Array) Array[source]

Compute log probability of x.

Parameters

x – Value to compute log probability for. Should have shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log probability with shape

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Optional[Tuple[int, ...]] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Additional sample dimensions.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

property shape: Tuple[int, ...]

Full shape (batch_shape + event_shape).

class sppcax.distributions.ExponentialFamily(batch_shape: Tuple[int, ...] = (), event_shape: Tuple[int, ...] = ())[source]

Bases: Distribution

Base class for exponential family distributions in natural parameterization.

The exponential family has the form: p(x|η) = h(x)exp(η^T T(x) - A(η)) where: - η: natural parameters - T(x): sufficient statistics - A(η): log normalizer - h(x): base measure

natural_param_shape

Shape of natural parameters (class variable).

Type

ClassVar[Tuple[int, …]]

batch_shape: Tuple[int, ...]
property entropy: Array

Compute entropy of the distribution.

Returns

batch_shape

Return type

Entropy with shape

event_shape: Tuple[int, ...]
property expected_log_base_measure: Array

Compute the expectation of the log base measure E_{p(x)}[log(h(x))]

Returns

batch_shape

Return type

Expectation E_{p(x)}[log(h(x))] with shape

property expected_sufficient_statistics: Array

Compute expected sufficient statistics E[T(x)].

Returns

batch_shape + natural_param_shape

Return type

Expected sufficient statistics E[T(x)] with shape

classmethod from_natural_parameters(eta: Array) ExponentialFamily[source]

Create distribution from natural parameters.

Parameters

eta – Natural parameters with shape: batch_shape + natural_param_shape

Returns

Distribution instance.

kl_divergence(other: ExponentialFamily) Array[source]

Compute KL divergence KL(self||other).

Parameters

other – Other distribution to compute KL divergence with.

Returns

batch_shape

Return type

KL divergence KL(self||other) with shape

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log base measure log(h(x)) with shape

property log_normalizer: Array

Compute log normalizer A(η).

Returns

batch_shape

Return type

Log normalizer A(η) with shape

log_prob(x: Array) Array[source]

Compute log probability.

Parameters

x – Data to compute log probability for. Shape: batch_shape + event_shape

Returns

batch_shape Returns -inf for values outside the support.

Return type

Log probability log p(x|η) with shape

natural_param_shape: ClassVar[Tuple[int, ...]] = ()
property natural_parameters: Array

Get natural parameters of the distribution.

Returns

batch_shape + natural_param_shape

Return type

Natural parameters η with shape

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x).

Parameters

x – Data to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + natural_param_shape

Return type

Sufficient statistics T(x) with shape

class sppcax.distributions.Gamma(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]

Bases: ExponentialFamily

Gamma distribution in natural parameters.

The gamma distribution has density: p(x|α,β) = β^α * x^(α-1) * exp(-βx) / Γ(α)

In exponential family form: h(x) = 1 η = [α-1, -β] T(x) = [log(x), x] A(η) = log(Γ(η₁ + 1)) - (η₁ + 1)*log(-η₂)

property alpha: Array

Get shape parameter α.

batch_shape: Tuple[int, ...]
property beta: Array

Get rate parameter β.

dnat1: Array
dnat2: Array
event_shape: Tuple[int, ...]
property expected_sufficient_statistics: Array

Compute E[T(x)] = [ψ(α) - log(β), α/β].

Returns

batch_shape + (2,)

Return type

Expected sufficient statistics [E[log(x)], E[x]] with shape

classmethod from_natural_parameters(eta: Array) Gamma[source]

Create gamma distribution from natural parameters.

Parameters

eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)

Returns

Gamma distribution.

property kl_divergence_from_prior: Array

Compute KL divergence KL(post||prior).

Returns

batch_shape

Return type

KL divergence KL(post||prior) with shape

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

zero

property log_normalizer: Array

Compute log normalizer A(η) = log(Γ(η₁ + 1)) - (η₁ + 1)*log(-η₂).

Returns

batch_shape

Return type

Log normalizer with shape

property mean: Array

Get mean E[x] = α/β.

property nat1: Array

First natural parameter η₁ = α - 1.

nat1_0: Array
property nat2: Array

Second natural parameter η₂ = -β.

nat2_0: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (2,)
property natural_parameters: Array

Get natural parameters η = [α-1, -β].

Returns

batch_shape + (2,)

Return type

Natural parameters [η₁, η₂] with shape

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [log(x), x].

Parameters

x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + (2,)

Return type

Sufficient statistics [log(x), x] with shape

class sppcax.distributions.InverseGamma(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]

Bases: ExponentialFamily

Inverse Gamma distribution in natural parameters.

The inverse gamma distribution has density: p(x|α,β) = β^α * x^(-α-1) * exp(-β/x) / Γ(α)

In exponential family form: h(x) = 1 η = [-α-1, -β] T(x) = [log(x), 1/x] A(η) = log(Γ(-η₁ - 1)) + (η₁ + 1)*log(-η₂)

property alpha: Array

Get shape parameter α.

batch_shape: Tuple[int, ...]
property beta: Array

Get scale parameter β.

dnat1: Array
dnat2: Array
event_shape: Tuple[int, ...]
property expected_psi: Array

Expected precision E[1/x] = alpha/beta for InverseGamma.

property expected_sufficient_statistics: Array

Compute E[T(x)] = [log(β) - ψ(α), α/β].

Returns

batch_shape + (2,)

Return type

Expected sufficient statistics [E[log(x)], E[1/x]] with shape

classmethod from_natural_parameters(eta: Array) InverseGamma[source]

Create inverse gamma distribution from natural parameters.

Parameters

eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)

Returns

InverseGamma distribution.

property kl_divergence_from_prior: Array

Compute KL divergence KL(post||prior).

Returns

batch_shape

Return type

KL divergence KL(post||prior) with shape

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

zero

property log_normalizer: Array

Compute log normalizer A(η) = log(Γ(-η₁ - 1)) + (η₁ + 1)*log(-η₂)

Returns

batch_shape

Return type

Log normalizer with shape

property mean: Array

Get mean E[x] = β/(α-1).

mf_expectations() dict[source]

Return expectations for mean-field coordinate ascent partner.

mf_update(stats: tuple, partner_expectations: dict) InverseGamma[source]

Mean-field coordinate ascent update for InverseGamma noise.

Given partner (weights) expectations, compute the residual and update the InverseGamma natural parameters.

Parameters
  • stats – Sufficient statistics tuple (SxxT, SxyT, SyyT, N).

  • partner_expectations – Dict with ‘mean’ and ‘second_moment’ from weights.

Returns

Updated InverseGamma distribution.

mode() Array[source]
property nat1: Array

First natural parameter η₁ = -α - 1.

nat1_0: Array
property nat2: Array

Second natural parameter η₂ = -β.

nat2_0: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (2,)
property natural_parameters: Array

Get natural parameters η = [-α-1, -β].

Returns

batch_shape + (2,)

Return type

Natural parameters [η₁, η₂] with shape

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [log(x), 1/x].

Parameters

x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + (2,)

Return type

Sufficient statistics [log(x), 1/x] with shape

class sppcax.distributions.InverseWishart(df0: float | jax.Array, scale0: Array)[source]

Bases: ExponentialFamily

Inverse Wishart distribution in natural parameters.

For Sigma ~ IW(df, Psi):

p(Sigma | df, Psi) propto |Sigma|^{-(df+k+1)/2} exp(-tr(Psi Sigma^{-1})/2)

Exponential family form:

eta1 = -(df + k + 1) / 2 (scalar) eta2 = -Psi / 2 (k x k matrix) T(Sigma) = [log|Sigma|, Sigma^{-1}] A(eta) = -(df/2) log|Psi/2| + multigammaln(df/2, k) + (df*k/2) log(2)

Key moments:

E[Sigma^{-1}] = df * Psi^{-1} E[log|Sigma^{-1}|] = multidigamma(df/2, k) + k*log(2) - log|Psi|

batch_shape: Tuple[int, ...]
property df: Array

df = -2*nat1 - k - 1.

Type

Degrees of freedom

property dim: int

Dimension k of the k x k matrices.

dnat1: Array
dnat2: Array
property entropy: Array

Entropy of IW(df, Psi).

event_shape: Tuple[int, ...]
property expected_psi: Array

E[Sigma^{-1}] = df * Psi^{-1}.

property expected_sufficient_statistics: tuple

E[T(Sigma)] = (E[log|Sigma|], E[Sigma^{-1}]).

property inv_scale: Array

Inverse scale matrix Psi^{-1}.

property kl_divergence_from_prior: Array

KL(posterior || prior) using stored prior natural parameters.

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log base measure log(h(x)) with shape

property log_normalizer: Array

Log normalizer A(df, Psi).

log_prob(x: Array) Array[source]

Log probability of x under IW(df, Psi).

property mean: Array

Mean E[Sigma] = Psi / (df - k - 1), requires df > k + 1.

mf_expectations() dict[source]

Return expectations for mean-field coordinate ascent partner.

mf_update(stats: tuple, partner_expectations: dict) InverseWishart[source]

Mean-field coordinate ascent update for InverseWishart noise.

For IW noise in a linear model y = W x + e, e ~ N(0, Sigma):

df_post = df_prior + N Psi_post = Psi_prior + E[sum_t (y_t - W x_t)(y_t - W x_t)^T]

= Psi_prior + SyyT - SxyT^T E[W]^T - E[W] SxyT + E[WW^T] SxxT

The quadratic term E[W SxxT W^T] decomposes under mean-field q(W) = prod_d N(w_d; m_d, V_d):

off-diag (i!=j): m_i^T SxxT m_j (independent rows) diag (i=i): tr(E[w_i w_i^T] SxxT) = m_i^T SxxT m_i + tr(V_i SxxT)

mode() Array[source]

Mode = Psi / (df + k + 1).

property nat1: Array

-(df + k + 1) / 2.

Type

First natural parameter

nat1_0: Array
property nat2: Array

-Psi / 2.

Type

Second natural parameter

nat2_0: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = ()
property natural_parameters: tuple

Natural parameters (eta1, eta2).

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from IW(df, Psi) via Bartlett decomposition.

property scale: Array

Scale matrix Psi = -2 * nat2.

sufficient_statistics(x: Array) tuple[source]

T(Sigma) = (log|Sigma|, Sigma^{-1}).

class sppcax.distributions.MeanField(weights: Distribution, noise: Distribution)[source]

Bases: Distribution

Mean-field composite distribution: q(W, Sigma) = q(W) x q(Sigma).

Components are independent. Posterior updates use coordinate ascent (alternating updates of weights and noise components).

weights

Distribution over weight parameters (MVN or Delta if frozen).

Type

sppcax.distributions.base.Distribution

noise

Distribution over noise parameters (InverseGamma, Delta, etc.).

Type

sppcax.distributions.base.Distribution

n_iter

Number of coordinate ascent iterations for posterior updates.

property alpha: Array

Shape parameter from InverseGamma noise (MVNIG compat).

batch_shape: Tuple[int, ...]
property beta: Array

Scale parameter from InverseGamma noise (MVNIG compat).

property col_covariance: Array

Column covariance (base covariance, same as covariance).

property covariance: Array

Covariance of the weights component (base, not scaled by noise).

entropy() Array[source]

Entropy of the mean-field distribution (sum of component entropies).

event_shape: Tuple[int, ...]
property expected_covariance: Array

Expected covariance E[sigma^2] * base_covariance.

For InverseGamma noise: scalar E[sigma^2] per row. For InverseWishart noise: full matrix E[Sigma], not factored with weights cov. For Delta noise: fixed value.

property expected_psi: Array

Expected noise precision E[1/sigma^2] from noise component.

property expected_sufficient_statistics_psi: Array

Expected sufficient statistics of noise precision (MVNIG compat).

property inv_gamma

InverseGamma component (for MVNIG compatibility).

log_prob(x: Tuple[Array, Array]) Array[source]

Compute log probability.

Parameters

x – Tuple of (cov, w) where: w: Value of the sample state cov: Value of the sample covariance

Returns

Log probability

property mask: Array

Mask from weights component (if available).

property mean: Array

Mean of the weights component.

mode() Tuple[Array, Array][source]

Compute joint mode (noise_cov_matrix, weights_mean).

Returns

Tuple of (noise covariance as matrix, weights mean).

property mvn: Distribution

Weights component (for ARD compatibility).

noise: Distribution
property precision: Array

Precision of the weights component.

sample(seed: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Tuple[Array, Array][source]

Sample from both components independently.

Parameters
  • seed – PRNG key.

  • sample_shape – Additional sample dimensions.

Returns

Tuple of (noise_sample, weights_sample).

weights: Distribution
class sppcax.distributions.MultivariateNormal(loc: Array, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None, mask: Optional[Array] = None)[source]

Bases: ExponentialFamily

Multivariate normal distribution in natural parameters.

apply_mask_matrix(x: Array, zeromask: bool = False) Array[source]

Apply mask to a matrix, zeroing out masked rows and columns.

Parameters
  • x – Matrix with shape (…, d, d)

  • zeromask – If True, set masked entries to 0. If False (default), set masked entries to identity matrix values.

Returns

Masked matrix with same shape.

apply_mask_vector(x: Array) Array[source]

Apply mask to a vector, zeroing out masked dimensions.

Parameters

x – Vector with shape (…, d)

Returns

Masked vector with same shape

batch_shape: Tuple[int, ...]
property covariance: Array
dnat1: Float[Array, 'dim']
dnat2: Float[Array, 'rows cols']
event_shape: Tuple[int, ...]
property expected_second_moment: Array

Expected second moment E[xx^T] = Cov + mean @ mean^T, per row.

property expected_sufficient_statistics: Array

Compute E[T(x)] = [μ, vec(μμ^T + Σ)].

Returns

Expected sufficient statistics [E[x], vec(E[xx^T])].

classmethod from_natural_parameters(nat1: Array, nat2: Array, mask: Optional[Array] = None) MultivariateNormal[source]

Create MVN from natural parameters.

Parameters
  • nat1 – First natural parameter (precision * mean).

  • nat2 – Second natural parameter (-0.5 * precision).

  • mask – Optional boolean mask with shape matching nat1

Returns

MultivariateNormal instance.

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log base measure log(h(x)) with shape

property log_normalizer: Array

Compute log normalizer A(η).

Returns

batch_shape

Return type

Log normalizer A(η) with shape

mask: Float[Array, 'dim']
property mean: Array

Get mean parameter.

mf_expectations() dict[source]

Return expectations for mean-field coordinate ascent partner.

mf_update(stats: tuple, partner_expectations: dict) MultivariateNormal[source]

Mean-field coordinate ascent update for MVN weights.

Scales data sufficient statistics by partner’s expected precision E[1/sigma^2], then performs standard natural parameter update.

Parameters
  • stats – Sufficient statistics tuple (SxxT, SxyT, SyyT, N).

  • partner_expectations – Dict with ‘expected_precision’ from noise component. E[precision] can be scalar (isotropic), (D,) vector (per-feature IG), or (D, D) matrix (InverseWishart). For matrix precision, the diagonal is used for per-row weight updates.

Returns

Updated MVN distribution (posterior).

property nat1: Array

First natural parameter (precision * mean), masked.

nat1_0: Float[Array, 'dim']
property nat2: Array

Second natural parameter (-0.5 * precision).

nat2_0: Float[Array, 'rows cols']
natural_param_shape: ClassVar[Tuple[int, ...]] = (1,)
property natural_parameters: Array

Get natural parameters η = [precision*mean, -0.5*precision].

Returns

Natural parameters [η₁, vec(η₂)].

property precision: Array

Get precision parameter.

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

Samples from the distribution.

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [x, xx^T].

Parameters

x – Value to compute sufficient statistics for.

Returns

Sufficient statistics [x, vec(xx^T)].

class sppcax.distributions.MultivariateNormalInverseGamma(loc: Array, *, isotropic_noise: bool, mask: Optional[Array] = None, alpha0: float = 2.0, beta0: float = 1.0, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None)[source]

Bases: ExponentialFamily

Multivariate Normal-Gamma distribution.

This distribution combines: p(x|σ²) = N(x; μ, σ²Λ⁻¹) # Multivariate Normal p(σ²) = InverseGamma(σ²; α, β) # Variance scalar

where: - x is a vector (can be batched) - μ is the location parameter - Λ is a base precision matrix - σ² is a scalar variance parameter - α, β are Inverse-Gamma distribution parameters

property alpha: Array

Shape parameter α of the InverseGamma component.

batch_shape: Shape
property beta: Array

Scale parameter β of the InverseGamma component.

property col_covariance: Array

Column covariance Λ⁻¹ (alias for covariance).

property covariance: Array

Base covariance Λ⁻¹ of the MVN component (without noise scaling).

event_shape: Shape
property expected_covariance: Array

Expected covariance E[σ²] * Λ⁻¹.

property expected_log_psi: Array

Compute expected log precision E[log(psi)].

property expected_psi: Array

Compute expected precision E[psi].

property expected_sufficient_statistics_psi: Array

Compute expected sufficient statistics of psi.

inv_gamma: InverseGamma
log_prob(x: Tuple[Array, Array]) Array[source]

Compute log probability.

Parameters

x – Tuple of (sig_sqr, w) where: w: Value of the sample state sig_sqr: Value of the sample variance

Returns

Log probability

property mean: Array

Get mean of the marginal distribution p(x).

mode() Tuple[Array, Array][source]

Solve for the mode. Recall, ..math:

p(\mu, \sigma^2) \propto
    \mathrm{N}(x \mid \mu, \sigma^2 \Sigma) \times
    \mathrm{IG}(\sigma^2 \mid \alpha, \beta)

The optimal mean is \(x^* = \mu_0\). Substituting this in, ..math:

p(\mu^*, \sigma^2) \propto IG(\sigma^2 \mid \alpha + D/2, \beta)

where D is dimensionality of x, and the mode of this inverse gamma distribution is at ..math:

(\sigma^2)* = \beta / (\alpha + (D + 2)/2)
mvn: MultivariateNormal
property precision: Array

Base precision matrix Λ of the MVN component.

sample(seed: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Tuple[Array, Array][source]

Sample from the distribution.

Parameters
  • seed – PRNG key

  • sample_shape – Shape of samples to draw

Returns

Tuple of (sig_sqr, value) samples

class sppcax.distributions.Normal(loc: Array = 0.0, scale: Array = 1.0)[source]

Bases: ExponentialFamily

Univariate normal distribution in natural parameters.

The normal distribution has density: p(x|μ,σ) = 1/√(2πσ²) * exp(-(x-μ)²/(2σ²))

In exponential family form: η = [μ/σ², -1/(2σ²)] T(x) = [x, x²] A(η) = -η₁²/(4η₂) - (1/2)log(-2η₂) + (1/2)log(2π)

batch_shape: Tuple[int, ...]
property entropy: Array

Compute entropy of the distribution.

Returns

batch_shape

Return type

Entropy with shape

event_shape: Tuple[int, ...]
property expected_sufficient_statistics: Array

Compute E[T(x)] = [μ, μ² + σ²].

Returns

batch_shape + (2,)

Return type

Expected sufficient statistics [E[x], E[x²]] with shape

classmethod from_natural_parameters(eta: Array) Normal[source]

Create normal distribution from natural parameters.

Parameters

eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)

Returns

Normal distribution.

property loc: Array

Get location parameter.

log_base_measure(x: Array = None) Array[source]

Compute log of base measure h(x).

Parameters

x – Data to compute base measure for. Shape: batch_shape + event_shape

Returns

batch_shape

Return type

Log base measure log(h(x)) with shape

property log_normalizer: Array

Compute log normalizer A(η).

Returns

batch_shape

Return type

Log normalizer with shape

nat1: Array
nat2: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = (2,)
property natural_parameters: Array

Get natural parameters η = [precision*mean, -0.5*precision].

Returns

batch_shape + (2,)

Return type

Natural parameters [η₁, η₂] with shape

property precision: Array

Get precision parameter

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]

Sample from the distribution.

Parameters
  • key – PRNG key for random sampling.

  • sample_shape – Shape of samples to draw.

Returns

sample_shape + batch_shape + event_shape

Return type

Samples with shape

property scale: Array

Get scale parameter.

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = [x, x²].

Parameters

x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape

Returns

batch_shape + (2,)

Return type

Sufficient statistics [x, x²] with shape

class sppcax.distributions.Poisson(log_rate: Array)[source]

Bases: ExponentialFamily

Poisson distribution parameterized by log rate.

The Poisson distribution has natural parameter η = log(λ) where λ is the rate parameter, and sufficient statistic T(x) = x.

batch_shape: Tuple[int, ...]
property entropy: Array

Compute entropy of Poisson distribution.

Returns

Entropy H(λ) = λ(1 - log(λ)) + exp(-λ)sum_{k=0}^∞ λ^k log(k!)/k!

event_shape: Tuple[int, ...]
property expected_sufficient_statistics: Array

Compute E[T(x)] = E[x] = λ = exp(η).

Returns

Expected sufficient statistics E[x].

classmethod from_natural_parameters(eta: Array) Poisson[source]

Create Poisson from natural parameters.

Parameters

log_rate – Natural parameter η = log(λ).

Returns

Poisson instance.

log_base_measure(x: Array) Array[source]

Compute log base measure log(h(x)) = -log(x!).

Parameters

x – Count data.

Returns

Log base measure -log(x!).

property log_normalizer: Array

Compute log normalizer A(η) = exp(η).

Returns

Log normalizer A(η).

property log_rate: Array

Get log(rate) parameter.

nat1: Array
natural_param_shape: ClassVar[Tuple[int, ...]] = ()
property natural_parameters: Array

Get natural parameters (log rate).

Returns

Natural parameters η = log(λ).

property rate: Array

Get rate parameter.

sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Optional[Tuple[int, ...]] = ()) Array[source]

Sample from Poisson distribution.

Parameters
  • key – PRNG key.

  • sample_shape – Shape of samples to draw.

Returns

Count samples from distribution.

sufficient_statistics(x: Array) Array[source]

Compute sufficient statistics T(x) = x.

Parameters

x – Count data.

Returns

Sufficient statistics T(x) = x.