PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
tmf921-intent-training / scripts /sample_failure_examples.py
nraptisss's picture
Add qualitative failure example sampler
77fad9d verified
#!/usr/bin/env python3
"""Sample publication-friendly success/failure examples from evaluation predictions.
Reads raw predictions and normalized scored predictions from an eval directory, then writes:
- analysis/failure_examples.json
- analysis/failure_examples.md
Designed to support qualitative error analysis in a paper.
"""
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List
from tmf921_train.utils import parse_json, write_json
DEFAULT_LAYERS = ["o1_nrm", "a1_policy", "tmf921_lifecycle_report", "tmf921_lifecycle_monitor", "tmf921", "camara", "intent_3gpp", "adversarial_ambiguous", "adversarial_out_of_scope"]
def load_rows(eval_dir: Path, split: str) -> List[Dict[str, Any]]:
pred_path = eval_dir / split / "predictions.json"
norm_path = eval_dir / split / "normalized_predictions_scored.json"
if not pred_path.exists():
return []
pred = json.loads(pred_path.read_text())
if norm_path.exists():
norm = {r.get("id"): r for r in json.loads(norm_path.read_text())}
out = []
for r in pred:
nr = norm.get(r.get("id"), {})
merged = dict(r)
for k, v in nr.items():
if k not in merged:
merged[k] = v
out.append(merged)
return out
return pred
def summarize_text(text: str, max_chars: int = 1800) -> str:
if text is None:
return ""
text = str(text).strip()
if len(text) <= max_chars:
return text
return text[:max_chars] + "\n...<truncated>..."
def infer_error_label(row: Dict[str, Any]) -> str:
if not row.get("parse_json", False) or not row.get("norm_parse_json", True):
return "invalid_or_unparseable_json"
layer = row.get("target_layer")
nf1 = row.get("norm_field_f1", row.get("field_f1", 0.0)) or 0.0
kf1 = row.get("norm_key_f1", 0.0) or 0.0
if kf1 > 0.95 and nf1 < 0.5:
return "correct_structure_wrong_values"
if kf1 < 0.8:
return "structural_mismatch_or_extra_missing_keys"
if layer == "o1_nrm":
return "o1_value_fidelity_error"
if layer == "a1_policy":
return "a1_policy_value_error"
if "lifecycle_report" in str(layer):
return "lifecycle_report_measurement_mismatch"
if "lifecycle_monitor" in str(layer):
return "lifecycle_monitor_measurement_mismatch"
return "value_level_mismatch"
def choose_examples(rows: List[Dict[str, Any]], layer: str, n_fail: int, n_success: int) -> Dict[str, List[Dict[str, Any]]]:
layer_rows = [r for r in rows if r.get("target_layer") == layer]
if not layer_rows:
return {"failures": [], "successes": []}
failures = sorted(layer_rows, key=lambda r: (r.get("norm_field_f1", r.get("field_f1", 0.0)) or 0.0, r.get("exact_match", False)))[:n_fail]
successes = sorted(layer_rows, key=lambda r: (r.get("norm_field_f1", r.get("field_f1", 0.0)) or 0.0, r.get("exact_match", False)), reverse=True)[:n_success]
return {"failures": failures, "successes": successes}
def compact_row(row: Dict[str, Any], split: str, kind: str) -> Dict[str, Any]:
pred_obj, _ = parse_json(row.get("prediction", ""))
gold_obj, _ = parse_json(row.get("gold", ""))
return {
"split": split,
"kind": kind,
"id": row.get("id"),
"target_layer": row.get("target_layer"),
"slice_type": row.get("slice_type"),
"lifecycle_operation": row.get("lifecycle_operation"),
"parse_json": row.get("parse_json"),
"exact_match": row.get("exact_match"),
"field_f1": row.get("field_f1"),
"norm_field_f1": row.get("norm_field_f1"),
"norm_key_f1": row.get("norm_key_f1"),
"error_label": infer_error_label(row) if kind == "failure" else "success_or_high_scoring_example",
"gold_json_keys": list(gold_obj.keys()) if isinstance(gold_obj, dict) else None,
"prediction_json_keys": list(pred_obj.keys()) if isinstance(pred_obj, dict) else None,
"gold": summarize_text(row.get("gold", "")),
"prediction": summarize_text(row.get("prediction", "")),
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--eval_dir", required=True, help="Eval dir containing split/predictions.json and normalized_predictions_scored.json")
ap.add_argument("--output_dir", default="analysis")
ap.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
ap.add_argument("--layers", nargs="+", default=DEFAULT_LAYERS)
ap.add_argument("--failures_per_layer", type=int, default=3)
ap.add_argument("--successes_per_layer", type=int, default=1)
args = ap.parse_args()
eval_dir = Path(args.eval_dir)
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
examples: List[Dict[str, Any]] = []
for split in args.splits:
rows = load_rows(eval_dir, split)
for layer in args.layers:
picked = choose_examples(rows, layer, args.failures_per_layer, args.successes_per_layer)
for r in picked["failures"]:
examples.append(compact_row(r, split, "failure"))
for r in picked["successes"]:
examples.append(compact_row(r, split, "success"))
write_json(out_dir / "failure_examples.json", examples)
lines = []
A = lines.append
A("# Qualitative Success and Failure Examples")
A("")
A(f"Source eval dir: `{eval_dir}`")
A("")
A("These examples are sampled to support qualitative error analysis. Long JSON objects are truncated for readability; full examples are in `failure_examples.json`.")
A("")
for i, ex in enumerate(examples, start=1):
A(f"## Example {i}: {ex['kind']} — `{ex['target_layer']}` — `{ex['split']}`")
A("")
A(f"- id: `{ex['id']}`")
A(f"- slice type: `{ex.get('slice_type')}`")
A(f"- lifecycle: `{ex.get('lifecycle_operation')}`")
A(f"- error label: `{ex['error_label']}`")
A(f"- raw field F1: `{ex.get('field_f1')}`")
A(f"- normalized field F1: `{ex.get('norm_field_f1')}`")
A(f"- normalized key F1: `{ex.get('norm_key_f1')}`")
A("")
A("### Gold")
A("```json")
A(ex["gold"])
A("```")
A("")
A("### Prediction")
A("```json")
A(ex["prediction"])
A("```")
A("")
(out_dir / "failure_examples.md").write_text("\n".join(lines), encoding="utf-8")
print(out_dir / "failure_examples.md")
print(out_dir / "failure_examples.json")
if __name__ == "__main__":
main()