src/iohmm_evac/diagnostics/alignment.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""State-permutation alignment via the Hungarian algorithm.

Latent state indices in a fitted IO-HMM are arbitrary up to permutation.
To compare a fit to the truth (or to another fit) we need to align the
``K!`` possible labelings — solving the assignment problem on the
``K x K`` co-occurrence (confusion) matrix.
"""

from __future__ import annotations

import numpy as np
from scipy.optimize import linear_sum_assignment

from iohmm_evac.types import IntArray

__all__ = ["align_states", "apply_permutation"]


def align_states(true_states: IntArray, fit_states: IntArray, k: int) -> IntArray:
    """Return a permutation ``perm`` of length ``K``.

    ``perm[fit_label]`` gives the canonical (true) label that ``fit_label``
    should be relabeled to. Solves a Hungarian assignment that maximizes
    the total mass on the diagonal of the confusion matrix.
    """
    if true_states.shape != fit_states.shape:
        msg = f"shape mismatch: true {true_states.shape} vs fit {fit_states.shape}"
        raise ValueError(msg)
    confusion = np.zeros((k, k), dtype=np.int64)
    np.add.at(confusion, (true_states.ravel(), fit_states.ravel()), 1)
    # Hungarian minimizes cost; we want to maximize matches, so negate.
    row_ind, col_ind = linear_sum_assignment(-confusion)
    # row_ind is sorted true labels; col_ind[i] is the fit label assigned to true label row_ind[i].
    perm = np.empty(k, dtype=np.int64)
    for true_label, fit_label in zip(row_ind, col_ind, strict=True):
        perm[int(fit_label)] = int(true_label)
    return perm


def apply_permutation(states: IntArray, perm: IntArray) -> IntArray:
    """Relabel ``states`` according to ``perm``: ``out[t] = perm[states[t]]``."""
    return np.asarray(perm[states], dtype=np.int64)