src/iohmm_evac/network/metrics.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Network/shelter metrics computed post-hoc from saved simulation arrays.

The IO-HMM proper produces only the behavioral state trajectory; the metrics
here translate that trajectory into operational outcomes (peak load,
overflow, failed evacuations, delay). They are pure functions of the saved
arrays — no resimulation, no stochastic content.
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from iohmm_evac.types import FloatArray, IntArray

__all__ = [
    "NetworkMetrics",
    "compute_metrics_from_arrays",
    "peak_enroute_share_and_hour",
]


@dataclass(frozen=True, slots=True)
class NetworkMetrics:
    """Summary network/shelter outcomes plus per-hour diagnostic arrays."""

    total_delay_hours: float
    """Sum across (household, hour) of congestion-attributable delay."""

    peak_enroute_share: float
    """Maximum share of households simultaneously in ER across all hours."""

    peak_enroute_hour: int
    """Hour at which ``peak_enroute_share`` is attained."""

    shelter_overflow_count: int
    """Cumulative SH-away arrivals beyond shelter capacity ``K``."""

    failed_evacuation_count: int
    """Households still in ER at the final hour ``T``."""

    delay_per_hour: FloatArray
    """Δ_t per hour, shape ``(T+1,)``."""

    enroute_count_per_hour: IntArray
    """ER count per hour, shape ``(T+1,)``."""

    arrivals_away_per_hour: IntArray
    """New SH-away arrivals per hour, shape ``(T+1,)``."""


def peak_enroute_share_and_hour(states: IntArray, er_code: int = 3) -> tuple[float, int]:
    """Return (peak ER share over hours, hour of the peak).

    Shared by :class:`SimulationResult.summary` and
    :func:`compute_metrics_from_arrays` so both code paths produce the same
    value for the same input array.
    """
    er_share = (states == er_code).mean(axis=0)
    peak_idx = int(np.argmax(er_share))
    return float(er_share[peak_idx]), peak_idx


def _state_codes_to_int(states_long: IntArray) -> IntArray:
    return states_long.astype(np.int64, copy=False)


def compute_metrics_from_arrays(
    states: IntArray,
    displacements: FloatArray,
    evac_path: IntArray,
    *,
    n_cap: int,
    shelter_capacity: int,
    v_free: float,
    congestion_penalty: float = 0.6,
    er_code: int = 3,
    sh_code: int = 4,
    away_code: int = 1,
) -> NetworkMetrics:
    """Compute :class:`NetworkMetrics` from raw simulation arrays.

    Parameters mirror what the simulator wrote and the sidecar TOML stores.
    Pure function — no I/O, no RNG.
    """
    if n_cap <= 0:
        msg = "n_cap must be a positive integer"
        raise ValueError(msg)
    if v_free <= 0:
        msg = "v_free must be positive"
        raise ValueError(msg)

    _n, t_plus_1 = states.shape
    t_total = t_plus_1 - 1

    is_er = states == er_code
    enroute_count_per_hour = is_er.sum(axis=0).astype(np.int64)

    # Δ_{i,t} = max(X_{i,t} - X_{i,t-1}, 0) — distance covered during hour t.
    deltas = np.zeros_like(displacements)
    deltas[:, 1:] = np.maximum(displacements[:, 1:] - displacements[:, :-1], 0.0)

    # c_t recomputed from the lagged state vector, matching the DGP's feedback.
    enroute_lagged = enroute_count_per_hour.copy()
    c_per_hour = np.zeros(t_plus_1, dtype=np.float64)
    if t_plus_1 > 1:
        c_per_hour[1:] = np.minimum(enroute_lagged[:-1] / float(n_cap), 1.0)

    # v_eff_t = v_free * (1 - α * c_t); guarded so the inverse is finite even
    # if c_t is somehow above 1 (it cannot be, given the min above).
    v_eff = v_free * (1.0 - congestion_penalty * c_per_hour)
    v_eff = np.maximum(v_eff, v_free * (1.0 - congestion_penalty))
    inv_diff = (1.0 / v_eff) - (1.0 / v_free)

    # Only ER-state household-hours contribute.
    er_mask = is_er.astype(np.float64)
    delay_matrix = deltas * er_mask * inv_diff[np.newaxis, :]
    delay_per_hour = delay_matrix.sum(axis=0)
    total_delay_hours = float(delay_per_hour.sum())

    peak_share, peak_hour = peak_enroute_share_and_hour(states, er_code=er_code)

    # Arrivals into SH-away per hour: state[t]==SH AND state[t-1]!=SH AND
    # evac_path==AWAY. Hour 0 has no prior, so contributes zero arrivals.
    arrivals_away_per_hour = np.zeros(t_plus_1, dtype=np.int64)
    if t_plus_1 > 1:
        is_sh = states == sh_code
        new_sh = is_sh[:, 1:] & ~is_sh[:, :-1]
        away_mask = (evac_path == away_code)[:, np.newaxis]
        arrivals_t = (new_sh & away_mask).sum(axis=0).astype(np.int64)
        arrivals_away_per_hour[1:] = arrivals_t

    cumulative_away = int(arrivals_away_per_hour.sum())
    shelter_overflow_count = max(0, cumulative_away - int(shelter_capacity))

    failed_evacuation_count = int(enroute_count_per_hour[t_total])

    return NetworkMetrics(
        total_delay_hours=total_delay_hours,
        peak_enroute_share=peak_share,
        peak_enroute_hour=peak_hour,
        shelter_overflow_count=shelter_overflow_count,
        failed_evacuation_count=failed_evacuation_count,
        delay_per_hour=delay_per_hour,
        enroute_count_per_hour=enroute_count_per_hour,
        arrivals_away_per_hour=arrivals_away_per_hour,
    )