PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
nraptisss commited on
Commit
6f5475f
·
verified ·
1 Parent(s): 5a23de5

Speed up and resume OOD evaluation with batched dynamic generation

Browse files
Files changed (1) hide show
  1. scripts/evaluate_model.py +92 -37
scripts/evaluate_model.py CHANGED
@@ -1,16 +1,18 @@
1
  #!/usr/bin/env python3
2
- """Generation evaluation for TMF921 intent-to-config models/adapters.
3
 
4
  Metrics: JSON parse rate, exact canonical JSON match, field-level F1, simple metadata constraints,
5
  stratified by split/target_layer/slice_type/lifecycle_operation.
 
 
 
6
  """
7
  import argparse
8
- import csv
9
  import json
10
  import os
11
  from collections import defaultdict
12
  from pathlib import Path
13
- from typing import Any, Dict, List
14
 
15
  import torch
16
  from datasets import load_dataset
@@ -29,7 +31,12 @@ def parse_args():
29
  p.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
30
  p.add_argument("--output_dir", default="eval_outputs")
31
  p.add_argument("--max_samples_per_split", type=int, default=None)
32
- p.add_argument("--max_new_tokens", type=int, default=2048)
 
 
 
 
 
33
  p.add_argument("--temperature", type=float, default=0.0)
34
  p.add_argument("--top_p", type=float, default=1.0)
35
  p.add_argument("--load_in_4bit", action="store_true", help="Load base model in 4-bit for adapter evaluation")
@@ -39,25 +46,43 @@ def parse_args():
39
 
40
 
41
  def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
42
- # Keep all messages before the final assistant answer. For standard rows this is system+user;
43
- # for multi-turn rows this preserves earlier assistant turns but removes gold final answer.
44
  out = []
45
- for m in messages:
46
- if m.get("role") == "assistant" and m == messages[-1]:
47
  break
48
  out.append({"role": m.get("role"), "content": m.get("content", "")})
49
- # Fallback for any unexpected format.
50
  if not out:
51
  out = [m for m in messages if m.get("role") != "assistant"]
52
  return out
53
 
54
 
55
- def generate_one(model, tokenizer, messages, args):
56
- prompt_messages = make_prompt_messages(messages)
57
- text = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)
58
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  gen_kwargs = dict(
60
- max_new_tokens=args.max_new_tokens,
61
  do_sample=args.temperature > 0,
62
  temperature=args.temperature if args.temperature > 0 else None,
63
  top_p=args.top_p,
@@ -65,16 +90,16 @@ def generate_one(model, tokenizer, messages, args):
65
  eos_token_id=tokenizer.eos_token_id,
66
  )
67
  gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
68
- with torch.no_grad():
69
  out = model.generate(**inputs, **gen_kwargs)
70
- new_tokens = out[0, inputs["input_ids"].shape[1]:]
71
- return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
72
 
73
 
74
  def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
75
- gold_text = example.get("completion") or get_message(example["messages"], "assistant")
76
  pred_obj, pred_err = parse_json(prediction)
77
- gold_obj, gold_err = parse_json(gold_text)
78
  out: Dict[str, Any] = {
79
  "id": example.get("id"),
80
  "target_layer": example.get("target_layer"),
@@ -84,7 +109,7 @@ def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
84
  "gold_parse_json": gold_obj is not None,
85
  "exact_match": False,
86
  "prediction": prediction,
87
- "gold": gold_text,
88
  "parse_error": pred_err,
89
  }
90
  if pred_obj is not None and gold_obj is not None:
@@ -108,6 +133,26 @@ def require_cuda():
108
  print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}")
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def main():
112
  args = parse_args()
113
  require_cuda()
@@ -143,26 +188,36 @@ def main():
143
  split_ds = ds[split]
144
  if args.max_samples_per_split:
145
  split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds))))
146
- rows = []
147
- print(f"Evaluating {split}: {len(split_ds)} examples")
148
- for ex in tqdm(split_ds, desc=split):
149
- pred = generate_one(model, tokenizer, ex["messages"], args)
150
- rows.append(row_metrics(ex, pred))
151
-
152
  split_dir = out_dir / split
153
  split_dir.mkdir(parents=True, exist_ok=True)
154
- write_json(split_dir / "predictions.json", rows)
155
- summary = aggregate_metrics(rows)
156
- # Stratified summaries.
157
- for key in ["target_layer", "slice_type", "lifecycle_operation"]:
158
- groups = defaultdict(list)
159
- for r in rows:
160
- groups[str(r.get(key))].append(r)
161
- summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
162
- write_json(split_dir / "metrics.json", summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  all_summary[split] = summary
164
-
165
- write_json(out_dir / "all_metrics.json", all_summary)
166
  print(json.dumps(all_summary, indent=2)[:4000])
167
 
168
 
 
1
  #!/usr/bin/env python3
2
+ """Fast generation evaluation for TMF921 intent-to-config models/adapters.
3
 
4
  Metrics: JSON parse rate, exact canonical JSON match, field-level F1, simple metadata constraints,
5
  stratified by split/target_layer/slice_type/lifecycle_operation.
6
+
7
+ The evaluator is resumable and batched. It periodically writes partial predictions so a long OOD
8
+ run can be stopped/restarted without losing completed examples.
9
  """
10
  import argparse
 
11
  import json
12
  import os
13
  from collections import defaultdict
14
  from pathlib import Path
15
+ from typing import Any, Dict, List, Tuple
16
 
17
  import torch
18
  from datasets import load_dataset
 
31
  p.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
32
  p.add_argument("--output_dir", default="eval_outputs")
33
  p.add_argument("--max_samples_per_split", type=int, default=None)
34
+ p.add_argument("--max_new_tokens", type=int, default=1536, help="Hard cap for generation. 1536 covers audited dataset outputs with margin.")
35
+ p.add_argument("--gold_length_buffer", type=int, default=96, help="Dynamic cap = max gold output tokens in batch + buffer, clipped by max_new_tokens")
36
+ p.add_argument("--batch_size", type=int, default=4, help="Batched generation size. Use 1 if OOM; 4 is usually safe on RTX 6000 Ada for Qwen3-8B QLoRA.")
37
+ p.add_argument("--save_every", type=int, default=25, help="Write partial predictions every N completed examples")
38
+ p.add_argument("--resume", action="store_true", default=True, help="Resume from existing predictions.json if present")
39
+ p.add_argument("--no_resume", dest="resume", action="store_false")
40
  p.add_argument("--temperature", type=float, default=0.0)
41
  p.add_argument("--top_p", type=float, default=1.0)
42
  p.add_argument("--load_in_4bit", action="store_true", help="Load base model in 4-bit for adapter evaluation")
 
46
 
47
 
48
  def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
 
 
49
  out = []
50
+ for i, m in enumerate(messages):
51
+ if i == len(messages) - 1 and m.get("role") == "assistant":
52
  break
53
  out.append({"role": m.get("role"), "content": m.get("content", "")})
 
54
  if not out:
55
  out = [m for m in messages if m.get("role") != "assistant"]
56
  return out
57
 
58
 
59
+ def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str:
60
+ return tokenizer.apply_chat_template(make_prompt_messages(messages), tokenize=False, add_generation_prompt=True)
61
+
62
+
63
+ def gold_text(example: Dict[str, Any]) -> str:
64
+ return example.get("completion") or get_message(example["messages"], "assistant")
65
+
66
+
67
+ def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int:
68
+ lens = []
69
+ for ex in examples:
70
+ ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"]
71
+ lens.append(len(ids))
72
+ return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer)))
73
+
74
+
75
+ def generate_batch(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]:
76
+ texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples]
77
+ old_padding_side = tokenizer.padding_side
78
+ tokenizer.padding_side = "left"
79
+ try:
80
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
81
+ finally:
82
+ tokenizer.padding_side = old_padding_side
83
+ max_new = dynamic_max_new_tokens(tokenizer, examples, args)
84
  gen_kwargs = dict(
85
+ max_new_tokens=max_new,
86
  do_sample=args.temperature > 0,
87
  temperature=args.temperature if args.temperature > 0 else None,
88
  top_p=args.top_p,
 
90
  eos_token_id=tokenizer.eos_token_id,
91
  )
92
  gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
93
+ with torch.inference_mode():
94
  out = model.generate(**inputs, **gen_kwargs)
95
+ new_tokens = out[:, inputs["input_ids"].shape[1]:]
96
+ return tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
97
 
98
 
99
  def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
100
+ gold = gold_text(example)
101
  pred_obj, pred_err = parse_json(prediction)
102
+ gold_obj, gold_err = parse_json(gold)
103
  out: Dict[str, Any] = {
104
  "id": example.get("id"),
105
  "target_layer": example.get("target_layer"),
 
109
  "gold_parse_json": gold_obj is not None,
110
  "exact_match": False,
111
  "prediction": prediction,
112
+ "gold": gold,
113
  "parse_error": pred_err,
114
  }
115
  if pred_obj is not None and gold_obj is not None:
 
133
  print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}")
134
 
135
 
136
+ def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]:
137
+ if path.exists():
138
+ rows = json.loads(path.read_text())
139
+ done = {str(r.get("id")) for r in rows}
140
+ return rows, done
141
+ return [], set()
142
+
143
+
144
+ def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]:
145
+ write_json(split_dir / "predictions.json", rows)
146
+ summary = aggregate_metrics(rows)
147
+ for key in ["target_layer", "slice_type", "lifecycle_operation"]:
148
+ groups = defaultdict(list)
149
+ for r in rows:
150
+ groups[str(r.get(key))].append(r)
151
+ summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
152
+ write_json(split_dir / "metrics.json", summary)
153
+ return summary
154
+
155
+
156
  def main():
157
  args = parse_args()
158
  require_cuda()
 
188
  split_ds = ds[split]
189
  if args.max_samples_per_split:
190
  split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds))))
 
 
 
 
 
 
191
  split_dir = out_dir / split
192
  split_dir.mkdir(parents=True, exist_ok=True)
193
+ pred_path = split_dir / "predictions.json"
194
+ rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set())
195
+ todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids]
196
+ print(f"Evaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}")
197
+ pbar = tqdm(total=len(todo), desc=split)
198
+ completed_since_save = 0
199
+ for start in range(0, len(todo), args.batch_size):
200
+ batch = todo[start:start + args.batch_size]
201
+ try:
202
+ preds = generate_batch(model, tokenizer, batch, args)
203
+ except torch.cuda.OutOfMemoryError:
204
+ torch.cuda.empty_cache()
205
+ if args.batch_size == 1 or len(batch) == 1:
206
+ raise
207
+ preds = []
208
+ for ex in batch:
209
+ preds.extend(generate_batch(model, tokenizer, [ex], args))
210
+ for ex, pred in zip(batch, preds):
211
+ rows.append(row_metrics(ex, pred.strip()))
212
+ pbar.update(len(batch))
213
+ completed_since_save += len(batch)
214
+ if completed_since_save >= args.save_every:
215
+ write_split_outputs(split_dir, rows)
216
+ completed_since_save = 0
217
+ pbar.close()
218
+ summary = write_split_outputs(split_dir, rows)
219
  all_summary[split] = summary
220
+ write_json(out_dir / "all_metrics.json", all_summary)
 
221
  print(json.dumps(all_summary, indent=2)[:4000])
222
 
223