# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Recovery tests against a clean DGP that the IO-HMM models exactly."""
from __future__ import annotations
import numpy as np
import pytest
from iohmm_evac.diagnostics.alignment import align_states, apply_permutation
from iohmm_evac.diagnostics.recovery import (
align_fit_to_truth,
parameter_recovery,
state_recovery_accuracy,
)
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.fit import fit
from iohmm_evac.inference.fit_params import (
EmissionFitParams,
FitParameters,
InitialFitParams,
TransitionFitParams,
)
from tests._clean_dgp import generate
def _clean_truth() -> FitParameters:
"""A 5-state truth with structural zeros in a left-to-right cascade."""
k = 5
f = 3
alpha = np.full((k, k), -np.inf, dtype=np.float64)
np.fill_diagonal(alpha, 0.0)
alpha[0, 1] = -1.5
alpha[1, 0] = -3.0
alpha[1, 2] = -2.0
alpha[2, 3] = -1.5
alpha[2, 4] = -3.0
alpha[3, 4] = -1.0
beta = np.zeros((k, k, f))
beta[0, 1, 0] = 1.5
beta[1, 2, 0] = 1.0
beta[1, 2, 1] = 0.8
beta[2, 3, 1] = 1.4
beta[2, 4, 2] = 0.5
beta[3, 4, 2] = 1.6
emit = EmissionFitParams(
p_departure=np.array([0.05, 0.10, 0.20, 0.95, 0.05]),
mu_displacement=np.array([0.0, 0.5, 1.0, 5.0, 8.0]),
sigma_displacement=np.array([0.5, 0.5, 0.5, 0.5, 0.5]),
lambda_comm=np.array([0.5, 1.5, 3.0, 2.0, 1.0]),
)
return FitParameters(
initial=InitialFitParams(logits=np.array([0.0, -3.0, -3.0, -3.0, -3.0])),
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
)
def _make_clean_data(
rng: np.random.Generator, n: int, t_plus_1: int
) -> tuple[FitParameters, FitData]:
truth = _clean_truth()
f = truth.transitions.beta.shape[2]
inputs = np.zeros((n, t_plus_1, f), dtype=np.float64)
inputs[:, :, 0] = np.linspace(0.0, 1.0, t_plus_1)[None, :]
inputs[:, t_plus_1 // 3 :, 1] = 1.0
inputs[:, 2 * t_plus_1 // 3 :, 2] = 1.0
inputs[:, :, :] += 0.1 * rng.standard_normal(inputs.shape)
sample = generate(truth, inputs, rng)
data = FitData(
inputs=inputs,
departure=sample.departure,
displacement=sample.displacement,
comm=sample.comm,
true_states=sample.states,
)
return truth, data
def _perturb_truth(truth: FitParameters, rng: np.random.Generator) -> FitParameters:
alpha = truth.transitions.alpha.copy()
beta = truth.transitions.beta.copy()
finite = np.isfinite(alpha) & ~np.eye(alpha.shape[0], dtype=bool)
alpha = np.where(finite, alpha + 0.5 * rng.standard_normal(alpha.shape), alpha)
beta = beta + 0.3 * rng.standard_normal(beta.shape)
p = np.clip(truth.emissions.p_departure + 0.1 * rng.standard_normal(5), 1e-3, 1 - 1e-3)
mu = truth.emissions.mu_displacement + 0.5 * rng.standard_normal(5)
sigma = np.maximum(truth.emissions.sigma_displacement + 0.2 * rng.standard_normal(5), 0.2)
lam = np.maximum(truth.emissions.lambda_comm + 0.3 * rng.standard_normal(5), 0.1)
emit = EmissionFitParams(
p_departure=p, mu_displacement=mu, sigma_displacement=sigma, lambda_comm=lam
)
return FitParameters(
initial=truth.initial,
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
feature_names=truth.feature_names,
)
@pytest.fixture(scope="module")
def fitted_run() -> tuple[FitParameters, FitData, FitParameters]:
rng = np.random.default_rng(0)
truth, data = _make_clean_data(rng, n=2000, t_plus_1=61)
inits = [_perturb_truth(truth, rng) for _ in range(5)]
best = None
best_ll = -np.inf
for init in inits:
result = fit(
data,
n_restarts=1,
em_config=EMConfig(max_iter=80, tol=1e-6),
init="truth",
rng=rng,
truth_init=init,
)
if result.best.final_log_likelihood > best_ll:
best = result.best
best_ll = result.best.final_log_likelihood
assert best is not None
return truth, data, best.params
def test_state_recovery_accuracy(
fitted_run: tuple[FitParameters, FitData, FitParameters],
) -> None:
_truth, data, fit_params = fitted_run
from iohmm_evac.diagnostics.decoding import viterbi
fit_path = viterbi(fit_params, data)
assert data.true_states is not None
perm = align_states(data.true_states, fit_path, k=5)
aligned = apply_permutation(fit_path, perm)
accuracy = state_recovery_accuracy(data.true_states, aligned)
assert accuracy >= 0.85, f"state recovery accuracy {accuracy:.3f} below 0.85 target"
def test_parameter_recovery_rmse(
fitted_run: tuple[FitParameters, FitData, FitParameters],
) -> None:
truth, data, fit_params = fitted_run
from iohmm_evac.diagnostics.decoding import viterbi
fit_path = viterbi(fit_params, data)
assert data.true_states is not None
perm = align_states(data.true_states, fit_path, k=5)
aligned_fit = align_fit_to_truth(fit_params, perm)
report = parameter_recovery(truth, aligned_fit)
assert report.transition_beta_rmse <= 0.5, (
f"transition β RMSE {report.transition_beta_rmse:.3f} above 0.5"
)
assert report.emission_mu_rmse <= 0.5, (
f"emission μ RMSE {report.emission_mu_rmse:.3f} above 0.5"
)