src/iohmm_evac/scenarios.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Predefined scenarios that produce :class:`SimulationConfig` instances."""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import replace

from iohmm_evac.params import (
    FeedbackParams,
    PopulationParams,
    SimulationConfig,
    TimelineParams,
)

__all__ = ["SCENARIO_BUILDERS", "build_scenario", "list_scenarios"]


def _baseline() -> SimulationConfig:
    return SimulationConfig()


def _early_warning() -> SimulationConfig:
    base = SimulationConfig()
    return replace(base, timeline=TimelineParams(voluntary_hour=48, mandatory_hour=72))


def _targeted_messaging() -> SimulationConfig:
    base = SimulationConfig()
    return replace(
        base,
        population=PopulationParams(targeted_zone_multiplier=1.5),
    )


def _contraflow() -> SimulationConfig:
    base = SimulationConfig()
    return replace(base, feedback=FeedbackParams(n_cap=2500))


SCENARIO_BUILDERS: dict[str, Callable[[], SimulationConfig]] = {
    "baseline": _baseline,
    "early-warning": _early_warning,
    "targeted-messaging": _targeted_messaging,
    "contraflow": _contraflow,
}


def list_scenarios() -> list[str]:
    """Return the registered scenario names in deterministic order."""
    return sorted(SCENARIO_BUILDERS.keys())


def build_scenario(name: str) -> SimulationConfig:
    """Build a :class:`SimulationConfig` for the named scenario."""
    if name not in SCENARIO_BUILDERS:
        msg = f"Unknown scenario: {name!r}. Known: {list_scenarios()}"
        raise ValueError(msg)
    return SCENARIO_BUILDERS[name]()