tests/test_emissions.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import numpy as np

from iohmm_evac.dgp.emissions import sample_emissions
from iohmm_evac.params import EmissionParams
from iohmm_evac.types import State


def _make_state_vector(s: State, n: int) -> np.ndarray:
    return np.full(n, int(s), dtype=np.int64)


def test_emission_shapes() -> None:
    rng = np.random.default_rng(0)
    n = 50
    state = _make_state_vector(State.AW, n)
    evac_path = np.zeros(n, dtype=np.int64)
    tir = np.zeros(n)
    dest = np.full(n, 60.0)
    d, x, c = sample_emissions(state, evac_path, tir, dest, 0.0, EmissionParams(), rng)
    assert d.shape == (n,) and d.dtype == bool
    assert x.shape == (n,)
    assert c.shape == (n,)


def test_departure_rate_matches_state() -> None:
    rng = np.random.default_rng(1)
    n = 5_000
    state = _make_state_vector(State.ER, n)
    d, _, _ = sample_emissions(
        state, np.zeros(n, np.int64), np.zeros(n), np.full(n, 60.0), 0.0, EmissionParams(), rng
    )
    # ER households depart with p ≈ 0.95.
    assert abs(d.mean() - 0.95) < 0.02


def test_displacement_increases_with_tir_for_er() -> None:
    n = 2_000
    state = _make_state_vector(State.ER, n)
    dest = np.full(n, 80.0)
    evac_path = np.zeros(n, dtype=np.int64)
    _, x_low, _ = sample_emissions(
        state, evac_path, np.full(n, 0.5), dest, 0.0, EmissionParams(), np.random.default_rng(3)
    )
    _, x_high, _ = sample_emissions(
        state, evac_path, np.full(n, 1.5), dest, 0.0, EmissionParams(), np.random.default_rng(3)
    )
    assert x_high.mean() > x_low.mean()


def test_displacement_capped_at_destination() -> None:
    rng = np.random.default_rng(4)
    n = 1_000
    state = _make_state_vector(State.ER, n)
    dest = np.full(n, 50.0)
    # Very large tir should cap at destination + small noise.
    _, x, _ = sample_emissions(
        state, np.zeros(n, np.int64), np.full(n, 100.0), dest, 0.0, EmissionParams(), rng
    )
    # After 100h at 40 km/h, raw progress is 4000; min with dest = 50.
    assert abs(x.mean() - 50.0) < 0.2


def test_poisson_rates_per_state() -> None:
    rng = np.random.default_rng(5)
    n = 4_000
    params = EmissionParams()
    expected = {
        State.UA: params.lambda_ua,
        State.AW: params.lambda_aw,
        State.PR: params.lambda_pr,
        State.ER: params.lambda_er,
        State.SH: params.lambda_sh,
    }
    for s, lam in expected.items():
        state = _make_state_vector(s, n)
        _, _, c = sample_emissions(
            state,
            np.zeros(n, np.int64),
            np.zeros(n),
            np.full(n, 60.0),
            0.0,
            params,
            rng,
        )
        assert abs(c.mean() - lam) < 0.2


def test_sh_displacement_branches_on_evac_path() -> None:
    rng = np.random.default_rng(6)
    n = 1_000
    state = _make_state_vector(State.SH, n)
    dest = np.full(n, 70.0)
    away = np.full(n, 1, dtype=np.int64)
    home = np.full(n, 2, dtype=np.int64)
    _, x_away, _ = sample_emissions(state, away, np.zeros(n), dest, 0.0, EmissionParams(), rng)
    _, x_home, _ = sample_emissions(
        state, home, np.zeros(n), dest, 0.0, EmissionParams(), np.random.default_rng(7)
    )
    assert abs(x_away.mean() - 70.0) < 0.5
    assert abs(x_home.mean()) < 1.0