File size: 6,682 Bytes
77fad9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | #!/usr/bin/env python3
"""Sample publication-friendly success/failure examples from evaluation predictions.
Reads raw predictions and normalized scored predictions from an eval directory, then writes:
- analysis/failure_examples.json
- analysis/failure_examples.md
Designed to support qualitative error analysis in a paper.
"""
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List
from tmf921_train.utils import parse_json, write_json
DEFAULT_LAYERS = ["o1_nrm", "a1_policy", "tmf921_lifecycle_report", "tmf921_lifecycle_monitor", "tmf921", "camara", "intent_3gpp", "adversarial_ambiguous", "adversarial_out_of_scope"]
def load_rows(eval_dir: Path, split: str) -> List[Dict[str, Any]]:
pred_path = eval_dir / split / "predictions.json"
norm_path = eval_dir / split / "normalized_predictions_scored.json"
if not pred_path.exists():
return []
pred = json.loads(pred_path.read_text())
if norm_path.exists():
norm = {r.get("id"): r for r in json.loads(norm_path.read_text())}
out = []
for r in pred:
nr = norm.get(r.get("id"), {})
merged = dict(r)
for k, v in nr.items():
if k not in merged:
merged[k] = v
out.append(merged)
return out
return pred
def summarize_text(text: str, max_chars: int = 1800) -> str:
if text is None:
return ""
text = str(text).strip()
if len(text) <= max_chars:
return text
return text[:max_chars] + "\n...<truncated>..."
def infer_error_label(row: Dict[str, Any]) -> str:
if not row.get("parse_json", False) or not row.get("norm_parse_json", True):
return "invalid_or_unparseable_json"
layer = row.get("target_layer")
nf1 = row.get("norm_field_f1", row.get("field_f1", 0.0)) or 0.0
kf1 = row.get("norm_key_f1", 0.0) or 0.0
if kf1 > 0.95 and nf1 < 0.5:
return "correct_structure_wrong_values"
if kf1 < 0.8:
return "structural_mismatch_or_extra_missing_keys"
if layer == "o1_nrm":
return "o1_value_fidelity_error"
if layer == "a1_policy":
return "a1_policy_value_error"
if "lifecycle_report" in str(layer):
return "lifecycle_report_measurement_mismatch"
if "lifecycle_monitor" in str(layer):
return "lifecycle_monitor_measurement_mismatch"
return "value_level_mismatch"
def choose_examples(rows: List[Dict[str, Any]], layer: str, n_fail: int, n_success: int) -> Dict[str, List[Dict[str, Any]]]:
layer_rows = [r for r in rows if r.get("target_layer") == layer]
if not layer_rows:
return {"failures": [], "successes": []}
failures = sorted(layer_rows, key=lambda r: (r.get("norm_field_f1", r.get("field_f1", 0.0)) or 0.0, r.get("exact_match", False)))[:n_fail]
successes = sorted(layer_rows, key=lambda r: (r.get("norm_field_f1", r.get("field_f1", 0.0)) or 0.0, r.get("exact_match", False)), reverse=True)[:n_success]
return {"failures": failures, "successes": successes}
def compact_row(row: Dict[str, Any], split: str, kind: str) -> Dict[str, Any]:
pred_obj, _ = parse_json(row.get("prediction", ""))
gold_obj, _ = parse_json(row.get("gold", ""))
return {
"split": split,
"kind": kind,
"id": row.get("id"),
"target_layer": row.get("target_layer"),
"slice_type": row.get("slice_type"),
"lifecycle_operation": row.get("lifecycle_operation"),
"parse_json": row.get("parse_json"),
"exact_match": row.get("exact_match"),
"field_f1": row.get("field_f1"),
"norm_field_f1": row.get("norm_field_f1"),
"norm_key_f1": row.get("norm_key_f1"),
"error_label": infer_error_label(row) if kind == "failure" else "success_or_high_scoring_example",
"gold_json_keys": list(gold_obj.keys()) if isinstance(gold_obj, dict) else None,
"prediction_json_keys": list(pred_obj.keys()) if isinstance(pred_obj, dict) else None,
"gold": summarize_text(row.get("gold", "")),
"prediction": summarize_text(row.get("prediction", "")),
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--eval_dir", required=True, help="Eval dir containing split/predictions.json and normalized_predictions_scored.json")
ap.add_argument("--output_dir", default="analysis")
ap.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
ap.add_argument("--layers", nargs="+", default=DEFAULT_LAYERS)
ap.add_argument("--failures_per_layer", type=int, default=3)
ap.add_argument("--successes_per_layer", type=int, default=1)
args = ap.parse_args()
eval_dir = Path(args.eval_dir)
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
examples: List[Dict[str, Any]] = []
for split in args.splits:
rows = load_rows(eval_dir, split)
for layer in args.layers:
picked = choose_examples(rows, layer, args.failures_per_layer, args.successes_per_layer)
for r in picked["failures"]:
examples.append(compact_row(r, split, "failure"))
for r in picked["successes"]:
examples.append(compact_row(r, split, "success"))
write_json(out_dir / "failure_examples.json", examples)
lines = []
A = lines.append
A("# Qualitative Success and Failure Examples")
A("")
A(f"Source eval dir: `{eval_dir}`")
A("")
A("These examples are sampled to support qualitative error analysis. Long JSON objects are truncated for readability; full examples are in `failure_examples.json`.")
A("")
for i, ex in enumerate(examples, start=1):
A(f"## Example {i}: {ex['kind']} — `{ex['target_layer']}` — `{ex['split']}`")
A("")
A(f"- id: `{ex['id']}`")
A(f"- slice type: `{ex.get('slice_type')}`")
A(f"- lifecycle: `{ex.get('lifecycle_operation')}`")
A(f"- error label: `{ex['error_label']}`")
A(f"- raw field F1: `{ex.get('field_f1')}`")
A(f"- normalized field F1: `{ex.get('norm_field_f1')}`")
A(f"- normalized key F1: `{ex.get('norm_key_f1')}`")
A("")
A("### Gold")
A("```json")
A(ex["gold"])
A("```")
A("")
A("### Prediction")
A("```json")
A(ex["prediction"])
A("```")
A("")
(out_dir / "failure_examples.md").write_text("\n".join(lines), encoding="utf-8")
print(out_dir / "failure_examples.md")
print(out_dir / "failure_examples.json")
if __name__ == "__main__":
main()
|