# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Aggregate sweep rows into per-shift quantile bands."""
from __future__ import annotations
import numpy as np
import pytest
from iohmm_evac.bootstrap.aggregate import (
BAND_METRICS,
BandResult,
compute_bands,
metric_matrix,
)
from iohmm_evac.bootstrap.shift_sweep import ShiftSweepResult, SweepRow
def _hand_built_sweep() -> ShiftSweepResult:
"""Three replicates × two shifts with hand-picked failed-evac counts."""
rows: list[SweepRow] = []
for rep in range(3):
for shift in (-8, 0):
failed = 100 + rep * 10 + (0 if shift == 0 else 20)
rows.append(
SweepRow(
replicate_id=rep,
shift=shift,
failed_evacuation_count=failed,
peak_enroute_share=0.1 + 0.01 * rep,
total_delay_hours=10.0 + rep,
shelter_overflow_count=rep * 2,
)
)
return ShiftSweepResult(rows=tuple(rows), shifts=(-8, 0), n_replicates=3)
def test_metric_matrix_has_expected_shape() -> None:
sweep = _hand_built_sweep()
grid = metric_matrix(sweep, "failed_evacuation_count")
assert grid.shape == (3, 2)
def test_metric_matrix_unknown_metric() -> None:
with pytest.raises(ValueError, match="Unknown metric"):
metric_matrix(_hand_built_sweep(), "not-a-metric")
def test_compute_bands_quantile_values() -> None:
sweep = _hand_built_sweep()
bands = compute_bands(sweep, percentiles=(50,))
# Replicate failed counts at shift 0 are {100, 110, 120}, median = 110.
np.testing.assert_allclose(bands.quantile("failed_evacuation_count", 50)[1], 110.0)
# At shift -8 they're {120, 130, 140}.
np.testing.assert_allclose(bands.quantile("failed_evacuation_count", 50)[0], 130.0)
def test_compute_bands_orders_quantiles() -> None:
sweep = _hand_built_sweep()
bands = compute_bands(sweep, percentiles=(5, 25, 50, 75, 95))
for metric in BAND_METRICS:
p5 = bands.quantile(metric, 5)
p25 = bands.quantile(metric, 25)
p50 = bands.quantile(metric, 50)
p75 = bands.quantile(metric, 75)
p95 = bands.quantile(metric, 95)
assert (p5 <= p25 + 1e-9).all()
assert (p25 <= p50 + 1e-9).all()
assert (p50 <= p75 + 1e-9).all()
assert (p75 <= p95 + 1e-9).all()
def test_compute_bands_rejects_empty_sweep() -> None:
empty = ShiftSweepResult(rows=(), shifts=(), n_replicates=0)
with pytest.raises(ValueError, match="empty"):
compute_bands(empty)
def test_compute_bands_rejects_invalid_percentiles() -> None:
with pytest.raises(ValueError, match="percentiles must not be empty"):
compute_bands(_hand_built_sweep(), percentiles=())
with pytest.raises(ValueError, match=r"in \[0, 100\]"):
compute_bands(_hand_built_sweep(), percentiles=(150,))
def test_band_result_quantile_lookup_errors() -> None:
bands = compute_bands(_hand_built_sweep(), percentiles=(5, 50, 95))
with pytest.raises(KeyError, match="Unknown metric"):
bands.quantile("nope", 50)
with pytest.raises(KeyError, match="Percentile"):
bands.quantile("failed_evacuation_count", 99)
def test_band_result_dataclass_shape() -> None:
sweep = _hand_built_sweep()
bands = compute_bands(sweep, percentiles=(5, 50, 95))
assert isinstance(bands, BandResult)
assert bands.percentiles == (5, 50, 95)
for metric, arr in bands.bands.items():
assert metric in BAND_METRICS
assert arr.shape == (3, 2)