src/iohmm_evac/dgp/transitions.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Multinomial-logit state transitions, vectorized per origin state.

For each origin state ``k``, this module computes the linear logits for every
allowed destination (with the self-loop pinned at 0), softmaxes them, and
samples the next state for every household currently in ``k``. SH is
absorbing.
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from numpy.random import Generator

from iohmm_evac.params import TransitionParams
from iohmm_evac.types import FloatArray, IntArray, State

__all__ = [
    "StepInputs",
    "categorical",
    "sample_transitions",
    "softmax",
    "transition_probs_for_state",
]


@dataclass(frozen=True, slots=True)
class StepInputs:
    """Per-step exogenous and feedback inputs needed by the transition model."""

    rho: FloatArray
    """Local risk per household, shape (N,)."""

    vol: int
    mand: int
    tau_norm: float
    """``(1 - τ_t / T) = t / T``; closer to landfall ⇒ closer to 1."""
    pi: float
    """Peer-departure share."""
    c: float
    """Network congestion."""


def softmax(logits: FloatArray) -> FloatArray:
    """Row-wise numerically stable softmax."""
    m = np.max(logits, axis=1, keepdims=True)
    e = np.exp(logits - m)
    s = np.sum(e, axis=1, keepdims=True)
    return np.asarray(e / s, dtype=np.float64)


def categorical(probs: FloatArray, rng: Generator) -> IntArray:
    """Sample one categorical draw per row of ``probs``."""
    cum: FloatArray = np.cumsum(probs, axis=1)
    # Clamp the final cumulative to 1 to absorb floating-point drift.
    cum[:, -1] = 1.0
    u = np.asarray(rng.random(size=probs.shape[0]), dtype=np.float64).reshape(-1, 1)
    return np.asarray(np.argmax(u < cum, axis=1), dtype=np.int64)


def _logits_from_ua(
    inputs: StepInputs, risk: FloatArray, idx: IntArray, params: TransitionParams
) -> FloatArray:
    """Logits for households currently in UA: columns [UA, AW]."""
    n = idx.shape[0]
    row = params.ua_to_aw
    eta_aw = (
        row.alpha
        + row.beta_vol * inputs.vol
        + row.beta_mand * inputs.mand
        + row.beta_rho * inputs.rho[idx]
        + row.beta_r * risk[idx]
        + row.beta_tau * inputs.tau_norm
    )
    return np.column_stack([np.zeros(n), eta_aw])


def _logits_from_aw(
    inputs: StepInputs,
    risk: FloatArray,
    vehicle: FloatArray,
    idx: IntArray,
    params: TransitionParams,
) -> FloatArray:
    """Logits for households currently in AW: columns [AW, UA, PR]."""
    n = idx.shape[0]
    eta_ua = np.full(n, params.aw_to_ua.alpha, dtype=np.float64)
    pr = params.aw_to_pr
    eta_pr = (
        pr.alpha
        + pr.beta_mand * inputs.mand
        + pr.beta_rho * inputs.rho[idx]
        + pr.beta_pi * inputs.pi
        + pr.beta_r * risk[idx]
        + pr.beta_v * vehicle[idx]
        + pr.beta_tau * inputs.tau_norm
    )
    return np.column_stack([np.zeros(n), eta_ua, eta_pr])


def _logits_from_pr(
    inputs: StepInputs,
    risk: FloatArray,
    vehicle: FloatArray,
    idx: IntArray,
    params: TransitionParams,
) -> FloatArray:
    """Logits for households currently in PR: columns [PR, ER, SH]."""
    n = idx.shape[0]
    er = params.pr_to_er
    eta_er = (
        er.alpha
        + er.beta_mand * inputs.mand
        + er.beta_tau * inputs.tau_norm
        + er.beta_negc * (-inputs.c)
        + er.beta_r * risk[idx]
        + er.beta_v * vehicle[idx]
    )
    sh = params.pr_to_sh
    eta_sh = sh.alpha + sh.beta_negr * (-risk[idx]) + sh.beta_negv * (-vehicle[idx])
    return np.column_stack([np.zeros(n), eta_er, eta_sh])


def _logits_from_er(
    inputs: StepInputs, tir: FloatArray, idx: IntArray, params: TransitionParams
) -> FloatArray:
    """Logits for households currently in ER: columns [ER, SH]."""
    n = idx.shape[0]
    sh = params.er_to_sh
    eta_sh = sh.alpha + sh.beta_tir * tir[idx] + sh.beta_negc * (-inputs.c)
    return np.column_stack([np.zeros(n), eta_sh])


_DEST_UA = np.array([State.UA, State.AW], dtype=np.int64)
_DEST_AW = np.array([State.AW, State.UA, State.PR], dtype=np.int64)
_DEST_PR = np.array([State.PR, State.ER, State.SH], dtype=np.int64)
_DEST_ER = np.array([State.ER, State.SH], dtype=np.int64)


def transition_probs_for_state(
    origin: State,
    inputs: StepInputs,
    risk: FloatArray,
    vehicle: FloatArray,
    tir: FloatArray,
    idx: IntArray,
    params: TransitionParams,
) -> tuple[FloatArray, IntArray]:
    """Return (probabilities, destination-codes) for households at ``origin``.

    Useful for tests that need to inspect the probability matrix directly.
    """
    if origin is State.UA:
        return softmax(_logits_from_ua(inputs, risk, idx, params)), _DEST_UA
    if origin is State.AW:
        return softmax(_logits_from_aw(inputs, risk, vehicle, idx, params)), _DEST_AW
    if origin is State.PR:
        return softmax(_logits_from_pr(inputs, risk, vehicle, idx, params)), _DEST_PR
    if origin is State.ER:
        return softmax(_logits_from_er(inputs, tir, idx, params)), _DEST_ER
    # SH is absorbing.
    n = idx.shape[0]
    return np.ones((n, 1), dtype=np.float64), np.array([State.SH], dtype=np.int64)


def sample_transitions(
    prev_state: IntArray,
    inputs: StepInputs,
    risk: FloatArray,
    vehicle: FloatArray,
    tir: FloatArray,
    params: TransitionParams,
    rng: Generator,
) -> IntArray:
    """Sample the next-step state vector from the multinomial logit."""
    new_state = prev_state.copy()

    for origin, dests in (
        (State.UA, _DEST_UA),
        (State.AW, _DEST_AW),
        (State.PR, _DEST_PR),
        (State.ER, _DEST_ER),
    ):
        idx = np.flatnonzero(prev_state == origin)
        if idx.size == 0:
            continue
        probs, _ = transition_probs_for_state(origin, inputs, risk, vehicle, tir, idx, params)
        choice = categorical(probs, rng)
        new_state[idx] = dests[choice]

    return new_state