# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Test-only clean DGP that matches the IO-HMM exactly.
The production DGP in ``iohmm_evac.dgp`` has endogenous feedback (``c``,
``pi``, ``tir``) and an ``evac_path`` bookkeeping flag that the IO-HMM
deliberately does not model. To get clean recovery tests, we need a DGP
where the IO-HMM is the *correct* model — that's the role of this module.
It exposes :func:`generate` which, given an initial distribution, transition
parameters, emission parameters, and a fixed input sequence ``u``, samples a
``(states, departure, displacement, comm)`` panel by exactly the IO-HMM
generative process.
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from numpy.random import Generator
from scipy.special import softmax
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.types import FloatArray, IntArray
__all__ = ["CleanDGPSample", "generate"]
@dataclass(frozen=True, slots=True)
class CleanDGPSample:
"""The output of one :func:`generate` call."""
states: IntArray # (N, T+1)
departure: FloatArray # (N, T+1) Bernoulli draws as 0/1 floats
displacement: FloatArray # (N, T+1) Gaussian draws
comm: FloatArray # (N, T+1) Poisson draws as floats
inputs: FloatArray # (N, T+1, F)
def generate(params: FitParameters, inputs: FloatArray, rng: Generator) -> CleanDGPSample:
"""Draw one sample from the IO-HMM generative process.
``inputs`` has shape ``(N, T+1, F)`` and is treated as exogenous; the
transition at step ``t`` uses ``inputs[:, t, :]`` (so step 0 is unused
by transitions but kept for symmetry with the production DGP).
"""
n, t_plus_1, _f = inputs.shape
states = np.zeros((n, t_plus_1), dtype=np.int64)
init_probs = params.initial.probs()
k_states = int(init_probs.shape[0])
states[:, 0] = rng.choice(k_states, size=n, p=init_probs)
alpha = params.transitions.alpha
beta = params.transitions.beta
forbidden = ~np.isfinite(alpha)
for t in range(1, t_plus_1):
u_t = inputs[:, t, :] # (N, F)
# logits[i, k, j] for transitions originating in state k seen at i
logits = alpha[None, :, :] + np.einsum("nf,kjf->nkj", u_t, beta)
if forbidden.any():
logits = np.where(forbidden[None, :, :], -1e30, logits)
# Pick the row corresponding to the household's current state.
prev = states[:, t - 1]
per_household_logits = logits[np.arange(n), prev]
probs = softmax(per_household_logits, axis=1)
states[:, t] = _categorical(probs, rng)
p = params.emissions.p_departure
mu = params.emissions.mu_displacement
sigma = params.emissions.sigma_displacement
lam = params.emissions.lambda_comm
departure = (rng.random(size=(n, t_plus_1)) < p[states]).astype(np.float64)
displacement = rng.normal(mu[states], sigma[states])
comm = rng.poisson(lam[states]).astype(np.float64)
return CleanDGPSample(
states=states,
departure=np.asarray(departure, dtype=np.float64),
displacement=np.asarray(displacement, dtype=np.float64),
comm=np.asarray(comm, dtype=np.float64),
inputs=inputs,
)
def _categorical(probs: FloatArray, rng: Generator) -> IntArray:
"""One categorical draw per row of ``probs``."""
cum = np.cumsum(probs, axis=1)
cum[:, -1] = 1.0
draws = np.asarray(rng.random(size=probs.shape[0]), dtype=np.float64)
u = draws[:, None]
return np.asarray(np.argmax(u < cum, axis=1), dtype=np.int64)