PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
File size: 6,377 Bytes
63d52bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""Re-score existing generation predictions with normalized JSON metrics.

This does NOT regenerate model outputs. It reads eval/<split>/predictions.json and writes:
- eval/<split>/normalized_metrics.json
- eval/all_normalized_metrics.json

Normalization removes volatile/generated fields (ids, hrefs, descriptions, timestamps, schema links),
sorts object lists deterministically, and computes normalized exact match + normalized field F1.
Use this alongside raw metrics: raw metrics show deterministic reproduction; normalized metrics better
estimate structural/semantic config agreement.
"""
import argparse
import json
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List

from tmf921_train.utils import aggregate_metrics, canonical_json, field_f1, parse_json, write_json

VOLATILE_KEY_EXACT = {
    "id", "uuid", "href", "name", "description", "displayName", "label",
    "@schemaLocation", "schemaLocation", "version", "revision",
    "createdAt", "updatedAt", "modifiedAt", "lastModified", "timestamp",
    "creationDate", "lastUpdate", "requestedStartDate", "requestedCompletionDate",
    "startTime", "endTime", "validFrom", "validTo", "validFor",
    "correlationId", "requestId", "transactionId", "reservationId",
}
VOLATILE_KEY_FRAGMENTS = ["href", "schema", "timestamp", "uuid", "correlation", "transaction"]
PROTECTED_KEYS = {"sst", "sd", "sliceType", "slice_type", "latency", "reliability", "dl", "ul", "maxUEs", "maxNumberOfUEs"}

ID_LIKE_RE = re.compile(r"\b(?:intent|slice|policy|booking|cell|me|gnb|nsi|nssi|req|report|monitor|assurance)[-_][A-Za-z0-9._:-]+", re.IGNORECASE)
HEX_RE = re.compile(r"\b[0-9a-f]{8,}\b", re.IGNORECASE)
ISO_TIME_RE = re.compile(r"\b\d{4}-\d{2}-\d{2}[T ][0-9:.+-Z]*\b")


def is_volatile_key(key: str) -> bool:
    if key in PROTECTED_KEYS:
        return False
    if key in VOLATILE_KEY_EXACT:
        return True
    lk = key.lower()
    if lk in {k.lower() for k in VOLATILE_KEY_EXACT}:
        return True
    return any(fragment in lk for fragment in VOLATILE_KEY_FRAGMENTS)


def normalize_string(s: str) -> str:
    s = ISO_TIME_RE.sub("<TIME>", s)
    s = ID_LIKE_RE.sub("<ID>", s)
    s = HEX_RE.sub("<HEX>", s)
    return s.strip()


def normalize_json(obj: Any) -> Any:
    if isinstance(obj, dict):
        out = {}
        for k, v in obj.items():
            if is_volatile_key(str(k)):
                continue
            nv = normalize_json(v)
            if nv == {} or nv == [] or nv is None:
                continue
            out[str(k)] = nv
        return dict(sorted(out.items(), key=lambda kv: kv[0]))
    if isinstance(obj, list):
        items = [normalize_json(x) for x in obj]
        items = [x for x in items if x not in ({}, [], None)]
        return sorted(items, key=lambda x: canonical_json(x))
    if isinstance(obj, str):
        return normalize_string(obj)
    return obj


def keyset_f1(pred_obj: Any, gold_obj: Any) -> Dict[str, float]:
    from tmf921_train.utils import flatten_json
    pred_keys = set(flatten_json(pred_obj).keys())
    gold_keys = set(flatten_json(gold_obj).keys())
    tp = len(pred_keys & gold_keys)
    fp = len(pred_keys - gold_keys)
    fn = len(gold_keys - pred_keys)
    p = tp / (tp + fp) if tp + fp else 1.0
    r = tp / (tp + fn) if tp + fn else 1.0
    f = 2 * p * r / (p + r) if p + r else 0.0
    return {"norm_key_precision": p, "norm_key_recall": r, "norm_key_f1": f, "norm_key_tp": tp, "norm_key_fp": fp, "norm_key_fn": fn}


def score_row(row: Dict[str, Any]) -> Dict[str, Any]:
    pred_obj, pred_err = parse_json(row.get("prediction", ""))
    gold_obj, gold_err = parse_json(row.get("gold", ""))
    out = dict(row)
    out.pop("prediction", None)
    out.pop("gold", None)
    out["norm_parse_json"] = pred_obj is not None
    out["norm_gold_parse_json"] = gold_obj is not None
    out["norm_exact_match"] = False
    if pred_obj is None or gold_obj is None:
        out.update({
            "norm_field_precision": 0.0, "norm_field_recall": 0.0, "norm_field_f1": 0.0,
            "norm_key_precision": 0.0, "norm_key_recall": 0.0, "norm_key_f1": 0.0,
        })
        return out
    pred_norm = normalize_json(pred_obj)
    gold_norm = normalize_json(gold_obj)
    out["norm_exact_match"] = canonical_json(pred_norm) == canonical_json(gold_norm)
    ff = field_f1(pred_norm, gold_norm)
    out.update({"norm_" + k: v for k, v in ff.items()})
    out.update(keyset_f1(pred_norm, gold_norm))
    return out


def summarize(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
    summary = aggregate_metrics(rows)
    for key in ["target_layer", "slice_type", "lifecycle_operation"]:
        groups = defaultdict(list)
        for r in rows:
            groups[str(r.get(key))].append(r)
        summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
    return summary


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--eval_dir", required=True, help="Directory containing split/predictions.json files, e.g. runs/.../eval_merged")
    ap.add_argument("--splits", nargs="*", default=None, help="Optional split names. Defaults to all subdirs with predictions.json")
    args = ap.parse_args()

    eval_dir = Path(args.eval_dir)
    if args.splits:
        splits = args.splits
    else:
        splits = sorted([p.name for p in eval_dir.iterdir() if p.is_dir() and (p / "predictions.json").exists()])
    all_metrics = {}
    for split in splits:
        pred_path = eval_dir / split / "predictions.json"
        if not pred_path.exists():
            print(f"Skipping {split}: no {pred_path}")
            continue
        raw_rows = json.loads(pred_path.read_text())
        scored = [score_row(r) for r in raw_rows]
        split_dir = eval_dir / split
        write_json(split_dir / "normalized_predictions_scored.json", scored)
        metrics = summarize(scored)
        write_json(split_dir / "normalized_metrics.json", metrics)
        all_metrics[split] = metrics
        print(f"{split}: norm_exact={metrics.get('norm_exact_match', 0):.4f} norm_field_f1={metrics.get('norm_field_f1', 0):.4f} norm_key_f1={metrics.get('norm_key_f1', 0):.4f}")
    write_json(eval_dir / "all_normalized_metrics.json", all_metrics)
    print(f"Wrote {eval_dir / 'all_normalized_metrics.json'}")


if __name__ == "__main__":
    main()