nraptisss commited on
Commit
a2801bb
Β·
verified Β·
1 Parent(s): 734da09

Add zero-shot baseline mode: --adapter_path none skips adapter loading, --no_think suppresses Qwen3 thinking"

Browse files
Files changed (1) hide show
  1. evaluate_v3.py +76 -65
evaluate_v3.py CHANGED
@@ -7,50 +7,50 @@ Optimized evaluation with:
7
  2. Layer-aware max tokens β€” caps generation length per layer (saves ~40% time)
8
  3. Incremental saves β€” writes results after every sample (never lose progress)
9
  4. Resume support β€” skips already-evaluated IDs from a previous checkpoint
 
10
 
11
  All KPI checking logic is identical to v2 (standard-aware).
12
 
13
- Speed estimate (vs v2 on full 2521 test set):
14
- v2 full test: ~50h (2521 samples Γ— ~70s avg)
15
- v3 stratified: ~4-5h (400 samples Γ— ~45s avg, with tighter max_tokens)
16
-
17
  Usage:
 
18
  python evaluate_v3.py --adapter_path ./output
19
- python evaluate_v3.py --adapter_path ./output --per_layer 30 --output_file my_eval.json
20
- python evaluate_v3.py --adapter_path ./output --resume # pick up where you left off
 
 
 
 
 
 
 
21
  """
22
 
23
  import argparse, json, re, os, sys, math, time, random, torch
24
  from collections import defaultdict
25
  from datasets import load_dataset
26
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
27
- from peft import PeftModel
28
 
29
 
30
  # ═══════════════════════════════════════════════════════════════════════
31
  # LAYER-AWARE MAX TOKENS
32
  # ═══════════════════════════════════════════════════════════════════════
33
- # Derived from reference output lengths in the dataset.
34
- # Each value = ceil(max_reference_length_chars / 3.5) + 20% safety margin.
35
- # This prevents the model from wasting time generating 4096 tokens for a
36
- # CAMARA output that's always ~700 chars (~250 tokens).
37
 
38
  LAYER_MAX_TOKENS = {
39
- "tmf921": 1600, # ref ~4200 chars β†’ ~1200 tok + margin
40
- "intent_3gpp": 900, # ref ~2100 chars β†’ ~600 tok + margin
41
- "camara": 400, # ref ~700 chars β†’ ~200 tok + margin
42
- "a1_policy": 350, # ref ~650 chars β†’ ~185 tok + margin
43
- "o1_nrm": 500, # ref ~1050 chars β†’ ~300 tok + margin
44
- "etsi_zsm": 1100, # ref ~2800 chars β†’ ~800 tok + margin
45
- "tmf921_lifecycle_activate": 250, # ref ~300 chars
46
- "tmf921_lifecycle_modify": 600, # ref ~1250 chars
47
- "tmf921_lifecycle_monitor": 500, # ref ~1000 chars
48
- "tmf921_lifecycle_report": 800, # ref ~1800 chars
49
- "tmf921_lifecycle_resume": 250, # ref ~300 chars
50
- "tmf921_lifecycle_scale": 350, # ref ~660 chars
51
- "tmf921_lifecycle_suspend": 250, # ref ~300 chars
52
- "tmf921_lifecycle_terminate": 250, # ref ~320 chars
53
- "adversarial_ambiguous": 200, # ref ~200 chars
54
  "adversarial_contradictory": 200,
55
  "adversarial_out_of_scope": 200,
56
  }
@@ -68,7 +68,8 @@ def log(msg: str):
68
  def parse_args():
69
  p = argparse.ArgumentParser(description="TMF921 Evaluation v3 β€” stratified, incremental, fast")
70
  p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
71
- p.add_argument("--adapter_path", type=str, default="./output")
 
72
  p.add_argument("--dataset", type=str, default="nraptisss/TMF921-intent-to-config-augmented")
73
  p.add_argument("--split", type=str, default="test")
74
  p.add_argument("--per_layer", type=int, default=50,
@@ -79,6 +80,8 @@ def parse_args():
79
  help="Resume from existing output_file, skipping already-evaluated IDs")
80
  p.add_argument("--flash_attn", action="store_true", default=True)
81
  p.add_argument("--no_flash_attn", action="store_true", default=False)
 
 
82
  p.add_argument("--save_generations", action="store_true", default=True)
83
  return p.parse_args()
84
 
@@ -88,18 +91,10 @@ def parse_args():
88
  # ═══════════════════════════════════════════════════════════════════════
89
 
90
  def stratified_sample(ds, per_layer: int, seed: int = 42):
91
- """
92
- Sample up to `per_layer` examples per target_layer.
93
- Layers with fewer samples than `per_layer` are included in full.
94
- Returns list of indices into the original dataset.
95
- """
96
  rng = random.Random(seed)
97
-
98
- # Group indices by target_layer
99
  layer_indices = defaultdict(list)
100
  for i in range(len(ds)):
101
  layer_indices[ds[i]["target_layer"]].append(i)
102
-
103
  selected = []
104
  layer_counts = {}
105
  for layer, indices in sorted(layer_indices.items()):
@@ -109,19 +104,20 @@ def stratified_sample(ds, per_layer: int, seed: int = 42):
109
  chosen = rng.sample(indices, per_layer)
110
  selected.extend(chosen)
111
  layer_counts[layer] = len(chosen)
112
-
113
- # Shuffle so we don't evaluate all of one layer before the next
114
  rng.shuffle(selected)
115
-
116
  return selected, layer_counts
117
 
118
 
119
  # ═══════════════════════════════════════════════════════════════════════
120
- # JSON PARSING (identical to v2)
121
  # ═══════════════════════════════════════════════════════════════════════
122
 
123
  def try_parse_json(text: str):
124
  text = text.strip()
 
 
 
 
125
  if text.startswith("```"):
126
  text = re.sub(r"^```(?:json)?\s*\n?", "", text)
127
  text = re.sub(r"\n?```\s*$", "", text)
@@ -139,7 +135,7 @@ def try_parse_json(text: str):
139
 
140
 
141
  # ═══════════════════════════════════════════════════════════════════════
142
- # KPI CHECKING (identical to v2)
143
  # ═══════════════════════════════════════════════════════════════════════
144
 
145
  def _num_representations(val: float) -> list:
@@ -150,7 +146,6 @@ def _num_representations(val: float) -> list:
150
  reps.append(f"{val:.0f}")
151
  return list(set(reps))
152
 
153
-
154
  def _reliability_representations(rel_pct: float) -> list:
155
  reps = _num_representations(rel_pct)
156
  per = 1 - rel_pct / 100
@@ -169,7 +164,6 @@ def _reliability_representations(rel_pct: float) -> list:
169
  reps.append(f"{per:.10f}".rstrip("0").rstrip("."))
170
  return list(set(reps))
171
 
172
-
173
  def _find_all_numbers(parsed) -> list:
174
  nums = []
175
  if isinstance(parsed, dict):
@@ -186,7 +180,6 @@ def _find_all_numbers(parsed) -> list:
186
  nums.extend(_find_all_numbers(item))
187
  return nums
188
 
189
-
190
  def _check_kpi_direct(parsed, row, flat):
191
  return {
192
  "has_latency": any(r in flat for r in _num_representations(row["latency_ms"])),
@@ -196,7 +189,6 @@ def _check_kpi_direct(parsed, row, flat):
196
  "has_max_ues": any(r in flat for r in _num_representations(float(row["max_ues"]))),
197
  }
198
 
199
-
200
  def _check_kpi_a1_policy(parsed, row, flat):
201
  results = {}
202
  all_nums = _find_all_numbers(parsed)
@@ -216,7 +208,6 @@ def _check_kpi_a1_policy(parsed, row, flat):
216
  results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat
217
  return results
218
 
219
-
220
  def _check_kpi_o1_nrm(parsed, row, flat):
221
  return {
222
  "has_latency": '"rrmpolicy"' in flat or '"nrcelldu"' in flat,
@@ -227,7 +218,6 @@ def _check_kpi_o1_nrm(parsed, row, flat):
227
  "has_max_ues": '"rrmpolicymemberlist"' in flat or '"snssai"' in flat,
228
  }
229
 
230
-
231
  DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"}
232
 
233
  def check_kpi_fields(parsed, row, target_layer):
@@ -242,7 +232,7 @@ def check_kpi_fields(parsed, row, target_layer):
242
 
243
 
244
  # ═══════════════════════════════════════════════════════════════════════
245
- # STRUCTURE CHECKING (identical to v2)
246
  # ═══════════════════════════════════════════════════════════════════════
247
 
248
  LAYER_ROOT_KEYS = {
@@ -290,20 +280,16 @@ def check_structure(parsed, target_layer):
290
  # ═══════════════════════════════════════════════════════════════════════
291
 
292
  def save_checkpoint(output_file, config, results):
293
- """Write full results JSON after every sample."""
294
  n = len(results)
295
  if n == 0:
296
  return
297
-
298
  total_valid = sum(1 for r in results if r["json_valid"])
299
  total_struct = sum(1 for r in results if r["structure_correct"])
300
-
301
  overall = {
302
  "total_samples": n,
303
  "json_validity_rate": total_valid / n,
304
  "structure_correctness_rate": total_struct / n,
305
  }
306
-
307
  kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
308
  kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
309
  if kpi_samples:
@@ -312,7 +298,6 @@ def save_checkpoint(output_file, config, results):
312
  overall[field + "_rate"] = sum(vals) / len(vals)
313
  all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
314
  overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
315
-
316
  per_layer = defaultdict(lambda: defaultdict(list))
317
  for r in results:
318
  layer = r["target_layer"]
@@ -321,7 +306,6 @@ def save_checkpoint(output_file, config, results):
321
  for k in kpi_fields:
322
  if k in r:
323
  per_layer[layer][k].append(r[k])
324
-
325
  layer_summary = {}
326
  for layer, metrics in per_layer.items():
327
  ln = len(metrics["json_valid"])
@@ -333,22 +317,18 @@ def save_checkpoint(output_file, config, results):
333
  for k in kpi_fields:
334
  if k in metrics and metrics[k]:
335
  layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
336
-
337
  output = {
338
  "config": config,
339
  "overall": overall,
340
  "per_layer": layer_summary,
341
  "raw_results": results,
342
  }
343
-
344
  tmp = output_file + ".tmp"
345
  with open(tmp, "w") as f:
346
  json.dump(output, f, indent=2, default=str)
347
  os.replace(tmp, output_file)
348
 
349
-
350
  def load_checkpoint(output_file):
351
- """Load previously evaluated IDs for resume."""
352
  if not os.path.exists(output_file):
353
  return [], set()
354
  try:
@@ -380,7 +360,6 @@ def compute_gt_baseline(ds):
380
  kpi = check_kpi_fields(parsed, row, layer)
381
  for k, v in kpi.items():
382
  gt_results[layer][k].append(v)
383
-
384
  log("\n Ground-truth baseline (metric ceiling):")
385
  log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}")
386
  log(" " + "─" * 55)
@@ -402,12 +381,18 @@ def main():
402
  if args.no_flash_attn:
403
  args.flash_attn = False
404
 
 
 
 
405
  log("=" * 70)
406
  log("TMF921 Intent Translation β€” Evaluation v3")
 
 
407
  log(" Stratified sampling Β· Layer-aware max tokens Β· Incremental saves")
408
  log("=" * 70)
409
  log(f" Base model : {args.base_model}")
410
- log(f" Adapter : {args.adapter_path}")
 
411
  log(f" Dataset : {args.dataset} [{args.split}]")
412
  log(f" Per-layer N : {args.per_layer} (-1 = all)")
413
  log(f" Output : {args.output_file}")
@@ -469,18 +454,31 @@ def main():
469
  model_kwargs["attn_implementation"] = "flash_attention_2"
470
 
471
  base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
472
- model = PeftModel.from_pretrained(base_model, args.adapter_path)
 
 
 
 
 
 
 
 
473
  model.eval()
474
 
475
  tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
476
  if tokenizer.pad_token is None:
477
  tokenizer.pad_token = tokenizer.eos_token
478
 
479
- log("βœ… Model loaded\n")
480
 
481
  # ── Inference loop ──
482
  results = list(prev_results)
483
- config = {**vars(args), "total_selected": len(indices), "layer_counts": dict(layer_counts)}
 
 
 
 
 
484
  t_start = time.time()
485
  total_to_do = len(remaining_indices)
486
 
@@ -494,7 +492,18 @@ def main():
494
  messages = row["messages"]
495
  reference_output = messages[-1]["content"]
496
 
 
497
  prompt_messages = [m for m in messages if m["role"] != "assistant"]
 
 
 
 
 
 
 
 
 
 
498
  input_text = tokenizer.apply_chat_template(
499
  prompt_messages, tokenize=False, add_generation_prompt=True
500
  )
@@ -537,7 +546,7 @@ def main():
537
  results.append(result)
538
 
539
  # ── Incremental save ──
540
- save_checkpoint(args.output_file, config, results)
541
 
542
  # ── Progress ──
543
  elapsed = time.time() - t_start
@@ -552,7 +561,8 @@ def main():
552
  s = "βœ“" if has_correct_structure else "βœ—"
553
  progress_n = len(done_ids) + done_now
554
  progress_total = len(indices)
555
- log(f" [{progress_n:>4}/{progress_total}] {target_layer:<30} JSON:{j} Struct:{s} "
 
556
  f"| {sample_time:.1f}s (max_tok={max_tokens}) | ETA: {eta_h}h{eta_m:02d}m")
557
 
558
  # ── Final summary ──
@@ -562,7 +572,8 @@ def main():
562
  total_struct = sum(1 for r in results if r["structure_correct"])
563
 
564
  log(f"\n{'=' * 70}")
565
- log(f"FINAL RESULTS ({n} samples, {total_time/3600:.1f}h)")
 
566
  log(f"{'=' * 70}")
567
  log(f" JSON Validity: {total_valid}/{n} ({total_valid/n*100:.1f}%)")
568
  log(f" Structure Correct: {total_struct}/{n} ({total_struct/n*100:.1f}%)")
 
7
  2. Layer-aware max tokens β€” caps generation length per layer (saves ~40% time)
8
  3. Incremental saves β€” writes results after every sample (never lose progress)
9
  4. Resume support β€” skips already-evaluated IDs from a previous checkpoint
10
+ 5. Zero-shot baseline β€” --adapter_path none to evaluate base model without adapter
11
 
12
  All KPI checking logic is identical to v2 (standard-aware).
13
 
 
 
 
 
14
  Usage:
15
+ # Fine-tuned model evaluation
16
  python evaluate_v3.py --adapter_path ./output
17
+
18
+ # Zero-shot baseline (no adapter, base model only)
19
+ python evaluate_v3.py --adapter_path none --output_file eval_v3_baseline.json
20
+
21
+ # Zero-shot baseline with Qwen3 thinking disabled
22
+ python evaluate_v3.py --adapter_path none --no_think --output_file eval_v3_baseline.json
23
+
24
+ # Resume interrupted run
25
+ python evaluate_v3.py --adapter_path ./output --resume
26
  """
27
 
28
  import argparse, json, re, os, sys, math, time, random, torch
29
  from collections import defaultdict
30
  from datasets import load_dataset
31
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
32
 
33
 
34
  # ═══════════════════════════════════════════════════════════════════════
35
  # LAYER-AWARE MAX TOKENS
36
  # ═══════════════════════════════════════════════════════════════════════
 
 
 
 
37
 
38
  LAYER_MAX_TOKENS = {
39
+ "tmf921": 1600,
40
+ "intent_3gpp": 900,
41
+ "camara": 400,
42
+ "a1_policy": 350,
43
+ "o1_nrm": 500,
44
+ "etsi_zsm": 1100,
45
+ "tmf921_lifecycle_activate": 250,
46
+ "tmf921_lifecycle_modify": 600,
47
+ "tmf921_lifecycle_monitor": 500,
48
+ "tmf921_lifecycle_report": 800,
49
+ "tmf921_lifecycle_resume": 250,
50
+ "tmf921_lifecycle_scale": 350,
51
+ "tmf921_lifecycle_suspend": 250,
52
+ "tmf921_lifecycle_terminate": 250,
53
+ "adversarial_ambiguous": 200,
54
  "adversarial_contradictory": 200,
55
  "adversarial_out_of_scope": 200,
56
  }
 
68
  def parse_args():
69
  p = argparse.ArgumentParser(description="TMF921 Evaluation v3 β€” stratified, incremental, fast")
70
  p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
71
+ p.add_argument("--adapter_path", type=str, default="./output",
72
+ help="Path to LoRA adapter, or 'none' for zero-shot baseline")
73
  p.add_argument("--dataset", type=str, default="nraptisss/TMF921-intent-to-config-augmented")
74
  p.add_argument("--split", type=str, default="test")
75
  p.add_argument("--per_layer", type=int, default=50,
 
80
  help="Resume from existing output_file, skipping already-evaluated IDs")
81
  p.add_argument("--flash_attn", action="store_true", default=True)
82
  p.add_argument("--no_flash_attn", action="store_true", default=False)
83
+ p.add_argument("--no_think", action="store_true", default=False,
84
+ help="Suppress Qwen3 thinking mode by adding /no_think to the last user message")
85
  p.add_argument("--save_generations", action="store_true", default=True)
86
  return p.parse_args()
87
 
 
91
  # ═══════════════════════════════════════════════════════════════════════
92
 
93
  def stratified_sample(ds, per_layer: int, seed: int = 42):
 
 
 
 
 
94
  rng = random.Random(seed)
 
 
95
  layer_indices = defaultdict(list)
96
  for i in range(len(ds)):
97
  layer_indices[ds[i]["target_layer"]].append(i)
 
98
  selected = []
99
  layer_counts = {}
100
  for layer, indices in sorted(layer_indices.items()):
 
104
  chosen = rng.sample(indices, per_layer)
105
  selected.extend(chosen)
106
  layer_counts[layer] = len(chosen)
 
 
107
  rng.shuffle(selected)
 
108
  return selected, layer_counts
109
 
110
 
111
  # ═══════════════════════════════════════════════════════════════════════
112
+ # JSON PARSING
113
  # ═══════════════════════════════════════════════════════════════════════
114
 
115
  def try_parse_json(text: str):
116
  text = text.strip()
117
+ # Strip thinking tags if present
118
+ think_match = re.match(r"<think>[\s\S]*?</think>\s*", text)
119
+ if think_match:
120
+ text = text[think_match.end():].strip()
121
  if text.startswith("```"):
122
  text = re.sub(r"^```(?:json)?\s*\n?", "", text)
123
  text = re.sub(r"\n?```\s*$", "", text)
 
135
 
136
 
137
  # ═══════════════════════════════════════════════════════════════════════
138
+ # KPI CHECKING (standard-aware, identical to v2)
139
  # ═══════════════════════════════════════════════════════════════════════
140
 
141
  def _num_representations(val: float) -> list:
 
146
  reps.append(f"{val:.0f}")
147
  return list(set(reps))
148
 
 
149
  def _reliability_representations(rel_pct: float) -> list:
150
  reps = _num_representations(rel_pct)
151
  per = 1 - rel_pct / 100
 
164
  reps.append(f"{per:.10f}".rstrip("0").rstrip("."))
165
  return list(set(reps))
166
 
 
167
  def _find_all_numbers(parsed) -> list:
168
  nums = []
169
  if isinstance(parsed, dict):
 
180
  nums.extend(_find_all_numbers(item))
181
  return nums
182
 
 
183
  def _check_kpi_direct(parsed, row, flat):
184
  return {
185
  "has_latency": any(r in flat for r in _num_representations(row["latency_ms"])),
 
189
  "has_max_ues": any(r in flat for r in _num_representations(float(row["max_ues"]))),
190
  }
191
 
 
192
  def _check_kpi_a1_policy(parsed, row, flat):
193
  results = {}
194
  all_nums = _find_all_numbers(parsed)
 
208
  results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat
209
  return results
210
 
 
211
  def _check_kpi_o1_nrm(parsed, row, flat):
212
  return {
213
  "has_latency": '"rrmpolicy"' in flat or '"nrcelldu"' in flat,
 
218
  "has_max_ues": '"rrmpolicymemberlist"' in flat or '"snssai"' in flat,
219
  }
220
 
 
221
  DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"}
222
 
223
  def check_kpi_fields(parsed, row, target_layer):
 
232
 
233
 
234
  # ═══════════════════════════════════════════════════════════════════════
235
+ # STRUCTURE CHECKING
236
  # ═══════════════════════════════════════════════════════════════════════
237
 
238
  LAYER_ROOT_KEYS = {
 
280
  # ═══════════════════════════════════════════════════════════════════════
281
 
282
  def save_checkpoint(output_file, config, results):
 
283
  n = len(results)
284
  if n == 0:
285
  return
 
286
  total_valid = sum(1 for r in results if r["json_valid"])
287
  total_struct = sum(1 for r in results if r["structure_correct"])
 
288
  overall = {
289
  "total_samples": n,
290
  "json_validity_rate": total_valid / n,
291
  "structure_correctness_rate": total_struct / n,
292
  }
 
293
  kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
294
  kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
295
  if kpi_samples:
 
298
  overall[field + "_rate"] = sum(vals) / len(vals)
299
  all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
300
  overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
 
301
  per_layer = defaultdict(lambda: defaultdict(list))
302
  for r in results:
303
  layer = r["target_layer"]
 
306
  for k in kpi_fields:
307
  if k in r:
308
  per_layer[layer][k].append(r[k])
 
309
  layer_summary = {}
310
  for layer, metrics in per_layer.items():
311
  ln = len(metrics["json_valid"])
 
317
  for k in kpi_fields:
318
  if k in metrics and metrics[k]:
319
  layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
 
320
  output = {
321
  "config": config,
322
  "overall": overall,
323
  "per_layer": layer_summary,
324
  "raw_results": results,
325
  }
 
326
  tmp = output_file + ".tmp"
327
  with open(tmp, "w") as f:
328
  json.dump(output, f, indent=2, default=str)
329
  os.replace(tmp, output_file)
330
 
 
331
  def load_checkpoint(output_file):
 
332
  if not os.path.exists(output_file):
333
  return [], set()
334
  try:
 
360
  kpi = check_kpi_fields(parsed, row, layer)
361
  for k, v in kpi.items():
362
  gt_results[layer][k].append(v)
 
363
  log("\n Ground-truth baseline (metric ceiling):")
364
  log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}")
365
  log(" " + "─" * 55)
 
381
  if args.no_flash_attn:
382
  args.flash_attn = False
383
 
384
+ # Detect baseline mode
385
+ is_baseline = args.adapter_path.lower() in ("none", "baseline", "base", "")
386
+
387
  log("=" * 70)
388
  log("TMF921 Intent Translation β€” Evaluation v3")
389
+ if is_baseline:
390
+ log(" *** ZERO-SHOT BASELINE MODE (no adapter) ***")
391
  log(" Stratified sampling Β· Layer-aware max tokens Β· Incremental saves")
392
  log("=" * 70)
393
  log(f" Base model : {args.base_model}")
394
+ log(f" Adapter : {'NONE (zero-shot baseline)' if is_baseline else args.adapter_path}")
395
+ log(f" No-think : {args.no_think}")
396
  log(f" Dataset : {args.dataset} [{args.split}]")
397
  log(f" Per-layer N : {args.per_layer} (-1 = all)")
398
  log(f" Output : {args.output_file}")
 
454
  model_kwargs["attn_implementation"] = "flash_attention_2"
455
 
456
  base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
457
+
458
+ if is_baseline:
459
+ model = base_model
460
+ log(" βœ… Base model loaded (zero-shot baseline β€” no adapter)")
461
+ else:
462
+ from peft import PeftModel
463
+ model = PeftModel.from_pretrained(base_model, args.adapter_path)
464
+ log(f" βœ… Fine-tuned model loaded (adapter: {args.adapter_path})")
465
+
466
  model.eval()
467
 
468
  tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
469
  if tokenizer.pad_token is None:
470
  tokenizer.pad_token = tokenizer.eos_token
471
 
472
+ log("")
473
 
474
  # ── Inference loop ──
475
  results = list(prev_results)
476
+ config_dict = {
477
+ **vars(args),
478
+ "is_baseline": is_baseline,
479
+ "total_selected": len(indices),
480
+ "layer_counts": dict(layer_counts),
481
+ }
482
  t_start = time.time()
483
  total_to_do = len(remaining_indices)
484
 
 
492
  messages = row["messages"]
493
  reference_output = messages[-1]["content"]
494
 
495
+ # Build prompt messages (system + user, no assistant)
496
  prompt_messages = [m for m in messages if m["role"] != "assistant"]
497
+
498
+ # Optionally suppress Qwen3 thinking mode
499
+ if args.no_think and prompt_messages:
500
+ last_msg = prompt_messages[-1]
501
+ if last_msg["role"] == "user" and "/no_think" not in last_msg["content"]:
502
+ prompt_messages[-1] = {
503
+ "role": last_msg["role"],
504
+ "content": last_msg["content"] + " /no_think"
505
+ }
506
+
507
  input_text = tokenizer.apply_chat_template(
508
  prompt_messages, tokenize=False, add_generation_prompt=True
509
  )
 
546
  results.append(result)
547
 
548
  # ── Incremental save ──
549
+ save_checkpoint(args.output_file, config_dict, results)
550
 
551
  # ── Progress ──
552
  elapsed = time.time() - t_start
 
561
  s = "βœ“" if has_correct_structure else "βœ—"
562
  progress_n = len(done_ids) + done_now
563
  progress_total = len(indices)
564
+ mode_tag = "[BASE]" if is_baseline else "[FT]"
565
+ log(f" {mode_tag} [{progress_n:>4}/{progress_total}] {target_layer:<30} JSON:{j} Struct:{s} "
566
  f"| {sample_time:.1f}s (max_tok={max_tokens}) | ETA: {eta_h}h{eta_m:02d}m")
567
 
568
  # ── Final summary ──
 
572
  total_struct = sum(1 for r in results if r["structure_correct"])
573
 
574
  log(f"\n{'=' * 70}")
575
+ mode_str = "ZERO-SHOT BASELINE" if is_baseline else "FINE-TUNED"
576
+ log(f"FINAL RESULTS β€” {mode_str} ({n} samples, {total_time/3600:.1f}h)")
577
  log(f"{'=' * 70}")
578
  log(f" JSON Validity: {total_valid}/{n} ({total_valid/n*100:.1f}%)")
579
  log(f" Structure Correct: {total_struct}/{n} ({total_struct/n*100:.1f}%)")