tests/test_recovery.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Recovery tests against a clean DGP that the IO-HMM models exactly."""

from __future__ import annotations

import numpy as np
import pytest

from iohmm_evac.diagnostics.alignment import align_states, apply_permutation
from iohmm_evac.diagnostics.recovery import (
    align_fit_to_truth,
    parameter_recovery,
    state_recovery_accuracy,
)
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.fit import fit
from iohmm_evac.inference.fit_params import (
    EmissionFitParams,
    FitParameters,
    InitialFitParams,
    TransitionFitParams,
)
from tests._clean_dgp import generate


def _clean_truth() -> FitParameters:
    """A 5-state truth with structural zeros in a left-to-right cascade."""
    k = 5
    f = 3
    alpha = np.full((k, k), -np.inf, dtype=np.float64)
    np.fill_diagonal(alpha, 0.0)
    alpha[0, 1] = -1.5
    alpha[1, 0] = -3.0
    alpha[1, 2] = -2.0
    alpha[2, 3] = -1.5
    alpha[2, 4] = -3.0
    alpha[3, 4] = -1.0

    beta = np.zeros((k, k, f))
    beta[0, 1, 0] = 1.5
    beta[1, 2, 0] = 1.0
    beta[1, 2, 1] = 0.8
    beta[2, 3, 1] = 1.4
    beta[2, 4, 2] = 0.5
    beta[3, 4, 2] = 1.6

    emit = EmissionFitParams(
        p_departure=np.array([0.05, 0.10, 0.20, 0.95, 0.05]),
        mu_displacement=np.array([0.0, 0.5, 1.0, 5.0, 8.0]),
        sigma_displacement=np.array([0.5, 0.5, 0.5, 0.5, 0.5]),
        lambda_comm=np.array([0.5, 1.5, 3.0, 2.0, 1.0]),
    )
    return FitParameters(
        initial=InitialFitParams(logits=np.array([0.0, -3.0, -3.0, -3.0, -3.0])),
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
    )


def _make_clean_data(
    rng: np.random.Generator, n: int, t_plus_1: int
) -> tuple[FitParameters, FitData]:
    truth = _clean_truth()
    f = truth.transitions.beta.shape[2]
    inputs = np.zeros((n, t_plus_1, f), dtype=np.float64)
    inputs[:, :, 0] = np.linspace(0.0, 1.0, t_plus_1)[None, :]
    inputs[:, t_plus_1 // 3 :, 1] = 1.0
    inputs[:, 2 * t_plus_1 // 3 :, 2] = 1.0
    inputs[:, :, :] += 0.1 * rng.standard_normal(inputs.shape)
    sample = generate(truth, inputs, rng)
    data = FitData(
        inputs=inputs,
        departure=sample.departure,
        displacement=sample.displacement,
        comm=sample.comm,
        true_states=sample.states,
    )
    return truth, data


def _perturb_truth(truth: FitParameters, rng: np.random.Generator) -> FitParameters:
    alpha = truth.transitions.alpha.copy()
    beta = truth.transitions.beta.copy()
    finite = np.isfinite(alpha) & ~np.eye(alpha.shape[0], dtype=bool)
    alpha = np.where(finite, alpha + 0.5 * rng.standard_normal(alpha.shape), alpha)
    beta = beta + 0.3 * rng.standard_normal(beta.shape)
    p = np.clip(truth.emissions.p_departure + 0.1 * rng.standard_normal(5), 1e-3, 1 - 1e-3)
    mu = truth.emissions.mu_displacement + 0.5 * rng.standard_normal(5)
    sigma = np.maximum(truth.emissions.sigma_displacement + 0.2 * rng.standard_normal(5), 0.2)
    lam = np.maximum(truth.emissions.lambda_comm + 0.3 * rng.standard_normal(5), 0.1)
    emit = EmissionFitParams(
        p_departure=p, mu_displacement=mu, sigma_displacement=sigma, lambda_comm=lam
    )
    return FitParameters(
        initial=truth.initial,
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
        feature_names=truth.feature_names,
    )


@pytest.fixture(scope="module")
def fitted_run() -> tuple[FitParameters, FitData, FitParameters]:
    rng = np.random.default_rng(0)
    truth, data = _make_clean_data(rng, n=2000, t_plus_1=61)
    inits = [_perturb_truth(truth, rng) for _ in range(5)]
    best = None
    best_ll = -np.inf
    for init in inits:
        result = fit(
            data,
            n_restarts=1,
            em_config=EMConfig(max_iter=80, tol=1e-6),
            init="truth",
            rng=rng,
            truth_init=init,
        )
        if result.best.final_log_likelihood > best_ll:
            best = result.best
            best_ll = result.best.final_log_likelihood
    assert best is not None
    return truth, data, best.params


def test_state_recovery_accuracy(
    fitted_run: tuple[FitParameters, FitData, FitParameters],
) -> None:
    _truth, data, fit_params = fitted_run
    from iohmm_evac.diagnostics.decoding import viterbi

    fit_path = viterbi(fit_params, data)
    assert data.true_states is not None
    perm = align_states(data.true_states, fit_path, k=5)
    aligned = apply_permutation(fit_path, perm)
    accuracy = state_recovery_accuracy(data.true_states, aligned)
    assert accuracy >= 0.85, f"state recovery accuracy {accuracy:.3f} below 0.85 target"


def test_parameter_recovery_rmse(
    fitted_run: tuple[FitParameters, FitData, FitParameters],
) -> None:
    truth, data, fit_params = fitted_run
    from iohmm_evac.diagnostics.decoding import viterbi

    fit_path = viterbi(fit_params, data)
    assert data.true_states is not None
    perm = align_states(data.true_states, fit_path, k=5)
    aligned_fit = align_fit_to_truth(fit_params, perm)
    report = parameter_recovery(truth, aligned_fit)
    assert report.transition_beta_rmse <= 0.5, (
        f"transition β RMSE {report.transition_beta_rmse:.3f} above 0.5"
    )
    assert report.emission_mu_rmse <= 0.5, (
        f"emission μ RMSE {report.emission_mu_rmse:.3f} above 0.5"
    )