File size: 6,340 Bytes
d9ba941 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import json
import math
import re
import hashlib
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
import yaml
SST_MAP = {"eMBB": 1, "URLLC": 2, "mMTC": 3, "V2X": 4, "HMTC": 5, "MPS": 5, "N/A": 0}
KPI_RANGES = {
"eMBB": {"latency_ms": (10, 100), "reliability_pct": (99.0, 99.99), "dl_throughput_mbps": (50, 1000), "ul_throughput_mbps": (10, 200), "max_ues": (50, 50000)},
"URLLC": {"latency_ms": (0.5, 10), "reliability_pct": (99.999, 99.9999), "dl_throughput_mbps": (10, 500), "ul_throughput_mbps": (5, 200), "max_ues": (5, 5000)},
"mMTC": {"latency_ms": (10, 1000), "reliability_pct": (90.0, 99.99), "dl_throughput_mbps": (0.1, 50), "ul_throughput_mbps": (0.05, 25), "max_ues": (1000, 1000000)},
"V2X": {"latency_ms": (1, 20), "reliability_pct": (99.99, 99.9999), "dl_throughput_mbps": (10, 300), "ul_throughput_mbps": (5, 150), "max_ues": (10, 10000)},
"MPS": {"latency_ms": (1, 30), "reliability_pct": (99.999, 99.9999), "dl_throughput_mbps": (10, 200), "ul_throughput_mbps": (5, 100), "max_ues": (5, 5000)},
"HMTC": {"latency_ms": (0.1, 5), "reliability_pct": (99.9999, 99.99999), "dl_throughput_mbps": (100, 2000), "ul_throughput_mbps": (50, 1000), "max_ues": (100, 100000)},
}
def load_config(path: str | Path) -> Dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def write_json(path: str | Path, obj: Any) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, indent=2, ensure_ascii=False)
def get_message(messages: List[Dict[str, str]], role: str) -> str:
vals = [m.get("content", "") for m in messages if m.get("role") == role]
return vals[0] if vals else ""
def strip_code_fence(text: str) -> str:
text = text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s*```$", "", text)
return text.strip()
def extract_json_text(text: str) -> str:
text = strip_code_fence(text)
if not text:
return text
# If generation includes role markers or extra text, extract first complete-looking object.
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
return text[start : end + 1].strip()
return text
def parse_json(text: str) -> Tuple[Optional[Any], Optional[str]]:
candidate = extract_json_text(text)
try:
return json.loads(candidate), None
except Exception as e:
return None, f"{type(e).__name__}: {str(e)[:200]}"
def canonical_json(obj: Any) -> str:
return json.dumps(obj, sort_keys=True, ensure_ascii=False, separators=(",", ":"))
def flatten_json(obj: Any, prefix: str = "") -> Dict[str, Any]:
out: Dict[str, Any] = {}
if isinstance(obj, dict):
for k, v in obj.items():
p = f"{prefix}.{k}" if prefix else str(k)
out.update(flatten_json(v, p))
elif isinstance(obj, list):
for i, v in enumerate(obj):
p = f"{prefix}[{i}]"
out.update(flatten_json(v, p))
else:
out[prefix] = obj
return out
def json_exact_match(pred_obj: Any, gold_obj: Any) -> bool:
return canonical_json(pred_obj) == canonical_json(gold_obj)
def field_f1(pred_obj: Any, gold_obj: Any) -> Dict[str, float]:
pred = flatten_json(pred_obj)
gold = flatten_json(gold_obj)
pred_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in pred.items())
gold_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in gold.items())
tp = len(pred_items & gold_items)
fp = len(pred_items - gold_items)
fn = len(gold_items - pred_items)
precision = tp / (tp + fp) if (tp + fp) else 1.0
recall = tp / (tp + fn) if (tp + fn) else 1.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
return {"field_precision": precision, "field_recall": recall, "field_f1": f1, "field_tp": tp, "field_fp": fp, "field_fn": fn}
def metadata_constraint_pass(example: Dict[str, Any], pred_text: str, pred_obj: Any) -> Dict[str, bool]:
text = canonical_json(pred_obj) if pred_obj is not None else pred_text
sl = example.get("slice_type")
target_layer = str(example.get("target_layer", ""))
if target_layer.startswith("adversarial") or sl == "N/A":
return {
"slice_sst_pass": True,
"kpi_text_presence_pass": True,
"adversarial_status_pass": any(s in text for s in ["CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"]),
}
sst_pass = str(example.get("sst")) in text or f'"sst":{int(example.get("sst"))}' in text.replace(" ", "")
# Basic exact value presence. Some layers encode values indirectly, so this is a conservative diagnostic not a hard schema validator.
vals = [example.get("slice_type"), example.get("sd"), example.get("use_case"), example.get("region")]
numeric = [example.get("latency_ms"), example.get("dl_throughput_mbps"), example.get("ul_throughput_mbps"), example.get("max_ues")]
present = 0
total = 0
for v in vals:
if v is None:
continue
total += 1
if str(v) in text:
present += 1
for v in numeric:
if v is None:
continue
total += 1
fv = float(v)
candidates = {str(v), str(int(fv)) if fv.is_integer() else str(fv)}
if any(c in text for c in candidates):
present += 1
return {
"slice_sst_pass": bool(sst_pass),
"kpi_text_presence_pass": (present / total >= 0.5) if total else True,
"adversarial_status_pass": True,
}
def aggregate_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
if not rows:
return {}
keys = [k for k, v in rows[0].items() if isinstance(v, (int, float, bool))]
agg: Dict[str, Any] = {"num_examples": len(rows)}
for k in keys:
vals = [float(r[k]) for r in rows if k in r and r[k] is not None]
if vals:
agg[k] = sum(vals) / len(vals)
return agg
def short_hash(text: str, n: int = 12) -> str:
return hashlib.sha256(text.encode("utf-8", errors="ignore")).hexdigest()[:n]
|