| """Persisted "repair library" — the model's accumulated knowledge of |
| known breakage -> repair pairs. Curated from successful rollouts during |
| training. Loaded at inference time as a few-shot prefix when the agent |
| recognises a familiar error class. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import asdict, dataclass, field |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
|
|
| @dataclass |
| class RepairExample: |
| primitive_type: str |
| breakage_params: dict[str, Any] |
| error_signature: str |
| repair_diff: str |
| visible_reward: float |
| held_out: dict[str, float] |
| task_id: str = "" |
|
|
| def signature_key(self) -> str: |
| return f"{self.primitive_type}::{self.error_signature[:80]}" |
|
|
|
|
| @dataclass |
| class RepairLibrary: |
| examples: list[RepairExample] = field(default_factory=list) |
|
|
| def add(self, example: RepairExample) -> None: |
| self.examples.append(example) |
|
|
| def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]: |
| """Return the highest-reward example whose primitive_type matches and |
| whose error text overlaps.""" |
| candidates = [ |
| e for e in self.examples if e.primitive_type == primitive_type |
| ] |
| if not candidates: |
| return None |
| scored = sorted( |
| candidates, |
| key=lambda e: ( |
| _ngram_overlap(e.error_signature, error_text), |
| e.visible_reward, |
| ), |
| reverse=True, |
| ) |
| return scored[0] if scored else None |
|
|
| def to_dict(self) -> dict: |
| return { |
| "version": "1", |
| "examples": [asdict(e) for e in self.examples], |
| "size": len(self.examples), |
| "by_primitive": _count_by_primitive(self.examples), |
| } |
|
|
| def save(self, path: str | Path) -> None: |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8") |
|
|
| @classmethod |
| def load(cls, path: str | Path) -> "RepairLibrary": |
| data = json.loads(Path(path).read_text(encoding="utf-8")) |
| examples = [RepairExample(**e) for e in data.get("examples", [])] |
| return cls(examples=examples) |
|
|
|
|
| def _ngram_overlap(a: str, b: str, n: int = 3) -> float: |
| if not a or not b: |
| return 0.0 |
|
|
| def grams(text: str) -> set[str]: |
| text = text.lower() |
| return {text[i : i + n] for i in range(len(text) - n + 1)} |
|
|
| ga, gb = grams(a), grams(b) |
| if not ga or not gb: |
| return 0.0 |
| return len(ga & gb) / max(1, len(ga | gb)) |
|
|
|
|
| def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]: |
| counts: dict[str, int] = {} |
| for e in examples: |
| counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1 |
| return counts |
|
|
|
|
| def curate_from_rollouts( |
| rollout_results: list, |
| min_reward: float = 0.6, |
| min_held_out_clean: float = 0.5, |
| ) -> RepairLibrary: |
| """Build a RepairLibrary from a list of rollout dicts/RolloutResults.""" |
| lib = RepairLibrary() |
| for r in rollout_results: |
| get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default) |
| if float(get("visible_reward", 0.0) or 0.0) < min_reward: |
| continue |
| if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean: |
| continue |
| lib.add( |
| RepairExample( |
| primitive_type=str(get("primitive_type", "unknown")), |
| breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {})) |
| if isinstance(get("info", {}), dict) |
| else {}, |
| error_signature=str(get("error_trace", "") or "")[:160], |
| repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000], |
| visible_reward=float(get("visible_reward", 0.0) or 0.0), |
| held_out=dict(get("held_out_breakdown", {}) or {}), |
| task_id=str(get("task_id", "")), |
| ) |
| ) |
| return lib |
|
|