# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Initialization strategies for the IO-HMM parameter dataclasses."""
from __future__ import annotations
import numpy as np
from numpy.random import Generator
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
ALLOWED_TRANSITIONS,
EmissionFitParams,
F,
FitParameters,
InitialFitParams,
K,
TransitionFitParams,
dgp_truth_to_fit_init,
)
from iohmm_evac.params import EmissionParams as DGPEmissionParams
from iohmm_evac.params import PopulationParams as DGPPopulationParams
from iohmm_evac.params import TransitionParams as DGPTransitionParams
from iohmm_evac.types import State
__all__ = [
"from_dgp_truth",
"kmeans_init",
"random_initialization",
]
def from_dgp_truth(
transitions: DGPTransitionParams,
emissions: DGPEmissionParams,
population: DGPPopulationParams | None = None,
) -> FitParameters:
"""Initialize at the DGP's true projected parameters."""
return dgp_truth_to_fit_init(transitions, emissions, population)
def random_initialization(
rng: Generator,
*,
alpha_scale: float = 1.0,
beta_scale: float = 0.3,
sigma: float = 1.0,
) -> FitParameters:
"""Randomly perturb a generic prior to seed an EM restart.
The transition matrix starts at a slight bias toward staying in place
(self-loops at logit 0; non-self learnable entries at moderately negative
values plus jitter). Emissions are seeded at heuristic location/scales.
"""
alpha = np.full((K, K), -np.inf, dtype=np.float64)
beta = np.zeros((K, K, F), dtype=np.float64)
np.fill_diagonal(alpha, 0.0)
for k in range(K):
for j in range(K):
if k == j or not ALLOWED_TRANSITIONS[k, j]:
continue
alpha[k, j] = -3.0 + alpha_scale * rng.normal(0.0, 1.0)
beta[k, j] = beta_scale * rng.normal(0.0, 1.0, size=F)
init_logits = np.full(K, -2.0, dtype=np.float64)
init_logits[int(State.UA)] = 0.0
init_logits = init_logits + 0.1 * rng.normal(size=K)
p_departure = np.clip(0.05 + 0.05 * rng.normal(size=K), 1e-3, 0.999)
p_departure[int(State.ER)] = 0.9 + 0.05 * rng.standard_normal()
p_departure = np.clip(p_departure, 1e-3, 1 - 1e-3)
mu = np.array([0.5, 0.5, 0.5, 20.0, 60.0], dtype=np.float64) + sigma * rng.normal(size=K)
sig = np.full(K, max(sigma, 0.5), dtype=np.float64)
lam = np.array([0.5, 1.5, 3.5, 2.5, 1.0], dtype=np.float64) + 0.3 * rng.normal(size=K)
lam = np.maximum(lam, 1e-3)
emit = EmissionFitParams(
p_departure=np.asarray(p_departure, dtype=np.float64),
mu_displacement=np.asarray(mu, dtype=np.float64),
sigma_displacement=np.asarray(sig, dtype=np.float64),
lambda_comm=np.asarray(lam, dtype=np.float64),
)
trans = TransitionFitParams(alpha=alpha, beta=beta)
return FitParameters(
initial=InitialFitParams(logits=np.asarray(init_logits, dtype=np.float64)),
transitions=trans,
emissions=emit,
)
def kmeans_init(data: FitData, rng: Generator, *, n_iter: int = 20) -> FitParameters:
"""Seed emission means via mini K-means on (departure, displacement, comm).
Transitions and the initial distribution are seeded from
:func:`random_initialization`; only the emission means are replaced.
"""
base = random_initialization(rng)
feats = np.stack(
[
data.departure.reshape(-1),
data.displacement.reshape(-1),
data.comm.reshape(-1),
],
axis=1,
)
n_total = feats.shape[0]
if n_total < K:
return base
# Mini K-means with K initial centers drawn at random.
idx = rng.choice(n_total, size=K, replace=False)
centers = feats[idx].copy()
for _ in range(n_iter):
dists = np.linalg.norm(feats[:, None, :] - centers[None, :, :], axis=2)
labels = np.argmin(dists, axis=1)
new_centers = centers.copy()
for k in range(K):
mask = labels == k
if mask.any():
new_centers[k] = feats[mask].mean(axis=0)
if np.allclose(new_centers, centers, atol=1e-6):
centers = new_centers
break
centers = new_centers
# Map cluster centers (sorted by displacement) back to canonical state order.
order = np.argsort(centers[:, 1])
centers = centers[order]
p = np.clip(centers[:, 0], 1e-3, 1 - 1e-3)
mu = centers[:, 1]
lam = np.maximum(centers[:, 2], 1e-3)
emit = EmissionFitParams(
p_departure=np.asarray(p, dtype=np.float64),
mu_displacement=np.asarray(mu, dtype=np.float64),
sigma_displacement=base.emissions.sigma_displacement.copy(),
lambda_comm=np.asarray(lam, dtype=np.float64),
)
return FitParameters(
initial=base.initial,
transitions=base.transitions,
emissions=emit,
feature_names=base.feature_names,
)