# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.axes import Axes
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.constants import STATE_ORDER
from iohmm_evac.report.loader import SimulationBundle, load_bundle
from iohmm_evac.report.plots import (
plot_cumulative_departures,
plot_emission_summary,
plot_household_trajectories,
plot_state_occupancy,
)
@pytest.fixture(autouse=True)
def _close_figures() -> Iterator[None]:
yield
plt.close("all")
@pytest.fixture
def bundle(tmp_path: Path) -> SimulationBundle:
config = SimulationConfig(n_households=50, n_hours=24, seed=0)
rng = np.random.default_rng(config.seed)
result = simulate(config, rng)
out = tmp_path / "fixture.parquet"
write_results(result, out)
return load_bundle(out)
def test_plot_state_occupancy_returns_axes(bundle: SimulationBundle) -> None:
_, ax = plt.subplots()
returned = plot_state_occupancy(bundle, ax=ax)
assert isinstance(returned, Axes)
assert returned is ax
# stackplot creates one PolyCollection per state
assert len(ax.collections) == len(STATE_ORDER)
def test_plot_state_occupancy_creates_axes_when_none(bundle: SimulationBundle) -> None:
ax = plot_state_occupancy(bundle, ax=None)
assert isinstance(ax, Axes)
def test_plot_cumulative_departures_returns_axes(bundle: SimulationBundle) -> None:
_, ax = plt.subplots()
returned = plot_cumulative_departures(bundle, ax=ax)
assert isinstance(returned, Axes)
# one plotted line plus three overlay axvlines
assert len(ax.lines) >= 1
def test_plot_cumulative_departures_default_axes(bundle: SimulationBundle) -> None:
ax = plot_cumulative_departures(bundle)
assert isinstance(ax, Axes)
def test_plot_household_trajectories(bundle: SimulationBundle) -> None:
ids = [0, 5, 12]
axes = plot_household_trajectories(bundle, household_ids=ids)
assert len(axes) == len(ids)
for ax in axes:
assert isinstance(ax, Axes)
def test_plot_household_trajectories_with_supplied_axes(bundle: SimulationBundle) -> None:
ids = [0, 1]
_, axes = plt.subplots(len(ids), 1)
returned = plot_household_trajectories(bundle, household_ids=ids, ax=list(axes))
assert len(returned) == len(ids)
def test_plot_household_trajectories_single_household(bundle: SimulationBundle) -> None:
axes = plot_household_trajectories(bundle, household_ids=[0])
assert len(axes) == 1
def test_plot_household_trajectories_too_many(bundle: SimulationBundle) -> None:
with pytest.raises(ValueError, match="at most"):
plot_household_trajectories(bundle, household_ids=list(range(10)))
def test_plot_household_trajectories_empty(bundle: SimulationBundle) -> None:
with pytest.raises(ValueError, match="at least one"):
plot_household_trajectories(bundle, household_ids=[])
def test_plot_household_trajectories_unknown_id(bundle: SimulationBundle) -> None:
with pytest.raises(ValueError, match="not found"):
plot_household_trajectories(bundle, household_ids=[99999])
def test_plot_household_trajectories_axis_count_mismatch(bundle: SimulationBundle) -> None:
_, ax = plt.subplots()
with pytest.raises(ValueError, match="one entry per"):
plot_household_trajectories(bundle, household_ids=[0, 1], ax=[ax])
def test_plot_emission_summary_returns_axes(bundle: SimulationBundle) -> None:
_, ax = plt.subplots()
returned = plot_emission_summary(bundle, ax=ax)
assert isinstance(returned, Axes)
# 3 metrics times at most 5 states = at most 15 bars; at least one per metric
assert len(ax.patches) >= 3
def test_plot_emission_summary_default_axes(bundle: SimulationBundle) -> None:
ax = plot_emission_summary(bundle)
assert isinstance(ax, Axes)