File size: 6,377 Bytes
63d52bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | #!/usr/bin/env python3
"""Re-score existing generation predictions with normalized JSON metrics.
This does NOT regenerate model outputs. It reads eval/<split>/predictions.json and writes:
- eval/<split>/normalized_metrics.json
- eval/all_normalized_metrics.json
Normalization removes volatile/generated fields (ids, hrefs, descriptions, timestamps, schema links),
sorts object lists deterministically, and computes normalized exact match + normalized field F1.
Use this alongside raw metrics: raw metrics show deterministic reproduction; normalized metrics better
estimate structural/semantic config agreement.
"""
import argparse
import json
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List
from tmf921_train.utils import aggregate_metrics, canonical_json, field_f1, parse_json, write_json
VOLATILE_KEY_EXACT = {
"id", "uuid", "href", "name", "description", "displayName", "label",
"@schemaLocation", "schemaLocation", "version", "revision",
"createdAt", "updatedAt", "modifiedAt", "lastModified", "timestamp",
"creationDate", "lastUpdate", "requestedStartDate", "requestedCompletionDate",
"startTime", "endTime", "validFrom", "validTo", "validFor",
"correlationId", "requestId", "transactionId", "reservationId",
}
VOLATILE_KEY_FRAGMENTS = ["href", "schema", "timestamp", "uuid", "correlation", "transaction"]
PROTECTED_KEYS = {"sst", "sd", "sliceType", "slice_type", "latency", "reliability", "dl", "ul", "maxUEs", "maxNumberOfUEs"}
ID_LIKE_RE = re.compile(r"\b(?:intent|slice|policy|booking|cell|me|gnb|nsi|nssi|req|report|monitor|assurance)[-_][A-Za-z0-9._:-]+", re.IGNORECASE)
HEX_RE = re.compile(r"\b[0-9a-f]{8,}\b", re.IGNORECASE)
ISO_TIME_RE = re.compile(r"\b\d{4}-\d{2}-\d{2}[T ][0-9:.+-Z]*\b")
def is_volatile_key(key: str) -> bool:
if key in PROTECTED_KEYS:
return False
if key in VOLATILE_KEY_EXACT:
return True
lk = key.lower()
if lk in {k.lower() for k in VOLATILE_KEY_EXACT}:
return True
return any(fragment in lk for fragment in VOLATILE_KEY_FRAGMENTS)
def normalize_string(s: str) -> str:
s = ISO_TIME_RE.sub("<TIME>", s)
s = ID_LIKE_RE.sub("<ID>", s)
s = HEX_RE.sub("<HEX>", s)
return s.strip()
def normalize_json(obj: Any) -> Any:
if isinstance(obj, dict):
out = {}
for k, v in obj.items():
if is_volatile_key(str(k)):
continue
nv = normalize_json(v)
if nv == {} or nv == [] or nv is None:
continue
out[str(k)] = nv
return dict(sorted(out.items(), key=lambda kv: kv[0]))
if isinstance(obj, list):
items = [normalize_json(x) for x in obj]
items = [x for x in items if x not in ({}, [], None)]
return sorted(items, key=lambda x: canonical_json(x))
if isinstance(obj, str):
return normalize_string(obj)
return obj
def keyset_f1(pred_obj: Any, gold_obj: Any) -> Dict[str, float]:
from tmf921_train.utils import flatten_json
pred_keys = set(flatten_json(pred_obj).keys())
gold_keys = set(flatten_json(gold_obj).keys())
tp = len(pred_keys & gold_keys)
fp = len(pred_keys - gold_keys)
fn = len(gold_keys - pred_keys)
p = tp / (tp + fp) if tp + fp else 1.0
r = tp / (tp + fn) if tp + fn else 1.0
f = 2 * p * r / (p + r) if p + r else 0.0
return {"norm_key_precision": p, "norm_key_recall": r, "norm_key_f1": f, "norm_key_tp": tp, "norm_key_fp": fp, "norm_key_fn": fn}
def score_row(row: Dict[str, Any]) -> Dict[str, Any]:
pred_obj, pred_err = parse_json(row.get("prediction", ""))
gold_obj, gold_err = parse_json(row.get("gold", ""))
out = dict(row)
out.pop("prediction", None)
out.pop("gold", None)
out["norm_parse_json"] = pred_obj is not None
out["norm_gold_parse_json"] = gold_obj is not None
out["norm_exact_match"] = False
if pred_obj is None or gold_obj is None:
out.update({
"norm_field_precision": 0.0, "norm_field_recall": 0.0, "norm_field_f1": 0.0,
"norm_key_precision": 0.0, "norm_key_recall": 0.0, "norm_key_f1": 0.0,
})
return out
pred_norm = normalize_json(pred_obj)
gold_norm = normalize_json(gold_obj)
out["norm_exact_match"] = canonical_json(pred_norm) == canonical_json(gold_norm)
ff = field_f1(pred_norm, gold_norm)
out.update({"norm_" + k: v for k, v in ff.items()})
out.update(keyset_f1(pred_norm, gold_norm))
return out
def summarize(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
summary = aggregate_metrics(rows)
for key in ["target_layer", "slice_type", "lifecycle_operation"]:
groups = defaultdict(list)
for r in rows:
groups[str(r.get(key))].append(r)
summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
return summary
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--eval_dir", required=True, help="Directory containing split/predictions.json files, e.g. runs/.../eval_merged")
ap.add_argument("--splits", nargs="*", default=None, help="Optional split names. Defaults to all subdirs with predictions.json")
args = ap.parse_args()
eval_dir = Path(args.eval_dir)
if args.splits:
splits = args.splits
else:
splits = sorted([p.name for p in eval_dir.iterdir() if p.is_dir() and (p / "predictions.json").exists()])
all_metrics = {}
for split in splits:
pred_path = eval_dir / split / "predictions.json"
if not pred_path.exists():
print(f"Skipping {split}: no {pred_path}")
continue
raw_rows = json.loads(pred_path.read_text())
scored = [score_row(r) for r in raw_rows]
split_dir = eval_dir / split
write_json(split_dir / "normalized_predictions_scored.json", scored)
metrics = summarize(scored)
write_json(split_dir / "normalized_metrics.json", metrics)
all_metrics[split] = metrics
print(f"{split}: norm_exact={metrics.get('norm_exact_match', 0):.4f} norm_field_f1={metrics.get('norm_field_f1', 0):.4f} norm_key_f1={metrics.get('norm_key_f1', 0):.4f}")
write_json(eval_dir / "all_normalized_metrics.json", all_metrics)
print(f"Wrote {eval_dir / 'all_normalized_metrics.json'}")
if __name__ == "__main__":
main()
|