# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Tests for plot_bootstrap_bands."""
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.collections import PolyCollection
from iohmm_evac.bootstrap.aggregate import compute_bands
from iohmm_evac.bootstrap.shift_sweep import ShiftSweepResult, SweepRow
from iohmm_evac.report.plots import plot_bootstrap_bands
def _tiny_sweep() -> ShiftSweepResult:
rng = np.random.default_rng(0)
rows: list[SweepRow] = []
for rep in range(5):
for shift in (-8, 0, 8):
failed = int(120 - 4 * shift + rng.integers(-3, 4))
rows.append(
SweepRow(
replicate_id=rep,
shift=shift,
failed_evacuation_count=failed,
peak_enroute_share=0.1 + 0.001 * shift,
total_delay_hours=10.0 - 0.1 * shift,
shelter_overflow_count=int(50 - shift),
)
)
return ShiftSweepResult(rows=tuple(rows), shifts=(-8, 0, 8), n_replicates=5)
def test_plot_bootstrap_bands_renders() -> None:
sweep = _tiny_sweep()
bands = compute_bands(sweep, percentiles=(5, 25, 50, 75, 95))
fig, ax = plt.subplots()
out = plot_bootstrap_bands(bands, metric="failed_evacuation_count", ax=ax)
assert out is ax
fills = [c for c in ax.collections if isinstance(c, PolyCollection)]
assert len(fills) >= 2 # outer + inner band
plt.close(fig)
def test_plot_bootstrap_bands_creates_axes_when_none() -> None:
sweep = _tiny_sweep()
bands = compute_bands(sweep, percentiles=(5, 25, 50, 75, 95))
ax = plot_bootstrap_bands(bands, metric="peak_enroute_share")
assert ax.get_xlabel() != ""
plt.close("all")
def test_plot_bootstrap_bands_unknown_metric() -> None:
sweep = _tiny_sweep()
bands = compute_bands(sweep, percentiles=(5, 25, 50, 75, 95))
with pytest.raises(KeyError, match="not in BandResult"):
plot_bootstrap_bands(bands, metric="not-a-metric")
def test_plot_bootstrap_bands_missing_percentiles() -> None:
sweep = _tiny_sweep()
bands = compute_bands(sweep, percentiles=(50,))
with pytest.raises(ValueError, match="missing percentiles"):
plot_bootstrap_bands(bands, metric="failed_evacuation_count")