src/iohmm_evac/dgp/simulator.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""End-to-end simulator that ties population, timeline, transitions, and emissions together."""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from numpy.random import Generator

from iohmm_evac.dgp.emissions import sample_emissions
from iohmm_evac.dgp.feedback import congestion, peer_share
from iohmm_evac.dgp.population import Population, synthesize_population
from iohmm_evac.dgp.timeline import Timeline, build_timeline, local_risk_at
from iohmm_evac.dgp.transitions import StepInputs, sample_transitions
from iohmm_evac.network.metrics import peak_enroute_share_and_hour
from iohmm_evac.params import SimulationConfig
from iohmm_evac.types import BoolArray, EvacPath, FloatArray, IntArray, State

__all__ = ["SimulationResult", "encode_evac_path", "simulate"]


_EVAC_PATH_CODE = {EvacPath.NONE: 0, EvacPath.AWAY: 1, EvacPath.HOME: 2}


def encode_evac_path(arr: IntArray) -> list[str]:
    """Decode an integer evac_path array back to its label set (for IO)."""
    inverse = {v: k.value for k, v in _EVAC_PATH_CODE.items()}
    return [inverse[int(x)] for x in arr]


@dataclass(frozen=True, slots=True)
class SimulationResult:
    """Bundle of all arrays produced by a simulation run."""

    states: IntArray
    """Latent states, shape (N, T+1), int8."""
    departures: BoolArray
    """Departure indicator, shape (N, T+1), bool."""
    displacements: FloatArray
    """Displacement in km, shape (N, T+1), float64."""
    communications: IntArray
    """Communication counts, shape (N, T+1), int32."""
    population: Population
    """Static covariates."""
    timeline: Timeline
    """Exogenous timeline."""
    evac_path: IntArray
    """Final evac_path code per household, shape (N,), int8 (0/1/2)."""
    config: SimulationConfig
    """Resolved configuration that produced this run."""

    def summary(self) -> dict[str, float]:
        """Sanity-check metrics; informational, no thresholds enforced.

        See ``docs/reporting.md`` for a description of each metric.
        """
        n, t_plus_1 = self.states.shape
        t_total = t_plus_1 - 1
        sh = int(State.SH)
        er = int(State.ER)
        away = _EVAC_PATH_CODE[EvacPath.AWAY]
        home = _EVAC_PATH_CODE[EvacPath.HOME]

        t48 = min(48, t_total)
        sheltered_at_t48 = float(np.mean(self.states[:, t48] == sh))
        sheltered_at_landfall = float(np.mean(self.states[:, t_total] == sh))
        failed_evacuation = float(np.mean(self.states[:, t_total] == er))

        final_state = self.states[:, t_total]
        in_sh = final_state == sh
        evacuated_away = float(np.mean(in_sh & (self.evac_path == away))) if n > 0 else 0.0
        sheltered_in_place = float(np.mean(in_sh & (self.evac_path == home))) if n > 0 else 0.0

        peak_enroute_share, peak_idx = peak_enroute_share_and_hour(self.states, er_code=er)
        peak_enroute_hour = float(peak_idx)

        # "Departure" here means the latent transition into ER (the
        # household has actually left), not the noisy ``departure``
        # emission flag — the latter fires under non-ER states with
        # probability ``p_departure_other`` and would be dominated by
        # measurement noise rather than evacuation timing.
        in_er_any = (self.states == er).any(axis=1)
        if bool(in_er_any.any()):
            first_er_hour = (self.states == er).argmax(axis=1)[in_er_any]
            median_departure_hour = float(np.median(first_er_hour))
        else:
            median_departure_hour = float("nan")

        return {
            "share_sheltered_at_t48": sheltered_at_t48,
            "share_sheltered_at_landfall": sheltered_at_landfall,
            "share_failed_evacuation": failed_evacuation,
            "share_evacuated_away": evacuated_away,
            "share_sheltered_in_place": sheltered_in_place,
            "peak_enroute_share": peak_enroute_share,
            "peak_enroute_hour": peak_enroute_hour,
            "median_departure_hour": median_departure_hour,
        }


def _zone_multiplier(population: Population, params: SimulationConfig) -> FloatArray | None:
    """Per-household risk multiplier for the targeted-messaging scenario.

    Applied to coastal zones A *and* B (codes 0 and 1) — zone C (code 2) is
    untouched. Build 3.5 widened the targeted set from zone A only after the
    aggregate impact at zone A's ~10% population share was below the noise
    floor.
    """
    mult = params.population.targeted_zone_multiplier
    if mult == 1.0:
        return None
    out = np.ones(population.n, dtype=np.float64)
    coastal = (population.zone == 0) | (population.zone == 1)
    out[coastal] = mult
    return out


def _update_evac_path(evac_path: IntArray, prev_state: IntArray, new_state: IntArray) -> None:
    """Set ``evac_path`` based on freshly-fired transitions (in place).

    PR → ER  ⇒ AWAY (will subsequently shelter away)
    PR → SH  ⇒ HOME (sheltered in place)
    ER → SH  ⇒ unchanged (already AWAY)
    """
    pr_to_er = (prev_state == State.PR) & (new_state == State.ER)
    pr_to_sh = (prev_state == State.PR) & (new_state == State.SH)
    evac_path[pr_to_er] = _EVAC_PATH_CODE[EvacPath.AWAY]
    evac_path[pr_to_sh] = _EVAC_PATH_CODE[EvacPath.HOME]


def _update_tir(tir: FloatArray, new_state: IntArray) -> FloatArray:
    """Increment time-in-ER for households still in ER, reset to 0 elsewhere."""
    in_er = new_state == State.ER
    tir = np.where(in_er, tir + 1.0, 0.0)
    return tir


def simulate(config: SimulationConfig, rng: Generator) -> SimulationResult:
    """Run the full DGP and return all observed and latent arrays."""
    n = config.n_households
    t_total = config.n_hours

    pop = synthesize_population(n, rng, config.population)
    timeline = build_timeline(t_total, rng, config.timeline)
    zone_mult = _zone_multiplier(pop, config)

    states = np.zeros((n, t_total + 1), dtype=np.int64)
    departures = np.zeros((n, t_total + 1), dtype=bool)
    displacements = np.zeros((n, t_total + 1), dtype=np.float64)
    communications = np.zeros((n, t_total + 1), dtype=np.int64)
    evac_path = np.zeros(n, dtype=np.int64)
    tir = np.zeros(n, dtype=np.float64)

    # Initial conditions: everyone starts UA, evac_path NONE, tir 0.
    states[:, 0] = int(State.UA)

    d0, x0, c0 = sample_emissions(
        states[:, 0],
        evac_path,
        tir,
        pop.destination,
        congestion_t=0.0,
        emissions=config.emissions,
        rng=rng,
    )
    departures[:, 0] = d0
    displacements[:, 0] = x0
    communications[:, 0] = c0

    for t in range(1, t_total + 1):
        prev = states[:, t - 1]
        c_t = congestion(prev, config.feedback.n_cap)
        pi_t = peer_share(prev, evac_path)
        rho_t = local_risk_at(timeline.forecast[t], pop.distance, zone_mult)

        inputs = StepInputs(
            rho=rho_t,
            vol=int(timeline.voluntary[t]),
            mand=int(timeline.mandatory[t]),
            tau_norm=t / t_total,
            pi=pi_t,
            c=c_t,
        )
        new_state = sample_transitions(
            prev_state=prev,
            inputs=inputs,
            risk=pop.risk,
            vehicle=pop.vehicle,
            tir=tir,
            params=config.transitions,
            rng=rng,
        )

        _update_evac_path(evac_path, prev, new_state)
        tir = _update_tir(tir, new_state)

        d, x, c = sample_emissions(
            new_state,
            evac_path,
            tir,
            pop.destination,
            congestion_t=c_t,
            emissions=config.emissions,
            rng=rng,
        )
        states[:, t] = new_state
        departures[:, t] = d
        displacements[:, t] = x
        communications[:, t] = c

    return SimulationResult(
        states=states,
        departures=departures,
        displacements=displacements,
        communications=communications,
        population=pop,
        timeline=timeline,
        evac_path=evac_path,
        config=config,
    )