# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import numpy as np
from iohmm_evac.dgp.transitions import (
StepInputs,
sample_transitions,
softmax,
transition_probs_for_state,
)
from iohmm_evac.params import TransitionParams
from iohmm_evac.types import State
def _zero_inputs(n: int, *, c: float = 0.0) -> StepInputs:
return StepInputs(
rho=np.zeros(n),
vol=0,
mand=0,
tau_norm=0.0,
pi=0.0,
c=c,
)
def test_softmax_rows_sum_to_one() -> None:
logits = np.array([[0.0, 1.0, -1.0], [10.0, -10.0, 0.0]])
probs = softmax(logits)
np.testing.assert_allclose(probs.sum(axis=1), 1.0, atol=1e-12)
assert (probs >= 0).all()
def test_transition_probs_per_state_shapes_and_sums() -> None:
n = 50
risk = np.zeros(n)
vehicle = np.zeros(n)
tir = np.zeros(n)
idx = np.arange(n, dtype=np.int64)
inputs = _zero_inputs(n)
params = TransitionParams()
expected_widths = {State.UA: 2, State.AW: 3, State.PR: 3, State.ER: 2, State.SH: 1}
for state, width in expected_widths.items():
probs, dests = transition_probs_for_state(state, inputs, risk, vehicle, tir, idx, params)
assert probs.shape == (n, width)
assert dests.shape == (width,)
np.testing.assert_allclose(probs.sum(axis=1), 1.0)
def test_strong_inputs_dominate_for_ua_to_aw() -> None:
n = 5_000
risk = np.zeros(n)
vehicle = np.zeros(n)
tir = np.zeros(n)
# mandatory is on, rho is high → UA→AW should dominate.
inputs = StepInputs(rho=np.full(n, 20.0), vol=1, mand=1, tau_norm=1.0, pi=0.0, c=0.0)
rng = np.random.default_rng(0)
new = sample_transitions(
prev_state=np.full(n, int(State.UA), dtype=np.int64),
inputs=inputs,
risk=risk,
vehicle=vehicle,
tir=tir,
params=TransitionParams(),
rng=rng,
)
assert (new == int(State.AW)).mean() > 0.95
def test_no_one_in_origin_state_is_a_noop() -> None:
n = 100
prev = np.full(n, int(State.SH), dtype=np.int64) # all absorbed
new = sample_transitions(
prev_state=prev,
inputs=_zero_inputs(n),
risk=np.zeros(n),
vehicle=np.zeros(n),
tir=np.zeros(n),
params=TransitionParams(),
rng=np.random.default_rng(0),
)
np.testing.assert_array_equal(new, prev)
def test_only_allowed_destinations_appear() -> None:
rng = np.random.default_rng(1)
n = 1_000
prev = np.array([int(State.PR)] * n, dtype=np.int64)
new = sample_transitions(
prev_state=prev,
inputs=_zero_inputs(n),
risk=np.zeros(n),
vehicle=np.zeros(n),
tir=np.zeros(n),
params=TransitionParams(),
rng=rng,
)
assert set(np.unique(new).tolist()).issubset({int(State.PR), int(State.ER), int(State.SH)})
def test_high_tir_drives_er_to_sh() -> None:
rng = np.random.default_rng(2)
n = 4_000
prev = np.full(n, int(State.ER), dtype=np.int64)
new = sample_transitions(
prev_state=prev,
inputs=_zero_inputs(n),
risk=np.zeros(n),
vehicle=np.zeros(n),
tir=np.full(n, 10.0),
params=TransitionParams(),
rng=rng,
)
assert (new == int(State.SH)).mean() > 0.7
def test_mixed_population_routes_through_correct_destinations() -> None:
rng = np.random.default_rng(3)
n = 500
prev = np.zeros(n, dtype=np.int64)
prev[: n // 2] = int(State.UA)
prev[n // 2 :] = int(State.AW)
new = sample_transitions(
prev_state=prev,
inputs=_zero_inputs(n),
risk=np.zeros(n),
vehicle=np.zeros(n),
tir=np.zeros(n),
params=TransitionParams(),
rng=rng,
)
# UA can go to {UA, AW}; AW can go to {UA, AW, PR}.
assert set(np.unique(new[: n // 2]).tolist()).issubset({int(State.UA), int(State.AW)})
assert set(np.unique(new[n // 2 :]).tolist()).issubset(
{int(State.UA), int(State.AW), int(State.PR)}
)