src/iohmm_evac/report/recovery_plots.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Recovery-diagnostic plots: confusion matrix, parameter scatter, LL trace.

Same conventions as :mod:`iohmm_evac.report.plots`: each function takes an
optional :class:`matplotlib.axes.Axes`, returns the Axes, and never calls
``plt.show()`` or ``fig.savefig()`` itself.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from iohmm_evac.diagnostics.recovery import ParameterRecoveryReport
from iohmm_evac.inference.fit_params import learnable_indices
from iohmm_evac.report.constants import STATE_ORDER
from iohmm_evac.types import FloatArray

if TYPE_CHECKING:
    from matplotlib.axes import Axes

__all__ = [
    "plot_log_likelihood_trace",
    "plot_parameter_recovery",
    "plot_state_recovery_confusion",
]


def plot_state_recovery_confusion(confusion: FloatArray, ax: Axes | None = None) -> Axes:
    """Heatmap of the K×K row-normalized confusion matrix.

    Each cell shows the share of true-row mass landing in fit-column.
    Darker cells = more mass.
    """
    if ax is None:
        _, ax = plt.subplots(figsize=(5, 5))
    k = confusion.shape[0]
    im = ax.imshow(confusion, cmap="Blues", vmin=0.0, vmax=1.0, aspect="equal")
    ax.set_xticks(range(k))
    ax.set_yticks(range(k))
    if k == len(STATE_ORDER):
        ax.set_xticklabels(list(STATE_ORDER))
        ax.set_yticklabels(list(STATE_ORDER))
    ax.set_xlabel("Fit (aligned)")
    ax.set_ylabel("Truth")
    ax.set_title("State recovery confusion")
    for i in range(k):
        for j in range(k):
            value = float(confusion[i, j])
            color = "white" if value > 0.5 else "black"
            ax.text(j, i, f"{value:.2f}", ha="center", va="center", color=color, fontsize=8)
    if ax.figure is not None:
        ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    return ax


def _scatter_alpha(report: ParameterRecoveryReport, ax: Axes) -> None:
    learnable, _ = learnable_indices()
    truth = report.transition_alpha_true
    fit = report.transition_alpha_fit
    xs: list[float] = []
    ys: list[float] = []
    labels: list[str] = []
    for k in range(truth.shape[0]):
        for j in range(truth.shape[1]):
            if not learnable[k, j]:
                continue
            t_val = truth[k, j]
            f_val = fit[k, j]
            if not (np.isfinite(t_val) and np.isfinite(f_val)):
                continue
            xs.append(float(t_val))
            ys.append(float(f_val))
            labels.append(f"α[{STATE_ORDER[k]}{STATE_ORDER[j]}]")
    ax.scatter(xs, ys, marker="o", color="#1f77b4", s=40, label="α (transition intercept)")
    for x_, y_, lab in zip(xs, ys, labels, strict=True):
        ax.annotate(lab, (x_, y_), fontsize=6, alpha=0.6, xytext=(3, 3), textcoords="offset points")


def _scatter_beta(report: ParameterRecoveryReport, ax: Axes) -> None:
    learnable, _ = learnable_indices()
    truth = report.transition_beta_true
    fit = report.transition_beta_fit
    xs: list[float] = []
    ys: list[float] = []
    for k in range(truth.shape[0]):
        for j in range(truth.shape[1]):
            if not learnable[k, j]:
                continue
            for f_idx in range(truth.shape[2]):
                t_val = float(truth[k, j, f_idx])
                f_val = float(fit[k, j, f_idx])
                if not (np.isfinite(t_val) and np.isfinite(f_val)):
                    continue
                xs.append(t_val)
                ys.append(f_val)
    if xs:
        ax.scatter(xs, ys, marker="x", color="#d62728", s=20, label="β (slope)", alpha=0.7)


def _scatter_emissions(report: ParameterRecoveryReport, ax: Axes) -> None:
    pairs = [
        (report.emission_p_true, report.emission_p_fit, "p", "#2ca02c", "s"),
        (report.emission_mu_true, report.emission_mu_fit, "μ", "#9467bd", "^"),
        (
            report.emission_sigma_true,
            report.emission_sigma_fit,
            "σ",
            "#8c564b",
            "v",
        ),
        (
            report.emission_lambda_true,
            report.emission_lambda_fit,
            "λ",
            "#e377c2",
            "P",
        ),
    ]
    for true, fit, label, color, marker in pairs:
        ax.scatter(true, fit, marker=marker, color=color, s=40, label=f"{label} (emission)")


def plot_parameter_recovery(report: ParameterRecoveryReport, ax: Axes | None = None) -> Axes:
    """Scatter true vs estimated parameter values, all groups on one axis."""
    if ax is None:
        _, ax = plt.subplots(figsize=(7, 7))
    _scatter_alpha(report, ax)
    _scatter_beta(report, ax)
    _scatter_emissions(report, ax)
    # Identity line spanning whatever the data range happens to be.
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    lo = min(xlim[0], ylim[0])
    hi = max(xlim[1], ylim[1])
    ax.plot([lo, hi], [lo, hi], color="black", linestyle="--", alpha=0.4, label="identity")
    ax.set_xlabel("Truth")
    ax.set_ylabel("Fit (aligned)")
    ax.set_title("Parameter recovery")
    ax.legend(loc="upper left", fontsize=7, framealpha=0.9)
    return ax


def _draw_one_ll_panel(
    ax: Axes,
    traces: Sequence[Sequence[float]],
    *,
    best_index: int | None,
    skip_first_iter: bool,
    title: str,
) -> None:
    """Render a single LL-trace panel, optionally dropping iteration 1."""
    for i, trace in enumerate(traces):
        if skip_first_iter:
            if len(trace) < 2:
                continue
            x_offset = 2
            values = list(trace[1:])
        else:
            x_offset = 1
            values = list(trace)
        if not values:
            continue
        x = np.arange(x_offset, x_offset + len(values))
        if i == best_index:
            ax.plot(x, values, color="#1f77b4", lw=2.5, label=f"restart {i} (best)")
        else:
            ax.plot(x, values, color="#aaaaaa", lw=1.0, alpha=0.7, label=f"restart {i}")
    ax.set_xlabel("EM iteration")
    ax.set_ylabel("Log-likelihood")
    ax.set_title(title)
    ax.legend(loc="lower right", fontsize=7, framealpha=0.9)


def plot_log_likelihood_trace(
    traces: Sequence[Sequence[float]],
    ax: Sequence[Axes] | None = None,
    *,
    best_index: int | None = None,
) -> Sequence[Axes]:
    """Two-panel log-likelihood traces over EM iterations.

    Left panel: every iteration, full y-range — shows the initial-to-final
    jump that random init typically produces. Right panel: iteration 2
    onward only, auto-scaled — the part where convergence behavior is
    actually visible.

    Y-units are *total* log-likelihood across all (i, t) observations; the
    figure annotates this so a reader doesn't have to guess at the
    magnitude. Restarts that converged in a single iteration are skipped
    in the right panel.

    If ``ax`` is supplied it must be a length-2 sequence; otherwise a
    single figure with two side-by-side subplots is created.
    """
    if ax is None:
        fig, axes_arr = plt.subplots(1, 2, figsize=(12, 4))
        axes: list[Axes] = list(axes_arr)
        fig.suptitle("EM log-likelihood traces (units: total LL across all (i,t))", fontsize=10)
    else:
        axes = list(ax)
        if len(axes) != 2:
            msg = f"plot_log_likelihood_trace needs exactly 2 axes, got {len(axes)}"
            raise ValueError(msg)
    if not traces:
        for a in axes:
            a.set_title("Log-likelihood trace (no restarts)")
        return axes
    finals = [t[-1] for t in traces if len(t) > 0]
    if best_index is None and finals:
        best_index = int(np.argmax(finals))
    _draw_one_ll_panel(
        axes[0], traces, best_index=best_index, skip_first_iter=False, title="All iterations"
    )
    _draw_one_ll_panel(
        axes[1], traces, best_index=best_index, skip_first_iter=True, title="From iter 2"
    )
    return axes