Speed up and resume OOD evaluation with batched dynamic generation
Browse files- scripts/evaluate_model.py +92 -37
scripts/evaluate_model.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
|
| 4 |
Metrics: JSON parse rate, exact canonical JSON match, field-level F1, simple metadata constraints,
|
| 5 |
stratified by split/target_layer/slice_type/lifecycle_operation.
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
import argparse
|
| 8 |
-
import csv
|
| 9 |
import json
|
| 10 |
import os
|
| 11 |
from collections import defaultdict
|
| 12 |
from pathlib import Path
|
| 13 |
-
from typing import Any, Dict, List
|
| 14 |
|
| 15 |
import torch
|
| 16 |
from datasets import load_dataset
|
|
@@ -29,7 +31,12 @@ def parse_args():
|
|
| 29 |
p.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
|
| 30 |
p.add_argument("--output_dir", default="eval_outputs")
|
| 31 |
p.add_argument("--max_samples_per_split", type=int, default=None)
|
| 32 |
-
p.add_argument("--max_new_tokens", type=int, default=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
p.add_argument("--temperature", type=float, default=0.0)
|
| 34 |
p.add_argument("--top_p", type=float, default=1.0)
|
| 35 |
p.add_argument("--load_in_4bit", action="store_true", help="Load base model in 4-bit for adapter evaluation")
|
|
@@ -39,25 +46,43 @@ def parse_args():
|
|
| 39 |
|
| 40 |
|
| 41 |
def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
| 42 |
-
# Keep all messages before the final assistant answer. For standard rows this is system+user;
|
| 43 |
-
# for multi-turn rows this preserves earlier assistant turns but removes gold final answer.
|
| 44 |
out = []
|
| 45 |
-
for m in messages:
|
| 46 |
-
if m.get("role") == "assistant"
|
| 47 |
break
|
| 48 |
out.append({"role": m.get("role"), "content": m.get("content", "")})
|
| 49 |
-
# Fallback for any unexpected format.
|
| 50 |
if not out:
|
| 51 |
out = [m for m in messages if m.get("role") != "assistant"]
|
| 52 |
return out
|
| 53 |
|
| 54 |
|
| 55 |
-
def
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
gen_kwargs = dict(
|
| 60 |
-
max_new_tokens=
|
| 61 |
do_sample=args.temperature > 0,
|
| 62 |
temperature=args.temperature if args.temperature > 0 else None,
|
| 63 |
top_p=args.top_p,
|
|
@@ -65,16 +90,16 @@ def generate_one(model, tokenizer, messages, args):
|
|
| 65 |
eos_token_id=tokenizer.eos_token_id,
|
| 66 |
)
|
| 67 |
gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
|
| 68 |
-
with torch.
|
| 69 |
out = model.generate(**inputs, **gen_kwargs)
|
| 70 |
-
new_tokens = out[
|
| 71 |
-
return tokenizer.
|
| 72 |
|
| 73 |
|
| 74 |
def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
|
| 75 |
-
|
| 76 |
pred_obj, pred_err = parse_json(prediction)
|
| 77 |
-
gold_obj, gold_err = parse_json(
|
| 78 |
out: Dict[str, Any] = {
|
| 79 |
"id": example.get("id"),
|
| 80 |
"target_layer": example.get("target_layer"),
|
|
@@ -84,7 +109,7 @@ def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
|
|
| 84 |
"gold_parse_json": gold_obj is not None,
|
| 85 |
"exact_match": False,
|
| 86 |
"prediction": prediction,
|
| 87 |
-
"gold":
|
| 88 |
"parse_error": pred_err,
|
| 89 |
}
|
| 90 |
if pred_obj is not None and gold_obj is not None:
|
|
@@ -108,6 +133,26 @@ def require_cuda():
|
|
| 108 |
print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}")
|
| 109 |
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def main():
|
| 112 |
args = parse_args()
|
| 113 |
require_cuda()
|
|
@@ -143,26 +188,36 @@ def main():
|
|
| 143 |
split_ds = ds[split]
|
| 144 |
if args.max_samples_per_split:
|
| 145 |
split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds))))
|
| 146 |
-
rows = []
|
| 147 |
-
print(f"Evaluating {split}: {len(split_ds)} examples")
|
| 148 |
-
for ex in tqdm(split_ds, desc=split):
|
| 149 |
-
pred = generate_one(model, tokenizer, ex["messages"], args)
|
| 150 |
-
rows.append(row_metrics(ex, pred))
|
| 151 |
-
|
| 152 |
split_dir = out_dir / split
|
| 153 |
split_dir.mkdir(parents=True, exist_ok=True)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
all_summary[split] = summary
|
| 164 |
-
|
| 165 |
-
write_json(out_dir / "all_metrics.json", all_summary)
|
| 166 |
print(json.dumps(all_summary, indent=2)[:4000])
|
| 167 |
|
| 168 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
+
"""Fast generation evaluation for TMF921 intent-to-config models/adapters.
|
| 3 |
|
| 4 |
Metrics: JSON parse rate, exact canonical JSON match, field-level F1, simple metadata constraints,
|
| 5 |
stratified by split/target_layer/slice_type/lifecycle_operation.
|
| 6 |
+
|
| 7 |
+
The evaluator is resumable and batched. It periodically writes partial predictions so a long OOD
|
| 8 |
+
run can be stopped/restarted without losing completed examples.
|
| 9 |
"""
|
| 10 |
import argparse
|
|
|
|
| 11 |
import json
|
| 12 |
import os
|
| 13 |
from collections import defaultdict
|
| 14 |
from pathlib import Path
|
| 15 |
+
from typing import Any, Dict, List, Tuple
|
| 16 |
|
| 17 |
import torch
|
| 18 |
from datasets import load_dataset
|
|
|
|
| 31 |
p.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
|
| 32 |
p.add_argument("--output_dir", default="eval_outputs")
|
| 33 |
p.add_argument("--max_samples_per_split", type=int, default=None)
|
| 34 |
+
p.add_argument("--max_new_tokens", type=int, default=1536, help="Hard cap for generation. 1536 covers audited dataset outputs with margin.")
|
| 35 |
+
p.add_argument("--gold_length_buffer", type=int, default=96, help="Dynamic cap = max gold output tokens in batch + buffer, clipped by max_new_tokens")
|
| 36 |
+
p.add_argument("--batch_size", type=int, default=4, help="Batched generation size. Use 1 if OOM; 4 is usually safe on RTX 6000 Ada for Qwen3-8B QLoRA.")
|
| 37 |
+
p.add_argument("--save_every", type=int, default=25, help="Write partial predictions every N completed examples")
|
| 38 |
+
p.add_argument("--resume", action="store_true", default=True, help="Resume from existing predictions.json if present")
|
| 39 |
+
p.add_argument("--no_resume", dest="resume", action="store_false")
|
| 40 |
p.add_argument("--temperature", type=float, default=0.0)
|
| 41 |
p.add_argument("--top_p", type=float, default=1.0)
|
| 42 |
p.add_argument("--load_in_4bit", action="store_true", help="Load base model in 4-bit for adapter evaluation")
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
|
|
|
|
|
|
| 49 |
out = []
|
| 50 |
+
for i, m in enumerate(messages):
|
| 51 |
+
if i == len(messages) - 1 and m.get("role") == "assistant":
|
| 52 |
break
|
| 53 |
out.append({"role": m.get("role"), "content": m.get("content", "")})
|
|
|
|
| 54 |
if not out:
|
| 55 |
out = [m for m in messages if m.get("role") != "assistant"]
|
| 56 |
return out
|
| 57 |
|
| 58 |
|
| 59 |
+
def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str:
|
| 60 |
+
return tokenizer.apply_chat_template(make_prompt_messages(messages), tokenize=False, add_generation_prompt=True)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def gold_text(example: Dict[str, Any]) -> str:
|
| 64 |
+
return example.get("completion") or get_message(example["messages"], "assistant")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int:
|
| 68 |
+
lens = []
|
| 69 |
+
for ex in examples:
|
| 70 |
+
ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"]
|
| 71 |
+
lens.append(len(ids))
|
| 72 |
+
return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer)))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def generate_batch(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]:
|
| 76 |
+
texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples]
|
| 77 |
+
old_padding_side = tokenizer.padding_side
|
| 78 |
+
tokenizer.padding_side = "left"
|
| 79 |
+
try:
|
| 80 |
+
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
|
| 81 |
+
finally:
|
| 82 |
+
tokenizer.padding_side = old_padding_side
|
| 83 |
+
max_new = dynamic_max_new_tokens(tokenizer, examples, args)
|
| 84 |
gen_kwargs = dict(
|
| 85 |
+
max_new_tokens=max_new,
|
| 86 |
do_sample=args.temperature > 0,
|
| 87 |
temperature=args.temperature if args.temperature > 0 else None,
|
| 88 |
top_p=args.top_p,
|
|
|
|
| 90 |
eos_token_id=tokenizer.eos_token_id,
|
| 91 |
)
|
| 92 |
gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
|
| 93 |
+
with torch.inference_mode():
|
| 94 |
out = model.generate(**inputs, **gen_kwargs)
|
| 95 |
+
new_tokens = out[:, inputs["input_ids"].shape[1]:]
|
| 96 |
+
return tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
|
| 97 |
|
| 98 |
|
| 99 |
def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
|
| 100 |
+
gold = gold_text(example)
|
| 101 |
pred_obj, pred_err = parse_json(prediction)
|
| 102 |
+
gold_obj, gold_err = parse_json(gold)
|
| 103 |
out: Dict[str, Any] = {
|
| 104 |
"id": example.get("id"),
|
| 105 |
"target_layer": example.get("target_layer"),
|
|
|
|
| 109 |
"gold_parse_json": gold_obj is not None,
|
| 110 |
"exact_match": False,
|
| 111 |
"prediction": prediction,
|
| 112 |
+
"gold": gold,
|
| 113 |
"parse_error": pred_err,
|
| 114 |
}
|
| 115 |
if pred_obj is not None and gold_obj is not None:
|
|
|
|
| 133 |
print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}")
|
| 134 |
|
| 135 |
|
| 136 |
+
def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]:
|
| 137 |
+
if path.exists():
|
| 138 |
+
rows = json.loads(path.read_text())
|
| 139 |
+
done = {str(r.get("id")) for r in rows}
|
| 140 |
+
return rows, done
|
| 141 |
+
return [], set()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 145 |
+
write_json(split_dir / "predictions.json", rows)
|
| 146 |
+
summary = aggregate_metrics(rows)
|
| 147 |
+
for key in ["target_layer", "slice_type", "lifecycle_operation"]:
|
| 148 |
+
groups = defaultdict(list)
|
| 149 |
+
for r in rows:
|
| 150 |
+
groups[str(r.get(key))].append(r)
|
| 151 |
+
summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
|
| 152 |
+
write_json(split_dir / "metrics.json", summary)
|
| 153 |
+
return summary
|
| 154 |
+
|
| 155 |
+
|
| 156 |
def main():
|
| 157 |
args = parse_args()
|
| 158 |
require_cuda()
|
|
|
|
| 188 |
split_ds = ds[split]
|
| 189 |
if args.max_samples_per_split:
|
| 190 |
split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
split_dir = out_dir / split
|
| 192 |
split_dir.mkdir(parents=True, exist_ok=True)
|
| 193 |
+
pred_path = split_dir / "predictions.json"
|
| 194 |
+
rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set())
|
| 195 |
+
todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids]
|
| 196 |
+
print(f"Evaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}")
|
| 197 |
+
pbar = tqdm(total=len(todo), desc=split)
|
| 198 |
+
completed_since_save = 0
|
| 199 |
+
for start in range(0, len(todo), args.batch_size):
|
| 200 |
+
batch = todo[start:start + args.batch_size]
|
| 201 |
+
try:
|
| 202 |
+
preds = generate_batch(model, tokenizer, batch, args)
|
| 203 |
+
except torch.cuda.OutOfMemoryError:
|
| 204 |
+
torch.cuda.empty_cache()
|
| 205 |
+
if args.batch_size == 1 or len(batch) == 1:
|
| 206 |
+
raise
|
| 207 |
+
preds = []
|
| 208 |
+
for ex in batch:
|
| 209 |
+
preds.extend(generate_batch(model, tokenizer, [ex], args))
|
| 210 |
+
for ex, pred in zip(batch, preds):
|
| 211 |
+
rows.append(row_metrics(ex, pred.strip()))
|
| 212 |
+
pbar.update(len(batch))
|
| 213 |
+
completed_since_save += len(batch)
|
| 214 |
+
if completed_since_save >= args.save_every:
|
| 215 |
+
write_split_outputs(split_dir, rows)
|
| 216 |
+
completed_since_save = 0
|
| 217 |
+
pbar.close()
|
| 218 |
+
summary = write_split_outputs(split_dir, rows)
|
| 219 |
all_summary[split] = summary
|
| 220 |
+
write_json(out_dir / "all_metrics.json", all_summary)
|
|
|
|
| 221 |
print(json.dumps(all_summary, indent=2)[:4000])
|
| 222 |
|
| 223 |
|