| |
| """ |
| TMF921 Intent Translation β Evaluation Script |
| ============================================== |
| Evaluates a fine-tuned QLoRA model on the test split with metrics: |
| 1. JSON Schema Validity β is the output valid JSON? |
| 2. KPI Field Extraction β are latency/throughput/reliability/UEs present & correct? |
| 3. Cross-Standard Output β correct structure per target_layer? |
| 4. Adversarial F1 β correct rejection of bad intents |
| 5. Lifecycle Accuracy β correct lifecycle operation format |
| |
| Usage: |
| python evaluate.py --adapter_path ./output --num_samples 200 |
| python evaluate.py --adapter_path nraptisss/Qwen3-8B-TMF921-Intent-QLora --num_samples -1 |
| """ |
|
|
| import argparse, json, re, os, sys, torch |
| from collections import defaultdict |
| from datasets import load_dataset |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| ) |
| from peft import PeftModel |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B") |
| p.add_argument("--adapter_path", type=str, default="./output", |
| help="Path or HF id of LoRA adapter") |
| p.add_argument("--dataset", type=str, |
| default="nraptisss/TMF921-intent-to-config-augmented") |
| p.add_argument("--split", type=str, default="test") |
| p.add_argument("--num_samples", type=int, default=200, |
| help="Number of samples to evaluate (-1 for all)") |
| p.add_argument("--max_new_tokens", type=int, default=4096) |
| p.add_argument("--output_file", type=str, default="eval_results.json") |
| p.add_argument("--flash_attn", action="store_true", default=True) |
| return p.parse_args() |
|
|
|
|
| |
| def try_parse_json(text: str) -> tuple[dict | None, bool]: |
| """Try to parse JSON from model output, handling markdown fences.""" |
| text = text.strip() |
| |
| if text.startswith("```"): |
| text = re.sub(r"^```(?:json)?\s*\n?", "", text) |
| text = re.sub(r"\n?```\s*$", "", text) |
| |
| try: |
| return json.loads(text), True |
| except json.JSONDecodeError: |
| pass |
| |
| match = re.search(r"\{[\s\S]*\}", text) |
| if match: |
| try: |
| return json.loads(match.group()), True |
| except json.JSONDecodeError: |
| pass |
| return None, False |
|
|
|
|
| def check_kpi_fields(parsed: dict, row: dict) -> dict: |
| """Check if the generated config contains correct KPI values.""" |
| flat = json.dumps(parsed).lower() |
| results = {} |
|
|
| |
| target_latency = row["latency_ms"] |
| results["has_latency"] = str(int(target_latency)) in flat or str(target_latency) in flat |
|
|
| |
| target_rel = row["reliability_pct"] |
| results["has_reliability"] = str(target_rel) in flat |
|
|
| |
| target_dl = row["dl_throughput_mbps"] |
| results["has_dl_throughput"] = str(int(target_dl)) in flat or str(target_dl) in flat |
|
|
| |
| target_ul = row["ul_throughput_mbps"] |
| results["has_ul_throughput"] = str(int(target_ul)) in flat or str(target_ul) in flat |
|
|
| |
| target_ues = row["max_ues"] |
| results["has_max_ues"] = str(target_ues) in flat |
|
|
| return results |
|
|
|
|
| LAYER_ROOT_KEYS = { |
| "tmf921": ["id", "href", "name", "intentexpression"], |
| "intent_3gpp": ["intent"], |
| "camara": ["networkslicebooking"], |
| "etsi_zsm": ["zsmintent"], |
| "a1_policy": ["a1policy"], |
| "o1_nrm": ["managedelement"], |
| } |
|
|
| ADVERSARIAL_STATUSES = {"CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"} |
|
|
| LIFECYCLE_LAYERS = { |
| "tmf921_lifecycle_activate", "tmf921_lifecycle_modify", |
| "tmf921_lifecycle_suspend", "tmf921_lifecycle_resume", |
| "tmf921_lifecycle_terminate", "tmf921_lifecycle_scale", |
| "tmf921_lifecycle_monitor", "tmf921_lifecycle_report", |
| } |
|
|
|
|
| def check_structure(parsed: dict, target_layer: str) -> bool: |
| """Check if the JSON has the expected root keys for the target standard.""" |
| if target_layer.startswith("adversarial"): |
| return parsed.get("status") in ADVERSARIAL_STATUSES |
| if target_layer in LIFECYCLE_LAYERS: |
| flat_keys = {k.lower() for k in parsed.keys()} |
| return "intentpatch" in flat_keys or "intentassurancereport" in flat_keys or "intentupdate" in flat_keys |
| expected = LAYER_ROOT_KEYS.get(target_layer, []) |
| if not expected: |
| return True |
| flat_keys = {k.lower() for k in parsed.keys()} |
| return any(k in flat_keys for k in expected) |
|
|
|
|
| |
| def main(): |
| args = parse_args() |
|
|
| print("=" * 70) |
| print("TMF921 Intent Translation β Evaluation") |
| print("=" * 70) |
| print(f"Base model : {args.base_model}") |
| print(f"Adapter : {args.adapter_path}") |
| print(f"Dataset : {args.dataset} [{args.split}]") |
| print(f"Num samples : {args.num_samples}") |
| print("=" * 70) |
|
|
| |
| print("\nLoading dataset β¦") |
| ds = load_dataset(args.dataset, split=args.split) |
| if args.num_samples > 0: |
| ds = ds.select(range(min(args.num_samples, len(ds)))) |
| print(f" Evaluating on {len(ds)} samples") |
|
|
| |
| print("\nLoading model β¦") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| model_kwargs = { |
| "quantization_config": bnb_config, |
| "device_map": "auto", |
| "trust_remote_code": True, |
| } |
| if args.flash_attn: |
| model_kwargs["attn_implementation"] = "flash_attention_2" |
|
|
| base_model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, **model_kwargs |
| ) |
| model = PeftModel.from_pretrained(base_model, args.adapter_path) |
| model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| args.base_model, trust_remote_code=True |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| print("\nRunning inference β¦") |
| results = [] |
| per_layer = defaultdict(lambda: defaultdict(list)) |
|
|
| for i, row in enumerate(ds): |
| if (i + 1) % 20 == 0 or i == 0: |
| print(f" [{i+1}/{len(ds)}] β¦") |
|
|
| messages = row["messages"] |
| target_layer = row["target_layer"] |
| reference_output = messages[-1]["content"] |
|
|
| |
| prompt_messages = [m for m in messages if m["role"] != "assistant"] |
| input_text = tokenizer.apply_chat_template( |
| prompt_messages, tokenize=False, add_generation_prompt=True |
| ) |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| do_sample=False, |
| temperature=None, |
| top_p=None, |
| ) |
|
|
| |
| generated_ids = output_ids[0][inputs["input_ids"].shape[1]:] |
| generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
| |
| parsed, is_valid_json = try_parse_json(generated_text) |
| has_correct_structure = check_structure(parsed, target_layer) if parsed else False |
|
|
| kpi_results = {} |
| if parsed and not target_layer.startswith("adversarial") and target_layer not in LIFECYCLE_LAYERS: |
| kpi_results = check_kpi_fields(parsed, row) |
|
|
| result = { |
| "id": row["id"], |
| "target_layer": target_layer, |
| "slice_type": row["slice_type"], |
| "lifecycle_operation": row["lifecycle_operation"], |
| "json_valid": is_valid_json, |
| "structure_correct": has_correct_structure, |
| **kpi_results, |
| "generated_length": len(generated_text), |
| "reference_length": len(reference_output), |
| } |
| results.append(result) |
|
|
| |
| layer_key = target_layer if target_layer.startswith("adversarial") or target_layer in LIFECYCLE_LAYERS else target_layer |
| per_layer[layer_key]["json_valid"].append(is_valid_json) |
| per_layer[layer_key]["structure_correct"].append(has_correct_structure) |
| for k, v in kpi_results.items(): |
| per_layer[layer_key][k].append(v) |
|
|
| |
| print("\n" + "=" * 70) |
| print("RESULTS") |
| print("=" * 70) |
|
|
| total_valid = sum(1 for r in results if r["json_valid"]) |
| total_struct = sum(1 for r in results if r["structure_correct"]) |
| n = len(results) |
|
|
| |
| overall = { |
| "total_samples": n, |
| "json_validity_rate": total_valid / n, |
| "structure_correctness_rate": total_struct / n, |
| } |
|
|
| |
| kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"] |
| kpi_samples = [r for r in results if any(k in r for k in kpi_fields)] |
| if kpi_samples: |
| for field in kpi_fields: |
| vals = [r.get(field, False) for r in kpi_samples] |
| overall[field + "_rate"] = sum(vals) / len(vals) if vals else 0.0 |
| all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples] |
| overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi) |
|
|
| |
| adv_results = [r for r in results if r["target_layer"].startswith("adversarial")] |
| if adv_results: |
| adv_correct = sum(1 for r in adv_results if r["json_valid"] and r["structure_correct"]) |
| overall["adversarial_accuracy"] = adv_correct / len(adv_results) |
| overall["adversarial_samples"] = len(adv_results) |
|
|
| |
| layer_summary = {} |
| for layer, metrics in sorted(per_layer.items()): |
| layer_n = len(metrics["json_valid"]) |
| layer_summary[layer] = { |
| "n": layer_n, |
| "json_valid": sum(metrics["json_valid"]) / layer_n, |
| "structure_correct": sum(metrics["structure_correct"]) / layer_n, |
| } |
| for k in kpi_fields: |
| if k in metrics and metrics[k]: |
| layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k]) |
|
|
| |
| print(f"\n{'Metric':<35} {'Value':>10}") |
| print("β" * 47) |
| for k, v in overall.items(): |
| if isinstance(v, float): |
| print(f" {k:<33} {v:>9.1%}") |
| else: |
| print(f" {k:<33} {v:>9}") |
|
|
| print(f"\n{'Layer':<35} {'N':>5} {'JSON%':>7} {'Struct%':>8} {'AllKPI%':>8}") |
| print("β" * 65) |
| for layer, m in layer_summary.items(): |
| kpi_str = f"{m.get('has_latency', 0):.0%}" if "has_latency" in m else "β" |
| print(f" {layer:<33} {m['n']:>5} {m['json_valid']:>6.1%} " |
| f"{m['structure_correct']:>7.1%} {kpi_str:>6}") |
|
|
| |
| output = { |
| "config": vars(args), |
| "overall": overall, |
| "per_layer": layer_summary, |
| "raw_results": results, |
| } |
| with open(args.output_file, "w") as f: |
| json.dump(output, f, indent=2, default=str) |
| print(f"\nβ
Results saved to {args.output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|