import json import math import re import hashlib from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple import yaml SST_MAP = {"eMBB": 1, "URLLC": 2, "mMTC": 3, "V2X": 4, "HMTC": 5, "MPS": 5, "N/A": 0} KPI_RANGES = { "eMBB": {"latency_ms": (10, 100), "reliability_pct": (99.0, 99.99), "dl_throughput_mbps": (50, 1000), "ul_throughput_mbps": (10, 200), "max_ues": (50, 50000)}, "URLLC": {"latency_ms": (0.5, 10), "reliability_pct": (99.999, 99.9999), "dl_throughput_mbps": (10, 500), "ul_throughput_mbps": (5, 200), "max_ues": (5, 5000)}, "mMTC": {"latency_ms": (10, 1000), "reliability_pct": (90.0, 99.99), "dl_throughput_mbps": (0.1, 50), "ul_throughput_mbps": (0.05, 25), "max_ues": (1000, 1000000)}, "V2X": {"latency_ms": (1, 20), "reliability_pct": (99.99, 99.9999), "dl_throughput_mbps": (10, 300), "ul_throughput_mbps": (5, 150), "max_ues": (10, 10000)}, "MPS": {"latency_ms": (1, 30), "reliability_pct": (99.999, 99.9999), "dl_throughput_mbps": (10, 200), "ul_throughput_mbps": (5, 100), "max_ues": (5, 5000)}, "HMTC": {"latency_ms": (0.1, 5), "reliability_pct": (99.9999, 99.99999), "dl_throughput_mbps": (100, 2000), "ul_throughput_mbps": (50, 1000), "max_ues": (100, 100000)}, } def load_config(path: str | Path) -> Dict[str, Any]: with open(path, "r", encoding="utf-8") as f: return yaml.safe_load(f) def write_json(path: str | Path, obj: Any) -> None: path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=2, ensure_ascii=False) def get_message(messages: List[Dict[str, str]], role: str) -> str: vals = [m.get("content", "") for m in messages if m.get("role") == role] return vals[0] if vals else "" def strip_code_fence(text: str) -> str: text = text.strip() if text.startswith("```"): text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE) text = re.sub(r"\s*```$", "", text) return text.strip() def extract_json_text(text: str) -> str: text = strip_code_fence(text) if not text: return text # If generation includes role markers or extra text, extract first complete-looking object. start = text.find("{") end = text.rfind("}") if start >= 0 and end > start: return text[start : end + 1].strip() return text def parse_json(text: str) -> Tuple[Optional[Any], Optional[str]]: candidate = extract_json_text(text) try: return json.loads(candidate), None except Exception as e: return None, f"{type(e).__name__}: {str(e)[:200]}" def canonical_json(obj: Any) -> str: return json.dumps(obj, sort_keys=True, ensure_ascii=False, separators=(",", ":")) def flatten_json(obj: Any, prefix: str = "") -> Dict[str, Any]: out: Dict[str, Any] = {} if isinstance(obj, dict): for k, v in obj.items(): p = f"{prefix}.{k}" if prefix else str(k) out.update(flatten_json(v, p)) elif isinstance(obj, list): for i, v in enumerate(obj): p = f"{prefix}[{i}]" out.update(flatten_json(v, p)) else: out[prefix] = obj return out def json_exact_match(pred_obj: Any, gold_obj: Any) -> bool: return canonical_json(pred_obj) == canonical_json(gold_obj) def field_f1(pred_obj: Any, gold_obj: Any) -> Dict[str, float]: pred = flatten_json(pred_obj) gold = flatten_json(gold_obj) pred_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in pred.items()) gold_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in gold.items()) tp = len(pred_items & gold_items) fp = len(pred_items - gold_items) fn = len(gold_items - pred_items) precision = tp / (tp + fp) if (tp + fp) else 1.0 recall = tp / (tp + fn) if (tp + fn) else 1.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 return {"field_precision": precision, "field_recall": recall, "field_f1": f1, "field_tp": tp, "field_fp": fp, "field_fn": fn} def metadata_constraint_pass(example: Dict[str, Any], pred_text: str, pred_obj: Any) -> Dict[str, bool]: text = canonical_json(pred_obj) if pred_obj is not None else pred_text sl = example.get("slice_type") target_layer = str(example.get("target_layer", "")) if target_layer.startswith("adversarial") or sl == "N/A": return { "slice_sst_pass": True, "kpi_text_presence_pass": True, "adversarial_status_pass": any(s in text for s in ["CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"]), } sst_pass = str(example.get("sst")) in text or f'"sst":{int(example.get("sst"))}' in text.replace(" ", "") # Basic exact value presence. Some layers encode values indirectly, so this is a conservative diagnostic not a hard schema validator. vals = [example.get("slice_type"), example.get("sd"), example.get("use_case"), example.get("region")] numeric = [example.get("latency_ms"), example.get("dl_throughput_mbps"), example.get("ul_throughput_mbps"), example.get("max_ues")] present = 0 total = 0 for v in vals: if v is None: continue total += 1 if str(v) in text: present += 1 for v in numeric: if v is None: continue total += 1 fv = float(v) candidates = {str(v), str(int(fv)) if fv.is_integer() else str(fv)} if any(c in text for c in candidates): present += 1 return { "slice_sst_pass": bool(sst_pass), "kpi_text_presence_pass": (present / total >= 0.5) if total else True, "adversarial_status_pass": True, } def aggregate_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]: if not rows: return {} keys = [k for k, v in rows[0].items() if isinstance(v, (int, float, bool))] agg: Dict[str, Any] = {"num_examples": len(rows)} for k in keys: vals = [float(r[k]) for r in rows if k in r and r[k] is not None] if vals: agg[k] = sum(vals) / len(vals) return agg def short_hash(text: str, n: int = 12) -> str: return hashlib.sha256(text.encode("utf-8", errors="ignore")).hexdigest()[:n]