PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
File size: 6,340 Bytes
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import json
import math
import re
import hashlib
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import yaml

SST_MAP = {"eMBB": 1, "URLLC": 2, "mMTC": 3, "V2X": 4, "HMTC": 5, "MPS": 5, "N/A": 0}
KPI_RANGES = {
    "eMBB": {"latency_ms": (10, 100), "reliability_pct": (99.0, 99.99), "dl_throughput_mbps": (50, 1000), "ul_throughput_mbps": (10, 200), "max_ues": (50, 50000)},
    "URLLC": {"latency_ms": (0.5, 10), "reliability_pct": (99.999, 99.9999), "dl_throughput_mbps": (10, 500), "ul_throughput_mbps": (5, 200), "max_ues": (5, 5000)},
    "mMTC": {"latency_ms": (10, 1000), "reliability_pct": (90.0, 99.99), "dl_throughput_mbps": (0.1, 50), "ul_throughput_mbps": (0.05, 25), "max_ues": (1000, 1000000)},
    "V2X": {"latency_ms": (1, 20), "reliability_pct": (99.99, 99.9999), "dl_throughput_mbps": (10, 300), "ul_throughput_mbps": (5, 150), "max_ues": (10, 10000)},
    "MPS": {"latency_ms": (1, 30), "reliability_pct": (99.999, 99.9999), "dl_throughput_mbps": (10, 200), "ul_throughput_mbps": (5, 100), "max_ues": (5, 5000)},
    "HMTC": {"latency_ms": (0.1, 5), "reliability_pct": (99.9999, 99.99999), "dl_throughput_mbps": (100, 2000), "ul_throughput_mbps": (50, 1000), "max_ues": (100, 100000)},
}


def load_config(path: str | Path) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def write_json(path: str | Path, obj: Any) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)


def get_message(messages: List[Dict[str, str]], role: str) -> str:
    vals = [m.get("content", "") for m in messages if m.get("role") == role]
    return vals[0] if vals else ""


def strip_code_fence(text: str) -> str:
    text = text.strip()
    if text.startswith("```"):
        text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
        text = re.sub(r"\s*```$", "", text)
    return text.strip()


def extract_json_text(text: str) -> str:
    text = strip_code_fence(text)
    if not text:
        return text
    # If generation includes role markers or extra text, extract first complete-looking object.
    start = text.find("{")
    end = text.rfind("}")
    if start >= 0 and end > start:
        return text[start : end + 1].strip()
    return text


def parse_json(text: str) -> Tuple[Optional[Any], Optional[str]]:
    candidate = extract_json_text(text)
    try:
        return json.loads(candidate), None
    except Exception as e:
        return None, f"{type(e).__name__}: {str(e)[:200]}"


def canonical_json(obj: Any) -> str:
    return json.dumps(obj, sort_keys=True, ensure_ascii=False, separators=(",", ":"))


def flatten_json(obj: Any, prefix: str = "") -> Dict[str, Any]:
    out: Dict[str, Any] = {}
    if isinstance(obj, dict):
        for k, v in obj.items():
            p = f"{prefix}.{k}" if prefix else str(k)
            out.update(flatten_json(v, p))
    elif isinstance(obj, list):
        for i, v in enumerate(obj):
            p = f"{prefix}[{i}]"
            out.update(flatten_json(v, p))
    else:
        out[prefix] = obj
    return out


def json_exact_match(pred_obj: Any, gold_obj: Any) -> bool:
    return canonical_json(pred_obj) == canonical_json(gold_obj)


def field_f1(pred_obj: Any, gold_obj: Any) -> Dict[str, float]:
    pred = flatten_json(pred_obj)
    gold = flatten_json(gold_obj)
    pred_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in pred.items())
    gold_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in gold.items())
    tp = len(pred_items & gold_items)
    fp = len(pred_items - gold_items)
    fn = len(gold_items - pred_items)
    precision = tp / (tp + fp) if (tp + fp) else 1.0
    recall = tp / (tp + fn) if (tp + fn) else 1.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return {"field_precision": precision, "field_recall": recall, "field_f1": f1, "field_tp": tp, "field_fp": fp, "field_fn": fn}


def metadata_constraint_pass(example: Dict[str, Any], pred_text: str, pred_obj: Any) -> Dict[str, bool]:
    text = canonical_json(pred_obj) if pred_obj is not None else pred_text
    sl = example.get("slice_type")
    target_layer = str(example.get("target_layer", ""))
    if target_layer.startswith("adversarial") or sl == "N/A":
        return {
            "slice_sst_pass": True,
            "kpi_text_presence_pass": True,
            "adversarial_status_pass": any(s in text for s in ["CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"]),
        }
    sst_pass = str(example.get("sst")) in text or f'"sst":{int(example.get("sst"))}' in text.replace(" ", "")
    # Basic exact value presence. Some layers encode values indirectly, so this is a conservative diagnostic not a hard schema validator.
    vals = [example.get("slice_type"), example.get("sd"), example.get("use_case"), example.get("region")]
    numeric = [example.get("latency_ms"), example.get("dl_throughput_mbps"), example.get("ul_throughput_mbps"), example.get("max_ues")]
    present = 0
    total = 0
    for v in vals:
        if v is None:
            continue
        total += 1
        if str(v) in text:
            present += 1
    for v in numeric:
        if v is None:
            continue
        total += 1
        fv = float(v)
        candidates = {str(v), str(int(fv)) if fv.is_integer() else str(fv)}
        if any(c in text for c in candidates):
            present += 1
    return {
        "slice_sst_pass": bool(sst_pass),
        "kpi_text_presence_pass": (present / total >= 0.5) if total else True,
        "adversarial_status_pass": True,
    }


def aggregate_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
    if not rows:
        return {}
    keys = [k for k, v in rows[0].items() if isinstance(v, (int, float, bool))]
    agg: Dict[str, Any] = {"num_examples": len(rows)}
    for k in keys:
        vals = [float(r[k]) for r in rows if k in r and r[k] is not None]
        if vals:
            agg[k] = sum(vals) / len(vals)
    return agg


def short_hash(text: str, n: int = 12) -> str:
    return hashlib.sha256(text.encode("utf-8", errors="ignore")).hexdigest()[:n]