tests/_clean_dgp.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Test-only clean DGP that matches the IO-HMM exactly.

The production DGP in ``iohmm_evac.dgp`` has endogenous feedback (``c``,
``pi``, ``tir``) and an ``evac_path`` bookkeeping flag that the IO-HMM
deliberately does not model. To get clean recovery tests, we need a DGP
where the IO-HMM is the *correct* model — that's the role of this module.

It exposes :func:`generate` which, given an initial distribution, transition
parameters, emission parameters, and a fixed input sequence ``u``, samples a
``(states, departure, displacement, comm)`` panel by exactly the IO-HMM
generative process.
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from numpy.random import Generator
from scipy.special import softmax

from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.types import FloatArray, IntArray

__all__ = ["CleanDGPSample", "generate"]


@dataclass(frozen=True, slots=True)
class CleanDGPSample:
    """The output of one :func:`generate` call."""

    states: IntArray  # (N, T+1)
    departure: FloatArray  # (N, T+1) Bernoulli draws as 0/1 floats
    displacement: FloatArray  # (N, T+1) Gaussian draws
    comm: FloatArray  # (N, T+1) Poisson draws as floats
    inputs: FloatArray  # (N, T+1, F)


def generate(params: FitParameters, inputs: FloatArray, rng: Generator) -> CleanDGPSample:
    """Draw one sample from the IO-HMM generative process.

    ``inputs`` has shape ``(N, T+1, F)`` and is treated as exogenous; the
    transition at step ``t`` uses ``inputs[:, t, :]`` (so step 0 is unused
    by transitions but kept for symmetry with the production DGP).
    """
    n, t_plus_1, _f = inputs.shape
    states = np.zeros((n, t_plus_1), dtype=np.int64)
    init_probs = params.initial.probs()
    k_states = int(init_probs.shape[0])
    states[:, 0] = rng.choice(k_states, size=n, p=init_probs)

    alpha = params.transitions.alpha
    beta = params.transitions.beta
    forbidden = ~np.isfinite(alpha)

    for t in range(1, t_plus_1):
        u_t = inputs[:, t, :]  # (N, F)
        # logits[i, k, j] for transitions originating in state k seen at i
        logits = alpha[None, :, :] + np.einsum("nf,kjf->nkj", u_t, beta)
        if forbidden.any():
            logits = np.where(forbidden[None, :, :], -1e30, logits)
        # Pick the row corresponding to the household's current state.
        prev = states[:, t - 1]
        per_household_logits = logits[np.arange(n), prev]
        probs = softmax(per_household_logits, axis=1)
        states[:, t] = _categorical(probs, rng)

    p = params.emissions.p_departure
    mu = params.emissions.mu_displacement
    sigma = params.emissions.sigma_displacement
    lam = params.emissions.lambda_comm

    departure = (rng.random(size=(n, t_plus_1)) < p[states]).astype(np.float64)
    displacement = rng.normal(mu[states], sigma[states])
    comm = rng.poisson(lam[states]).astype(np.float64)

    return CleanDGPSample(
        states=states,
        departure=np.asarray(departure, dtype=np.float64),
        displacement=np.asarray(displacement, dtype=np.float64),
        comm=np.asarray(comm, dtype=np.float64),
        inputs=inputs,
    )


def _categorical(probs: FloatArray, rng: Generator) -> IntArray:
    """One categorical draw per row of ``probs``."""
    cum = np.cumsum(probs, axis=1)
    cum[:, -1] = 1.0
    draws = np.asarray(rng.random(size=probs.shape[0]), dtype=np.float64)
    u = draws[:, None]
    return np.asarray(np.argmax(u < cum, axis=1), dtype=np.int64)