# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""M-step solvers: closed-form for π/emissions, L-BFGS-B for transitions."""
from __future__ import annotations
import numpy as np
from scipy.optimize import minimize
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
EmissionFitParams,
FitParameters,
InitialFitParams,
TransitionFitParams,
)
from iohmm_evac.types import BoolArray, FloatArray
__all__ = [
"m_step",
"transition_neg_q_and_grad",
"update_emissions",
"update_initial",
"update_transitions",
]
def update_initial(log_gamma: FloatArray) -> InitialFitParams:
"""Closed-form M-step for the initial distribution."""
pi_at_zero = np.exp(log_gamma[:, 0, :]) # (N, K)
expected = pi_at_zero.sum(axis=0) # (K,)
expected = np.maximum(expected, 1e-12)
probs = expected / expected.sum()
logits = np.log(probs)
return InitialFitParams(logits=np.asarray(logits, dtype=np.float64))
def update_emissions(
log_gamma: FloatArray, data: FitData, sigma_floor: float = 1e-2
) -> EmissionFitParams:
"""Closed-form M-step for Bernoulli/Gaussian/Poisson emission parameters."""
gamma = np.exp(log_gamma) # (N, T+1, K)
weight_sum = gamma.sum(axis=(0, 1)) # (K,)
weight_sum_safe = np.maximum(weight_sum, 1e-12)
departure = data.departure
displacement = data.displacement
comm = data.comm
p_departure = (gamma * departure[:, :, None]).sum(axis=(0, 1)) / weight_sum_safe
p_departure = np.clip(p_departure, 1e-6, 1 - 1e-6)
mu = (gamma * displacement[:, :, None]).sum(axis=(0, 1)) / weight_sum_safe
diff_sq = (displacement[:, :, None] - mu[None, None, :]) ** 2
var = (gamma * diff_sq).sum(axis=(0, 1)) / weight_sum_safe
sigma = np.sqrt(np.maximum(var, sigma_floor * sigma_floor))
lam = (gamma * comm[:, :, None]).sum(axis=(0, 1)) / weight_sum_safe
lam = np.maximum(lam, 1e-6)
return EmissionFitParams(
p_departure=np.asarray(p_departure, dtype=np.float64),
mu_displacement=np.asarray(mu, dtype=np.float64),
sigma_displacement=np.asarray(sigma, dtype=np.float64),
lambda_comm=np.asarray(lam, dtype=np.float64),
sigma_floor=sigma_floor,
)
def _flatten_origin_params(
alpha_row: FloatArray, beta_row: FloatArray, learnable_j: BoolArray
) -> FloatArray:
"""Flatten the learnable subset of one origin row into a 1-D parameter vector.
Layout: ``[alpha_j1, alpha_j2, ..., beta_j1_f1, beta_j1_f2, ..., beta_j2_f1, ...]``
where the ``j*`` are the learnable destination indices in ascending order.
"""
js = np.flatnonzero(learnable_j)
f = beta_row.shape[1]
n_dest = js.shape[0]
out = np.empty(n_dest + n_dest * f, dtype=np.float64)
out[:n_dest] = alpha_row[js]
out[n_dest:] = beta_row[js].reshape(-1)
return out
def transition_neg_q_and_grad(
x: FloatArray,
*,
origin_k: int,
learnable_j: BoolArray,
xi_origin: FloatArray,
gamma_origin: FloatArray,
inputs_steps: FloatArray,
) -> tuple[float, FloatArray]:
"""Negative weighted multinomial-logit objective for one origin row.
Returns ``(-Q_k, -grad_k)`` so the result can be passed straight to
:func:`scipy.optimize.minimize` with ``jac=True``.
Performance: the destination axis is treated as ``[self_loop, *learnable]``
only (forbidden cells are dropped — their ``xi`` is zero in finite
precision so they contribute nothing to either ``Q`` or its gradient).
The β gradient is a single ``inputs.T @ diff`` matmul rather than the
naive ``(N*T, F, n_dest)`` outer-product sum.
"""
js = np.flatnonzero(learnable_j)
f = inputs_steps.shape[1]
n_dest = js.shape[0]
alpha_learn = x[:n_dest]
beta_learn = x[n_dest:].reshape(n_dest, f)
if n_dest == 0:
# Absorbing row: A_kk = 1 deterministically; Q = 0 and grad = 0.
return 0.0, np.zeros_like(x)
learnable_logits = alpha_learn[None, :] + inputs_steps @ beta_learn.T # (N_steps, n_dest)
# Compose with the self-loop logit (= 0) and log-softmax across the
# joined ``[self, *learnable]`` axis only; forbidden destinations
# contribute exp(-inf) = 0 to the normalizer.
m = np.maximum(learnable_logits.max(axis=1), 0.0) # (N_steps,)
exp_self = np.exp(-m)
exp_learn = np.exp(learnable_logits - m[:, None]) # (N_steps, n_dest)
z = exp_self + exp_learn.sum(axis=1) # (N_steps,)
log_z = m + np.log(z) # (N_steps,)
log_a_learn = learnable_logits - log_z[:, None] # (N_steps, n_dest)
log_a_self = -log_z # (N_steps,)
a_learn = np.exp(log_a_learn) # (N_steps, n_dest)
xi_self = xi_origin[:, origin_k]
xi_learn = xi_origin[:, js] # (N_steps, n_dest)
q = float(np.sum(xi_self * log_a_self) + np.sum(xi_learn * log_a_learn))
# diff_d = xi[d] - gamma * A_d for each learnable destination d
diff_learn = xi_learn - gamma_origin[:, None] * a_learn # (N_steps, n_dest)
grad_alpha = diff_learn.sum(axis=0) # (n_dest,)
grad_beta = inputs_steps.T @ diff_learn # (F, n_dest)
grad = np.empty(n_dest + n_dest * f, dtype=np.float64)
grad[:n_dest] = grad_alpha
grad[n_dest:] = grad_beta.T.reshape(-1)
return -q, -grad
def update_transitions(
log_xi: FloatArray,
log_gamma: FloatArray,
inputs: FloatArray,
current: TransitionFitParams,
*,
maxiter: int = 50,
tol: float = 1e-6,
) -> TransitionFitParams:
"""L-BFGS-B M-step for the transition row of every non-absorbing origin."""
xi = np.exp(log_xi) # (N, T, K, K)
gamma = np.exp(log_gamma) # (N, T+1, K)
n, t_total, k_states, _ = xi.shape
f = inputs.shape[2]
allowed_mask = np.isfinite(current.alpha)
self_mask = np.eye(k_states, dtype=bool)
learnable_mask = allowed_mask & ~self_mask
inputs_steps = inputs[:, 1:, :].reshape(n * t_total, f)
xi_flat = xi.reshape(n * t_total, k_states, k_states)
gamma_flat = gamma[:, :-1, :].reshape(n * t_total, k_states)
new_alpha = current.alpha.copy()
new_beta = current.beta.copy()
np.fill_diagonal(new_alpha, 0.0)
for k in range(k_states):
learnable_j = learnable_mask[k]
if not learnable_j.any():
continue # absorbing: nothing to optimize
x0 = _flatten_origin_params(current.alpha[k], current.beta[k], learnable_j)
def objective(
xx: FloatArray, k_local: int = k, lj: BoolArray = learnable_j
) -> tuple[float, FloatArray]:
return transition_neg_q_and_grad(
xx,
origin_k=k_local,
learnable_j=lj,
xi_origin=xi_flat[:, k_local, :],
gamma_origin=gamma_flat[:, k_local],
inputs_steps=inputs_steps,
)
result = minimize(
objective,
x0,
jac=True,
method="L-BFGS-B",
options={"maxiter": maxiter, "gtol": tol},
)
x_opt = np.asarray(result.x, dtype=np.float64)
js = np.flatnonzero(learnable_j)
n_dest = js.shape[0]
new_alpha[k, js] = x_opt[:n_dest]
new_beta[k, js] = x_opt[n_dest:].reshape(n_dest, f)
forbidden = ~allowed_mask[k]
new_alpha[k, forbidden] = -np.inf
new_beta[k, forbidden] = 0.0
new_alpha[k, k] = 0.0
new_beta[k, k] = 0.0
return TransitionFitParams(alpha=new_alpha, beta=new_beta)
def m_step(
log_gamma: FloatArray,
log_xi: FloatArray,
data: FitData,
current: FitParameters,
*,
sigma_floor: float = 1e-2,
transition_maxiter: int = 50,
transition_tol: float = 1e-6,
) -> FitParameters:
"""One full M-step: closed-form initial+emissions, L-BFGS for transitions."""
initial = update_initial(log_gamma)
emissions = update_emissions(log_gamma, data, sigma_floor=sigma_floor)
transitions = update_transitions(
log_xi,
log_gamma,
data.inputs,
current.transitions,
maxiter=transition_maxiter,
tol=transition_tol,
)
return FitParameters(
initial=initial,
transitions=transitions,
emissions=emissions,
feature_names=current.feature_names,
)