tests/test_feedback.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import numpy as np
import pytest

from iohmm_evac.dgp.feedback import congestion, peer_share
from iohmm_evac.types import State


def test_congestion_basic() -> None:
    prev = np.array([State.ER, State.ER, State.UA, State.AW], dtype=np.int64)
    assert congestion(prev, n_cap=10) == pytest.approx(0.2)


def test_congestion_saturates_at_one() -> None:
    prev = np.full(2000, int(State.ER), dtype=np.int64)
    assert congestion(prev, n_cap=1000) == 1.0


def test_congestion_invalid_cap() -> None:
    prev = np.array([State.ER], dtype=np.int64)
    with pytest.raises(ValueError, match="positive"):
        congestion(prev, n_cap=0)


def test_peer_share_in_unit_interval() -> None:
    prev = np.array([State.UA, State.ER, State.SH, State.SH, State.PR], dtype=np.int64)
    evac_path = np.array([0, 0, 1, 2, 0], dtype=np.int64)  # SH+away counts; SH+home doesn't
    share = peer_share(prev, evac_path)
    assert share == pytest.approx(2 / 5)
    assert 0.0 <= share <= 1.0


def test_peer_share_uses_lagged_state() -> None:
    # 'Lagged' simply means the caller passes the t-1 state vector; the function
    # itself should treat its input as authoritative without modification.
    prev = np.array([State.ER, State.ER, State.UA], dtype=np.int64)
    evac_path = np.zeros(3, dtype=np.int64)
    assert peer_share(prev, evac_path) == pytest.approx(2 / 3)


def test_peer_share_empty() -> None:
    assert peer_share(np.array([], dtype=np.int64), np.array([], dtype=np.int64)) == 0.0