tests/test_alignment.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Tests for the Hungarian-based state-alignment helper."""

from __future__ import annotations

import numpy as np

from iohmm_evac.diagnostics.alignment import align_states, apply_permutation


def test_align_recovers_identity_permutation() -> None:
    rng = np.random.default_rng(0)
    truth = rng.integers(0, 4, size=(20, 30))
    perm = align_states(truth, truth, k=4)
    assert (perm == np.arange(4)).all()


def test_align_recovers_known_permutation() -> None:
    rng = np.random.default_rng(0)
    truth = rng.integers(0, 5, size=(50, 40))
    # Apply a known permutation: relabel truth -> fit using known mapping.
    # Here, fit_label = inverse_permutation[true_label].
    known = np.array([2, 4, 0, 3, 1], dtype=np.int64)
    fit = known[truth]
    perm = align_states(truth, fit, k=5)
    # perm[fit_label] should give back true_label, so applying perm to fit
    # should reconstruct truth.
    relabeled = apply_permutation(fit, perm)
    assert (relabeled == truth).all()


def test_align_maximizes_diagonal_mass() -> None:
    """Even with noise, the aligned diagonal should be the largest possible."""
    rng = np.random.default_rng(0)
    truth = rng.integers(0, 4, size=(100, 50))
    flips = rng.random(truth.shape) < 0.2
    fit = truth.copy()
    fit[flips] = (fit[flips] + 1) % 4
    perm = align_states(truth, fit, k=4)
    aligned = apply_permutation(fit, perm)
    accuracy = float((aligned == truth).mean())
    # Random labeling would give 25%; alignment should clear 70%+.
    assert accuracy > 0.7


def test_align_shape_mismatch_errors() -> None:
    truth = np.zeros((10, 5), dtype=np.int64)
    fit = np.zeros((10, 6), dtype=np.int64)
    try:
        align_states(truth, fit, k=3)
    except ValueError:
        return
    raise AssertionError("expected shape-mismatch ValueError")