tests/test_sweep.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import filecmp
import tomllib
from pathlib import Path

import pytest

from iohmm_evac.scenarios import list_scenarios
from iohmm_evac.sweep import (
    DEFAULT_SCENARIOS,
    SweepConfig,
    load_sweep,
    run_sweep,
)


@pytest.fixture(scope="module")
def small_sweep(tmp_path_factory: pytest.TempPathFactory) -> Path:
    # N=2000 keeps the sweep under 2s while making the peak-hour estimate
    # stable enough for the early-warning vs. baseline ordering check.
    # Module-scoped: every test in this file consumes the same on-disk
    # sweep, since none mutate it.
    out = tmp_path_factory.mktemp("small_sweep")
    config = SweepConfig(
        output_dir=out / "sweep",
        scenarios=DEFAULT_SCENARIOS,
        seed=0,
        n_households=2000,
        n_hours=120,
    )
    run_sweep(config)
    return config.output_dir


def test_run_sweep_writes_expected_directory_structure(small_sweep: Path) -> None:
    assert (small_sweep / "sweep.toml").exists()
    for scenario in DEFAULT_SCENARIOS:
        scenario_dir = small_sweep / scenario
        assert scenario_dir.is_dir()
        for name in (
            "observations.parquet",
            "observations.population.parquet",
            "observations.timeline.parquet",
            "observations.config.toml",
            "network_metrics.toml",
        ):
            target = scenario_dir / name
            assert target.exists(), f"missing {target}"
            assert target.stat().st_size > 0


def test_load_sweep_round_trips_run_sweep(small_sweep: Path) -> None:
    loaded = load_sweep(small_sweep)
    assert tuple(loaded.config.scenarios) == DEFAULT_SCENARIOS
    assert loaded.config.seed == 0
    assert set(loaded.bundles.keys()) == set(DEFAULT_SCENARIOS)
    assert set(loaded.network_metrics.keys()) == set(DEFAULT_SCENARIOS)
    for scenario in DEFAULT_SCENARIOS:
        assert loaded.bundles[scenario].exists()
        m = loaded.network_metrics[scenario]
        assert m.delay_per_hour.shape == (loaded.config.n_hours + 1,)
        assert m.enroute_count_per_hour.shape == (loaded.config.n_hours + 1,)
        assert m.arrivals_away_per_hour.shape == (loaded.config.n_hours + 1,)


def test_top_level_marker_contents(small_sweep: Path) -> None:
    with (small_sweep / "sweep.toml").open("rb") as f:
        data = tomllib.load(f)
    assert data["seed"] == 0
    assert data["n_hours"] == 120
    assert data["n_households"] == 2000
    assert tuple(data["scenarios"]) == DEFAULT_SCENARIOS
    assert "version" in data


def test_scenarios_produce_distinct_metrics(small_sweep: Path) -> None:
    loaded = load_sweep(small_sweep)
    metrics_by_scenario = loaded.network_metrics
    # At least one numerical metric must differ between baseline and another
    # scenario — otherwise the scenarios are accidentally identical.
    baseline = metrics_by_scenario["baseline"]
    diffs: dict[str, float] = {}
    for name in DEFAULT_SCENARIOS:
        if name == "baseline":
            continue
        other = metrics_by_scenario[name]
        diffs[name] = abs(other.total_delay_hours - baseline.total_delay_hours) + abs(
            other.failed_evacuation_count - baseline.failed_evacuation_count
        )
    assert any(d > 0 for d in diffs.values()), f"all scenarios produced identical metrics: {diffs}"


def test_early_warning_shifts_peak_enroute_earlier(small_sweep: Path) -> None:
    loaded = load_sweep(small_sweep)
    baseline = loaded.network_metrics["baseline"]
    early = loaded.network_metrics["early-warning"]
    # Earlier orders should pull the road-network peak earlier. Use a
    # tolerance of >5 hours to absorb stochastic variation at the small N
    # used by the test fixture.
    assert early.peak_enroute_hour < baseline.peak_enroute_hour - 5, (
        f"early-warning peak ({early.peak_enroute_hour}) should be "
        f">5h earlier than baseline ({baseline.peak_enroute_hour})"
    )


def test_same_seed_reproduces_baseline_bit_for_bit(tmp_path: Path) -> None:
    cfg_a = SweepConfig(
        output_dir=tmp_path / "a",
        scenarios=("baseline",),
        seed=7,
        n_households=200,
        n_hours=24,
    )
    cfg_b = SweepConfig(
        output_dir=tmp_path / "b",
        scenarios=("baseline",),
        seed=7,
        n_households=200,
        n_hours=24,
    )
    run_sweep(cfg_a)
    run_sweep(cfg_b)
    a = cfg_a.output_dir / "baseline" / "observations.parquet"
    b = cfg_b.output_dir / "baseline" / "observations.parquet"
    assert filecmp.cmp(a, b, shallow=False)


def test_run_sweep_subset_of_scenarios(tmp_path: Path) -> None:
    config = SweepConfig(
        output_dir=tmp_path / "subset",
        scenarios=("baseline", "contraflow"),
        seed=0,
        n_households=200,
        n_hours=24,
    )
    result = run_sweep(config)
    assert set(result.bundles.keys()) == {"baseline", "contraflow"}
    assert set(result.network_metrics.keys()) == {"baseline", "contraflow"}
    assert (config.output_dir / "early-warning").exists() is False


def test_unknown_scenario_raises(tmp_path: Path) -> None:
    config = SweepConfig(
        output_dir=tmp_path / "bad",
        scenarios=("baseline", "does-not-exist"),
        seed=0,
        n_households=10,
        n_hours=2,
    )
    with pytest.raises(ValueError, match="Unknown scenario"):
        run_sweep(config)


def test_load_sweep_missing_marker(tmp_path: Path) -> None:
    with pytest.raises(FileNotFoundError, match=r"sweep\.toml"):
        load_sweep(tmp_path / "no-such")


def test_default_scenarios_match_registry() -> None:
    assert set(DEFAULT_SCENARIOS) == set(list_scenarios())