src/iohmm_evac/inference/cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac fit`` subcommand: run EM and write a fit bundle to disk."""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

import numpy as np

from iohmm_evac.diagnostics.decoding import viterbi
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.fit import fit
from iohmm_evac.inference.fit_params import dgp_truth_to_fit_init
from iohmm_evac.inference.io import write_fit_bundle
from iohmm_evac.params import (
    EmissionParams as DGPEmissionParams,
)
from iohmm_evac.params import (
    PopulationParams,
    SimulationConfig,
    TimelineParams,
    TransitionParams,
    TransitionRow,
)
from iohmm_evac.report.loader import load_bundle

__all__ = ["add_fit_subparser", "run_fit"]


def add_fit_subparser(
    subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
    """Register the ``fit`` subcommand on the top-level CLI parser."""
    p = subparsers.add_parser("fit", help="Fit an IO-HMM to a saved simulation bundle.")
    p.add_argument("--input", type=Path, required=True, help="Observations Parquet path.")
    p.add_argument(
        "--output",
        type=Path,
        default=Path("./output/fit/"),
        help="Directory to write the fit bundle into (default: ./output/fit/).",
    )
    p.add_argument("--restarts", type=int, default=1, help="Number of EM restarts (default 1).")
    p.add_argument(
        "--max-iter", type=int, default=200, help="Hard cap on EM iterations per restart."
    )
    p.add_argument("--tol", type=float, default=1e-5, help="Relative-LL change stopping threshold.")
    p.add_argument(
        "--init",
        choices=("random", "kmeans", "truth"),
        default="random",
        help=(
            "Parameter initialization strategy (default: random). "
            "'truth' is for tests / EM-stability audits — NOT for honest "
            "recovery measurement, since it seeds EM at the answer."
        ),
    )
    p.add_argument("--seed", type=int, default=0, help="RNG seed (default 0).")
    p.add_argument("--quiet", action="store_true", help="Suppress per-restart progress.")


def _build_dgp_truth_from_config(cfg: dict[str, object]) -> SimulationConfig:
    """Reconstruct a (partial) :class:`SimulationConfig` from a TOML config dict.

    Used for the ``--init truth`` path and by recovery diagnostics. The
    transition / emission / timeline / population sections are read from
    the TOML; feedback parameters stay at the dataclass defaults.
    """
    trans_dict = cfg.get("transitions", {})
    assert isinstance(trans_dict, dict)
    rows = {}
    for name, row in trans_dict.items():
        assert isinstance(row, dict)
        rows[name] = TransitionRow(**{k: float(v) for k, v in row.items()})
    transitions = TransitionParams(**rows)

    emit_dict = cfg.get("emissions", {})
    assert isinstance(emit_dict, dict)
    emissions = DGPEmissionParams(**{k: float(v) for k, v in emit_dict.items()})

    timeline_dict = cfg.get("timeline", {})
    assert isinstance(timeline_dict, dict)
    timeline_kwargs = {
        k: (tuple(v) if isinstance(v, list) else v) for k, v in timeline_dict.items()
    }
    timeline = TimelineParams(**timeline_kwargs)  # type: ignore[arg-type]

    pop_dict = cfg.get("population", {})
    assert isinstance(pop_dict, dict)
    population = PopulationParams(**{k: float(v) for k, v in pop_dict.items()})

    return SimulationConfig(
        transitions=transitions,
        emissions=emissions,
        timeline=timeline,
        population=population,
    )


def run_fit(args: argparse.Namespace) -> int:
    """Execute the ``fit`` subcommand."""
    bundle = load_bundle(args.input)
    data = bundle_to_fit_data(bundle)
    rng = np.random.default_rng(args.seed)

    em_config = EMConfig(
        max_iter=args.max_iter,
        tol=args.tol,
        verbose=not args.quiet,
    )
    truth_init = None
    if args.init == "truth":
        cfg = bundle.config
        truth_cfg = _build_dgp_truth_from_config(cfg)
        truth_init = dgp_truth_to_fit_init(
            truth_cfg.transitions, truth_cfg.emissions, truth_cfg.population
        )

    if not args.quiet:
        print(
            f"Fitting IO-HMM: N={data.n}, T={data.t_total}, "
            f"restarts={args.restarts}, init={args.init}",
            file=sys.stderr,
        )

    result = fit(
        data,
        n_restarts=args.restarts,
        em_config=em_config,
        init=args.init,
        rng=rng,
        truth_init=truth_init,
    )
    posterior = viterbi(result.best.params, data)
    paths = write_fit_bundle(result, posterior, args.output)
    if not args.quiet:
        for label, p in paths.items():
            print(f"{label}: {p}", file=sys.stderr)
        print(
            f"Best restart: #{result.best_index}, final LL={result.best.final_log_likelihood:.4f},"
            f" iters={result.best.iterations}, converged={result.best.converged}",
            file=sys.stderr,
        )
    return 0