# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import numpy as np
import pytest
from iohmm_evac.dgp.population import synthesize_population, zone_codes
from iohmm_evac.params import PopulationParams
def test_population_truncation_bounds() -> None:
rng = np.random.default_rng(0)
params = PopulationParams()
pop = synthesize_population(5_000, rng, params)
assert pop.distance.min() >= params.distance_lo
assert pop.distance.max() <= params.distance_hi
assert pop.n == 5_000
def test_population_moments_approx() -> None:
rng = np.random.default_rng(0)
params = PopulationParams()
pop = synthesize_population(20_000, rng, params)
# Truncation pulls the mean inward but it should still be near mu.
assert abs(pop.distance.mean() - params.distance_mu) < 2.0
assert abs(pop.risk.mean()) < 0.05
assert abs(pop.risk.std() - params.risk_sigma) < 0.1
assert abs(pop.vehicle.mean() - params.vehicle_p) < 0.02
assert pop.destination.min() >= params.dest_lo
assert pop.destination.max() <= params.dest_hi
@pytest.mark.parametrize(
("distance", "expected_zone"),
[(0.5, 0), (4.99, 0), (5.0, 1), (10.0, 1), (19.99, 1), (20.0, 2), (45.0, 2)],
)
def test_zone_derivation(distance: float, expected_zone: int) -> None:
arr = np.array([distance])
codes = zone_codes(arr, PopulationParams())
assert int(codes[0]) == expected_zone
def test_population_reproducibility() -> None:
p1 = synthesize_population(500, np.random.default_rng(42), PopulationParams())
p2 = synthesize_population(500, np.random.default_rng(42), PopulationParams())
np.testing.assert_array_equal(p1.distance, p2.distance)
np.testing.assert_array_equal(p1.vehicle, p2.vehicle)
np.testing.assert_array_equal(p1.risk, p2.risk)
np.testing.assert_array_equal(p1.zone, p2.zone)
np.testing.assert_array_equal(p1.destination, p2.destination)
def test_population_rejection_sampling_fills_completely() -> None:
# Tight bounds force more rejection iterations.
params = PopulationParams(distance_mu=15, distance_sigma=20, distance_lo=14, distance_hi=16)
pop = synthesize_population(300, np.random.default_rng(7), params)
assert pop.distance.shape == (300,)
assert ((pop.distance >= 14) & (pop.distance <= 16)).all()