Usage
This section provides examples of how to use the sppcax package for Bayesian Factor Analysis and Model Reduction.
Basic Example
Here’s a simple example of using the Bayesian Factor Analysis model:
import jax.numpy as jnp
import jax.random as jr
from sppcax.models import PPCA, fit, transform, inverse_transform
# Generate random data
key = jr.PRNGKey(0)
n_samples, n_features, n_components = 100, 20, 5
# Create a model with 5 components
model = PPCA(n_components=n_components,
n_features=n_features,
random_state=key)
# Generate some synthetic data
X = jnp.ones((n_samples, n_features))
# Fit the model
key, _key = jr.split(key)
model, elbos = fit(model, X, n_iter=50, key=_key)
# Transform data to latent space
qz = transform(model, X)
# Reconstruct the data
reconstructed_X = inverse_transform(model, qz).mean
Factor Analysis vs. PPCA
sppcax provides two main variants of Bayesian Factor Analysis:
Probabilistic PCA (PPCA): Uses isotropic noise (same precision for all features)
from sppcax.models import PPCA model = PPCA(n_components=5, n_features=20)
Factor Analysis (FA): Uses diagonal noise (different precision for each feature)
from sppcax.models import PFA model = PFA(n_components=5, n_features=20)
Handling Missing Data (Partial Observations)
Both models can handle missing data by providing a boolean mask where True
marks observed values and False marks missing ones. The mask can have shape
(T, D) for per-element masking, or (T,) to mask entire observation
vectors at specific timesteps.
For a comprehensive worked example comparing masked models against subset baselines across EM, VB-EM, and Gibbs inference, see the Masked Observations notebook.
Here is a minimal example:
import jax.numpy as jnp
from sppcax.models import PPCA, fit, transform
# Data with some missing values (marked as False in the mask)
data = jnp.ones((100, 20))
mask = jnp.ones((100, 20), dtype=bool)
# Set some values as missing
mask = mask.at[10:20, 5:10].set(False)
# Create a model with the mask
model = PPCA(n_components=5,
n_features=20,
data_mask=mask)
# Fit the model
model, elbos = fit(model, data)
# Transform can use the mask to handle missing values in new data
latent = transform(model, data, use_data_mask=True)
Bayesian Model Reduction
The Bayesian Model Reduction (BMR) algorithm can be used to prune unnecessary parameters in the loading matrix:
import jax.numpy as jnp
from sppcax.models import PFA
from sppcax.bmr.delta_f import compute_delta_f
# Fit a model
model = PPCA(
n_components=5,
n_features=20,
optimize_with_bmr=True,
bmr_e_step=True,
bmr_m_step=True,
bmr_e_step_opts=('max_iter', 2, 'pi', 0.2)
)
# optimize_with_bmr controls Empirical Bayes like hyperparameter optimization
# bmr_e_step controls BMR during VBE-step, where the posterior over latents is pruned
# bmr_m_step controls BMR during VBM-step, where the posterior over loading matrix elements is pruned.
data = jnp.ones((100, 20))
key, _key = jr.split(key)
fitted_model, elbos = fit(model, data, n_iter=256, bmr_frequency=16, key=_key)
# bmr_frequency specifies the frequency of BMR pruning during the VBM-step, here
# the pruning is performed every 16 steps.
Advanced Usage
For more advanced usage please refer to jupyter notebooks provided in examples directory.