# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Shift-sweep tests: simulate-with-iohmm-transitions + per-cell deterministic seeds."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
from iohmm_evac.bootstrap.runner import BootstrapFit
from iohmm_evac.bootstrap.shift_sweep import (
load_sweep_result,
run_shift_sweep,
shift_timeline,
simulate_with_iohmm_transitions,
write_sweep_result,
)
from iohmm_evac.inference.fit_params import dgp_truth_to_fit_init
from iohmm_evac.params import SimulationConfig
from iohmm_evac.scenarios import build_scenario
def _make_fits(n_replicates: int) -> list[BootstrapFit]:
"""Two replicates that share the DGP-truth-projected fit (used as θ̂)."""
base = SimulationConfig(n_households=200, n_hours=24)
truth = dgp_truth_to_fit_init(base.transitions, base.emissions, base.population)
fits: list[BootstrapFit] = []
for i in range(n_replicates):
fits.append(
BootstrapFit(
replicate_id=i,
params=truth,
final_log_likelihood=-1.0,
iterations=1,
converged=True,
indices=np.arange(base.n_households, dtype=np.int64),
)
)
return fits
def test_shift_timeline_offsets_voluntary_and_mandatory() -> None:
base = SimulationConfig().timeline
shifted = shift_timeline(base, 10)
assert shifted.voluntary_hour == base.voluntary_hour + 10
assert shifted.mandatory_hour == base.mandatory_hour + 10
def test_simulate_with_iohmm_transitions_runs() -> None:
config = SimulationConfig(n_households=200, n_hours=24, seed=0)
truth = dgp_truth_to_fit_init(config.transitions, config.emissions, config.population)
rng = np.random.default_rng(0)
sim = simulate_with_iohmm_transitions(config, truth, rng)
assert sim.states.shape == (200, 25)
assert sim.displacements.shape == (200, 25)
assert sim.evac_path.shape == (200,)
# SH must be absorbing.
sh_mask = sim.states == 4 # State.SH
if sh_mask.any():
for i in range(sim.states.shape[0]):
row = sim.states[i]
first_sh = np.argmax(row == 4) if (row == 4).any() else -1
if first_sh >= 0:
assert (row[first_sh:] == 4).all()
def test_simulate_with_iohmm_transitions_is_deterministic() -> None:
config = SimulationConfig(n_households=200, n_hours=24, seed=42)
truth = dgp_truth_to_fit_init(config.transitions, config.emissions, config.population)
a = simulate_with_iohmm_transitions(config, truth, np.random.default_rng(7))
b = simulate_with_iohmm_transitions(config, truth, np.random.default_rng(7))
np.testing.assert_array_equal(a.states, b.states)
def test_run_shift_sweep_row_count() -> None:
fits = _make_fits(2)
shifts = (-8, 0, 8)
result = run_shift_sweep(
bootstrap_fits=fits,
shifts=shifts,
scenario_base=build_scenario("baseline"),
n_households=200,
n_hours=24,
base_seed=0,
)
assert len(result.rows) == len(fits) * len(shifts)
assert result.shifts == shifts
assert result.n_replicates == len(fits)
def test_run_shift_sweep_seeds_are_deterministic() -> None:
fits = _make_fits(2)
shifts = (-8, 0, 8)
a = run_shift_sweep(
bootstrap_fits=fits,
shifts=shifts,
scenario_base=build_scenario("baseline"),
n_households=200,
n_hours=24,
base_seed=0,
)
b = run_shift_sweep(
bootstrap_fits=fits,
shifts=shifts,
scenario_base=build_scenario("baseline"),
n_households=200,
n_hours=24,
base_seed=0,
)
for row_a, row_b in zip(a.rows, b.rows, strict=True):
assert row_a == row_b
def test_negative_shift_yields_fewer_failed_evacuations() -> None:
"""Earlier orders (δ=-16) should beat later orders (δ=+16)."""
fits = _make_fits(1)
result = run_shift_sweep(
bootstrap_fits=fits,
shifts=(-16, 16),
scenario_base=build_scenario("baseline"),
n_households=2_000,
n_hours=120,
base_seed=0,
)
by_shift = {r.shift: r.failed_evacuation_count for r in result.rows}
assert by_shift[-16] < by_shift[16], f"earlier shift should have fewer failed evacs: {by_shift}"
def test_run_shift_sweep_validates_inputs() -> None:
base = build_scenario("baseline")
with pytest.raises(ValueError, match="bootstrap_fits"):
run_shift_sweep(
bootstrap_fits=[],
shifts=(0,),
scenario_base=base,
n_households=10,
n_hours=4,
base_seed=0,
)
with pytest.raises(ValueError, match="shifts"):
run_shift_sweep(
bootstrap_fits=_make_fits(1),
shifts=(),
scenario_base=base,
n_households=10,
n_hours=4,
base_seed=0,
)
def test_sweep_result_round_trip(tmp_path: Path) -> None:
fits = _make_fits(2)
result = run_shift_sweep(
bootstrap_fits=fits,
shifts=(0, 8),
scenario_base=build_scenario("baseline"),
n_households=200,
n_hours=24,
base_seed=0,
)
out = tmp_path / "sweep.parquet"
write_sweep_result(result, out)
loaded = load_sweep_result(out)
assert loaded.n_replicates == result.n_replicates
assert loaded.shifts == result.shifts
assert len(loaded.rows) == len(result.rows)
for r_a, r_b in zip(result.rows, loaded.rows, strict=True):
assert r_a == r_b