intent-translation-training / evaluate_v3.py
nraptisss's picture
Add zero-shot baseline mode: --adapter_path none skips adapter loading, --no_think suppresses Qwen3 thinking"
a2801bb verified
#!/usr/bin/env python3
"""
TMF921 Intent Translation β€” Evaluation Script v3
=================================================
Optimized evaluation with:
1. Stratified sampling β€” N samples per target_layer (covers all layers evenly)
2. Layer-aware max tokens β€” caps generation length per layer (saves ~40% time)
3. Incremental saves β€” writes results after every sample (never lose progress)
4. Resume support β€” skips already-evaluated IDs from a previous checkpoint
5. Zero-shot baseline β€” --adapter_path none to evaluate base model without adapter
All KPI checking logic is identical to v2 (standard-aware).
Usage:
# Fine-tuned model evaluation
python evaluate_v3.py --adapter_path ./output
# Zero-shot baseline (no adapter, base model only)
python evaluate_v3.py --adapter_path none --output_file eval_v3_baseline.json
# Zero-shot baseline with Qwen3 thinking disabled
python evaluate_v3.py --adapter_path none --no_think --output_file eval_v3_baseline.json
# Resume interrupted run
python evaluate_v3.py --adapter_path ./output --resume
"""
import argparse, json, re, os, sys, math, time, random, torch
from collections import defaultdict
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# ═══════════════════════════════════════════════════════════════════════
# LAYER-AWARE MAX TOKENS
# ═══════════════════════════════════════════════════════════════════════
LAYER_MAX_TOKENS = {
"tmf921": 1600,
"intent_3gpp": 900,
"camara": 400,
"a1_policy": 350,
"o1_nrm": 500,
"etsi_zsm": 1100,
"tmf921_lifecycle_activate": 250,
"tmf921_lifecycle_modify": 600,
"tmf921_lifecycle_monitor": 500,
"tmf921_lifecycle_report": 800,
"tmf921_lifecycle_resume": 250,
"tmf921_lifecycle_scale": 350,
"tmf921_lifecycle_suspend": 250,
"tmf921_lifecycle_terminate": 250,
"adversarial_ambiguous": 200,
"adversarial_contradictory": 200,
"adversarial_out_of_scope": 200,
}
DEFAULT_MAX_TOKENS = 2048
# ═══════════════════════════════════════════════════════════════════════
# HELPERS
# ═══════════════════════════════════════════════════════════════════════
def log(msg: str):
print(msg, flush=True)
def parse_args():
p = argparse.ArgumentParser(description="TMF921 Evaluation v3 β€” stratified, incremental, fast")
p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
p.add_argument("--adapter_path", type=str, default="./output",
help="Path to LoRA adapter, or 'none' for zero-shot baseline")
p.add_argument("--dataset", type=str, default="nraptisss/TMF921-intent-to-config-augmented")
p.add_argument("--split", type=str, default="test")
p.add_argument("--per_layer", type=int, default=50,
help="Samples per target_layer. Layers with fewer samples use all. -1 = all.")
p.add_argument("--seed", type=int, default=42, help="Random seed for stratified sampling")
p.add_argument("--output_file", type=str, default="eval_v3_results.json")
p.add_argument("--resume", action="store_true",
help="Resume from existing output_file, skipping already-evaluated IDs")
p.add_argument("--flash_attn", action="store_true", default=True)
p.add_argument("--no_flash_attn", action="store_true", default=False)
p.add_argument("--no_think", action="store_true", default=False,
help="Suppress Qwen3 thinking mode by adding /no_think to the last user message")
p.add_argument("--save_generations", action="store_true", default=True)
return p.parse_args()
# ═══════════════════════════════════════════════════════════════════════
# STRATIFIED SAMPLING
# ═══════════════════════════════════════════════════════════════════════
def stratified_sample(ds, per_layer: int, seed: int = 42):
rng = random.Random(seed)
layer_indices = defaultdict(list)
for i in range(len(ds)):
layer_indices[ds[i]["target_layer"]].append(i)
selected = []
layer_counts = {}
for layer, indices in sorted(layer_indices.items()):
if per_layer < 0 or len(indices) <= per_layer:
chosen = indices
else:
chosen = rng.sample(indices, per_layer)
selected.extend(chosen)
layer_counts[layer] = len(chosen)
rng.shuffle(selected)
return selected, layer_counts
# ═══════════════════════════════════════════════════════════════════════
# JSON PARSING
# ═══════════════════════════════════════════════════════════════════════
def try_parse_json(text: str):
text = text.strip()
# Strip thinking tags if present
think_match = re.match(r"<think>[\s\S]*?</think>\s*", text)
if think_match:
text = text[think_match.end():].strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
text = re.sub(r"\n?```\s*$", "", text)
try:
return json.loads(text), True
except json.JSONDecodeError:
pass
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
return json.loads(match.group()), True
except json.JSONDecodeError:
pass
return None, False
# ═══════════════════════════════════════════════════════════════════════
# KPI CHECKING (standard-aware, identical to v2)
# ═══════════════════════════════════════════════════════════════════════
def _num_representations(val: float) -> list:
reps = [str(val)]
if val == int(val):
reps.append(str(int(val)))
reps.append(f"{val:.1f}")
reps.append(f"{val:.0f}")
return list(set(reps))
def _reliability_representations(rel_pct: float) -> list:
reps = _num_representations(rel_pct)
per = 1 - rel_pct / 100
if per > 0:
exp = math.floor(math.log10(per))
mantissa = per / (10 ** exp)
reps.append(f"1e-{abs(exp):02d}")
reps.append(f"1e-{abs(exp)}")
reps.append(f"{per:.0e}")
reps.append(f"{per}")
if mantissa == 1.0:
reps.append(f"1e-{abs(exp):02d}")
else:
reps.append(f"{mantissa:.1f}e-{abs(exp):02d}")
if per < 1:
reps.append(f"{per:.10f}".rstrip("0").rstrip("."))
return list(set(reps))
def _find_all_numbers(parsed) -> list:
nums = []
if isinstance(parsed, dict):
for v in parsed.values():
if isinstance(v, (int, float)) and not isinstance(v, bool):
nums.append(float(v))
elif isinstance(v, (dict, list)):
nums.extend(_find_all_numbers(v))
elif isinstance(parsed, list):
for item in parsed:
if isinstance(item, (int, float)) and not isinstance(item, bool):
nums.append(float(item))
elif isinstance(item, (dict, list)):
nums.extend(_find_all_numbers(item))
return nums
def _check_kpi_direct(parsed, row, flat):
return {
"has_latency": any(r in flat for r in _num_representations(row["latency_ms"])),
"has_reliability": any(r in flat for r in _reliability_representations(row["reliability_pct"])),
"has_dl_throughput": any(r in flat for r in _num_representations(row["dl_throughput_mbps"])),
"has_ul_throughput": any(r in flat for r in _num_representations(row["ul_throughput_mbps"])),
"has_max_ues": any(r in flat for r in _num_representations(float(row["max_ues"]))),
}
def _check_kpi_a1_policy(parsed, row, flat):
results = {}
all_nums = _find_all_numbers(parsed)
target_rel = row["reliability_pct"]
rel_found = any(r in flat for r in _reliability_representations(target_rel))
if not rel_found:
per = 1 - target_rel / 100
for n in all_nums:
if abs(n - target_rel) < 0.01:
rel_found = True; break
if per > 0 and n > 0 and abs(n - per) / max(per, 1e-15) < 0.1:
rel_found = True; break
results["has_reliability"] = rel_found
results["has_latency"] = '"pdb"' in flat or '"packetdelaybudget"' in flat
results["has_dl_throughput"] = '"gfbr"' in flat or '"mfbr"' in flat or '"guaranteedflowbitrate"' in flat
results["has_ul_throughput"] = results["has_dl_throughput"]
results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat
return results
def _check_kpi_o1_nrm(parsed, row, flat):
return {
"has_latency": '"rrmpolicy"' in flat or '"nrcelldu"' in flat,
"has_reliability": '"operationalstate"' in flat or '"administrativestate"' in flat,
"has_dl_throughput": '"bschannelbwdl"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcndl"' in flat,
"has_ul_throughput": ('"bschannelbwul"' in flat or '"rrmpolicymaxratio"' in flat
or '"arfcnul"' in flat or '"rrmpolicydedicatedratio"' in flat),
"has_max_ues": '"rrmpolicymemberlist"' in flat or '"snssai"' in flat,
}
DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"}
def check_kpi_fields(parsed, row, target_layer):
flat = json.dumps(parsed).lower()
if target_layer in DIRECT_KPI_LAYERS:
return _check_kpi_direct(parsed, row, flat)
elif target_layer == "a1_policy":
return _check_kpi_a1_policy(parsed, row, flat)
elif target_layer == "o1_nrm":
return _check_kpi_o1_nrm(parsed, row, flat)
return _check_kpi_direct(parsed, row, flat)
# ═══════════════════════════════════════════════════════════════════════
# STRUCTURE CHECKING
# ═══════════════════════════════════════════════════════════════════════
LAYER_ROOT_KEYS = {
"tmf921": ["id", "href", "name", "intentexpression"],
"intent_3gpp": ["intent"],
"camara": ["networkslicebooking"],
"etsi_zsm": ["zsmintent"],
"a1_policy": ["a1policy"],
"o1_nrm": ["managedelement"],
}
ADVERSARIAL_STATUSES = {"CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"}
LIFECYCLE_LAYERS = {
"tmf921_lifecycle_activate", "tmf921_lifecycle_modify",
"tmf921_lifecycle_suspend", "tmf921_lifecycle_resume",
"tmf921_lifecycle_terminate", "tmf921_lifecycle_scale",
"tmf921_lifecycle_monitor", "tmf921_lifecycle_report",
}
LIFECYCLE_KEYS = {
"tmf921_lifecycle_activate": ["intentpatch", "intentactivation"],
"tmf921_lifecycle_modify": ["intentpatch", "intentupdate", "intentmodification"],
"tmf921_lifecycle_suspend": ["intentpatch", "intentsuspension"],
"tmf921_lifecycle_resume": ["intentpatch", "intentresumption"],
"tmf921_lifecycle_terminate": ["intentpatch", "intenttermination"],
"tmf921_lifecycle_scale": ["intentpatch", "intentscaling"],
"tmf921_lifecycle_monitor": ["intentassurancereport", "intentmonitor", "intentfulfillmentreport",
"monitoringreport", "fulfillmentinfo", "report"],
"tmf921_lifecycle_report": ["intentassurancereport", "intentreport", "fulfillmentinfo", "report"],
}
def check_structure(parsed, target_layer):
if target_layer.startswith("adversarial"):
return parsed.get("status") in ADVERSARIAL_STATUSES
if target_layer in LIFECYCLE_LAYERS:
flat_keys = {k.lower() for k in parsed.keys()}
return any(k in flat_keys for k in LIFECYCLE_KEYS.get(target_layer, []))
expected = LAYER_ROOT_KEYS.get(target_layer, [])
if not expected:
return True
flat_keys = {k.lower() for k in parsed.keys()}
return any(k in flat_keys for k in expected)
# ═══════════════════════════════════════════════════════════════════════
# INCREMENTAL SAVE / RESUME
# ═══════════════════════════════════════════════════════════════════════
def save_checkpoint(output_file, config, results):
n = len(results)
if n == 0:
return
total_valid = sum(1 for r in results if r["json_valid"])
total_struct = sum(1 for r in results if r["structure_correct"])
overall = {
"total_samples": n,
"json_validity_rate": total_valid / n,
"structure_correctness_rate": total_struct / n,
}
kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
if kpi_samples:
for field in kpi_fields:
vals = [r.get(field, False) for r in kpi_samples]
overall[field + "_rate"] = sum(vals) / len(vals)
all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
per_layer = defaultdict(lambda: defaultdict(list))
for r in results:
layer = r["target_layer"]
per_layer[layer]["json_valid"].append(r["json_valid"])
per_layer[layer]["structure_correct"].append(r["structure_correct"])
for k in kpi_fields:
if k in r:
per_layer[layer][k].append(r[k])
layer_summary = {}
for layer, metrics in per_layer.items():
ln = len(metrics["json_valid"])
layer_summary[layer] = {
"n": ln,
"json_valid": sum(metrics["json_valid"]) / ln,
"structure_correct": sum(metrics["structure_correct"]) / ln,
}
for k in kpi_fields:
if k in metrics and metrics[k]:
layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
output = {
"config": config,
"overall": overall,
"per_layer": layer_summary,
"raw_results": results,
}
tmp = output_file + ".tmp"
with open(tmp, "w") as f:
json.dump(output, f, indent=2, default=str)
os.replace(tmp, output_file)
def load_checkpoint(output_file):
if not os.path.exists(output_file):
return [], set()
try:
with open(output_file) as f:
data = json.load(f)
results = data.get("raw_results", [])
done_ids = {r["id"] for r in results}
log(f" Resuming from checkpoint: {len(results)} samples already done")
return results, done_ids
except (json.JSONDecodeError, KeyError):
log(" Warning: checkpoint file corrupted, starting fresh")
return [], set()
# ═══════════════════════════════════════════════════════════════════════
# GROUND-TRUTH BASELINE
# ═══════════════════════════════════════════════════════════════════════
def compute_gt_baseline(ds):
gt_results = defaultdict(lambda: defaultdict(list))
for row in ds:
layer = row["target_layer"]
if layer.startswith("adversarial") or layer in LIFECYCLE_LAYERS:
continue
gt_text = row["messages"][-1]["content"]
parsed, valid = try_parse_json(gt_text)
if not parsed:
continue
kpi = check_kpi_fields(parsed, row, layer)
for k, v in kpi.items():
gt_results[layer][k].append(v)
log("\n Ground-truth baseline (metric ceiling):")
log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}")
log(" " + "─" * 55)
for layer in sorted(gt_results.keys()):
m = gt_results[layer]
def rate(key):
vals = m.get(key, [])
return sum(vals) / len(vals) * 100 if vals else 0
log(f" {layer:<20} {rate('has_latency'):>7.1f}% {rate('has_reliability'):>7.1f}% "
f"{rate('has_dl_throughput'):>7.1f}% {rate('has_ul_throughput'):>7.1f}% {rate('has_max_ues'):>7.1f}%")
# ═══════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════
def main():
args = parse_args()
if args.no_flash_attn:
args.flash_attn = False
# Detect baseline mode
is_baseline = args.adapter_path.lower() in ("none", "baseline", "base", "")
log("=" * 70)
log("TMF921 Intent Translation β€” Evaluation v3")
if is_baseline:
log(" *** ZERO-SHOT BASELINE MODE (no adapter) ***")
log(" Stratified sampling Β· Layer-aware max tokens Β· Incremental saves")
log("=" * 70)
log(f" Base model : {args.base_model}")
log(f" Adapter : {'NONE (zero-shot baseline)' if is_baseline else args.adapter_path}")
log(f" No-think : {args.no_think}")
log(f" Dataset : {args.dataset} [{args.split}]")
log(f" Per-layer N : {args.per_layer} (-1 = all)")
log(f" Output : {args.output_file}")
log(f" Resume : {args.resume}")
log("=" * 70)
# ── Load dataset ──
log("\nLoading dataset …")
ds = load_dataset(args.dataset, split=args.split)
log(f" Full test set: {len(ds)} samples")
# ── Ground-truth baseline ──
log("\nComputing ground-truth metric baseline …")
compute_gt_baseline(ds)
# ── Stratified sampling ──
if args.per_layer < 0:
indices = list(range(len(ds)))
layer_counts = defaultdict(int)
for i in indices:
layer_counts[ds[i]["target_layer"]] += 1
random.Random(args.seed).shuffle(indices)
else:
indices, layer_counts = stratified_sample(ds, args.per_layer, args.seed)
log(f"\n Stratified sample: {len(indices)} total samples")
for layer, cnt in sorted(layer_counts.items()):
log(f" {layer:<35} {cnt:>4}")
# ── Resume from checkpoint ──
if args.resume:
prev_results, done_ids = load_checkpoint(args.output_file)
else:
prev_results, done_ids = [], set()
remaining_indices = [i for i in indices if ds[i]["id"] not in done_ids]
log(f"\n Already evaluated: {len(done_ids)}")
log(f" Remaining: {len(remaining_indices)}")
if not remaining_indices:
log("\n All samples already evaluated! Nothing to do.")
log(f" Results in: {args.output_file}")
return
# ── Load model ──
log("\nLoading model …")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model_kwargs = {
"quantization_config": bnb_config,
"device_map": "auto",
"trust_remote_code": True,
}
if args.flash_attn:
model_kwargs["attn_implementation"] = "flash_attention_2"
base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
if is_baseline:
model = base_model
log(" βœ… Base model loaded (zero-shot baseline β€” no adapter)")
else:
from peft import PeftModel
model = PeftModel.from_pretrained(base_model, args.adapter_path)
log(f" βœ… Fine-tuned model loaded (adapter: {args.adapter_path})")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
log("")
# ── Inference loop ──
results = list(prev_results)
config_dict = {
**vars(args),
"is_baseline": is_baseline,
"total_selected": len(indices),
"layer_counts": dict(layer_counts),
}
t_start = time.time()
total_to_do = len(remaining_indices)
for eval_idx, ds_idx in enumerate(remaining_indices):
row = ds[ds_idx]
target_layer = row["target_layer"]
t0 = time.time()
max_tokens = LAYER_MAX_TOKENS.get(target_layer, DEFAULT_MAX_TOKENS)
messages = row["messages"]
reference_output = messages[-1]["content"]
# Build prompt messages (system + user, no assistant)
prompt_messages = [m for m in messages if m["role"] != "assistant"]
# Optionally suppress Qwen3 thinking mode
if args.no_think and prompt_messages:
last_msg = prompt_messages[-1]
if last_msg["role"] == "user" and "/no_think" not in last_msg["content"]:
prompt_messages[-1] = {
"role": last_msg["role"],
"content": last_msg["content"] + " /no_think"
}
input_text = tokenizer.apply_chat_template(
prompt_messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=None,
top_p=None,
)
generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
parsed, is_valid_json = try_parse_json(generated_text)
has_correct_structure = check_structure(parsed, target_layer) if parsed else False
kpi_results = {}
if parsed and not target_layer.startswith("adversarial") and target_layer not in LIFECYCLE_LAYERS:
kpi_results = check_kpi_fields(parsed, row, target_layer)
result = {
"id": row["id"],
"target_layer": target_layer,
"slice_type": row["slice_type"],
"lifecycle_operation": row["lifecycle_operation"],
"json_valid": is_valid_json,
"structure_correct": has_correct_structure,
**kpi_results,
"generated_length": len(generated_text),
"reference_length": len(reference_output),
"max_tokens_used": max_tokens,
}
if args.save_generations:
result["generated_text"] = generated_text
results.append(result)
# ── Incremental save ──
save_checkpoint(args.output_file, config_dict, results)
# ── Progress ──
elapsed = time.time() - t_start
sample_time = time.time() - t0
done_now = eval_idx + 1
avg_time = elapsed / done_now
remaining_time = avg_time * (total_to_do - done_now)
eta_h, eta_rem = divmod(int(remaining_time), 3600)
eta_m = eta_rem // 60
j = "βœ“" if is_valid_json else "βœ—"
s = "βœ“" if has_correct_structure else "βœ—"
progress_n = len(done_ids) + done_now
progress_total = len(indices)
mode_tag = "[BASE]" if is_baseline else "[FT]"
log(f" {mode_tag} [{progress_n:>4}/{progress_total}] {target_layer:<30} JSON:{j} Struct:{s} "
f"| {sample_time:.1f}s (max_tok={max_tokens}) | ETA: {eta_h}h{eta_m:02d}m")
# ── Final summary ──
total_time = time.time() - t_start
n = len(results)
total_valid = sum(1 for r in results if r["json_valid"])
total_struct = sum(1 for r in results if r["structure_correct"])
log(f"\n{'=' * 70}")
mode_str = "ZERO-SHOT BASELINE" if is_baseline else "FINE-TUNED"
log(f"FINAL RESULTS β€” {mode_str} ({n} samples, {total_time/3600:.1f}h)")
log(f"{'=' * 70}")
log(f" JSON Validity: {total_valid}/{n} ({total_valid/n*100:.1f}%)")
log(f" Structure Correct: {total_struct}/{n} ({total_struct/n*100:.1f}%)")
kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
if kpi_samples:
all_kpi = sum(1 for r in kpi_samples if all(r.get(f, False) for f in kpi_fields))
log(f" All KPIs Correct: {all_kpi}/{len(kpi_samples)} ({all_kpi/len(kpi_samples)*100:.1f}%)")
per_layer_agg = defaultdict(lambda: defaultdict(list))
for r in results:
layer = r["target_layer"]
per_layer_agg[layer]["json_valid"].append(r["json_valid"])
per_layer_agg[layer]["structure_correct"].append(r["structure_correct"])
for k in kpi_fields:
if k in r:
per_layer_agg[layer][k].append(r[k])
log(f"\n{'Layer':<30} {'N':>4} {'JSON':>6} {'Struct':>7} | {'Lat':>5} {'Rel':>5} {'DL':>5} {'UL':>5} {'UEs':>5}")
log("─" * 85)
for layer in sorted(per_layer_agg.keys()):
m = per_layer_agg[layer]
ln = len(m["json_valid"])
jv = sum(m["json_valid"]) / ln * 100
sc = sum(m["structure_correct"]) / ln * 100
def fmt(k):
return f"{sum(m[k])/len(m[k])*100:.0f}%" if k in m and m[k] else "β€”"
log(f" {layer:<28} {ln:>4} {jv:>5.0f}% {sc:>6.0f}% | "
f"{fmt('has_latency'):>5} {fmt('has_reliability'):>5} {fmt('has_dl_throughput'):>5} "
f"{fmt('has_ul_throughput'):>5} {fmt('has_max_ues'):>5}")
log(f"\nβœ… Results saved to {args.output_file}")
log(f" (incremental β€” every sample was saved as it completed)")
if __name__ == "__main__":
main()