# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""EM loop with monotonicity tracking."""
from __future__ import annotations
import logging
from dataclasses import dataclass
import numpy as np
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.inference.forward_backward import forward_backward
from iohmm_evac.inference.m_step import m_step
__all__ = ["EMConfig", "EMResult", "run_em"]
_log = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class EMConfig:
"""Tuning knobs for the EM loop."""
max_iter: int = 200
tol: float = 1e-5
"""Relative log-likelihood change stopping threshold."""
verbose: bool = False
sigma_floor: float = 1e-2
transition_maxiter: int = 20
"""Per-row L-BFGS-B iteration cap. EM is GEM-style; partial M-steps are fine."""
transition_tol: float = 1e-5
@dataclass(frozen=True, slots=True)
class EMResult:
"""Output of one EM run."""
params: FitParameters
log_likelihood_trace: tuple[float, ...]
iterations: int
converged: bool
final_log_likelihood: float
def _total_log_likelihood(per_household: np.ndarray) -> float:
return float(per_household.sum())
def run_em(initial: FitParameters, data: FitData, config: EMConfig | None = None) -> EMResult:
"""Run EM to convergence; returns the best parameters and trace."""
cfg = config or EMConfig()
params = initial
trace: list[float] = []
prev_ll = -np.inf
iteration = 0
converged = False
last_good = params
for iteration in range(1, cfg.max_iter + 1):
fb = forward_backward(params, data)
ll = _total_log_likelihood(fb.log_likelihood)
trace.append(ll)
if cfg.verbose:
_log.info("EM iter %d: log_likelihood=%.6f", iteration, ll)
# Monotonicity check (after the first iteration).
if iteration > 1:
if ll < prev_ll - 1e-6:
_log.warning(
"EM log-likelihood decreased at iter %d (%.6f -> %.6f). "
"Reverting to previous params.",
iteration,
prev_ll,
ll,
)
params = last_good
trace[-1] = prev_ll
ll = prev_ll
break
denom = max(abs(ll), 1.0)
if (ll - prev_ll) / denom < cfg.tol:
converged = True
break
last_good = params
prev_ll = ll
params = m_step(
fb.log_gamma,
fb.log_xi,
data,
params,
sigma_floor=cfg.sigma_floor,
transition_maxiter=cfg.transition_maxiter,
transition_tol=cfg.transition_tol,
)
return EMResult(
params=params,
log_likelihood_trace=tuple(trace),
iterations=iteration,
converged=converged,
final_log_likelihood=trace[-1] if trace else float("-inf"),
)