#!/usr/bin/env python3 """ TMF921 Intent Translation — Evaluation Script v2 ================================================= Standard-aware KPI checking that correctly handles how each telecom standard encodes network parameters: TMF921, 3GPP TS 28.312, CAMARA, ETSI ZSM: → KPI values embedded directly (with int/float tolerance: 99 vs 99.0) O-RAN A1 Policy: → reliability → packet error rate (PER): 99.999% → 1e-05 → latency → packet delay budget (pdb): mapped via 5QI table → throughput → gfbr/mfbr (guaranteed/maximum flow bitrate) O-RAN O1 NRM (3GPP TS 28.541): → KPIs translated to radio resource management configs (RRM policies, cell parameters, frequency allocations). No direct numeric values. → Evaluated via structural element presence. Changes from v1: - Fixes metric bug where 92% of "reliability failures" were false negatives - Adds ground-truth baseline (metric ceiling) printed before evaluation - Standard-specific KPI checking (3 strategies) - Expanded lifecycle operation key matching - Saves generated text for error analysis - Flushes stdout on every print (fixes nohup buffering) Usage: python evaluate_v2.py --adapter_path ./output --num_samples 200 python evaluate_v2.py --adapter_path ./output --num_samples -1 """ import argparse, json, re, os, sys, math, time, torch from collections import defaultdict from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, ) from peft import PeftModel def log(msg: str): """Print with flush so nohup logs update in real time.""" print(msg, flush=True) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B") p.add_argument("--adapter_path", type=str, default="./output", help="Path or HF id of LoRA adapter") p.add_argument("--dataset", type=str, default="nraptisss/TMF921-intent-to-config-augmented") p.add_argument("--split", type=str, default="test") p.add_argument("--num_samples", type=int, default=200, help="Number of samples to evaluate (-1 for all)") p.add_argument("--max_new_tokens", type=int, default=4096) p.add_argument("--output_file", type=str, default="eval_results_v2.json") p.add_argument("--flash_attn", action="store_true", default=True) p.add_argument("--save_generations", action="store_true", default=True, help="Save generated text in results for error analysis") return p.parse_args() # ── JSON Parsing ───────────────────────────────────────────────────── def try_parse_json(text: str) -> tuple[dict | None, bool]: """Try to parse JSON from model output, handling markdown fences.""" text = text.strip() if text.startswith("```"): text = re.sub(r"^```(?:json)?\s*\n?", "", text) text = re.sub(r"\n?```\s*$", "", text) try: return json.loads(text), True except json.JSONDecodeError: pass match = re.search(r"\{[\s\S]*\}", text) if match: try: return json.loads(match.group()), True except json.JSONDecodeError: pass return None, False # ── Standard-aware KPI checking ───────────────────────────────────── def _num_representations(val: float) -> list[str]: """Generate multiple string representations of a numeric value.""" reps = [str(val)] if val == int(val): reps.append(str(int(val))) reps.append(f"{val:.1f}") reps.append(f"{val:.0f}") return list(set(reps)) def _reliability_representations(rel_pct: float) -> list[str]: """Generate all plausible encodings of a reliability percentage.""" reps = _num_representations(rel_pct) per = 1 - rel_pct / 100 if per > 0: exp = math.floor(math.log10(per)) mantissa = per / (10 ** exp) reps.append(f"1e-{abs(exp):02d}") reps.append(f"1e-{abs(exp)}") reps.append(f"{per:.0e}") reps.append(f"{per}") if mantissa == 1.0: reps.append(f"1e-{abs(exp):02d}") else: reps.append(f"{mantissa:.1f}e-{abs(exp):02d}") if per < 1: reps.append(f"{per:.10f}".rstrip("0").rstrip(".")) return list(set(reps)) def _find_all_numbers(parsed: dict) -> list[float]: """Extract all numeric values from a nested JSON structure.""" nums = [] if isinstance(parsed, dict): for v in parsed.values(): if isinstance(v, (int, float)) and not isinstance(v, bool): nums.append(float(v)) elif isinstance(v, (dict, list)): nums.extend(_find_all_numbers(v)) elif isinstance(parsed, list): for item in parsed: if isinstance(item, (int, float)) and not isinstance(item, bool): nums.append(float(item)) elif isinstance(item, (dict, list)): nums.extend(_find_all_numbers(item)) return nums def _check_kpi_direct(parsed: dict, row: dict, flat: str) -> dict: """Direct KPI matching for TMF921, intent_3gpp, CAMARA, ETSI ZSM.""" results = {} results["has_latency"] = any(rep in flat for rep in _num_representations(row["latency_ms"])) results["has_reliability"] = any(rep in flat for rep in _reliability_representations(row["reliability_pct"])) results["has_dl_throughput"] = any(rep in flat for rep in _num_representations(row["dl_throughput_mbps"])) results["has_ul_throughput"] = any(rep in flat for rep in _num_representations(row["ul_throughput_mbps"])) results["has_max_ues"] = any(rep in flat for rep in _num_representations(float(row["max_ues"]))) return results def _check_kpi_a1_policy(parsed: dict, row: dict, flat: str) -> dict: """A1 Policy: reliability→PER, latency→pdb, throughput→gfbr/mfbr.""" results = {} all_nums = _find_all_numbers(parsed) target_rel = row["reliability_pct"] rel_found = any(rep in flat for rep in _reliability_representations(target_rel)) if not rel_found: per = 1 - target_rel / 100 for n in all_nums: if abs(n - target_rel) < 0.01: rel_found = True; break if per > 0 and n > 0 and abs(n - per) / max(per, 1e-15) < 0.1: rel_found = True; break results["has_reliability"] = rel_found results["has_latency"] = '"pdb"' in flat or '"packetdelaybudget"' in flat has_tput = '"gfbr"' in flat or '"mfbr"' in flat or '"guaranteedflowbitrate"' in flat results["has_dl_throughput"] = has_tput results["has_ul_throughput"] = has_tput results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat return results def _check_kpi_o1_nrm(parsed: dict, row: dict, flat: str) -> dict: """O1 NRM: structural element presence (KPIs→RRM policies, not direct values).""" results = {} results["has_latency"] = '"rrmpolicy"' in flat or '"nrcelldu"' in flat results["has_reliability"] = '"operationalstate"' in flat or '"administrativestate"' in flat results["has_dl_throughput"] = '"bschannelbwdl"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcndl"' in flat results["has_ul_throughput"] = ( '"bschannelbwul"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcnul"' in flat or '"rrmpolicydedicatedratio"' in flat ) results["has_max_ues"] = '"rrmpolicymemberlist"' in flat or '"snssai"' in flat return results DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"} def check_kpi_fields(parsed: dict, row: dict, target_layer: str) -> dict: """Standard-aware KPI checking: direct / A1 Policy / O1 NRM strategies.""" flat = json.dumps(parsed).lower() if target_layer in DIRECT_KPI_LAYERS: return _check_kpi_direct(parsed, row, flat) elif target_layer == "a1_policy": return _check_kpi_a1_policy(parsed, row, flat) elif target_layer == "o1_nrm": return _check_kpi_o1_nrm(parsed, row, flat) else: return _check_kpi_direct(parsed, row, flat) # ── Structure checking ─────────────────────────────────────────────── LAYER_ROOT_KEYS = { "tmf921": ["id", "href", "name", "intentexpression"], "intent_3gpp": ["intent"], "camara": ["networkslicebooking"], "etsi_zsm": ["zsmintent"], "a1_policy": ["a1policy"], "o1_nrm": ["managedelement"], } ADVERSARIAL_STATUSES = {"CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"} LIFECYCLE_LAYERS = { "tmf921_lifecycle_activate", "tmf921_lifecycle_modify", "tmf921_lifecycle_suspend", "tmf921_lifecycle_resume", "tmf921_lifecycle_terminate", "tmf921_lifecycle_scale", "tmf921_lifecycle_monitor", "tmf921_lifecycle_report", } LIFECYCLE_KEYS = { "tmf921_lifecycle_activate": ["intentpatch", "intentactivation"], "tmf921_lifecycle_modify": ["intentpatch", "intentupdate", "intentmodification"], "tmf921_lifecycle_suspend": ["intentpatch", "intentsuspension"], "tmf921_lifecycle_resume": ["intentpatch", "intentresumption"], "tmf921_lifecycle_terminate": ["intentpatch", "intenttermination"], "tmf921_lifecycle_scale": ["intentpatch", "intentscaling"], "tmf921_lifecycle_monitor": ["intentassurancereport", "intentmonitor", "intentfulfillmentreport", "monitoringreport", "fulfillmentinfo", "report"], "tmf921_lifecycle_report": ["intentassurancereport", "intentreport", "fulfillmentinfo", "report"], } def check_structure(parsed: dict, target_layer: str) -> bool: """Check if the JSON has the expected root keys for the target standard.""" if target_layer.startswith("adversarial"): return parsed.get("status") in ADVERSARIAL_STATUSES if target_layer in LIFECYCLE_LAYERS: flat_keys = {k.lower() for k in parsed.keys()} expected = LIFECYCLE_KEYS.get(target_layer, []) return any(k in flat_keys for k in expected) expected = LAYER_ROOT_KEYS.get(target_layer, []) if not expected: return True flat_keys = {k.lower() for k in parsed.keys()} return any(k in flat_keys for k in expected) # ── Ground-truth baseline ──────────────────────────────────────────── def compute_gt_baseline(ds): """Run the KPI checker against ground truth to establish metric ceiling.""" gt_results = defaultdict(lambda: defaultdict(list)) for row in ds: layer = row["target_layer"] if layer.startswith("adversarial") or layer in LIFECYCLE_LAYERS: continue gt_text = row["messages"][-1]["content"] parsed, valid = try_parse_json(gt_text) if not parsed: continue kpi = check_kpi_fields(parsed, row, layer) for k, v in kpi.items(): gt_results[layer][k].append(v) log("\n Ground-truth baseline (metric ceiling — should be 100% for all):") log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}") log(" " + "─" * 55) for layer in sorted(gt_results.keys()): metrics = gt_results[layer] def rate(key): vals = metrics.get(key, []) return sum(vals) / len(vals) * 100 if vals else 0 log(f" {layer:<20} {rate('has_latency'):>7.1f}% {rate('has_reliability'):>7.1f}% " f"{rate('has_dl_throughput'):>7.1f}% {rate('has_ul_throughput'):>7.1f}% {rate('has_max_ues'):>7.1f}%") return gt_results # ── Main evaluation ────────────────────────────────────────────────── def main(): args = parse_args() log("=" * 70) log("TMF921 Intent Translation — Evaluation v2") log("=" * 70) log(f"Base model : {args.base_model}") log(f"Adapter : {args.adapter_path}") log(f"Dataset : {args.dataset} [{args.split}]") log(f"Num samples : {args.num_samples}") log(f"KPI checking : standard-aware (v2)") log("=" * 70) # Load dataset log("\nLoading dataset …") ds = load_dataset(args.dataset, split=args.split) # Compute ground-truth baseline on full test set log("\nComputing ground-truth metric baseline …") gt_baseline = compute_gt_baseline(ds) if args.num_samples > 0: ds = ds.select(range(min(args.num_samples, len(ds)))) log(f"\n Evaluating on {len(ds)} samples") # Load model log("\nLoading model …") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model_kwargs = { "quantization_config": bnb_config, "device_map": "auto", "trust_remote_code": True, } if args.flash_attn: model_kwargs["attn_implementation"] = "flash_attention_2" base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs) model = PeftModel.from_pretrained(base_model, args.adapter_path) model.eval() tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token log("✅ Model loaded successfully") log(f"\nStarting inference on {len(ds)} samples …") log(f" (First sample may take 1-2 min for CUDA warmup)\n") # Evaluate results = [] per_layer = defaultdict(lambda: defaultdict(list)) t_start = time.time() for i, row in enumerate(ds): t0 = time.time() messages = row["messages"] target_layer = row["target_layer"] reference_output = messages[-1]["content"] prompt_messages = [m for m in messages if m["role"] != "assistant"] input_text = tokenizer.apply_chat_template( prompt_messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(input_text, return_tensors="pt").to(model.device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=False, temperature=None, top_p=None, ) generated_ids = output_ids[0][inputs["input_ids"].shape[1]:] generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) parsed, is_valid_json = try_parse_json(generated_text) has_correct_structure = check_structure(parsed, target_layer) if parsed else False kpi_results = {} if parsed and not target_layer.startswith("adversarial") and target_layer not in LIFECYCLE_LAYERS: kpi_results = check_kpi_fields(parsed, row, target_layer) result = { "id": row["id"], "target_layer": target_layer, "slice_type": row["slice_type"], "lifecycle_operation": row["lifecycle_operation"], "json_valid": is_valid_json, "structure_correct": has_correct_structure, **kpi_results, "generated_length": len(generated_text), "reference_length": len(reference_output), } if args.save_generations: result["generated_text"] = generated_text result["reference_text"] = reference_output results.append(result) layer_key = target_layer per_layer[layer_key]["json_valid"].append(is_valid_json) per_layer[layer_key]["structure_correct"].append(has_correct_structure) for k, v in kpi_results.items(): per_layer[layer_key][k].append(v) # Progress logging — every sample with ETA elapsed = time.time() - t_start sample_time = time.time() - t0 avg_time = elapsed / (i + 1) remaining = avg_time * (len(ds) - i - 1) eta_h, eta_m = divmod(int(remaining), 3600) eta_m = eta_m // 60 json_ok = "✓" if is_valid_json else "✗" struct_ok = "✓" if has_correct_structure else "✗" log(f" [{i+1:>4}/{len(ds)}] {target_layer:<25} JSON:{json_ok} Struct:{struct_ok} " f"| {sample_time:.1f}s | ETA: {eta_h}h{eta_m:02d}m") # ── Aggregate metrics ──────────────────────────────────────────── total_time = time.time() - t_start log(f"\n Total inference time: {total_time/3600:.1f}h ({total_time/len(ds):.1f}s/sample)") log("\n" + "=" * 70) log("RESULTS (v2 — standard-aware KPI matching)") log("=" * 70) total_valid = sum(1 for r in results if r["json_valid"]) total_struct = sum(1 for r in results if r["structure_correct"]) n = len(results) overall = { "total_samples": n, "json_validity_rate": total_valid / n, "structure_correctness_rate": total_struct / n, } kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"] kpi_samples = [r for r in results if any(k in r for k in kpi_fields)] if kpi_samples: for field in kpi_fields: vals = [r.get(field, False) for r in kpi_samples] overall[field + "_rate"] = sum(vals) / len(vals) if vals else 0.0 all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples] overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi) adv_results = [r for r in results if r["target_layer"].startswith("adversarial")] if adv_results: adv_correct = sum(1 for r in adv_results if r["json_valid"] and r["structure_correct"]) overall["adversarial_accuracy"] = adv_correct / len(adv_results) overall["adversarial_samples"] = len(adv_results) layer_summary = {} for layer, metrics in sorted(per_layer.items()): layer_n = len(metrics["json_valid"]) layer_summary[layer] = { "n": layer_n, "json_valid": sum(metrics["json_valid"]) / layer_n, "structure_correct": sum(metrics["structure_correct"]) / layer_n, } for k in kpi_fields: if k in metrics and metrics[k]: layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k]) log(f"\n{'Metric':<35} {'Value':>10}") log("─" * 47) for k, v in overall.items(): if isinstance(v, float): log(f" {k:<33} {v:>9.1%}") else: log(f" {k:<33} {v:>9}") log(f"\n{'Layer':<25} {'N':>4} {'JSON':>6} {'Struct':>7} {'Lat':>6} {'Rel':>6} {'DL':>6} {'UL':>6} {'UEs':>6} {'All':>6}") log("─" * 85) for layer, m in layer_summary.items(): def fmt(key): return f"{m[key]*100:.0f}%" if key in m else "—" line = (f" {layer:<23} {m['n']:>4} {m['json_valid']*100:>5.0f}% {m['structure_correct']*100:>6.0f}% " f"{fmt('has_latency'):>6} {fmt('has_reliability'):>6} {fmt('has_dl_throughput'):>6} " f"{fmt('has_ul_throughput'):>6} {fmt('has_max_ues'):>6} ") layer_results = [r for r in results if r["target_layer"] == layer] layer_kpi = [r for r in layer_results if any(k in r for k in kpi_fields)] if layer_kpi: all_correct = sum(1 for r in layer_kpi if all(r.get(f, False) for f in kpi_fields)) line += f"{all_correct/len(layer_kpi)*100:>4.0f}%" else: line += f"{'—':>5}" log(line) output = { "config": vars(args), "overall": overall, "per_layer": layer_summary, "raw_results": results, } with open(args.output_file, "w") as f: json.dump(output, f, indent=2, default=str) log(f"\n✅ Results saved to {args.output_file}") if __name__ == "__main__": main()