File size: 4,127 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Persisted "repair library" — the model's accumulated knowledge of
known breakage -> repair pairs. Curated from successful rollouts during
training. Loaded at inference time as a few-shot prefix when the agent
recognises a familiar error class.
"""
from __future__ import annotations

import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Optional


@dataclass
class RepairExample:
    primitive_type: str
    breakage_params: dict[str, Any]
    error_signature: str
    repair_diff: str
    visible_reward: float
    held_out: dict[str, float]
    task_id: str = ""

    def signature_key(self) -> str:
        return f"{self.primitive_type}::{self.error_signature[:80]}"


@dataclass
class RepairLibrary:
    examples: list[RepairExample] = field(default_factory=list)

    def add(self, example: RepairExample) -> None:
        self.examples.append(example)

    def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]:
        """Return the highest-reward example whose primitive_type matches and
        whose error text overlaps."""
        candidates = [
            e for e in self.examples if e.primitive_type == primitive_type
        ]
        if not candidates:
            return None
        scored = sorted(
            candidates,
            key=lambda e: (
                _ngram_overlap(e.error_signature, error_text),
                e.visible_reward,
            ),
            reverse=True,
        )
        return scored[0] if scored else None

    def to_dict(self) -> dict:
        return {
            "version": "1",
            "examples": [asdict(e) for e in self.examples],
            "size": len(self.examples),
            "by_primitive": _count_by_primitive(self.examples),
        }

    def save(self, path: str | Path) -> None:
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8")

    @classmethod
    def load(cls, path: str | Path) -> "RepairLibrary":
        data = json.loads(Path(path).read_text(encoding="utf-8"))
        examples = [RepairExample(**e) for e in data.get("examples", [])]
        return cls(examples=examples)


def _ngram_overlap(a: str, b: str, n: int = 3) -> float:
    if not a or not b:
        return 0.0

    def grams(text: str) -> set[str]:
        text = text.lower()
        return {text[i : i + n] for i in range(len(text) - n + 1)}

    ga, gb = grams(a), grams(b)
    if not ga or not gb:
        return 0.0
    return len(ga & gb) / max(1, len(ga | gb))


def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]:
    counts: dict[str, int] = {}
    for e in examples:
        counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1
    return counts


def curate_from_rollouts(
    rollout_results: list,
    min_reward: float = 0.6,
    min_held_out_clean: float = 0.5,
) -> RepairLibrary:
    """Build a RepairLibrary from a list of rollout dicts/RolloutResults."""
    lib = RepairLibrary()
    for r in rollout_results:
        get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default)
        if float(get("visible_reward", 0.0) or 0.0) < min_reward:
            continue
        if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean:
            continue
        lib.add(
            RepairExample(
                primitive_type=str(get("primitive_type", "unknown")),
                breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {}))
                if isinstance(get("info", {}), dict)
                else {},
                error_signature=str(get("error_trace", "") or "")[:160],
                repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000],
                visible_reward=float(get("visible_reward", 0.0) or 0.0),
                held_out=dict(get("held_out_breakdown", {}) or {}),
                task_id=str(get("task_id", "")),
            )
        )
    return lib