src/iohmm_evac/bootstrap/resample.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Household-level bootstrap resampling.

The resample is just an index vector; downstream code consumes
``(observations[idx], inputs[idx], population[idx])``. We never duplicate
the underlying parquet rows.
"""

from __future__ import annotations

from collections.abc import Iterator
from dataclasses import dataclass

import numpy as np

from iohmm_evac.inference.data import FitData
from iohmm_evac.types import IntArray

__all__ = [
    "ResampledFitData",
    "index_fit_data",
    "resample_indices",
]


@dataclass(frozen=True, slots=True)
class ResampledFitData:
    """A bootstrap-resampled :class:`FitData` plus the index vector that produced it."""

    indices: IntArray
    """The household indices used for this resample, shape (N,)."""
    data: FitData
    """The :class:`FitData` produced by indexing into the source bundle."""


def resample_indices(n: int, n_resamples: int, seed: int) -> Iterator[IntArray]:
    """Yield ``n_resamples`` index vectors of length ``n`` drawn with replacement.

    Reproducible: ``seed=k`` always produces the same sequence of resamples,
    independent of ``n_resamples``.
    """
    if n <= 0:
        msg = f"n must be a positive integer, got {n}"
        raise ValueError(msg)
    if n_resamples < 0:
        msg = f"n_resamples must be non-negative, got {n_resamples}"
        raise ValueError(msg)
    rng = np.random.default_rng(seed)
    for _ in range(n_resamples):
        idx = rng.integers(0, n, size=n, dtype=np.int64)
        yield np.asarray(idx, dtype=np.int64)


def index_fit_data(data: FitData, indices: IntArray) -> FitData:
    """Return a :class:`FitData` whose households are ``data`` indexed by ``indices``.

    ``indices`` is a length-``M`` integer vector; rows may repeat. The
    resulting :class:`FitData` has ``M`` households and the same time axis.
    """
    if indices.ndim != 1:
        msg = f"indices must be 1-D, got shape {indices.shape}"
        raise ValueError(msg)
    if indices.size and (int(indices.min()) < 0 or int(indices.max()) >= data.n):
        msg = (
            f"indices out of bounds: min={int(indices.min())}, max={int(indices.max())}, n={data.n}"
        )
        raise ValueError(msg)
    sub_inputs = np.ascontiguousarray(data.inputs[indices])
    sub_dep = np.ascontiguousarray(data.departure[indices])
    sub_disp = np.ascontiguousarray(data.displacement[indices])
    sub_comm = np.ascontiguousarray(data.comm[indices])
    sub_states = (
        np.ascontiguousarray(data.true_states[indices]) if data.true_states is not None else None
    )
    return FitData(
        inputs=sub_inputs,
        departure=sub_dep,
        displacement=sub_disp,
        comm=sub_comm,
        true_states=sub_states,
    )