tests/test_baseline_shape.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Soft-shape regression test for the baseline scenario.

Catches regressions of the kind that appeared after Build 1.5: under the
default parameters the population should *not* mostly evacuate before any
warning order has fired.
"""

from __future__ import annotations

from dataclasses import replace

import numpy as np
import pytest

from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.scenarios import build_scenario


@pytest.fixture(scope="module")
def baseline_summary() -> dict[str, float]:
    config = build_scenario("baseline")
    config = replace(config, n_households=2000, n_hours=120, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    return result.summary()


def test_baseline_shape(baseline_summary: dict[str, float]) -> None:
    """Soft-shape constraints from the baseline calibration table.

    Values are stamped into the assertion message verbatim so the chapter
    author can tell which metric drifted out of band on a regression.
    """
    s = baseline_summary
    failures: list[str] = []
    if not (s["share_sheltered_at_t48"] < 0.05):
        failures.append(f"share_sheltered_at_t48={s['share_sheltered_at_t48']:.4f} not < 0.05")
    if not (s["share_sheltered_at_landfall"] > 0.40):
        failures.append(
            f"share_sheltered_at_landfall={s['share_sheltered_at_landfall']:.4f} not > 0.40"
        )
    if not (0.10 < s["peak_enroute_share"] < 0.50):
        failures.append(f"peak_enroute_share={s['peak_enroute_share']:.4f} outside (0.10, 0.50)")
    if not (60 < s["peak_enroute_hour"] < 105):
        failures.append(f"peak_enroute_hour={s['peak_enroute_hour']:.0f} outside (60, 105)")
    if not (60 < s["median_departure_hour"] < 100):
        failures.append(f"median_departure_hour={s['median_departure_hour']:.0f} outside (60, 100)")
    if not (s["share_failed_evacuation"] < 0.03):
        failures.append(f"share_failed_evacuation={s['share_failed_evacuation']:.4f} not < 0.03")
    assert not failures, "Baseline shape regressions:\n  " + "\n  ".join(failures)


def test_summary_contains_all_eight_metrics(baseline_summary: dict[str, float]) -> None:
    expected = {
        "share_sheltered_at_t48",
        "share_sheltered_at_landfall",
        "share_failed_evacuation",
        "share_evacuated_away",
        "share_sheltered_in_place",
        "peak_enroute_share",
        "peak_enroute_hour",
        "median_departure_hour",
    }
    assert set(baseline_summary.keys()) == expected