| |
| """ |
| TMF921 Intent Translation β Evaluation Script v2 |
| ================================================= |
| Standard-aware KPI checking that correctly handles how each telecom standard |
| encodes network parameters: |
| |
| TMF921, 3GPP TS 28.312, CAMARA, ETSI ZSM: |
| β KPI values embedded directly (with int/float tolerance: 99 vs 99.0) |
| |
| O-RAN A1 Policy: |
| β reliability β packet error rate (PER): 99.999% β 1e-05 |
| β latency β packet delay budget (pdb): mapped via 5QI table |
| β throughput β gfbr/mfbr (guaranteed/maximum flow bitrate) |
| |
| O-RAN O1 NRM (3GPP TS 28.541): |
| β KPIs translated to radio resource management configs (RRM policies, |
| cell parameters, frequency allocations). No direct numeric values. |
| β Evaluated via structural element presence. |
| |
| Changes from v1: |
| - Fixes metric bug where 92% of "reliability failures" were false negatives |
| - Adds ground-truth baseline (metric ceiling) printed before evaluation |
| - Standard-specific KPI checking (3 strategies) |
| - Expanded lifecycle operation key matching |
| - Saves generated text for error analysis |
| - Flushes stdout on every print (fixes nohup buffering) |
| |
| Usage: |
| python evaluate_v2.py --adapter_path ./output --num_samples 200 |
| python evaluate_v2.py --adapter_path ./output --num_samples -1 |
| """ |
|
|
| import argparse, json, re, os, sys, math, time, torch |
| from collections import defaultdict |
| from datasets import load_dataset |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| ) |
| from peft import PeftModel |
|
|
|
|
| def log(msg: str): |
| """Print with flush so nohup logs update in real time.""" |
| print(msg, flush=True) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B") |
| p.add_argument("--adapter_path", type=str, default="./output", |
| help="Path or HF id of LoRA adapter") |
| p.add_argument("--dataset", type=str, |
| default="nraptisss/TMF921-intent-to-config-augmented") |
| p.add_argument("--split", type=str, default="test") |
| p.add_argument("--num_samples", type=int, default=200, |
| help="Number of samples to evaluate (-1 for all)") |
| p.add_argument("--max_new_tokens", type=int, default=4096) |
| p.add_argument("--output_file", type=str, default="eval_results_v2.json") |
| p.add_argument("--flash_attn", action="store_true", default=True) |
| p.add_argument("--save_generations", action="store_true", default=True, |
| help="Save generated text in results for error analysis") |
| return p.parse_args() |
|
|
|
|
| |
| def try_parse_json(text: str) -> tuple[dict | None, bool]: |
| """Try to parse JSON from model output, handling markdown fences.""" |
| text = text.strip() |
| if text.startswith("```"): |
| text = re.sub(r"^```(?:json)?\s*\n?", "", text) |
| text = re.sub(r"\n?```\s*$", "", text) |
| try: |
| return json.loads(text), True |
| except json.JSONDecodeError: |
| pass |
| match = re.search(r"\{[\s\S]*\}", text) |
| if match: |
| try: |
| return json.loads(match.group()), True |
| except json.JSONDecodeError: |
| pass |
| return None, False |
|
|
|
|
| |
| def _num_representations(val: float) -> list[str]: |
| """Generate multiple string representations of a numeric value.""" |
| reps = [str(val)] |
| if val == int(val): |
| reps.append(str(int(val))) |
| reps.append(f"{val:.1f}") |
| reps.append(f"{val:.0f}") |
| return list(set(reps)) |
|
|
|
|
| def _reliability_representations(rel_pct: float) -> list[str]: |
| """Generate all plausible encodings of a reliability percentage.""" |
| reps = _num_representations(rel_pct) |
| per = 1 - rel_pct / 100 |
| if per > 0: |
| exp = math.floor(math.log10(per)) |
| mantissa = per / (10 ** exp) |
| reps.append(f"1e-{abs(exp):02d}") |
| reps.append(f"1e-{abs(exp)}") |
| reps.append(f"{per:.0e}") |
| reps.append(f"{per}") |
| if mantissa == 1.0: |
| reps.append(f"1e-{abs(exp):02d}") |
| else: |
| reps.append(f"{mantissa:.1f}e-{abs(exp):02d}") |
| if per < 1: |
| reps.append(f"{per:.10f}".rstrip("0").rstrip(".")) |
| return list(set(reps)) |
|
|
|
|
| def _find_all_numbers(parsed: dict) -> list[float]: |
| """Extract all numeric values from a nested JSON structure.""" |
| nums = [] |
| if isinstance(parsed, dict): |
| for v in parsed.values(): |
| if isinstance(v, (int, float)) and not isinstance(v, bool): |
| nums.append(float(v)) |
| elif isinstance(v, (dict, list)): |
| nums.extend(_find_all_numbers(v)) |
| elif isinstance(parsed, list): |
| for item in parsed: |
| if isinstance(item, (int, float)) and not isinstance(item, bool): |
| nums.append(float(item)) |
| elif isinstance(item, (dict, list)): |
| nums.extend(_find_all_numbers(item)) |
| return nums |
|
|
|
|
| def _check_kpi_direct(parsed: dict, row: dict, flat: str) -> dict: |
| """Direct KPI matching for TMF921, intent_3gpp, CAMARA, ETSI ZSM.""" |
| results = {} |
| results["has_latency"] = any(rep in flat for rep in _num_representations(row["latency_ms"])) |
| results["has_reliability"] = any(rep in flat for rep in _reliability_representations(row["reliability_pct"])) |
| results["has_dl_throughput"] = any(rep in flat for rep in _num_representations(row["dl_throughput_mbps"])) |
| results["has_ul_throughput"] = any(rep in flat for rep in _num_representations(row["ul_throughput_mbps"])) |
| results["has_max_ues"] = any(rep in flat for rep in _num_representations(float(row["max_ues"]))) |
| return results |
|
|
|
|
| def _check_kpi_a1_policy(parsed: dict, row: dict, flat: str) -> dict: |
| """A1 Policy: reliabilityβPER, latencyβpdb, throughputβgfbr/mfbr.""" |
| results = {} |
| all_nums = _find_all_numbers(parsed) |
|
|
| target_rel = row["reliability_pct"] |
| rel_found = any(rep in flat for rep in _reliability_representations(target_rel)) |
| if not rel_found: |
| per = 1 - target_rel / 100 |
| for n in all_nums: |
| if abs(n - target_rel) < 0.01: |
| rel_found = True; break |
| if per > 0 and n > 0 and abs(n - per) / max(per, 1e-15) < 0.1: |
| rel_found = True; break |
| results["has_reliability"] = rel_found |
| results["has_latency"] = '"pdb"' in flat or '"packetdelaybudget"' in flat |
| has_tput = '"gfbr"' in flat or '"mfbr"' in flat or '"guaranteedflowbitrate"' in flat |
| results["has_dl_throughput"] = has_tput |
| results["has_ul_throughput"] = has_tput |
| results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat |
| return results |
|
|
|
|
| def _check_kpi_o1_nrm(parsed: dict, row: dict, flat: str) -> dict: |
| """O1 NRM: structural element presence (KPIsβRRM policies, not direct values).""" |
| results = {} |
| results["has_latency"] = '"rrmpolicy"' in flat or '"nrcelldu"' in flat |
| results["has_reliability"] = '"operationalstate"' in flat or '"administrativestate"' in flat |
| results["has_dl_throughput"] = '"bschannelbwdl"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcndl"' in flat |
| results["has_ul_throughput"] = ( |
| '"bschannelbwul"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcnul"' in flat |
| or '"rrmpolicydedicatedratio"' in flat |
| ) |
| results["has_max_ues"] = '"rrmpolicymemberlist"' in flat or '"snssai"' in flat |
| return results |
|
|
|
|
| DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"} |
|
|
|
|
| def check_kpi_fields(parsed: dict, row: dict, target_layer: str) -> dict: |
| """Standard-aware KPI checking: direct / A1 Policy / O1 NRM strategies.""" |
| flat = json.dumps(parsed).lower() |
| if target_layer in DIRECT_KPI_LAYERS: |
| return _check_kpi_direct(parsed, row, flat) |
| elif target_layer == "a1_policy": |
| return _check_kpi_a1_policy(parsed, row, flat) |
| elif target_layer == "o1_nrm": |
| return _check_kpi_o1_nrm(parsed, row, flat) |
| else: |
| return _check_kpi_direct(parsed, row, flat) |
|
|
|
|
| |
| LAYER_ROOT_KEYS = { |
| "tmf921": ["id", "href", "name", "intentexpression"], |
| "intent_3gpp": ["intent"], |
| "camara": ["networkslicebooking"], |
| "etsi_zsm": ["zsmintent"], |
| "a1_policy": ["a1policy"], |
| "o1_nrm": ["managedelement"], |
| } |
|
|
| ADVERSARIAL_STATUSES = {"CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"} |
|
|
| LIFECYCLE_LAYERS = { |
| "tmf921_lifecycle_activate", "tmf921_lifecycle_modify", |
| "tmf921_lifecycle_suspend", "tmf921_lifecycle_resume", |
| "tmf921_lifecycle_terminate", "tmf921_lifecycle_scale", |
| "tmf921_lifecycle_monitor", "tmf921_lifecycle_report", |
| } |
|
|
| LIFECYCLE_KEYS = { |
| "tmf921_lifecycle_activate": ["intentpatch", "intentactivation"], |
| "tmf921_lifecycle_modify": ["intentpatch", "intentupdate", "intentmodification"], |
| "tmf921_lifecycle_suspend": ["intentpatch", "intentsuspension"], |
| "tmf921_lifecycle_resume": ["intentpatch", "intentresumption"], |
| "tmf921_lifecycle_terminate": ["intentpatch", "intenttermination"], |
| "tmf921_lifecycle_scale": ["intentpatch", "intentscaling"], |
| "tmf921_lifecycle_monitor": ["intentassurancereport", "intentmonitor", "intentfulfillmentreport", |
| "monitoringreport", "fulfillmentinfo", "report"], |
| "tmf921_lifecycle_report": ["intentassurancereport", "intentreport", "fulfillmentinfo", "report"], |
| } |
|
|
|
|
| def check_structure(parsed: dict, target_layer: str) -> bool: |
| """Check if the JSON has the expected root keys for the target standard.""" |
| if target_layer.startswith("adversarial"): |
| return parsed.get("status") in ADVERSARIAL_STATUSES |
| if target_layer in LIFECYCLE_LAYERS: |
| flat_keys = {k.lower() for k in parsed.keys()} |
| expected = LIFECYCLE_KEYS.get(target_layer, []) |
| return any(k in flat_keys for k in expected) |
| expected = LAYER_ROOT_KEYS.get(target_layer, []) |
| if not expected: |
| return True |
| flat_keys = {k.lower() for k in parsed.keys()} |
| return any(k in flat_keys for k in expected) |
|
|
|
|
| |
| def compute_gt_baseline(ds): |
| """Run the KPI checker against ground truth to establish metric ceiling.""" |
| gt_results = defaultdict(lambda: defaultdict(list)) |
| for row in ds: |
| layer = row["target_layer"] |
| if layer.startswith("adversarial") or layer in LIFECYCLE_LAYERS: |
| continue |
| gt_text = row["messages"][-1]["content"] |
| parsed, valid = try_parse_json(gt_text) |
| if not parsed: |
| continue |
| kpi = check_kpi_fields(parsed, row, layer) |
| for k, v in kpi.items(): |
| gt_results[layer][k].append(v) |
|
|
| log("\n Ground-truth baseline (metric ceiling β should be 100% for all):") |
| log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}") |
| log(" " + "β" * 55) |
| for layer in sorted(gt_results.keys()): |
| metrics = gt_results[layer] |
| def rate(key): |
| vals = metrics.get(key, []) |
| return sum(vals) / len(vals) * 100 if vals else 0 |
| log(f" {layer:<20} {rate('has_latency'):>7.1f}% {rate('has_reliability'):>7.1f}% " |
| f"{rate('has_dl_throughput'):>7.1f}% {rate('has_ul_throughput'):>7.1f}% {rate('has_max_ues'):>7.1f}%") |
| return gt_results |
|
|
|
|
| |
| def main(): |
| args = parse_args() |
|
|
| log("=" * 70) |
| log("TMF921 Intent Translation β Evaluation v2") |
| log("=" * 70) |
| log(f"Base model : {args.base_model}") |
| log(f"Adapter : {args.adapter_path}") |
| log(f"Dataset : {args.dataset} [{args.split}]") |
| log(f"Num samples : {args.num_samples}") |
| log(f"KPI checking : standard-aware (v2)") |
| log("=" * 70) |
|
|
| |
| log("\nLoading dataset β¦") |
| ds = load_dataset(args.dataset, split=args.split) |
|
|
| |
| log("\nComputing ground-truth metric baseline β¦") |
| gt_baseline = compute_gt_baseline(ds) |
|
|
| if args.num_samples > 0: |
| ds = ds.select(range(min(args.num_samples, len(ds)))) |
| log(f"\n Evaluating on {len(ds)} samples") |
|
|
| |
| log("\nLoading model β¦") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
| model_kwargs = { |
| "quantization_config": bnb_config, |
| "device_map": "auto", |
| "trust_remote_code": True, |
| } |
| if args.flash_attn: |
| model_kwargs["attn_implementation"] = "flash_attention_2" |
|
|
| base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs) |
| model = PeftModel.from_pretrained(base_model, args.adapter_path) |
| model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| log("β
Model loaded successfully") |
| log(f"\nStarting inference on {len(ds)} samples β¦") |
| log(f" (First sample may take 1-2 min for CUDA warmup)\n") |
|
|
| |
| results = [] |
| per_layer = defaultdict(lambda: defaultdict(list)) |
| t_start = time.time() |
|
|
| for i, row in enumerate(ds): |
| t0 = time.time() |
|
|
| messages = row["messages"] |
| target_layer = row["target_layer"] |
| reference_output = messages[-1]["content"] |
|
|
| prompt_messages = [m for m in messages if m["role"] != "assistant"] |
| input_text = tokenizer.apply_chat_template( |
| prompt_messages, tokenize=False, add_generation_prompt=True |
| ) |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| do_sample=False, |
| temperature=None, |
| top_p=None, |
| ) |
|
|
| generated_ids = output_ids[0][inputs["input_ids"].shape[1]:] |
| generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
| parsed, is_valid_json = try_parse_json(generated_text) |
| has_correct_structure = check_structure(parsed, target_layer) if parsed else False |
|
|
| kpi_results = {} |
| if parsed and not target_layer.startswith("adversarial") and target_layer not in LIFECYCLE_LAYERS: |
| kpi_results = check_kpi_fields(parsed, row, target_layer) |
|
|
| result = { |
| "id": row["id"], |
| "target_layer": target_layer, |
| "slice_type": row["slice_type"], |
| "lifecycle_operation": row["lifecycle_operation"], |
| "json_valid": is_valid_json, |
| "structure_correct": has_correct_structure, |
| **kpi_results, |
| "generated_length": len(generated_text), |
| "reference_length": len(reference_output), |
| } |
| if args.save_generations: |
| result["generated_text"] = generated_text |
| result["reference_text"] = reference_output |
|
|
| results.append(result) |
|
|
| layer_key = target_layer |
| per_layer[layer_key]["json_valid"].append(is_valid_json) |
| per_layer[layer_key]["structure_correct"].append(has_correct_structure) |
| for k, v in kpi_results.items(): |
| per_layer[layer_key][k].append(v) |
|
|
| |
| elapsed = time.time() - t_start |
| sample_time = time.time() - t0 |
| avg_time = elapsed / (i + 1) |
| remaining = avg_time * (len(ds) - i - 1) |
| eta_h, eta_m = divmod(int(remaining), 3600) |
| eta_m = eta_m // 60 |
|
|
| json_ok = "β" if is_valid_json else "β" |
| struct_ok = "β" if has_correct_structure else "β" |
| log(f" [{i+1:>4}/{len(ds)}] {target_layer:<25} JSON:{json_ok} Struct:{struct_ok} " |
| f"| {sample_time:.1f}s | ETA: {eta_h}h{eta_m:02d}m") |
|
|
| |
| total_time = time.time() - t_start |
| log(f"\n Total inference time: {total_time/3600:.1f}h ({total_time/len(ds):.1f}s/sample)") |
|
|
| log("\n" + "=" * 70) |
| log("RESULTS (v2 β standard-aware KPI matching)") |
| log("=" * 70) |
|
|
| total_valid = sum(1 for r in results if r["json_valid"]) |
| total_struct = sum(1 for r in results if r["structure_correct"]) |
| n = len(results) |
|
|
| overall = { |
| "total_samples": n, |
| "json_validity_rate": total_valid / n, |
| "structure_correctness_rate": total_struct / n, |
| } |
|
|
| kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"] |
| kpi_samples = [r for r in results if any(k in r for k in kpi_fields)] |
| if kpi_samples: |
| for field in kpi_fields: |
| vals = [r.get(field, False) for r in kpi_samples] |
| overall[field + "_rate"] = sum(vals) / len(vals) if vals else 0.0 |
| all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples] |
| overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi) |
|
|
| adv_results = [r for r in results if r["target_layer"].startswith("adversarial")] |
| if adv_results: |
| adv_correct = sum(1 for r in adv_results if r["json_valid"] and r["structure_correct"]) |
| overall["adversarial_accuracy"] = adv_correct / len(adv_results) |
| overall["adversarial_samples"] = len(adv_results) |
|
|
| layer_summary = {} |
| for layer, metrics in sorted(per_layer.items()): |
| layer_n = len(metrics["json_valid"]) |
| layer_summary[layer] = { |
| "n": layer_n, |
| "json_valid": sum(metrics["json_valid"]) / layer_n, |
| "structure_correct": sum(metrics["structure_correct"]) / layer_n, |
| } |
| for k in kpi_fields: |
| if k in metrics and metrics[k]: |
| layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k]) |
|
|
| log(f"\n{'Metric':<35} {'Value':>10}") |
| log("β" * 47) |
| for k, v in overall.items(): |
| if isinstance(v, float): |
| log(f" {k:<33} {v:>9.1%}") |
| else: |
| log(f" {k:<33} {v:>9}") |
|
|
| log(f"\n{'Layer':<25} {'N':>4} {'JSON':>6} {'Struct':>7} {'Lat':>6} {'Rel':>6} {'DL':>6} {'UL':>6} {'UEs':>6} {'All':>6}") |
| log("β" * 85) |
| for layer, m in layer_summary.items(): |
| def fmt(key): |
| return f"{m[key]*100:.0f}%" if key in m else "β" |
| line = (f" {layer:<23} {m['n']:>4} {m['json_valid']*100:>5.0f}% {m['structure_correct']*100:>6.0f}% " |
| f"{fmt('has_latency'):>6} {fmt('has_reliability'):>6} {fmt('has_dl_throughput'):>6} " |
| f"{fmt('has_ul_throughput'):>6} {fmt('has_max_ues'):>6} ") |
| layer_results = [r for r in results if r["target_layer"] == layer] |
| layer_kpi = [r for r in layer_results if any(k in r for k in kpi_fields)] |
| if layer_kpi: |
| all_correct = sum(1 for r in layer_kpi if all(r.get(f, False) for f in kpi_fields)) |
| line += f"{all_correct/len(layer_kpi)*100:>4.0f}%" |
| else: |
| line += f"{'β':>5}" |
| log(line) |
|
|
| output = { |
| "config": vars(args), |
| "overall": overall, |
| "per_layer": layer_summary, |
| "raw_results": results, |
| } |
| with open(args.output_file, "w") as f: |
| json.dump(output, f, indent=2, default=str) |
| log(f"\nβ
Results saved to {args.output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|