Add zero-shot baseline mode: --adapter_path none skips adapter loading, --no_think suppresses Qwen3 thinking"
Browse files- evaluate_v3.py +76 -65
evaluate_v3.py
CHANGED
|
@@ -7,50 +7,50 @@ Optimized evaluation with:
|
|
| 7 |
2. Layer-aware max tokens β caps generation length per layer (saves ~40% time)
|
| 8 |
3. Incremental saves β writes results after every sample (never lose progress)
|
| 9 |
4. Resume support β skips already-evaluated IDs from a previous checkpoint
|
|
|
|
| 10 |
|
| 11 |
All KPI checking logic is identical to v2 (standard-aware).
|
| 12 |
|
| 13 |
-
Speed estimate (vs v2 on full 2521 test set):
|
| 14 |
-
v2 full test: ~50h (2521 samples Γ ~70s avg)
|
| 15 |
-
v3 stratified: ~4-5h (400 samples Γ ~45s avg, with tighter max_tokens)
|
| 16 |
-
|
| 17 |
Usage:
|
|
|
|
| 18 |
python evaluate_v3.py --adapter_path ./output
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
import argparse, json, re, os, sys, math, time, random, torch
|
| 24 |
from collections import defaultdict
|
| 25 |
from datasets import load_dataset
|
| 26 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 27 |
-
from peft import PeftModel
|
| 28 |
|
| 29 |
|
| 30 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
# LAYER-AWARE MAX TOKENS
|
| 32 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
-
# Derived from reference output lengths in the dataset.
|
| 34 |
-
# Each value = ceil(max_reference_length_chars / 3.5) + 20% safety margin.
|
| 35 |
-
# This prevents the model from wasting time generating 4096 tokens for a
|
| 36 |
-
# CAMARA output that's always ~700 chars (~250 tokens).
|
| 37 |
|
| 38 |
LAYER_MAX_TOKENS = {
|
| 39 |
-
"tmf921": 1600,
|
| 40 |
-
"intent_3gpp": 900,
|
| 41 |
-
"camara": 400,
|
| 42 |
-
"a1_policy": 350,
|
| 43 |
-
"o1_nrm": 500,
|
| 44 |
-
"etsi_zsm": 1100,
|
| 45 |
-
"tmf921_lifecycle_activate": 250,
|
| 46 |
-
"tmf921_lifecycle_modify": 600,
|
| 47 |
-
"tmf921_lifecycle_monitor": 500,
|
| 48 |
-
"tmf921_lifecycle_report": 800,
|
| 49 |
-
"tmf921_lifecycle_resume": 250,
|
| 50 |
-
"tmf921_lifecycle_scale": 350,
|
| 51 |
-
"tmf921_lifecycle_suspend": 250,
|
| 52 |
-
"tmf921_lifecycle_terminate": 250,
|
| 53 |
-
"adversarial_ambiguous": 200,
|
| 54 |
"adversarial_contradictory": 200,
|
| 55 |
"adversarial_out_of_scope": 200,
|
| 56 |
}
|
|
@@ -68,7 +68,8 @@ def log(msg: str):
|
|
| 68 |
def parse_args():
|
| 69 |
p = argparse.ArgumentParser(description="TMF921 Evaluation v3 β stratified, incremental, fast")
|
| 70 |
p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
|
| 71 |
-
p.add_argument("--adapter_path", type=str, default="./output"
|
|
|
|
| 72 |
p.add_argument("--dataset", type=str, default="nraptisss/TMF921-intent-to-config-augmented")
|
| 73 |
p.add_argument("--split", type=str, default="test")
|
| 74 |
p.add_argument("--per_layer", type=int, default=50,
|
|
@@ -79,6 +80,8 @@ def parse_args():
|
|
| 79 |
help="Resume from existing output_file, skipping already-evaluated IDs")
|
| 80 |
p.add_argument("--flash_attn", action="store_true", default=True)
|
| 81 |
p.add_argument("--no_flash_attn", action="store_true", default=False)
|
|
|
|
|
|
|
| 82 |
p.add_argument("--save_generations", action="store_true", default=True)
|
| 83 |
return p.parse_args()
|
| 84 |
|
|
@@ -88,18 +91,10 @@ def parse_args():
|
|
| 88 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
|
| 90 |
def stratified_sample(ds, per_layer: int, seed: int = 42):
|
| 91 |
-
"""
|
| 92 |
-
Sample up to `per_layer` examples per target_layer.
|
| 93 |
-
Layers with fewer samples than `per_layer` are included in full.
|
| 94 |
-
Returns list of indices into the original dataset.
|
| 95 |
-
"""
|
| 96 |
rng = random.Random(seed)
|
| 97 |
-
|
| 98 |
-
# Group indices by target_layer
|
| 99 |
layer_indices = defaultdict(list)
|
| 100 |
for i in range(len(ds)):
|
| 101 |
layer_indices[ds[i]["target_layer"]].append(i)
|
| 102 |
-
|
| 103 |
selected = []
|
| 104 |
layer_counts = {}
|
| 105 |
for layer, indices in sorted(layer_indices.items()):
|
|
@@ -109,19 +104,20 @@ def stratified_sample(ds, per_layer: int, seed: int = 42):
|
|
| 109 |
chosen = rng.sample(indices, per_layer)
|
| 110 |
selected.extend(chosen)
|
| 111 |
layer_counts[layer] = len(chosen)
|
| 112 |
-
|
| 113 |
-
# Shuffle so we don't evaluate all of one layer before the next
|
| 114 |
rng.shuffle(selected)
|
| 115 |
-
|
| 116 |
return selected, layer_counts
|
| 117 |
|
| 118 |
|
| 119 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
-
# JSON PARSING
|
| 121 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
|
| 123 |
def try_parse_json(text: str):
|
| 124 |
text = text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
if text.startswith("```"):
|
| 126 |
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
|
| 127 |
text = re.sub(r"\n?```\s*$", "", text)
|
|
@@ -139,7 +135,7 @@ def try_parse_json(text: str):
|
|
| 139 |
|
| 140 |
|
| 141 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
-
# KPI CHECKING (identical to v2)
|
| 143 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 144 |
|
| 145 |
def _num_representations(val: float) -> list:
|
|
@@ -150,7 +146,6 @@ def _num_representations(val: float) -> list:
|
|
| 150 |
reps.append(f"{val:.0f}")
|
| 151 |
return list(set(reps))
|
| 152 |
|
| 153 |
-
|
| 154 |
def _reliability_representations(rel_pct: float) -> list:
|
| 155 |
reps = _num_representations(rel_pct)
|
| 156 |
per = 1 - rel_pct / 100
|
|
@@ -169,7 +164,6 @@ def _reliability_representations(rel_pct: float) -> list:
|
|
| 169 |
reps.append(f"{per:.10f}".rstrip("0").rstrip("."))
|
| 170 |
return list(set(reps))
|
| 171 |
|
| 172 |
-
|
| 173 |
def _find_all_numbers(parsed) -> list:
|
| 174 |
nums = []
|
| 175 |
if isinstance(parsed, dict):
|
|
@@ -186,7 +180,6 @@ def _find_all_numbers(parsed) -> list:
|
|
| 186 |
nums.extend(_find_all_numbers(item))
|
| 187 |
return nums
|
| 188 |
|
| 189 |
-
|
| 190 |
def _check_kpi_direct(parsed, row, flat):
|
| 191 |
return {
|
| 192 |
"has_latency": any(r in flat for r in _num_representations(row["latency_ms"])),
|
|
@@ -196,7 +189,6 @@ def _check_kpi_direct(parsed, row, flat):
|
|
| 196 |
"has_max_ues": any(r in flat for r in _num_representations(float(row["max_ues"]))),
|
| 197 |
}
|
| 198 |
|
| 199 |
-
|
| 200 |
def _check_kpi_a1_policy(parsed, row, flat):
|
| 201 |
results = {}
|
| 202 |
all_nums = _find_all_numbers(parsed)
|
|
@@ -216,7 +208,6 @@ def _check_kpi_a1_policy(parsed, row, flat):
|
|
| 216 |
results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat
|
| 217 |
return results
|
| 218 |
|
| 219 |
-
|
| 220 |
def _check_kpi_o1_nrm(parsed, row, flat):
|
| 221 |
return {
|
| 222 |
"has_latency": '"rrmpolicy"' in flat or '"nrcelldu"' in flat,
|
|
@@ -227,7 +218,6 @@ def _check_kpi_o1_nrm(parsed, row, flat):
|
|
| 227 |
"has_max_ues": '"rrmpolicymemberlist"' in flat or '"snssai"' in flat,
|
| 228 |
}
|
| 229 |
|
| 230 |
-
|
| 231 |
DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"}
|
| 232 |
|
| 233 |
def check_kpi_fields(parsed, row, target_layer):
|
|
@@ -242,7 +232,7 @@ def check_kpi_fields(parsed, row, target_layer):
|
|
| 242 |
|
| 243 |
|
| 244 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 245 |
-
# STRUCTURE CHECKING
|
| 246 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 247 |
|
| 248 |
LAYER_ROOT_KEYS = {
|
|
@@ -290,20 +280,16 @@ def check_structure(parsed, target_layer):
|
|
| 290 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 291 |
|
| 292 |
def save_checkpoint(output_file, config, results):
|
| 293 |
-
"""Write full results JSON after every sample."""
|
| 294 |
n = len(results)
|
| 295 |
if n == 0:
|
| 296 |
return
|
| 297 |
-
|
| 298 |
total_valid = sum(1 for r in results if r["json_valid"])
|
| 299 |
total_struct = sum(1 for r in results if r["structure_correct"])
|
| 300 |
-
|
| 301 |
overall = {
|
| 302 |
"total_samples": n,
|
| 303 |
"json_validity_rate": total_valid / n,
|
| 304 |
"structure_correctness_rate": total_struct / n,
|
| 305 |
}
|
| 306 |
-
|
| 307 |
kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
|
| 308 |
kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
|
| 309 |
if kpi_samples:
|
|
@@ -312,7 +298,6 @@ def save_checkpoint(output_file, config, results):
|
|
| 312 |
overall[field + "_rate"] = sum(vals) / len(vals)
|
| 313 |
all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
|
| 314 |
overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
|
| 315 |
-
|
| 316 |
per_layer = defaultdict(lambda: defaultdict(list))
|
| 317 |
for r in results:
|
| 318 |
layer = r["target_layer"]
|
|
@@ -321,7 +306,6 @@ def save_checkpoint(output_file, config, results):
|
|
| 321 |
for k in kpi_fields:
|
| 322 |
if k in r:
|
| 323 |
per_layer[layer][k].append(r[k])
|
| 324 |
-
|
| 325 |
layer_summary = {}
|
| 326 |
for layer, metrics in per_layer.items():
|
| 327 |
ln = len(metrics["json_valid"])
|
|
@@ -333,22 +317,18 @@ def save_checkpoint(output_file, config, results):
|
|
| 333 |
for k in kpi_fields:
|
| 334 |
if k in metrics and metrics[k]:
|
| 335 |
layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
|
| 336 |
-
|
| 337 |
output = {
|
| 338 |
"config": config,
|
| 339 |
"overall": overall,
|
| 340 |
"per_layer": layer_summary,
|
| 341 |
"raw_results": results,
|
| 342 |
}
|
| 343 |
-
|
| 344 |
tmp = output_file + ".tmp"
|
| 345 |
with open(tmp, "w") as f:
|
| 346 |
json.dump(output, f, indent=2, default=str)
|
| 347 |
os.replace(tmp, output_file)
|
| 348 |
|
| 349 |
-
|
| 350 |
def load_checkpoint(output_file):
|
| 351 |
-
"""Load previously evaluated IDs for resume."""
|
| 352 |
if not os.path.exists(output_file):
|
| 353 |
return [], set()
|
| 354 |
try:
|
|
@@ -380,7 +360,6 @@ def compute_gt_baseline(ds):
|
|
| 380 |
kpi = check_kpi_fields(parsed, row, layer)
|
| 381 |
for k, v in kpi.items():
|
| 382 |
gt_results[layer][k].append(v)
|
| 383 |
-
|
| 384 |
log("\n Ground-truth baseline (metric ceiling):")
|
| 385 |
log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}")
|
| 386 |
log(" " + "β" * 55)
|
|
@@ -402,12 +381,18 @@ def main():
|
|
| 402 |
if args.no_flash_attn:
|
| 403 |
args.flash_attn = False
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
log("=" * 70)
|
| 406 |
log("TMF921 Intent Translation β Evaluation v3")
|
|
|
|
|
|
|
| 407 |
log(" Stratified sampling Β· Layer-aware max tokens Β· Incremental saves")
|
| 408 |
log("=" * 70)
|
| 409 |
log(f" Base model : {args.base_model}")
|
| 410 |
-
log(f" Adapter : {args.adapter_path}")
|
|
|
|
| 411 |
log(f" Dataset : {args.dataset} [{args.split}]")
|
| 412 |
log(f" Per-layer N : {args.per_layer} (-1 = all)")
|
| 413 |
log(f" Output : {args.output_file}")
|
|
@@ -469,18 +454,31 @@ def main():
|
|
| 469 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 470 |
|
| 471 |
base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
model.eval()
|
| 474 |
|
| 475 |
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
|
| 476 |
if tokenizer.pad_token is None:
|
| 477 |
tokenizer.pad_token = tokenizer.eos_token
|
| 478 |
|
| 479 |
-
log("
|
| 480 |
|
| 481 |
# ββ Inference loop ββ
|
| 482 |
results = list(prev_results)
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
t_start = time.time()
|
| 485 |
total_to_do = len(remaining_indices)
|
| 486 |
|
|
@@ -494,7 +492,18 @@ def main():
|
|
| 494 |
messages = row["messages"]
|
| 495 |
reference_output = messages[-1]["content"]
|
| 496 |
|
|
|
|
| 497 |
prompt_messages = [m for m in messages if m["role"] != "assistant"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
input_text = tokenizer.apply_chat_template(
|
| 499 |
prompt_messages, tokenize=False, add_generation_prompt=True
|
| 500 |
)
|
|
@@ -537,7 +546,7 @@ def main():
|
|
| 537 |
results.append(result)
|
| 538 |
|
| 539 |
# ββ Incremental save ββ
|
| 540 |
-
save_checkpoint(args.output_file,
|
| 541 |
|
| 542 |
# ββ Progress ββ
|
| 543 |
elapsed = time.time() - t_start
|
|
@@ -552,7 +561,8 @@ def main():
|
|
| 552 |
s = "β" if has_correct_structure else "β"
|
| 553 |
progress_n = len(done_ids) + done_now
|
| 554 |
progress_total = len(indices)
|
| 555 |
-
|
|
|
|
| 556 |
f"| {sample_time:.1f}s (max_tok={max_tokens}) | ETA: {eta_h}h{eta_m:02d}m")
|
| 557 |
|
| 558 |
# ββ Final summary ββ
|
|
@@ -562,7 +572,8 @@ def main():
|
|
| 562 |
total_struct = sum(1 for r in results if r["structure_correct"])
|
| 563 |
|
| 564 |
log(f"\n{'=' * 70}")
|
| 565 |
-
|
|
|
|
| 566 |
log(f"{'=' * 70}")
|
| 567 |
log(f" JSON Validity: {total_valid}/{n} ({total_valid/n*100:.1f}%)")
|
| 568 |
log(f" Structure Correct: {total_struct}/{n} ({total_struct/n*100:.1f}%)")
|
|
|
|
| 7 |
2. Layer-aware max tokens β caps generation length per layer (saves ~40% time)
|
| 8 |
3. Incremental saves β writes results after every sample (never lose progress)
|
| 9 |
4. Resume support β skips already-evaluated IDs from a previous checkpoint
|
| 10 |
+
5. Zero-shot baseline β --adapter_path none to evaluate base model without adapter
|
| 11 |
|
| 12 |
All KPI checking logic is identical to v2 (standard-aware).
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
Usage:
|
| 15 |
+
# Fine-tuned model evaluation
|
| 16 |
python evaluate_v3.py --adapter_path ./output
|
| 17 |
+
|
| 18 |
+
# Zero-shot baseline (no adapter, base model only)
|
| 19 |
+
python evaluate_v3.py --adapter_path none --output_file eval_v3_baseline.json
|
| 20 |
+
|
| 21 |
+
# Zero-shot baseline with Qwen3 thinking disabled
|
| 22 |
+
python evaluate_v3.py --adapter_path none --no_think --output_file eval_v3_baseline.json
|
| 23 |
+
|
| 24 |
+
# Resume interrupted run
|
| 25 |
+
python evaluate_v3.py --adapter_path ./output --resume
|
| 26 |
"""
|
| 27 |
|
| 28 |
import argparse, json, re, os, sys, math, time, random, torch
|
| 29 |
from collections import defaultdict
|
| 30 |
from datasets import load_dataset
|
| 31 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
# LAYER-AWARE MAX TOKENS
|
| 36 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
LAYER_MAX_TOKENS = {
|
| 39 |
+
"tmf921": 1600,
|
| 40 |
+
"intent_3gpp": 900,
|
| 41 |
+
"camara": 400,
|
| 42 |
+
"a1_policy": 350,
|
| 43 |
+
"o1_nrm": 500,
|
| 44 |
+
"etsi_zsm": 1100,
|
| 45 |
+
"tmf921_lifecycle_activate": 250,
|
| 46 |
+
"tmf921_lifecycle_modify": 600,
|
| 47 |
+
"tmf921_lifecycle_monitor": 500,
|
| 48 |
+
"tmf921_lifecycle_report": 800,
|
| 49 |
+
"tmf921_lifecycle_resume": 250,
|
| 50 |
+
"tmf921_lifecycle_scale": 350,
|
| 51 |
+
"tmf921_lifecycle_suspend": 250,
|
| 52 |
+
"tmf921_lifecycle_terminate": 250,
|
| 53 |
+
"adversarial_ambiguous": 200,
|
| 54 |
"adversarial_contradictory": 200,
|
| 55 |
"adversarial_out_of_scope": 200,
|
| 56 |
}
|
|
|
|
| 68 |
def parse_args():
|
| 69 |
p = argparse.ArgumentParser(description="TMF921 Evaluation v3 β stratified, incremental, fast")
|
| 70 |
p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
|
| 71 |
+
p.add_argument("--adapter_path", type=str, default="./output",
|
| 72 |
+
help="Path to LoRA adapter, or 'none' for zero-shot baseline")
|
| 73 |
p.add_argument("--dataset", type=str, default="nraptisss/TMF921-intent-to-config-augmented")
|
| 74 |
p.add_argument("--split", type=str, default="test")
|
| 75 |
p.add_argument("--per_layer", type=int, default=50,
|
|
|
|
| 80 |
help="Resume from existing output_file, skipping already-evaluated IDs")
|
| 81 |
p.add_argument("--flash_attn", action="store_true", default=True)
|
| 82 |
p.add_argument("--no_flash_attn", action="store_true", default=False)
|
| 83 |
+
p.add_argument("--no_think", action="store_true", default=False,
|
| 84 |
+
help="Suppress Qwen3 thinking mode by adding /no_think to the last user message")
|
| 85 |
p.add_argument("--save_generations", action="store_true", default=True)
|
| 86 |
return p.parse_args()
|
| 87 |
|
|
|
|
| 91 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 92 |
|
| 93 |
def stratified_sample(ds, per_layer: int, seed: int = 42):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
rng = random.Random(seed)
|
|
|
|
|
|
|
| 95 |
layer_indices = defaultdict(list)
|
| 96 |
for i in range(len(ds)):
|
| 97 |
layer_indices[ds[i]["target_layer"]].append(i)
|
|
|
|
| 98 |
selected = []
|
| 99 |
layer_counts = {}
|
| 100 |
for layer, indices in sorted(layer_indices.items()):
|
|
|
|
| 104 |
chosen = rng.sample(indices, per_layer)
|
| 105 |
selected.extend(chosen)
|
| 106 |
layer_counts[layer] = len(chosen)
|
|
|
|
|
|
|
| 107 |
rng.shuffle(selected)
|
|
|
|
| 108 |
return selected, layer_counts
|
| 109 |
|
| 110 |
|
| 111 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 112 |
+
# JSON PARSING
|
| 113 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
|
| 115 |
def try_parse_json(text: str):
|
| 116 |
text = text.strip()
|
| 117 |
+
# Strip thinking tags if present
|
| 118 |
+
think_match = re.match(r"<think>[\s\S]*?</think>\s*", text)
|
| 119 |
+
if think_match:
|
| 120 |
+
text = text[think_match.end():].strip()
|
| 121 |
if text.startswith("```"):
|
| 122 |
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
|
| 123 |
text = re.sub(r"\n?```\s*$", "", text)
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
# KPI CHECKING (standard-aware, identical to v2)
|
| 139 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
|
| 141 |
def _num_representations(val: float) -> list:
|
|
|
|
| 146 |
reps.append(f"{val:.0f}")
|
| 147 |
return list(set(reps))
|
| 148 |
|
|
|
|
| 149 |
def _reliability_representations(rel_pct: float) -> list:
|
| 150 |
reps = _num_representations(rel_pct)
|
| 151 |
per = 1 - rel_pct / 100
|
|
|
|
| 164 |
reps.append(f"{per:.10f}".rstrip("0").rstrip("."))
|
| 165 |
return list(set(reps))
|
| 166 |
|
|
|
|
| 167 |
def _find_all_numbers(parsed) -> list:
|
| 168 |
nums = []
|
| 169 |
if isinstance(parsed, dict):
|
|
|
|
| 180 |
nums.extend(_find_all_numbers(item))
|
| 181 |
return nums
|
| 182 |
|
|
|
|
| 183 |
def _check_kpi_direct(parsed, row, flat):
|
| 184 |
return {
|
| 185 |
"has_latency": any(r in flat for r in _num_representations(row["latency_ms"])),
|
|
|
|
| 189 |
"has_max_ues": any(r in flat for r in _num_representations(float(row["max_ues"]))),
|
| 190 |
}
|
| 191 |
|
|
|
|
| 192 |
def _check_kpi_a1_policy(parsed, row, flat):
|
| 193 |
results = {}
|
| 194 |
all_nums = _find_all_numbers(parsed)
|
|
|
|
| 208 |
results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat
|
| 209 |
return results
|
| 210 |
|
|
|
|
| 211 |
def _check_kpi_o1_nrm(parsed, row, flat):
|
| 212 |
return {
|
| 213 |
"has_latency": '"rrmpolicy"' in flat or '"nrcelldu"' in flat,
|
|
|
|
| 218 |
"has_max_ues": '"rrmpolicymemberlist"' in flat or '"snssai"' in flat,
|
| 219 |
}
|
| 220 |
|
|
|
|
| 221 |
DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"}
|
| 222 |
|
| 223 |
def check_kpi_fields(parsed, row, target_layer):
|
|
|
|
| 232 |
|
| 233 |
|
| 234 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 235 |
+
# STRUCTURE CHECKING
|
| 236 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 237 |
|
| 238 |
LAYER_ROOT_KEYS = {
|
|
|
|
| 280 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 281 |
|
| 282 |
def save_checkpoint(output_file, config, results):
|
|
|
|
| 283 |
n = len(results)
|
| 284 |
if n == 0:
|
| 285 |
return
|
|
|
|
| 286 |
total_valid = sum(1 for r in results if r["json_valid"])
|
| 287 |
total_struct = sum(1 for r in results if r["structure_correct"])
|
|
|
|
| 288 |
overall = {
|
| 289 |
"total_samples": n,
|
| 290 |
"json_validity_rate": total_valid / n,
|
| 291 |
"structure_correctness_rate": total_struct / n,
|
| 292 |
}
|
|
|
|
| 293 |
kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
|
| 294 |
kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
|
| 295 |
if kpi_samples:
|
|
|
|
| 298 |
overall[field + "_rate"] = sum(vals) / len(vals)
|
| 299 |
all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
|
| 300 |
overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
|
|
|
|
| 301 |
per_layer = defaultdict(lambda: defaultdict(list))
|
| 302 |
for r in results:
|
| 303 |
layer = r["target_layer"]
|
|
|
|
| 306 |
for k in kpi_fields:
|
| 307 |
if k in r:
|
| 308 |
per_layer[layer][k].append(r[k])
|
|
|
|
| 309 |
layer_summary = {}
|
| 310 |
for layer, metrics in per_layer.items():
|
| 311 |
ln = len(metrics["json_valid"])
|
|
|
|
| 317 |
for k in kpi_fields:
|
| 318 |
if k in metrics and metrics[k]:
|
| 319 |
layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
|
|
|
|
| 320 |
output = {
|
| 321 |
"config": config,
|
| 322 |
"overall": overall,
|
| 323 |
"per_layer": layer_summary,
|
| 324 |
"raw_results": results,
|
| 325 |
}
|
|
|
|
| 326 |
tmp = output_file + ".tmp"
|
| 327 |
with open(tmp, "w") as f:
|
| 328 |
json.dump(output, f, indent=2, default=str)
|
| 329 |
os.replace(tmp, output_file)
|
| 330 |
|
|
|
|
| 331 |
def load_checkpoint(output_file):
|
|
|
|
| 332 |
if not os.path.exists(output_file):
|
| 333 |
return [], set()
|
| 334 |
try:
|
|
|
|
| 360 |
kpi = check_kpi_fields(parsed, row, layer)
|
| 361 |
for k, v in kpi.items():
|
| 362 |
gt_results[layer][k].append(v)
|
|
|
|
| 363 |
log("\n Ground-truth baseline (metric ceiling):")
|
| 364 |
log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}")
|
| 365 |
log(" " + "β" * 55)
|
|
|
|
| 381 |
if args.no_flash_attn:
|
| 382 |
args.flash_attn = False
|
| 383 |
|
| 384 |
+
# Detect baseline mode
|
| 385 |
+
is_baseline = args.adapter_path.lower() in ("none", "baseline", "base", "")
|
| 386 |
+
|
| 387 |
log("=" * 70)
|
| 388 |
log("TMF921 Intent Translation β Evaluation v3")
|
| 389 |
+
if is_baseline:
|
| 390 |
+
log(" *** ZERO-SHOT BASELINE MODE (no adapter) ***")
|
| 391 |
log(" Stratified sampling Β· Layer-aware max tokens Β· Incremental saves")
|
| 392 |
log("=" * 70)
|
| 393 |
log(f" Base model : {args.base_model}")
|
| 394 |
+
log(f" Adapter : {'NONE (zero-shot baseline)' if is_baseline else args.adapter_path}")
|
| 395 |
+
log(f" No-think : {args.no_think}")
|
| 396 |
log(f" Dataset : {args.dataset} [{args.split}]")
|
| 397 |
log(f" Per-layer N : {args.per_layer} (-1 = all)")
|
| 398 |
log(f" Output : {args.output_file}")
|
|
|
|
| 454 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 455 |
|
| 456 |
base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
|
| 457 |
+
|
| 458 |
+
if is_baseline:
|
| 459 |
+
model = base_model
|
| 460 |
+
log(" β
Base model loaded (zero-shot baseline β no adapter)")
|
| 461 |
+
else:
|
| 462 |
+
from peft import PeftModel
|
| 463 |
+
model = PeftModel.from_pretrained(base_model, args.adapter_path)
|
| 464 |
+
log(f" β
Fine-tuned model loaded (adapter: {args.adapter_path})")
|
| 465 |
+
|
| 466 |
model.eval()
|
| 467 |
|
| 468 |
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
|
| 469 |
if tokenizer.pad_token is None:
|
| 470 |
tokenizer.pad_token = tokenizer.eos_token
|
| 471 |
|
| 472 |
+
log("")
|
| 473 |
|
| 474 |
# ββ Inference loop ββ
|
| 475 |
results = list(prev_results)
|
| 476 |
+
config_dict = {
|
| 477 |
+
**vars(args),
|
| 478 |
+
"is_baseline": is_baseline,
|
| 479 |
+
"total_selected": len(indices),
|
| 480 |
+
"layer_counts": dict(layer_counts),
|
| 481 |
+
}
|
| 482 |
t_start = time.time()
|
| 483 |
total_to_do = len(remaining_indices)
|
| 484 |
|
|
|
|
| 492 |
messages = row["messages"]
|
| 493 |
reference_output = messages[-1]["content"]
|
| 494 |
|
| 495 |
+
# Build prompt messages (system + user, no assistant)
|
| 496 |
prompt_messages = [m for m in messages if m["role"] != "assistant"]
|
| 497 |
+
|
| 498 |
+
# Optionally suppress Qwen3 thinking mode
|
| 499 |
+
if args.no_think and prompt_messages:
|
| 500 |
+
last_msg = prompt_messages[-1]
|
| 501 |
+
if last_msg["role"] == "user" and "/no_think" not in last_msg["content"]:
|
| 502 |
+
prompt_messages[-1] = {
|
| 503 |
+
"role": last_msg["role"],
|
| 504 |
+
"content": last_msg["content"] + " /no_think"
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
input_text = tokenizer.apply_chat_template(
|
| 508 |
prompt_messages, tokenize=False, add_generation_prompt=True
|
| 509 |
)
|
|
|
|
| 546 |
results.append(result)
|
| 547 |
|
| 548 |
# ββ Incremental save ββ
|
| 549 |
+
save_checkpoint(args.output_file, config_dict, results)
|
| 550 |
|
| 551 |
# ββ Progress ββ
|
| 552 |
elapsed = time.time() - t_start
|
|
|
|
| 561 |
s = "β" if has_correct_structure else "β"
|
| 562 |
progress_n = len(done_ids) + done_now
|
| 563 |
progress_total = len(indices)
|
| 564 |
+
mode_tag = "[BASE]" if is_baseline else "[FT]"
|
| 565 |
+
log(f" {mode_tag} [{progress_n:>4}/{progress_total}] {target_layer:<30} JSON:{j} Struct:{s} "
|
| 566 |
f"| {sample_time:.1f}s (max_tok={max_tokens}) | ETA: {eta_h}h{eta_m:02d}m")
|
| 567 |
|
| 568 |
# ββ Final summary ββ
|
|
|
|
| 572 |
total_struct = sum(1 for r in results if r["structure_correct"])
|
| 573 |
|
| 574 |
log(f"\n{'=' * 70}")
|
| 575 |
+
mode_str = "ZERO-SHOT BASELINE" if is_baseline else "FINE-TUNED"
|
| 576 |
+
log(f"FINAL RESULTS β {mode_str} ({n} samples, {total_time/3600:.1f}h)")
|
| 577 |
log(f"{'=' * 70}")
|
| 578 |
log(f" JSON Validity: {total_valid}/{n} ({total_valid/n*100:.1f}%)")
|
| 579 |
log(f" Structure Correct: {total_struct}/{n} ({total_struct/n*100:.1f}%)")
|