nraptisss commited on
Commit
f34fb3a
Β·
verified Β·
1 Parent(s): f1d77cf

Fix: flush stdout for nohup, log every sample, add timestamps

Browse files
Files changed (1) hide show
  1. evaluate_v2.py +88 -171
evaluate_v2.py CHANGED
@@ -24,13 +24,14 @@ Changes from v1:
24
  - Standard-specific KPI checking (3 strategies)
25
  - Expanded lifecycle operation key matching
26
  - Saves generated text for error analysis
 
27
 
28
  Usage:
29
  python evaluate_v2.py --adapter_path ./output --num_samples 200
30
  python evaluate_v2.py --adapter_path ./output --num_samples -1
31
  """
32
 
33
- import argparse, json, re, os, sys, math, torch
34
  from collections import defaultdict
35
  from datasets import load_dataset
36
  from transformers import (
@@ -41,6 +42,11 @@ from transformers import (
41
  from peft import PeftModel
42
 
43
 
 
 
 
 
 
44
  def parse_args():
45
  p = argparse.ArgumentParser()
46
  p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
@@ -83,10 +89,8 @@ def try_parse_json(text: str) -> tuple[dict | None, bool]:
83
  def _num_representations(val: float) -> list[str]:
84
  """Generate multiple string representations of a numeric value."""
85
  reps = [str(val)]
86
- # Integer form: 99.0 β†’ "99"
87
  if val == int(val):
88
  reps.append(str(int(val)))
89
- # Also try with fewer/more decimal places
90
  reps.append(f"{val:.1f}")
91
  reps.append(f"{val:.0f}")
92
  return list(set(reps))
@@ -95,25 +99,20 @@ def _num_representations(val: float) -> list[str]:
95
  def _reliability_representations(rel_pct: float) -> list[str]:
96
  """Generate all plausible encodings of a reliability percentage."""
97
  reps = _num_representations(rel_pct)
98
-
99
- # Packet error rate: 99.999% β†’ 1e-05
100
  per = 1 - rel_pct / 100
101
  if per > 0:
102
- # Scientific notation forms
103
  exp = math.floor(math.log10(per))
104
  mantissa = per / (10 ** exp)
105
- reps.append(f"1e-{abs(exp):02d}") # "1e-07"
106
- reps.append(f"1e-{abs(exp)}") # "1e-7"
107
- reps.append(f"{per:.0e}") # "1e-02"
108
- reps.append(f"{per}") # "0.01"
109
  if mantissa == 1.0:
110
  reps.append(f"1e-{abs(exp):02d}")
111
  else:
112
  reps.append(f"{mantissa:.1f}e-{abs(exp):02d}")
113
- # Also check as fraction
114
  if per < 1:
115
  reps.append(f"{per:.10f}".rstrip("0").rstrip("."))
116
-
117
  return list(set(reps))
118
 
119
 
@@ -136,54 +135,21 @@ def _find_all_numbers(parsed: dict) -> list[float]:
136
 
137
 
138
  def _check_kpi_direct(parsed: dict, row: dict, flat: str) -> dict:
139
- """
140
- Direct KPI matching for standards that embed values as-is.
141
- Works for: TMF921, intent_3gpp, CAMARA, ETSI ZSM.
142
- Handles int/float representation differences (99 vs 99.0).
143
- """
144
  results = {}
145
-
146
- # Latency
147
- target_lat = row["latency_ms"]
148
- results["has_latency"] = any(rep in flat for rep in _num_representations(target_lat))
149
-
150
- # Reliability (also check PER encoding e.g. 99.999% β†’ 1e-05)
151
- target_rel = row["reliability_pct"]
152
- results["has_reliability"] = any(rep in flat for rep in _reliability_representations(target_rel))
153
-
154
- # DL Throughput
155
- target_dl = row["dl_throughput_mbps"]
156
- results["has_dl_throughput"] = any(rep in flat for rep in _num_representations(target_dl))
157
-
158
- # UL Throughput
159
- target_ul = row["ul_throughput_mbps"]
160
- results["has_ul_throughput"] = any(rep in flat for rep in _num_representations(target_ul))
161
-
162
- # Max UEs
163
- target_ues = row["max_ues"]
164
- results["has_max_ues"] = any(rep in flat for rep in _num_representations(float(target_ues)))
165
-
166
  return results
167
 
168
 
169
  def _check_kpi_a1_policy(parsed: dict, row: dict, flat: str) -> dict:
170
- """
171
- A1 Policy KPI checking.
172
-
173
- A1 policies encode KPIs as 3GPP QoS parameters:
174
- - reliability_pct β†’ per (packet error rate): 99.999% β†’ 1e-05
175
- - latency_ms β†’ pdb (packet delay budget): mapped via 5QI table, NOT same value
176
- - throughput β†’ gfbr/mfbr (guaranteed/maximum flow bitrate): combined, not DL/UL
177
- - max_ues β†’ not directly encoded (scope uses groupId)
178
-
179
- Strategy: check PER for reliability, check gfbr/mfbr presence for throughput,
180
- check pdb presence for latency. These are TRANSFORMED values β€” the model correctly
181
- maps intent KPIs to standards-specific parameters.
182
- """
183
  results = {}
184
  all_nums = _find_all_numbers(parsed)
185
 
186
- # Reliability: check PER encoding
187
  target_rel = row["reliability_pct"]
188
  rel_found = any(rep in flat for rep in _reliability_representations(target_rel))
189
  if not rel_found:
@@ -194,77 +160,34 @@ def _check_kpi_a1_policy(parsed: dict, row: dict, flat: str) -> dict:
194
  if per > 0 and n > 0 and abs(n - per) / max(per, 1e-15) < 0.1:
195
  rel_found = True; break
196
  results["has_reliability"] = rel_found
197
-
198
- # Latency: A1 uses pdb (packet delay budget) β€” check field exists
199
  results["has_latency"] = '"pdb"' in flat or '"packetdelaybudget"' in flat
200
-
201
- # Throughput: A1 uses gfbr/mfbr β€” check fields exist
202
  has_tput = '"gfbr"' in flat or '"mfbr"' in flat or '"guaranteedflowbitrate"' in flat
203
  results["has_dl_throughput"] = has_tput
204
  results["has_ul_throughput"] = has_tput
205
-
206
- # Max UEs: A1 uses scope.groupId β€” check scope exists
207
  results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat
208
-
209
  return results
210
 
211
 
212
  def _check_kpi_o1_nrm(parsed: dict, row: dict, flat: str) -> dict:
213
- """
214
- O1 NRM KPI checking.
215
-
216
- O-RAN O1 NRM translates intent KPIs into radio resource management configs:
217
- - No direct KPI values β€” they become RRM policy ratios, cell parameters, etc.
218
- - The correct evaluation is: does the output have the right ManagedElement structure
219
- with appropriate NRCellDU, rrmPolicyMemberList, and frequency configs?
220
-
221
- Strategy: check for presence of key O1 NRM structural elements rather than
222
- attempting value matching (which is fundamentally impossible for this standard).
223
- """
224
  results = {}
225
-
226
- # Check for key O1 NRM QoS-related structural elements
227
- results["has_latency"] = (
228
- '"rrmpolicy"' in flat or '"nrcelldu"' in flat
229
- )
230
- results["has_reliability"] = (
231
- '"operationalstate"' in flat or '"administrativestate"' in flat
232
- )
233
- results["has_dl_throughput"] = (
234
- '"bschannelbwdl"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcndl"' in flat
235
- )
236
  results["has_ul_throughput"] = (
237
  '"bschannelbwul"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcnul"' in flat
238
- or '"rrmpolicydedicatedratio"' in flat # UL often uses dedicated ratio
239
- )
240
- results["has_max_ues"] = (
241
- '"rrmpolicymemberlist"' in flat or '"snssai"' in flat
242
  )
243
-
244
  return results
245
 
246
 
247
- # Standards where KPIs are directly embedded as numeric values
248
  DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"}
249
 
250
 
251
  def check_kpi_fields(parsed: dict, row: dict, target_layer: str) -> dict:
252
- """
253
- Standard-aware KPI checking with three strategies:
254
-
255
- 1. Direct layers (TMF921, 3GPP, CAMARA, ETSI ZSM):
256
- KPI values appear directly in JSON β€” use value matching with int/float tolerance.
257
-
258
- 2. A1 Policy:
259
- KPIs are transformed to 3GPP QoS parameters (PER, pdb, gfbr/mfbr).
260
- Check transformed encodings + structural field presence.
261
-
262
- 3. O1 NRM:
263
- KPIs are translated to radio resource configs (RRM policies, cell parameters).
264
- No direct numeric correspondence β€” evaluate via structural element presence.
265
- """
266
  flat = json.dumps(parsed).lower()
267
-
268
  if target_layer in DIRECT_KPI_LAYERS:
269
  return _check_kpi_direct(parsed, row, flat)
270
  elif target_layer == "a1_policy":
@@ -272,7 +195,6 @@ def check_kpi_fields(parsed: dict, row: dict, target_layer: str) -> dict:
272
  elif target_layer == "o1_nrm":
273
  return _check_kpi_o1_nrm(parsed, row, flat)
274
  else:
275
- # Unknown layer β€” fall back to direct matching
276
  return _check_kpi_direct(parsed, row, flat)
277
 
278
 
@@ -295,7 +217,6 @@ LIFECYCLE_LAYERS = {
295
  "tmf921_lifecycle_monitor", "tmf921_lifecycle_report",
296
  }
297
 
298
- # Expanded lifecycle key matching β€” more flexible than v1
299
  LIFECYCLE_KEYS = {
300
  "tmf921_lifecycle_activate": ["intentpatch", "intentactivation"],
301
  "tmf921_lifecycle_modify": ["intentpatch", "intentupdate", "intentmodification"],
@@ -313,12 +234,10 @@ def check_structure(parsed: dict, target_layer: str) -> bool:
313
  """Check if the JSON has the expected root keys for the target standard."""
314
  if target_layer.startswith("adversarial"):
315
  return parsed.get("status") in ADVERSARIAL_STATUSES
316
-
317
  if target_layer in LIFECYCLE_LAYERS:
318
  flat_keys = {k.lower() for k in parsed.keys()}
319
  expected = LIFECYCLE_KEYS.get(target_layer, [])
320
  return any(k in flat_keys for k in expected)
321
-
322
  expected = LAYER_ROOT_KEYS.get(target_layer, [])
323
  if not expected:
324
  return True
@@ -328,38 +247,30 @@ def check_structure(parsed: dict, target_layer: str) -> bool:
328
 
329
  # ── Ground-truth baseline ────────────────────────────────────────────
330
  def compute_gt_baseline(ds):
331
- """
332
- Run the KPI checker against ground truth outputs to establish metric ceiling.
333
- This tells us the maximum score our metric CAN give, even for perfect outputs.
334
- """
335
  gt_results = defaultdict(lambda: defaultdict(list))
336
-
337
  for row in ds:
338
  layer = row["target_layer"]
339
  if layer.startswith("adversarial") or layer in LIFECYCLE_LAYERS:
340
  continue
341
-
342
  gt_text = row["messages"][-1]["content"]
343
  parsed, valid = try_parse_json(gt_text)
344
  if not parsed:
345
  continue
346
-
347
  kpi = check_kpi_fields(parsed, row, layer)
348
  for k, v in kpi.items():
349
  gt_results[layer][k].append(v)
350
-
351
- print("\n Ground-truth baseline (metric ceiling β€” should be 100% for all):")
352
- print(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}")
353
- print(" " + "─" * 55)
354
-
355
  for layer in sorted(gt_results.keys()):
356
  metrics = gt_results[layer]
357
  def rate(key):
358
  vals = metrics.get(key, [])
359
  return sum(vals) / len(vals) * 100 if vals else 0
360
- print(f" {layer:<20} {rate('has_latency'):>7.1f}% {rate('has_reliability'):>7.1f}% "
361
- f"{rate('has_dl_throughput'):>7.1f}% {rate('has_ul_throughput'):>7.1f}% {rate('has_max_ues'):>7.1f}%")
362
-
363
  return gt_results
364
 
365
 
@@ -367,37 +278,36 @@ def compute_gt_baseline(ds):
367
  def main():
368
  args = parse_args()
369
 
370
- print("=" * 70)
371
- print("TMF921 Intent Translation β€” Evaluation v2")
372
- print("=" * 70)
373
- print(f"Base model : {args.base_model}")
374
- print(f"Adapter : {args.adapter_path}")
375
- print(f"Dataset : {args.dataset} [{args.split}]")
376
- print(f"Num samples : {args.num_samples}")
377
- print(f"KPI checking : standard-aware (v2)")
378
- print("=" * 70)
379
 
380
  # Load dataset
381
- print("\nLoading dataset …")
382
  ds = load_dataset(args.dataset, split=args.split)
383
-
384
  # Compute ground-truth baseline on full test set
385
- print("\nComputing ground-truth metric baseline …")
386
  gt_baseline = compute_gt_baseline(ds)
387
-
388
  if args.num_samples > 0:
389
  ds = ds.select(range(min(args.num_samples, len(ds))))
390
- print(f"\n Evaluating on {len(ds)} samples")
391
 
392
  # Load model
393
- print("\nLoading model …")
394
  bnb_config = BitsAndBytesConfig(
395
  load_in_4bit=True,
396
  bnb_4bit_quant_type="nf4",
397
  bnb_4bit_compute_dtype=torch.bfloat16,
398
  bnb_4bit_use_double_quant=True,
399
  )
400
-
401
  model_kwargs = {
402
  "quantization_config": bnb_config,
403
  "device_map": "auto",
@@ -406,32 +316,30 @@ def main():
406
  if args.flash_attn:
407
  model_kwargs["attn_implementation"] = "flash_attention_2"
408
 
409
- base_model = AutoModelForCausalLM.from_pretrained(
410
- args.base_model, **model_kwargs
411
- )
412
  model = PeftModel.from_pretrained(base_model, args.adapter_path)
413
  model.eval()
414
 
415
- tokenizer = AutoTokenizer.from_pretrained(
416
- args.base_model, trust_remote_code=True
417
- )
418
  if tokenizer.pad_token is None:
419
  tokenizer.pad_token = tokenizer.eos_token
420
 
 
 
 
 
421
  # Evaluate
422
- print("\nRunning inference …")
423
  results = []
424
  per_layer = defaultdict(lambda: defaultdict(list))
 
425
 
426
  for i, row in enumerate(ds):
427
- if (i + 1) % 20 == 0 or i == 0:
428
- print(f" [{i+1}/{len(ds)}] …")
429
 
430
  messages = row["messages"]
431
  target_layer = row["target_layer"]
432
  reference_output = messages[-1]["content"]
433
 
434
- # Build prompt (system + user only)
435
  prompt_messages = [m for m in messages if m["role"] != "assistant"]
436
  input_text = tokenizer.apply_chat_template(
437
  prompt_messages, tokenize=False, add_generation_prompt=True
@@ -450,7 +358,6 @@ def main():
450
  generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
451
  generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
452
 
453
- # Parse & validate
454
  parsed, is_valid_json = try_parse_json(generated_text)
455
  has_correct_structure = check_structure(parsed, target_layer) if parsed else False
456
 
@@ -472,20 +379,35 @@ def main():
472
  if args.save_generations:
473
  result["generated_text"] = generated_text
474
  result["reference_text"] = reference_output
475
-
476
  results.append(result)
477
 
478
- # Accumulate per-layer
479
  layer_key = target_layer
480
  per_layer[layer_key]["json_valid"].append(is_valid_json)
481
  per_layer[layer_key]["structure_correct"].append(has_correct_structure)
482
  for k, v in kpi_results.items():
483
  per_layer[layer_key][k].append(v)
484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  # ── Aggregate metrics ────────────────────────────────────────────
486
- print("\n" + "=" * 70)
487
- print("RESULTS (v2 β€” standard-aware KPI matching)")
488
- print("=" * 70)
 
 
 
489
 
490
  total_valid = sum(1 for r in results if r["json_valid"])
491
  total_struct = sum(1 for r in results if r["structure_correct"])
@@ -506,14 +428,12 @@ def main():
506
  all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
507
  overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
508
 
509
- # Adversarial
510
  adv_results = [r for r in results if r["target_layer"].startswith("adversarial")]
511
  if adv_results:
512
  adv_correct = sum(1 for r in adv_results if r["json_valid"] and r["structure_correct"])
513
  overall["adversarial_accuracy"] = adv_correct / len(adv_results)
514
  overall["adversarial_samples"] = len(adv_results)
515
 
516
- # Per-layer breakdown
517
  layer_summary = {}
518
  for layer, metrics in sorted(per_layer.items()):
519
  layer_n = len(metrics["json_valid"])
@@ -526,34 +446,31 @@ def main():
526
  if k in metrics and metrics[k]:
527
  layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
528
 
529
- # Print overall
530
- print(f"\n{'Metric':<35} {'Value':>10}")
531
- print("─" * 47)
532
  for k, v in overall.items():
533
  if isinstance(v, float):
534
- print(f" {k:<33} {v:>9.1%}")
535
  else:
536
- print(f" {k:<33} {v:>9}")
537
 
538
- # Print per-layer with all KPI columns
539
- print(f"\n{'Layer':<25} {'N':>4} {'JSON':>6} {'Struct':>7} {'Lat':>6} {'Rel':>6} {'DL':>6} {'UL':>6} {'UEs':>6} {'All':>6}")
540
- print("─" * 85)
541
  for layer, m in layer_summary.items():
542
  def fmt(key):
543
  return f"{m[key]*100:.0f}%" if key in m else "β€”"
544
- print(f" {layer:<23} {m['n']:>4} {m['json_valid']*100:>5.0f}% {m['structure_correct']*100:>6.0f}% "
545
- f"{fmt('has_latency'):>6} {fmt('has_reliability'):>6} {fmt('has_dl_throughput'):>6} "
546
- f"{fmt('has_ul_throughput'):>6} {fmt('has_max_ues'):>6} ", end="")
547
- # All KPIs correct for this layer
548
  layer_results = [r for r in results if r["target_layer"] == layer]
549
  layer_kpi = [r for r in layer_results if any(k in r for k in kpi_fields)]
550
  if layer_kpi:
551
  all_correct = sum(1 for r in layer_kpi if all(r.get(f, False) for f in kpi_fields))
552
- print(f"{all_correct/len(layer_kpi)*100:>4.0f}%")
553
  else:
554
- print(f"{'β€”':>5}")
 
555
 
556
- # Save
557
  output = {
558
  "config": vars(args),
559
  "overall": overall,
@@ -562,7 +479,7 @@ def main():
562
  }
563
  with open(args.output_file, "w") as f:
564
  json.dump(output, f, indent=2, default=str)
565
- print(f"\nβœ… Results saved to {args.output_file}")
566
 
567
 
568
  if __name__ == "__main__":
 
24
  - Standard-specific KPI checking (3 strategies)
25
  - Expanded lifecycle operation key matching
26
  - Saves generated text for error analysis
27
+ - Flushes stdout on every print (fixes nohup buffering)
28
 
29
  Usage:
30
  python evaluate_v2.py --adapter_path ./output --num_samples 200
31
  python evaluate_v2.py --adapter_path ./output --num_samples -1
32
  """
33
 
34
+ import argparse, json, re, os, sys, math, time, torch
35
  from collections import defaultdict
36
  from datasets import load_dataset
37
  from transformers import (
 
42
  from peft import PeftModel
43
 
44
 
45
+ def log(msg: str):
46
+ """Print with flush so nohup logs update in real time."""
47
+ print(msg, flush=True)
48
+
49
+
50
  def parse_args():
51
  p = argparse.ArgumentParser()
52
  p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B")
 
89
  def _num_representations(val: float) -> list[str]:
90
  """Generate multiple string representations of a numeric value."""
91
  reps = [str(val)]
 
92
  if val == int(val):
93
  reps.append(str(int(val)))
 
94
  reps.append(f"{val:.1f}")
95
  reps.append(f"{val:.0f}")
96
  return list(set(reps))
 
99
  def _reliability_representations(rel_pct: float) -> list[str]:
100
  """Generate all plausible encodings of a reliability percentage."""
101
  reps = _num_representations(rel_pct)
 
 
102
  per = 1 - rel_pct / 100
103
  if per > 0:
 
104
  exp = math.floor(math.log10(per))
105
  mantissa = per / (10 ** exp)
106
+ reps.append(f"1e-{abs(exp):02d}")
107
+ reps.append(f"1e-{abs(exp)}")
108
+ reps.append(f"{per:.0e}")
109
+ reps.append(f"{per}")
110
  if mantissa == 1.0:
111
  reps.append(f"1e-{abs(exp):02d}")
112
  else:
113
  reps.append(f"{mantissa:.1f}e-{abs(exp):02d}")
 
114
  if per < 1:
115
  reps.append(f"{per:.10f}".rstrip("0").rstrip("."))
 
116
  return list(set(reps))
117
 
118
 
 
135
 
136
 
137
  def _check_kpi_direct(parsed: dict, row: dict, flat: str) -> dict:
138
+ """Direct KPI matching for TMF921, intent_3gpp, CAMARA, ETSI ZSM."""
 
 
 
 
139
  results = {}
140
+ results["has_latency"] = any(rep in flat for rep in _num_representations(row["latency_ms"]))
141
+ results["has_reliability"] = any(rep in flat for rep in _reliability_representations(row["reliability_pct"]))
142
+ results["has_dl_throughput"] = any(rep in flat for rep in _num_representations(row["dl_throughput_mbps"]))
143
+ results["has_ul_throughput"] = any(rep in flat for rep in _num_representations(row["ul_throughput_mbps"]))
144
+ results["has_max_ues"] = any(rep in flat for rep in _num_representations(float(row["max_ues"])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  return results
146
 
147
 
148
  def _check_kpi_a1_policy(parsed: dict, row: dict, flat: str) -> dict:
149
+ """A1 Policy: reliability→PER, latency→pdb, throughput→gfbr/mfbr."""
 
 
 
 
 
 
 
 
 
 
 
 
150
  results = {}
151
  all_nums = _find_all_numbers(parsed)
152
 
 
153
  target_rel = row["reliability_pct"]
154
  rel_found = any(rep in flat for rep in _reliability_representations(target_rel))
155
  if not rel_found:
 
160
  if per > 0 and n > 0 and abs(n - per) / max(per, 1e-15) < 0.1:
161
  rel_found = True; break
162
  results["has_reliability"] = rel_found
 
 
163
  results["has_latency"] = '"pdb"' in flat or '"packetdelaybudget"' in flat
 
 
164
  has_tput = '"gfbr"' in flat or '"mfbr"' in flat or '"guaranteedflowbitrate"' in flat
165
  results["has_dl_throughput"] = has_tput
166
  results["has_ul_throughput"] = has_tput
 
 
167
  results["has_max_ues"] = '"scope"' in flat or '"groupid"' in flat
 
168
  return results
169
 
170
 
171
  def _check_kpi_o1_nrm(parsed: dict, row: dict, flat: str) -> dict:
172
+ """O1 NRM: structural element presence (KPIs→RRM policies, not direct values)."""
 
 
 
 
 
 
 
 
 
 
173
  results = {}
174
+ results["has_latency"] = '"rrmpolicy"' in flat or '"nrcelldu"' in flat
175
+ results["has_reliability"] = '"operationalstate"' in flat or '"administrativestate"' in flat
176
+ results["has_dl_throughput"] = '"bschannelbwdl"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcndl"' in flat
 
 
 
 
 
 
 
 
177
  results["has_ul_throughput"] = (
178
  '"bschannelbwul"' in flat or '"rrmpolicymaxratio"' in flat or '"arfcnul"' in flat
179
+ or '"rrmpolicydedicatedratio"' in flat
 
 
 
180
  )
181
+ results["has_max_ues"] = '"rrmpolicymemberlist"' in flat or '"snssai"' in flat
182
  return results
183
 
184
 
 
185
  DIRECT_KPI_LAYERS = {"tmf921", "intent_3gpp", "camara", "etsi_zsm"}
186
 
187
 
188
  def check_kpi_fields(parsed: dict, row: dict, target_layer: str) -> dict:
189
+ """Standard-aware KPI checking: direct / A1 Policy / O1 NRM strategies."""
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  flat = json.dumps(parsed).lower()
 
191
  if target_layer in DIRECT_KPI_LAYERS:
192
  return _check_kpi_direct(parsed, row, flat)
193
  elif target_layer == "a1_policy":
 
195
  elif target_layer == "o1_nrm":
196
  return _check_kpi_o1_nrm(parsed, row, flat)
197
  else:
 
198
  return _check_kpi_direct(parsed, row, flat)
199
 
200
 
 
217
  "tmf921_lifecycle_monitor", "tmf921_lifecycle_report",
218
  }
219
 
 
220
  LIFECYCLE_KEYS = {
221
  "tmf921_lifecycle_activate": ["intentpatch", "intentactivation"],
222
  "tmf921_lifecycle_modify": ["intentpatch", "intentupdate", "intentmodification"],
 
234
  """Check if the JSON has the expected root keys for the target standard."""
235
  if target_layer.startswith("adversarial"):
236
  return parsed.get("status") in ADVERSARIAL_STATUSES
 
237
  if target_layer in LIFECYCLE_LAYERS:
238
  flat_keys = {k.lower() for k in parsed.keys()}
239
  expected = LIFECYCLE_KEYS.get(target_layer, [])
240
  return any(k in flat_keys for k in expected)
 
241
  expected = LAYER_ROOT_KEYS.get(target_layer, [])
242
  if not expected:
243
  return True
 
247
 
248
  # ── Ground-truth baseline ────────────────────────────────────────────
249
  def compute_gt_baseline(ds):
250
+ """Run the KPI checker against ground truth to establish metric ceiling."""
 
 
 
251
  gt_results = defaultdict(lambda: defaultdict(list))
 
252
  for row in ds:
253
  layer = row["target_layer"]
254
  if layer.startswith("adversarial") or layer in LIFECYCLE_LAYERS:
255
  continue
 
256
  gt_text = row["messages"][-1]["content"]
257
  parsed, valid = try_parse_json(gt_text)
258
  if not parsed:
259
  continue
 
260
  kpi = check_kpi_fields(parsed, row, layer)
261
  for k, v in kpi.items():
262
  gt_results[layer][k].append(v)
263
+
264
+ log("\n Ground-truth baseline (metric ceiling β€” should be 100% for all):")
265
+ log(f" {'Layer':<20} {'latency':>8} {'reliab':>8} {'dl_tput':>8} {'ul_tput':>8} {'max_ues':>8}")
266
+ log(" " + "─" * 55)
 
267
  for layer in sorted(gt_results.keys()):
268
  metrics = gt_results[layer]
269
  def rate(key):
270
  vals = metrics.get(key, [])
271
  return sum(vals) / len(vals) * 100 if vals else 0
272
+ log(f" {layer:<20} {rate('has_latency'):>7.1f}% {rate('has_reliability'):>7.1f}% "
273
+ f"{rate('has_dl_throughput'):>7.1f}% {rate('has_ul_throughput'):>7.1f}% {rate('has_max_ues'):>7.1f}%")
 
274
  return gt_results
275
 
276
 
 
278
  def main():
279
  args = parse_args()
280
 
281
+ log("=" * 70)
282
+ log("TMF921 Intent Translation β€” Evaluation v2")
283
+ log("=" * 70)
284
+ log(f"Base model : {args.base_model}")
285
+ log(f"Adapter : {args.adapter_path}")
286
+ log(f"Dataset : {args.dataset} [{args.split}]")
287
+ log(f"Num samples : {args.num_samples}")
288
+ log(f"KPI checking : standard-aware (v2)")
289
+ log("=" * 70)
290
 
291
  # Load dataset
292
+ log("\nLoading dataset …")
293
  ds = load_dataset(args.dataset, split=args.split)
294
+
295
  # Compute ground-truth baseline on full test set
296
+ log("\nComputing ground-truth metric baseline …")
297
  gt_baseline = compute_gt_baseline(ds)
298
+
299
  if args.num_samples > 0:
300
  ds = ds.select(range(min(args.num_samples, len(ds))))
301
+ log(f"\n Evaluating on {len(ds)} samples")
302
 
303
  # Load model
304
+ log("\nLoading model …")
305
  bnb_config = BitsAndBytesConfig(
306
  load_in_4bit=True,
307
  bnb_4bit_quant_type="nf4",
308
  bnb_4bit_compute_dtype=torch.bfloat16,
309
  bnb_4bit_use_double_quant=True,
310
  )
 
311
  model_kwargs = {
312
  "quantization_config": bnb_config,
313
  "device_map": "auto",
 
316
  if args.flash_attn:
317
  model_kwargs["attn_implementation"] = "flash_attention_2"
318
 
319
+ base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
 
 
320
  model = PeftModel.from_pretrained(base_model, args.adapter_path)
321
  model.eval()
322
 
323
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
 
 
324
  if tokenizer.pad_token is None:
325
  tokenizer.pad_token = tokenizer.eos_token
326
 
327
+ log("βœ… Model loaded successfully")
328
+ log(f"\nStarting inference on {len(ds)} samples …")
329
+ log(f" (First sample may take 1-2 min for CUDA warmup)\n")
330
+
331
  # Evaluate
 
332
  results = []
333
  per_layer = defaultdict(lambda: defaultdict(list))
334
+ t_start = time.time()
335
 
336
  for i, row in enumerate(ds):
337
+ t0 = time.time()
 
338
 
339
  messages = row["messages"]
340
  target_layer = row["target_layer"]
341
  reference_output = messages[-1]["content"]
342
 
 
343
  prompt_messages = [m for m in messages if m["role"] != "assistant"]
344
  input_text = tokenizer.apply_chat_template(
345
  prompt_messages, tokenize=False, add_generation_prompt=True
 
358
  generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
359
  generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
360
 
 
361
  parsed, is_valid_json = try_parse_json(generated_text)
362
  has_correct_structure = check_structure(parsed, target_layer) if parsed else False
363
 
 
379
  if args.save_generations:
380
  result["generated_text"] = generated_text
381
  result["reference_text"] = reference_output
382
+
383
  results.append(result)
384
 
 
385
  layer_key = target_layer
386
  per_layer[layer_key]["json_valid"].append(is_valid_json)
387
  per_layer[layer_key]["structure_correct"].append(has_correct_structure)
388
  for k, v in kpi_results.items():
389
  per_layer[layer_key][k].append(v)
390
 
391
+ # Progress logging β€” every sample with ETA
392
+ elapsed = time.time() - t_start
393
+ sample_time = time.time() - t0
394
+ avg_time = elapsed / (i + 1)
395
+ remaining = avg_time * (len(ds) - i - 1)
396
+ eta_h, eta_m = divmod(int(remaining), 3600)
397
+ eta_m = eta_m // 60
398
+
399
+ json_ok = "βœ“" if is_valid_json else "βœ—"
400
+ struct_ok = "βœ“" if has_correct_structure else "βœ—"
401
+ log(f" [{i+1:>4}/{len(ds)}] {target_layer:<25} JSON:{json_ok} Struct:{struct_ok} "
402
+ f"| {sample_time:.1f}s | ETA: {eta_h}h{eta_m:02d}m")
403
+
404
  # ── Aggregate metrics ────────────────────────────────────────────
405
+ total_time = time.time() - t_start
406
+ log(f"\n Total inference time: {total_time/3600:.1f}h ({total_time/len(ds):.1f}s/sample)")
407
+
408
+ log("\n" + "=" * 70)
409
+ log("RESULTS (v2 β€” standard-aware KPI matching)")
410
+ log("=" * 70)
411
 
412
  total_valid = sum(1 for r in results if r["json_valid"])
413
  total_struct = sum(1 for r in results if r["structure_correct"])
 
428
  all_kpi = [all(r.get(f, False) for f in kpi_fields) for r in kpi_samples]
429
  overall["all_kpis_correct_rate"] = sum(all_kpi) / len(all_kpi)
430
 
 
431
  adv_results = [r for r in results if r["target_layer"].startswith("adversarial")]
432
  if adv_results:
433
  adv_correct = sum(1 for r in adv_results if r["json_valid"] and r["structure_correct"])
434
  overall["adversarial_accuracy"] = adv_correct / len(adv_results)
435
  overall["adversarial_samples"] = len(adv_results)
436
 
 
437
  layer_summary = {}
438
  for layer, metrics in sorted(per_layer.items()):
439
  layer_n = len(metrics["json_valid"])
 
446
  if k in metrics and metrics[k]:
447
  layer_summary[layer][k] = sum(metrics[k]) / len(metrics[k])
448
 
449
+ log(f"\n{'Metric':<35} {'Value':>10}")
450
+ log("─" * 47)
 
451
  for k, v in overall.items():
452
  if isinstance(v, float):
453
+ log(f" {k:<33} {v:>9.1%}")
454
  else:
455
+ log(f" {k:<33} {v:>9}")
456
 
457
+ log(f"\n{'Layer':<25} {'N':>4} {'JSON':>6} {'Struct':>7} {'Lat':>6} {'Rel':>6} {'DL':>6} {'UL':>6} {'UEs':>6} {'All':>6}")
458
+ log("─" * 85)
 
459
  for layer, m in layer_summary.items():
460
  def fmt(key):
461
  return f"{m[key]*100:.0f}%" if key in m else "β€”"
462
+ line = (f" {layer:<23} {m['n']:>4} {m['json_valid']*100:>5.0f}% {m['structure_correct']*100:>6.0f}% "
463
+ f"{fmt('has_latency'):>6} {fmt('has_reliability'):>6} {fmt('has_dl_throughput'):>6} "
464
+ f"{fmt('has_ul_throughput'):>6} {fmt('has_max_ues'):>6} ")
 
465
  layer_results = [r for r in results if r["target_layer"] == layer]
466
  layer_kpi = [r for r in layer_results if any(k in r for k in kpi_fields)]
467
  if layer_kpi:
468
  all_correct = sum(1 for r in layer_kpi if all(r.get(f, False) for f in kpi_fields))
469
+ line += f"{all_correct/len(layer_kpi)*100:>4.0f}%"
470
  else:
471
+ line += f"{'β€”':>5}"
472
+ log(line)
473
 
 
474
  output = {
475
  "config": vars(args),
476
  "overall": overall,
 
479
  }
480
  with open(args.output_file, "w") as f:
481
  json.dump(output, f, indent=2, default=str)
482
+ log(f"\nβœ… Results saved to {args.output_file}")
483
 
484
 
485
  if __name__ == "__main__":