| """System and user prompts for the two RL roles. | |
| Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with | |
| LoRA adapters per role, so role prompts are the only thing distinguishing | |
| them at inference time. Keep them concise β every token is a token of GPU | |
| budget during GRPO rollouts. | |
| """ | |
| from __future__ import annotations | |
| from typing import Iterable | |
| PRIMITIVE_DESCRIPTIONS = { | |
| "RenameApiCall": "Rename a function/method call (api_drift)", | |
| "DeprecateImport": "Change an import path (import_drift)", | |
| "ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)", | |
| "ModifyConfigField": "Change a config-class default (config_drift)", | |
| "RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)", | |
| "ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)", | |
| "RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)", | |
| "ChangeReturnType": "Function returns a different structure (api_drift)", | |
| } | |
| DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator. | |
| You see a working HuggingFace training script and the curriculum target category. | |
| Output exactly one JSON object describing a breakage primitive that simulates | |
| realistic library version drift. The primitive must: | |
| 1. Be PLAUSIBLE β match the kind of breakage that happens between real | |
| transformers/datasets/trl releases. | |
| 2. Be SOLVABLE β the Repair Agent should be able to fix it from the error trace alone. | |
| 3. Match the requested target_category. | |
| Output schema: | |
| {"primitive_type": "<one of the 8 types>", "params": { ... }} | |
| Available primitive types and parameter schemas: | |
| - RenameApiCall: {"old_name": str, "new_name": str} | |
| - DeprecateImport: {"old_module": str, "new_module": str} | |
| - ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str} | |
| - ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str} | |
| - RestructureDatasetSchema: {"old_column": str, "new_column": str} | |
| - ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str} | |
| - RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str} | |
| - ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str} | |
| Output ONLY the JSON object β no commentary, no markdown fences. | |
| """ | |
| REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent. | |
| You see a broken HuggingFace training script, an error trace, and the current | |
| library version snapshot. Output ONLY a unified diff that fixes the script. | |
| Rules: | |
| 1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py` | |
| headers and `@@ ... @@` hunk markers. | |
| 2. Make the MINIMAL change that resolves the error AND preserves the original | |
| training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit | |
| calls. | |
| 3. Do NOT add any prose, markdown fences, or thinking output β diff only. | |
| 4. If the error is unfixable, output an empty diff. | |
| """ | |
| def render_drift_generator_prompt( | |
| script: str, target_category: str, library_versions: dict | |
| ) -> str: | |
| versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items()) | |
| return f"""Target category: {target_category} | |
| Library versions: {versions_str} | |
| Working script: | |
| ```python | |
| {script} | |
| ``` | |
| Output JSON breakage primitive:""" | |
| def render_repair_agent_prompt( | |
| broken_script: str, | |
| error_trace: str, | |
| library_versions: dict, | |
| target_category: str = "", | |
| ) -> str: | |
| versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items()) | |
| return f"""Library versions: {versions_str} | |
| Target category hint: {target_category or 'unknown'} | |
| Broken script: | |
| ```python | |
| {broken_script} | |
| ``` | |
| Error trace: | |
| {error_trace} | |
| Output unified diff (no prose, no fences):""" | |
| def list_primitive_descriptions() -> Iterable[str]: | |
| return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items()) | |