# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Bootstrap runner tests: parallel EM fits + warm-start sanity check."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
from iohmm_evac.bootstrap.runner import (
BootstrapFit,
load_bootstrap_fits,
run_bootstrap_fits,
)
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.fit_params import dgp_truth_to_fit_init
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.loader import load_bundle
@pytest.fixture(scope="module")
def small_data(tmp_path_factory: pytest.TempPathFactory): # type: ignore[no-untyped-def]
out = tmp_path_factory.mktemp("runner")
config = SimulationConfig(n_households=200, n_hours=24, seed=0)
rng = np.random.default_rng(config.seed)
result = simulate(config, rng)
obs = out / "obs.parquet"
write_results(result, obs)
bundle = load_bundle(obs)
data = bundle_to_fit_data(bundle)
truth = dgp_truth_to_fit_init(config.transitions, config.emissions, config.population)
return data, truth
def test_run_bootstrap_fits_writes_replicate_dirs(tmp_path: Path, small_data) -> None: # type: ignore[no-untyped-def]
data, _ = small_data
out_dir = tmp_path / "fits"
em_config = EMConfig(max_iter=3, tol=1e-3, verbose=False)
fits = run_bootstrap_fits(
data=data,
n_replicates=2,
em_config=em_config,
base_seed=0,
n_jobs=2,
output_dir=out_dir,
)
assert len(fits) == 2
for fit in fits:
rep_dir = out_dir / f"replicate_{fit.replicate_id:03d}"
for name in ("theta.toml", "metadata.toml", "indices.parquet"):
assert (rep_dir / name).exists(), f"missing {name} in {rep_dir}"
def test_log_likelihoods_are_finite(tmp_path: Path, small_data) -> None: # type: ignore[no-untyped-def]
data, _ = small_data
out_dir = tmp_path / "fits"
fits = run_bootstrap_fits(
data=data,
n_replicates=2,
em_config=EMConfig(max_iter=3, tol=1e-3, verbose=False),
base_seed=0,
n_jobs=1,
output_dir=out_dir,
)
for fit in fits:
assert np.isfinite(fit.final_log_likelihood)
assert fit.iterations >= 1
def test_load_bootstrap_fits_round_trips(tmp_path: Path, small_data) -> None: # type: ignore[no-untyped-def]
data, _ = small_data
out_dir = tmp_path / "fits"
fits = run_bootstrap_fits(
data=data,
n_replicates=2,
em_config=EMConfig(max_iter=2, tol=1e-3, verbose=False),
base_seed=0,
n_jobs=1,
output_dir=out_dir,
)
loaded = load_bootstrap_fits(out_dir)
assert len(loaded) == len(fits)
for original, reloaded in zip(fits, loaded, strict=True):
assert original.replicate_id == reloaded.replicate_id
assert original.iterations == reloaded.iterations
np.testing.assert_array_equal(original.indices, reloaded.indices)
np.testing.assert_allclose(
original.params.transitions.beta, reloaded.params.transitions.beta
)
def test_warm_start_reduces_iteration_count(tmp_path: Path, small_data) -> None: # type: ignore[no-untyped-def]
"""Sanity check, not a strict assertion: warm beats cold on average."""
data, truth = small_data
em_config = EMConfig(max_iter=20, tol=1e-4, verbose=False)
cold = run_bootstrap_fits(
data=data,
n_replicates=2,
em_config=em_config,
base_seed=0,
n_jobs=1,
output_dir=tmp_path / "cold",
)
warm = run_bootstrap_fits(
data=data,
n_replicates=2,
em_config=em_config,
base_seed=0,
n_jobs=1,
output_dir=tmp_path / "warm",
warm_start_theta=truth,
)
cold_avg = float(np.mean([f.iterations for f in cold]))
warm_avg = float(np.mean([f.iterations for f in warm]))
assert warm_avg <= cold_avg, (
f"warm avg iters ({warm_avg}) should be <= cold avg iters ({cold_avg})"
)
def test_n_replicates_must_be_positive(tmp_path: Path, small_data) -> None: # type: ignore[no-untyped-def]
data, _ = small_data
with pytest.raises(ValueError, match=">= 1"):
run_bootstrap_fits(
data=data,
n_replicates=0,
em_config=EMConfig(max_iter=1, tol=1e-3, verbose=False),
base_seed=0,
n_jobs=1,
output_dir=tmp_path / "empty",
)
def test_load_bootstrap_fits_missing_dir(tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError):
load_bootstrap_fits(tmp_path / "no-such")
def test_load_bootstrap_fits_empty_dir(tmp_path: Path) -> None:
out = tmp_path / "empty_fits"
out.mkdir()
with pytest.raises(FileNotFoundError, match="No replicate"):
load_bootstrap_fits(out)
def test_bootstrap_fit_dataclass_is_frozen() -> None:
indices = np.zeros(3, dtype=np.int64)
fit = BootstrapFit(
replicate_id=0,
params=None, # type: ignore[arg-type]
final_log_likelihood=-1.0,
iterations=1,
converged=False,
indices=indices,
)
with pytest.raises(AttributeError):
fit.replicate_id = 99 # type: ignore[misc]