# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
from dataclasses import replace
from pathlib import Path
import numpy as np
import pytest
from iohmm_evac.cli import main
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.report.summary import bundle_summary, format_summary
from iohmm_evac.scenarios import build_scenario
@pytest.fixture
def baseline_bundle_path(tmp_path: Path) -> Path:
config = replace(build_scenario("baseline"), n_households=200, n_hours=120, seed=0)
rng = np.random.default_rng(config.seed)
result = simulate(config, rng)
out = tmp_path / "sum.parquet"
write_results(result, out)
return out
def test_bundle_summary_keys(baseline_bundle_path: Path) -> None:
bundle = load_bundle(baseline_bundle_path)
metrics = bundle_summary(bundle)
assert set(metrics.keys()) == {
"share_sheltered_at_t48",
"share_sheltered_at_landfall",
"share_failed_evacuation",
"share_evacuated_away",
"share_sheltered_in_place",
"peak_enroute_share",
"peak_enroute_hour",
"median_departure_hour",
}
def test_bundle_summary_matches_simulationresult(tmp_path: Path) -> None:
"""The bundle-derived metrics should match the in-memory ones bit-for-bit."""
config = replace(build_scenario("baseline"), n_households=300, n_hours=120, seed=3)
rng = np.random.default_rng(config.seed)
result = simulate(config, rng)
inmem = result.summary()
out = tmp_path / "match.parquet"
write_results(result, out)
derived = bundle_summary(load_bundle(out))
for k, v in inmem.items():
assert derived[k] == pytest.approx(v, abs=1e-9), k
def test_format_summary_lists_all_metrics(baseline_bundle_path: Path) -> None:
metrics = bundle_summary(load_bundle(baseline_bundle_path))
text = format_summary(metrics)
for key in metrics:
assert key in text
assert "metric" in text
def test_report_summary_cli_smoke(
baseline_bundle_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
rc = main(["report", "summary", "--input", str(baseline_bundle_path)])
captured = capsys.readouterr()
assert rc == 0
out = captured.out
assert "share_sheltered_at_t48" in out
assert "median_departure_hour" in out