# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Log-space forward-backward recursions for the IO-HMM.
All quantities are kept in log-space; ``logsumexp`` is the only place where
exponentials are taken explicitly.
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.inference.log_space import LOG_EPS, logsumexp, safe_log
from iohmm_evac.types import FloatArray
__all__ = [
"ForwardBackwardResult",
"emission_log_prob",
"forward_backward",
"log_transition_matrix",
]
@dataclass(frozen=True, slots=True)
class ForwardBackwardResult:
"""Output of a single forward-backward pass.
All log-arrays are well-defined (no ``NaN``); forbidden cells in
``log_xi`` are at :data:`~iohmm_evac.inference.log_space.LOG_EPS`.
"""
log_gamma: FloatArray
"""Posterior over states, shape ``(N, T+1, K)``."""
log_xi: FloatArray
"""Pairwise posterior, shape ``(N, T, K, K)`` for transitions ``t -> t+1``."""
log_alpha: FloatArray
"""Forward messages, shape ``(N, T+1, K)``."""
log_beta: FloatArray
"""Backward messages, shape ``(N, T+1, K)``."""
log_likelihood: FloatArray
"""Per-household log-likelihood, shape ``(N,)``."""
def log_transition_matrix(
inputs: FloatArray,
transitions_alpha: FloatArray,
transitions_beta: FloatArray,
) -> FloatArray:
"""Return ``log A_{kj}(u_{i,t})`` for every ``(i, t, k, j)``.
The output has shape ``(N, T+1, K, K)``. Forbidden destinations carry
:data:`~iohmm_evac.inference.log_space.LOG_EPS` (a finite stand-in for
``-inf``); the row over ``j`` is log-softmax-normalized.
Note: although we materialize an ``(N, T+1, K, K)`` array, only slots
``t = 1..T`` participate in the recursion (transitions ``t-1 -> t``);
``t = 0`` is never read.
"""
logits = transitions_alpha[None, None, :, :] + np.einsum(
"ntf,kjf->ntkj", inputs, transitions_beta
)
forbidden = ~np.isfinite(transitions_alpha)
if forbidden.any():
logits = np.where(forbidden[None, None, :, :], LOG_EPS, logits)
m = np.max(logits, axis=-1, keepdims=True)
z = m + np.log(np.exp(logits - m).sum(axis=-1, keepdims=True))
return np.asarray(logits - z, dtype=np.float64)
def _gaussian_log_pdf(x: FloatArray, mu: float, sigma: float) -> FloatArray:
"""Vectorized univariate Gaussian log-pdf with a finite ``sigma`` floor."""
var = max(sigma * sigma, 1e-12)
diff = x - mu
out = -0.5 * (np.log(2.0 * np.pi * var) + diff * diff / var)
return np.asarray(out, dtype=np.float64)
def emission_log_prob(data: FitData, params: FitParameters) -> FloatArray:
"""Compute ``log b_k(y_{i,t})`` for every (i, t, k).
Channels are assumed conditionally independent given the state, so the
log-pdf is the sum over the three observed channels.
"""
p = params.emissions.p_departure
mu = params.emissions.mu_displacement
sigma = params.emissions.sigma_displacement
lam = params.emissions.lambda_comm
n, t_plus_1 = data.departure.shape
k_states = int(p.shape[0])
log_b = np.zeros((n, t_plus_1, k_states), dtype=np.float64)
for k in range(k_states):
log_p = float(np.log(np.clip(p[k], 1e-12, 1 - 1e-12)))
log_1mp = float(np.log(np.clip(1.0 - p[k], 1e-12, 1 - 1e-12)))
bern = data.departure * log_p + (1.0 - data.departure) * log_1mp
gauss = _gaussian_log_pdf(data.displacement, float(mu[k]), float(sigma[k]))
lam_k = float(max(lam[k], 1e-9))
from scipy.special import gammaln
c = data.comm
pois = c * np.log(lam_k) - lam_k - np.asarray(gammaln(c + 1.0), dtype=np.float64)
log_b[:, :, k] = bern + gauss + pois
return log_b
def _logsumexp_axis(x: FloatArray, axis: int) -> FloatArray:
"""Hand-rolled log-sum-exp along one axis; faster than scipy in tight loops."""
m = np.max(x, axis=axis, keepdims=True)
out = np.log(np.exp(x - m).sum(axis=axis)) + np.squeeze(m, axis=axis)
return np.asarray(out, dtype=np.float64)
def _forward_pass(
log_b: FloatArray,
log_initial: FloatArray,
log_a: FloatArray,
) -> tuple[FloatArray, FloatArray]:
"""Compute log-alpha messages and per-household log-likelihood."""
n, t_plus_1, k = log_b.shape
log_alpha = np.empty((n, t_plus_1, k), dtype=np.float64)
log_alpha[:, 0, :] = log_initial[None, :] + log_b[:, 0, :]
for t in range(1, t_plus_1):
# log_alpha[t, k] = log_b[t, k] + logsumexp_j(log_alpha[t-1, j] + log_a[t, j, k])
prev = log_alpha[:, t - 1, :, None] # (N, K, 1)
trans = log_a[:, t, :, :] # (N, K, K) — A_{j,k}
log_alpha[:, t, :] = log_b[:, t, :] + _logsumexp_axis(prev + trans, axis=1)
log_likelihood = logsumexp(log_alpha[:, t_plus_1 - 1, :], axis=1)
return log_alpha, np.asarray(log_likelihood, dtype=np.float64)
def _backward_pass(log_b: FloatArray, log_a: FloatArray) -> FloatArray:
"""Compute log-beta messages."""
n, t_plus_1, k = log_b.shape
log_beta = np.zeros((n, t_plus_1, k), dtype=np.float64)
for t in range(t_plus_1 - 2, -1, -1):
# log_beta[t, k] = logsumexp_j(log_a[t+1, k, j] + log_b[t+1, j] + log_beta[t+1, j])
nxt = log_b[:, t + 1, :] + log_beta[:, t + 1, :] # (N, K)
log_beta[:, t, :] = _logsumexp_axis(log_a[:, t + 1, :, :] + nxt[:, None, :], axis=2)
return log_beta
def forward_backward(params: FitParameters, data: FitData) -> ForwardBackwardResult:
"""Run forward-backward and return posteriors plus log-likelihood."""
log_a = log_transition_matrix(data.inputs, params.transitions.alpha, params.transitions.beta)
log_b = emission_log_prob(data, params)
log_initial = safe_log(params.initial.probs())
log_alpha, log_likelihood = _forward_pass(log_b, log_initial, log_a)
log_beta = _backward_pass(log_b, log_a)
log_gamma = log_alpha + log_beta - log_likelihood[:, None, None]
n, t_plus_1, k = log_b.shape
t_total = t_plus_1 - 1
log_xi = np.zeros((n, t_total, k, k), dtype=np.float64)
if t_total > 0:
a_kj = log_a[:, 1:, :, :] # transitions for steps 1..T
b_next = log_b[:, 1:, :] # emissions at t+1
beta_next = log_beta[:, 1:, :] # backwards at t+1
alpha_prev = log_alpha[:, :-1, :] # forward at t
log_xi = (
alpha_prev[:, :, :, None]
+ a_kj
+ b_next[:, :, None, :]
+ beta_next[:, :, None, :]
- log_likelihood[:, None, None, None]
)
return ForwardBackwardResult(
log_gamma=np.asarray(log_gamma, dtype=np.float64),
log_xi=np.asarray(log_xi, dtype=np.float64),
log_alpha=log_alpha,
log_beta=log_beta,
log_likelihood=log_likelihood,
)