intent-translation-training / evaluate_v2.py
nraptisss's picture
Fix: flush stdout for nohup, log every sample, add timestamps
f34fb3a verified
#!/usr/bin/env python3
"""
TMF921 Intent Translation β€” Evaluation Script v2
=================================================
Standard-aware KPI checking that correctly handles how each telecom standard
encodes network parameters:
TMF921, 3GPP TS 28.312, CAMARA, ETSI ZSM:
β†’ KPI values embedded directly (with int/float tolerance: 99 vs 99.0)
O-RAN A1 Policy:
β†’ reliability β†’ packet error rate (PER): 99.999% β†’ 1e-05
β†’ latency β†’ packet delay budget (pdb): mapped via 5QI table
β†’ throughput β†’ gfbr/mfbr (guaranteed/maximum flow bitrate)
O-RAN O1 NRM (3GPP TS 28.541):
β†’ KPIs translated to radio resource management configs (RRM policies,
cell parameters, frequency allocations). No direct numeric values.
β†’ Evaluated via structural element presence.
Changes from v1:
- Fixes metric bug where 92% of "reliability failures" were false negatives
- Adds ground-truth baseline (metric ceiling) printed before evaluation
- Standard-specific KPI checking (3 strategies)
- Expanded lifecycle operation key matching
- Saves generated text for error analysis
- Flushes stdout on every print (fixes nohup buffering)
Usage:
python evaluate_v2.py --adapter_path ./output --num_samples 200
python evaluate_v2.py --adapter_path ./output --num_samples -1
"""
import argparse, json, re, os, sys, math, time, torch
from collections import defaultdict
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import PeftModel
def log(msg: str):
"""Print with flush so nohup logs update in real time."""
print(msg, flush=True)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
p.add_argument("--adapter_path", type=str, default="./output",
help="Path or HF id of LoRA adapter")
p.add_argument("--dataset", type=str,
default="nraptisss/TMF921-intent-to-config-augmented")
p.add_argument("--split", type=str, default="test")
p.add_argument("--num_samples", type=int, default=200,
help="Number of samples to evaluate (-1 for all)")
p.add_argument("--max_new_tokens", type=int, default=4096)
p.add_argument("--output_file", type=str, default="eval_results_v2.json")
p.add_argument("--flash_attn", action="store_true", default=True)
p.add_argument("--save_generations", action="store_true", default=True,
help="Save generated text in results for error analysis")
return p.parse_args()
# ── JSON Parsing ─────────────────────────────────────────────────────
def try_parse_json(text: str) -> tuple[dict | None, bool]:
"""Try to parse JSON from model output, handling markdown fences."""
text = text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
text = re.sub(r"\n?```\s*$", "", text)
try:
return json.loads(text), True
except json.JSONDecodeError:
pass
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
return json.loads(match.group()), True
except json.JSONDecodeError:
pass
return None, False
# ── Standard-aware KPI checking ─────────────────────────────────────
def _num_representations(val: float) -> list[str]:
"""Generate multiple string representations of a numeric value."""
reps = [str(val)]
if val == int(val):
reps.append(str(int(val)))
reps.append(f"{val:.1f}")
reps.append(f"{val:.0f}")
return list(set(reps))
def _reliability_representations(rel_pct: float) -> list[str]:
"""Generate all plausible encodings of a reliability percentage."""
reps = _num_representations(rel_pct)
per = 1 - rel_pct / 100
if per > 0:
exp = math.floor(math.log10(per))
mantissa = per / (10 ** exp)
reps.append(f"1e-{abs(exp):02d}")
reps.append(f"1e-{abs(exp)}")
reps.append(f"{per:.0e}")
reps.append(f"{per}")
if mantissa == 1.0:
reps.append(f"1e-{abs(exp):02d}")
else:
reps.append(f"{mantissa:.1f}e-{abs(exp):02d}")
if per < 1:
reps.append(f"{per:.10f}".rstrip("0").rstrip("."))
return list(set(reps))
def _find_all_numbers(parsed: dict) -> list[float]:
"""Extract all numeric values from a nested JSON structure."""
nums = []
if isinstance(parsed, dict):
for v in parsed.values():
if isinstance(v, (int, float)) and not isinstance(v, bool):
nums.append(float(v))
elif isinstance(v, (dict, list)):
nums.extend(_find_all_numbers(v))
elif isinstance(parsed, list):
for item in parsed:
if isinstance(item, (int, float)) and not isinstance(item, bool):
nums.append(float(item))
elif isinstance(item, (dict, list)):
nums.extend(_find_all_numbers(item))
return nums
def _check_kpi_direct(parsed: dict, row: dict, flat: str) -> dict:
"""Direct KPI matching for TMF921, intent_3gpp, CAMARA, ETSI ZSM."""
results = {}
results["has_latency"] = any(rep in flat for rep in _num_representations(row["latency_ms"]))
results["has_reliability"] = any(rep in flat for rep in _reliability_representations(row["reliability_pct"]))
results["has_dl_throughput"] = any(rep in flat for rep in _num_representations(row["dl_throughput_mbps"]))
results["has_ul_throughput"] = any(rep in flat for rep in _num_representations(row["ul_throughput_mbps"]))
results["has_max_ues"] = any(rep in flat for rep in _num_representations(float(row["max_ues"])))
return results
def _check_kpi_a1_policy(parsed: dict, row: dict, flat: str) -> dict:
"""A1 Policy: reliability→PER, latency→pdb, throughput→gfbr/mfbr."""
results = {}
all_nums = _find_all_numbers(parsed)
target_rel = row["reliability_pct"]
rel_found = any(rep in flat for rep in _reliability_representations(target_rel))
if not rel_found:
per = 1 - target_rel / 100
for n in all_nums:
if abs(n - target_rel) < 0.01:
rel_found = True; break
if per > 0 and n > 0 and abs(n - per) / max(per, 1e-15) < 0.1:
rel_found = True; break
results["has_reliability"] = rel_found
results["has_latency"] = '"pdb"' in flat or '"packetdelaybudget"' in flat
has_tput = '"gfbr"' in flat or '"mfbr"' in flat or '"guaranteedflowbitrate"' in flat
results["has_dl_throughput"] = has_tput
results["has_ul_throughput"] = has_tput
results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat
return results
def _check_kpi_o1_nrm(parsed: dict, row: dict, flat: str) -> dict:
"""O1 NRM: structural element presence (KPIs→RRM policies, not direct values)."""
results = {}
results["has_latency"] = '"rrmpolicy"' in flat or '"nrcelldu"' in flat
results["has_reliability"] = '"operationalstate"' in flat or '"administrativestate"' in flat
results["has_dl_throughput"] = '"bschannelbwdl"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcndl"' in flat
results["has_ul_throughput"] = (
'"bschannelbwul"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcnul"' in flat
or '"rrmpolicydedicatedratio"' in flat
)
results["has_max_ues"] = '"rrmpolicymemberlist"' in flat or '"snssai"' in flat
return results
DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"}
def check_kpi_fields(parsed: dict, row: dict, target_layer: str) -> dict:
"""Standard-aware KPI checking: direct / A1 Policy / O1 NRM strategies."""
flat = json.dumps(parsed).lower()
if target_layer in DIRECT_KPI_LAYERS:
return _check_kpi_direct(parsed, row, flat)
elif target_layer == "a1_policy":
return _check_kpi_a1_policy(parsed, row, flat)
elif target_layer == "o1_nrm":
return _check_kpi_o1_nrm(parsed, row, flat)
else:
return _check_kpi_direct(parsed, row, flat)
# ── Structure checking ───────────────────────────────────────────────
LAYER_ROOT_KEYS = {
"tmf921": ["id", "href", "name", "intentexpression"],
"intent_3gpp": ["intent"],
"camara": ["networkslicebooking"],
"etsi_zsm": ["zsmintent"],
"a1_policy": ["a1policy"],
"o1_nrm": ["managedelement"],
}
ADVERSARIAL_STATUSES = {"CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"}
LIFECYCLE_LAYERS = {
"tmf921_lifecycle_activate", "tmf921_lifecycle_modify",
"tmf921_lifecycle_suspend", "tmf921_lifecycle_resume",
"tmf921_lifecycle_terminate", "tmf921_lifecycle_scale",
"tmf921_lifecycle_monitor", "tmf921_lifecycle_report",
}
LIFECYCLE_KEYS = {
"tmf921_lifecycle_activate": ["intentpatch", "intentactivation"],
"tmf921_lifecycle_modify": ["intentpatch", "intentupdate", "intentmodification"],
"tmf921_lifecycle_suspend": ["intentpatch", "intentsuspension"],
"tmf921_lifecycle_resume": ["intentpatch", "intentresumption"],
"tmf921_lifecycle_terminate": ["intentpatch", "intenttermination"],
"tmf921_lifecycle_scale": ["intentpatch", "intentscaling"],
"tmf921_lifecycle_monitor": ["intentassurancereport", "intentmonitor", "intentfulfillmentreport",
"monitoringreport", "fulfillmentinfo", "report"],
"tmf921_lifecycle_report": ["intentassurancereport", "intentreport", "fulfillmentinfo", "report"],
}
def check_structure(parsed: dict, target_layer: str) -> bool:
"""Check if the JSON has the expected root keys for the target standard."""
if target_layer.startswith("adversarial"):
return parsed.get("status") in ADVERSARIAL_STATUSES
if target_layer in LIFECYCLE_LAYERS:
flat_keys = {k.lower() for k in parsed.keys()}
expected = LIFECYCLE_KEYS.get(target_layer, [])
return any(k in flat_keys for k in expected)
expected = LAYER_ROOT_KEYS.get(target_layer, [])
if not expected:
return True
flat_keys = {k.lower() for k in parsed.keys()}
return any(k in flat_keys for k in expected)
# ── Ground-truth baseline ────────────────────────────────────────────
def compute_gt_baseline(ds):
"""Run the KPI checker against ground truth to establish metric ceiling."""
gt_results = defaultdict(lambda: defaultdict(list))
for row in ds:
layer = row["target_layer"]
if layer.startswith("adversarial") or layer in LIFECYCLE_LAYERS:
continue
gt_text = row["messages"][-1]["content"]
parsed, valid = try_parse_json(gt_text)
if not parsed:
continue
kpi = check_kpi_fields(parsed, row, layer)
for k, v in kpi.items():
gt_results[layer][k].append(v)
log("\n Ground-truth baseline (metric ceiling β€” should be 100% for all):")
log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}")
log(" " + "─" * 55)
for layer in sorted(gt_results.keys()):
metrics = gt_results[layer]
def rate(key):
vals = metrics.get(key, [])
return sum(vals) / len(vals) * 100 if vals else 0
log(f" {layer:<20} {rate('has_latency'):>7.1f}% {rate('has_reliability'):>7.1f}% "
f"{rate('has_dl_throughput'):>7.1f}% {rate('has_ul_throughput'):>7.1f}% {rate('has_max_ues'):>7.1f}%")
return gt_results
# ── Main evaluation ──────────────────────────────────────────────────
def main():
args = parse_args()
log("=" * 70)
log("TMF921 Intent Translation β€” Evaluation v2")
log("=" * 70)
log(f"Base model : {args.base_model}")
log(f"Adapter : {args.adapter_path}")
log(f"Dataset : {args.dataset} [{args.split}]")
log(f"Num samples : {args.num_samples}")
log(f"KPI checking : standard-aware (v2)")
log("=" * 70)
# Load dataset
log("\nLoading dataset …")
ds = load_dataset(args.dataset, split=args.split)
# Compute ground-truth baseline on full test set
log("\nComputing ground-truth metric baseline …")
gt_baseline = compute_gt_baseline(ds)
if args.num_samples > 0:
ds = ds.select(range(min(args.num_samples, len(ds))))
log(f"\n Evaluating on {len(ds)} samples")
# Load model
log("\nLoading model …")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model_kwargs = {
"quantization_config": bnb_config,
"device_map": "auto",
"trust_remote_code": True,
}
if args.flash_attn:
model_kwargs["attn_implementation"] = "flash_attention_2"
base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
model = PeftModel.from_pretrained(base_model, args.adapter_path)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
log("βœ… Model loaded successfully")
log(f"\nStarting inference on {len(ds)} samples …")
log(f" (First sample may take 1-2 min for CUDA warmup)\n")
# Evaluate
results = []
per_layer = defaultdict(lambda: defaultdict(list))
t_start = time.time()
for i, row in enumerate(ds):
t0 = time.time()
messages = row["messages"]
target_layer = row["target_layer"]
reference_output = messages[-1]["content"]
prompt_messages = [m for m in messages if m["role"] != "assistant"]
input_text = tokenizer.apply_chat_template(
prompt_messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=False,
temperature=None,
top_p=None,
)
generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
parsed, is_valid_json = try_parse_json(generated_text)
has_correct_structure = check_structure(parsed, target_layer) if parsed else False
kpi_results = {}
if parsed and not target_layer.startswith("adversarial") and target_layer not in LIFECYCLE_LAYERS:
kpi_results = check_kpi_fields(parsed, row, target_layer)
result = {
"id": row["id"],
"target_layer": target_layer,
"slice_type": row["slice_type"],
"lifecycle_operation": row["lifecycle_operation"],
"json_valid": is_valid_json,
"structure_correct": has_correct_structure,
**kpi_results,
"generated_length": len(generated_text),
"reference_length": len(reference_output),
}
if args.save_generations:
result["generated_text"] = generated_text
result["reference_text"] = reference_output
results.append(result)
layer_key = target_layer
per_layer[layer_key]["json_valid"].append(is_valid_json)
per_layer[layer_key]["structure_correct"].append(has_correct_structure)
for k, v in kpi_results.items():
per_layer[layer_key][k].append(v)
# Progress logging β€” every sample with ETA
elapsed = time.time() - t_start
sample_time = time.time() - t0
avg_time = elapsed / (i + 1)
remaining = avg_time * (len(ds) - i - 1)
eta_h, eta_m = divmod(int(remaining), 3600)
eta_m = eta_m // 60
json_ok = "βœ“" if is_valid_json else "βœ—"
struct_ok = "βœ“" if has_correct_structure else "βœ—"
log(f" [{i+1:>4}/{len(ds)}] {target_layer:<25} JSON:{json_ok} Struct:{struct_ok} "
f"| {sample_time:.1f}s | ETA: {eta_h}h{eta_m:02d}m")
# ── Aggregate metrics ────────────────────────────────────────────
total_time = time.time() - t_start
log(f"\n Total inference time: {total_time/3600:.1f}h ({total_time/len(ds):.1f}s/sample)")
log("\n" + "=" * 70)
log("RESULTS (v2 β€” standard-aware KPI matching)")
log("=" * 70)
total_valid = sum(1 for r in results if r["json_valid"])
total_struct = sum(1 for r in results if r["structure_correct"])
n = len(results)
overall = {
"total_samples": n,
"json_validity_rate": total_valid / n,
"structure_correctness_rate": total_struct / n,
}
kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
if kpi_samples:
for field in kpi_fields:
vals = [r.get(field, False) for r in kpi_samples]
overall[field + "_rate"] = sum(vals) / len(vals) if vals else 0.0
all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
adv_results = [r for r in results if r["target_layer"].startswith("adversarial")]
if adv_results:
adv_correct = sum(1 for r in adv_results if r["json_valid"] and r["structure_correct"])
overall["adversarial_accuracy"] = adv_correct / len(adv_results)
overall["adversarial_samples"] = len(adv_results)
layer_summary = {}
for layer, metrics in sorted(per_layer.items()):
layer_n = len(metrics["json_valid"])
layer_summary[layer] = {
"n": layer_n,
"json_valid": sum(metrics["json_valid"]) / layer_n,
"structure_correct": sum(metrics["structure_correct"]) / layer_n,
}
for k in kpi_fields:
if k in metrics and metrics[k]:
layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
log(f"\n{'Metric':<35} {'Value':>10}")
log("─" * 47)
for k, v in overall.items():
if isinstance(v, float):
log(f" {k:<33} {v:>9.1%}")
else:
log(f" {k:<33} {v:>9}")
log(f"\n{'Layer':<25} {'N':>4} {'JSON':>6} {'Struct':>7} {'Lat':>6} {'Rel':>6} {'DL':>6} {'UL':>6} {'UEs':>6} {'All':>6}")
log("─" * 85)
for layer, m in layer_summary.items():
def fmt(key):
return f"{m[key]*100:.0f}%" if key in m else "β€”"
line = (f" {layer:<23} {m['n']:>4} {m['json_valid']*100:>5.0f}% {m['structure_correct']*100:>6.0f}% "
f"{fmt('has_latency'):>6} {fmt('has_reliability'):>6} {fmt('has_dl_throughput'):>6} "
f"{fmt('has_ul_throughput'):>6} {fmt('has_max_ues'):>6} ")
layer_results = [r for r in results if r["target_layer"] == layer]
layer_kpi = [r for r in layer_results if any(k in r for k in kpi_fields)]
if layer_kpi:
all_correct = sum(1 for r in layer_kpi if all(r.get(f, False) for f in kpi_fields))
line += f"{all_correct/len(layer_kpi)*100:>4.0f}%"
else:
line += f"{'β€”':>5}"
log(line)
output = {
"config": vars(args),
"overall": overall,
"per_layer": layer_summary,
"raw_results": results,
}
with open(args.output_file, "w") as f:
json.dump(output, f, indent=2, default=str)
log(f"\nβœ… Results saved to {args.output_file}")
if __name__ == "__main__":
main()