PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
nraptisss's picture
Add RTX 6000 Ada QLoRA training and evaluation repo
d9ba941 verified
import json
import math
import re
import hashlib
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
import yaml
SST_MAP = {"eMBB": 1, "URLLC": 2, "mMTC": 3, "V2X": 4, "HMTC": 5, "MPS": 5, "N/A": 0}
KPI_RANGES = {
"eMBB": {"latency_ms": (10, 100), "reliability_pct": (99.0, 99.99), "dl_throughput_mbps": (50, 1000), "ul_throughput_mbps": (10, 200), "max_ues": (50, 50000)},
"URLLC": {"latency_ms": (0.5, 10), "reliability_pct": (99.999, 99.9999), "dl_throughput_mbps": (10, 500), "ul_throughput_mbps": (5, 200), "max_ues": (5, 5000)},
"mMTC": {"latency_ms": (10, 1000), "reliability_pct": (90.0, 99.99), "dl_throughput_mbps": (0.1, 50), "ul_throughput_mbps": (0.05, 25), "max_ues": (1000, 1000000)},
"V2X": {"latency_ms": (1, 20), "reliability_pct": (99.99, 99.9999), "dl_throughput_mbps": (10, 300), "ul_throughput_mbps": (5, 150), "max_ues": (10, 10000)},
"MPS": {"latency_ms": (1, 30), "reliability_pct": (99.999, 99.9999), "dl_throughput_mbps": (10, 200), "ul_throughput_mbps": (5, 100), "max_ues": (5, 5000)},
"HMTC": {"latency_ms": (0.1, 5), "reliability_pct": (99.9999, 99.99999), "dl_throughput_mbps": (100, 2000), "ul_throughput_mbps": (50, 1000), "max_ues": (100, 100000)},
}
def load_config(path: str | Path) -> Dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def write_json(path: str | Path, obj: Any) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, indent=2, ensure_ascii=False)
def get_message(messages: List[Dict[str, str]], role: str) -> str:
vals = [m.get("content", "") for m in messages if m.get("role") == role]
return vals[0] if vals else ""
def strip_code_fence(text: str) -> str:
text = text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s*```$", "", text)
return text.strip()
def extract_json_text(text: str) -> str:
text = strip_code_fence(text)
if not text:
return text
# If generation includes role markers or extra text, extract first complete-looking object.
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
return text[start : end + 1].strip()
return text
def parse_json(text: str) -> Tuple[Optional[Any], Optional[str]]:
candidate = extract_json_text(text)
try:
return json.loads(candidate), None
except Exception as e:
return None, f"{type(e).__name__}: {str(e)[:200]}"
def canonical_json(obj: Any) -> str:
return json.dumps(obj, sort_keys=True, ensure_ascii=False, separators=(",", ":"))
def flatten_json(obj: Any, prefix: str = "") -> Dict[str, Any]:
out: Dict[str, Any] = {}
if isinstance(obj, dict):
for k, v in obj.items():
p = f"{prefix}.{k}" if prefix else str(k)
out.update(flatten_json(v, p))
elif isinstance(obj, list):
for i, v in enumerate(obj):
p = f"{prefix}[{i}]"
out.update(flatten_json(v, p))
else:
out[prefix] = obj
return out
def json_exact_match(pred_obj: Any, gold_obj: Any) -> bool:
return canonical_json(pred_obj) == canonical_json(gold_obj)
def field_f1(pred_obj: Any, gold_obj: Any) -> Dict[str, float]:
pred = flatten_json(pred_obj)
gold = flatten_json(gold_obj)
pred_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in pred.items())
gold_items = set((k, canonical_json(v) if isinstance(v, (dict, list)) else str(v)) for k, v in gold.items())
tp = len(pred_items & gold_items)
fp = len(pred_items - gold_items)
fn = len(gold_items - pred_items)
precision = tp / (tp + fp) if (tp + fp) else 1.0
recall = tp / (tp + fn) if (tp + fn) else 1.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
return {"field_precision": precision, "field_recall": recall, "field_f1": f1, "field_tp": tp, "field_fp": fp, "field_fn": fn}
def metadata_constraint_pass(example: Dict[str, Any], pred_text: str, pred_obj: Any) -> Dict[str, bool]:
text = canonical_json(pred_obj) if pred_obj is not None else pred_text
sl = example.get("slice_type")
target_layer = str(example.get("target_layer", ""))
if target_layer.startswith("adversarial") or sl == "N/A":
return {
"slice_sst_pass": True,
"kpi_text_presence_pass": True,
"adversarial_status_pass": any(s in text for s in ["CLARIFICATION_REQUIRED", "OUT_OF_SCOPE", "INTENT_VALIDATION_FAILED"]),
}
sst_pass = str(example.get("sst")) in text or f'"sst":{int(example.get("sst"))}' in text.replace(" ", "")
# Basic exact value presence. Some layers encode values indirectly, so this is a conservative diagnostic not a hard schema validator.
vals = [example.get("slice_type"), example.get("sd"), example.get("use_case"), example.get("region")]
numeric = [example.get("latency_ms"), example.get("dl_throughput_mbps"), example.get("ul_throughput_mbps"), example.get("max_ues")]
present = 0
total = 0
for v in vals:
if v is None:
continue
total += 1
if str(v) in text:
present += 1
for v in numeric:
if v is None:
continue
total += 1
fv = float(v)
candidates = {str(v), str(int(fv)) if fv.is_integer() else str(fv)}
if any(c in text for c in candidates):
present += 1
return {
"slice_sst_pass": bool(sst_pass),
"kpi_text_presence_pass": (present / total >= 0.5) if total else True,
"adversarial_status_pass": True,
}
def aggregate_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
if not rows:
return {}
keys = [k for k, v in rows[0].items() if isinstance(v, (int, float, bool))]
agg: Dict[str, Any] = {"num_examples": len(rows)}
for k in keys:
vals = [float(r[k]) for r in rows if k in r and r[k] is not None]
if vals:
agg[k] = sum(vals) / len(vals)
return agg
def short_hash(text: str, n: int = 12) -> str:
return hashlib.sha256(text.encode("utf-8", errors="ignore")).hexdigest()[:n]