src/iohmm_evac/sweep_cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac sweep`` subcommands: run, summary."""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

from iohmm_evac.scenarios import list_scenarios
from iohmm_evac.sweep import (
    DEFAULT_SCENARIOS,
    SweepConfig,
    SweepResult,
    load_sweep,
    run_sweep,
)

__all__ = [
    "add_sweep_subparser",
    "format_summary_table",
    "run_sweep_command",
]


def _parse_scenarios(value: str) -> tuple[str, ...]:
    """Parse a comma-separated scenario list, validating against the registry."""
    if not value:
        msg = "--scenarios must not be empty"
        raise argparse.ArgumentTypeError(msg)
    raw = [p.strip() for p in value.split(",")]
    items = [p for p in raw if p]
    if not items:
        msg = "--scenarios must list at least one scenario"
        raise argparse.ArgumentTypeError(msg)
    known = set(list_scenarios())
    unknown = [p for p in items if p not in known]
    if unknown:
        msg = f"Unknown scenario(s): {unknown}. Known: {sorted(known)}"
        raise argparse.ArgumentTypeError(msg)
    return tuple(items)


def add_sweep_subparser(
    subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
    """Register the ``sweep`` command and its child actions."""
    p = subparsers.add_parser("sweep", help="Run all scenarios and compute network metrics.")
    actions = p.add_subparsers(dest="action", required=True)

    p_run = actions.add_parser("run", help="Run every scenario and write outputs.")
    p_run.add_argument(
        "--scenarios",
        type=_parse_scenarios,
        default=DEFAULT_SCENARIOS,
        help=(
            "Comma-separated scenario names to run (default: all four). "
            "Use this to run a subset for faster iteration."
        ),
    )
    p_run.add_argument("--seed", type=int, default=0, help="RNG seed (default 0).")
    p_run.add_argument(
        "--n-households", type=int, default=10_000, help="Households per scenario (default 10000)."
    )
    p_run.add_argument(
        "--n-hours", type=int, default=120, help="Simulation horizon in hours (default 120)."
    )
    p_run.add_argument(
        "--output-dir", type=Path, required=True, help="Directory to write the sweep into."
    )
    p_run.add_argument("--quiet", action="store_true", help="Suppress non-error output.")

    p_sum = actions.add_parser("summary", help="Print a per-scenario metric table.")
    p_sum.add_argument(
        "--input-dir", type=Path, required=True, help="Sweep directory produced by sweep run."
    )


_TABLE_HEADER: tuple[str, ...] = (
    "scenario",
    "delay_hr",
    "peak_er_share",
    "peak_er_hour",
    "overflow",
    "failed_evac",
)


def format_summary_table(result: SweepResult) -> str:
    """Render a human-readable scenario × metric table."""
    rows: list[tuple[str, ...]] = [_TABLE_HEADER]
    for scenario in result.config.scenarios:
        m = result.network_metrics[scenario]
        rows.append(
            (
                scenario,
                f"{m.total_delay_hours:.1f}",
                f"{m.peak_enroute_share:.3f}",
                f"{m.peak_enroute_hour}",
                f"{m.shelter_overflow_count}",
                f"{m.failed_evacuation_count}",
            )
        )
    widths = [max(len(r[i]) for r in rows) for i in range(len(_TABLE_HEADER))]
    pad = 3
    lines: list[str] = []
    for r in rows:
        line = "".join(cell.ljust(width + pad) for cell, width in zip(r, widths, strict=True))
        lines.append(line.rstrip())
    return "\n".join(lines)


def run_sweep_command(args: argparse.Namespace) -> int:
    """Dispatch ``iohmm-evac sweep <action>``."""
    if args.action == "run":
        config = SweepConfig(
            output_dir=args.output_dir,
            scenarios=tuple(args.scenarios),
            seed=args.seed,
            n_households=args.n_households,
            n_hours=args.n_hours,
        )
        if not args.quiet:
            print(
                "Running sweep: scenarios="
                f"{','.join(config.scenarios)}, seed={config.seed}, "
                f"N={config.n_households}, T={config.n_hours}",
                file=sys.stderr,
            )
        result = run_sweep(config)
        if not args.quiet:
            for scenario, path in result.bundles.items():
                print(f"{scenario}: {path}", file=sys.stderr)
        return 0
    if args.action == "summary":
        result = load_sweep(args.input_dir)
        sys.stdout.write(format_summary_table(result) + "\n")
        return 0
    msg = f"Unknown sweep action: {args.action!r}"
    raise ValueError(msg)