src/iohmm_evac/inference/io.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Serialize / deserialize :class:`FitResult` to a directory on disk."""

from __future__ import annotations

import tomllib
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import tomli_w

from iohmm_evac.inference.fit import FitResult
from iohmm_evac.inference.fit_params import (
    FEATURE_NAMES,
    EmissionFitParams,
    FitParameters,
    InitialFitParams,
    K,
    TransitionFitParams,
)
from iohmm_evac.types import FloatArray, IntArray

__all__ = [
    "FIT_FILE_NAMES",
    "FitBundle",
    "read_fit_bundle",
    "write_fit_bundle",
]


FIT_FILE_NAMES: dict[str, str] = {
    "theta": "theta.toml",
    "trace": "log_likelihood_trace.parquet",
    "states": "posterior_states.parquet",
    "metadata": "metadata.toml",
}


_INF = 1.0e30  # sentinel: TOML cannot represent infinity literally.


@dataclass(frozen=True, slots=True)
class FitBundle:
    """A fit reloaded from disk."""

    params: FitParameters
    log_likelihood_traces: tuple[tuple[float, ...], ...]
    posterior_states: IntArray
    best_index: int
    iterations_per_restart: tuple[int, ...]
    converged_per_restart: tuple[bool, ...]
    final_log_likelihoods: tuple[float, ...]


def _alpha_to_serializable(alpha: FloatArray) -> list[list[float]]:
    return [
        [float(_INF if x == np.inf else (-_INF if x == -np.inf else x)) for x in row]
        for row in alpha
    ]


def _alpha_from_serializable(rows: list[list[float]]) -> FloatArray:
    arr = np.array(rows, dtype=np.float64)
    arr = np.where(arr <= -0.99 * _INF, -np.inf, arr)
    arr = np.where(arr >= 0.99 * _INF, np.inf, arr)
    return arr


def _params_to_dict(p: FitParameters) -> dict[str, object]:
    return {
        "feature_names": list(p.feature_names),
        "initial": {"logits": [float(x) for x in p.initial.logits]},
        "transitions": {
            "alpha": _alpha_to_serializable(p.transitions.alpha),
            "beta": [
                [
                    [float(x) for x in p.transitions.beta[k, j]]
                    for j in range(p.transitions.beta.shape[1])
                ]
                for k in range(p.transitions.beta.shape[0])
            ],
        },
        "emissions": {
            "p_departure": [float(x) for x in p.emissions.p_departure],
            "mu_displacement": [float(x) for x in p.emissions.mu_displacement],
            "sigma_displacement": [float(x) for x in p.emissions.sigma_displacement],
            "lambda_comm": [float(x) for x in p.emissions.lambda_comm],
            "sigma_floor": float(p.emissions.sigma_floor),
        },
    }


def _params_from_dict(d: dict[str, object]) -> FitParameters:
    init = d["initial"]
    assert isinstance(init, dict)
    logits = np.array(init["logits"], dtype=np.float64)
    trans = d["transitions"]
    assert isinstance(trans, dict)
    alpha = _alpha_from_serializable(trans["alpha"])
    beta = np.array(trans["beta"], dtype=np.float64)
    em = d["emissions"]
    assert isinstance(em, dict)
    emit = EmissionFitParams(
        p_departure=np.array(em["p_departure"], dtype=np.float64),
        mu_displacement=np.array(em["mu_displacement"], dtype=np.float64),
        sigma_displacement=np.array(em["sigma_displacement"], dtype=np.float64),
        lambda_comm=np.array(em["lambda_comm"], dtype=np.float64),
        sigma_floor=float(em["sigma_floor"]),
    )
    raw_feature_names = d.get("feature_names", FEATURE_NAMES)
    assert isinstance(raw_feature_names, list | tuple)
    feature_names: tuple[str, ...] = tuple(str(x) for x in raw_feature_names)
    return FitParameters(
        initial=InitialFitParams(logits=logits),
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
        feature_names=feature_names,
    )


def write_fit_bundle(result: FitResult, posterior: IntArray, output_dir: Path) -> dict[str, Path]:
    """Write the four-file fit bundle into ``output_dir``."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    theta_path = output_dir / FIT_FILE_NAMES["theta"]
    trace_path = output_dir / FIT_FILE_NAMES["trace"]
    states_path = output_dir / FIT_FILE_NAMES["states"]
    metadata_path = output_dir / FIT_FILE_NAMES["metadata"]

    theta_path.write_bytes(tomli_w.dumps(_params_to_dict(result.best.params)).encode("utf-8"))

    rows: list[dict[str, float | int]] = []
    for restart_idx, run in enumerate(result.all_runs):
        for it_idx, ll in enumerate(run.log_likelihood_trace):
            rows.append({"restart": restart_idx, "iteration": it_idx + 1, "log_likelihood": ll})
    if rows:
        table = pa.table(
            {
                "restart": pa.array([int(r["restart"]) for r in rows], type=pa.int32()),
                "iteration": pa.array([int(r["iteration"]) for r in rows], type=pa.int32()),
                "log_likelihood": pa.array(
                    [float(r["log_likelihood"]) for r in rows], type=pa.float64()
                ),
            }
        )
    else:
        table = pa.table(
            {
                "restart": pa.array([], type=pa.int32()),
                "iteration": pa.array([], type=pa.int32()),
                "log_likelihood": pa.array([], type=pa.float64()),
            }
        )
    pq.write_table(table, trace_path)  # type: ignore[no-untyped-call]

    n, t_plus_1 = posterior.shape
    states_table = pa.table(
        {
            "household_id": pa.array(
                np.repeat(np.arange(n, dtype=np.int64), t_plus_1), type=pa.int64()
            ),
            "t": pa.array(np.tile(np.arange(t_plus_1, dtype=np.int64), n), type=pa.int64()),
            "state_code": pa.array(posterior.reshape(-1), type=pa.int64()),
        }
    )
    pq.write_table(states_table, states_path)  # type: ignore[no-untyped-call]

    metadata = {
        "k": int(K),
        "feature_names": list(result.best.params.feature_names),
        "best_index": int(result.best_index),
        "iterations_per_restart": [int(r.iterations) for r in result.all_runs],
        "converged_per_restart": [bool(r.converged) for r in result.all_runs],
        "final_log_likelihoods": [float(r.final_log_likelihood) for r in result.all_runs],
    }
    metadata_path.write_bytes(tomli_w.dumps(metadata).encode("utf-8"))

    return {
        "theta": theta_path,
        "trace": trace_path,
        "states": states_path,
        "metadata": metadata_path,
    }


def read_fit_bundle(input_dir: Path) -> FitBundle:
    """Read a fit bundle written by :func:`write_fit_bundle`."""
    input_dir = Path(input_dir)
    theta_path = input_dir / FIT_FILE_NAMES["theta"]
    trace_path = input_dir / FIT_FILE_NAMES["trace"]
    states_path = input_dir / FIT_FILE_NAMES["states"]
    metadata_path = input_dir / FIT_FILE_NAMES["metadata"]
    for label, p in [
        ("theta", theta_path),
        ("trace", trace_path),
        ("states", states_path),
        ("metadata", metadata_path),
    ]:
        if not p.exists():
            msg = f"Missing fit bundle file: {label} -> {p}"
            raise FileNotFoundError(msg)

    with theta_path.open("rb") as f:
        params = _params_from_dict(tomllib.load(f))

    trace_table = pq.read_table(trace_path)  # type: ignore[no-untyped-call]
    restart_col = trace_table.column("restart").to_pylist()
    ll_col = trace_table.column("log_likelihood").to_pylist()
    grouped: dict[int, list[float]] = {}
    for r, ll in zip(restart_col, ll_col, strict=True):
        grouped.setdefault(int(r), []).append(float(ll))
    n_restarts = (max(grouped.keys()) + 1) if grouped else 0
    traces = tuple(tuple(grouped.get(i, [])) for i in range(n_restarts))

    states_table = pq.read_table(states_path)  # type: ignore[no-untyped-call]
    n_total = states_table.num_rows
    if n_total > 0:
        codes = np.array(states_table.column("state_code").to_pylist(), dtype=np.int64)
        t_col = np.array(states_table.column("t").to_pylist(), dtype=np.int64)
        t_plus_1 = int(t_col.max()) + 1
        n_households = n_total // t_plus_1
        posterior = codes.reshape(n_households, t_plus_1)
    else:
        posterior = np.zeros((0, 0), dtype=np.int64)

    with metadata_path.open("rb") as f:
        metadata = tomllib.load(f)

    return FitBundle(
        params=params,
        log_likelihood_traces=traces,
        posterior_states=posterior,
        best_index=int(metadata.get("best_index", 0)),
        iterations_per_restart=tuple(metadata.get("iterations_per_restart", [])),
        converged_per_restart=tuple(metadata.get("converged_per_restart", [])),
        final_log_likelihoods=tuple(float(x) for x in metadata.get("final_log_likelihoods", [])),
    )