"""Persisted "repair library" — the model's accumulated knowledge of known breakage -> repair pairs. Curated from successful rollouts during training. Loaded at inference time as a few-shot prefix when the agent recognises a familiar error class. """ from __future__ import annotations import json from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any, Optional @dataclass class RepairExample: primitive_type: str breakage_params: dict[str, Any] error_signature: str repair_diff: str visible_reward: float held_out: dict[str, float] task_id: str = "" def signature_key(self) -> str: return f"{self.primitive_type}::{self.error_signature[:80]}" @dataclass class RepairLibrary: examples: list[RepairExample] = field(default_factory=list) def add(self, example: RepairExample) -> None: self.examples.append(example) def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]: """Return the highest-reward example whose primitive_type matches and whose error text overlaps.""" candidates = [ e for e in self.examples if e.primitive_type == primitive_type ] if not candidates: return None scored = sorted( candidates, key=lambda e: ( _ngram_overlap(e.error_signature, error_text), e.visible_reward, ), reverse=True, ) return scored[0] if scored else None def to_dict(self) -> dict: return { "version": "1", "examples": [asdict(e) for e in self.examples], "size": len(self.examples), "by_primitive": _count_by_primitive(self.examples), } def save(self, path: str | Path) -> None: path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8") @classmethod def load(cls, path: str | Path) -> "RepairLibrary": data = json.loads(Path(path).read_text(encoding="utf-8")) examples = [RepairExample(**e) for e in data.get("examples", [])] return cls(examples=examples) def _ngram_overlap(a: str, b: str, n: int = 3) -> float: if not a or not b: return 0.0 def grams(text: str) -> set[str]: text = text.lower() return {text[i : i + n] for i in range(len(text) - n + 1)} ga, gb = grams(a), grams(b) if not ga or not gb: return 0.0 return len(ga & gb) / max(1, len(ga | gb)) def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]: counts: dict[str, int] = {} for e in examples: counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1 return counts def curate_from_rollouts( rollout_results: list, min_reward: float = 0.6, min_held_out_clean: float = 0.5, ) -> RepairLibrary: """Build a RepairLibrary from a list of rollout dicts/RolloutResults.""" lib = RepairLibrary() for r in rollout_results: get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default) if float(get("visible_reward", 0.0) or 0.0) < min_reward: continue if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean: continue lib.add( RepairExample( primitive_type=str(get("primitive_type", "unknown")), breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {})) if isinstance(get("info", {}), dict) else {}, error_signature=str(get("error_trace", "") or "")[:160], repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000], visible_reward=float(get("visible_reward", 0.0) or 0.0), held_out=dict(get("held_out_breakdown", {}) or {}), task_id=str(get("task_id", "")), ) ) return lib