# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Serialize / deserialize :class:`FitResult` to a directory on disk."""
from __future__ import annotations
import tomllib
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import tomli_w
from iohmm_evac.inference.fit import FitResult
from iohmm_evac.inference.fit_params import (
FEATURE_NAMES,
EmissionFitParams,
FitParameters,
InitialFitParams,
K,
TransitionFitParams,
)
from iohmm_evac.types import FloatArray, IntArray
__all__ = [
"FIT_FILE_NAMES",
"FitBundle",
"read_fit_bundle",
"write_fit_bundle",
]
FIT_FILE_NAMES: dict[str, str] = {
"theta": "theta.toml",
"trace": "log_likelihood_trace.parquet",
"states": "posterior_states.parquet",
"metadata": "metadata.toml",
}
_INF = 1.0e30 # sentinel: TOML cannot represent infinity literally.
@dataclass(frozen=True, slots=True)
class FitBundle:
"""A fit reloaded from disk."""
params: FitParameters
log_likelihood_traces: tuple[tuple[float, ...], ...]
posterior_states: IntArray
best_index: int
iterations_per_restart: tuple[int, ...]
converged_per_restart: tuple[bool, ...]
final_log_likelihoods: tuple[float, ...]
def _alpha_to_serializable(alpha: FloatArray) -> list[list[float]]:
return [
[float(_INF if x == np.inf else (-_INF if x == -np.inf else x)) for x in row]
for row in alpha
]
def _alpha_from_serializable(rows: list[list[float]]) -> FloatArray:
arr = np.array(rows, dtype=np.float64)
arr = np.where(arr <= -0.99 * _INF, -np.inf, arr)
arr = np.where(arr >= 0.99 * _INF, np.inf, arr)
return arr
def _params_to_dict(p: FitParameters) -> dict[str, object]:
return {
"feature_names": list(p.feature_names),
"initial": {"logits": [float(x) for x in p.initial.logits]},
"transitions": {
"alpha": _alpha_to_serializable(p.transitions.alpha),
"beta": [
[
[float(x) for x in p.transitions.beta[k, j]]
for j in range(p.transitions.beta.shape[1])
]
for k in range(p.transitions.beta.shape[0])
],
},
"emissions": {
"p_departure": [float(x) for x in p.emissions.p_departure],
"mu_displacement": [float(x) for x in p.emissions.mu_displacement],
"sigma_displacement": [float(x) for x in p.emissions.sigma_displacement],
"lambda_comm": [float(x) for x in p.emissions.lambda_comm],
"sigma_floor": float(p.emissions.sigma_floor),
},
}
def _params_from_dict(d: dict[str, object]) -> FitParameters:
init = d["initial"]
assert isinstance(init, dict)
logits = np.array(init["logits"], dtype=np.float64)
trans = d["transitions"]
assert isinstance(trans, dict)
alpha = _alpha_from_serializable(trans["alpha"])
beta = np.array(trans["beta"], dtype=np.float64)
em = d["emissions"]
assert isinstance(em, dict)
emit = EmissionFitParams(
p_departure=np.array(em["p_departure"], dtype=np.float64),
mu_displacement=np.array(em["mu_displacement"], dtype=np.float64),
sigma_displacement=np.array(em["sigma_displacement"], dtype=np.float64),
lambda_comm=np.array(em["lambda_comm"], dtype=np.float64),
sigma_floor=float(em["sigma_floor"]),
)
raw_feature_names = d.get("feature_names", FEATURE_NAMES)
assert isinstance(raw_feature_names, list | tuple)
feature_names: tuple[str, ...] = tuple(str(x) for x in raw_feature_names)
return FitParameters(
initial=InitialFitParams(logits=logits),
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
feature_names=feature_names,
)
def write_fit_bundle(result: FitResult, posterior: IntArray, output_dir: Path) -> dict[str, Path]:
"""Write the four-file fit bundle into ``output_dir``."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
theta_path = output_dir / FIT_FILE_NAMES["theta"]
trace_path = output_dir / FIT_FILE_NAMES["trace"]
states_path = output_dir / FIT_FILE_NAMES["states"]
metadata_path = output_dir / FIT_FILE_NAMES["metadata"]
theta_path.write_bytes(tomli_w.dumps(_params_to_dict(result.best.params)).encode("utf-8"))
rows: list[dict[str, float | int]] = []
for restart_idx, run in enumerate(result.all_runs):
for it_idx, ll in enumerate(run.log_likelihood_trace):
rows.append({"restart": restart_idx, "iteration": it_idx + 1, "log_likelihood": ll})
if rows:
table = pa.table(
{
"restart": pa.array([int(r["restart"]) for r in rows], type=pa.int32()),
"iteration": pa.array([int(r["iteration"]) for r in rows], type=pa.int32()),
"log_likelihood": pa.array(
[float(r["log_likelihood"]) for r in rows], type=pa.float64()
),
}
)
else:
table = pa.table(
{
"restart": pa.array([], type=pa.int32()),
"iteration": pa.array([], type=pa.int32()),
"log_likelihood": pa.array([], type=pa.float64()),
}
)
pq.write_table(table, trace_path) # type: ignore[no-untyped-call]
n, t_plus_1 = posterior.shape
states_table = pa.table(
{
"household_id": pa.array(
np.repeat(np.arange(n, dtype=np.int64), t_plus_1), type=pa.int64()
),
"t": pa.array(np.tile(np.arange(t_plus_1, dtype=np.int64), n), type=pa.int64()),
"state_code": pa.array(posterior.reshape(-1), type=pa.int64()),
}
)
pq.write_table(states_table, states_path) # type: ignore[no-untyped-call]
metadata = {
"k": int(K),
"feature_names": list(result.best.params.feature_names),
"best_index": int(result.best_index),
"iterations_per_restart": [int(r.iterations) for r in result.all_runs],
"converged_per_restart": [bool(r.converged) for r in result.all_runs],
"final_log_likelihoods": [float(r.final_log_likelihood) for r in result.all_runs],
}
metadata_path.write_bytes(tomli_w.dumps(metadata).encode("utf-8"))
return {
"theta": theta_path,
"trace": trace_path,
"states": states_path,
"metadata": metadata_path,
}
def read_fit_bundle(input_dir: Path) -> FitBundle:
"""Read a fit bundle written by :func:`write_fit_bundle`."""
input_dir = Path(input_dir)
theta_path = input_dir / FIT_FILE_NAMES["theta"]
trace_path = input_dir / FIT_FILE_NAMES["trace"]
states_path = input_dir / FIT_FILE_NAMES["states"]
metadata_path = input_dir / FIT_FILE_NAMES["metadata"]
for label, p in [
("theta", theta_path),
("trace", trace_path),
("states", states_path),
("metadata", metadata_path),
]:
if not p.exists():
msg = f"Missing fit bundle file: {label} -> {p}"
raise FileNotFoundError(msg)
with theta_path.open("rb") as f:
params = _params_from_dict(tomllib.load(f))
trace_table = pq.read_table(trace_path) # type: ignore[no-untyped-call]
restart_col = trace_table.column("restart").to_pylist()
ll_col = trace_table.column("log_likelihood").to_pylist()
grouped: dict[int, list[float]] = {}
for r, ll in zip(restart_col, ll_col, strict=True):
grouped.setdefault(int(r), []).append(float(ll))
n_restarts = (max(grouped.keys()) + 1) if grouped else 0
traces = tuple(tuple(grouped.get(i, [])) for i in range(n_restarts))
states_table = pq.read_table(states_path) # type: ignore[no-untyped-call]
n_total = states_table.num_rows
if n_total > 0:
codes = np.array(states_table.column("state_code").to_pylist(), dtype=np.int64)
t_col = np.array(states_table.column("t").to_pylist(), dtype=np.int64)
t_plus_1 = int(t_col.max()) + 1
n_households = n_total // t_plus_1
posterior = codes.reshape(n_households, t_plus_1)
else:
posterior = np.zeros((0, 0), dtype=np.int64)
with metadata_path.open("rb") as f:
metadata = tomllib.load(f)
return FitBundle(
params=params,
log_likelihood_traces=traces,
posterior_states=posterior,
best_index=int(metadata.get("best_index", 0)),
iterations_per_restart=tuple(metadata.get("iterations_per_restart", [])),
converged_per_restart=tuple(metadata.get("converged_per_restart", [])),
final_log_likelihoods=tuple(float(x) for x in metadata.get("final_log_likelihoods", [])),
)