#!/usr/bin/env python3 """Run fixed-feature head-selection regret for global and top-k fire-prone scopes.""" from __future__ import annotations import argparse import csv import importlib.util import json import math from pathlib import Path from typing import Any import numpy as np BASE_RUNNER = Path(__file__).resolve().parent / "task_scripts" / "run_all_backbone_selection_regret_20260504.py" spec = importlib.util.spec_from_file_location("selection_regret_base_20260504", BASE_RUNNER) if spec is None or spec.loader is None: raise RuntimeError(f"Cannot import base runner: {BASE_RUNNER}") base = importlib.util.module_from_spec(spec) spec.loader.exec_module(base) head_control = base.head_control SCOPE_FRACS = (0.05, 0.10, 0.20) SCOPE_ORDER = ("global", "top5", "top10", "top20") SCOPE_LABELS = { "global": "global", "top5": "top 5%", "top10": "top 10%", "top20": "top 20%", } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Selection-regret scope sweep.") parser.add_argument("--source-kind", choices=("reference", "attached", "spatial", "alphaearth"), required=True) parser.add_argument("--feature-root", type=Path, required=True) parser.add_argument("--daily-rows-csv", type=Path) parser.add_argument("--support-dir", type=Path) parser.add_argument("--alphaearth-cache-root", type=Path) parser.add_argument("--output-dir", type=Path, required=True) parser.add_argument("--fm-family", type=str, required=True) parser.add_argument("--model-tag", type=str, required=True) parser.add_argument("--seed", type=int, required=True) parser.add_argument("--heads", nargs="+", choices=base.HEADS, default=["linear", "pixel_mlp", "shallow"]) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--epochs", type=int, default=2) parser.add_argument("--learning-rate", type=float, default=8e-4) parser.add_argument("--weight-decay", type=float, default=1e-5) parser.add_argument("--pos-weight-cap", type=float, default=150.0) parser.add_argument("--device", choices=("cpu", "cuda", "auto"), default="cpu") parser.add_argument( "--metric-thresholds", nargs="+", type=float, default=[ 1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 8e-2, 1e-1, 1.5e-1, 2e-1, 3e-1, 5e-1, ], ) parser.add_argument("--variants", nargs="+", default=["identity"]) parser.add_argument("--fire-prone-top-fracs", nargs="+", type=float, default=list(SCOPE_FRACS)) parser.add_argument("--temporal-steps", type=int, default=3) parser.add_argument("--spatial-radius", type=int, default=8) parser.add_argument("--buffer-radius", type=int, default=8) parser.add_argument("--boundary-radius", type=int, default=8) parser.add_argument("--coarse-factor", type=int, default=8) parser.add_argument("--time-step-hours", type=int, default=6) return parser.parse_args() def scope_name(top_frac: float) -> str: pct = int(round(float(top_frac) * 100.0)) return f"top{pct}" def scope_label(top_frac: float) -> str: pct = int(round(float(top_frac) * 100.0)) return f"top {pct}%" def build_scope_masks( split_rows: dict[str, list[dict[str, str]]], store: Any, top_fracs: list[float], ) -> tuple[dict[str, np.ndarray | None], dict[str, dict[str, Any]]]: masks: dict[str, np.ndarray | None] = {"global": None} meta: dict[str, dict[str, Any]] = { "global": { "scope_name": "global", "reported_as": "global", "top_fraction": None, } } for frac in top_fracs: name = scope_name(frac) mask, mask_meta = head_control.build_fire_prone_mask(split_rows["train"], store, float(frac)) masks[name] = mask meta[name] = { "scope_name": name, "reported_as": scope_label(frac), **mask_meta, } return masks, meta def build_posthoc_rows_for_scopes( probs: np.ndarray, targets: np.ndarray, sample_times: np.ndarray, split: str, scope_masks: dict[str, np.ndarray | None], args: argparse.Namespace, ) -> list[dict[str, object]]: rows_out: list[dict[str, object]] = [] for threshold in [float(v) for v in args.metric_thresholds]: base_binary = probs >= threshold for variant in args.variants: binary = head_control.apply_variant(base_binary, variant) tensors = head_control.evaluate_threshold_variant( binary_np=binary, target_np=targets, sample_times=sample_times, time_step_hours=args.time_step_hours, temporal_steps=args.temporal_steps, spatial_radius=args.spatial_radius, buffer_radius=args.buffer_radius, boundary_radius=args.boundary_radius, coarse_factor=args.coarse_factor, tolerance_hours=args.temporal_steps * args.time_step_hours, ) for scope, region_mask in scope_masks.items(): row: dict[str, object] = { "split": split, "scope": scope, "threshold": float(threshold), "variant": variant, "time_step_hours": int(args.time_step_hours), "temporal_steps": int(args.temporal_steps), "tolerance_hours": int(args.temporal_steps * args.time_step_hours), "spatial_radius": int(args.spatial_radius), "buffer_radius": int(args.buffer_radius), "boundary_radius": int(args.boundary_radius), "coarse_factor": int(args.coarse_factor), } row.update(head_control.metrics_for_scope(tensors, region_mask)) rows_out.append(row) return rows_out def read_csv(path: Path) -> list[dict[str, str]]: with path.open("r", encoding="utf-8", newline="") as fh: return list(csv.DictReader(fh)) def load_head_summary( head_dir: Path, head_arch: str, scopes: tuple[str, ...], ) -> tuple[list[dict[str, object]], dict[str, dict[str, float]], dict[str, object]] | None: posthoc_path = head_dir / "posthoc_rows.csv" summary_path = head_dir / "summary.json" if not posthoc_path.exists() or not summary_path.exists(): return None rows = [dict(row) for row in read_csv(posthoc_path)] if not rows: return None try: summary = json.loads(summary_path.read_text(encoding="utf-8")) except json.JSONDecodeError: return None if str(summary.get("head_arch")) != str(head_arch): return None raw_pr_auc = summary.get("raw_pr_auc") if not isinstance(raw_pr_auc, dict): return None try: parsed_pr_auc = { split: {scope: float(raw_pr_auc[split][scope]) for scope in scopes} for split in ("val", "test") } except Exception: return None return rows, parsed_pr_auc, summary def finite_json(value: Any) -> Any: if isinstance(value, float): return value if math.isfinite(value) else None if isinstance(value, dict): return {key: finite_json(val) for key, val in value.items()} if isinstance(value, list): return [finite_json(val) for val in value] return value def main() -> None: args = parse_args() args.output_dir.mkdir(parents=True, exist_ok=True) base.set_seed(int(args.seed)) device = base.choose_device(args.device) top_fracs = sorted({float(v) for v in args.fire_prone_top_fracs}) scope_order = ("global",) + tuple(scope_name(frac) for frac in top_fracs) base.SCOPE_ORDER = scope_order split_rows = { split: base.read_rows(args.feature_root / "splits" / f"{split}.csv") for split in ("train", "val", "test") } if args.source_kind == "reference": store = base.build_reference_store(split_rows) elif args.source_kind == "attached": store = base.build_attached_store(args, split_rows) elif args.source_kind == "spatial": store = base.build_spatial_store(args, split_rows) else: store = base.build_alphaearth_store(args, split_rows) loaders = base.make_loaders(split_rows, store, int(args.batch_size), device, int(args.seed)) first = next(iter(loaders["train"])) in_ch = int(first["x"].shape[1]) prior_prob = base.total_positive_rate(split_rows["train"]) scope_masks, scope_meta = build_scope_masks(split_rows, store, top_fracs) head_metrics: list[dict[str, object]] = [] head_artifacts: dict[str, str] = {} for head_index, head_arch in enumerate(args.heads): head_dir = args.output_dir / head_arch head_dir.mkdir(parents=True, exist_ok=True) cached = load_head_summary(head_dir, head_arch, scope_order) if cached is not None: posthoc_rows, raw_pr_auc, _ = cached print(f"[scope-sweep] reuse {args.fm_family} seed={args.seed} head={head_arch}", flush=True) else: print(f"[scope-sweep] training {args.fm_family} seed={args.seed} head={head_arch}", flush=True) model, history = base.train_one_head( head_arch=head_arch, in_ch=in_ch, prior_prob=prior_prob, loaders=loaders, args=args, device=device, seed_offset=1009 * (head_index + 1), ) posthoc_rows = [] raw_pr_auc: dict[str, dict[str, float]] = {} for split in ("val", "test"): probs, targets = base.collect_predictions(model, loaders[split], device) sample_times = base.build_sample_times(split_rows[split]) raw_pr_auc[split] = { scope: head_control._masked_average_precision(probs, targets, region_mask=mask) for scope, mask in scope_masks.items() } posthoc_rows.extend( build_posthoc_rows_for_scopes( probs=probs, targets=targets, sample_times=sample_times, split=split, scope_masks=scope_masks, args=args, ) ) base.write_csv(posthoc_rows, head_dir / "posthoc_rows.csv") head_summary = { "head_arch": head_arch, "head_label": head_control.HEAD_LABELS[head_arch], "history": history, "raw_pr_auc": raw_pr_auc, "scope_meta": scope_meta, "posthoc_rows_csv": str(head_dir / "posthoc_rows.csv"), } (head_dir / "summary.json").write_text(json.dumps(finite_json(head_summary), indent=2), encoding="utf-8") head_artifacts[head_arch] = str(head_dir / "summary.json") base.append_head_metrics(head_metrics, posthoc_rows, raw_pr_auc, head_arch, args) selection_rows = base.summarize_head_scores(head_metrics) for row in selection_rows: row["model_tag"] = args.model_tag row["family"] = args.fm_family row["seed"] = int(args.seed) base.write_csv(head_metrics, args.output_dir / "head_metrics.csv") base.write_csv(selection_rows, args.output_dir / "selection_rows.csv") summary = { "experiment": "fixed-feature head-selection regret scope sweep", "task": "wildfire_occupancy", "model_tag": args.model_tag, "fm_family": args.fm_family, "source_kind": args.source_kind, "seed": int(args.seed), "feature_root": str(args.feature_root), "daily_rows_csv": str(args.daily_rows_csv) if args.daily_rows_csv else None, "support_dir": str(args.support_dir) if args.support_dir else None, "alphaearth_cache_root": str(args.alphaearth_cache_root) if args.alphaearth_cache_root else None, "device": str(device), "heads": list(args.heads), "scope_order": list(scope_order), "scope_meta": scope_meta, "input_channels": int(in_ch), "prior_prob": float(prior_prob), "metrics": base.METRICS, "head_metrics": head_metrics, "selection_rows": selection_rows, "head_artifacts": head_artifacts, "artifacts": { "head_metrics_csv": str(args.output_dir / "head_metrics.csv"), "selection_rows_csv": str(args.output_dir / "selection_rows.csv"), }, } (args.output_dir / "summary.json").write_text(json.dumps(finite_json(summary), indent=2), encoding="utf-8") print(json.dumps(finite_json(summary), indent=2), flush=True) if __name__ == "__main__": main()