File size: 4,127 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | """Persisted "repair library" — the model's accumulated knowledge of
known breakage -> repair pairs. Curated from successful rollouts during
training. Loaded at inference time as a few-shot prefix when the agent
recognises a familiar error class.
"""
from __future__ import annotations
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Optional
@dataclass
class RepairExample:
primitive_type: str
breakage_params: dict[str, Any]
error_signature: str
repair_diff: str
visible_reward: float
held_out: dict[str, float]
task_id: str = ""
def signature_key(self) -> str:
return f"{self.primitive_type}::{self.error_signature[:80]}"
@dataclass
class RepairLibrary:
examples: list[RepairExample] = field(default_factory=list)
def add(self, example: RepairExample) -> None:
self.examples.append(example)
def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]:
"""Return the highest-reward example whose primitive_type matches and
whose error text overlaps."""
candidates = [
e for e in self.examples if e.primitive_type == primitive_type
]
if not candidates:
return None
scored = sorted(
candidates,
key=lambda e: (
_ngram_overlap(e.error_signature, error_text),
e.visible_reward,
),
reverse=True,
)
return scored[0] if scored else None
def to_dict(self) -> dict:
return {
"version": "1",
"examples": [asdict(e) for e in self.examples],
"size": len(self.examples),
"by_primitive": _count_by_primitive(self.examples),
}
def save(self, path: str | Path) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8")
@classmethod
def load(cls, path: str | Path) -> "RepairLibrary":
data = json.loads(Path(path).read_text(encoding="utf-8"))
examples = [RepairExample(**e) for e in data.get("examples", [])]
return cls(examples=examples)
def _ngram_overlap(a: str, b: str, n: int = 3) -> float:
if not a or not b:
return 0.0
def grams(text: str) -> set[str]:
text = text.lower()
return {text[i : i + n] for i in range(len(text) - n + 1)}
ga, gb = grams(a), grams(b)
if not ga or not gb:
return 0.0
return len(ga & gb) / max(1, len(ga | gb))
def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]:
counts: dict[str, int] = {}
for e in examples:
counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1
return counts
def curate_from_rollouts(
rollout_results: list,
min_reward: float = 0.6,
min_held_out_clean: float = 0.5,
) -> RepairLibrary:
"""Build a RepairLibrary from a list of rollout dicts/RolloutResults."""
lib = RepairLibrary()
for r in rollout_results:
get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default)
if float(get("visible_reward", 0.0) or 0.0) < min_reward:
continue
if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean:
continue
lib.add(
RepairExample(
primitive_type=str(get("primitive_type", "unknown")),
breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {}))
if isinstance(get("info", {}), dict)
else {},
error_signature=str(get("error_trace", "") or "")[:160],
repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000],
visible_reward=float(get("visible_reward", 0.0) or 0.0),
held_out=dict(get("held_out_breakdown", {}) or {}),
task_id=str(get("task_id", "")),
)
)
return lib
|