tests/test_transitions.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import numpy as np

from iohmm_evac.dgp.transitions import (
    StepInputs,
    sample_transitions,
    softmax,
    transition_probs_for_state,
)
from iohmm_evac.params import TransitionParams
from iohmm_evac.types import State


def _zero_inputs(n: int, *, c: float = 0.0) -> StepInputs:
    return StepInputs(
        rho=np.zeros(n),
        vol=0,
        mand=0,
        tau_norm=0.0,
        pi=0.0,
        c=c,
    )


def test_softmax_rows_sum_to_one() -> None:
    logits = np.array([[0.0, 1.0, -1.0], [10.0, -10.0, 0.0]])
    probs = softmax(logits)
    np.testing.assert_allclose(probs.sum(axis=1), 1.0, atol=1e-12)
    assert (probs >= 0).all()


def test_transition_probs_per_state_shapes_and_sums() -> None:
    n = 50
    risk = np.zeros(n)
    vehicle = np.zeros(n)
    tir = np.zeros(n)
    idx = np.arange(n, dtype=np.int64)
    inputs = _zero_inputs(n)
    params = TransitionParams()
    expected_widths = {State.UA: 2, State.AW: 3, State.PR: 3, State.ER: 2, State.SH: 1}
    for state, width in expected_widths.items():
        probs, dests = transition_probs_for_state(state, inputs, risk, vehicle, tir, idx, params)
        assert probs.shape == (n, width)
        assert dests.shape == (width,)
        np.testing.assert_allclose(probs.sum(axis=1), 1.0)


def test_strong_inputs_dominate_for_ua_to_aw() -> None:
    n = 5_000
    risk = np.zeros(n)
    vehicle = np.zeros(n)
    tir = np.zeros(n)
    # mandatory is on, rho is high → UA→AW should dominate.
    inputs = StepInputs(rho=np.full(n, 20.0), vol=1, mand=1, tau_norm=1.0, pi=0.0, c=0.0)
    rng = np.random.default_rng(0)
    new = sample_transitions(
        prev_state=np.full(n, int(State.UA), dtype=np.int64),
        inputs=inputs,
        risk=risk,
        vehicle=vehicle,
        tir=tir,
        params=TransitionParams(),
        rng=rng,
    )
    assert (new == int(State.AW)).mean() > 0.95


def test_no_one_in_origin_state_is_a_noop() -> None:
    n = 100
    prev = np.full(n, int(State.SH), dtype=np.int64)  # all absorbed
    new = sample_transitions(
        prev_state=prev,
        inputs=_zero_inputs(n),
        risk=np.zeros(n),
        vehicle=np.zeros(n),
        tir=np.zeros(n),
        params=TransitionParams(),
        rng=np.random.default_rng(0),
    )
    np.testing.assert_array_equal(new, prev)


def test_only_allowed_destinations_appear() -> None:
    rng = np.random.default_rng(1)
    n = 1_000
    prev = np.array([int(State.PR)] * n, dtype=np.int64)
    new = sample_transitions(
        prev_state=prev,
        inputs=_zero_inputs(n),
        risk=np.zeros(n),
        vehicle=np.zeros(n),
        tir=np.zeros(n),
        params=TransitionParams(),
        rng=rng,
    )
    assert set(np.unique(new).tolist()).issubset({int(State.PR), int(State.ER), int(State.SH)})


def test_high_tir_drives_er_to_sh() -> None:
    rng = np.random.default_rng(2)
    n = 4_000
    prev = np.full(n, int(State.ER), dtype=np.int64)
    new = sample_transitions(
        prev_state=prev,
        inputs=_zero_inputs(n),
        risk=np.zeros(n),
        vehicle=np.zeros(n),
        tir=np.full(n, 10.0),
        params=TransitionParams(),
        rng=rng,
    )
    assert (new == int(State.SH)).mean() > 0.7


def test_mixed_population_routes_through_correct_destinations() -> None:
    rng = np.random.default_rng(3)
    n = 500
    prev = np.zeros(n, dtype=np.int64)
    prev[: n // 2] = int(State.UA)
    prev[n // 2 :] = int(State.AW)
    new = sample_transitions(
        prev_state=prev,
        inputs=_zero_inputs(n),
        risk=np.zeros(n),
        vehicle=np.zeros(n),
        tir=np.zeros(n),
        params=TransitionParams(),
        rng=rng,
    )
    # UA can go to {UA, AW}; AW can go to {UA, AW, PR}.
    assert set(np.unique(new[: n // 2]).tolist()).issubset({int(State.UA), int(State.AW)})
    assert set(np.unique(new[n // 2 :]).tolist()).issubset(
        {int(State.UA), int(State.AW), int(State.PR)}
    )