# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""End-to-end simulator that ties population, timeline, transitions, and emissions together."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from numpy.random import Generator
from iohmm_evac.dgp.emissions import sample_emissions
from iohmm_evac.dgp.feedback import congestion, peer_share
from iohmm_evac.dgp.population import Population, synthesize_population
from iohmm_evac.dgp.timeline import Timeline, build_timeline, local_risk_at
from iohmm_evac.dgp.transitions import StepInputs, sample_transitions
from iohmm_evac.network.metrics import peak_enroute_share_and_hour
from iohmm_evac.params import SimulationConfig
from iohmm_evac.types import BoolArray, EvacPath, FloatArray, IntArray, State
__all__ = ["SimulationResult", "encode_evac_path", "simulate"]
_EVAC_PATH_CODE = {EvacPath.NONE: 0, EvacPath.AWAY: 1, EvacPath.HOME: 2}
def encode_evac_path(arr: IntArray) -> list[str]:
"""Decode an integer evac_path array back to its label set (for IO)."""
inverse = {v: k.value for k, v in _EVAC_PATH_CODE.items()}
return [inverse[int(x)] for x in arr]
@dataclass(frozen=True, slots=True)
class SimulationResult:
"""Bundle of all arrays produced by a simulation run."""
states: IntArray
"""Latent states, shape (N, T+1), int8."""
departures: BoolArray
"""Departure indicator, shape (N, T+1), bool."""
displacements: FloatArray
"""Displacement in km, shape (N, T+1), float64."""
communications: IntArray
"""Communication counts, shape (N, T+1), int32."""
population: Population
"""Static covariates."""
timeline: Timeline
"""Exogenous timeline."""
evac_path: IntArray
"""Final evac_path code per household, shape (N,), int8 (0/1/2)."""
config: SimulationConfig
"""Resolved configuration that produced this run."""
def summary(self) -> dict[str, float]:
"""Sanity-check metrics; informational, no thresholds enforced.
See ``docs/reporting.md`` for a description of each metric.
"""
n, t_plus_1 = self.states.shape
t_total = t_plus_1 - 1
sh = int(State.SH)
er = int(State.ER)
away = _EVAC_PATH_CODE[EvacPath.AWAY]
home = _EVAC_PATH_CODE[EvacPath.HOME]
t48 = min(48, t_total)
sheltered_at_t48 = float(np.mean(self.states[:, t48] == sh))
sheltered_at_landfall = float(np.mean(self.states[:, t_total] == sh))
failed_evacuation = float(np.mean(self.states[:, t_total] == er))
final_state = self.states[:, t_total]
in_sh = final_state == sh
evacuated_away = float(np.mean(in_sh & (self.evac_path == away))) if n > 0 else 0.0
sheltered_in_place = float(np.mean(in_sh & (self.evac_path == home))) if n > 0 else 0.0
peak_enroute_share, peak_idx = peak_enroute_share_and_hour(self.states, er_code=er)
peak_enroute_hour = float(peak_idx)
# "Departure" here means the latent transition into ER (the
# household has actually left), not the noisy ``departure``
# emission flag — the latter fires under non-ER states with
# probability ``p_departure_other`` and would be dominated by
# measurement noise rather than evacuation timing.
in_er_any = (self.states == er).any(axis=1)
if bool(in_er_any.any()):
first_er_hour = (self.states == er).argmax(axis=1)[in_er_any]
median_departure_hour = float(np.median(first_er_hour))
else:
median_departure_hour = float("nan")
return {
"share_sheltered_at_t48": sheltered_at_t48,
"share_sheltered_at_landfall": sheltered_at_landfall,
"share_failed_evacuation": failed_evacuation,
"share_evacuated_away": evacuated_away,
"share_sheltered_in_place": sheltered_in_place,
"peak_enroute_share": peak_enroute_share,
"peak_enroute_hour": peak_enroute_hour,
"median_departure_hour": median_departure_hour,
}
def _zone_multiplier(population: Population, params: SimulationConfig) -> FloatArray | None:
"""Per-household risk multiplier for the targeted-messaging scenario.
Applied to coastal zones A *and* B (codes 0 and 1) — zone C (code 2) is
untouched. Build 3.5 widened the targeted set from zone A only after the
aggregate impact at zone A's ~10% population share was below the noise
floor.
"""
mult = params.population.targeted_zone_multiplier
if mult == 1.0:
return None
out = np.ones(population.n, dtype=np.float64)
coastal = (population.zone == 0) | (population.zone == 1)
out[coastal] = mult
return out
def _update_evac_path(evac_path: IntArray, prev_state: IntArray, new_state: IntArray) -> None:
"""Set ``evac_path`` based on freshly-fired transitions (in place).
PR → ER ⇒ AWAY (will subsequently shelter away)
PR → SH ⇒ HOME (sheltered in place)
ER → SH ⇒ unchanged (already AWAY)
"""
pr_to_er = (prev_state == State.PR) & (new_state == State.ER)
pr_to_sh = (prev_state == State.PR) & (new_state == State.SH)
evac_path[pr_to_er] = _EVAC_PATH_CODE[EvacPath.AWAY]
evac_path[pr_to_sh] = _EVAC_PATH_CODE[EvacPath.HOME]
def _update_tir(tir: FloatArray, new_state: IntArray) -> FloatArray:
"""Increment time-in-ER for households still in ER, reset to 0 elsewhere."""
in_er = new_state == State.ER
tir = np.where(in_er, tir + 1.0, 0.0)
return tir
def simulate(config: SimulationConfig, rng: Generator) -> SimulationResult:
"""Run the full DGP and return all observed and latent arrays."""
n = config.n_households
t_total = config.n_hours
pop = synthesize_population(n, rng, config.population)
timeline = build_timeline(t_total, rng, config.timeline)
zone_mult = _zone_multiplier(pop, config)
states = np.zeros((n, t_total + 1), dtype=np.int64)
departures = np.zeros((n, t_total + 1), dtype=bool)
displacements = np.zeros((n, t_total + 1), dtype=np.float64)
communications = np.zeros((n, t_total + 1), dtype=np.int64)
evac_path = np.zeros(n, dtype=np.int64)
tir = np.zeros(n, dtype=np.float64)
# Initial conditions: everyone starts UA, evac_path NONE, tir 0.
states[:, 0] = int(State.UA)
d0, x0, c0 = sample_emissions(
states[:, 0],
evac_path,
tir,
pop.destination,
congestion_t=0.0,
emissions=config.emissions,
rng=rng,
)
departures[:, 0] = d0
displacements[:, 0] = x0
communications[:, 0] = c0
for t in range(1, t_total + 1):
prev = states[:, t - 1]
c_t = congestion(prev, config.feedback.n_cap)
pi_t = peer_share(prev, evac_path)
rho_t = local_risk_at(timeline.forecast[t], pop.distance, zone_mult)
inputs = StepInputs(
rho=rho_t,
vol=int(timeline.voluntary[t]),
mand=int(timeline.mandatory[t]),
tau_norm=t / t_total,
pi=pi_t,
c=c_t,
)
new_state = sample_transitions(
prev_state=prev,
inputs=inputs,
risk=pop.risk,
vehicle=pop.vehicle,
tir=tir,
params=config.transitions,
rng=rng,
)
_update_evac_path(evac_path, prev, new_state)
tir = _update_tir(tir, new_state)
d, x, c = sample_emissions(
new_state,
evac_path,
tir,
pop.destination,
congestion_t=c_t,
emissions=config.emissions,
rng=rng,
)
states[:, t] = new_state
departures[:, t] = d
displacements[:, t] = x
communications[:, t] = c
return SimulationResult(
states=states,
departures=departures,
displacements=displacements,
communications=communications,
population=pop,
timeline=timeline,
evac_path=evac_path,
config=config,
)