tests/test_inference_modules.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Direct (in-process) coverage of inference and report-recovery modules."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from iohmm_evac.cli import main
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.diagnostics.alignment import align_states
from iohmm_evac.diagnostics.recovery import (
    align_fit_to_truth,
    parameter_recovery,
    state_recovery_accuracy,
    state_recovery_confusion,
)
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.fit import fit
from iohmm_evac.inference.fit_params import dgp_truth_to_fit_init
from iohmm_evac.inference.initialization import (
    from_dgp_truth,
    kmeans_init,
    random_initialization,
)
from iohmm_evac.inference.io import read_fit_bundle, write_fit_bundle
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.report.recovery_plots import (
    plot_log_likelihood_trace,
    plot_parameter_recovery,
    plot_state_recovery_confusion,
)


@pytest.fixture
def tiny_bundle_path(tmp_path: Path) -> Path:
    config = SimulationConfig(n_households=80, n_hours=18, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    out = tmp_path / "tiny.parquet"
    write_results(result, out)
    return out


def test_random_initialization_shapes() -> None:
    rng = np.random.default_rng(0)
    params = random_initialization(rng)
    assert params.transitions.alpha.shape == (5, 5)
    assert params.transitions.beta.shape[0] == 5
    assert params.emissions.p_departure.shape == (5,)


def test_kmeans_init_returns_valid_params(tiny_bundle_path: Path) -> None:
    bundle = load_bundle(tiny_bundle_path)
    data = bundle_to_fit_data(bundle)
    rng = np.random.default_rng(0)
    params = kmeans_init(data, rng)
    assert np.isfinite(params.emissions.mu_displacement).all()


def test_from_dgp_truth_matches_dgp_truth_adapter() -> None:
    cfg = SimulationConfig()
    a = from_dgp_truth(cfg.transitions, cfg.emissions)
    b = dgp_truth_to_fit_init(cfg.transitions, cfg.emissions)
    np.testing.assert_array_equal(a.transitions.alpha, b.transitions.alpha)


def test_fit_io_round_trip(tmp_path: Path, tiny_bundle_path: Path) -> None:
    bundle = load_bundle(tiny_bundle_path)
    data = bundle_to_fit_data(bundle)
    rng = np.random.default_rng(0)
    cfg_truth = SimulationConfig()
    truth_init = dgp_truth_to_fit_init(cfg_truth.transitions, cfg_truth.emissions)
    result = fit(
        data,
        n_restarts=1,
        em_config=EMConfig(max_iter=2, tol=1e-3),
        init="truth",
        rng=rng,
        truth_init=truth_init,
    )
    posterior = np.zeros((data.n, data.t_total + 1), dtype=np.int64)
    fit_dir = tmp_path / "fit"
    paths = write_fit_bundle(result, posterior, fit_dir)
    for label, p in paths.items():
        assert p.exists(), label
    fit_bundle = read_fit_bundle(fit_dir)
    assert fit_bundle.params.transitions.alpha.shape == (5, 5)
    np.testing.assert_array_equal(fit_bundle.posterior_states, posterior)


def test_fit_random_init_runs(tiny_bundle_path: Path) -> None:
    bundle = load_bundle(tiny_bundle_path)
    data = bundle_to_fit_data(bundle)
    result = fit(
        data,
        n_restarts=2,
        em_config=EMConfig(max_iter=2, tol=1e-3),
        init="random",
        rng=np.random.default_rng(0),
    )
    assert len(result.all_runs) == 2
    assert result.best_index in {0, 1}


def test_fit_kmeans_init_runs(tiny_bundle_path: Path) -> None:
    bundle = load_bundle(tiny_bundle_path)
    data = bundle_to_fit_data(bundle)
    result = fit(
        data,
        n_restarts=1,
        em_config=EMConfig(max_iter=2, tol=1e-3),
        init="kmeans",
        rng=np.random.default_rng(0),
    )
    assert result.best.iterations >= 1


def test_fit_truth_init_requires_truth(tiny_bundle_path: Path) -> None:
    bundle = load_bundle(tiny_bundle_path)
    data = bundle_to_fit_data(bundle)
    with pytest.raises(ValueError, match="truth"):
        fit(
            data,
            n_restarts=1,
            em_config=EMConfig(max_iter=1),
            init="truth",
            rng=np.random.default_rng(0),
        )


def test_fit_unknown_init_strategy(tiny_bundle_path: Path) -> None:
    bundle = load_bundle(tiny_bundle_path)
    data = bundle_to_fit_data(bundle)
    with pytest.raises(ValueError, match="Unknown init strategy"):
        fit(
            data,
            n_restarts=1,
            em_config=EMConfig(max_iter=1),
            init="bogus",
            rng=np.random.default_rng(0),
        )


def test_fit_zero_restarts_rejected(tiny_bundle_path: Path) -> None:
    bundle = load_bundle(tiny_bundle_path)
    data = bundle_to_fit_data(bundle)
    with pytest.raises(ValueError, match="n_restarts"):
        fit(data, n_restarts=0)


def test_recovery_plots_render(tmp_path: Path, tiny_bundle_path: Path) -> None:
    import matplotlib.pyplot as plt

    bundle = load_bundle(tiny_bundle_path)
    data = bundle_to_fit_data(bundle)
    cfg_truth = SimulationConfig()
    truth_init = dgp_truth_to_fit_init(cfg_truth.transitions, cfg_truth.emissions)
    result = fit(
        data,
        n_restarts=1,
        em_config=EMConfig(max_iter=2, tol=1e-3),
        init="truth",
        rng=np.random.default_rng(0),
        truth_init=truth_init,
    )
    from iohmm_evac.diagnostics.decoding import viterbi

    fit_path = viterbi(result.best.params, data)
    assert data.true_states is not None
    perm = align_states(data.true_states, fit_path, k=5)
    confusion = state_recovery_confusion(data.true_states, perm[fit_path])
    accuracy = state_recovery_accuracy(data.true_states, perm[fit_path])
    aligned = align_fit_to_truth(result.best.params, perm)
    report = parameter_recovery(truth_init, aligned)

    fig, ax = plt.subplots()
    plot_state_recovery_confusion(confusion, ax=ax)
    plt.close(fig)

    fig, ax = plt.subplots()
    plot_parameter_recovery(report, ax=ax)
    plt.close(fig)

    fig, axes = plt.subplots(1, 2)
    plot_log_likelihood_trace(
        [list(r.log_likelihood_trace) for r in result.all_runs],
        ax=list(axes),
        best_index=result.best_index,
    )
    plt.close(fig)
    assert 0.0 <= accuracy <= 1.0


def test_cli_fit_diagnose_inproc(tmp_path: Path, tiny_bundle_path: Path) -> None:
    fit_dir = tmp_path / "fit"
    rc = main(
        [
            "fit",
            "--input",
            str(tiny_bundle_path),
            "--output",
            str(fit_dir),
            "--init",
            "truth",
            "--max-iter",
            "2",
            "--quiet",
        ]
    )
    assert rc == 0
    assert (fit_dir / "theta.toml").exists()

    rc = main(
        [
            "diagnose",
            "recovery",
            "--fit",
            str(fit_dir),
            "--truth",
            str(tiny_bundle_path),
        ]
    )
    assert rc == 0
    assert (fit_dir / "recovery.toml").exists()


def test_cli_report_recovery_inproc(tmp_path: Path, tiny_bundle_path: Path) -> None:
    fit_dir = tmp_path / "fit"
    main(
        [
            "fit",
            "--input",
            str(tiny_bundle_path),
            "--output",
            str(fit_dir),
            "--init",
            "truth",
            "--max-iter",
            "2",
            "--quiet",
        ]
    )
    confusion_png = tmp_path / "c.png"
    rc = main(
        [
            "report",
            "recovery-confusion",
            "--fit",
            str(fit_dir),
            "--truth",
            str(tiny_bundle_path),
            "--output",
            str(confusion_png),
        ]
    )
    assert rc == 0
    assert confusion_png.exists()

    param_png = tmp_path / "p.png"
    rc = main(
        [
            "report",
            "parameter-recovery",
            "--fit",
            str(fit_dir),
            "--truth",
            str(tiny_bundle_path),
            "--output",
            str(param_png),
        ]
    )
    assert rc == 0
    assert param_png.exists()

    ll_png = tmp_path / "ll.png"
    rc = main(["report", "ll-trace", "--fit", str(fit_dir), "--output", str(ll_png)])
    assert rc == 0

    rc = main(["report", "fit-summary", "--fit", str(fit_dir)])
    assert rc == 0