tests/test_sweep_cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import os
import subprocess
import sys
from pathlib import Path

import pytest

from iohmm_evac.cli import main


def _inherit_env() -> dict[str, str]:
    keep = ("PATH", "HOME", "USER", "VIRTUAL_ENV", "PYTHONPATH")
    return {k: os.environ[k] for k in keep if k in os.environ}


def _run_tiny_sweep(out_dir: Path) -> int:
    return main(
        [
            "sweep",
            "run",
            "--output-dir",
            str(out_dir),
            "--seed",
            "0",
            "--n-households",
            "200",
            "--n-hours",
            "24",
            "--quiet",
        ]
    )


def test_sweep_run_writes_directory_layout(tmp_path: Path) -> None:
    out_dir = tmp_path / "sweep"
    rc = _run_tiny_sweep(out_dir)
    assert rc == 0
    for scenario in ("baseline", "early-warning", "targeted-messaging", "contraflow"):
        scenario_dir = out_dir / scenario
        assert (scenario_dir / "observations.parquet").exists()
        assert (scenario_dir / "network_metrics.toml").exists()
    assert (out_dir / "sweep.toml").exists()


def test_sweep_summary_prints_table(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
    out_dir = tmp_path / "sweep"
    assert _run_tiny_sweep(out_dir) == 0
    rc = main(["sweep", "summary", "--input-dir", str(out_dir)])
    assert rc == 0
    captured = capsys.readouterr()
    table = captured.out
    assert "scenario" in table
    assert "delay_hr" in table
    for scenario in ("baseline", "early-warning", "targeted-messaging", "contraflow"):
        assert scenario in table


def test_sweep_run_subset(tmp_path: Path) -> None:
    out_dir = tmp_path / "sweep-subset"
    rc = main(
        [
            "sweep",
            "run",
            "--output-dir",
            str(out_dir),
            "--scenarios",
            "baseline,contraflow",
            "--n-households",
            "100",
            "--n-hours",
            "12",
            "--quiet",
        ]
    )
    assert rc == 0
    assert (out_dir / "baseline" / "observations.parquet").exists()
    assert (out_dir / "contraflow" / "observations.parquet").exists()
    assert not (out_dir / "early-warning").exists()


def test_sweep_run_unknown_scenario_rejected(tmp_path: Path) -> None:
    out_dir = tmp_path / "bad"
    with pytest.raises(SystemExit):
        main(
            [
                "sweep",
                "run",
                "--output-dir",
                str(out_dir),
                "--scenarios",
                "no-such-scenario",
                "--quiet",
            ]
        )


def test_report_sweep_all_writes_pngs(tmp_path: Path) -> None:
    out_dir = tmp_path / "sweep"
    assert _run_tiny_sweep(out_dir) == 0
    fig_dir = tmp_path / "figures"
    rc = main(
        [
            "report",
            "sweep-all",
            "--input-dir",
            str(out_dir),
            "--output-dir",
            str(fig_dir),
        ]
    )
    assert rc == 0
    for name in ("sweep_departures.png", "sweep_network.png"):
        target = fig_dir / name
        assert target.exists()
        assert target.stat().st_size > 0


def test_report_sweep_departures_default_output(tmp_path: Path) -> None:
    out_dir = tmp_path / "sweep"
    assert _run_tiny_sweep(out_dir) == 0
    rc = main(
        [
            "report",
            "sweep-departures",
            "--input-dir",
            str(out_dir),
        ]
    )
    assert rc == 0
    expected = out_dir / "sweep_departures.png"
    assert expected.exists()
    assert expected.stat().st_size > 0


def test_report_sweep_network_explicit_output(tmp_path: Path) -> None:
    out_dir = tmp_path / "sweep"
    assert _run_tiny_sweep(out_dir) == 0
    out = tmp_path / "net.png"
    rc = main(
        [
            "report",
            "sweep-network",
            "--input-dir",
            str(out_dir),
            "--output",
            str(out),
        ]
    )
    assert rc == 0
    assert out.exists()
    assert out.stat().st_size > 0


def test_sweep_subprocess_end_to_end(tmp_path: Path) -> None:
    out_dir = tmp_path / "sweep-sub"
    proc = subprocess.run(
        [
            sys.executable,
            "-m",
            "iohmm_evac.cli",
            "sweep",
            "run",
            "--output-dir",
            str(out_dir),
            "--seed",
            "0",
            "--n-households",
            "200",
            "--n-hours",
            "24",
            "--quiet",
        ],
        capture_output=True,
        text=True,
        env={"MPLBACKEND": "Agg", **_inherit_env()},
    )
    assert proc.returncode == 0, proc.stderr
    fig_dir = tmp_path / "figures-sub"
    proc = subprocess.run(
        [
            sys.executable,
            "-m",
            "iohmm_evac.cli",
            "report",
            "sweep-all",
            "--input-dir",
            str(out_dir),
            "--output-dir",
            str(fig_dir),
        ],
        capture_output=True,
        text=True,
        env={"MPLBACKEND": "Agg", **_inherit_env()},
    )
    assert proc.returncode == 0, proc.stderr
    for name in ("sweep_departures.png", "sweep_network.png"):
        assert (fig_dir / name).exists()