Wildfire-FM / experiments /raw_reference /run_selection_regret_scope_sweep_20260505.py
yx21e's picture
Initial FireWx-FM artifact release
80ef3b2 verified
#!/usr/bin/env python3
"""Run fixed-feature head-selection regret for global and top-k fire-prone scopes."""
from __future__ import annotations
import argparse
import csv
import importlib.util
import json
import math
from pathlib import Path
from typing import Any
import numpy as np
BASE_RUNNER = Path(__file__).resolve().parent / "task_scripts" / "run_all_backbone_selection_regret_20260504.py"
spec = importlib.util.spec_from_file_location("selection_regret_base_20260504", BASE_RUNNER)
if spec is None or spec.loader is None:
raise RuntimeError(f"Cannot import base runner: {BASE_RUNNER}")
base = importlib.util.module_from_spec(spec)
spec.loader.exec_module(base)
head_control = base.head_control
SCOPE_FRACS = (0.05, 0.10, 0.20)
SCOPE_ORDER = ("global", "top5", "top10", "top20")
SCOPE_LABELS = {
"global": "global",
"top5": "top 5%",
"top10": "top 10%",
"top20": "top 20%",
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Selection-regret scope sweep.")
parser.add_argument("--source-kind", choices=("reference", "attached", "spatial", "alphaearth"), required=True)
parser.add_argument("--feature-root", type=Path, required=True)
parser.add_argument("--daily-rows-csv", type=Path)
parser.add_argument("--support-dir", type=Path)
parser.add_argument("--alphaearth-cache-root", type=Path)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--fm-family", type=str, required=True)
parser.add_argument("--model-tag", type=str, required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--heads", nargs="+", choices=base.HEADS, default=["linear", "pixel_mlp", "shallow"])
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--learning-rate", type=float, default=8e-4)
parser.add_argument("--weight-decay", type=float, default=1e-5)
parser.add_argument("--pos-weight-cap", type=float, default=150.0)
parser.add_argument("--device", choices=("cpu", "cuda", "auto"), default="cpu")
parser.add_argument(
"--metric-thresholds",
nargs="+",
type=float,
default=[
1e-5,
2e-5,
5e-5,
1e-4,
2e-4,
5e-4,
1e-3,
2e-3,
5e-3,
1e-2,
2e-2,
5e-2,
8e-2,
1e-1,
1.5e-1,
2e-1,
3e-1,
5e-1,
],
)
parser.add_argument("--variants", nargs="+", default=["identity"])
parser.add_argument("--fire-prone-top-fracs", nargs="+", type=float, default=list(SCOPE_FRACS))
parser.add_argument("--temporal-steps", type=int, default=3)
parser.add_argument("--spatial-radius", type=int, default=8)
parser.add_argument("--buffer-radius", type=int, default=8)
parser.add_argument("--boundary-radius", type=int, default=8)
parser.add_argument("--coarse-factor", type=int, default=8)
parser.add_argument("--time-step-hours", type=int, default=6)
return parser.parse_args()
def scope_name(top_frac: float) -> str:
pct = int(round(float(top_frac) * 100.0))
return f"top{pct}"
def scope_label(top_frac: float) -> str:
pct = int(round(float(top_frac) * 100.0))
return f"top {pct}%"
def build_scope_masks(
split_rows: dict[str, list[dict[str, str]]],
store: Any,
top_fracs: list[float],
) -> tuple[dict[str, np.ndarray | None], dict[str, dict[str, Any]]]:
masks: dict[str, np.ndarray | None] = {"global": None}
meta: dict[str, dict[str, Any]] = {
"global": {
"scope_name": "global",
"reported_as": "global",
"top_fraction": None,
}
}
for frac in top_fracs:
name = scope_name(frac)
mask, mask_meta = head_control.build_fire_prone_mask(split_rows["train"], store, float(frac))
masks[name] = mask
meta[name] = {
"scope_name": name,
"reported_as": scope_label(frac),
**mask_meta,
}
return masks, meta
def build_posthoc_rows_for_scopes(
probs: np.ndarray,
targets: np.ndarray,
sample_times: np.ndarray,
split: str,
scope_masks: dict[str, np.ndarray | None],
args: argparse.Namespace,
) -> list[dict[str, object]]:
rows_out: list[dict[str, object]] = []
for threshold in [float(v) for v in args.metric_thresholds]:
base_binary = probs >= threshold
for variant in args.variants:
binary = head_control.apply_variant(base_binary, variant)
tensors = head_control.evaluate_threshold_variant(
binary_np=binary,
target_np=targets,
sample_times=sample_times,
time_step_hours=args.time_step_hours,
temporal_steps=args.temporal_steps,
spatial_radius=args.spatial_radius,
buffer_radius=args.buffer_radius,
boundary_radius=args.boundary_radius,
coarse_factor=args.coarse_factor,
tolerance_hours=args.temporal_steps * args.time_step_hours,
)
for scope, region_mask in scope_masks.items():
row: dict[str, object] = {
"split": split,
"scope": scope,
"threshold": float(threshold),
"variant": variant,
"time_step_hours": int(args.time_step_hours),
"temporal_steps": int(args.temporal_steps),
"tolerance_hours": int(args.temporal_steps * args.time_step_hours),
"spatial_radius": int(args.spatial_radius),
"buffer_radius": int(args.buffer_radius),
"boundary_radius": int(args.boundary_radius),
"coarse_factor": int(args.coarse_factor),
}
row.update(head_control.metrics_for_scope(tensors, region_mask))
rows_out.append(row)
return rows_out
def read_csv(path: Path) -> list[dict[str, str]]:
with path.open("r", encoding="utf-8", newline="") as fh:
return list(csv.DictReader(fh))
def load_head_summary(
head_dir: Path,
head_arch: str,
scopes: tuple[str, ...],
) -> tuple[list[dict[str, object]], dict[str, dict[str, float]], dict[str, object]] | None:
posthoc_path = head_dir / "posthoc_rows.csv"
summary_path = head_dir / "summary.json"
if not posthoc_path.exists() or not summary_path.exists():
return None
rows = [dict(row) for row in read_csv(posthoc_path)]
if not rows:
return None
try:
summary = json.loads(summary_path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return None
if str(summary.get("head_arch")) != str(head_arch):
return None
raw_pr_auc = summary.get("raw_pr_auc")
if not isinstance(raw_pr_auc, dict):
return None
try:
parsed_pr_auc = {
split: {scope: float(raw_pr_auc[split][scope]) for scope in scopes}
for split in ("val", "test")
}
except Exception:
return None
return rows, parsed_pr_auc, summary
def finite_json(value: Any) -> Any:
if isinstance(value, float):
return value if math.isfinite(value) else None
if isinstance(value, dict):
return {key: finite_json(val) for key, val in value.items()}
if isinstance(value, list):
return [finite_json(val) for val in value]
return value
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
base.set_seed(int(args.seed))
device = base.choose_device(args.device)
top_fracs = sorted({float(v) for v in args.fire_prone_top_fracs})
scope_order = ("global",) + tuple(scope_name(frac) for frac in top_fracs)
base.SCOPE_ORDER = scope_order
split_rows = {
split: base.read_rows(args.feature_root / "splits" / f"{split}.csv")
for split in ("train", "val", "test")
}
if args.source_kind == "reference":
store = base.build_reference_store(split_rows)
elif args.source_kind == "attached":
store = base.build_attached_store(args, split_rows)
elif args.source_kind == "spatial":
store = base.build_spatial_store(args, split_rows)
else:
store = base.build_alphaearth_store(args, split_rows)
loaders = base.make_loaders(split_rows, store, int(args.batch_size), device, int(args.seed))
first = next(iter(loaders["train"]))
in_ch = int(first["x"].shape[1])
prior_prob = base.total_positive_rate(split_rows["train"])
scope_masks, scope_meta = build_scope_masks(split_rows, store, top_fracs)
head_metrics: list[dict[str, object]] = []
head_artifacts: dict[str, str] = {}
for head_index, head_arch in enumerate(args.heads):
head_dir = args.output_dir / head_arch
head_dir.mkdir(parents=True, exist_ok=True)
cached = load_head_summary(head_dir, head_arch, scope_order)
if cached is not None:
posthoc_rows, raw_pr_auc, _ = cached
print(f"[scope-sweep] reuse {args.fm_family} seed={args.seed} head={head_arch}", flush=True)
else:
print(f"[scope-sweep] training {args.fm_family} seed={args.seed} head={head_arch}", flush=True)
model, history = base.train_one_head(
head_arch=head_arch,
in_ch=in_ch,
prior_prob=prior_prob,
loaders=loaders,
args=args,
device=device,
seed_offset=1009 * (head_index + 1),
)
posthoc_rows = []
raw_pr_auc: dict[str, dict[str, float]] = {}
for split in ("val", "test"):
probs, targets = base.collect_predictions(model, loaders[split], device)
sample_times = base.build_sample_times(split_rows[split])
raw_pr_auc[split] = {
scope: head_control._masked_average_precision(probs, targets, region_mask=mask)
for scope, mask in scope_masks.items()
}
posthoc_rows.extend(
build_posthoc_rows_for_scopes(
probs=probs,
targets=targets,
sample_times=sample_times,
split=split,
scope_masks=scope_masks,
args=args,
)
)
base.write_csv(posthoc_rows, head_dir / "posthoc_rows.csv")
head_summary = {
"head_arch": head_arch,
"head_label": head_control.HEAD_LABELS[head_arch],
"history": history,
"raw_pr_auc": raw_pr_auc,
"scope_meta": scope_meta,
"posthoc_rows_csv": str(head_dir / "posthoc_rows.csv"),
}
(head_dir / "summary.json").write_text(json.dumps(finite_json(head_summary), indent=2), encoding="utf-8")
head_artifacts[head_arch] = str(head_dir / "summary.json")
base.append_head_metrics(head_metrics, posthoc_rows, raw_pr_auc, head_arch, args)
selection_rows = base.summarize_head_scores(head_metrics)
for row in selection_rows:
row["model_tag"] = args.model_tag
row["family"] = args.fm_family
row["seed"] = int(args.seed)
base.write_csv(head_metrics, args.output_dir / "head_metrics.csv")
base.write_csv(selection_rows, args.output_dir / "selection_rows.csv")
summary = {
"experiment": "fixed-feature head-selection regret scope sweep",
"task": "wildfire_occupancy",
"model_tag": args.model_tag,
"fm_family": args.fm_family,
"source_kind": args.source_kind,
"seed": int(args.seed),
"feature_root": str(args.feature_root),
"daily_rows_csv": str(args.daily_rows_csv) if args.daily_rows_csv else None,
"support_dir": str(args.support_dir) if args.support_dir else None,
"alphaearth_cache_root": str(args.alphaearth_cache_root) if args.alphaearth_cache_root else None,
"device": str(device),
"heads": list(args.heads),
"scope_order": list(scope_order),
"scope_meta": scope_meta,
"input_channels": int(in_ch),
"prior_prob": float(prior_prob),
"metrics": base.METRICS,
"head_metrics": head_metrics,
"selection_rows": selection_rows,
"head_artifacts": head_artifacts,
"artifacts": {
"head_metrics_csv": str(args.output_dir / "head_metrics.csv"),
"selection_rows_csv": str(args.output_dir / "selection_rows.csv"),
},
}
(args.output_dir / "summary.json").write_text(json.dumps(finite_json(summary), indent=2), encoding="utf-8")
print(json.dumps(finite_json(summary), indent=2), flush=True)
if __name__ == "__main__":
main()