tests/test_population.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import numpy as np
import pytest

from iohmm_evac.dgp.population import synthesize_population, zone_codes
from iohmm_evac.params import PopulationParams


def test_population_truncation_bounds() -> None:
    rng = np.random.default_rng(0)
    params = PopulationParams()
    pop = synthesize_population(5_000, rng, params)
    assert pop.distance.min() >= params.distance_lo
    assert pop.distance.max() <= params.distance_hi
    assert pop.n == 5_000


def test_population_moments_approx() -> None:
    rng = np.random.default_rng(0)
    params = PopulationParams()
    pop = synthesize_population(20_000, rng, params)
    # Truncation pulls the mean inward but it should still be near mu.
    assert abs(pop.distance.mean() - params.distance_mu) < 2.0
    assert abs(pop.risk.mean()) < 0.05
    assert abs(pop.risk.std() - params.risk_sigma) < 0.1
    assert abs(pop.vehicle.mean() - params.vehicle_p) < 0.02
    assert pop.destination.min() >= params.dest_lo
    assert pop.destination.max() <= params.dest_hi


@pytest.mark.parametrize(
    ("distance", "expected_zone"),
    [(0.5, 0), (4.99, 0), (5.0, 1), (10.0, 1), (19.99, 1), (20.0, 2), (45.0, 2)],
)
def test_zone_derivation(distance: float, expected_zone: int) -> None:
    arr = np.array([distance])
    codes = zone_codes(arr, PopulationParams())
    assert int(codes[0]) == expected_zone


def test_population_reproducibility() -> None:
    p1 = synthesize_population(500, np.random.default_rng(42), PopulationParams())
    p2 = synthesize_population(500, np.random.default_rng(42), PopulationParams())
    np.testing.assert_array_equal(p1.distance, p2.distance)
    np.testing.assert_array_equal(p1.vehicle, p2.vehicle)
    np.testing.assert_array_equal(p1.risk, p2.risk)
    np.testing.assert_array_equal(p1.zone, p2.zone)
    np.testing.assert_array_equal(p1.destination, p2.destination)


def test_population_rejection_sampling_fills_completely() -> None:
    # Tight bounds force more rejection iterations.
    params = PopulationParams(distance_mu=15, distance_sigma=20, distance_lo=14, distance_hi=16)
    pop = synthesize_population(300, np.random.default_rng(7), params)
    assert pop.distance.shape == (300,)
    assert ((pop.distance >= 14) & (pop.distance <= 16)).all()