# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Parquet output and TOML config sidecar serialization."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import tomli_w
from iohmm_evac.dgp.simulator import SimulationResult, encode_evac_path
from iohmm_evac.params import SimulationConfig, to_nested_dict
from iohmm_evac.types import State
__all__ = ["write_results"]
_STATE_LABELS = np.array([s.name for s in State], dtype=object)
def _long_table(result: SimulationResult) -> pa.Table:
"""Long-format observation table.
Columns: household_id, t, state, departure, displacement, comm_count.
"""
n, t_plus_1 = result.states.shape
household_id = np.repeat(np.arange(n, dtype=np.int64), t_plus_1)
t = np.tile(np.arange(t_plus_1, dtype=np.int64), n)
state_codes = result.states.reshape(-1)
state_label = _STATE_LABELS[state_codes]
return pa.table(
{
"household_id": pa.array(household_id),
"t": pa.array(t),
"state": pa.array(state_label),
"departure": pa.array(result.departures.reshape(-1)),
"displacement": pa.array(result.displacements.reshape(-1)),
"comm_count": pa.array(result.communications.reshape(-1)),
}
)
def _population_table(result: SimulationResult) -> pa.Table:
pop = result.population
zone_label = np.array(["A", "B", "C"], dtype=object)[pop.zone]
return pa.table(
{
"household_id": pa.array(np.arange(pop.n, dtype=np.int64)),
"distance_km": pa.array(pop.distance),
"vehicle": pa.array(pop.vehicle.astype(bool)),
"risk": pa.array(pop.risk),
"zone": pa.array(zone_label),
"destination_km": pa.array(pop.destination),
"evac_path": pa.array(encode_evac_path(result.evac_path)),
}
)
def _timeline_table(result: SimulationResult) -> pa.Table:
tl = result.timeline
n_steps = tl.forecast.shape[0]
return pa.table(
{
"t": pa.array(np.arange(n_steps, dtype=np.int64)),
"forecast": pa.array(tl.forecast),
"voluntary": pa.array(tl.voluntary.astype(bool)),
"mandatory": pa.array(tl.mandatory.astype(bool)),
"time_since_order": pa.array(tl.time_since_order),
}
)
def _write_table(table: pa.Table, path: Path) -> None:
pq.write_table(table, path) # type: ignore[no-untyped-call]
def _config_to_toml(config: SimulationConfig) -> bytes:
nested = to_nested_dict(config)
assert isinstance(nested, dict)
return tomli_w.dumps(nested).encode("utf-8")
def write_results(result: SimulationResult, output: Path) -> dict[str, Path]:
"""Write the long-format simulation, plus population/timeline/config sidecars.
Sidecar files share the stem of ``output``:
* ``<stem>.parquet`` — observation panel (one row per (household, t)).
* ``<stem>.population.parquet``
* ``<stem>.timeline.parquet``
* ``<stem>.config.toml``
"""
output = Path(output)
output.parent.mkdir(parents=True, exist_ok=True)
stem_dir = output.parent
stem = output.stem
obs_path = output
pop_path = stem_dir / f"{stem}.population.parquet"
tl_path = stem_dir / f"{stem}.timeline.parquet"
cfg_path = stem_dir / f"{stem}.config.toml"
_write_table(_long_table(result), obs_path)
_write_table(_population_table(result), pop_path)
_write_table(_timeline_table(result), tl_path)
cfg_path.write_bytes(_config_to_toml(result.config))
return {
"observations": obs_path,
"population": pop_path,
"timeline": tl_path,
"config": cfg_path,
}