# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
from collections.abc import Iterator
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.axes import Axes
from iohmm_evac.report.plots import plot_sweep_departures, plot_sweep_network
from iohmm_evac.sweep import DEFAULT_SCENARIOS, SweepConfig, SweepResult, run_sweep
matplotlib.use("Agg")
@pytest.fixture(autouse=True)
def _close_figures() -> Iterator[None]:
yield
plt.close("all")
@pytest.fixture(scope="module")
def tiny_sweep(tmp_path_factory: pytest.TempPathFactory) -> SweepResult:
out = tmp_path_factory.mktemp("tiny_sweep")
config = SweepConfig(
output_dir=out / "sweep",
scenarios=DEFAULT_SCENARIOS,
seed=0,
n_households=200,
n_hours=24,
)
return run_sweep(config)
def test_plot_sweep_departures_returns_axes(tiny_sweep: SweepResult) -> None:
_, ax = plt.subplots()
returned = plot_sweep_departures(tiny_sweep, ax=ax)
assert isinstance(returned, Axes)
# One curve per scenario (plus the landfall axvline).
n_scenarios = len(tiny_sweep.config.scenarios)
assert len(ax.lines) == n_scenarios + 1
def test_plot_sweep_departures_creates_axes_when_none(tiny_sweep: SweepResult) -> None:
ax = plot_sweep_departures(tiny_sweep, ax=None)
assert isinstance(ax, Axes)
def test_plot_sweep_network_2x2_panel(tiny_sweep: SweepResult) -> None:
_, axes = plt.subplots(2, 2)
returned = plot_sweep_network(tiny_sweep, ax=axes)
assert returned.shape == (2, 2)
n_scenarios = len(tiny_sweep.config.scenarios)
# Each panel: one bar per scenario.
for row in range(2):
for col in range(2):
assert len(axes[row, col].patches) == n_scenarios
def test_plot_sweep_network_creates_figure_when_none(tiny_sweep: SweepResult) -> None:
axes = plot_sweep_network(tiny_sweep)
assert axes.shape == (2, 2)
def test_plot_sweep_network_rejects_wrong_shape(tiny_sweep: SweepResult) -> None:
_, axes = plt.subplots(1, 4)
with pytest.raises(ValueError, match="2x2"):
plot_sweep_network(tiny_sweep, ax=axes)
def test_plot_sweep_departures_legend_includes_warning_hours(tiny_sweep: SweepResult) -> None:
_, ax = plt.subplots()
plot_sweep_departures(tiny_sweep, ax=ax)
legend = ax.get_legend()
assert legend is not None
legend_texts = [t.get_text() for t in legend.get_texts()]
# Each scenario label must mention voluntary and mandatory order hours.
for scenario in tiny_sweep.config.scenarios:
match = next((s for s in legend_texts if s.startswith(scenario)), None)
assert match is not None, f"missing legend label for {scenario}"
assert "vol=" in match
assert "mand=" in match
def test_plot_sweep_network_uses_distinct_colors_per_scenario(tiny_sweep: SweepResult) -> None:
_, axes = plt.subplots(2, 2)
plot_sweep_network(tiny_sweep, ax=axes)
colors = [tuple(p.get_facecolor()) for p in axes[0, 0].patches]
assert len(set(colors)) == len(tiny_sweep.config.scenarios)
def test_plot_sweep_network_accepts_ndarray(tiny_sweep: SweepResult) -> None:
_, axes = plt.subplots(2, 2)
arr = np.asarray(axes, dtype=object)
returned = plot_sweep_network(tiny_sweep, ax=arr)
assert returned.shape == (2, 2)
def test_plot_sweep_network_panel_titles(tiny_sweep: SweepResult) -> None:
_, axes = plt.subplots(2, 2)
plot_sweep_network(tiny_sweep, ax=axes)
expected = (
("Total delay (hours)", "Peak EnRoute share"),
("Shelter overflow (count)", "Failed evacuations (count)"),
)
for row in range(2):
for col in range(2):
assert axes[row, col].get_title() == expected[row][col]
def test_plot_sweep_network_peak_panel_labels_include_hour(tiny_sweep: SweepResult) -> None:
_, axes = plt.subplots(2, 2)
plot_sweep_network(tiny_sweep, ax=axes)
peak_ax = axes[0, 1]
text_strs = [t.get_text() for t in peak_ax.texts]
n_scenarios = len(tiny_sweep.config.scenarios)
assert len(text_strs) == n_scenarios
for txt in text_strs:
assert "@ t=" in txt