tests/test_simulator.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import numpy as np

from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.params import SimulationConfig
from iohmm_evac.types import State


def test_simulator_end_to_end_shapes() -> None:
    config = SimulationConfig(n_households=200, n_hours=24, seed=0)
    rng = np.random.default_rng(0)
    result = simulate(config, rng)
    assert result.states.shape == (200, 25)
    assert result.departures.shape == (200, 25)
    assert result.displacements.shape == (200, 25)
    assert result.communications.shape == (200, 25)
    assert result.evac_path.shape == (200,)


def test_simulator_initial_state_is_ua() -> None:
    config = SimulationConfig(n_households=100, n_hours=12, seed=0)
    result = simulate(config, np.random.default_rng(0))
    assert (result.states[:, 0] == int(State.UA)).all()


def test_simulator_states_visit_full_space() -> None:
    config = SimulationConfig(n_households=2_000, n_hours=120, seed=0)
    result = simulate(config, np.random.default_rng(0))
    visited = set(int(x) for x in np.unique(result.states))
    assert visited == {int(s) for s in State}


def test_simulator_terminal_sh_share_positive() -> None:
    config = SimulationConfig(n_households=1_000, n_hours=120, seed=0)
    result = simulate(config, np.random.default_rng(0))
    sh_share = float((result.states[:, -1] == int(State.SH)).mean())
    assert sh_share > 0.0


def test_simulator_sh_is_absorbing() -> None:
    config = SimulationConfig(n_households=500, n_hours=120, seed=1)
    result = simulate(config, np.random.default_rng(1))
    states = result.states
    # If household is in SH at t, it must be in SH at t+1.
    sh_now = states[:, :-1] == int(State.SH)
    sh_next = states[:, 1:] == int(State.SH)
    assert ((~sh_now) | sh_next).all()


def test_simulator_evac_path_is_set_at_first_sh_entry() -> None:
    config = SimulationConfig(n_households=500, n_hours=120, seed=2)
    result = simulate(config, np.random.default_rng(2))
    states = result.states
    ever_sh = (states == int(State.SH)).any(axis=1)
    ever_er = (states == int(State.ER)).any(axis=1)
    # Every household that reached SH has evac_path != NONE (0).
    assert (result.evac_path[ever_sh] != 0).all()
    # evac_path is set at the PR -> {ER, SH} transition, so it is non-NONE
    # iff the household ever entered ER or SH (not only at SH entry).
    ever_evacuated = ever_sh | ever_er
    assert (result.evac_path[~ever_evacuated] == 0).all()
    assert (result.evac_path[ever_evacuated] != 0).all()


def test_simulator_reproducible_under_seed() -> None:
    config = SimulationConfig(n_households=200, n_hours=24, seed=0)
    r1 = simulate(config, np.random.default_rng(config.seed))
    r2 = simulate(config, np.random.default_rng(config.seed))
    np.testing.assert_array_equal(r1.states, r2.states)
    np.testing.assert_array_equal(r1.departures, r2.departures)
    np.testing.assert_array_equal(r1.communications, r2.communications)
    np.testing.assert_allclose(r1.displacements, r2.displacements)