Spaces:
Running
Running
| """ | |
| CLI driver: run a foundation model zero-shot on one ISOMORPH release. | |
| Outputs: | |
| results/{model_short}_{dataset}.csv long-format per-channel metrics | |
| results/{model_short}_{dataset}_summary.csv cross-channel mean/median | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from data_utils import load_dataset, iter_test_windows | |
| from metrics import metrics_at_horizons, to_long_dataframe, HORIZONS | |
| from chronos_runner import ( | |
| load_chronos, predict_rolling_origin, collect_y_true, | |
| ) | |
| def short_name(model_id: str) -> str: | |
| return model_id.split("/")[-1].replace("-", "_") | |
| def run(out_dir: Path, model_id: str, results_dir: Path, | |
| L: int, H: int, stride: int, | |
| num_samples: int, channel_batch: int, | |
| max_windows: int | None = None, | |
| label: str | None = None) -> None: | |
| print(f"=== {out_dir.name} with {model_id} ===", file=sys.stderr) | |
| split = load_dataset(out_dir) | |
| if label is not None: | |
| split.label = label | |
| print(f" T={split.T} n_items={split.n_items} " | |
| f"train_end={split.train_end} val_end={split.val_end} " | |
| f"test=[{split.test_start}, {split.T})", file=sys.stderr) | |
| starts = list(iter_test_windows(split, L=L, H=H, stride=stride)) | |
| if max_windows is not None: | |
| starts = starts[:max_windows] | |
| print(f" rolling-origin windows: {len(starts)} " | |
| f"(L={L}, H={H}, stride={stride})", file=sys.stderr) | |
| pipe = load_chronos(model_id) | |
| t0 = time.time() | |
| y_pred = predict_rolling_origin( | |
| pipe, split.D, starts, L=L, H=H, | |
| num_samples=num_samples, channel_batch=channel_batch, | |
| ) | |
| y_true = collect_y_true(split.D, starts, H) | |
| elapsed = time.time() - t0 | |
| print(f" inference done in {elapsed/60:.1f} min", file=sys.stderr) | |
| metric_dict = metrics_at_horizons(y_true, y_pred, split.mase_denom, | |
| horizons=HORIZONS) | |
| long_df = to_long_dataframe(metric_dict, split.item_ids, | |
| model=model_id, dataset=split.label) | |
| sn = short_name(model_id) | |
| results_dir.mkdir(parents=True, exist_ok=True) | |
| out_long = results_dir / f"{sn}_{split.label}.csv" | |
| long_df.to_csv(out_long, index=False) | |
| print(f" -> {out_long}", file=sys.stderr) | |
| # Persist raw tensors for post-hoc slicing (e.g. stationary-vs-shock). | |
| out_npz = results_dir / f"{sn}_{split.label}_tensors.npz" | |
| np.savez_compressed( | |
| out_npz, y_pred=y_pred, y_true=y_true, | |
| window_starts=np.asarray(starts, dtype=np.int64), | |
| item_ids=np.asarray(split.item_ids), | |
| L=L, H=H, stride=stride, model=model_id, dataset=split.label, | |
| ) | |
| print(f" -> {out_npz}", file=sys.stderr) | |
| # Cross-channel summary at each (metric, h). | |
| summary = (long_df | |
| .groupby(["model", "dataset", "metric", "h"])["value"] | |
| .agg(mean="mean", median="median", | |
| q25=lambda x: x.quantile(0.25), | |
| q75=lambda x: x.quantile(0.75), | |
| n="count") | |
| .reset_index()) | |
| out_sum = results_dir / f"{sn}_{split.label}_summary.csv" | |
| summary.to_csv(out_sum, index=False) | |
| print(f" -> {out_sum}", file=sys.stderr) | |
| # Print a compact body-table preview. | |
| print("\n Headline (median across channels):", file=sys.stderr) | |
| print(summary.pivot_table(index="metric", columns="h", | |
| values="median").to_string(), | |
| file=sys.stderr) | |
| def main(): | |
| repo = Path(__file__).resolve().parents[1] | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--root", default=str(repo / "data")) | |
| ap.add_argument("--dataset", default="output_item50") | |
| ap.add_argument("--scenario_path", default=None, | |
| help="Path to a scenario directory " | |
| "(e.g. data/output_mixture/baseline/seed2025). " | |
| "Overrides --root/--dataset when set.") | |
| ap.add_argument("--label", default=None, | |
| help="Output filename label. Defaults to the directory " | |
| "name. For scenario paths, set this to the scenario " | |
| "name so results from different scenarios don't collide.") | |
| ap.add_argument("--model_id", default="amazon/chronos-t5-base") | |
| ap.add_argument("--out", default=str( | |
| repo / "results" / "eval" / "baseline_and_scenarios")) | |
| ap.add_argument("--L", type=int, default=512) | |
| ap.add_argument("--H", type=int, default=30) | |
| ap.add_argument("--stride", type=int, default=30) | |
| ap.add_argument("--num_samples", type=int, default=20) | |
| ap.add_argument("--channel_batch", type=int, default=16) | |
| ap.add_argument("--max_windows", type=int, default=None, | |
| help="cap for smoke testing") | |
| args = ap.parse_args() | |
| if args.scenario_path is not None: | |
| out_dir = Path(args.scenario_path) | |
| else: | |
| out_dir = Path(args.root) / args.dataset | |
| run(out_dir, args.model_id, Path(args.out), | |
| L=args.L, H=args.H, stride=args.stride, | |
| num_samples=args.num_samples, | |
| channel_batch=args.channel_batch, | |
| max_windows=args.max_windows, | |
| label=args.label) | |
| if __name__ == "__main__": | |
| main() | |