| """Persisted "repair library" — the model's accumulated knowledge of
|
| known breakage -> repair pairs. Curated from successful rollouts during
|
| training. Loaded at inference time as a few-shot prefix when the agent
|
| recognises a familiar error class.
|
| """
|
| from __future__ import annotations
|
|
|
| import json
|
| from dataclasses import asdict, dataclass, field
|
| from pathlib import Path
|
| from typing import Any, Optional
|
|
|
|
|
| @dataclass
|
| class RepairExample:
|
| primitive_type: str
|
| breakage_params: dict[str, Any]
|
| error_signature: str
|
| repair_diff: str
|
| visible_reward: float
|
| held_out: dict[str, float]
|
| task_id: str = ""
|
|
|
| def signature_key(self) -> str:
|
| return f"{self.primitive_type}::{self.error_signature[:80]}"
|
|
|
|
|
| @dataclass
|
| class RepairLibrary:
|
| examples: list[RepairExample] = field(default_factory=list)
|
|
|
| def add(self, example: RepairExample) -> None:
|
| self.examples.append(example)
|
|
|
| def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]:
|
| """Return the highest-reward example whose primitive_type matches and
|
| whose error text overlaps."""
|
| candidates = [
|
| e for e in self.examples if e.primitive_type == primitive_type
|
| ]
|
| if not candidates:
|
| return None
|
| scored = sorted(
|
| candidates,
|
| key=lambda e: (
|
| _ngram_overlap(e.error_signature, error_text),
|
| e.visible_reward,
|
| ),
|
| reverse=True,
|
| )
|
| return scored[0] if scored else None
|
|
|
| def to_dict(self) -> dict:
|
| return {
|
| "version": "1",
|
| "examples": [asdict(e) for e in self.examples],
|
| "size": len(self.examples),
|
| "by_primitive": _count_by_primitive(self.examples),
|
| }
|
|
|
| def save(self, path: str | Path) -> None:
|
| path = Path(path)
|
| path.parent.mkdir(parents=True, exist_ok=True)
|
| path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8")
|
|
|
| @classmethod
|
| def load(cls, path: str | Path) -> "RepairLibrary":
|
| data = json.loads(Path(path).read_text(encoding="utf-8"))
|
| examples = [RepairExample(**e) for e in data.get("examples", [])]
|
| return cls(examples=examples)
|
|
|
|
|
| def _ngram_overlap(a: str, b: str, n: int = 3) -> float:
|
| if not a or not b:
|
| return 0.0
|
|
|
| def grams(text: str) -> set[str]:
|
| text = text.lower()
|
| return {text[i : i + n] for i in range(len(text) - n + 1)}
|
|
|
| ga, gb = grams(a), grams(b)
|
| if not ga or not gb:
|
| return 0.0
|
| return len(ga & gb) / max(1, len(ga | gb))
|
|
|
|
|
| def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]:
|
| counts: dict[str, int] = {}
|
| for e in examples:
|
| counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1
|
| return counts
|
|
|
|
|
| def curate_from_rollouts(
|
| rollout_results: list,
|
| min_reward: float = 0.6,
|
| min_held_out_clean: float = 0.5,
|
| ) -> RepairLibrary:
|
| """Build a RepairLibrary from a list of rollout dicts/RolloutResults."""
|
| lib = RepairLibrary()
|
| for r in rollout_results:
|
| get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default)
|
| if float(get("visible_reward", 0.0) or 0.0) < min_reward:
|
| continue
|
| if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean:
|
| continue
|
| lib.add(
|
| RepairExample(
|
| primitive_type=str(get("primitive_type", "unknown")),
|
| breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {}))
|
| if isinstance(get("info", {}), dict)
|
| else {},
|
| error_signature=str(get("error_trace", "") or "")[:160],
|
| repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000],
|
| visible_reward=float(get("visible_reward", 0.0) or 0.0),
|
| held_out=dict(get("held_out_breakdown", {}) or {}),
|
| task_id=str(get("task_id", "")),
|
| )
|
| )
|
| return lib
|
|
|