# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Tests for the Hungarian-based state-alignment helper."""
from __future__ import annotations
import numpy as np
from iohmm_evac.diagnostics.alignment import align_states, apply_permutation
def test_align_recovers_identity_permutation() -> None:
rng = np.random.default_rng(0)
truth = rng.integers(0, 4, size=(20, 30))
perm = align_states(truth, truth, k=4)
assert (perm == np.arange(4)).all()
def test_align_recovers_known_permutation() -> None:
rng = np.random.default_rng(0)
truth = rng.integers(0, 5, size=(50, 40))
# Apply a known permutation: relabel truth -> fit using known mapping.
# Here, fit_label = inverse_permutation[true_label].
known = np.array([2, 4, 0, 3, 1], dtype=np.int64)
fit = known[truth]
perm = align_states(truth, fit, k=5)
# perm[fit_label] should give back true_label, so applying perm to fit
# should reconstruct truth.
relabeled = apply_permutation(fit, perm)
assert (relabeled == truth).all()
def test_align_maximizes_diagonal_mass() -> None:
"""Even with noise, the aligned diagonal should be the largest possible."""
rng = np.random.default_rng(0)
truth = rng.integers(0, 4, size=(100, 50))
flips = rng.random(truth.shape) < 0.2
fit = truth.copy()
fit[flips] = (fit[flips] + 1) % 4
perm = align_states(truth, fit, k=4)
aligned = apply_permutation(fit, perm)
accuracy = float((aligned == truth).mean())
# Random labeling would give 25%; alignment should clear 70%+.
assert accuracy > 0.7
def test_align_shape_mismatch_errors() -> None:
truth = np.zeros((10, 5), dtype=np.int64)
fit = np.zeros((10, 6), dtype=np.int64)
try:
align_states(truth, fit, k=3)
except ValueError:
return
raise AssertionError("expected shape-mismatch ValueError")