# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Bundle-level sanity-check metrics matching ``SimulationResult.summary()``."""
from __future__ import annotations
import numpy as np
from iohmm_evac.report.loader import SimulationBundle
__all__ = ["bundle_summary", "format_summary"]
_METRIC_ORDER: tuple[str, ...] = (
"share_sheltered_at_t48",
"share_sheltered_at_landfall",
"share_failed_evacuation",
"share_evacuated_away",
"share_sheltered_in_place",
"peak_enroute_share",
"peak_enroute_hour",
"median_departure_hour",
)
def bundle_summary(bundle: SimulationBundle) -> dict[str, float]:
"""Compute the eight diagnostic metrics from a loaded bundle.
Mirrors :meth:`iohmm_evac.dgp.simulator.SimulationResult.summary` but
operates on the long-format DataFrames written to disk.
"""
obs = bundle.observations
t_total = bundle.t_landfall
n = bundle.n_households
state_pivot = (
obs.assign(is_sh=(obs["state"] == "SH"), is_er=(obs["state"] == "ER"))
.groupby("t")[["is_sh", "is_er"]]
.mean()
.sort_index()
)
t48 = min(48, t_total)
share_sh_t48 = float(state_pivot.loc[t48, "is_sh"])
share_sh_landfall = float(state_pivot.loc[t_total, "is_sh"])
share_failed = float(state_pivot.loc[t_total, "is_er"])
final_state = obs[obs["t"] == t_total][["household_id", "state"]]
pop = bundle.population[["household_id", "evac_path"]]
final = final_state.merge(pop, on="household_id", how="left")
in_sh = final["state"] == "SH"
if n > 0:
share_evac_away = float(((in_sh) & (final["evac_path"] == "away")).mean())
share_in_place = float(((in_sh) & (final["evac_path"] == "home")).mean())
else:
share_evac_away = 0.0
share_in_place = 0.0
er_share = state_pivot["is_er"].to_numpy(dtype=float)
peak_idx = int(np.argmax(er_share))
peak_share = float(er_share[peak_idx])
peak_hour = float(int(state_pivot.index[peak_idx]))
# "Departure" tracks the latent transition into ER, not the noisy
# emission flag — see SimulationResult.summary for rationale.
in_er = obs[obs["state"] == "ER"]
if not in_er.empty:
first_er = in_er.groupby("household_id")["t"].min()
median_dep = float(np.median(first_er.to_numpy(dtype=float)))
else:
median_dep = float("nan")
return {
"share_sheltered_at_t48": share_sh_t48,
"share_sheltered_at_landfall": share_sh_landfall,
"share_failed_evacuation": share_failed,
"share_evacuated_away": share_evac_away,
"share_sheltered_in_place": share_in_place,
"peak_enroute_share": peak_share,
"peak_enroute_hour": peak_hour,
"median_departure_hour": median_dep,
}
def format_summary(metrics: dict[str, float]) -> str:
"""Render the metrics dict as a small two-column human-readable table."""
lines = ["metric value"]
lines.append("-" * 41)
for key in _METRIC_ORDER:
if key not in metrics:
continue
value = metrics[key]
lines.append(f"{key:<32}{value:>9.4f}")
return "\n".join(lines)