src/iohmm_evac/report/summary.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Bundle-level sanity-check metrics matching ``SimulationResult.summary()``."""

from __future__ import annotations

import numpy as np

from iohmm_evac.report.loader import SimulationBundle

__all__ = ["bundle_summary", "format_summary"]


_METRIC_ORDER: tuple[str, ...] = (
    "share_sheltered_at_t48",
    "share_sheltered_at_landfall",
    "share_failed_evacuation",
    "share_evacuated_away",
    "share_sheltered_in_place",
    "peak_enroute_share",
    "peak_enroute_hour",
    "median_departure_hour",
)


def bundle_summary(bundle: SimulationBundle) -> dict[str, float]:
    """Compute the eight diagnostic metrics from a loaded bundle.

    Mirrors :meth:`iohmm_evac.dgp.simulator.SimulationResult.summary` but
    operates on the long-format DataFrames written to disk.
    """
    obs = bundle.observations
    t_total = bundle.t_landfall
    n = bundle.n_households

    state_pivot = (
        obs.assign(is_sh=(obs["state"] == "SH"), is_er=(obs["state"] == "ER"))
        .groupby("t")[["is_sh", "is_er"]]
        .mean()
        .sort_index()
    )

    t48 = min(48, t_total)
    share_sh_t48 = float(state_pivot.loc[t48, "is_sh"])
    share_sh_landfall = float(state_pivot.loc[t_total, "is_sh"])
    share_failed = float(state_pivot.loc[t_total, "is_er"])

    final_state = obs[obs["t"] == t_total][["household_id", "state"]]
    pop = bundle.population[["household_id", "evac_path"]]
    final = final_state.merge(pop, on="household_id", how="left")
    in_sh = final["state"] == "SH"
    if n > 0:
        share_evac_away = float(((in_sh) & (final["evac_path"] == "away")).mean())
        share_in_place = float(((in_sh) & (final["evac_path"] == "home")).mean())
    else:
        share_evac_away = 0.0
        share_in_place = 0.0

    er_share = state_pivot["is_er"].to_numpy(dtype=float)
    peak_idx = int(np.argmax(er_share))
    peak_share = float(er_share[peak_idx])
    peak_hour = float(int(state_pivot.index[peak_idx]))

    # "Departure" tracks the latent transition into ER, not the noisy
    # emission flag — see SimulationResult.summary for rationale.
    in_er = obs[obs["state"] == "ER"]
    if not in_er.empty:
        first_er = in_er.groupby("household_id")["t"].min()
        median_dep = float(np.median(first_er.to_numpy(dtype=float)))
    else:
        median_dep = float("nan")

    return {
        "share_sheltered_at_t48": share_sh_t48,
        "share_sheltered_at_landfall": share_sh_landfall,
        "share_failed_evacuation": share_failed,
        "share_evacuated_away": share_evac_away,
        "share_sheltered_in_place": share_in_place,
        "peak_enroute_share": peak_share,
        "peak_enroute_hour": peak_hour,
        "median_departure_hour": median_dep,
    }


def format_summary(metrics: dict[str, float]) -> str:
    """Render the metrics dict as a small two-column human-readable table."""
    lines = ["metric                          value"]
    lines.append("-" * 41)
    for key in _METRIC_ORDER:
        if key not in metrics:
            continue
        value = metrics[key]
        lines.append(f"{key:<32}{value:>9.4f}")
    return "\n".join(lines)