tests/test_decoding.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Tests for the Viterbi and posterior-mode decoders."""

from __future__ import annotations

import numpy as np

from iohmm_evac.diagnostics.decoding import posterior_mode, viterbi
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
    EmissionFitParams,
    FitParameters,
    InitialFitParams,
    TransitionFitParams,
)
from iohmm_evac.inference.forward_backward import forward_backward


def _deterministic_problem(seed: int = 0) -> tuple[FitParameters, FitData, np.ndarray]:
    """Three-state, near-deterministic-emission problem.

    State means are far apart in displacement and observation noise is
    small, so MAP decoding should recover the truth path exactly.
    """
    rng = np.random.default_rng(seed)
    k = 3
    f = 1
    n = 5
    t_plus_1 = 12

    alpha = np.array(
        [[0.0, -0.5, -3.0], [-3.0, 0.0, -0.5], [-3.0, -3.0, 0.0]],
        dtype=np.float64,
    )
    beta = np.zeros((k, k, f))
    beta[0, 1] = np.array([1.0])
    beta[1, 2] = np.array([1.0])
    emit = EmissionFitParams(
        p_departure=np.array([0.05, 0.5, 0.95]),
        mu_displacement=np.array([0.0, 5.0, 10.0]),
        sigma_displacement=np.array([0.1, 0.1, 0.1]),
        lambda_comm=np.array([0.5, 1.5, 3.0]),
    )
    params = FitParameters(
        initial=InitialFitParams(logits=np.array([0.0, -3.0, -3.0])),
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
    )

    inputs = np.zeros((n, t_plus_1, f), dtype=np.float64)
    inputs[:, :, 0] = np.linspace(0.0, 1.0, t_plus_1)[None, :]

    # Hand-construct a state path that's coherent with the transitions.
    truth = np.array(
        [
            [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
            [0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2],
            [0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2],
            [0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
            [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2],
        ],
        dtype=np.int64,
    )

    departure = (rng.random((n, t_plus_1)) < emit.p_departure[truth]).astype(np.float64)
    displacement = rng.normal(emit.mu_displacement[truth], emit.sigma_displacement[truth])
    comm = rng.poisson(emit.lambda_comm[truth]).astype(np.float64)
    data = FitData(
        inputs=inputs,
        departure=departure,
        displacement=displacement,
        comm=comm,
        true_states=truth,
    )
    return params, data, truth


def test_viterbi_returns_correct_shape() -> None:
    params, data, _ = _deterministic_problem()
    path = viterbi(params, data)
    assert path.shape == (data.n, data.t_total + 1)
    assert path.dtype == np.int64


def test_viterbi_recovers_truth_under_low_noise() -> None:
    params, data, truth = _deterministic_problem(seed=0)
    path = viterbi(params, data)
    accuracy = float(np.mean(path == truth))
    assert accuracy >= 0.99, f"Viterbi accuracy {accuracy:.4f} too low"


def test_posterior_mode_shape_and_consistency() -> None:
    params, data, _ = _deterministic_problem(seed=1)
    fb = forward_backward(params, data)
    mode = posterior_mode(fb.log_gamma)
    assert mode.shape == (data.n, data.t_total + 1)
    # Mode is per-step argmax: it must be a valid state index.
    assert mode.min() >= 0
    assert mode.max() < params.emissions.p_departure.shape[0]