File size: 4,834 Bytes
6a82282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Evaluate a LoRA adapter against the locked test split.

Single source of truth for publishable test metrics per ../EVAL.md.
Uses Lightning's trainer.test() against the SemanticSegmentationTask
so all the metric plumbing matches what was used during training —
this is required because the task's forward() does pre/post-processing
that a hand-rolled loop diverges from. See dev notes in TRAINING.md.

Writes:
    eval/metrics_{mode}.json — full metrics dict
    eval/test_results.txt    — pretty-printed Lightning summary

Usage:
    python3 shared/eval_adapter.py --adapter adapters/lulc_nyc
    python3 shared/eval_adapter.py --adapter adapters/lulc_nyc --mode full_ft \
        --ckpt-override adapters/lulc_nyc/output/ckpt/last.ckpt

Modes:
    lora      load adapter_model.safetensors + decoder_head.safetensors
    full_ft   load a complete Lightning .ckpt (Phase 2/3/4 baseline)
    zero_shot no fine-tune; freshly built task with pretrained base only
"""
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import lightning.pytorch as pl
import torch
import yaml
from safetensors.torch import load_file

sys.path.insert(0, str(Path(__file__).parent))
from train_lora import build_task, build_datamodule  # noqa: E402


def load_adapter_into_task(task, adapter_dir: Path):
    """Restore LoRA Δ + decoder/neck/head weights into a fresh task.

    Uses state_dict() format (parameters + buffers including BatchNorm
    running stats — those matter for inference accuracy and were the
    cause of an earlier eval failure when omitted).
    """
    lora = load_file(adapter_dir / "adapter_model.safetensors")
    head = load_file(adapter_dir / "decoder_head.safetensors")

    model = task.model

    # Encoder LoRA Δ.
    enc_state = {k.removeprefix("encoder."): v
                 for k, v in lora.items() if k.startswith("encoder.")}
    missing, unexpected = model.encoder.load_state_dict(
        enc_state, strict=False)
    # missing[] is huge (the entire frozen base); we don't print it. We
    # do warn on unexpected, since those mean the saved file has keys
    # the model doesn't recognize.
    if unexpected:
        print(f"WARN: {len(unexpected)} unexpected encoder keys; "
              f"first: {unexpected[:3]}", file=sys.stderr)

    # Decoder / neck / head / aux_heads.
    head_grouped: dict[str, dict] = {}
    for k, v in head.items():
        sub, _, rest = k.partition(".")
        head_grouped.setdefault(sub, {})[rest] = v
    for sub, state in head_grouped.items():
        m = getattr(model, sub, None)
        if m is None:
            continue
        m.load_state_dict(state, strict=False)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--adapter", required=True, type=Path)
    ap.add_argument("--mode", choices=["lora", "full_ft", "zero_shot"],
                    default="lora")
    ap.add_argument("--ckpt-override", type=Path, default=None)
    args = ap.parse_args()

    cfg = yaml.safe_load((args.adapter / "config.yaml").read_text())
    pl.seed_everything(cfg.get("seed", 42), workers=True)

    task = build_task(cfg)
    if args.mode == "lora":
        adapter_dir = args.adapter / "output"
        load_adapter_into_task(task, adapter_dir)
    elif args.mode == "full_ft":
        if not args.ckpt_override:
            raise SystemExit("--mode full_ft requires --ckpt-override")
        ckpt = torch.load(args.ckpt_override, map_location="cpu",
                          weights_only=False)
        task.load_state_dict(ckpt["state_dict"], strict=True)
    # zero_shot: no weight loading; just evaluate the freshly built task.

    dm = build_datamodule(cfg["data"])
    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        precision=cfg.get("precision", "16-mixed"),
        logger=False,
        enable_progress_bar=False,
    )
    results = trainer.test(task, datamodule=dm)
    metrics = results[0] if results else {}
    metrics["mode"] = args.mode
    metrics["task_name"] = cfg.get("task_name", args.adapter.name)
    metrics["num_classes"] = cfg["num_classes"]

    out_dir = args.adapter / "eval"
    out_dir.mkdir(parents=True, exist_ok=True)
    (out_dir / f"metrics_{args.mode}.json").write_text(
        json.dumps(metrics, indent=2))

    # Print summary
    print(f"\n=== {cfg.get('task_name')} :: {args.mode} ===")
    keys = ["test/mIoU", "test/loss", "test/Pixel_Accuracy",
            "test/F1_Score", "test/Boundary_mIoU"]
    for k in keys:
        if k in metrics:
            print(f"  {k:24s} {metrics[k]:.4f}")
    print(f"  per-class IoU:           "
          f"{[f'{metrics.get(f'test/IoU_{i}', float('nan')):.4f}' for i in range(cfg['num_classes'])]}")


if __name__ == "__main__":
    sys.exit(main())