tests/test_em.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""EM-loop tests: monotonicity and convergence on a synthetic problem."""

from __future__ import annotations

import numpy as np

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.em import EMConfig, run_em
from iohmm_evac.inference.fit_params import (
    EmissionFitParams,
    FitParameters,
    InitialFitParams,
    TransitionFitParams,
)
from iohmm_evac.inference.forward_backward import forward_backward
from tests._clean_dgp import generate


def _make_clean_truth(rng: np.random.Generator) -> tuple[FitParameters, FitData]:
    """A small K=3, T=20, N=50 problem generated from a clean DGP."""
    k = 3
    f = 2
    n = 50
    t_plus_1 = 21

    alpha = np.array(
        [
            [0.0, -1.5, -3.0],
            [-3.0, 0.0, -1.5],
            [-3.0, -3.0, 0.0],
        ],
        dtype=np.float64,
    )
    beta = np.zeros((k, k, f))
    beta[0, 1] = np.array([1.5, 0.0])
    beta[1, 2] = np.array([0.0, 1.5])
    emit = EmissionFitParams(
        p_departure=np.array([0.05, 0.5, 0.95]),
        mu_displacement=np.array([0.0, 1.0, 5.0]),
        sigma_displacement=np.array([0.5, 0.5, 0.5]),
        lambda_comm=np.array([0.5, 1.5, 3.0]),
    )
    truth = FitParameters(
        initial=InitialFitParams(logits=np.array([0.0, -2.0, -2.0])),
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
    )

    inputs = np.zeros((n, t_plus_1, f), dtype=np.float64)
    # u[:, t, 0] ramps with time; u[:, t, 1] turns on at t >= T/2.
    inputs[:, :, 0] = np.linspace(0.0, 1.0, t_plus_1)[None, :]
    inputs[:, t_plus_1 // 2 :, 1] = 1.0

    sample = generate(truth, inputs, rng)
    data = FitData(
        inputs=inputs,
        departure=sample.departure,
        displacement=sample.displacement,
        comm=sample.comm,
        true_states=sample.states,
    )
    return truth, data


def _perturb(truth: FitParameters, rng: np.random.Generator) -> FitParameters:
    """Move ``truth`` slightly so EM has to actually do work to recover it."""
    alpha = truth.transitions.alpha.copy()
    beta = truth.transitions.beta.copy()
    finite_mask = np.isfinite(alpha) & ~np.eye(alpha.shape[0], dtype=bool)
    alpha = np.where(finite_mask, alpha + 0.3 * rng.standard_normal(alpha.shape), alpha)
    beta = beta + 0.2 * rng.standard_normal(beta.shape)
    init_logits = truth.initial.logits + 0.2 * rng.standard_normal(truth.initial.logits.shape)
    p = np.clip(truth.emissions.p_departure + 0.05 * rng.standard_normal(3), 1e-3, 1 - 1e-3)
    mu = truth.emissions.mu_displacement + 0.3 * rng.standard_normal(3)
    sigma = np.maximum(truth.emissions.sigma_displacement + 0.2 * rng.standard_normal(3), 0.1)
    lam = np.maximum(truth.emissions.lambda_comm + 0.2 * rng.standard_normal(3), 0.1)
    emit = EmissionFitParams(
        p_departure=p,
        mu_displacement=mu,
        sigma_displacement=sigma,
        lambda_comm=lam,
    )
    return FitParameters(
        initial=InitialFitParams(logits=init_logits),
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
        feature_names=truth.feature_names,
    )


def test_em_log_likelihood_is_non_decreasing() -> None:
    rng = np.random.default_rng(7)
    truth, data = _make_clean_truth(rng)
    init = _perturb(truth, rng)
    em = run_em(init, data, EMConfig(max_iter=20, tol=1e-8))
    trace = em.log_likelihood_trace
    assert len(trace) >= 2
    diffs = np.diff(np.array(trace))
    assert (diffs > -1e-9).all(), f"non-monotone trace: {trace}"


def test_em_converges_close_to_truth_log_likelihood() -> None:
    rng = np.random.default_rng(11)
    truth, data = _make_clean_truth(rng)
    init = _perturb(truth, rng)
    em = run_em(init, data, EMConfig(max_iter=100, tol=1e-8))

    truth_ll = float(forward_backward(truth, data).log_likelihood.sum())
    fit_ll = em.final_log_likelihood
    n_obs = data.n * (data.t_total + 1)
    per_obs_gap = (truth_ll - fit_ll) / n_obs
    # Relax slightly from the design target of ``1e-3`` since the optimization
    # is finite-N and may exceed the truth's likelihood by overfitting.
    assert per_obs_gap < 5e-2, (
        f"per-observation LL gap {per_obs_gap:.4e} too large "
        f"(truth={truth_ll:.3f}, fit={fit_ll:.3f}, n_obs={n_obs})"
    )