tests/test_network_metrics.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import numpy as np

from iohmm_evac.network.metrics import (
    NetworkMetrics,
    compute_metrics_from_arrays,
    peak_enroute_share_and_hour,
)
from iohmm_evac.types import State

ER = int(State.ER)
SH = int(State.SH)
UA = int(State.UA)
PR = int(State.PR)
AW = int(State.AW)


def _states(rows: list[list[int]]) -> np.ndarray:
    return np.asarray(rows, dtype=np.int64)


def _disp(rows: list[list[float]]) -> np.ndarray:
    return np.asarray(rows, dtype=np.float64)


def test_total_delay_matches_hand_computed_value() -> None:
    # Two households, three hours. Household 0 is ER at t=1 and t=2.
    states = _states(
        [
            [PR, ER, ER],
            [PR, PR, ER],
        ]
    )
    displacements = _disp(
        [
            [0.0, 5.0, 12.0],
            [0.0, 0.0, 4.0],
        ]
    )
    evac_path = np.array([1, 1], dtype=np.int64)  # both AWAY

    n_cap = 2
    v_free = 40.0
    alpha = 0.6

    # c_t derived from states[:, t-1]:
    # c_0 = 0 (no lag); c_1 = #ER at t=0 / n_cap = 0/2 = 0;
    # c_2 = #ER at t=1 / n_cap = 1/2 = 0.5
    # v_eff_0 = 40, v_eff_1 = 40, v_eff_2 = 40 * (1 - 0.6 * 0.5) = 28
    # 1/v_eff_2 - 1/v_free = 1/28 - 1/40
    # Only ER hours contribute.
    # Household 0: t=1 ER, delta=5, c_1=0 → 0; t=2 ER, delta=7, contrib 7*(1/28-1/40)
    # Household 1: t=2 ER, delta=4, contrib 4*(1/28-1/40)
    expected = (7 + 4) * (1.0 / 28.0 - 1.0 / 40.0)

    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=n_cap,
        shelter_capacity=0,
        v_free=v_free,
        congestion_penalty=alpha,
    )
    assert metrics.total_delay_hours == np.float64(expected)


def test_peak_enroute_share_and_hour_helper() -> None:
    states = _states(
        [
            [UA, ER, ER, SH],
            [UA, AW, ER, ER],
            [UA, AW, AW, ER],
        ]
    )
    # ER counts per hour: 0, 1, 2, 2; share = counts/3.
    # First argmax tie at t=2.
    share, hour = peak_enroute_share_and_hour(states)
    assert hour == 2
    assert share == 2.0 / 3.0


def test_peak_enroute_in_metrics_picks_correct_hour() -> None:
    states = _states(
        [
            [PR, ER, ER, SH],
            [PR, ER, SH, SH],
        ]
    )
    displacements = np.zeros_like(states, dtype=np.float64)
    evac_path = np.array([1, 1], dtype=np.int64)
    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=10,
        shelter_capacity=10,
        v_free=40.0,
    )
    # ER counts per hour: 0, 2, 1, 0 → peak at t=1, share 1.0.
    assert metrics.peak_enroute_hour == 1
    assert metrics.peak_enroute_share == 1.0


def test_peak_enroute_share_is_at_most_one() -> None:
    # Construct a saturated case: every household ER at t=1.
    states = _states(
        [
            [PR, ER],
            [PR, ER],
            [PR, ER],
        ]
    )
    displacements = np.zeros_like(states, dtype=np.float64)
    evac_path = np.array([1, 1, 1], dtype=np.int64)
    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=10,
        shelter_capacity=10,
        v_free=40.0,
    )
    assert 0.0 <= metrics.peak_enroute_share <= 1.0
    assert metrics.peak_enroute_share == 1.0


def test_failed_evacuation_count_is_er_at_horizon() -> None:
    states = _states(
        [
            [PR, ER, ER],
            [PR, ER, SH],
            [PR, PR, ER],
        ]
    )
    displacements = np.zeros_like(states, dtype=np.float64)
    evac_path = np.array([1, 1, 1], dtype=np.int64)
    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=10,
        shelter_capacity=10,
        v_free=40.0,
    )
    # At t=2, two households are ER (ids 0 and 2).
    assert metrics.failed_evacuation_count == 2


def test_shelter_overflow_with_zero_capacity_counts_every_arrival() -> None:
    states = _states(
        [
            [PR, ER, SH],
            [PR, ER, SH],
            [PR, PR, SH],
        ]
    )
    displacements = np.zeros_like(states, dtype=np.float64)
    # Two households AWAY (ids 0, 1), one HOME (id 2).
    evac_path = np.array([1, 1, 2], dtype=np.int64)
    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=10,
        shelter_capacity=0,
        v_free=40.0,
    )
    # SH-away arrivals: only ids 0 and 1 enter SH with evac_path AWAY.
    assert metrics.shelter_overflow_count == 2


def test_shelter_overflow_with_large_capacity_is_zero() -> None:
    states = _states(
        [
            [PR, ER, SH],
            [PR, ER, SH],
        ]
    )
    displacements = np.zeros_like(states, dtype=np.float64)
    evac_path = np.array([1, 1], dtype=np.int64)
    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=10,
        shelter_capacity=1_000_000,
        v_free=40.0,
    )
    assert metrics.shelter_overflow_count == 0


def test_total_delay_zero_with_huge_n_cap() -> None:
    # A large n_cap pushes c_t toward zero, so v_eff ≈ v_free and the delay
    # contribution should round-trip to ~0 (in fact exactly 0 since c_t = 0).
    states = _states(
        [
            [PR, ER, ER],
            [PR, ER, ER],
        ]
    )
    displacements = _disp(
        [
            [0.0, 5.0, 10.0],
            [0.0, 6.0, 11.0],
        ]
    )
    evac_path = np.array([1, 1], dtype=np.int64)
    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=10**9,
        shelter_capacity=10,
        v_free=40.0,
    )
    assert metrics.total_delay_hours < 1e-6


def test_diagnostic_arrays_shape_and_values() -> None:
    states = _states(
        [
            [UA, AW, PR, ER, SH],
            [UA, AW, AW, PR, ER],
        ]
    )
    displacements = np.zeros_like(states, dtype=np.float64)
    evac_path = np.array([1, 1], dtype=np.int64)
    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=displacements,
        evac_path=evac_path,
        n_cap=5,
        shelter_capacity=10,
        v_free=40.0,
    )
    assert metrics.delay_per_hour.shape == (5,)
    assert metrics.enroute_count_per_hour.shape == (5,)
    assert metrics.arrivals_away_per_hour.shape == (5,)
    # ER counts per hour: 0,0,0,1,1
    expected_er = np.array([0, 0, 0, 1, 1], dtype=np.int64)
    assert np.array_equal(metrics.enroute_count_per_hour, expected_er)
    # SH-away arrivals: only id 0 enters SH at t=4. Hour-0 contribution is zero.
    expected_arr = np.array([0, 0, 0, 0, 1], dtype=np.int64)
    assert np.array_equal(metrics.arrivals_away_per_hour, expected_arr)


def test_metrics_dataclass_is_frozen() -> None:
    states = _states([[PR, ER]])
    metrics = compute_metrics_from_arrays(
        states=states,
        displacements=np.zeros_like(states, dtype=np.float64),
        evac_path=np.array([1], dtype=np.int64),
        n_cap=10,
        shelter_capacity=10,
        v_free=40.0,
    )
    assert isinstance(metrics, NetworkMetrics)
    import dataclasses

    with np.testing.assert_no_warnings():
        pass
    try:
        dataclasses.replace(metrics, total_delay_hours=99.0)  # OK on frozen
    except Exception as exc:  # pragma: no cover
        raise AssertionError(f"frozen dataclass replace should succeed: {exc}") from exc


def test_n_cap_must_be_positive() -> None:
    states = _states([[PR, ER]])
    try:
        compute_metrics_from_arrays(
            states=states,
            displacements=np.zeros_like(states, dtype=np.float64),
            evac_path=np.array([1], dtype=np.int64),
            n_cap=0,
            shelter_capacity=10,
            v_free=40.0,
        )
    except ValueError as exc:
        assert "n_cap" in str(exc)
    else:  # pragma: no cover
        raise AssertionError("expected ValueError")


def test_v_free_must_be_positive() -> None:
    states = _states([[PR, ER]])
    try:
        compute_metrics_from_arrays(
            states=states,
            displacements=np.zeros_like(states, dtype=np.float64),
            evac_path=np.array([1], dtype=np.int64),
            n_cap=10,
            shelter_capacity=10,
            v_free=0.0,
        )
    except ValueError as exc:
        assert "v_free" in str(exc)
    else:  # pragma: no cover
        raise AssertionError("expected ValueError")