#!/usr/bin/env python3 """Sample publication-friendly success/failure examples from evaluation predictions. Reads raw predictions and normalized scored predictions from an eval directory, then writes: - analysis/failure_examples.json - analysis/failure_examples.md Designed to support qualitative error analysis in a paper. """ import argparse import json from pathlib import Path from typing import Any, Dict, List from tmf921_train.utils import parse_json, write_json DEFAULT_LAYERS = ["o1_nrm", "a1_policy", "tmf921_lifecycle_report", "tmf921_lifecycle_monitor", "tmf921", "camara", "intent_3gpp", "adversarial_ambiguous", "adversarial_out_of_scope"] def load_rows(eval_dir: Path, split: str) -> List[Dict[str, Any]]: pred_path = eval_dir / split / "predictions.json" norm_path = eval_dir / split / "normalized_predictions_scored.json" if not pred_path.exists(): return [] pred = json.loads(pred_path.read_text()) if norm_path.exists(): norm = {r.get("id"): r for r in json.loads(norm_path.read_text())} out = [] for r in pred: nr = norm.get(r.get("id"), {}) merged = dict(r) for k, v in nr.items(): if k not in merged: merged[k] = v out.append(merged) return out return pred def summarize_text(text: str, max_chars: int = 1800) -> str: if text is None: return "" text = str(text).strip() if len(text) <= max_chars: return text return text[:max_chars] + "\n......" def infer_error_label(row: Dict[str, Any]) -> str: if not row.get("parse_json", False) or not row.get("norm_parse_json", True): return "invalid_or_unparseable_json" layer = row.get("target_layer") nf1 = row.get("norm_field_f1", row.get("field_f1", 0.0)) or 0.0 kf1 = row.get("norm_key_f1", 0.0) or 0.0 if kf1 > 0.95 and nf1 < 0.5: return "correct_structure_wrong_values" if kf1 < 0.8: return "structural_mismatch_or_extra_missing_keys" if layer == "o1_nrm": return "o1_value_fidelity_error" if layer == "a1_policy": return "a1_policy_value_error" if "lifecycle_report" in str(layer): return "lifecycle_report_measurement_mismatch" if "lifecycle_monitor" in str(layer): return "lifecycle_monitor_measurement_mismatch" return "value_level_mismatch" def choose_examples(rows: List[Dict[str, Any]], layer: str, n_fail: int, n_success: int) -> Dict[str, List[Dict[str, Any]]]: layer_rows = [r for r in rows if r.get("target_layer") == layer] if not layer_rows: return {"failures": [], "successes": []} failures = sorted(layer_rows, key=lambda r: (r.get("norm_field_f1", r.get("field_f1", 0.0)) or 0.0, r.get("exact_match", False)))[:n_fail] successes = sorted(layer_rows, key=lambda r: (r.get("norm_field_f1", r.get("field_f1", 0.0)) or 0.0, r.get("exact_match", False)), reverse=True)[:n_success] return {"failures": failures, "successes": successes} def compact_row(row: Dict[str, Any], split: str, kind: str) -> Dict[str, Any]: pred_obj, _ = parse_json(row.get("prediction", "")) gold_obj, _ = parse_json(row.get("gold", "")) return { "split": split, "kind": kind, "id": row.get("id"), "target_layer": row.get("target_layer"), "slice_type": row.get("slice_type"), "lifecycle_operation": row.get("lifecycle_operation"), "parse_json": row.get("parse_json"), "exact_match": row.get("exact_match"), "field_f1": row.get("field_f1"), "norm_field_f1": row.get("norm_field_f1"), "norm_key_f1": row.get("norm_key_f1"), "error_label": infer_error_label(row) if kind == "failure" else "success_or_high_scoring_example", "gold_json_keys": list(gold_obj.keys()) if isinstance(gold_obj, dict) else None, "prediction_json_keys": list(pred_obj.keys()) if isinstance(pred_obj, dict) else None, "gold": summarize_text(row.get("gold", "")), "prediction": summarize_text(row.get("prediction", "")), } def main(): ap = argparse.ArgumentParser() ap.add_argument("--eval_dir", required=True, help="Eval dir containing split/predictions.json and normalized_predictions_scored.json") ap.add_argument("--output_dir", default="analysis") ap.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"]) ap.add_argument("--layers", nargs="+", default=DEFAULT_LAYERS) ap.add_argument("--failures_per_layer", type=int, default=3) ap.add_argument("--successes_per_layer", type=int, default=1) args = ap.parse_args() eval_dir = Path(args.eval_dir) out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) examples: List[Dict[str, Any]] = [] for split in args.splits: rows = load_rows(eval_dir, split) for layer in args.layers: picked = choose_examples(rows, layer, args.failures_per_layer, args.successes_per_layer) for r in picked["failures"]: examples.append(compact_row(r, split, "failure")) for r in picked["successes"]: examples.append(compact_row(r, split, "success")) write_json(out_dir / "failure_examples.json", examples) lines = [] A = lines.append A("# Qualitative Success and Failure Examples") A("") A(f"Source eval dir: `{eval_dir}`") A("") A("These examples are sampled to support qualitative error analysis. Long JSON objects are truncated for readability; full examples are in `failure_examples.json`.") A("") for i, ex in enumerate(examples, start=1): A(f"## Example {i}: {ex['kind']} — `{ex['target_layer']}` — `{ex['split']}`") A("") A(f"- id: `{ex['id']}`") A(f"- slice type: `{ex.get('slice_type')}`") A(f"- lifecycle: `{ex.get('lifecycle_operation')}`") A(f"- error label: `{ex['error_label']}`") A(f"- raw field F1: `{ex.get('field_f1')}`") A(f"- normalized field F1: `{ex.get('norm_field_f1')}`") A(f"- normalized key F1: `{ex.get('norm_key_f1')}`") A("") A("### Gold") A("```json") A(ex["gold"]) A("```") A("") A("### Prediction") A("```json") A(ex["prediction"]) A("```") A("") (out_dir / "failure_examples.md").write_text("\n".join(lines), encoding="utf-8") print(out_dir / "failure_examples.md") print(out_dir / "failure_examples.json") if __name__ == "__main__": main()