#!/usr/bin/env python3 """ TMF921 Intent Translation — Evaluation Script v3 ================================================= Optimized evaluation with: 1. Stratified sampling — N samples per target_layer (covers all layers evenly) 2. Layer-aware max tokens — caps generation length per layer (saves ~40% time) 3. Incremental saves — writes results after every sample (never lose progress) 4. Resume support — skips already-evaluated IDs from a previous checkpoint 5. Zero-shot baseline — --adapter_path none to evaluate base model without adapter All KPI checking logic is identical to v2 (standard-aware). Usage: # Fine-tuned model evaluation python evaluate_v3.py --adapter_path ./output # Zero-shot baseline (no adapter, base model only) python evaluate_v3.py --adapter_path none --output_file eval_v3_baseline.json # Zero-shot baseline with Qwen3 thinking disabled python evaluate_v3.py --adapter_path none --no_think --output_file eval_v3_baseline.json # Resume interrupted run python evaluate_v3.py --adapter_path ./output --resume """ import argparse, json, re, os, sys, math, time, random, torch from collections import defaultdict from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # ═══════════════════════════════════════════════════════════════════════ # LAYER-AWARE MAX TOKENS # ═══════════════════════════════════════════════════════════════════════ LAYER_MAX_TOKENS = { "tmf921": 1600, "intent_3gpp": 900, "camara": 400, "a1_policy": 350, "o1_nrm": 500, "etsi_zsm": 1100, "tmf921_lifecycle_activate": 250, "tmf921_lifecycle_modify": 600, "tmf921_lifecycle_monitor": 500, "tmf921_lifecycle_report": 800, "tmf921_lifecycle_resume": 250, "tmf921_lifecycle_scale": 350, "tmf921_lifecycle_suspend": 250, "tmf921_lifecycle_terminate": 250, "adversarial_ambiguous": 200, "adversarial_contradictory": 200, "adversarial_out_of_scope": 200, } DEFAULT_MAX_TOKENS = 2048 # ═══════════════════════════════════════════════════════════════════════ # HELPERS # ═══════════════════════════════════════════════════════════════════════ def log(msg: str): print(msg, flush=True) def parse_args(): p = argparse.ArgumentParser(description="TMF921 Evaluation v3 — stratified, incremental, fast") p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B") p.add_argument("--adapter_path", type=str, default="./output", help="Path to LoRA adapter, or 'none' for zero-shot baseline") p.add_argument("--dataset", type=str, default="nraptisss/TMF921-intent-to-config-augmented") p.add_argument("--split", type=str, default="test") p.add_argument("--per_layer", type=int, default=50, help="Samples per target_layer. Layers with fewer samples use all. -1 = all.") p.add_argument("--seed", type=int, default=42, help="Random seed for stratified sampling") p.add_argument("--output_file", type=str, default="eval_v3_results.json") p.add_argument("--resume", action="store_true", help="Resume from existing output_file, skipping already-evaluated IDs") p.add_argument("--flash_attn", action="store_true", default=True) p.add_argument("--no_flash_attn", action="store_true", default=False) p.add_argument("--no_think", action="store_true", default=False, help="Suppress Qwen3 thinking mode by adding /no_think to the last user message") p.add_argument("--save_generations", action="store_true", default=True) return p.parse_args() # ═══════════════════════════════════════════════════════════════════════ # STRATIFIED SAMPLING # ═══════════════════════════════════════════════════════════════════════ def stratified_sample(ds, per_layer: int, seed: int = 42): rng = random.Random(seed) layer_indices = defaultdict(list) for i in range(len(ds)): layer_indices[ds[i]["target_layer"]].append(i) selected = [] layer_counts = {} for layer, indices in sorted(layer_indices.items()): if per_layer < 0 or len(indices) <= per_layer: chosen = indices else: chosen = rng.sample(indices, per_layer) selected.extend(chosen) layer_counts[layer] = len(chosen) rng.shuffle(selected) return selected, layer_counts # ═══════════════════════════════════════════════════════════════════════ # JSON PARSING # ═══════════════════════════════════════════════════════════════════════ def try_parse_json(text: str): text = text.strip() # Strip thinking tags if present think_match = re.match(r"[\s\S]*?\s*", text) if think_match: text = text[think_match.end():].strip() if text.startswith("```"): text = re.sub(r"^```(?:json)?\s*\n?", "", text) text = re.sub(r"\n?```\s*$", "", text) try: return json.loads(text), True except json.JSONDecodeError: pass match = re.search(r"\{[\s\S]*\}", text) if match: try: return json.loads(match.group()), True except json.JSONDecodeError: pass return None, False # ═══════════════════════════════════════════════════════════════════════ # KPI CHECKING (standard-aware, identical to v2) # ═══════════════════════════════════════════════════════════════════════ def _num_representations(val: float) -> list: reps = [str(val)] if val == int(val): reps.append(str(int(val))) reps.append(f"{val:.1f}") reps.append(f"{val:.0f}") return list(set(reps)) def _reliability_representations(rel_pct: float) -> list: reps = _num_representations(rel_pct) per = 1 - rel_pct / 100 if per > 0: exp = math.floor(math.log10(per)) mantissa = per / (10 ** exp) reps.append(f"1e-{abs(exp):02d}") reps.append(f"1e-{abs(exp)}") reps.append(f"{per:.0e}") reps.append(f"{per}") if mantissa == 1.0: reps.append(f"1e-{abs(exp):02d}") else: reps.append(f"{mantissa:.1f}e-{abs(exp):02d}") if per < 1: reps.append(f"{per:.10f}".rstrip("0").rstrip(".")) return list(set(reps)) def _find_all_numbers(parsed) -> list: nums = [] if isinstance(parsed, dict): for v in parsed.values(): if isinstance(v, (int, float)) and not isinstance(v, bool): nums.append(float(v)) elif isinstance(v, (dict, list)): nums.extend(_find_all_numbers(v)) elif isinstance(parsed, list): for item in parsed: if isinstance(item, (int, float)) and not isinstance(item, bool): nums.append(float(item)) elif isinstance(item, (dict, list)): nums.extend(_find_all_numbers(item)) return nums def _check_kpi_direct(parsed, row, flat): return { "has_latency": any(r in flat for r in _num_representations(row["latency_ms"])), "has_reliability": any(r in flat for r in _reliability_representations(row["reliability_pct"])), "has_dl_throughput": any(r in flat for r in _num_representations(row["dl_throughput_mbps"])), "has_ul_throughput": any(r in flat for r in _num_representations(row["ul_throughput_mbps"])), "has_max_ues": any(r in flat for r in _num_representations(float(row["max_ues"]))), } def _check_kpi_a1_policy(parsed, row, flat): results = {} all_nums = _find_all_numbers(parsed) target_rel = row["reliability_pct"] rel_found = any(r in flat for r in _reliability_representations(target_rel)) if not rel_found: per = 1 - target_rel / 100 for n in all_nums: if abs(n - target_rel) < 0.01: rel_found = True; break if per > 0 and n > 0 and abs(n - per) / max(per, 1e-15) < 0.1: rel_found = True; break results["has_reliability"] = rel_found results["has_latency"] = '"pdb"' in flat or '"packetdelaybudget"' in flat results["has_dl_throughput"] = '"gfbr"' in flat or '"mfbr"' in flat or '"guaranteedflowbitrate"' in flat results["has_ul_throughput"] = results["has_dl_throughput"] results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat return results def _check_kpi_o1_nrm(parsed, row, flat): return { "has_latency": '"rrmpolicy"' in flat or '"nrcelldu"' in flat, "has_reliability": '"operationalstate"' in flat or '"administrativestate"' in flat, "has_dl_throughput": '"bschannelbwdl"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcndl"' in flat, "has_ul_throughput": ('"bschannelbwul"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcnul"' in flat or '"rrmpolicydedicatedratio"' in flat), "has_max_ues": '"rrmpolicymemberlist"' in flat or '"snssai"' in flat, } DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"} def check_kpi_fields(parsed, row, target_layer): flat = json.dumps(parsed).lower() if target_layer in DIRECT_KPI_LAYERS: return _check_kpi_direct(parsed, row, flat) elif target_layer == "a1_policy": return _check_kpi_a1_policy(parsed, row, flat) elif target_layer == "o1_nrm": return _check_kpi_o1_nrm(parsed, row, flat) return _check_kpi_direct(parsed, row, flat) # ═══════════════════════════════════════════════════════════════════════ # STRUCTURE CHECKING # ═══════════════════════════════════════════════════════════════════════ LAYER_ROOT_KEYS = { "tmf921": ["id", "href", "name", "intentexpression"], "intent_3gpp": ["intent"], "camara": ["networkslicebooking"], "etsi_zsm": ["zsmintent"], "a1_policy": ["a1policy"], "o1_nrm": ["managedelement"], } ADVERSARIAL_STATUSES = {"CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"} LIFECYCLE_LAYERS = { "tmf921_lifecycle_activate", "tmf921_lifecycle_modify", "tmf921_lifecycle_suspend", "tmf921_lifecycle_resume", "tmf921_lifecycle_terminate", "tmf921_lifecycle_scale", "tmf921_lifecycle_monitor", "tmf921_lifecycle_report", } LIFECYCLE_KEYS = { "tmf921_lifecycle_activate": ["intentpatch", "intentactivation"], "tmf921_lifecycle_modify": ["intentpatch", "intentupdate", "intentmodification"], "tmf921_lifecycle_suspend": ["intentpatch", "intentsuspension"], "tmf921_lifecycle_resume": ["intentpatch", "intentresumption"], "tmf921_lifecycle_terminate": ["intentpatch", "intenttermination"], "tmf921_lifecycle_scale": ["intentpatch", "intentscaling"], "tmf921_lifecycle_monitor": ["intentassurancereport", "intentmonitor", "intentfulfillmentreport", "monitoringreport", "fulfillmentinfo", "report"], "tmf921_lifecycle_report": ["intentassurancereport", "intentreport", "fulfillmentinfo", "report"], } def check_structure(parsed, target_layer): if target_layer.startswith("adversarial"): return parsed.get("status") in ADVERSARIAL_STATUSES if target_layer in LIFECYCLE_LAYERS: flat_keys = {k.lower() for k in parsed.keys()} return any(k in flat_keys for k in LIFECYCLE_KEYS.get(target_layer, [])) expected = LAYER_ROOT_KEYS.get(target_layer, []) if not expected: return True flat_keys = {k.lower() for k in parsed.keys()} return any(k in flat_keys for k in expected) # ═══════════════════════════════════════════════════════════════════════ # INCREMENTAL SAVE / RESUME # ═══════════════════════════════════════════════════════════════════════ def save_checkpoint(output_file, config, results): n = len(results) if n == 0: return total_valid = sum(1 for r in results if r["json_valid"]) total_struct = sum(1 for r in results if r["structure_correct"]) overall = { "total_samples": n, "json_validity_rate": total_valid / n, "structure_correctness_rate": total_struct / n, } kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"] kpi_samples = [r for r in results if any(k in r for k in kpi_fields)] if kpi_samples: for field in kpi_fields: vals = [r.get(field, False) for r in kpi_samples] overall[field + "_rate"] = sum(vals) / len(vals) all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples] overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi) per_layer = defaultdict(lambda: defaultdict(list)) for r in results: layer = r["target_layer"] per_layer[layer]["json_valid"].append(r["json_valid"]) per_layer[layer]["structure_correct"].append(r["structure_correct"]) for k in kpi_fields: if k in r: per_layer[layer][k].append(r[k]) layer_summary = {} for layer, metrics in per_layer.items(): ln = len(metrics["json_valid"]) layer_summary[layer] = { "n": ln, "json_valid": sum(metrics["json_valid"]) / ln, "structure_correct": sum(metrics["structure_correct"]) / ln, } for k in kpi_fields: if k in metrics and metrics[k]: layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k]) output = { "config": config, "overall": overall, "per_layer": layer_summary, "raw_results": results, } tmp = output_file + ".tmp" with open(tmp, "w") as f: json.dump(output, f, indent=2, default=str) os.replace(tmp, output_file) def load_checkpoint(output_file): if not os.path.exists(output_file): return [], set() try: with open(output_file) as f: data = json.load(f) results = data.get("raw_results", []) done_ids = {r["id"] for r in results} log(f" Resuming from checkpoint: {len(results)} samples already done") return results, done_ids except (json.JSONDecodeError, KeyError): log(" Warning: checkpoint file corrupted, starting fresh") return [], set() # ═══════════════════════════════════════════════════════════════════════ # GROUND-TRUTH BASELINE # ═══════════════════════════════════════════════════════════════════════ def compute_gt_baseline(ds): gt_results = defaultdict(lambda: defaultdict(list)) for row in ds: layer = row["target_layer"] if layer.startswith("adversarial") or layer in LIFECYCLE_LAYERS: continue gt_text = row["messages"][-1]["content"] parsed, valid = try_parse_json(gt_text) if not parsed: continue kpi = check_kpi_fields(parsed, row, layer) for k, v in kpi.items(): gt_results[layer][k].append(v) log("\n Ground-truth baseline (metric ceiling):") log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}") log(" " + "─" * 55) for layer in sorted(gt_results.keys()): m = gt_results[layer] def rate(key): vals = m.get(key, []) return sum(vals) / len(vals) * 100 if vals else 0 log(f" {layer:<20} {rate('has_latency'):>7.1f}% {rate('has_reliability'):>7.1f}% " f"{rate('has_dl_throughput'):>7.1f}% {rate('has_ul_throughput'):>7.1f}% {rate('has_max_ues'):>7.1f}%") # ═══════════════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════════════ def main(): args = parse_args() if args.no_flash_attn: args.flash_attn = False # Detect baseline mode is_baseline = args.adapter_path.lower() in ("none", "baseline", "base", "") log("=" * 70) log("TMF921 Intent Translation — Evaluation v3") if is_baseline: log(" *** ZERO-SHOT BASELINE MODE (no adapter) ***") log(" Stratified sampling · Layer-aware max tokens · Incremental saves") log("=" * 70) log(f" Base model : {args.base_model}") log(f" Adapter : {'NONE (zero-shot baseline)' if is_baseline else args.adapter_path}") log(f" No-think : {args.no_think}") log(f" Dataset : {args.dataset} [{args.split}]") log(f" Per-layer N : {args.per_layer} (-1 = all)") log(f" Output : {args.output_file}") log(f" Resume : {args.resume}") log("=" * 70) # ── Load dataset ── log("\nLoading dataset …") ds = load_dataset(args.dataset, split=args.split) log(f" Full test set: {len(ds)} samples") # ── Ground-truth baseline ── log("\nComputing ground-truth metric baseline …") compute_gt_baseline(ds) # ── Stratified sampling ── if args.per_layer < 0: indices = list(range(len(ds))) layer_counts = defaultdict(int) for i in indices: layer_counts[ds[i]["target_layer"]] += 1 random.Random(args.seed).shuffle(indices) else: indices, layer_counts = stratified_sample(ds, args.per_layer, args.seed) log(f"\n Stratified sample: {len(indices)} total samples") for layer, cnt in sorted(layer_counts.items()): log(f" {layer:<35} {cnt:>4}") # ── Resume from checkpoint ── if args.resume: prev_results, done_ids = load_checkpoint(args.output_file) else: prev_results, done_ids = [], set() remaining_indices = [i for i in indices if ds[i]["id"] not in done_ids] log(f"\n Already evaluated: {len(done_ids)}") log(f" Remaining: {len(remaining_indices)}") if not remaining_indices: log("\n All samples already evaluated! Nothing to do.") log(f" Results in: {args.output_file}") return # ── Load model ── log("\nLoading model …") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model_kwargs = { "quantization_config": bnb_config, "device_map": "auto", "trust_remote_code": True, } if args.flash_attn: model_kwargs["attn_implementation"] = "flash_attention_2" base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs) if is_baseline: model = base_model log(" ✅ Base model loaded (zero-shot baseline — no adapter)") else: from peft import PeftModel model = PeftModel.from_pretrained(base_model, args.adapter_path) log(f" ✅ Fine-tuned model loaded (adapter: {args.adapter_path})") model.eval() tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token log("") # ── Inference loop ── results = list(prev_results) config_dict = { **vars(args), "is_baseline": is_baseline, "total_selected": len(indices), "layer_counts": dict(layer_counts), } t_start = time.time() total_to_do = len(remaining_indices) for eval_idx, ds_idx in enumerate(remaining_indices): row = ds[ds_idx] target_layer = row["target_layer"] t0 = time.time() max_tokens = LAYER_MAX_TOKENS.get(target_layer, DEFAULT_MAX_TOKENS) messages = row["messages"] reference_output = messages[-1]["content"] # Build prompt messages (system + user, no assistant) prompt_messages = [m for m in messages if m["role"] != "assistant"] # Optionally suppress Qwen3 thinking mode if args.no_think and prompt_messages: last_msg = prompt_messages[-1] if last_msg["role"] == "user" and "/no_think" not in last_msg["content"]: prompt_messages[-1] = { "role": last_msg["role"], "content": last_msg["content"] + " /no_think" } input_text = tokenizer.apply_chat_template( prompt_messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(input_text, return_tensors="pt").to(model.device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=False, temperature=None, top_p=None, ) generated_ids = output_ids[0][inputs["input_ids"].shape[1]:] generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) parsed, is_valid_json = try_parse_json(generated_text) has_correct_structure = check_structure(parsed, target_layer) if parsed else False kpi_results = {} if parsed and not target_layer.startswith("adversarial") and target_layer not in LIFECYCLE_LAYERS: kpi_results = check_kpi_fields(parsed, row, target_layer) result = { "id": row["id"], "target_layer": target_layer, "slice_type": row["slice_type"], "lifecycle_operation": row["lifecycle_operation"], "json_valid": is_valid_json, "structure_correct": has_correct_structure, **kpi_results, "generated_length": len(generated_text), "reference_length": len(reference_output), "max_tokens_used": max_tokens, } if args.save_generations: result["generated_text"] = generated_text results.append(result) # ── Incremental save ── save_checkpoint(args.output_file, config_dict, results) # ── Progress ── elapsed = time.time() - t_start sample_time = time.time() - t0 done_now = eval_idx + 1 avg_time = elapsed / done_now remaining_time = avg_time * (total_to_do - done_now) eta_h, eta_rem = divmod(int(remaining_time), 3600) eta_m = eta_rem // 60 j = "✓" if is_valid_json else "✗" s = "✓" if has_correct_structure else "✗" progress_n = len(done_ids) + done_now progress_total = len(indices) mode_tag = "[BASE]" if is_baseline else "[FT]" log(f" {mode_tag} [{progress_n:>4}/{progress_total}] {target_layer:<30} JSON:{j} Struct:{s} " f"| {sample_time:.1f}s (max_tok={max_tokens}) | ETA: {eta_h}h{eta_m:02d}m") # ── Final summary ── total_time = time.time() - t_start n = len(results) total_valid = sum(1 for r in results if r["json_valid"]) total_struct = sum(1 for r in results if r["structure_correct"]) log(f"\n{'=' * 70}") mode_str = "ZERO-SHOT BASELINE" if is_baseline else "FINE-TUNED" log(f"FINAL RESULTS — {mode_str} ({n} samples, {total_time/3600:.1f}h)") log(f"{'=' * 70}") log(f" JSON Validity: {total_valid}/{n} ({total_valid/n*100:.1f}%)") log(f" Structure Correct: {total_struct}/{n} ({total_struct/n*100:.1f}%)") kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"] kpi_samples = [r for r in results if any(k in r for k in kpi_fields)] if kpi_samples: all_kpi = sum(1 for r in kpi_samples if all(r.get(f, False) for f in kpi_fields)) log(f" All KPIs Correct: {all_kpi}/{len(kpi_samples)} ({all_kpi/len(kpi_samples)*100:.1f}%)") per_layer_agg = defaultdict(lambda: defaultdict(list)) for r in results: layer = r["target_layer"] per_layer_agg[layer]["json_valid"].append(r["json_valid"]) per_layer_agg[layer]["structure_correct"].append(r["structure_correct"]) for k in kpi_fields: if k in r: per_layer_agg[layer][k].append(r[k]) log(f"\n{'Layer':<30} {'N':>4} {'JSON':>6} {'Struct':>7} | {'Lat':>5} {'Rel':>5} {'DL':>5} {'UL':>5} {'UEs':>5}") log("─" * 85) for layer in sorted(per_layer_agg.keys()): m = per_layer_agg[layer] ln = len(m["json_valid"]) jv = sum(m["json_valid"]) / ln * 100 sc = sum(m["structure_correct"]) / ln * 100 def fmt(k): return f"{sum(m[k])/len(m[k])*100:.0f}%" if k in m and m[k] else "—" log(f" {layer:<28} {ln:>4} {jv:>5.0f}% {sc:>6.0f}% | " f"{fmt('has_latency'):>5} {fmt('has_reliability'):>5} {fmt('has_dl_throughput'):>5} " f"{fmt('has_ul_throughput'):>5} {fmt('has_max_ues'):>5}") log(f"\n✅ Results saved to {args.output_file}") log(f" (incremental — every sample was saved as it completed)") if __name__ == "__main__": main()