| |
| """Run fixed-feature head-selection regret for global and top-k fire-prone scopes.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import importlib.util |
| import json |
| import math |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
|
|
|
|
| BASE_RUNNER = Path(__file__).resolve().parent / "task_scripts" / "run_all_backbone_selection_regret_20260504.py" |
| spec = importlib.util.spec_from_file_location("selection_regret_base_20260504", BASE_RUNNER) |
| if spec is None or spec.loader is None: |
| raise RuntimeError(f"Cannot import base runner: {BASE_RUNNER}") |
| base = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(base) |
|
|
| head_control = base.head_control |
|
|
| SCOPE_FRACS = (0.05, 0.10, 0.20) |
| SCOPE_ORDER = ("global", "top5", "top10", "top20") |
| SCOPE_LABELS = { |
| "global": "global", |
| "top5": "top 5%", |
| "top10": "top 10%", |
| "top20": "top 20%", |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Selection-regret scope sweep.") |
| parser.add_argument("--source-kind", choices=("reference", "attached", "spatial", "alphaearth"), required=True) |
| parser.add_argument("--feature-root", type=Path, required=True) |
| parser.add_argument("--daily-rows-csv", type=Path) |
| parser.add_argument("--support-dir", type=Path) |
| parser.add_argument("--alphaearth-cache-root", type=Path) |
| parser.add_argument("--output-dir", type=Path, required=True) |
| parser.add_argument("--fm-family", type=str, required=True) |
| parser.add_argument("--model-tag", type=str, required=True) |
| parser.add_argument("--seed", type=int, required=True) |
| parser.add_argument("--heads", nargs="+", choices=base.HEADS, default=["linear", "pixel_mlp", "shallow"]) |
| parser.add_argument("--batch-size", type=int, default=8) |
| parser.add_argument("--epochs", type=int, default=2) |
| parser.add_argument("--learning-rate", type=float, default=8e-4) |
| parser.add_argument("--weight-decay", type=float, default=1e-5) |
| parser.add_argument("--pos-weight-cap", type=float, default=150.0) |
| parser.add_argument("--device", choices=("cpu", "cuda", "auto"), default="cpu") |
| parser.add_argument( |
| "--metric-thresholds", |
| nargs="+", |
| type=float, |
| default=[ |
| 1e-5, |
| 2e-5, |
| 5e-5, |
| 1e-4, |
| 2e-4, |
| 5e-4, |
| 1e-3, |
| 2e-3, |
| 5e-3, |
| 1e-2, |
| 2e-2, |
| 5e-2, |
| 8e-2, |
| 1e-1, |
| 1.5e-1, |
| 2e-1, |
| 3e-1, |
| 5e-1, |
| ], |
| ) |
| parser.add_argument("--variants", nargs="+", default=["identity"]) |
| parser.add_argument("--fire-prone-top-fracs", nargs="+", type=float, default=list(SCOPE_FRACS)) |
| parser.add_argument("--temporal-steps", type=int, default=3) |
| parser.add_argument("--spatial-radius", type=int, default=8) |
| parser.add_argument("--buffer-radius", type=int, default=8) |
| parser.add_argument("--boundary-radius", type=int, default=8) |
| parser.add_argument("--coarse-factor", type=int, default=8) |
| parser.add_argument("--time-step-hours", type=int, default=6) |
| return parser.parse_args() |
|
|
|
|
| def scope_name(top_frac: float) -> str: |
| pct = int(round(float(top_frac) * 100.0)) |
| return f"top{pct}" |
|
|
|
|
| def scope_label(top_frac: float) -> str: |
| pct = int(round(float(top_frac) * 100.0)) |
| return f"top {pct}%" |
|
|
|
|
| def build_scope_masks( |
| split_rows: dict[str, list[dict[str, str]]], |
| store: Any, |
| top_fracs: list[float], |
| ) -> tuple[dict[str, np.ndarray | None], dict[str, dict[str, Any]]]: |
| masks: dict[str, np.ndarray | None] = {"global": None} |
| meta: dict[str, dict[str, Any]] = { |
| "global": { |
| "scope_name": "global", |
| "reported_as": "global", |
| "top_fraction": None, |
| } |
| } |
| for frac in top_fracs: |
| name = scope_name(frac) |
| mask, mask_meta = head_control.build_fire_prone_mask(split_rows["train"], store, float(frac)) |
| masks[name] = mask |
| meta[name] = { |
| "scope_name": name, |
| "reported_as": scope_label(frac), |
| **mask_meta, |
| } |
| return masks, meta |
|
|
|
|
| def build_posthoc_rows_for_scopes( |
| probs: np.ndarray, |
| targets: np.ndarray, |
| sample_times: np.ndarray, |
| split: str, |
| scope_masks: dict[str, np.ndarray | None], |
| args: argparse.Namespace, |
| ) -> list[dict[str, object]]: |
| rows_out: list[dict[str, object]] = [] |
| for threshold in [float(v) for v in args.metric_thresholds]: |
| base_binary = probs >= threshold |
| for variant in args.variants: |
| binary = head_control.apply_variant(base_binary, variant) |
| tensors = head_control.evaluate_threshold_variant( |
| binary_np=binary, |
| target_np=targets, |
| sample_times=sample_times, |
| time_step_hours=args.time_step_hours, |
| temporal_steps=args.temporal_steps, |
| spatial_radius=args.spatial_radius, |
| buffer_radius=args.buffer_radius, |
| boundary_radius=args.boundary_radius, |
| coarse_factor=args.coarse_factor, |
| tolerance_hours=args.temporal_steps * args.time_step_hours, |
| ) |
| for scope, region_mask in scope_masks.items(): |
| row: dict[str, object] = { |
| "split": split, |
| "scope": scope, |
| "threshold": float(threshold), |
| "variant": variant, |
| "time_step_hours": int(args.time_step_hours), |
| "temporal_steps": int(args.temporal_steps), |
| "tolerance_hours": int(args.temporal_steps * args.time_step_hours), |
| "spatial_radius": int(args.spatial_radius), |
| "buffer_radius": int(args.buffer_radius), |
| "boundary_radius": int(args.boundary_radius), |
| "coarse_factor": int(args.coarse_factor), |
| } |
| row.update(head_control.metrics_for_scope(tensors, region_mask)) |
| rows_out.append(row) |
| return rows_out |
|
|
|
|
| def read_csv(path: Path) -> list[dict[str, str]]: |
| with path.open("r", encoding="utf-8", newline="") as fh: |
| return list(csv.DictReader(fh)) |
|
|
|
|
| def load_head_summary( |
| head_dir: Path, |
| head_arch: str, |
| scopes: tuple[str, ...], |
| ) -> tuple[list[dict[str, object]], dict[str, dict[str, float]], dict[str, object]] | None: |
| posthoc_path = head_dir / "posthoc_rows.csv" |
| summary_path = head_dir / "summary.json" |
| if not posthoc_path.exists() or not summary_path.exists(): |
| return None |
| rows = [dict(row) for row in read_csv(posthoc_path)] |
| if not rows: |
| return None |
| try: |
| summary = json.loads(summary_path.read_text(encoding="utf-8")) |
| except json.JSONDecodeError: |
| return None |
| if str(summary.get("head_arch")) != str(head_arch): |
| return None |
| raw_pr_auc = summary.get("raw_pr_auc") |
| if not isinstance(raw_pr_auc, dict): |
| return None |
| try: |
| parsed_pr_auc = { |
| split: {scope: float(raw_pr_auc[split][scope]) for scope in scopes} |
| for split in ("val", "test") |
| } |
| except Exception: |
| return None |
| return rows, parsed_pr_auc, summary |
|
|
|
|
| def finite_json(value: Any) -> Any: |
| if isinstance(value, float): |
| return value if math.isfinite(value) else None |
| if isinstance(value, dict): |
| return {key: finite_json(val) for key, val in value.items()} |
| if isinstance(value, list): |
| return [finite_json(val) for val in value] |
| return value |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
| base.set_seed(int(args.seed)) |
| device = base.choose_device(args.device) |
|
|
| top_fracs = sorted({float(v) for v in args.fire_prone_top_fracs}) |
| scope_order = ("global",) + tuple(scope_name(frac) for frac in top_fracs) |
| base.SCOPE_ORDER = scope_order |
|
|
| split_rows = { |
| split: base.read_rows(args.feature_root / "splits" / f"{split}.csv") |
| for split in ("train", "val", "test") |
| } |
| if args.source_kind == "reference": |
| store = base.build_reference_store(split_rows) |
| elif args.source_kind == "attached": |
| store = base.build_attached_store(args, split_rows) |
| elif args.source_kind == "spatial": |
| store = base.build_spatial_store(args, split_rows) |
| else: |
| store = base.build_alphaearth_store(args, split_rows) |
|
|
| loaders = base.make_loaders(split_rows, store, int(args.batch_size), device, int(args.seed)) |
| first = next(iter(loaders["train"])) |
| in_ch = int(first["x"].shape[1]) |
| prior_prob = base.total_positive_rate(split_rows["train"]) |
| scope_masks, scope_meta = build_scope_masks(split_rows, store, top_fracs) |
|
|
| head_metrics: list[dict[str, object]] = [] |
| head_artifacts: dict[str, str] = {} |
| for head_index, head_arch in enumerate(args.heads): |
| head_dir = args.output_dir / head_arch |
| head_dir.mkdir(parents=True, exist_ok=True) |
| cached = load_head_summary(head_dir, head_arch, scope_order) |
| if cached is not None: |
| posthoc_rows, raw_pr_auc, _ = cached |
| print(f"[scope-sweep] reuse {args.fm_family} seed={args.seed} head={head_arch}", flush=True) |
| else: |
| print(f"[scope-sweep] training {args.fm_family} seed={args.seed} head={head_arch}", flush=True) |
| model, history = base.train_one_head( |
| head_arch=head_arch, |
| in_ch=in_ch, |
| prior_prob=prior_prob, |
| loaders=loaders, |
| args=args, |
| device=device, |
| seed_offset=1009 * (head_index + 1), |
| ) |
| posthoc_rows = [] |
| raw_pr_auc: dict[str, dict[str, float]] = {} |
| for split in ("val", "test"): |
| probs, targets = base.collect_predictions(model, loaders[split], device) |
| sample_times = base.build_sample_times(split_rows[split]) |
| raw_pr_auc[split] = { |
| scope: head_control._masked_average_precision(probs, targets, region_mask=mask) |
| for scope, mask in scope_masks.items() |
| } |
| posthoc_rows.extend( |
| build_posthoc_rows_for_scopes( |
| probs=probs, |
| targets=targets, |
| sample_times=sample_times, |
| split=split, |
| scope_masks=scope_masks, |
| args=args, |
| ) |
| ) |
| base.write_csv(posthoc_rows, head_dir / "posthoc_rows.csv") |
| head_summary = { |
| "head_arch": head_arch, |
| "head_label": head_control.HEAD_LABELS[head_arch], |
| "history": history, |
| "raw_pr_auc": raw_pr_auc, |
| "scope_meta": scope_meta, |
| "posthoc_rows_csv": str(head_dir / "posthoc_rows.csv"), |
| } |
| (head_dir / "summary.json").write_text(json.dumps(finite_json(head_summary), indent=2), encoding="utf-8") |
| head_artifacts[head_arch] = str(head_dir / "summary.json") |
| base.append_head_metrics(head_metrics, posthoc_rows, raw_pr_auc, head_arch, args) |
|
|
| selection_rows = base.summarize_head_scores(head_metrics) |
| for row in selection_rows: |
| row["model_tag"] = args.model_tag |
| row["family"] = args.fm_family |
| row["seed"] = int(args.seed) |
|
|
| base.write_csv(head_metrics, args.output_dir / "head_metrics.csv") |
| base.write_csv(selection_rows, args.output_dir / "selection_rows.csv") |
| summary = { |
| "experiment": "fixed-feature head-selection regret scope sweep", |
| "task": "wildfire_occupancy", |
| "model_tag": args.model_tag, |
| "fm_family": args.fm_family, |
| "source_kind": args.source_kind, |
| "seed": int(args.seed), |
| "feature_root": str(args.feature_root), |
| "daily_rows_csv": str(args.daily_rows_csv) if args.daily_rows_csv else None, |
| "support_dir": str(args.support_dir) if args.support_dir else None, |
| "alphaearth_cache_root": str(args.alphaearth_cache_root) if args.alphaearth_cache_root else None, |
| "device": str(device), |
| "heads": list(args.heads), |
| "scope_order": list(scope_order), |
| "scope_meta": scope_meta, |
| "input_channels": int(in_ch), |
| "prior_prob": float(prior_prob), |
| "metrics": base.METRICS, |
| "head_metrics": head_metrics, |
| "selection_rows": selection_rows, |
| "head_artifacts": head_artifacts, |
| "artifacts": { |
| "head_metrics_csv": str(args.output_dir / "head_metrics.csv"), |
| "selection_rows_csv": str(args.output_dir / "selection_rows.csv"), |
| }, |
| } |
| (args.output_dir / "summary.json").write_text(json.dumps(finite_json(summary), indent=2), encoding="utf-8") |
| print(json.dumps(finite_json(summary), indent=2), flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|