File size: 3,941 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""System and user prompts for the two RL roles.

Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with
LoRA adapters per role, so role prompts are the only thing distinguishing
them at inference time. Keep them concise — every token is a token of GPU
budget during GRPO rollouts.
"""
from __future__ import annotations

from typing import Iterable


PRIMITIVE_DESCRIPTIONS = {
    "RenameApiCall": "Rename a function/method call (api_drift)",
    "DeprecateImport": "Change an import path (import_drift)",
    "ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)",
    "ModifyConfigField": "Change a config-class default (config_drift)",
    "RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)",
    "ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)",
    "RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)",
    "ChangeReturnType": "Function returns a different structure (api_drift)",
}

DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator.
You see a working HuggingFace training script and the curriculum target category.
Output exactly one JSON object describing a breakage primitive that simulates
realistic library version drift. The primitive must:
1. Be PLAUSIBLE — match the kind of breakage that happens between real
   transformers/datasets/trl releases.
2. Be SOLVABLE — the Repair Agent should be able to fix it from the error trace alone.
3. Match the requested target_category.

Output schema:
{"primitive_type": "<one of the 8 types>", "params": { ... }}

Available primitive types and parameter schemas:
- RenameApiCall: {"old_name": str, "new_name": str}
- DeprecateImport: {"old_module": str, "new_module": str}
- ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str}
- ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str}
- RestructureDatasetSchema: {"old_column": str, "new_column": str}
- ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str}
- RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str}
- ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str}

Output ONLY the JSON object — no commentary, no markdown fences.
"""


REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent.
You see a broken HuggingFace training script, an error trace, and the current
library version snapshot. Output ONLY a unified diff that fixes the script.

Rules:
1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py`
   headers and `@@ ... @@` hunk markers.
2. Make the MINIMAL change that resolves the error AND preserves the original
   training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit
   calls.
3. Do NOT add any prose, markdown fences, or thinking output — diff only.
4. If the error is unfixable, output an empty diff.
"""


def render_drift_generator_prompt(
    script: str, target_category: str, library_versions: dict
) -> str:
    versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
    return f"""Target category: {target_category}
Library versions: {versions_str}

Working script:
```python
{script}
```

Output JSON breakage primitive:"""


def render_repair_agent_prompt(
    broken_script: str,
    error_trace: str,
    library_versions: dict,
    target_category: str = "",
) -> str:
    versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
    return f"""Library versions: {versions_str}
Target category hint: {target_category or 'unknown'}

Broken script:
```python
{broken_script}
```

Error trace:
{error_trace}

Output unified diff (no prose, no fences):"""


def list_primitive_descriptions() -> Iterable[str]:
    return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items())