tests/test_bootstrap_aggregate.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Aggregate sweep rows into per-shift quantile bands."""

from __future__ import annotations

import numpy as np
import pytest

from iohmm_evac.bootstrap.aggregate import (
    BAND_METRICS,
    BandResult,
    compute_bands,
    metric_matrix,
)
from iohmm_evac.bootstrap.shift_sweep import ShiftSweepResult, SweepRow


def _hand_built_sweep() -> ShiftSweepResult:
    """Three replicates × two shifts with hand-picked failed-evac counts."""
    rows: list[SweepRow] = []
    for rep in range(3):
        for shift in (-8, 0):
            failed = 100 + rep * 10 + (0 if shift == 0 else 20)
            rows.append(
                SweepRow(
                    replicate_id=rep,
                    shift=shift,
                    failed_evacuation_count=failed,
                    peak_enroute_share=0.1 + 0.01 * rep,
                    total_delay_hours=10.0 + rep,
                    shelter_overflow_count=rep * 2,
                )
            )
    return ShiftSweepResult(rows=tuple(rows), shifts=(-8, 0), n_replicates=3)


def test_metric_matrix_has_expected_shape() -> None:
    sweep = _hand_built_sweep()
    grid = metric_matrix(sweep, "failed_evacuation_count")
    assert grid.shape == (3, 2)


def test_metric_matrix_unknown_metric() -> None:
    with pytest.raises(ValueError, match="Unknown metric"):
        metric_matrix(_hand_built_sweep(), "not-a-metric")


def test_compute_bands_quantile_values() -> None:
    sweep = _hand_built_sweep()
    bands = compute_bands(sweep, percentiles=(50,))
    # Replicate failed counts at shift 0 are {100, 110, 120}, median = 110.
    np.testing.assert_allclose(bands.quantile("failed_evacuation_count", 50)[1], 110.0)
    # At shift -8 they're {120, 130, 140}.
    np.testing.assert_allclose(bands.quantile("failed_evacuation_count", 50)[0], 130.0)


def test_compute_bands_orders_quantiles() -> None:
    sweep = _hand_built_sweep()
    bands = compute_bands(sweep, percentiles=(5, 25, 50, 75, 95))
    for metric in BAND_METRICS:
        p5 = bands.quantile(metric, 5)
        p25 = bands.quantile(metric, 25)
        p50 = bands.quantile(metric, 50)
        p75 = bands.quantile(metric, 75)
        p95 = bands.quantile(metric, 95)
        assert (p5 <= p25 + 1e-9).all()
        assert (p25 <= p50 + 1e-9).all()
        assert (p50 <= p75 + 1e-9).all()
        assert (p75 <= p95 + 1e-9).all()


def test_compute_bands_rejects_empty_sweep() -> None:
    empty = ShiftSweepResult(rows=(), shifts=(), n_replicates=0)
    with pytest.raises(ValueError, match="empty"):
        compute_bands(empty)


def test_compute_bands_rejects_invalid_percentiles() -> None:
    with pytest.raises(ValueError, match="percentiles must not be empty"):
        compute_bands(_hand_built_sweep(), percentiles=())
    with pytest.raises(ValueError, match=r"in \[0, 100\]"):
        compute_bands(_hand_built_sweep(), percentiles=(150,))


def test_band_result_quantile_lookup_errors() -> None:
    bands = compute_bands(_hand_built_sweep(), percentiles=(5, 50, 95))
    with pytest.raises(KeyError, match="Unknown metric"):
        bands.quantile("nope", 50)
    with pytest.raises(KeyError, match="Percentile"):
        bands.quantile("failed_evacuation_count", 99)


def test_band_result_dataclass_shape() -> None:
    sweep = _hand_built_sweep()
    bands = compute_bands(sweep, percentiles=(5, 50, 95))
    assert isinstance(bands, BandResult)
    assert bands.percentiles == (5, 50, 95)
    for metric, arr in bands.bands.items():
        assert metric in BAND_METRICS
        assert arr.shape == (3, 2)