tests/test_shift_sweep.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Shift-sweep tests: simulate-with-iohmm-transitions + per-cell deterministic seeds."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from iohmm_evac.bootstrap.runner import BootstrapFit
from iohmm_evac.bootstrap.shift_sweep import (
    load_sweep_result,
    run_shift_sweep,
    shift_timeline,
    simulate_with_iohmm_transitions,
    write_sweep_result,
)
from iohmm_evac.inference.fit_params import dgp_truth_to_fit_init
from iohmm_evac.params import SimulationConfig
from iohmm_evac.scenarios import build_scenario


def _make_fits(n_replicates: int) -> list[BootstrapFit]:
    """Two replicates that share the DGP-truth-projected fit (used as θ̂)."""
    base = SimulationConfig(n_households=200, n_hours=24)
    truth = dgp_truth_to_fit_init(base.transitions, base.emissions, base.population)
    fits: list[BootstrapFit] = []
    for i in range(n_replicates):
        fits.append(
            BootstrapFit(
                replicate_id=i,
                params=truth,
                final_log_likelihood=-1.0,
                iterations=1,
                converged=True,
                indices=np.arange(base.n_households, dtype=np.int64),
            )
        )
    return fits


def test_shift_timeline_offsets_voluntary_and_mandatory() -> None:
    base = SimulationConfig().timeline
    shifted = shift_timeline(base, 10)
    assert shifted.voluntary_hour == base.voluntary_hour + 10
    assert shifted.mandatory_hour == base.mandatory_hour + 10


def test_simulate_with_iohmm_transitions_runs() -> None:
    config = SimulationConfig(n_households=200, n_hours=24, seed=0)
    truth = dgp_truth_to_fit_init(config.transitions, config.emissions, config.population)
    rng = np.random.default_rng(0)
    sim = simulate_with_iohmm_transitions(config, truth, rng)
    assert sim.states.shape == (200, 25)
    assert sim.displacements.shape == (200, 25)
    assert sim.evac_path.shape == (200,)
    # SH must be absorbing.
    sh_mask = sim.states == 4  # State.SH
    if sh_mask.any():
        for i in range(sim.states.shape[0]):
            row = sim.states[i]
            first_sh = np.argmax(row == 4) if (row == 4).any() else -1
            if first_sh >= 0:
                assert (row[first_sh:] == 4).all()


def test_simulate_with_iohmm_transitions_is_deterministic() -> None:
    config = SimulationConfig(n_households=200, n_hours=24, seed=42)
    truth = dgp_truth_to_fit_init(config.transitions, config.emissions, config.population)
    a = simulate_with_iohmm_transitions(config, truth, np.random.default_rng(7))
    b = simulate_with_iohmm_transitions(config, truth, np.random.default_rng(7))
    np.testing.assert_array_equal(a.states, b.states)


def test_run_shift_sweep_row_count() -> None:
    fits = _make_fits(2)
    shifts = (-8, 0, 8)
    result = run_shift_sweep(
        bootstrap_fits=fits,
        shifts=shifts,
        scenario_base=build_scenario("baseline"),
        n_households=200,
        n_hours=24,
        base_seed=0,
    )
    assert len(result.rows) == len(fits) * len(shifts)
    assert result.shifts == shifts
    assert result.n_replicates == len(fits)


def test_run_shift_sweep_seeds_are_deterministic() -> None:
    fits = _make_fits(2)
    shifts = (-8, 0, 8)
    a = run_shift_sweep(
        bootstrap_fits=fits,
        shifts=shifts,
        scenario_base=build_scenario("baseline"),
        n_households=200,
        n_hours=24,
        base_seed=0,
    )
    b = run_shift_sweep(
        bootstrap_fits=fits,
        shifts=shifts,
        scenario_base=build_scenario("baseline"),
        n_households=200,
        n_hours=24,
        base_seed=0,
    )
    for row_a, row_b in zip(a.rows, b.rows, strict=True):
        assert row_a == row_b


def test_negative_shift_yields_fewer_failed_evacuations() -> None:
    """Earlier orders (δ=-16) should beat later orders (δ=+16)."""
    fits = _make_fits(1)
    result = run_shift_sweep(
        bootstrap_fits=fits,
        shifts=(-16, 16),
        scenario_base=build_scenario("baseline"),
        n_households=2_000,
        n_hours=120,
        base_seed=0,
    )
    by_shift = {r.shift: r.failed_evacuation_count for r in result.rows}
    assert by_shift[-16] < by_shift[16], f"earlier shift should have fewer failed evacs: {by_shift}"


def test_run_shift_sweep_validates_inputs() -> None:
    base = build_scenario("baseline")
    with pytest.raises(ValueError, match="bootstrap_fits"):
        run_shift_sweep(
            bootstrap_fits=[],
            shifts=(0,),
            scenario_base=base,
            n_households=10,
            n_hours=4,
            base_seed=0,
        )
    with pytest.raises(ValueError, match="shifts"):
        run_shift_sweep(
            bootstrap_fits=_make_fits(1),
            shifts=(),
            scenario_base=base,
            n_households=10,
            n_hours=4,
            base_seed=0,
        )


def test_sweep_result_round_trip(tmp_path: Path) -> None:
    fits = _make_fits(2)
    result = run_shift_sweep(
        bootstrap_fits=fits,
        shifts=(0, 8),
        scenario_base=build_scenario("baseline"),
        n_households=200,
        n_hours=24,
        base_seed=0,
    )
    out = tmp_path / "sweep.parquet"
    write_sweep_result(result, out)
    loaded = load_sweep_result(out)
    assert loaded.n_replicates == result.n_replicates
    assert loaded.shifts == result.shifts
    assert len(loaded.rows) == len(result.rows)
    for r_a, r_b in zip(result.rows, loaded.rows, strict=True):
        assert r_a == r_b