src/iohmm_evac/inference/fit_params.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Parameter dataclasses for IO-HMM inference, decoupled from the DGP side.

The IO-HMM views the world as ``K`` latent states, an input vector ``u_t``
of dimension ``F``, and three emission channels (Bernoulli ``D``, Gaussian
``X``, Poisson ``C``). Transitions are multinomial-logit:

    A_kj(u) = exp(alpha_kj + beta_kj^T u) / sum_l exp(alpha_kl + beta_kl^T u)

with ``(alpha_kk, beta_kk) = (0, 0)`` for identifiability and
``alpha_kj = -inf`` for forbidden destinations.

Self-loops and forbidden-destination cells in the ``alpha`` and ``beta``
arrays are pinned to fixed values (``0`` and a finite sentinel respectively
for self; ``-inf`` and zeros for forbidden) and never enter the optimization
parameter vector — see ``learnable_indices``.
"""

from __future__ import annotations

from dataclasses import dataclass, field, replace

import numpy as np

from iohmm_evac.params import (
    EmissionParams as DGPEmissionParams,
)
from iohmm_evac.params import (
    PopulationParams as DGPPopulationParams,
)
from iohmm_evac.params import TransitionParams as DGPTransitionParams
from iohmm_evac.types import BoolArray, FloatArray, State

__all__ = [
    "ALLOWED_TRANSITIONS",
    "FEATURE_NAMES",
    "EmissionFitParams",
    "FitParameters",
    "InitialFitParams",
    "K",
    "TransitionFitParams",
    "allowed_mask",
    "dgp_truth_to_fit_init",
    "learnable_indices",
]


K: int = 5
"""Number of latent states. Matches the DGP's ``State`` enum."""


FEATURE_NAMES: tuple[str, ...] = ("vol", "mand", "rho", "r", "v", "tau")
"""Exogenous IO-HMM input features, in order. ``F = len(FEATURE_NAMES)``.

Endogenous DGP features (``pi``, ``c``, ``tir``) are intentionally absent —
they depend on the latent state path and so cannot enter the inputs of an
inference model that does not know that path. This mis-specification is a
deliberate design choice; see ``docs/inference.md``.
"""

F: int = len(FEATURE_NAMES)


_ALLOWED_PAIRS: tuple[tuple[State, State], ...] = (
    (State.UA, State.AW),
    (State.AW, State.UA),
    (State.AW, State.PR),
    (State.PR, State.ER),
    (State.PR, State.SH),
    (State.ER, State.SH),
)


def allowed_mask() -> BoolArray:
    """Return the K×K boolean mask of allowed (incl. self-loop) transitions."""
    mask = np.zeros((K, K), dtype=bool)
    for k in range(K):
        mask[k, k] = True
    for src, dst in _ALLOWED_PAIRS:
        mask[int(src), int(dst)] = True
    return mask


ALLOWED_TRANSITIONS: BoolArray = allowed_mask()
"""Module-level constant. Index ``[k, j]`` is True iff ``k -> j`` is allowed."""


def learnable_indices() -> tuple[BoolArray, BoolArray]:
    """Return (allowed_non_self_mask, self_mask).

    ``allowed_non_self_mask[k, j]`` is True for cells that participate in
    L-BFGS optimization. ``self_mask[k, j]`` is True iff ``k == j``.
    """
    self_mask = np.eye(K, dtype=bool)
    learnable = ALLOWED_TRANSITIONS & ~self_mask
    return learnable, self_mask


@dataclass(frozen=True, slots=True)
class InitialFitParams:
    """Initial-state distribution.

    ``logits`` are unnormalized scores; the actual distribution is the
    softmax. Free parameters: ``K - 1`` (one anchored to 0).
    """

    logits: FloatArray

    def probs(self) -> FloatArray:
        """Return the normalized initial-state distribution."""
        m = np.max(self.logits)
        e = np.exp(self.logits - m)
        out = e / e.sum()
        return np.asarray(out, dtype=np.float64)


@dataclass(frozen=True, slots=True)
class TransitionFitParams:
    """K×K transition logit parameters under the IO-HMM input ``u``."""

    alpha: FloatArray  # shape (K, K); self-loops 0, forbidden -inf
    beta: FloatArray  # shape (K, K, F); self-loops zeros, forbidden zeros


@dataclass(frozen=True, slots=True)
class EmissionFitParams:
    """State-conditional emission parameters."""

    p_departure: FloatArray  # shape (K,) Bernoulli rates
    mu_displacement: FloatArray  # shape (K,)
    sigma_displacement: FloatArray  # shape (K,) std-devs (>= sigma_floor)
    lambda_comm: FloatArray  # shape (K,) Poisson rates
    sigma_floor: float = 1e-2


@dataclass(frozen=True, slots=True)
class FitParameters:
    """Top-level IO-HMM parameter bundle."""

    initial: InitialFitParams
    transitions: TransitionFitParams
    emissions: EmissionFitParams
    feature_names: tuple[str, ...] = field(default=FEATURE_NAMES)


def _trans_pair_to_io(
    dgp_alpha: float,
    dgp_betas: dict[str, float],
) -> tuple[float, FloatArray]:
    """Translate one DGP TransitionRow into IO-HMM (alpha, beta_vec).

    The DGP uses ``beta_negc * (-c)``, ``beta_negr * (-r)``,
    ``beta_negv * (-v)`` (note: negation, not ``1 - x``). The IO-HMM input
    vector uses ``r`` and ``v`` directly and *omits* ``c`` (endogenous).

    ``beta_negc`` therefore drops out of the IO-HMM image entirely (the IO
    vector has no ``c``); ``beta_negr`` and ``beta_negv`` are folded by sign
    into the corresponding ``r`` and ``v`` slots.
    """
    beta_vec = np.zeros(F, dtype=np.float64)
    idx = {name: i for i, name in enumerate(FEATURE_NAMES)}
    beta_vec[idx["vol"]] = dgp_betas.get("beta_vol", 0.0)
    beta_vec[idx["mand"]] = dgp_betas.get("beta_mand", 0.0)
    beta_vec[idx["rho"]] = dgp_betas.get("beta_rho", 0.0)
    beta_vec[idx["r"]] = dgp_betas.get("beta_r", 0.0) - dgp_betas.get("beta_negr", 0.0)
    beta_vec[idx["v"]] = dgp_betas.get("beta_v", 0.0) - dgp_betas.get("beta_negv", 0.0)
    beta_vec[idx["tau"]] = dgp_betas.get("beta_tau", 0.0)
    return dgp_alpha, beta_vec


def _row_dict(row: object) -> dict[str, float]:
    return {
        f: float(getattr(row, f))
        for f in (
            "alpha",
            "beta_vol",
            "beta_mand",
            "beta_rho",
            "beta_pi",
            "beta_r",
            "beta_v",
            "beta_tau",
            "beta_negc",
            "beta_negr",
            "beta_negv",
            "beta_tir",
        )
    }


def _dgp_displacement_moments(
    emissions: DGPEmissionParams, population: DGPPopulationParams
) -> tuple[FloatArray, FloatArray]:
    """Return DGP-implied per-state displacement (μ, σ) for the IO-HMM init.

    Derivations:

    * **Half-normal** (UA, AW, PR — ``|N(0, σ²)|`` with
      ``σ = displacement_idle_sigma``):
      ``E[X] = σ √(2/π)``, ``Var(X) = σ²(1 - 2/π)``.
    * **Mid-evacuation** (ER): mid-evacuation approximation,
      ``μ = (dest_lo + dest_hi) / 4``,
      ``σ = √((dest_hi - dest_lo)² / 12 + 5²)``. The ``+5²`` term is a
      conservative inflation accounting for ``tir``/``c_t`` variability the
      IO-HMM does not see.
    * **Uniform** (SH — ``Uniform(dest_lo, dest_hi)`` for the dominant
      ``away`` mode; the small ``home`` component is folded into the
      same Gaussian here as a first approximation):
      ``E[X] = (dest_lo + dest_hi) / 2``,
      ``Var(X) = (dest_hi - dest_lo)² / 12``.
    """
    sigma_idle = float(emissions.displacement_idle_sigma)
    mu_half = sigma_idle * np.sqrt(2.0 / np.pi)
    sigma_half = sigma_idle * np.sqrt(1.0 - 2.0 / np.pi)

    dest_lo = float(population.dest_lo)
    dest_hi = float(population.dest_hi)
    dest_span = dest_hi - dest_lo

    mu_er = (dest_lo + dest_hi) / 4.0
    sigma_er = float(np.sqrt(dest_span * dest_span / 12.0 + 25.0))

    mu_sh = (dest_lo + dest_hi) / 2.0
    sigma_sh = float(np.sqrt(dest_span * dest_span / 12.0))

    mu = np.array(
        [mu_half, mu_half, mu_half, mu_er, mu_sh],
        dtype=np.float64,
    )
    sigma = np.array(
        [sigma_half, sigma_half, sigma_half, sigma_er, sigma_sh],
        dtype=np.float64,
    )
    return mu, sigma


def dgp_truth_to_fit_init(
    transitions: DGPTransitionParams,
    emissions: DGPEmissionParams,
    population: DGPPopulationParams | None = None,
) -> FitParameters:
    """Build a :class:`FitParameters` initialized at the DGP's true values.

    Used for the ``--init truth`` testing path and as a reference point for
    parameter recovery diagnostics. Endogenous-feedback DGP coefficients
    (``beta_pi``, ``beta_negc``, ``beta_tir``) have no IO-HMM image and are
    discarded by this projection.

    ``population`` (default :class:`DGPPopulationParams`) is read for
    ``dest_lo`` / ``dest_hi`` to compute DGP-implied displacement moments —
    see :func:`_dgp_displacement_moments`.
    """
    pop = population if population is not None else DGPPopulationParams()

    alpha = np.full((K, K), -np.inf, dtype=np.float64)
    beta = np.zeros((K, K, F), dtype=np.float64)
    np.fill_diagonal(alpha, 0.0)

    from iohmm_evac.params import TransitionRow as _TransitionRow

    pairs: list[tuple[int, int, _TransitionRow]] = [
        (int(State.UA), int(State.AW), transitions.ua_to_aw),
        (int(State.AW), int(State.UA), transitions.aw_to_ua),
        (int(State.AW), int(State.PR), transitions.aw_to_pr),
        (int(State.PR), int(State.ER), transitions.pr_to_er),
        (int(State.PR), int(State.SH), transitions.pr_to_sh),
        (int(State.ER), int(State.SH), transitions.er_to_sh),
    ]
    for k, j, row in pairs:
        a, b = _trans_pair_to_io(float(row.alpha), _row_dict(row))
        alpha[k, j] = a
        beta[k, j] = b

    initial_logits = np.full(K, -10.0, dtype=np.float64)
    initial_logits[int(State.UA)] = 0.0
    init_params = InitialFitParams(logits=initial_logits)

    p_departure = np.full(K, emissions.p_departure_other, dtype=np.float64)
    p_departure[int(State.ER)] = emissions.p_departure_er
    mu_displacement, sigma_displacement = _dgp_displacement_moments(emissions, pop)
    lambda_comm = np.array(
        [
            emissions.lambda_ua,
            emissions.lambda_aw,
            emissions.lambda_pr,
            emissions.lambda_er,
            emissions.lambda_sh,
        ],
        dtype=np.float64,
    )
    emit_params = EmissionFitParams(
        p_departure=p_departure,
        mu_displacement=mu_displacement,
        sigma_displacement=sigma_displacement,
        lambda_comm=lambda_comm,
    )

    trans_params = TransitionFitParams(alpha=alpha, beta=beta)
    return FitParameters(initial=init_params, transitions=trans_params, emissions=emit_params)


def with_initial(params: FitParameters, initial: InitialFitParams) -> FitParameters:
    """Return ``params`` with a new initial distribution; helper for E/M-step."""
    return replace(params, initial=initial)