| """Evaluate a LoRA adapter against the locked test split. |
| |
| Single source of truth for publishable test metrics per ../EVAL.md. |
| Uses Lightning's trainer.test() against the SemanticSegmentationTask |
| so all the metric plumbing matches what was used during training — |
| this is required because the task's forward() does pre/post-processing |
| that a hand-rolled loop diverges from. See dev notes in TRAINING.md. |
| |
| Writes: |
| eval/metrics_{mode}.json — full metrics dict |
| eval/test_results.txt — pretty-printed Lightning summary |
| |
| Usage: |
| python3 shared/eval_adapter.py --adapter adapters/lulc_nyc |
| python3 shared/eval_adapter.py --adapter adapters/lulc_nyc --mode full_ft \ |
| --ckpt-override adapters/lulc_nyc/output/ckpt/last.ckpt |
| |
| Modes: |
| lora load adapter_model.safetensors + decoder_head.safetensors |
| full_ft load a complete Lightning .ckpt (Phase 2/3/4 baseline) |
| zero_shot no fine-tune; freshly built task with pretrained base only |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import lightning.pytorch as pl |
| import torch |
| import yaml |
| from safetensors.torch import load_file |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
| from train_lora import build_task, build_datamodule |
|
|
|
|
| def load_adapter_into_task(task, adapter_dir: Path): |
| """Restore LoRA Δ + decoder/neck/head weights into a fresh task. |
| |
| Uses state_dict() format (parameters + buffers including BatchNorm |
| running stats — those matter for inference accuracy and were the |
| cause of an earlier eval failure when omitted). |
| """ |
| lora = load_file(adapter_dir / "adapter_model.safetensors") |
| head = load_file(adapter_dir / "decoder_head.safetensors") |
|
|
| model = task.model |
|
|
| |
| enc_state = {k.removeprefix("encoder."): v |
| for k, v in lora.items() if k.startswith("encoder.")} |
| missing, unexpected = model.encoder.load_state_dict( |
| enc_state, strict=False) |
| |
| |
| |
| if unexpected: |
| print(f"WARN: {len(unexpected)} unexpected encoder keys; " |
| f"first: {unexpected[:3]}", file=sys.stderr) |
|
|
| |
| head_grouped: dict[str, dict] = {} |
| for k, v in head.items(): |
| sub, _, rest = k.partition(".") |
| head_grouped.setdefault(sub, {})[rest] = v |
| for sub, state in head_grouped.items(): |
| m = getattr(model, sub, None) |
| if m is None: |
| continue |
| m.load_state_dict(state, strict=False) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--adapter", required=True, type=Path) |
| ap.add_argument("--mode", choices=["lora", "full_ft", "zero_shot"], |
| default="lora") |
| ap.add_argument("--ckpt-override", type=Path, default=None) |
| args = ap.parse_args() |
|
|
| cfg = yaml.safe_load((args.adapter / "config.yaml").read_text()) |
| pl.seed_everything(cfg.get("seed", 42), workers=True) |
|
|
| task = build_task(cfg) |
| if args.mode == "lora": |
| adapter_dir = args.adapter / "output" |
| load_adapter_into_task(task, adapter_dir) |
| elif args.mode == "full_ft": |
| if not args.ckpt_override: |
| raise SystemExit("--mode full_ft requires --ckpt-override") |
| ckpt = torch.load(args.ckpt_override, map_location="cpu", |
| weights_only=False) |
| task.load_state_dict(ckpt["state_dict"], strict=True) |
| |
|
|
| dm = build_datamodule(cfg["data"]) |
| trainer = pl.Trainer( |
| accelerator="gpu" if torch.cuda.is_available() else "cpu", |
| devices=1, |
| precision=cfg.get("precision", "16-mixed"), |
| logger=False, |
| enable_progress_bar=False, |
| ) |
| results = trainer.test(task, datamodule=dm) |
| metrics = results[0] if results else {} |
| metrics["mode"] = args.mode |
| metrics["task_name"] = cfg.get("task_name", args.adapter.name) |
| metrics["num_classes"] = cfg["num_classes"] |
|
|
| out_dir = args.adapter / "eval" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| (out_dir / f"metrics_{args.mode}.json").write_text( |
| json.dumps(metrics, indent=2)) |
|
|
| |
| print(f"\n=== {cfg.get('task_name')} :: {args.mode} ===") |
| keys = ["test/mIoU", "test/loss", "test/Pixel_Accuracy", |
| "test/F1_Score", "test/Boundary_mIoU"] |
| for k in keys: |
| if k in metrics: |
| print(f" {k:24s} {metrics[k]:.4f}") |
| print(f" per-class IoU: " |
| f"{[f'{metrics.get(f'test/IoU_{i}', float('nan')):.4f}' for i in range(cfg['num_classes'])]}") |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|