# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Numerical-gradient and closed-form sanity checks for the M-step."""
from __future__ import annotations
import numpy as np
import pytest
from scipy.optimize import check_grad
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
EmissionFitParams,
FitParameters,
InitialFitParams,
TransitionFitParams,
)
from iohmm_evac.inference.forward_backward import forward_backward
from iohmm_evac.inference.m_step import (
transition_neg_q_and_grad,
update_emissions,
update_initial,
)
def _random_problem(seed: int = 0) -> tuple[FitParameters, FitData]:
"""A small K=3, T=10, N=20 problem with random ξ/γ for the gradient check."""
rng = np.random.default_rng(seed)
k = 3
f = 4
n = 20
t_plus_1 = 11
inputs = rng.normal(size=(n, t_plus_1, f)) * 0.5
alpha = np.array(
[
[0.0, -1.0, -2.0],
[-1.5, 0.0, -1.0],
[-2.0, -2.0, 0.0],
],
dtype=np.float64,
)
beta = rng.normal(size=(k, k, f)) * 0.2
np.fill_diagonal(alpha, 0.0)
for kk in range(k):
beta[kk, kk] = 0.0
emit = EmissionFitParams(
p_departure=np.array([0.1, 0.5, 0.9]),
mu_displacement=np.array([0.0, 1.0, 4.0]),
sigma_displacement=np.array([1.0, 1.0, 1.0]),
lambda_comm=np.array([0.5, 1.5, 3.0]),
)
params = FitParameters(
initial=InitialFitParams(logits=np.array([0.0, -0.5, -1.0])),
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
)
departure = (rng.random((n, t_plus_1)) < 0.4).astype(np.float64)
displacement = rng.normal(size=(n, t_plus_1)) * 0.5 + 1.0
comm = rng.poisson(1.5, size=(n, t_plus_1)).astype(np.float64)
data = FitData(
inputs=inputs,
departure=departure,
displacement=displacement,
comm=comm,
true_states=None,
)
return params, data
def test_transition_gradient_numerical_check() -> None:
params, data = _random_problem(seed=42)
fb = forward_backward(params, data)
xi = np.exp(fb.log_xi)
gamma = np.exp(fb.log_gamma)
n, t_total, k, _ = xi.shape
f = data.inputs.shape[2]
inputs_steps = data.inputs[:, 1:, :].reshape(n * t_total, f)
xi_flat = xi.reshape(n * t_total, k, k)
gamma_flat = gamma[:, :-1, :].reshape(n * t_total, k)
rng = np.random.default_rng(0)
for origin_k in range(k):
learnable_j = np.array([j != origin_k for j in range(k)], dtype=bool)
n_dest = int(learnable_j.sum())
x0 = rng.normal(size=n_dest + n_dest * f) * 0.5
def fun(x: np.ndarray, ok: int = origin_k, lj: np.ndarray = learnable_j) -> float:
val, _grad = transition_neg_q_and_grad(
x,
origin_k=ok,
learnable_j=lj,
xi_origin=xi_flat[:, ok, :],
gamma_origin=gamma_flat[:, ok],
inputs_steps=inputs_steps,
)
return val
def grad(x: np.ndarray, ok: int = origin_k, lj: np.ndarray = learnable_j) -> np.ndarray:
_val, g = transition_neg_q_and_grad(
x,
origin_k=ok,
learnable_j=lj,
xi_origin=xi_flat[:, ok, :],
gamma_origin=gamma_flat[:, ok],
inputs_steps=inputs_steps,
)
return g
err = check_grad(fun, grad, x0)
# Tolerance: 1e-5 on a tiny problem.
assert err < 1e-4, f"origin {origin_k}: numerical-gradient error {err:.3e}"
def test_emission_closed_form_matches_brute_force() -> None:
"""With known posteriors, weighted-MLE emission updates match a hand-rolled MLE."""
params, data = _random_problem(seed=1)
fb = forward_backward(params, data)
new_emit = update_emissions(fb.log_gamma, data, sigma_floor=1e-6)
gamma = np.exp(fb.log_gamma)
weights = gamma.sum(axis=(0, 1)) # (K,)
expected_p = (gamma * data.departure[:, :, None]).sum(axis=(0, 1)) / weights
expected_mu = (gamma * data.displacement[:, :, None]).sum(axis=(0, 1)) / weights
expected_lam = (gamma * data.comm[:, :, None]).sum(axis=(0, 1)) / weights
assert new_emit.p_departure == pytest.approx(np.clip(expected_p, 1e-6, 1 - 1e-6), abs=1e-9)
assert new_emit.mu_displacement == pytest.approx(expected_mu, abs=1e-9)
assert new_emit.lambda_comm == pytest.approx(np.maximum(expected_lam, 1e-6), abs=1e-9)
def test_initial_closed_form_normalizes() -> None:
params, data = _random_problem(seed=2)
fb = forward_backward(params, data)
new_init = update_initial(fb.log_gamma)
probs = new_init.probs()
assert probs.sum() == pytest.approx(1.0, abs=1e-9)
assert (probs >= 0).all()