# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac fit`` subcommand: run EM and write a fit bundle to disk."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import numpy as np
from iohmm_evac.diagnostics.decoding import viterbi
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.fit import fit
from iohmm_evac.inference.fit_params import dgp_truth_to_fit_init
from iohmm_evac.inference.io import write_fit_bundle
from iohmm_evac.params import (
EmissionParams as DGPEmissionParams,
)
from iohmm_evac.params import (
PopulationParams,
SimulationConfig,
TimelineParams,
TransitionParams,
TransitionRow,
)
from iohmm_evac.report.loader import load_bundle
__all__ = ["add_fit_subparser", "run_fit"]
def add_fit_subparser(
subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
"""Register the ``fit`` subcommand on the top-level CLI parser."""
p = subparsers.add_parser("fit", help="Fit an IO-HMM to a saved simulation bundle.")
p.add_argument("--input", type=Path, required=True, help="Observations Parquet path.")
p.add_argument(
"--output",
type=Path,
default=Path("./output/fit/"),
help="Directory to write the fit bundle into (default: ./output/fit/).",
)
p.add_argument("--restarts", type=int, default=1, help="Number of EM restarts (default 1).")
p.add_argument(
"--max-iter", type=int, default=200, help="Hard cap on EM iterations per restart."
)
p.add_argument("--tol", type=float, default=1e-5, help="Relative-LL change stopping threshold.")
p.add_argument(
"--init",
choices=("random", "kmeans", "truth"),
default="random",
help=(
"Parameter initialization strategy (default: random). "
"'truth' is for tests / EM-stability audits — NOT for honest "
"recovery measurement, since it seeds EM at the answer."
),
)
p.add_argument("--seed", type=int, default=0, help="RNG seed (default 0).")
p.add_argument("--quiet", action="store_true", help="Suppress per-restart progress.")
def _build_dgp_truth_from_config(cfg: dict[str, object]) -> SimulationConfig:
"""Reconstruct a (partial) :class:`SimulationConfig` from a TOML config dict.
Used for the ``--init truth`` path and by recovery diagnostics. The
transition / emission / timeline / population sections are read from
the TOML; feedback parameters stay at the dataclass defaults.
"""
trans_dict = cfg.get("transitions", {})
assert isinstance(trans_dict, dict)
rows = {}
for name, row in trans_dict.items():
assert isinstance(row, dict)
rows[name] = TransitionRow(**{k: float(v) for k, v in row.items()})
transitions = TransitionParams(**rows)
emit_dict = cfg.get("emissions", {})
assert isinstance(emit_dict, dict)
emissions = DGPEmissionParams(**{k: float(v) for k, v in emit_dict.items()})
timeline_dict = cfg.get("timeline", {})
assert isinstance(timeline_dict, dict)
timeline_kwargs = {
k: (tuple(v) if isinstance(v, list) else v) for k, v in timeline_dict.items()
}
timeline = TimelineParams(**timeline_kwargs) # type: ignore[arg-type]
pop_dict = cfg.get("population", {})
assert isinstance(pop_dict, dict)
population = PopulationParams(**{k: float(v) for k, v in pop_dict.items()})
return SimulationConfig(
transitions=transitions,
emissions=emissions,
timeline=timeline,
population=population,
)
def run_fit(args: argparse.Namespace) -> int:
"""Execute the ``fit`` subcommand."""
bundle = load_bundle(args.input)
data = bundle_to_fit_data(bundle)
rng = np.random.default_rng(args.seed)
em_config = EMConfig(
max_iter=args.max_iter,
tol=args.tol,
verbose=not args.quiet,
)
truth_init = None
if args.init == "truth":
cfg = bundle.config
truth_cfg = _build_dgp_truth_from_config(cfg)
truth_init = dgp_truth_to_fit_init(
truth_cfg.transitions, truth_cfg.emissions, truth_cfg.population
)
if not args.quiet:
print(
f"Fitting IO-HMM: N={data.n}, T={data.t_total}, "
f"restarts={args.restarts}, init={args.init}",
file=sys.stderr,
)
result = fit(
data,
n_restarts=args.restarts,
em_config=em_config,
init=args.init,
rng=rng,
truth_init=truth_init,
)
posterior = viterbi(result.best.params, data)
paths = write_fit_bundle(result, posterior, args.output)
if not args.quiet:
for label, p in paths.items():
print(f"{label}: {p}", file=sys.stderr)
print(
f"Best restart: #{result.best_index}, final LL={result.best.final_log_likelihood:.4f},"
f" iters={result.best.iterations}, converged={result.best.converged}",
file=sys.stderr,
)
return 0