File size: 4,043 Bytes
b0fbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""System and user prompts for the two RL roles.



Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with

LoRA adapters per role, so role prompts are the only thing distinguishing

them at inference time. Keep them concise — every token is a token of GPU

budget during GRPO rollouts.

"""
from __future__ import annotations

from typing import Iterable


PRIMITIVE_DESCRIPTIONS = {
    "RenameApiCall": "Rename a function/method call (api_drift)",
    "DeprecateImport": "Change an import path (import_drift)",
    "ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)",
    "ModifyConfigField": "Change a config-class default (config_drift)",
    "RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)",
    "ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)",
    "RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)",
    "ChangeReturnType": "Function returns a different structure (api_drift)",
}

DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator.

You see a working HuggingFace training script and the curriculum target category.

Output exactly one JSON object describing a breakage primitive that simulates

realistic library version drift. The primitive must:

1. Be PLAUSIBLE — match the kind of breakage that happens between real

   transformers/datasets/trl releases.

2. Be SOLVABLE — the Repair Agent should be able to fix it from the error trace alone.

3. Match the requested target_category.



Output schema:

{"primitive_type": "<one of the 8 types>", "params": { ... }}



Available primitive types and parameter schemas:

- RenameApiCall: {"old_name": str, "new_name": str}

- DeprecateImport: {"old_module": str, "new_module": str}

- ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str}

- ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str}

- RestructureDatasetSchema: {"old_column": str, "new_column": str}

- ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str}

- RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str}

- ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str}



Output ONLY the JSON object — no commentary, no markdown fences.

"""


REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent.

You see a broken HuggingFace training script, an error trace, and the current

library version snapshot. Output ONLY a unified diff that fixes the script.



Rules:

1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py`

   headers and `@@ ... @@` hunk markers.

2. Make the MINIMAL change that resolves the error AND preserves the original

   training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit

   calls.

3. Do NOT add any prose, markdown fences, or thinking output — diff only.

4. If the error is unfixable, output an empty diff.

"""


def render_drift_generator_prompt(

    script: str, target_category: str, library_versions: dict

) -> str:
    versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
    return f"""Target category: {target_category}

Library versions: {versions_str}



Working script:

```python

{script}

```



Output JSON breakage primitive:"""


def render_repair_agent_prompt(

    broken_script: str,

    error_trace: str,

    library_versions: dict,

    target_category: str = "",

) -> str:
    versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
    return f"""Library versions: {versions_str}

Target category hint: {target_category or 'unknown'}



Broken script:

```python

{broken_script}

```



Error trace:

{error_trace}



Output unified diff (no prose, no fences):"""


def list_primitive_descriptions() -> Iterable[str]:
    return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items())