PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
tmf921-intent-training / scripts /normalize_eval_metrics.py
nraptisss's picture
Add normalized evaluator for existing predictions
63d52bc verified
#!/usr/bin/env python3
"""Re-score existing generation predictions with normalized JSON metrics.
This does NOT regenerate model outputs. It reads eval/<split>/predictions.json and writes:
- eval/<split>/normalized_metrics.json
- eval/all_normalized_metrics.json
Normalization removes volatile/generated fields (ids, hrefs, descriptions, timestamps, schema links),
sorts object lists deterministically, and computes normalized exact match + normalized field F1.
Use this alongside raw metrics: raw metrics show deterministic reproduction; normalized metrics better
estimate structural/semantic config agreement.
"""
import argparse
import json
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List
from tmf921_train.utils import aggregate_metrics, canonical_json, field_f1, parse_json, write_json
VOLATILE_KEY_EXACT = {
"id", "uuid", "href", "name", "description", "displayName", "label",
"@schemaLocation", "schemaLocation", "version", "revision",
"createdAt", "updatedAt", "modifiedAt", "lastModified", "timestamp",
"creationDate", "lastUpdate", "requestedStartDate", "requestedCompletionDate",
"startTime", "endTime", "validFrom", "validTo", "validFor",
"correlationId", "requestId", "transactionId", "reservationId",
}
VOLATILE_KEY_FRAGMENTS = ["href", "schema", "timestamp", "uuid", "correlation", "transaction"]
PROTECTED_KEYS = {"sst", "sd", "sliceType", "slice_type", "latency", "reliability", "dl", "ul", "maxUEs", "maxNumberOfUEs"}
ID_LIKE_RE = re.compile(r"\b(?:intent|slice|policy|booking|cell|me|gnb|nsi|nssi|req|report|monitor|assurance)[-_][A-Za-z0-9._:-]+", re.IGNORECASE)
HEX_RE = re.compile(r"\b[0-9a-f]{8,}\b", re.IGNORECASE)
ISO_TIME_RE = re.compile(r"\b\d{4}-\d{2}-\d{2}[T ][0-9:.+-Z]*\b")
def is_volatile_key(key: str) -> bool:
if key in PROTECTED_KEYS:
return False
if key in VOLATILE_KEY_EXACT:
return True
lk = key.lower()
if lk in {k.lower() for k in VOLATILE_KEY_EXACT}:
return True
return any(fragment in lk for fragment in VOLATILE_KEY_FRAGMENTS)
def normalize_string(s: str) -> str:
s = ISO_TIME_RE.sub("<TIME>", s)
s = ID_LIKE_RE.sub("<ID>", s)
s = HEX_RE.sub("<HEX>", s)
return s.strip()
def normalize_json(obj: Any) -> Any:
if isinstance(obj, dict):
out = {}
for k, v in obj.items():
if is_volatile_key(str(k)):
continue
nv = normalize_json(v)
if nv == {} or nv == [] or nv is None:
continue
out[str(k)] = nv
return dict(sorted(out.items(), key=lambda kv: kv[0]))
if isinstance(obj, list):
items = [normalize_json(x) for x in obj]
items = [x for x in items if x not in ({}, [], None)]
return sorted(items, key=lambda x: canonical_json(x))
if isinstance(obj, str):
return normalize_string(obj)
return obj
def keyset_f1(pred_obj: Any, gold_obj: Any) -> Dict[str, float]:
from tmf921_train.utils import flatten_json
pred_keys = set(flatten_json(pred_obj).keys())
gold_keys = set(flatten_json(gold_obj).keys())
tp = len(pred_keys & gold_keys)
fp = len(pred_keys - gold_keys)
fn = len(gold_keys - pred_keys)
p = tp / (tp + fp) if tp + fp else 1.0
r = tp / (tp + fn) if tp + fn else 1.0
f = 2 * p * r / (p + r) if p + r else 0.0
return {"norm_key_precision": p, "norm_key_recall": r, "norm_key_f1": f, "norm_key_tp": tp, "norm_key_fp": fp, "norm_key_fn": fn}
def score_row(row: Dict[str, Any]) -> Dict[str, Any]:
pred_obj, pred_err = parse_json(row.get("prediction", ""))
gold_obj, gold_err = parse_json(row.get("gold", ""))
out = dict(row)
out.pop("prediction", None)
out.pop("gold", None)
out["norm_parse_json"] = pred_obj is not None
out["norm_gold_parse_json"] = gold_obj is not None
out["norm_exact_match"] = False
if pred_obj is None or gold_obj is None:
out.update({
"norm_field_precision": 0.0, "norm_field_recall": 0.0, "norm_field_f1": 0.0,
"norm_key_precision": 0.0, "norm_key_recall": 0.0, "norm_key_f1": 0.0,
})
return out
pred_norm = normalize_json(pred_obj)
gold_norm = normalize_json(gold_obj)
out["norm_exact_match"] = canonical_json(pred_norm) == canonical_json(gold_norm)
ff = field_f1(pred_norm, gold_norm)
out.update({"norm_" + k: v for k, v in ff.items()})
out.update(keyset_f1(pred_norm, gold_norm))
return out
def summarize(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
summary = aggregate_metrics(rows)
for key in ["target_layer", "slice_type", "lifecycle_operation"]:
groups = defaultdict(list)
for r in rows:
groups[str(r.get(key))].append(r)
summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
return summary
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--eval_dir", required=True, help="Directory containing split/predictions.json files, e.g. runs/.../eval_merged")
ap.add_argument("--splits", nargs="*", default=None, help="Optional split names. Defaults to all subdirs with predictions.json")
args = ap.parse_args()
eval_dir = Path(args.eval_dir)
if args.splits:
splits = args.splits
else:
splits = sorted([p.name for p in eval_dir.iterdir() if p.is_dir() and (p / "predictions.json").exists()])
all_metrics = {}
for split in splits:
pred_path = eval_dir / split / "predictions.json"
if not pred_path.exists():
print(f"Skipping {split}: no {pred_path}")
continue
raw_rows = json.loads(pred_path.read_text())
scored = [score_row(r) for r in raw_rows]
split_dir = eval_dir / split
write_json(split_dir / "normalized_predictions_scored.json", scored)
metrics = summarize(scored)
write_json(split_dir / "normalized_metrics.json", metrics)
all_metrics[split] = metrics
print(f"{split}: norm_exact={metrics.get('norm_exact_match', 0):.4f} norm_field_f1={metrics.get('norm_field_f1', 0):.4f} norm_key_f1={metrics.get('norm_key_f1', 0):.4f}")
write_json(eval_dir / "all_normalized_metrics.json", all_metrics)
print(f"Wrote {eval_dir / 'all_normalized_metrics.json'}")
if __name__ == "__main__":
main()