nraptisss's picture
Upload evaluate.py
15addfa verified
#!/usr/bin/env python3
"""
TMF921 Intent Translation β€” Evaluation Script
==============================================
Evaluates a fine-tuned QLoRA model on the test split with metrics:
1. JSON Schema Validity β€” is the output valid JSON?
2. KPI Field Extraction β€” are latency/throughput/reliability/UEs present & correct?
3. Cross-Standard Output β€” correct structure per target_layer?
4. Adversarial F1 β€” correct rejection of bad intents
5. Lifecycle Accuracy β€” correct lifecycle operation format
Usage:
python evaluate.py --adapter_path ./output --num_samples 200
python evaluate.py --adapter_path nraptisss/Qwen3-8B-TMF921-Intent-QLora --num_samples -1
"""
import argparse, json, re, os, sys, torch
from collections import defaultdict
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import PeftModel
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
p.add_argument("--adapter_path", type=str, default="./output",
help="Path or HF id of LoRA adapter")
p.add_argument("--dataset", type=str,
default="nraptisss/TMF921-intent-to-config-augmented")
p.add_argument("--split", type=str, default="test")
p.add_argument("--num_samples", type=int, default=200,
help="Number of samples to evaluate (-1 for all)")
p.add_argument("--max_new_tokens", type=int, default=4096)
p.add_argument("--output_file", type=str, default="eval_results.json")
p.add_argument("--flash_attn", action="store_true", default=True)
return p.parse_args()
# ── Validation helpers ───────────────────────────────────────────────
def try_parse_json(text: str) -> tuple[dict | None, bool]:
"""Try to parse JSON from model output, handling markdown fences."""
text = text.strip()
# Remove markdown code fences
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
text = re.sub(r"\n?```\s*$", "", text)
# Try direct parse
try:
return json.loads(text), True
except json.JSONDecodeError:
pass
# Try to find JSON object in text
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
return json.loads(match.group()), True
except json.JSONDecodeError:
pass
return None, False
def check_kpi_fields(parsed: dict, row: dict) -> dict:
"""Check if the generated config contains correct KPI values."""
flat = json.dumps(parsed).lower()
results = {}
# Check latency
target_latency = row["latency_ms"]
results["has_latency"] = str(int(target_latency)) in flat or str(target_latency) in flat
# Check reliability
target_rel = row["reliability_pct"]
results["has_reliability"] = str(target_rel) in flat
# Check DL throughput
target_dl = row["dl_throughput_mbps"]
results["has_dl_throughput"] = str(int(target_dl)) in flat or str(target_dl) in flat
# Check UL throughput
target_ul = row["ul_throughput_mbps"]
results["has_ul_throughput"] = str(int(target_ul)) in flat or str(target_ul) in flat
# Check max UEs
target_ues = row["max_ues"]
results["has_max_ues"] = str(target_ues) in flat
return results
LAYER_ROOT_KEYS = {
"tmf921": ["id", "href", "name", "intentexpression"],
"intent_3gpp": ["intent"],
"camara": ["networkslicebooking"],
"etsi_zsm": ["zsmintent"],
"a1_policy": ["a1policy"],
"o1_nrm": ["managedelement"],
}
ADVERSARIAL_STATUSES = {"CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"}
LIFECYCLE_LAYERS = {
"tmf921_lifecycle_activate", "tmf921_lifecycle_modify",
"tmf921_lifecycle_suspend", "tmf921_lifecycle_resume",
"tmf921_lifecycle_terminate", "tmf921_lifecycle_scale",
"tmf921_lifecycle_monitor", "tmf921_lifecycle_report",
}
def check_structure(parsed: dict, target_layer: str) -> bool:
"""Check if the JSON has the expected root keys for the target standard."""
if target_layer.startswith("adversarial"):
return parsed.get("status") in ADVERSARIAL_STATUSES
if target_layer in LIFECYCLE_LAYERS:
flat_keys = {k.lower() for k in parsed.keys()}
return "intentpatch" in flat_keys or "intentassurancereport" in flat_keys or "intentupdate" in flat_keys
expected = LAYER_ROOT_KEYS.get(target_layer, [])
if not expected:
return True
flat_keys = {k.lower() for k in parsed.keys()}
return any(k in flat_keys for k in expected)
# ── Main evaluation ──────────────────────────────────────────────────
def main():
args = parse_args()
print("=" * 70)
print("TMF921 Intent Translation β€” Evaluation")
print("=" * 70)
print(f"Base model : {args.base_model}")
print(f"Adapter : {args.adapter_path}")
print(f"Dataset : {args.dataset} [{args.split}]")
print(f"Num samples : {args.num_samples}")
print("=" * 70)
# Load dataset
print("\nLoading dataset …")
ds = load_dataset(args.dataset, split=args.split)
if args.num_samples > 0:
ds = ds.select(range(min(args.num_samples, len(ds))))
print(f" Evaluating on {len(ds)} samples")
# Load model
print("\nLoading model …")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model_kwargs = {
"quantization_config": bnb_config,
"device_map": "auto",
"trust_remote_code": True,
}
if args.flash_attn:
model_kwargs["attn_implementation"] = "flash_attention_2"
base_model = AutoModelForCausalLM.from_pretrained(
args.base_model, **model_kwargs
)
model = PeftModel.from_pretrained(base_model, args.adapter_path)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
args.base_model, trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Evaluate
print("\nRunning inference …")
results = []
per_layer = defaultdict(lambda: defaultdict(list))
for i, row in enumerate(ds):
if (i + 1) % 20 == 0 or i == 0:
print(f" [{i+1}/{len(ds)}] …")
messages = row["messages"]
target_layer = row["target_layer"]
reference_output = messages[-1]["content"] # ground truth
# Build prompt (system + user only)
prompt_messages = [m for m in messages if m["role"] != "assistant"]
input_text = tokenizer.apply_chat_template(
prompt_messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=False,
temperature=None,
top_p=None,
)
# Decode only the new tokens
generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Parse & validate
parsed, is_valid_json = try_parse_json(generated_text)
has_correct_structure = check_structure(parsed, target_layer) if parsed else False
kpi_results = {}
if parsed and not target_layer.startswith("adversarial") and target_layer not in LIFECYCLE_LAYERS:
kpi_results = check_kpi_fields(parsed, row)
result = {
"id": row["id"],
"target_layer": target_layer,
"slice_type": row["slice_type"],
"lifecycle_operation": row["lifecycle_operation"],
"json_valid": is_valid_json,
"structure_correct": has_correct_structure,
**kpi_results,
"generated_length": len(generated_text),
"reference_length": len(reference_output),
}
results.append(result)
# Accumulate per-layer
layer_key = target_layer if target_layer.startswith("adversarial") or target_layer in LIFECYCLE_LAYERS else target_layer
per_layer[layer_key]["json_valid"].append(is_valid_json)
per_layer[layer_key]["structure_correct"].append(has_correct_structure)
for k, v in kpi_results.items():
per_layer[layer_key][k].append(v)
# ── Aggregate metrics ────────────────────────────────────────────
print("\n" + "=" * 70)
print("RESULTS")
print("=" * 70)
total_valid = sum(1 for r in results if r["json_valid"])
total_struct = sum(1 for r in results if r["structure_correct"])
n = len(results)
# Overall
overall = {
"total_samples": n,
"json_validity_rate": total_valid / n,
"structure_correctness_rate": total_struct / n,
}
# KPI accuracy (only for create operations on standard layers)
kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
if kpi_samples:
for field in kpi_fields:
vals = [r.get(field, False) for r in kpi_samples]
overall[field + "_rate"] = sum(vals) / len(vals) if vals else 0.0
all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
# Adversarial
adv_results = [r for r in results if r["target_layer"].startswith("adversarial")]
if adv_results:
adv_correct = sum(1 for r in adv_results if r["json_valid"] and r["structure_correct"])
overall["adversarial_accuracy"] = adv_correct / len(adv_results)
overall["adversarial_samples"] = len(adv_results)
# Per-layer breakdown
layer_summary = {}
for layer, metrics in sorted(per_layer.items()):
layer_n = len(metrics["json_valid"])
layer_summary[layer] = {
"n": layer_n,
"json_valid": sum(metrics["json_valid"]) / layer_n,
"structure_correct": sum(metrics["structure_correct"]) / layer_n,
}
for k in kpi_fields:
if k in metrics and metrics[k]:
layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
# Print
print(f"\n{'Metric':<35} {'Value':>10}")
print("─" * 47)
for k, v in overall.items():
if isinstance(v, float):
print(f" {k:<33} {v:>9.1%}")
else:
print(f" {k:<33} {v:>9}")
print(f"\n{'Layer':<35} {'N':>5} {'JSON%':>7} {'Struct%':>8} {'AllKPI%':>8}")
print("─" * 65)
for layer, m in layer_summary.items():
kpi_str = f"{m.get('has_latency', 0):.0%}" if "has_latency" in m else "β€”"
print(f" {layer:<33} {m['n']:>5} {m['json_valid']:>6.1%} "
f"{m['structure_correct']:>7.1%} {kpi_str:>6}")
# Save
output = {
"config": vars(args),
"overall": overall,
"per_layer": layer_summary,
"raw_results": results,
}
with open(args.output_file, "w") as f:
json.dump(output, f, indent=2, default=str)
print(f"\nβœ… Results saved to {args.output_file}")
if __name__ == "__main__":
main()