src/iohmm_evac/io.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Parquet output and TOML config sidecar serialization."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import tomli_w

from iohmm_evac.dgp.simulator import SimulationResult, encode_evac_path
from iohmm_evac.params import SimulationConfig, to_nested_dict
from iohmm_evac.types import State

__all__ = ["write_results"]


_STATE_LABELS = np.array([s.name for s in State], dtype=object)


def _long_table(result: SimulationResult) -> pa.Table:
    """Long-format observation table.

    Columns: household_id, t, state, departure, displacement, comm_count.
    """
    n, t_plus_1 = result.states.shape
    household_id = np.repeat(np.arange(n, dtype=np.int64), t_plus_1)
    t = np.tile(np.arange(t_plus_1, dtype=np.int64), n)
    state_codes = result.states.reshape(-1)
    state_label = _STATE_LABELS[state_codes]
    return pa.table(
        {
            "household_id": pa.array(household_id),
            "t": pa.array(t),
            "state": pa.array(state_label),
            "departure": pa.array(result.departures.reshape(-1)),
            "displacement": pa.array(result.displacements.reshape(-1)),
            "comm_count": pa.array(result.communications.reshape(-1)),
        }
    )


def _population_table(result: SimulationResult) -> pa.Table:
    pop = result.population
    zone_label = np.array(["A", "B", "C"], dtype=object)[pop.zone]
    return pa.table(
        {
            "household_id": pa.array(np.arange(pop.n, dtype=np.int64)),
            "distance_km": pa.array(pop.distance),
            "vehicle": pa.array(pop.vehicle.astype(bool)),
            "risk": pa.array(pop.risk),
            "zone": pa.array(zone_label),
            "destination_km": pa.array(pop.destination),
            "evac_path": pa.array(encode_evac_path(result.evac_path)),
        }
    )


def _timeline_table(result: SimulationResult) -> pa.Table:
    tl = result.timeline
    n_steps = tl.forecast.shape[0]
    return pa.table(
        {
            "t": pa.array(np.arange(n_steps, dtype=np.int64)),
            "forecast": pa.array(tl.forecast),
            "voluntary": pa.array(tl.voluntary.astype(bool)),
            "mandatory": pa.array(tl.mandatory.astype(bool)),
            "time_since_order": pa.array(tl.time_since_order),
        }
    )


def _write_table(table: pa.Table, path: Path) -> None:
    pq.write_table(table, path)  # type: ignore[no-untyped-call]


def _config_to_toml(config: SimulationConfig) -> bytes:
    nested = to_nested_dict(config)
    assert isinstance(nested, dict)
    return tomli_w.dumps(nested).encode("utf-8")


def write_results(result: SimulationResult, output: Path) -> dict[str, Path]:
    """Write the long-format simulation, plus population/timeline/config sidecars.

    Sidecar files share the stem of ``output``:

    * ``<stem>.parquet`` — observation panel (one row per (household, t)).
    * ``<stem>.population.parquet``
    * ``<stem>.timeline.parquet``
    * ``<stem>.config.toml``
    """
    output = Path(output)
    output.parent.mkdir(parents=True, exist_ok=True)
    stem_dir = output.parent
    stem = output.stem

    obs_path = output
    pop_path = stem_dir / f"{stem}.population.parquet"
    tl_path = stem_dir / f"{stem}.timeline.parquet"
    cfg_path = stem_dir / f"{stem}.config.toml"

    _write_table(_long_table(result), obs_path)
    _write_table(_population_table(result), pop_path)
    _write_table(_timeline_table(result), tl_path)
    cfg_path.write_bytes(_config_to_toml(result.config))

    return {
        "observations": obs_path,
        "population": pop_path,
        "timeline": tl_path,
        "config": cfg_path,
    }