sppcax.distributions package
Submodules
sppcax.distributions.base module
Base distribution class.
- class sppcax.distributions.base.Distribution(batch_shape: Tuple[int, ...], event_shape: Tuple[int, ...])[source]
Bases:
ModuleBase distribution class in natural parameters.
- broadcast_to_shape(x: Array, ignore_event: bool = False) Array[source]
Broadcast array to match distribution shape.
- Parameters
x – Array to broadcast.
ignore_event – If True, only broadcast batch dimensions.
- Returns
Broadcasted array.
- entropy() Array[source]
Compute entropy of the distribution.
- Returns
batch_shape
- Return type
Entropy with shape
- log_prob(x: Array) Array[source]
Compute log probability of x.
- Parameters
x – Value to compute log probability for. Should have shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log probability with shape
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Optional[Tuple[int, ...]] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Additional sample dimensions.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
sppcax.distributions.beta module
Beta distribution implementation.
- class sppcax.distributions.beta.Beta(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]
Bases:
ExponentialFamilyBeta distribution in natural parameters.
The beta distribution has density: p(x|α,β) = x^(α-1) * (1-x)^(β-1) / B(α,β) for x ∈ [0,1]
In exponential family form: h(x) = 1 η = [α-1, β-1] T(x) = [log(x), log(1-x)] A(η) = log(B(η₁+1, η₂+1))
- property alpha: Array
Get first shape parameter α.
- property beta: Array
Get second shape parameter β.
- dnat1: Array
- dnat2: Array
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [ψ(α) - ψ(α+β), ψ(β) - ψ(α+β)].
- Returns
batch_shape + (2,)
- Return type
Expected sufficient statistics [E[log(x)], E[log(1-x)]] with shape
- classmethod from_natural_parameters(eta: Array) Beta[source]
Create beta distribution from natural parameters.
- Parameters
eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)
- Returns
Beta distribution.
- property kl_divergence_from_prior: Array
Compute KL divergence KL(post||prior).
- Returns
batch_shape
- Return type
KL divergence KL(post||prior) with shape
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
zero
- property log_normalizer: Array
Compute log normalizer A(η) = log(B(η₁+1, η₂+1)).
- Returns
batch_shape
- Return type
Log normalizer with shape
- property mean: Array
Get mean of the distribution.
- property nat1: Array
First natural parameter η₁ = α - 1.
- nat1_0: Array
- property nat2: Array
Second natural parameter η₂ = β - 1.
- nat2_0: Array
- property natural_parameters: Array
Get natural parameters η = [α-1, β-1].
- Returns
batch_shape + (2,)
- Return type
Natural parameters [η₁, η₂] with shape
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
- sufficient_statistics(x: Array) Array[source]
Compute sufficient statistics T(x) = [log(x), log(1-x)].
- Parameters
x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape
- Returns
batch_shape + (2,)
- Return type
Sufficient statistics [log(x), log(1-x)] with shape
- property variance: Array
Get variance of the distribution.
sppcax.distributions.categorical module
Categorical distribution in natural parameterization.
- class sppcax.distributions.categorical.Categorical(logits: Array)[source]
Bases:
ExponentialFamilyCategorical distribution parameterized by logits.
The Categorical distribution has K-1 natural parameters η_k = log(p_k/p_K) for k=1,…,K-1 where p_K is the probability of the last category.
- property expected_sufficient_statistics: Array
Compute E[T(x)] = p_1,…,p_{K-1}.
- Returns
Expected probabilities for first K-1 categories.
- classmethod from_natural_parameters(eta: Array) Categorical[source]
Create Categorical from natural parameters.
- Parameters
eta – Log-odds relative to last category. Shape (…, K-1) where K is number of categories.
- Returns
Categorical instance.
- property full_logits: Array
- property log_normalizer: Array
Compute log normalizer A(η).
For categorical, A(η) = log(1 + sum(exp(η_k))).
- Returns
Log normalizer A(η).
- nat1: Array
- property natural_parameters: Array
Get natural parameters (logits).
- Returns
Natural parameters η.
- property probs: Array
Get probability parameters.
sppcax.distributions.delta module
Delta distribution implementation.
- class sppcax.distributions.delta.Delta(location: Array, sufficient_statistics_fn: Optional[Callable] = None)[source]
Bases:
DistributionDelta distribution (Dirac delta) concentrated at a single point.
- property covariance: Array
Covariance matrix (always zero for delta distribution).
- entropy() Array[source]
Compute entropy (always 0 for delta distribution).
- Returns
batch_shape
- Return type
Entropy with shape
- property expected_psi: Array
matrix inverse for square matrices, element-wise otherwise.
- Type
Expected precision for Delta
- property expected_second_moment: Array
Expected second moment E[XX^T] = location @ location^T (no variance).
- property expected_sufficient_statistics: Array
Compute expected sufficient statistics.
For delta distribution, this is just sufficient_statistics(location) since all probability mass is concentrated at location.
- Returns
Expected sufficient statistics with shape determined by sufficient_statistics_fn.
- log_prob(x: Array) Array[source]
Compute log probability.
- Parameters
x – Value to compute log probability for. Shape: batch_shape + event_shape
- Returns
batch_shape Returns 0 at location, -inf elsewhere.
- Return type
Log probability with shape
- mean: Array
- property precision: Array
Precision (infinite for delta, but return large finite value for compatibility).
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution (always returns location).
- Parameters
key – PRNG key (unused).
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape All samples are equal to location.
- Return type
Samples with shape
sppcax.distributions.exponential_family module
Base class for exponential family distributions.
- class sppcax.distributions.exponential_family.ExponentialFamily(batch_shape: Tuple[int, ...] = (), event_shape: Tuple[int, ...] = ())[source]
Bases:
DistributionBase class for exponential family distributions in natural parameterization.
The exponential family has the form: p(x|η) = h(x)exp(η^T T(x) - A(η)) where: - η: natural parameters - T(x): sufficient statistics - A(η): log normalizer - h(x): base measure
- property entropy: Array
Compute entropy of the distribution.
- Returns
batch_shape
- Return type
Entropy with shape
- property expected_log_base_measure: Array
Compute the expectation of the log base measure E_{p(x)}[log(h(x))]
- Returns
batch_shape
- Return type
Expectation E_{p(x)}[log(h(x))] with shape
- property expected_sufficient_statistics: Array
Compute expected sufficient statistics E[T(x)].
- Returns
batch_shape + natural_param_shape
- Return type
Expected sufficient statistics E[T(x)] with shape
- classmethod from_natural_parameters(eta: Array) ExponentialFamily[source]
Create distribution from natural parameters.
- Parameters
eta – Natural parameters with shape: batch_shape + natural_param_shape
- Returns
Distribution instance.
- kl_divergence(other: ExponentialFamily) Array[source]
Compute KL divergence KL(self||other).
- Parameters
other – Other distribution to compute KL divergence with.
- Returns
batch_shape
- Return type
KL divergence KL(self||other) with shape
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log base measure log(h(x)) with shape
- property log_normalizer: Array
Compute log normalizer A(η).
- Returns
batch_shape
- Return type
Log normalizer A(η) with shape
- log_prob(x: Array) Array[source]
Compute log probability.
- Parameters
x – Data to compute log probability for. Shape: batch_shape + event_shape
- Returns
batch_shape Returns -inf for values outside the support.
- Return type
Log probability log p(x|η) with shape
- property natural_parameters: Array
Get natural parameters of the distribution.
- Returns
batch_shape + natural_param_shape
- Return type
Natural parameters η with shape
sppcax.distributions.gamma module
Gamma distribution implementation.
- class sppcax.distributions.gamma.Gamma(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]
Bases:
ExponentialFamilyGamma distribution in natural parameters.
The gamma distribution has density: p(x|α,β) = β^α * x^(α-1) * exp(-βx) / Γ(α)
In exponential family form: h(x) = 1 η = [α-1, -β] T(x) = [log(x), x] A(η) = log(Γ(η₁ + 1)) - (η₁ + 1)*log(-η₂)
- property alpha: Array
Get shape parameter α.
- property beta: Array
Get rate parameter β.
- dnat1: Array
- dnat2: Array
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [ψ(α) - log(β), α/β].
- Returns
batch_shape + (2,)
- Return type
Expected sufficient statistics [E[log(x)], E[x]] with shape
- classmethod from_natural_parameters(eta: Array) Gamma[source]
Create gamma distribution from natural parameters.
- Parameters
eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)
- Returns
Gamma distribution.
- property kl_divergence_from_prior: Array
Compute KL divergence KL(post||prior).
- Returns
batch_shape
- Return type
KL divergence KL(post||prior) with shape
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
zero
- property log_normalizer: Array
Compute log normalizer A(η) = log(Γ(η₁ + 1)) - (η₁ + 1)*log(-η₂).
- Returns
batch_shape
- Return type
Log normalizer with shape
- property mean: Array
Get mean E[x] = α/β.
- property nat1: Array
First natural parameter η₁ = α - 1.
- nat1_0: Array
- property nat2: Array
Second natural parameter η₂ = -β.
- nat2_0: Array
- property natural_parameters: Array
Get natural parameters η = [α-1, -β].
- Returns
batch_shape + (2,)
- Return type
Natural parameters [η₁, η₂] with shape
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
- class sppcax.distributions.gamma.InverseGamma(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]
Bases:
ExponentialFamilyInverse Gamma distribution in natural parameters.
The inverse gamma distribution has density: p(x|α,β) = β^α * x^(-α-1) * exp(-β/x) / Γ(α)
In exponential family form: h(x) = 1 η = [-α-1, -β] T(x) = [log(x), 1/x] A(η) = log(Γ(-η₁ - 1)) + (η₁ + 1)*log(-η₂)
- property alpha: Array
Get shape parameter α.
- property beta: Array
Get scale parameter β.
- dnat1: Array
- dnat2: Array
- property expected_psi: Array
Expected precision E[1/x] = alpha/beta for InverseGamma.
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [log(β) - ψ(α), α/β].
- Returns
batch_shape + (2,)
- Return type
Expected sufficient statistics [E[log(x)], E[1/x]] with shape
- classmethod from_natural_parameters(eta: Array) InverseGamma[source]
Create inverse gamma distribution from natural parameters.
- Parameters
eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)
- Returns
InverseGamma distribution.
- property kl_divergence_from_prior: Array
Compute KL divergence KL(post||prior).
- Returns
batch_shape
- Return type
KL divergence KL(post||prior) with shape
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
zero
- property log_normalizer: Array
Compute log normalizer A(η) = log(Γ(-η₁ - 1)) + (η₁ + 1)*log(-η₂)
- Returns
batch_shape
- Return type
Log normalizer with shape
- property mean: Array
Get mean E[x] = β/(α-1).
- mf_update(stats: tuple, partner_expectations: dict) InverseGamma[source]
Mean-field coordinate ascent update for InverseGamma noise.
Given partner (weights) expectations, compute the residual and update the InverseGamma natural parameters.
- Parameters
stats – Sufficient statistics tuple (SxxT, SxyT, SyyT, N).
partner_expectations – Dict with ‘mean’ and ‘second_moment’ from weights.
- Returns
Updated InverseGamma distribution.
- property nat1: Array
First natural parameter η₁ = -α - 1.
- nat1_0: Array
- property nat2: Array
Second natural parameter η₂ = -β.
- nat2_0: Array
- property natural_parameters: Array
Get natural parameters η = [-α-1, -β].
- Returns
batch_shape + (2,)
- Return type
Natural parameters [η₁, η₂] with shape
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
sppcax.distributions.inverse_wishart module
Inverse Wishart distribution with mean-field interface.
Wraps the dynamax InverseWishart with natural parameters and coordinate ascent update methods for use in MeanField composites.
- class sppcax.distributions.inverse_wishart.InverseWishart(df0: float | jax.Array, scale0: Array)[source]
Bases:
ExponentialFamilyInverse Wishart distribution in natural parameters.
- For Sigma ~ IW(df, Psi):
p(Sigma | df, Psi) propto |Sigma|^{-(df+k+1)/2} exp(-tr(Psi Sigma^{-1})/2)
- Exponential family form:
eta1 = -(df + k + 1) / 2 (scalar) eta2 = -Psi / 2 (k x k matrix) T(Sigma) = [log|Sigma|, Sigma^{-1}] A(eta) = -(df/2) log|Psi/2| + multigammaln(df/2, k) + (df*k/2) log(2)
- Key moments:
E[Sigma^{-1}] = df * Psi^{-1} E[log|Sigma^{-1}|] = multidigamma(df/2, k) + k*log(2) - log|Psi|
- property df: Array
df = -2*nat1 - k - 1.
- Type
Degrees of freedom
- dnat1: Array
- dnat2: Array
- property entropy: Array
Entropy of IW(df, Psi).
- property expected_psi: Array
E[Sigma^{-1}] = df * Psi^{-1}.
- property inv_scale: Array
Inverse scale matrix Psi^{-1}.
- property kl_divergence_from_prior: Array
KL(posterior || prior) using stored prior natural parameters.
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log base measure log(h(x)) with shape
- property log_normalizer: Array
Log normalizer A(df, Psi).
- property mean: Array
Mean E[Sigma] = Psi / (df - k - 1), requires df > k + 1.
- mf_update(stats: tuple, partner_expectations: dict) InverseWishart[source]
Mean-field coordinate ascent update for InverseWishart noise.
- For IW noise in a linear model y = W x + e, e ~ N(0, Sigma):
df_post = df_prior + N Psi_post = Psi_prior + E[sum_t (y_t - W x_t)(y_t - W x_t)^T]
= Psi_prior + SyyT - SxyT^T E[W]^T - E[W] SxyT + E[WW^T] SxxT
- The quadratic term E[W SxxT W^T] decomposes under mean-field q(W) = prod_d N(w_d; m_d, V_d):
off-diag (i!=j): m_i^T SxxT m_j (independent rows) diag (i=i): tr(E[w_i w_i^T] SxxT) = m_i^T SxxT m_i + tr(V_i SxxT)
- property nat1: Array
-(df + k + 1) / 2.
- Type
First natural parameter
- nat1_0: Array
- property nat2: Array
-Psi / 2.
- Type
Second natural parameter
- nat2_0: Array
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from IW(df, Psi) via Bartlett decomposition.
- property scale: Array
Scale matrix Psi = -2 * nat2.
sppcax.distributions.mean_field module
Mean-field composite distribution with independent components.
- class sppcax.distributions.mean_field.MeanField(weights: Distribution, noise: Distribution)[source]
Bases:
DistributionMean-field composite distribution: q(W, Sigma) = q(W) x q(Sigma).
Components are independent. Posterior updates use coordinate ascent (alternating updates of weights and noise components).
- weights
Distribution over weight parameters (MVN or Delta if frozen).
- noise
Distribution over noise parameters (InverseGamma, Delta, etc.).
- n_iter
Number of coordinate ascent iterations for posterior updates.
- property alpha: Array
Shape parameter from InverseGamma noise (MVNIG compat).
- property beta: Array
Scale parameter from InverseGamma noise (MVNIG compat).
- property col_covariance: Array
Column covariance (base covariance, same as covariance).
- property covariance: Array
Covariance of the weights component (base, not scaled by noise).
- property expected_covariance: Array
Expected covariance E[sigma^2] * base_covariance.
For InverseGamma noise: scalar E[sigma^2] per row. For InverseWishart noise: full matrix E[Sigma], not factored with weights cov. For Delta noise: fixed value.
- property expected_psi: Array
Expected noise precision E[1/sigma^2] from noise component.
- property expected_sufficient_statistics_psi: Array
Expected sufficient statistics of noise precision (MVNIG compat).
- property inv_gamma
InverseGamma component (for MVNIG compatibility).
- log_prob(x: Tuple[Array, Array]) Array[source]
Compute log probability.
- Parameters
x – Tuple of (cov, w) where: w: Value of the sample state cov: Value of the sample covariance
- Returns
Log probability
- property mask: Array
Mask from weights component (if available).
- property mean: Array
Mean of the weights component.
- mode() Tuple[Array, Array][source]
Compute joint mode (noise_cov_matrix, weights_mean).
- Returns
Tuple of (noise covariance as matrix, weights mean).
- property mvn: Distribution
Weights component (for ARD compatibility).
- noise: Distribution
- property precision: Array
Precision of the weights component.
- sample(seed: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Tuple[Array, Array][source]
Sample from both components independently.
- Parameters
seed – PRNG key.
sample_shape – Additional sample dimensions.
- Returns
Tuple of (noise_sample, weights_sample).
- weights: Distribution
sppcax.distributions.mvn module
Multivariate normal distribution implementation.
- class sppcax.distributions.mvn.MultivariateNormal(loc: Array, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None, mask: Optional[Array] = None)[source]
Bases:
ExponentialFamilyMultivariate normal distribution in natural parameters.
- apply_mask_matrix(x: Array, zeromask: bool = False) Array[source]
Apply mask to a matrix, zeroing out masked rows and columns.
- Parameters
x – Matrix with shape (…, d, d)
zeromask – If True, set masked entries to 0. If False (default), set masked entries to identity matrix values.
- Returns
Masked matrix with same shape.
- apply_mask_vector(x: Array) Array[source]
Apply mask to a vector, zeroing out masked dimensions.
- Parameters
x – Vector with shape (…, d)
- Returns
Masked vector with same shape
- property covariance: Array
- dnat1: Float[Array, 'dim']
- dnat2: Float[Array, 'rows cols']
- property expected_second_moment: Array
Expected second moment E[xx^T] = Cov + mean @ mean^T, per row.
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [μ, vec(μμ^T + Σ)].
- Returns
Expected sufficient statistics [E[x], vec(E[xx^T])].
- classmethod from_natural_parameters(nat1: Array, nat2: Array, mask: Optional[Array] = None) MultivariateNormal[source]
Create MVN from natural parameters.
- Parameters
nat1 – First natural parameter (precision * mean).
nat2 – Second natural parameter (-0.5 * precision).
mask – Optional boolean mask with shape matching nat1
- Returns
MultivariateNormal instance.
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log base measure log(h(x)) with shape
- property log_normalizer: Array
Compute log normalizer A(η).
- Returns
batch_shape
- Return type
Log normalizer A(η) with shape
- mask: Float[Array, 'dim']
- property mean: Array
Get mean parameter.
- mf_update(stats: tuple, partner_expectations: dict) MultivariateNormal[source]
Mean-field coordinate ascent update for MVN weights.
Scales data sufficient statistics by partner’s expected precision E[1/sigma^2], then performs standard natural parameter update.
- Parameters
stats – Sufficient statistics tuple (SxxT, SxyT, SyyT, N).
partner_expectations – Dict with ‘expected_precision’ from noise component. E[precision] can be scalar (isotropic), (D,) vector (per-feature IG), or (D, D) matrix (InverseWishart). For matrix precision, the diagonal is used for per-row weight updates.
- Returns
Updated MVN distribution (posterior).
- property nat1: Array
First natural parameter (precision * mean), masked.
- nat1_0: Float[Array, 'dim']
- property nat2: Array
Second natural parameter (-0.5 * precision).
- nat2_0: Float[Array, 'rows cols']
- property natural_parameters: Array
Get natural parameters η = [precision*mean, -0.5*precision].
- Returns
Natural parameters [η₁, vec(η₂)].
- property precision: Array
Get precision parameter.
sppcax.distributions.mvn_gamma module
Multivariate Normal-Gamma distribution implementation.
- class sppcax.distributions.mvn_gamma.MultivariateNormalInverseGamma(loc: Array, *, isotropic_noise: bool, mask: Optional[Array] = None, alpha0: float = 2.0, beta0: float = 1.0, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None)[source]
Bases:
ExponentialFamilyMultivariate Normal-Gamma distribution.
This distribution combines: p(x|σ²) = N(x; μ, σ²Λ⁻¹) # Multivariate Normal p(σ²) = InverseGamma(σ²; α, β) # Variance scalar
where: - x is a vector (can be batched) - μ is the location parameter - Λ is a base precision matrix - σ² is a scalar variance parameter - α, β are Inverse-Gamma distribution parameters
- property alpha: Array
Shape parameter α of the InverseGamma component.
- property beta: Array
Scale parameter β of the InverseGamma component.
- property col_covariance: Array
Column covariance Λ⁻¹ (alias for covariance).
- property covariance: Array
Base covariance Λ⁻¹ of the MVN component (without noise scaling).
- property expected_covariance: Array
Expected covariance E[σ²] * Λ⁻¹.
- property expected_log_psi: Array
Compute expected log precision E[log(psi)].
- property expected_psi: Array
Compute expected precision E[psi].
- property expected_sufficient_statistics_psi: Array
Compute expected sufficient statistics of psi.
- inv_gamma: InverseGamma
- log_prob(x: Tuple[Array, Array]) Array[source]
Compute log probability.
- Parameters
x – Tuple of (sig_sqr, w) where: w: Value of the sample state sig_sqr: Value of the sample variance
- Returns
Log probability
- property mean: Array
Get mean of the marginal distribution p(x).
- mode() Tuple[Array, Array][source]
Solve for the mode. Recall, ..math:
p(\mu, \sigma^2) \propto \mathrm{N}(x \mid \mu, \sigma^2 \Sigma) \times \mathrm{IG}(\sigma^2 \mid \alpha, \beta)
The optimal mean is \(x^* = \mu_0\). Substituting this in, ..math:
p(\mu^*, \sigma^2) \propto IG(\sigma^2 \mid \alpha + D/2, \beta)
where D is dimensionality of x, and the mode of this inverse gamma distribution is at ..math:
(\sigma^2)* = \beta / (\alpha + (D + 2)/2)
- mvn: MultivariateNormal
- property precision: Array
Base precision matrix Λ of the MVN component.
- sppcax.distributions.mvn_gamma.mvnig_posterior_update(mvnig_prior: MultivariateNormalInverseGamma, sufficient_stats: tuple, props: object) MultivariateNormalInverseGamma[source]
Update the multivariate normal inverse gamma (MVNIG) posterior from sufficient statistics.
- Parameters
mvnig_prior – Prior MVNIG distribution.
sufficient_stats – Tuple of (SxxT, SxyT, SyyT, N) sufficient statistics.
props – Parameter properties controlling which parameters are trainable.
- Returns
Posterior MVNIG distribution.
sppcax.distributions.normal module
Normal distribution implementations.
- class sppcax.distributions.normal.Normal(loc: Array = 0.0, scale: Array = 1.0)[source]
Bases:
ExponentialFamilyUnivariate normal distribution in natural parameters.
The normal distribution has density: p(x|μ,σ) = 1/√(2πσ²) * exp(-(x-μ)²/(2σ²))
In exponential family form: η = [μ/σ², -1/(2σ²)] T(x) = [x, x²] A(η) = -η₁²/(4η₂) - (1/2)log(-2η₂) + (1/2)log(2π)
- property entropy: Array
Compute entropy of the distribution.
- Returns
batch_shape
- Return type
Entropy with shape
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [μ, μ² + σ²].
- Returns
batch_shape + (2,)
- Return type
Expected sufficient statistics [E[x], E[x²]] with shape
- classmethod from_natural_parameters(eta: Array) Normal[source]
Create normal distribution from natural parameters.
- Parameters
eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)
- Returns
Normal distribution.
- property loc: Array
Get location parameter.
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log base measure log(h(x)) with shape
- property log_normalizer: Array
Compute log normalizer A(η).
- Returns
batch_shape
- Return type
Log normalizer with shape
- nat1: Array
- nat2: Array
- property natural_parameters: Array
Get natural parameters η = [precision*mean, -0.5*precision].
- Returns
batch_shape + (2,)
- Return type
Natural parameters [η₁, η₂] with shape
- property precision: Array
Get precision parameter
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
- property scale: Array
Get scale parameter.
sppcax.distributions.poisson module
Poisson distribution in natural parameterization.
- class sppcax.distributions.poisson.Poisson(log_rate: Array)[source]
Bases:
ExponentialFamilyPoisson distribution parameterized by log rate.
The Poisson distribution has natural parameter η = log(λ) where λ is the rate parameter, and sufficient statistic T(x) = x.
- property entropy: Array
Compute entropy of Poisson distribution.
- Returns
Entropy H(λ) = λ(1 - log(λ)) + exp(-λ)sum_{k=0}^∞ λ^k log(k!)/k!
- property expected_sufficient_statistics: Array
Compute E[T(x)] = E[x] = λ = exp(η).
- Returns
Expected sufficient statistics E[x].
- classmethod from_natural_parameters(eta: Array) Poisson[source]
Create Poisson from natural parameters.
- Parameters
log_rate – Natural parameter η = log(λ).
- Returns
Poisson instance.
- log_base_measure(x: Array) Array[source]
Compute log base measure log(h(x)) = -log(x!).
- Parameters
x – Count data.
- Returns
Log base measure -log(x!).
- property log_normalizer: Array
Compute log normalizer A(η) = exp(η).
- Returns
Log normalizer A(η).
- property log_rate: Array
Get log(rate) parameter.
- nat1: Array
- property natural_parameters: Array
Get natural parameters (log rate).
- Returns
Natural parameters η = log(λ).
- property rate: Array
Get rate parameter.
sppcax.distributions.updates module
sppcax.distributions.utils module
Utility functions for distributions.
- sppcax.distributions.utils.cho_inv(matrix: Float[Array, 'rows cols'], diagonal_boost: float = 1e-12) Float[Array, 'rows cols'][source]
Invert a positive-definite matrix via Cholesky decomposition.
- Parameters
matrix – Positive-definite matrix with shape (…, d, d).
diagonal_boost – Small value added to diagonal for numerical stability.
- Returns
Inverse matrix with same shape.
- sppcax.distributions.utils.moment_to_natural(mean: Array, covariance: Array) tuple[jax.Array, jax.Array][source]
Convert moment parameters to natural parameters for multivariate normal.
- Parameters
mean – Mean vector.
covariance – Covariance matrix.
- Returns
Tuple of (nat1, nat2).
- sppcax.distributions.utils.natural_to_moment(nat1: Array, nat2: Array) tuple[jax.Array, jax.Array][source]
Convert natural parameters to moment parameters for multivariate normal.
- Parameters
nat1 – First natural parameter (precision * mean).
nat2 – Second natural parameter (-0.5 * precision).
- Returns
Tuple of (mean, covariance).
- sppcax.distributions.utils.safe_cholesky(X: Array, jitter: float = 1e-12) Array[source]
Compute Cholesky decomposition with added diagonal jitter for numerical stability.
- Parameters
X – Symmetric positive definite matrix.
jitter – Small positive value to add to diagonal for stability.
- Returns
Lower triangular Cholesky factor L such that X ≈ L @ L^T.
- sppcax.distributions.utils.safe_cholesky_and_logdet(X: Array, jitter: float = 1e-12) tuple[jax.Array, jaxtyping.Float[Array, '']][source]
Compute Cholesky decomposition and log determinant with added diagonal jitter.
- Parameters
X – Symmetric positive definite matrix.
jitter – Small positive value to add to diagonal for stability.
- Returns
Tuple of (L, logdet) where L is the lower triangular Cholesky factor and logdet is the log determinant of X.
Module contents
Distribution classes.
- class sppcax.distributions.Beta(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]
Bases:
ExponentialFamilyBeta distribution in natural parameters.
The beta distribution has density: p(x|α,β) = x^(α-1) * (1-x)^(β-1) / B(α,β) for x ∈ [0,1]
In exponential family form: h(x) = 1 η = [α-1, β-1] T(x) = [log(x), log(1-x)] A(η) = log(B(η₁+1, η₂+1))
- property alpha: Array
Get first shape parameter α.
- property beta: Array
Get second shape parameter β.
- dnat1: Array
- dnat2: Array
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [ψ(α) - ψ(α+β), ψ(β) - ψ(α+β)].
- Returns
batch_shape + (2,)
- Return type
Expected sufficient statistics [E[log(x)], E[log(1-x)]] with shape
- classmethod from_natural_parameters(eta: Array) Beta[source]
Create beta distribution from natural parameters.
- Parameters
eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)
- Returns
Beta distribution.
- property kl_divergence_from_prior: Array
Compute KL divergence KL(post||prior).
- Returns
batch_shape
- Return type
KL divergence KL(post||prior) with shape
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
zero
- property log_normalizer: Array
Compute log normalizer A(η) = log(B(η₁+1, η₂+1)).
- Returns
batch_shape
- Return type
Log normalizer with shape
- property mean: Array
Get mean of the distribution.
- property nat1: Array
First natural parameter η₁ = α - 1.
- nat1_0: Array
- property nat2: Array
Second natural parameter η₂ = β - 1.
- nat2_0: Array
- property natural_parameters: Array
Get natural parameters η = [α-1, β-1].
- Returns
batch_shape + (2,)
- Return type
Natural parameters [η₁, η₂] with shape
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
- sufficient_statistics(x: Array) Array[source]
Compute sufficient statistics T(x) = [log(x), log(1-x)].
- Parameters
x – Value to compute sufficient statistics for. Shape: batch_shape + event_shape
- Returns
batch_shape + (2,)
- Return type
Sufficient statistics [log(x), log(1-x)] with shape
- property variance: Array
Get variance of the distribution.
- class sppcax.distributions.Categorical(logits: Array)[source]
Bases:
ExponentialFamilyCategorical distribution parameterized by logits.
The Categorical distribution has K-1 natural parameters η_k = log(p_k/p_K) for k=1,…,K-1 where p_K is the probability of the last category.
- property expected_sufficient_statistics: Array
Compute E[T(x)] = p_1,…,p_{K-1}.
- Returns
Expected probabilities for first K-1 categories.
- classmethod from_natural_parameters(eta: Array) Categorical[source]
Create Categorical from natural parameters.
- Parameters
eta – Log-odds relative to last category. Shape (…, K-1) where K is number of categories.
- Returns
Categorical instance.
- property full_logits: Array
- property log_normalizer: Array
Compute log normalizer A(η).
For categorical, A(η) = log(1 + sum(exp(η_k))).
- Returns
Log normalizer A(η).
- nat1: Array
- property natural_parameters: Array
Get natural parameters (logits).
- Returns
Natural parameters η.
- property probs: Array
Get probability parameters.
- class sppcax.distributions.Delta(location: Array, sufficient_statistics_fn: Optional[Callable] = None)[source]
Bases:
DistributionDelta distribution (Dirac delta) concentrated at a single point.
- property covariance: Array
Covariance matrix (always zero for delta distribution).
- entropy() Array[source]
Compute entropy (always 0 for delta distribution).
- Returns
batch_shape
- Return type
Entropy with shape
- property expected_psi: Array
matrix inverse for square matrices, element-wise otherwise.
- Type
Expected precision for Delta
- property expected_second_moment: Array
Expected second moment E[XX^T] = location @ location^T (no variance).
- property expected_sufficient_statistics: Array
Compute expected sufficient statistics.
For delta distribution, this is just sufficient_statistics(location) since all probability mass is concentrated at location.
- Returns
Expected sufficient statistics with shape determined by sufficient_statistics_fn.
- log_prob(x: Array) Array[source]
Compute log probability.
- Parameters
x – Value to compute log probability for. Shape: batch_shape + event_shape
- Returns
batch_shape Returns 0 at location, -inf elsewhere.
- Return type
Log probability with shape
- mean: Array
- property precision: Array
Precision (infinite for delta, but return large finite value for compatibility).
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution (always returns location).
- Parameters
key – PRNG key (unused).
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape All samples are equal to location.
- Return type
Samples with shape
- class sppcax.distributions.Distribution(batch_shape: Tuple[int, ...], event_shape: Tuple[int, ...])[source]
Bases:
ModuleBase distribution class in natural parameters.
- broadcast_to_shape(x: Array, ignore_event: bool = False) Array[source]
Broadcast array to match distribution shape.
- Parameters
x – Array to broadcast.
ignore_event – If True, only broadcast batch dimensions.
- Returns
Broadcasted array.
- entropy() Array[source]
Compute entropy of the distribution.
- Returns
batch_shape
- Return type
Entropy with shape
- log_prob(x: Array) Array[source]
Compute log probability of x.
- Parameters
x – Value to compute log probability for. Should have shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log probability with shape
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Optional[Tuple[int, ...]] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Additional sample dimensions.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
- class sppcax.distributions.ExponentialFamily(batch_shape: Tuple[int, ...] = (), event_shape: Tuple[int, ...] = ())[source]
Bases:
DistributionBase class for exponential family distributions in natural parameterization.
The exponential family has the form: p(x|η) = h(x)exp(η^T T(x) - A(η)) where: - η: natural parameters - T(x): sufficient statistics - A(η): log normalizer - h(x): base measure
- property entropy: Array
Compute entropy of the distribution.
- Returns
batch_shape
- Return type
Entropy with shape
- property expected_log_base_measure: Array
Compute the expectation of the log base measure E_{p(x)}[log(h(x))]
- Returns
batch_shape
- Return type
Expectation E_{p(x)}[log(h(x))] with shape
- property expected_sufficient_statistics: Array
Compute expected sufficient statistics E[T(x)].
- Returns
batch_shape + natural_param_shape
- Return type
Expected sufficient statistics E[T(x)] with shape
- classmethod from_natural_parameters(eta: Array) ExponentialFamily[source]
Create distribution from natural parameters.
- Parameters
eta – Natural parameters with shape: batch_shape + natural_param_shape
- Returns
Distribution instance.
- kl_divergence(other: ExponentialFamily) Array[source]
Compute KL divergence KL(self||other).
- Parameters
other – Other distribution to compute KL divergence with.
- Returns
batch_shape
- Return type
KL divergence KL(self||other) with shape
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log base measure log(h(x)) with shape
- property log_normalizer: Array
Compute log normalizer A(η).
- Returns
batch_shape
- Return type
Log normalizer A(η) with shape
- log_prob(x: Array) Array[source]
Compute log probability.
- Parameters
x – Data to compute log probability for. Shape: batch_shape + event_shape
- Returns
batch_shape Returns -inf for values outside the support.
- Return type
Log probability log p(x|η) with shape
- property natural_parameters: Array
Get natural parameters of the distribution.
- Returns
batch_shape + natural_param_shape
- Return type
Natural parameters η with shape
- class sppcax.distributions.Gamma(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]
Bases:
ExponentialFamilyGamma distribution in natural parameters.
The gamma distribution has density: p(x|α,β) = β^α * x^(α-1) * exp(-βx) / Γ(α)
In exponential family form: h(x) = 1 η = [α-1, -β] T(x) = [log(x), x] A(η) = log(Γ(η₁ + 1)) - (η₁ + 1)*log(-η₂)
- property alpha: Array
Get shape parameter α.
- property beta: Array
Get rate parameter β.
- dnat1: Array
- dnat2: Array
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [ψ(α) - log(β), α/β].
- Returns
batch_shape + (2,)
- Return type
Expected sufficient statistics [E[log(x)], E[x]] with shape
- classmethod from_natural_parameters(eta: Array) Gamma[source]
Create gamma distribution from natural parameters.
- Parameters
eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)
- Returns
Gamma distribution.
- property kl_divergence_from_prior: Array
Compute KL divergence KL(post||prior).
- Returns
batch_shape
- Return type
KL divergence KL(post||prior) with shape
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
zero
- property log_normalizer: Array
Compute log normalizer A(η) = log(Γ(η₁ + 1)) - (η₁ + 1)*log(-η₂).
- Returns
batch_shape
- Return type
Log normalizer with shape
- property mean: Array
Get mean E[x] = α/β.
- property nat1: Array
First natural parameter η₁ = α - 1.
- nat1_0: Array
- property nat2: Array
Second natural parameter η₂ = -β.
- nat2_0: Array
- property natural_parameters: Array
Get natural parameters η = [α-1, -β].
- Returns
batch_shape + (2,)
- Return type
Natural parameters [η₁, η₂] with shape
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
- class sppcax.distributions.InverseGamma(alpha0: float | jax.Array = 1.0, beta0: float | jax.Array = 1.0)[source]
Bases:
ExponentialFamilyInverse Gamma distribution in natural parameters.
The inverse gamma distribution has density: p(x|α,β) = β^α * x^(-α-1) * exp(-β/x) / Γ(α)
In exponential family form: h(x) = 1 η = [-α-1, -β] T(x) = [log(x), 1/x] A(η) = log(Γ(-η₁ - 1)) + (η₁ + 1)*log(-η₂)
- property alpha: Array
Get shape parameter α.
- property beta: Array
Get scale parameter β.
- dnat1: Array
- dnat2: Array
- property expected_psi: Array
Expected precision E[1/x] = alpha/beta for InverseGamma.
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [log(β) - ψ(α), α/β].
- Returns
batch_shape + (2,)
- Return type
Expected sufficient statistics [E[log(x)], E[1/x]] with shape
- classmethod from_natural_parameters(eta: Array) InverseGamma[source]
Create inverse gamma distribution from natural parameters.
- Parameters
eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)
- Returns
InverseGamma distribution.
- property kl_divergence_from_prior: Array
Compute KL divergence KL(post||prior).
- Returns
batch_shape
- Return type
KL divergence KL(post||prior) with shape
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
zero
- property log_normalizer: Array
Compute log normalizer A(η) = log(Γ(-η₁ - 1)) + (η₁ + 1)*log(-η₂)
- Returns
batch_shape
- Return type
Log normalizer with shape
- property mean: Array
Get mean E[x] = β/(α-1).
- mf_update(stats: tuple, partner_expectations: dict) InverseGamma[source]
Mean-field coordinate ascent update for InverseGamma noise.
Given partner (weights) expectations, compute the residual and update the InverseGamma natural parameters.
- Parameters
stats – Sufficient statistics tuple (SxxT, SxyT, SyyT, N).
partner_expectations – Dict with ‘mean’ and ‘second_moment’ from weights.
- Returns
Updated InverseGamma distribution.
- property nat1: Array
First natural parameter η₁ = -α - 1.
- nat1_0: Array
- property nat2: Array
Second natural parameter η₂ = -β.
- nat2_0: Array
- property natural_parameters: Array
Get natural parameters η = [-α-1, -β].
- Returns
batch_shape + (2,)
- Return type
Natural parameters [η₁, η₂] with shape
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
- class sppcax.distributions.InverseWishart(df0: float | jax.Array, scale0: Array)[source]
Bases:
ExponentialFamilyInverse Wishart distribution in natural parameters.
- For Sigma ~ IW(df, Psi):
p(Sigma | df, Psi) propto |Sigma|^{-(df+k+1)/2} exp(-tr(Psi Sigma^{-1})/2)
- Exponential family form:
eta1 = -(df + k + 1) / 2 (scalar) eta2 = -Psi / 2 (k x k matrix) T(Sigma) = [log|Sigma|, Sigma^{-1}] A(eta) = -(df/2) log|Psi/2| + multigammaln(df/2, k) + (df*k/2) log(2)
- Key moments:
E[Sigma^{-1}] = df * Psi^{-1} E[log|Sigma^{-1}|] = multidigamma(df/2, k) + k*log(2) - log|Psi|
- property df: Array
df = -2*nat1 - k - 1.
- Type
Degrees of freedom
- dnat1: Array
- dnat2: Array
- property entropy: Array
Entropy of IW(df, Psi).
- property expected_psi: Array
E[Sigma^{-1}] = df * Psi^{-1}.
- property inv_scale: Array
Inverse scale matrix Psi^{-1}.
- property kl_divergence_from_prior: Array
KL(posterior || prior) using stored prior natural parameters.
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log base measure log(h(x)) with shape
- property log_normalizer: Array
Log normalizer A(df, Psi).
- property mean: Array
Mean E[Sigma] = Psi / (df - k - 1), requires df > k + 1.
- mf_update(stats: tuple, partner_expectations: dict) InverseWishart[source]
Mean-field coordinate ascent update for InverseWishart noise.
- For IW noise in a linear model y = W x + e, e ~ N(0, Sigma):
df_post = df_prior + N Psi_post = Psi_prior + E[sum_t (y_t - W x_t)(y_t - W x_t)^T]
= Psi_prior + SyyT - SxyT^T E[W]^T - E[W] SxyT + E[WW^T] SxxT
- The quadratic term E[W SxxT W^T] decomposes under mean-field q(W) = prod_d N(w_d; m_d, V_d):
off-diag (i!=j): m_i^T SxxT m_j (independent rows) diag (i=i): tr(E[w_i w_i^T] SxxT) = m_i^T SxxT m_i + tr(V_i SxxT)
- property nat1: Array
-(df + k + 1) / 2.
- Type
First natural parameter
- nat1_0: Array
- property nat2: Array
-Psi / 2.
- Type
Second natural parameter
- nat2_0: Array
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from IW(df, Psi) via Bartlett decomposition.
- property scale: Array
Scale matrix Psi = -2 * nat2.
- class sppcax.distributions.MeanField(weights: Distribution, noise: Distribution)[source]
Bases:
DistributionMean-field composite distribution: q(W, Sigma) = q(W) x q(Sigma).
Components are independent. Posterior updates use coordinate ascent (alternating updates of weights and noise components).
- weights
Distribution over weight parameters (MVN or Delta if frozen).
- noise
Distribution over noise parameters (InverseGamma, Delta, etc.).
- n_iter
Number of coordinate ascent iterations for posterior updates.
- property alpha: Array
Shape parameter from InverseGamma noise (MVNIG compat).
- property beta: Array
Scale parameter from InverseGamma noise (MVNIG compat).
- property col_covariance: Array
Column covariance (base covariance, same as covariance).
- property covariance: Array
Covariance of the weights component (base, not scaled by noise).
- property expected_covariance: Array
Expected covariance E[sigma^2] * base_covariance.
For InverseGamma noise: scalar E[sigma^2] per row. For InverseWishart noise: full matrix E[Sigma], not factored with weights cov. For Delta noise: fixed value.
- property expected_psi: Array
Expected noise precision E[1/sigma^2] from noise component.
- property expected_sufficient_statistics_psi: Array
Expected sufficient statistics of noise precision (MVNIG compat).
- property inv_gamma
InverseGamma component (for MVNIG compatibility).
- log_prob(x: Tuple[Array, Array]) Array[source]
Compute log probability.
- Parameters
x – Tuple of (cov, w) where: w: Value of the sample state cov: Value of the sample covariance
- Returns
Log probability
- property mask: Array
Mask from weights component (if available).
- property mean: Array
Mean of the weights component.
- mode() Tuple[Array, Array][source]
Compute joint mode (noise_cov_matrix, weights_mean).
- Returns
Tuple of (noise covariance as matrix, weights mean).
- property mvn: Distribution
Weights component (for ARD compatibility).
- noise: Distribution
- property precision: Array
Precision of the weights component.
- sample(seed: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Tuple[Array, Array][source]
Sample from both components independently.
- Parameters
seed – PRNG key.
sample_shape – Additional sample dimensions.
- Returns
Tuple of (noise_sample, weights_sample).
- weights: Distribution
- class sppcax.distributions.MultivariateNormal(loc: Array, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None, mask: Optional[Array] = None)[source]
Bases:
ExponentialFamilyMultivariate normal distribution in natural parameters.
- apply_mask_matrix(x: Array, zeromask: bool = False) Array[source]
Apply mask to a matrix, zeroing out masked rows and columns.
- Parameters
x – Matrix with shape (…, d, d)
zeromask – If True, set masked entries to 0. If False (default), set masked entries to identity matrix values.
- Returns
Masked matrix with same shape.
- apply_mask_vector(x: Array) Array[source]
Apply mask to a vector, zeroing out masked dimensions.
- Parameters
x – Vector with shape (…, d)
- Returns
Masked vector with same shape
- property covariance: Array
- dnat1: Float[Array, 'dim']
- dnat2: Float[Array, 'rows cols']
- property expected_second_moment: Array
Expected second moment E[xx^T] = Cov + mean @ mean^T, per row.
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [μ, vec(μμ^T + Σ)].
- Returns
Expected sufficient statistics [E[x], vec(E[xx^T])].
- classmethod from_natural_parameters(nat1: Array, nat2: Array, mask: Optional[Array] = None) MultivariateNormal[source]
Create MVN from natural parameters.
- Parameters
nat1 – First natural parameter (precision * mean).
nat2 – Second natural parameter (-0.5 * precision).
mask – Optional boolean mask with shape matching nat1
- Returns
MultivariateNormal instance.
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log base measure log(h(x)) with shape
- property log_normalizer: Array
Compute log normalizer A(η).
- Returns
batch_shape
- Return type
Log normalizer A(η) with shape
- mask: Float[Array, 'dim']
- property mean: Array
Get mean parameter.
- mf_update(stats: tuple, partner_expectations: dict) MultivariateNormal[source]
Mean-field coordinate ascent update for MVN weights.
Scales data sufficient statistics by partner’s expected precision E[1/sigma^2], then performs standard natural parameter update.
- Parameters
stats – Sufficient statistics tuple (SxxT, SxyT, SyyT, N).
partner_expectations – Dict with ‘expected_precision’ from noise component. E[precision] can be scalar (isotropic), (D,) vector (per-feature IG), or (D, D) matrix (InverseWishart). For matrix precision, the diagonal is used for per-row weight updates.
- Returns
Updated MVN distribution (posterior).
- property nat1: Array
First natural parameter (precision * mean), masked.
- nat1_0: Float[Array, 'dim']
- property nat2: Array
Second natural parameter (-0.5 * precision).
- nat2_0: Float[Array, 'rows cols']
- property natural_parameters: Array
Get natural parameters η = [precision*mean, -0.5*precision].
- Returns
Natural parameters [η₁, vec(η₂)].
- property precision: Array
Get precision parameter.
- class sppcax.distributions.MultivariateNormalInverseGamma(loc: Array, *, isotropic_noise: bool, mask: Optional[Array] = None, alpha0: float = 2.0, beta0: float = 1.0, scale_tril: Optional[Array] = None, covariance: Optional[Array] = None, precision: Optional[Array] = None)[source]
Bases:
ExponentialFamilyMultivariate Normal-Gamma distribution.
This distribution combines: p(x|σ²) = N(x; μ, σ²Λ⁻¹) # Multivariate Normal p(σ²) = InverseGamma(σ²; α, β) # Variance scalar
where: - x is a vector (can be batched) - μ is the location parameter - Λ is a base precision matrix - σ² is a scalar variance parameter - α, β are Inverse-Gamma distribution parameters
- property alpha: Array
Shape parameter α of the InverseGamma component.
- batch_shape: Shape
- property beta: Array
Scale parameter β of the InverseGamma component.
- property col_covariance: Array
Column covariance Λ⁻¹ (alias for covariance).
- property covariance: Array
Base covariance Λ⁻¹ of the MVN component (without noise scaling).
- event_shape: Shape
- property expected_covariance: Array
Expected covariance E[σ²] * Λ⁻¹.
- property expected_log_psi: Array
Compute expected log precision E[log(psi)].
- property expected_psi: Array
Compute expected precision E[psi].
- property expected_sufficient_statistics_psi: Array
Compute expected sufficient statistics of psi.
- inv_gamma: InverseGamma
- log_prob(x: Tuple[Array, Array]) Array[source]
Compute log probability.
- Parameters
x – Tuple of (sig_sqr, w) where: w: Value of the sample state sig_sqr: Value of the sample variance
- Returns
Log probability
- property mean: Array
Get mean of the marginal distribution p(x).
- mode() Tuple[Array, Array][source]
Solve for the mode. Recall, ..math:
p(\mu, \sigma^2) \propto \mathrm{N}(x \mid \mu, \sigma^2 \Sigma) \times \mathrm{IG}(\sigma^2 \mid \alpha, \beta)
The optimal mean is \(x^* = \mu_0\). Substituting this in, ..math:
p(\mu^*, \sigma^2) \propto IG(\sigma^2 \mid \alpha + D/2, \beta)
where D is dimensionality of x, and the mode of this inverse gamma distribution is at ..math:
(\sigma^2)* = \beta / (\alpha + (D + 2)/2)
- mvn: MultivariateNormal
- property precision: Array
Base precision matrix Λ of the MVN component.
- class sppcax.distributions.Normal(loc: Array = 0.0, scale: Array = 1.0)[source]
Bases:
ExponentialFamilyUnivariate normal distribution in natural parameters.
The normal distribution has density: p(x|μ,σ) = 1/√(2πσ²) * exp(-(x-μ)²/(2σ²))
In exponential family form: η = [μ/σ², -1/(2σ²)] T(x) = [x, x²] A(η) = -η₁²/(4η₂) - (1/2)log(-2η₂) + (1/2)log(2π)
- property entropy: Array
Compute entropy of the distribution.
- Returns
batch_shape
- Return type
Entropy with shape
- property expected_sufficient_statistics: Array
Compute E[T(x)] = [μ, μ² + σ²].
- Returns
batch_shape + (2,)
- Return type
Expected sufficient statistics [E[x], E[x²]] with shape
- classmethod from_natural_parameters(eta: Array) Normal[source]
Create normal distribution from natural parameters.
- Parameters
eta – Natural parameters [η₁, η₂] with shape: batch_shape + (2,)
- Returns
Normal distribution.
- property loc: Array
Get location parameter.
- log_base_measure(x: Array = None) Array[source]
Compute log of base measure h(x).
- Parameters
x – Data to compute base measure for. Shape: batch_shape + event_shape
- Returns
batch_shape
- Return type
Log base measure log(h(x)) with shape
- property log_normalizer: Array
Compute log normalizer A(η).
- Returns
batch_shape
- Return type
Log normalizer with shape
- nat1: Array
- nat2: Array
- property natural_parameters: Array
Get natural parameters η = [precision*mean, -0.5*precision].
- Returns
batch_shape + (2,)
- Return type
Natural parameters [η₁, η₂] with shape
- property precision: Array
Get precision parameter
- sample(key: Union[Key[Array, ''], UInt32[Array, '2']], sample_shape: Tuple[int, ...] = ()) Array[source]
Sample from the distribution.
- Parameters
key – PRNG key for random sampling.
sample_shape – Shape of samples to draw.
- Returns
sample_shape + batch_shape + event_shape
- Return type
Samples with shape
- property scale: Array
Get scale parameter.
- class sppcax.distributions.Poisson(log_rate: Array)[source]
Bases:
ExponentialFamilyPoisson distribution parameterized by log rate.
The Poisson distribution has natural parameter η = log(λ) where λ is the rate parameter, and sufficient statistic T(x) = x.
- property entropy: Array
Compute entropy of Poisson distribution.
- Returns
Entropy H(λ) = λ(1 - log(λ)) + exp(-λ)sum_{k=0}^∞ λ^k log(k!)/k!
- property expected_sufficient_statistics: Array
Compute E[T(x)] = E[x] = λ = exp(η).
- Returns
Expected sufficient statistics E[x].
- classmethod from_natural_parameters(eta: Array) Poisson[source]
Create Poisson from natural parameters.
- Parameters
log_rate – Natural parameter η = log(λ).
- Returns
Poisson instance.
- log_base_measure(x: Array) Array[source]
Compute log base measure log(h(x)) = -log(x!).
- Parameters
x – Count data.
- Returns
Log base measure -log(x!).
- property log_normalizer: Array
Compute log normalizer A(η) = exp(η).
- Returns
Log normalizer A(η).
- property log_rate: Array
Get log(rate) parameter.
- nat1: Array
- property natural_parameters: Array
Get natural parameters (log rate).
- Returns
Natural parameters η = log(λ).
- property rate: Array
Get rate parameter.