src/iohmm_evac/inference/em.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""EM loop with monotonicity tracking."""

from __future__ import annotations

import logging
from dataclasses import dataclass

import numpy as np

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.inference.forward_backward import forward_backward
from iohmm_evac.inference.m_step import m_step

__all__ = ["EMConfig", "EMResult", "run_em"]

_log = logging.getLogger(__name__)


@dataclass(frozen=True, slots=True)
class EMConfig:
    """Tuning knobs for the EM loop."""

    max_iter: int = 200
    tol: float = 1e-5
    """Relative log-likelihood change stopping threshold."""
    verbose: bool = False
    sigma_floor: float = 1e-2
    transition_maxiter: int = 20
    """Per-row L-BFGS-B iteration cap. EM is GEM-style; partial M-steps are fine."""
    transition_tol: float = 1e-5


@dataclass(frozen=True, slots=True)
class EMResult:
    """Output of one EM run."""

    params: FitParameters
    log_likelihood_trace: tuple[float, ...]
    iterations: int
    converged: bool
    final_log_likelihood: float


def _total_log_likelihood(per_household: np.ndarray) -> float:
    return float(per_household.sum())


def run_em(initial: FitParameters, data: FitData, config: EMConfig | None = None) -> EMResult:
    """Run EM to convergence; returns the best parameters and trace."""
    cfg = config or EMConfig()
    params = initial
    trace: list[float] = []
    prev_ll = -np.inf
    iteration = 0
    converged = False
    last_good = params

    for iteration in range(1, cfg.max_iter + 1):
        fb = forward_backward(params, data)
        ll = _total_log_likelihood(fb.log_likelihood)
        trace.append(ll)

        if cfg.verbose:
            _log.info("EM iter %d: log_likelihood=%.6f", iteration, ll)

        # Monotonicity check (after the first iteration).
        if iteration > 1:
            if ll < prev_ll - 1e-6:
                _log.warning(
                    "EM log-likelihood decreased at iter %d (%.6f -> %.6f). "
                    "Reverting to previous params.",
                    iteration,
                    prev_ll,
                    ll,
                )
                params = last_good
                trace[-1] = prev_ll
                ll = prev_ll
                break
            denom = max(abs(ll), 1.0)
            if (ll - prev_ll) / denom < cfg.tol:
                converged = True
                break

        last_good = params
        prev_ll = ll
        params = m_step(
            fb.log_gamma,
            fb.log_xi,
            data,
            params,
            sigma_floor=cfg.sigma_floor,
            transition_maxiter=cfg.transition_maxiter,
            transition_tol=cfg.transition_tol,
        )

    return EMResult(
        params=params,
        log_likelihood_trace=tuple(trace),
        iterations=iteration,
        converged=converged,
        final_log_likelihood=trace[-1] if trace else float("-inf"),
    )