[1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
jax.config.update('jax_default_device', 'cpu')
jax.config.update('jax_default_matmul_precision', 'highest')
[2]:
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial
from jax import vmap
from scipy.linalg import orthogonal_procrustes, solve
from sppcax.distributions import MultivariateNormalInverseGamma as MVNIG, MultivariateNormal
from sppcax.models import BayesianDynamicFactorAnalysis as BDFA
from sppcax.models.dynamic_factor_analysis import (
ParamsLGSSM, ParamsLGSSMInitial,
)
from sppcax.inference.utils import ParamsLGSSMVB
from dynamax.linear_gaussian_ssm import parallel_lgssm_smoother
Testing PX-VBEM for Dynamic Factor Analysis
This notebook tests the Parameter-Expanded VBEM (PX-VBEM) algorithm applied to Bayesian Dynamic Factor Analysis (DFA).
The PX-VB rotation matrix \(\mathbf{R}\) is found by numerically minimizing the full objective \(\mathbb{E}_q[-\ln p(\tilde{\mathbf{H}}, \tilde{\mathbf{F}}, \tilde{\mathbf{x}} \mid \mathbf{R})]\) using gradient descent with Anderson acceleration (m=1).
The rotation \(\mathbf{R}\) transforms posteriors as: - Emissions: \(\tilde{\mathbf{H}} = \mathbf{H} \mathbf{R}_\text{block}\) where \(\mathbf{R}_\text{block} = \text{blkdiag}(\mathbf{R}, \mathbf{I})\) - Dynamics: \(\tilde{\mathbf{F}} = \mathbf{R}^{-1} \mathbf{F} \mathbf{R}_\text{block}\) - Initial: \(\tilde{\boldsymbol{\mu}}_0 = \mathbf{R}^{-1} \boldsymbol{\mu}_0\), \(\tilde{\boldsymbol{\Sigma}}_0 = \mathbf{R}^{-1} \boldsymbol{\Sigma}_0 \mathbf{R}^{-\top}\) - Process noise: \(\mathbf{Q} = \mathbf{I}\) (kept fixed)
1. Generate Synthetic DFA Data
[3]:
# Data dimensions
T = 200 # time steps
D = 20 # features
K = 2 # latent components
key = jr.PRNGKey(137)
# True dynamics: damped 2D rotation
theta = jnp.pi / 6
rho = 0.95
F_true = rho * jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
[jnp.sin(theta), jnp.cos(theta)]])
Q_true = jnp.eye(K)
# True sparse loading matrix
key, k1, k2 = jr.split(key, 3)
H_true = jr.normal(k1, (D, K)) * jr.bernoulli(k2, p=0.5, shape=(D, K))
# True diagonal noise covariance
key, k1 = jr.split(key)
R_true_diag = jr.uniform(k1, (D,), minval=0.25, maxval=1.0)
R_true = jnp.diag(R_true_diag)
# Initial state
mu0_true = jnp.zeros(K)
Sigma0_true = jnp.eye(K)
# Generate latent states and observations
z_true = jnp.zeros((T, K))
y_obs = jnp.zeros((T, D))
key, k1, k2 = jr.split(key, 3)
z_curr = jr.multivariate_normal(k1, mu0_true, Sigma0_true)
z_true = z_true.at[0].set(z_curr)
y_obs = y_obs.at[0].set(jr.multivariate_normal(k2, H_true @ z_curr, R_true))
for t in range(1, T):
key, k1, k2 = jr.split(key, 3)
z_curr = jr.multivariate_normal(k1, F_true @ z_curr, Q_true)
z_true = z_true.at[t].set(z_curr)
y_obs = y_obs.at[t].set(jr.multivariate_normal(k2, H_true @ z_curr, R_true))
print(f"y_obs shape: {y_obs.shape}, H_true shape: {H_true.shape}")
print(f"F_true eigenvalues: {jnp.linalg.eigvals(F_true)}")
print(f"H_true sparsity: {jnp.isclose(H_true, 0.).mean():.1%} zeros")
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(z_true[:, 0], z_true[:, 1], alpha=0.5, lw=0.5)
axes[0].set_xlabel('$z_1$'); axes[0].set_ylabel('$z_2$')
axes[0].set_title('Latent Phase Portrait')
axes[0].set_aspect('equal')
sns.heatmap(y_obs[:100], ax=axes[1], cmap='RdBu_r', center=0)
axes[1].set_title('Observations (first 100)')
sns.heatmap(H_true, ax=axes[2], cmap='RdBu_r', center=0)
axes[2].set_title('True Loading Matrix H')
plt.tight_layout()
y_obs shape: (200, 20), H_true shape: (20, 2)
F_true eigenvalues: [0.8227241+0.475j 0.8227241-0.475j]
H_true sparsity: 42.5% zeros
2. Experiments Without BMR
[4]:
NUM_ITERS = 20
# --- Standard EM ---
model = BDFA(2 * K, D, use_px=False)
key, k1, k2 = jr.split(key, 3)
params_init, props = model.initialize(k1)
params_em, elbos_em = model.fit_em(params_init, props, y_obs, k2, num_iters=NUM_ITERS)
print(f"EM final ELBO: {elbos_em[-1]:.1f}")
# --- PX-EM ---
model_px = BDFA(2 * K, D, use_px=True)
params_pxl, elbos_pxl = model_px.fit_em(params_init, props, y_obs, k2, num_iters=NUM_ITERS)
print(f"PX-EM final ELBO: {elbos_pxl[-1]:.1f}")
# --- Standard VBEM ---
model = BDFA(2 * K, D, use_px=False)
key, k1, k2 = jr.split(key, 3)
params_vbem_init, props_vbem = model.initialize(k1, variational_bayes=True)
params_vbem, elbos_vbem = model.fit_vbem(params_vbem_init, props_vbem, y_obs, k2, num_iters=NUM_ITERS)
print(f"VBEM final ELBO: {elbos_vbem[-1]:.1f}")
# --- PX-VBEM ---
model_px = BDFA(2 * K, D, use_px=True)
params_pxl_vb, elbos_pxl_vb = model_px.fit_vbem(params_vbem_init, props_vbem, y_obs, k2, num_iters=NUM_ITERS)
print(f"PX-VBEM final ELBO: {elbos_pxl_vb[-1]:.1f}")
EM final ELBO: -5607.1
PX-EM final ELBO: -5361.8
VBEM final ELBO: -5638.9
PX-VBEM final ELBO: -5414.5
Convergence Comparison (No BMR)
[5]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
ax = axes[0]
ax.plot(elbos_em[1:], label='EM', alpha=0.8)
ax.plot(elbos_pxl[1:], label='PX-EM', alpha=0.8)
ax.plot(elbos_vbem[1:], label='VBEM', alpha=0.8)
ax.plot(elbos_pxl_vb[1:], label='PX-VBEM', alpha=0.8)
ax.set_xlabel('Iteration')
ax.set_ylabel('ELBO')
ax.set_title('ELBO Convergence (No BMR)')
ax.legend()
ax.grid(True, alpha=0.3)
ax = axes[1]
ax.plot(elbos_em[1:20], label='EM', alpha=0.8)
ax.plot(elbos_pxl[1:20], label='PX-EM', alpha=0.8)
ax.plot(elbos_vbem[1:20], label='VBEM', alpha=0.8)
ax.plot(elbos_pxl_vb[1:20], label='PX-VBEM', alpha=0.8)
ax.set_xlabel('Iteration')
ax.set_ylabel('ELBO')
ax.set_title('ELBO Convergence \u2014 First 20 Iterations')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
Loading Matrix & Latent State Recovery (No BMR)
[6]:
def procrustes_similarity(H_est, H_true):
"""Align H_est to H_true via orthogonal Procrustes.
Handles different number of columns by zero-padding.
Returns aligned H_est and Frobenius disparity.
"""
D, K_est = H_est.shape
K_true = H_true.shape[1]
K_max = max(K_est, K_true)
A = jnp.pad(H_est, [(0, 0), (0, K_max - K_est)]) if K_est < K_max else H_est
B = jnp.pad(H_true, [(0, 0), (0, K_max - K_true)]) if K_true < K_max else H_true
R, _ = orthogonal_procrustes(A, B)
H_aligned = A @ R
disparity = jnp.linalg.norm(H_aligned - B, 'fro')
return H_aligned, float(disparity), R
results_no_bmr = {
'EM': params_em,
'PX-EM': params_pxl,
'VBEM': params_vbem,
'PX-VBEM': params_pxl_vb,
}
# Loading recovery
fig, axes = plt.subplots(1, 5, figsize=(20, 6))
vmax = jnp.abs(H_true).max()
sns.heatmap(jnp.pad(H_true, [(0, 0), (0, K)]), ax=axes[0], cmap='RdBu_r', center=0, vmin=-vmax, vmax=vmax)
axes[0].set_title('True H')
Ezs = {}
procrustes_Rs = {}
for idx, (name, p) in enumerate(results_no_bmr.items()):
H_aligned, disparity, R = procrustes_similarity(p.emissions.weights, H_true)
procrustes_Rs[name] = R
sns.heatmap(H_aligned, ax=axes[idx + 1], cmap='RdBu_r', center=0, vmin=-vmax, vmax=vmax)
axes[idx + 1].set_title(f'{name}\n||.||={disparity:.2f}')
post = parallel_lgssm_smoother(p, y_obs)
Ezs[name] = jax.scipy.linalg.solve(R, post.smoothed_means[..., None])[..., 0]
plt.suptitle('Loading Matrix Recovery (No BMR)', y=1.02)
plt.tight_layout()
# Transition matrix recovery: F_aligned = R^T @ F_est @ R
K_true = H_true.shape[1]
K_est = params_em.emissions.weights.shape[1]
K_max = max(K_est, K_true)
F_true_pad = jnp.pad(F_true, [(0, K_max - K_true), (0, K_max - K_true)])
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
vmax_f = max(jnp.abs(F_true).max(), 0.5)
sns.heatmap(F_true_pad, ax=axes[0], cmap='RdBu_r', center=0, vmin=-vmax_f, vmax=vmax_f, annot=True, fmt='.2f')
axes[0].set_title('True F')
for idx, (name, p) in enumerate(results_no_bmr.items()):
R = procrustes_Rs[name]
F_aligned = R.T @ p.dynamics.weights @ R
err = jnp.linalg.norm(F_aligned - F_true_pad, 'fro')
sns.heatmap(F_aligned, ax=axes[idx + 1], cmap='RdBu_r', center=0, vmin=-vmax_f, vmax=vmax_f, annot=True, fmt='.2f')
axes[idx + 1].set_title(f'{name}\n||.||={err:.3f}')
plt.suptitle('Transition Matrix Recovery (No BMR)', y=1.02)
plt.tight_layout()
# Latent state recovery
fig, axes = plt.subplots(K, 1, figsize=(12, 3 * K), sharex=True)
name = 'PX-VBEM'
for i in range(K):
axes[i].plot(z_true[:, i], label=f'True $z_{i}$', ls='--', alpha=0.7)
axes[i].plot(Ezs[name][:, i], label=f'{name} $E[z_{i}]$', alpha=0.7)
axes[i].legend()
axes[i].grid(True, alpha=0.3)
axes[-1].set_xlabel('Time')
plt.suptitle(f'Latent State Recovery ({name}, No BMR)')
plt.tight_layout()
3. Experiments With BMR
[7]:
NUM_ITERS_BMR = 20
use_ard = True
comp_dim = 2 * K
# --- Standard EM + BMR ---
model_bmr = BDFA(
comp_dim,
D,
has_ard=use_ard, use_bmr=True, use_px=False,
)
key, k1, k2 = jr.split(key, 3)
params_init_bmr, props_bmr = model_bmr.initialize(k1)
params_em_bmr, elbos_em_bmr = model_bmr.fit_em(
params_init_bmr, props_bmr, y_obs, k2, num_iters=NUM_ITERS_BMR, bmr_start_iter=8)
print(f"EM+BMR final ELBO: {elbos_em_bmr[-1]:.1f}")
# --- PX-EM + BMR ---
model_bmr2 = BDFA(
comp_dim, D,
has_ard=use_ard, use_bmr=True, use_px=True
)
params_pxl_bmr, elbos_pxl_bmr = model_bmr2.fit_em(
params_init_bmr, props_bmr, y_obs, k2, num_iters=NUM_ITERS_BMR, bmr_start_iter=8)
print(f"PX-EM+BMR final ELBO: {elbos_pxl_bmr[-1]:.1f}")
EM+BMR final ELBO: -5412.8
PX-EM+BMR final ELBO: -5243.1
[8]:
# --- Standard VBEM + BMR ---
model_bmr3 = BDFA(
comp_dim,
D,
has_ard=use_ard, use_bmr=True, use_px=False,
)
key, k1, k2 = jr.split(key, 3)
params_vbem_bmr_init, props_vbem_bmr = model_bmr3.initialize(k1, variational_bayes=True)
params_vbem_bmr, elbos_vbem_bmr = model_bmr3.fit_vbem(
params_vbem_bmr_init, props_vbem_bmr, y_obs, k2, num_iters=NUM_ITERS_BMR, bmr_start_iter=8)
print(f"VBEM+BMR final ELBO: {elbos_vbem_bmr[-1]:.1f}")
# --- PXL-VBEM + BMR ---
model_bmr4 = BDFA(
comp_dim,
D,
has_ard=use_ard, use_bmr=True, use_px=True
)
params_pxl_vb_bmr, elbos_pxl_vb_bmr = model_bmr4.fit_vbem(
params_vbem_bmr_init, props_vbem_bmr, y_obs, k2, num_iters=NUM_ITERS_BMR, bmr_start_iter=8)
print(f"PX-VBEM+BMR final ELBO: {elbos_pxl_vb_bmr[-1]:.1f}")
VBEM+BMR final ELBO: -5540.4
PX-VBEM+BMR final ELBO: -5259.5
Convergence Comparison (With BMR)
[9]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
ax = axes[0]
ax.plot(elbos_em_bmr[1:], label='EM+BMR', alpha=0.8)
ax.plot(elbos_pxl_bmr[1:], label='PXL-EM+BMR', alpha=0.8)
ax.plot(elbos_vbem_bmr[1:], label='VBEM+BMR', alpha=0.8)
ax.plot(elbos_pxl_vb_bmr[1:], label='PXL-VBEM+BMR', alpha=0.8)
ax.set_xlabel('Iteration')
ax.set_ylabel('ELBO')
ax.set_title('ELBO Convergence (With BMR)')
ax.legend()
ax.grid(True, alpha=0.3)
ax = axes[1]
ax.plot(elbos_em_bmr[1:20], label='EM+BMR', alpha=0.8)
ax.plot(elbos_pxl_bmr[1:20], label='PX-EM+BMR', alpha=0.8)
ax.plot(elbos_vbem_bmr[1:20], label='VBEM+BMR', alpha=0.8)
ax.plot(elbos_pxl_vb_bmr[1:20], label='PX-VBEM+BMR', alpha=0.8)
ax.set_xlabel('Iteration')
ax.set_ylabel('ELBO')
ax.set_title('ELBO Convergence \u2014 First 20 Iterations')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
Loading Matrix Recovery (With BMR)
[10]:
results_bmr = {
'EM+BMR': params_em_bmr,
'PX-EM+BMR': params_pxl_bmr,
'VBEM+BMR': params_vbem_bmr,
'PX-VBEM+BMR': params_pxl_vb_bmr,
}
# Loading recovery
fig, axes = plt.subplots(1, 5, figsize=(20, 6))
vmax = jnp.abs(H_true).max()
sns.heatmap(jnp.pad(H_true, [(0, 0), (0, comp_dim - K)]), ax=axes[0], cmap='RdBu_r', center=0, vmin=-vmax, vmax=vmax)
axes[0].set_title('True H')
procrustes_Rs_bmr = {}
for idx, (name, p) in enumerate(results_bmr.items()):
H_aligned, disparity, R = procrustes_similarity(p.emissions.weights, H_true)
procrustes_Rs_bmr[name] = R
sparsity = jnp.mean(jnp.abs(p.emissions.weights) < 0.05)
sns.heatmap(H_aligned, ax=axes[idx + 1], cmap='RdBu_r', center=0, vmin=-vmax, vmax=vmax)
axes[idx + 1].set_title(f'{name}\n||.||={disparity:.2f}, Sp={sparsity:.0%}')
plt.suptitle('Loading Matrix Recovery (With BMR)', y=1.02)
plt.tight_layout()
# Transition matrix recovery
K_bmr = params_em_bmr.dynamics.weights.shape[0]
F_true_bmr = F_true[:K_bmr, :K_bmr] if K_bmr <= K_true else jnp.pad(F_true, [(0, K_bmr - K_true)] * 2)
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
vmax_f = max(jnp.abs(F_true).max(), 0.5)
sns.heatmap(F_true_bmr, ax=axes[0], cmap='RdBu_r', center=0, vmin=-vmax_f, vmax=vmax_f, annot=True, fmt='.2f')
axes[0].set_title('True F')
for idx, (name, p) in enumerate(results_bmr.items()):
R = procrustes_Rs_bmr[name]
F_aligned = R.T @ p.dynamics.weights @ R
err = jnp.linalg.norm(F_aligned - F_true_bmr, 'fro')
sns.heatmap(F_aligned, ax=axes[idx + 1], cmap='RdBu_r', center=0, vmin=-vmax_f, vmax=vmax_f, annot=True, fmt='.2f')
axes[idx + 1].set_title(f'{name}\n||.||={err:.3f}')
plt.suptitle('Transition Matrix Recovery (With BMR)', y=1.02)
plt.tight_layout()
7. Summary
DFA PX-VB rotation: The rotation \(\mathbf{R}\) is found by numerically minimizing the full objective \(\mathcal{L}(\mathbf{R}) = \mathbb{E}_q[-\ln p(\tilde{\mathbf{H}}, \tilde{\mathbf{F}}, \tilde{\mathbf{x}} \mid \mathbf{R})]\) via gradient descent with Anderson acceleration (m=1). The loss has four terms: - \(\mathcal{L}_\text{init}\): initial state prior - \(\mathcal{L}_\text{em}\): emission prior - \(\mathcal{L}_\text{dyn-prior}\): dynamics prior - \(\mathcal{L}_\text{dyn-lik}\): dynamics likelihood
The converged \(\mathbf{R}\) transforms posteriors as: - \(\tilde{\mathbf{H}} = \mathbf{H} \mathbf{R}_\text{block}\), \(\mathbf{R}_\text{block} = \text{blkdiag}(\mathbf{R}, \mathbf{I})\) - \(\tilde{\mathbf{F}} = \mathbf{R}^{-1} \mathbf{F} \mathbf{R}_\text{block}\) - \(\tilde{\boldsymbol{\mu}}_0 = \mathbf{R}^{-1} \boldsymbol{\mu}_0\), \(\tilde{\boldsymbol{\Sigma}}_0 = \mathbf{R}^{-1} \boldsymbol{\Sigma}_0 \mathbf{R}^{-\top}\) - \(\mathbf{Q} = \mathbf{I}\) (kept fixed)
If the optimization does not reduce the loss, \(\mathbf{R}\) falls back to identity.
Expected findings: - PX-EM converges faster than standard EM - The 2D rotation dynamics test whether PX-EM can correctly handle coupled latent dynamics - BMR enables sparse loading recovery