# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import numpy as np
import pytest
from iohmm_evac.dgp.feedback import congestion, peer_share
from iohmm_evac.types import State
def test_congestion_basic() -> None:
prev = np.array([State.ER, State.ER, State.UA, State.AW], dtype=np.int64)
assert congestion(prev, n_cap=10) == pytest.approx(0.2)
def test_congestion_saturates_at_one() -> None:
prev = np.full(2000, int(State.ER), dtype=np.int64)
assert congestion(prev, n_cap=1000) == 1.0
def test_congestion_invalid_cap() -> None:
prev = np.array([State.ER], dtype=np.int64)
with pytest.raises(ValueError, match="positive"):
congestion(prev, n_cap=0)
def test_peer_share_in_unit_interval() -> None:
prev = np.array([State.UA, State.ER, State.SH, State.SH, State.PR], dtype=np.int64)
evac_path = np.array([0, 0, 1, 2, 0], dtype=np.int64) # SH+away counts; SH+home doesn't
share = peer_share(prev, evac_path)
assert share == pytest.approx(2 / 5)
assert 0.0 <= share <= 1.0
def test_peer_share_uses_lagged_state() -> None:
# 'Lagged' simply means the caller passes the t-1 state vector; the function
# itself should treat its input as authoritative without modification.
prev = np.array([State.ER, State.ER, State.UA], dtype=np.int64)
evac_path = np.zeros(3, dtype=np.int64)
assert peer_share(prev, evac_path) == pytest.approx(2 / 3)
def test_peer_share_empty() -> None:
assert peer_share(np.array([], dtype=np.int64), np.array([], dtype=np.int64)) == 0.0