nraptisss commited on
Commit
15addfa
Β·
verified Β·
1 Parent(s): aee8025

Upload evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +313 -0
evaluate.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ TMF921 Intent Translation β€” Evaluation Script
4
+ ==============================================
5
+ Evaluates a fine-tuned QLoRA model on the test split with metrics:
6
+ 1. JSON Schema Validity β€” is the output valid JSON?
7
+ 2. KPI Field Extraction β€” are latency/throughput/reliability/UEs present & correct?
8
+ 3. Cross-Standard Output β€” correct structure per target_layer?
9
+ 4. Adversarial F1 β€” correct rejection of bad intents
10
+ 5. Lifecycle Accuracy β€” correct lifecycle operation format
11
+
12
+ Usage:
13
+ python evaluate.py --adapter_path ./output --num_samples 200
14
+ python evaluate.py --adapter_path nraptisss/Qwen3-8B-TMF921-Intent-QLora --num_samples -1
15
+ """
16
+
17
+ import argparse, json, re, os, sys, torch
18
+ from collections import defaultdict
19
+ from datasets import load_dataset
20
+ from transformers import (
21
+ AutoModelForCausalLM,
22
+ AutoTokenizer,
23
+ BitsAndBytesConfig,
24
+ )
25
+ from peft import PeftModel
26
+
27
+
28
+ def parse_args():
29
+ p = argparse.ArgumentParser()
30
+ p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
31
+ p.add_argument("--adapter_path", type=str, default="./output",
32
+ help="Path or HF id of LoRA adapter")
33
+ p.add_argument("--dataset", type=str,
34
+ default="nraptisss/TMF921-intent-to-config-augmented")
35
+ p.add_argument("--split", type=str, default="test")
36
+ p.add_argument("--num_samples", type=int, default=200,
37
+ help="Number of samples to evaluate (-1 for all)")
38
+ p.add_argument("--max_new_tokens", type=int, default=4096)
39
+ p.add_argument("--output_file", type=str, default="eval_results.json")
40
+ p.add_argument("--flash_attn", action="store_true", default=True)
41
+ return p.parse_args()
42
+
43
+
44
+ # ── Validation helpers ───────────────────────────────────────────────
45
+ def try_parse_json(text: str) -> tuple[dict | None, bool]:
46
+ """Try to parse JSON from model output, handling markdown fences."""
47
+ text = text.strip()
48
+ # Remove markdown code fences
49
+ if text.startswith("```"):
50
+ text = re.sub(r"^```(?:json)?\s*\n?", "", text)
51
+ text = re.sub(r"\n?```\s*$", "", text)
52
+ # Try direct parse
53
+ try:
54
+ return json.loads(text), True
55
+ except json.JSONDecodeError:
56
+ pass
57
+ # Try to find JSON object in text
58
+ match = re.search(r"\{[\s\S]*\}", text)
59
+ if match:
60
+ try:
61
+ return json.loads(match.group()), True
62
+ except json.JSONDecodeError:
63
+ pass
64
+ return None, False
65
+
66
+
67
+ def check_kpi_fields(parsed: dict, row: dict) -> dict:
68
+ """Check if the generated config contains correct KPI values."""
69
+ flat = json.dumps(parsed).lower()
70
+ results = {}
71
+
72
+ # Check latency
73
+ target_latency = row["latency_ms"]
74
+ results["has_latency"] = str(int(target_latency)) in flat or str(target_latency) in flat
75
+
76
+ # Check reliability
77
+ target_rel = row["reliability_pct"]
78
+ results["has_reliability"] = str(target_rel) in flat
79
+
80
+ # Check DL throughput
81
+ target_dl = row["dl_throughput_mbps"]
82
+ results["has_dl_throughput"] = str(int(target_dl)) in flat or str(target_dl) in flat
83
+
84
+ # Check UL throughput
85
+ target_ul = row["ul_throughput_mbps"]
86
+ results["has_ul_throughput"] = str(int(target_ul)) in flat or str(target_ul) in flat
87
+
88
+ # Check max UEs
89
+ target_ues = row["max_ues"]
90
+ results["has_max_ues"] = str(target_ues) in flat
91
+
92
+ return results
93
+
94
+
95
+ LAYER_ROOT_KEYS = {
96
+ "tmf921": ["id", "href", "name", "intentexpression"],
97
+ "intent_3gpp": ["intent"],
98
+ "camara": ["networkslicebooking"],
99
+ "etsi_zsm": ["zsmintent"],
100
+ "a1_policy": ["a1policy"],
101
+ "o1_nrm": ["managedelement"],
102
+ }
103
+
104
+ ADVERSARIAL_STATUSES = {"CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"}
105
+
106
+ LIFECYCLE_LAYERS = {
107
+ "tmf921_lifecycle_activate", "tmf921_lifecycle_modify",
108
+ "tmf921_lifecycle_suspend", "tmf921_lifecycle_resume",
109
+ "tmf921_lifecycle_terminate", "tmf921_lifecycle_scale",
110
+ "tmf921_lifecycle_monitor", "tmf921_lifecycle_report",
111
+ }
112
+
113
+
114
+ def check_structure(parsed: dict, target_layer: str) -> bool:
115
+ """Check if the JSON has the expected root keys for the target standard."""
116
+ if target_layer.startswith("adversarial"):
117
+ return parsed.get("status") in ADVERSARIAL_STATUSES
118
+ if target_layer in LIFECYCLE_LAYERS:
119
+ flat_keys = {k.lower() for k in parsed.keys()}
120
+ return "intentpatch" in flat_keys or "intentassurancereport" in flat_keys or "intentupdate" in flat_keys
121
+ expected = LAYER_ROOT_KEYS.get(target_layer, [])
122
+ if not expected:
123
+ return True
124
+ flat_keys = {k.lower() for k in parsed.keys()}
125
+ return any(k in flat_keys for k in expected)
126
+
127
+
128
+ # ── Main evaluation ──────────────────────────────────────────────────
129
+ def main():
130
+ args = parse_args()
131
+
132
+ print("=" * 70)
133
+ print("TMF921 Intent Translation β€” Evaluation")
134
+ print("=" * 70)
135
+ print(f"Base model : {args.base_model}")
136
+ print(f"Adapter : {args.adapter_path}")
137
+ print(f"Dataset : {args.dataset} [{args.split}]")
138
+ print(f"Num samples : {args.num_samples}")
139
+ print("=" * 70)
140
+
141
+ # Load dataset
142
+ print("\nLoading dataset …")
143
+ ds = load_dataset(args.dataset, split=args.split)
144
+ if args.num_samples > 0:
145
+ ds = ds.select(range(min(args.num_samples, len(ds))))
146
+ print(f" Evaluating on {len(ds)} samples")
147
+
148
+ # Load model
149
+ print("\nLoading model …")
150
+ bnb_config = BitsAndBytesConfig(
151
+ load_in_4bit=True,
152
+ bnb_4bit_quant_type="nf4",
153
+ bnb_4bit_compute_dtype=torch.bfloat16,
154
+ bnb_4bit_use_double_quant=True,
155
+ )
156
+
157
+ model_kwargs = {
158
+ "quantization_config": bnb_config,
159
+ "device_map": "auto",
160
+ "trust_remote_code": True,
161
+ }
162
+ if args.flash_attn:
163
+ model_kwargs["attn_implementation"] = "flash_attention_2"
164
+
165
+ base_model = AutoModelForCausalLM.from_pretrained(
166
+ args.base_model, **model_kwargs
167
+ )
168
+ model = PeftModel.from_pretrained(base_model, args.adapter_path)
169
+ model.eval()
170
+
171
+ tokenizer = AutoTokenizer.from_pretrained(
172
+ args.base_model, trust_remote_code=True
173
+ )
174
+ if tokenizer.pad_token is None:
175
+ tokenizer.pad_token = tokenizer.eos_token
176
+
177
+ # Evaluate
178
+ print("\nRunning inference …")
179
+ results = []
180
+ per_layer = defaultdict(lambda: defaultdict(list))
181
+
182
+ for i, row in enumerate(ds):
183
+ if (i + 1) % 20 == 0 or i == 0:
184
+ print(f" [{i+1}/{len(ds)}] …")
185
+
186
+ messages = row["messages"]
187
+ target_layer = row["target_layer"]
188
+ reference_output = messages[-1]["content"] # ground truth
189
+
190
+ # Build prompt (system + user only)
191
+ prompt_messages = [m for m in messages if m["role"] != "assistant"]
192
+ input_text = tokenizer.apply_chat_template(
193
+ prompt_messages, tokenize=False, add_generation_prompt=True
194
+ )
195
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
196
+
197
+ with torch.no_grad():
198
+ output_ids = model.generate(
199
+ **inputs,
200
+ max_new_tokens=args.max_new_tokens,
201
+ do_sample=False,
202
+ temperature=None,
203
+ top_p=None,
204
+ )
205
+
206
+ # Decode only the new tokens
207
+ generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
208
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
209
+
210
+ # Parse & validate
211
+ parsed, is_valid_json = try_parse_json(generated_text)
212
+ has_correct_structure = check_structure(parsed, target_layer) if parsed else False
213
+
214
+ kpi_results = {}
215
+ if parsed and not target_layer.startswith("adversarial") and target_layer not in LIFECYCLE_LAYERS:
216
+ kpi_results = check_kpi_fields(parsed, row)
217
+
218
+ result = {
219
+ "id": row["id"],
220
+ "target_layer": target_layer,
221
+ "slice_type": row["slice_type"],
222
+ "lifecycle_operation": row["lifecycle_operation"],
223
+ "json_valid": is_valid_json,
224
+ "structure_correct": has_correct_structure,
225
+ **kpi_results,
226
+ "generated_length": len(generated_text),
227
+ "reference_length": len(reference_output),
228
+ }
229
+ results.append(result)
230
+
231
+ # Accumulate per-layer
232
+ layer_key = target_layer if target_layer.startswith("adversarial") or target_layer in LIFECYCLE_LAYERS else target_layer
233
+ per_layer[layer_key]["json_valid"].append(is_valid_json)
234
+ per_layer[layer_key]["structure_correct"].append(has_correct_structure)
235
+ for k, v in kpi_results.items():
236
+ per_layer[layer_key][k].append(v)
237
+
238
+ # ── Aggregate metrics ────────────────────────────────────────────
239
+ print("\n" + "=" * 70)
240
+ print("RESULTS")
241
+ print("=" * 70)
242
+
243
+ total_valid = sum(1 for r in results if r["json_valid"])
244
+ total_struct = sum(1 for r in results if r["structure_correct"])
245
+ n = len(results)
246
+
247
+ # Overall
248
+ overall = {
249
+ "total_samples": n,
250
+ "json_validity_rate": total_valid / n,
251
+ "structure_correctness_rate": total_struct / n,
252
+ }
253
+
254
+ # KPI accuracy (only for create operations on standard layers)
255
+ kpi_fields = ["has_latency", "has_reliability", "has_dl_throughput", "has_ul_throughput", "has_max_ues"]
256
+ kpi_samples = [r for r in results if any(k in r for k in kpi_fields)]
257
+ if kpi_samples:
258
+ for field in kpi_fields:
259
+ vals = [r.get(field, False) for r in kpi_samples]
260
+ overall[field + "_rate"] = sum(vals) / len(vals) if vals else 0.0
261
+ all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
262
+ overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
263
+
264
+ # Adversarial
265
+ adv_results = [r for r in results if r["target_layer"].startswith("adversarial")]
266
+ if adv_results:
267
+ adv_correct = sum(1 for r in adv_results if r["json_valid"] and r["structure_correct"])
268
+ overall["adversarial_accuracy"] = adv_correct / len(adv_results)
269
+ overall["adversarial_samples"] = len(adv_results)
270
+
271
+ # Per-layer breakdown
272
+ layer_summary = {}
273
+ for layer, metrics in sorted(per_layer.items()):
274
+ layer_n = len(metrics["json_valid"])
275
+ layer_summary[layer] = {
276
+ "n": layer_n,
277
+ "json_valid": sum(metrics["json_valid"]) / layer_n,
278
+ "structure_correct": sum(metrics["structure_correct"]) / layer_n,
279
+ }
280
+ for k in kpi_fields:
281
+ if k in metrics and metrics[k]:
282
+ layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
283
+
284
+ # Print
285
+ print(f"\n{'Metric':<35} {'Value':>10}")
286
+ print("─" * 47)
287
+ for k, v in overall.items():
288
+ if isinstance(v, float):
289
+ print(f" {k:<33} {v:>9.1%}")
290
+ else:
291
+ print(f" {k:<33} {v:>9}")
292
+
293
+ print(f"\n{'Layer':<35} {'N':>5} {'JSON%':>7} {'Struct%':>8} {'AllKPI%':>8}")
294
+ print("─" * 65)
295
+ for layer, m in layer_summary.items():
296
+ kpi_str = f"{m.get('has_latency', 0):.0%}" if "has_latency" in m else "β€”"
297
+ print(f" {layer:<33} {m['n']:>5} {m['json_valid']:>6.1%} "
298
+ f"{m['structure_correct']:>7.1%} {kpi_str:>6}")
299
+
300
+ # Save
301
+ output = {
302
+ "config": vars(args),
303
+ "overall": overall,
304
+ "per_layer": layer_summary,
305
+ "raw_results": results,
306
+ }
307
+ with open(args.output_file, "w") as f:
308
+ json.dump(output, f, indent=2, default=str)
309
+ print(f"\nβœ… Results saved to {args.output_file}")
310
+
311
+
312
+ if __name__ == "__main__":
313
+ main()