tests/test_sweep_plots.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

from collections.abc import Iterator

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.axes import Axes

from iohmm_evac.report.plots import plot_sweep_departures, plot_sweep_network
from iohmm_evac.sweep import DEFAULT_SCENARIOS, SweepConfig, SweepResult, run_sweep

matplotlib.use("Agg")


@pytest.fixture(autouse=True)
def _close_figures() -> Iterator[None]:
    yield
    plt.close("all")


@pytest.fixture(scope="module")
def tiny_sweep(tmp_path_factory: pytest.TempPathFactory) -> SweepResult:
    out = tmp_path_factory.mktemp("tiny_sweep")
    config = SweepConfig(
        output_dir=out / "sweep",
        scenarios=DEFAULT_SCENARIOS,
        seed=0,
        n_households=200,
        n_hours=24,
    )
    return run_sweep(config)


def test_plot_sweep_departures_returns_axes(tiny_sweep: SweepResult) -> None:
    _, ax = plt.subplots()
    returned = plot_sweep_departures(tiny_sweep, ax=ax)
    assert isinstance(returned, Axes)
    # One curve per scenario (plus the landfall axvline).
    n_scenarios = len(tiny_sweep.config.scenarios)
    assert len(ax.lines) == n_scenarios + 1


def test_plot_sweep_departures_creates_axes_when_none(tiny_sweep: SweepResult) -> None:
    ax = plot_sweep_departures(tiny_sweep, ax=None)
    assert isinstance(ax, Axes)


def test_plot_sweep_network_2x2_panel(tiny_sweep: SweepResult) -> None:
    _, axes = plt.subplots(2, 2)
    returned = plot_sweep_network(tiny_sweep, ax=axes)
    assert returned.shape == (2, 2)
    n_scenarios = len(tiny_sweep.config.scenarios)
    # Each panel: one bar per scenario.
    for row in range(2):
        for col in range(2):
            assert len(axes[row, col].patches) == n_scenarios


def test_plot_sweep_network_creates_figure_when_none(tiny_sweep: SweepResult) -> None:
    axes = plot_sweep_network(tiny_sweep)
    assert axes.shape == (2, 2)


def test_plot_sweep_network_rejects_wrong_shape(tiny_sweep: SweepResult) -> None:
    _, axes = plt.subplots(1, 4)
    with pytest.raises(ValueError, match="2x2"):
        plot_sweep_network(tiny_sweep, ax=axes)


def test_plot_sweep_departures_legend_includes_warning_hours(tiny_sweep: SweepResult) -> None:
    _, ax = plt.subplots()
    plot_sweep_departures(tiny_sweep, ax=ax)
    legend = ax.get_legend()
    assert legend is not None
    legend_texts = [t.get_text() for t in legend.get_texts()]
    # Each scenario label must mention voluntary and mandatory order hours.
    for scenario in tiny_sweep.config.scenarios:
        match = next((s for s in legend_texts if s.startswith(scenario)), None)
        assert match is not None, f"missing legend label for {scenario}"
        assert "vol=" in match
        assert "mand=" in match


def test_plot_sweep_network_uses_distinct_colors_per_scenario(tiny_sweep: SweepResult) -> None:
    _, axes = plt.subplots(2, 2)
    plot_sweep_network(tiny_sweep, ax=axes)
    colors = [tuple(p.get_facecolor()) for p in axes[0, 0].patches]
    assert len(set(colors)) == len(tiny_sweep.config.scenarios)


def test_plot_sweep_network_accepts_ndarray(tiny_sweep: SweepResult) -> None:
    _, axes = plt.subplots(2, 2)
    arr = np.asarray(axes, dtype=object)
    returned = plot_sweep_network(tiny_sweep, ax=arr)
    assert returned.shape == (2, 2)


def test_plot_sweep_network_panel_titles(tiny_sweep: SweepResult) -> None:
    _, axes = plt.subplots(2, 2)
    plot_sweep_network(tiny_sweep, ax=axes)
    expected = (
        ("Total delay (hours)", "Peak EnRoute share"),
        ("Shelter overflow (count)", "Failed evacuations (count)"),
    )
    for row in range(2):
        for col in range(2):
            assert axes[row, col].get_title() == expected[row][col]


def test_plot_sweep_network_peak_panel_labels_include_hour(tiny_sweep: SweepResult) -> None:
    _, axes = plt.subplots(2, 2)
    plot_sweep_network(tiny_sweep, ax=axes)
    peak_ax = axes[0, 1]
    text_strs = [t.get_text() for t in peak_ax.texts]
    n_scenarios = len(tiny_sweep.config.scenarios)
    assert len(text_strs) == n_scenarios
    for txt in text_strs:
        assert "@ t=" in txt