tests/test_scenarios.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import numpy as np
import pytest

from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.params import SimulationConfig
from iohmm_evac.scenarios import build_scenario, list_scenarios
from iohmm_evac.types import State


def test_scenarios_listed() -> None:
    expected = {"baseline", "early-warning", "targeted-messaging", "contraflow"}
    assert set(list_scenarios()) == expected


def test_each_scenario_builds_a_valid_config() -> None:
    for name in list_scenarios():
        cfg = build_scenario(name)
        assert isinstance(cfg, SimulationConfig)
        assert cfg.n_households > 0
        assert cfg.n_hours > 0


def test_unknown_scenario_raises() -> None:
    with pytest.raises(ValueError, match="Unknown scenario"):
        build_scenario("not-a-scenario")


def test_early_warning_pulls_departures_earlier() -> None:
    base = build_scenario("baseline")
    early = build_scenario("early-warning")
    # Use a smallish cohort and average a couple of seeds to keep the test fast.
    n = 1_000
    t_total = 96  # before landfall hour
    departures_base = []
    departures_early = []
    for seed in (0, 1, 2):
        b = build_scenario("baseline").__class__(
            n_households=n,
            n_hours=t_total,
            seed=seed,
            population=base.population,
            timeline=base.timeline,
            transitions=base.transitions,
            emissions=base.emissions,
            feedback=base.feedback,
        )
        e = SimulationConfig(
            n_households=n,
            n_hours=t_total,
            seed=seed,
            population=early.population,
            timeline=early.timeline,
            transitions=early.transitions,
            emissions=early.emissions,
            feedback=early.feedback,
        )
        rb = simulate(b, np.random.default_rng(seed))
        re_ = simulate(e, np.random.default_rng(seed))

        def first_departure_time(states: np.ndarray) -> np.ndarray:
            on_route = states == int(State.ER)
            ever = on_route.any(axis=1)
            t_first = np.where(ever, on_route.argmax(axis=1), states.shape[1])
            return np.asarray(t_first[ever])

        departures_base.append(first_departure_time(rb.states).mean())
        departures_early.append(first_departure_time(re_.states).mean())

    assert np.mean(departures_early) < np.mean(departures_base)


def test_contraflow_raises_capacity() -> None:
    cfg = build_scenario("contraflow")
    assert cfg.feedback.n_cap == 2500


def test_targeted_messaging_zone_multiplier() -> None:
    cfg = build_scenario("targeted-messaging")
    assert cfg.population.targeted_zone_multiplier == 1.5