| """Drift Generator parser + a deterministic baseline policy. |
| |
| In training the LLM produces a JSON breakage spec; we parse it. In rollouts |
| where we want a baseline (or a fallback when the LLM emits malformed JSON) |
| we use `BaselineDriftGenerator`, which samples from the per-category set of |
| known good primitive parameterisations. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import random |
| import re |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| from forgeenv.primitives.breakage_primitives import ( |
| PRIMITIVE_REGISTRY, |
| parse_breakage_spec, |
| BreakagePrimitive, |
| ) |
|
|
|
|
| _JSON_RE = re.compile(r"\{[\s\S]*\}") |
|
|
|
|
| def parse_drift_output(text: str) -> Optional[dict]: |
| """Extract a JSON object from possibly-noisy LLM output. |
| |
| Handles markdown fences, prose preamble, trailing commas (best-effort). |
| Returns None on failure. |
| """ |
| if not text: |
| return None |
| text = text.strip() |
| if text.startswith("```"): |
| text = re.sub(r"^```[a-zA-Z]*\n?", "", text) |
| text = re.sub(r"\n?```$", "", text) |
| match = _JSON_RE.search(text) |
| if not match: |
| return None |
| blob = match.group(0) |
| try: |
| return json.loads(blob) |
| except json.JSONDecodeError: |
| cleaned = re.sub(r",\s*([}\]])", r"\1", blob) |
| try: |
| return json.loads(cleaned) |
| except json.JSONDecodeError: |
| return None |
|
|
|
|
| def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]: |
| """End-to-end: LLM text -> validated BreakagePrimitive (or None).""" |
| data = parse_drift_output(text) |
| if not isinstance(data, dict): |
| return None |
| try: |
| return parse_breakage_spec(data) |
| except (ValueError, TypeError): |
| return None |
|
|
|
|
| |
| _DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = { |
| "RenameApiCall": [ |
| {"old_name": "trainer.train", "new_name": "trainer.start_training"}, |
| {"old_name": "save_pretrained", "new_name": "save_to_hub"}, |
| {"old_name": "from_pretrained", "new_name": "load_from_hub"}, |
| ], |
| "DeprecateImport": [ |
| { |
| "old_module": "from transformers import Trainer", |
| "new_module": "from transformers.legacy import Trainer", |
| }, |
| { |
| "old_module": "from transformers import TrainingArguments", |
| "new_module": "from transformers.training import TrainingArguments", |
| }, |
| ], |
| "ChangeArgumentSignature": [ |
| { |
| "function_name": "TrainingArguments", |
| "removed_arg": "num_train_epochs", |
| "added_arg": "max_steps", |
| "added_value": "1000", |
| }, |
| { |
| "function_name": "TrainingArguments", |
| "removed_arg": "evaluation_strategy", |
| "added_arg": "eval_strategy", |
| "added_value": '"steps"', |
| }, |
| ], |
| "ModifyConfigField": [ |
| {"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"}, |
| {"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"}, |
| ], |
| "RestructureDatasetSchema": [ |
| {"old_column": "text", "new_column": "input_text"}, |
| {"old_column": "label", "new_column": "labels"}, |
| {"old_column": "tokens", "new_column": "words"}, |
| ], |
| "ChangeTokenizerBehavior": [ |
| {"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"}, |
| {"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"}, |
| ], |
| "RemoveDeprecatedMethod": [ |
| {"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"}, |
| {"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"}, |
| ], |
| "ChangeReturnType": [ |
| {"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"}, |
| {"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"}, |
| ], |
| } |
|
|
|
|
| @dataclass |
| class BaselineDriftGenerator: |
| """Deterministic stand-in for the LLM Drift Generator. |
| |
| Used for warm-start data, baseline rollouts, and unit tests. |
| """ |
|
|
| seed: Optional[int] = None |
|
|
| def __post_init__(self) -> None: |
| self._rng = random.Random(self.seed) if self.seed is not None else random |
|
|
| def propose( |
| self, target_category: str = "", script: str = "" |
| ) -> dict: |
| """Produce a JSON-serializable breakage spec for `target_category`. |
| |
| Order of preference: |
| 1. A primitive of `target_category` whose default params apply to `script`. |
| 2. A primitive of any type whose default params apply to `script`. |
| 3. A primitive of `target_category` (no-op fallback). |
| """ |
|
|
| preferred_types = ( |
| [target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else [] |
| ) |
| all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys()) |
|
|
| for type_set in (preferred_types, all_types): |
| shuffled = self._rng.sample(type_set, len(type_set)) if type_set else [] |
| for ptype in shuffled: |
| for params in self._rng.sample( |
| _DEFAULT_PARAMS_BY_TYPE[ptype], |
| len(_DEFAULT_PARAMS_BY_TYPE[ptype]), |
| ): |
| if self._params_apply_to_script(ptype, params, script): |
| return {"primitive_type": ptype, "params": dict(params)} |
|
|
| ptype = preferred_types[0] if preferred_types else all_types[0] |
| return { |
| "primitive_type": ptype, |
| "params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]), |
| } |
|
|
| @staticmethod |
| def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool: |
| """Heuristic: would this primitive actually mutate `script`?""" |
| if not script: |
| return True |
| for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"): |
| if key in params and params[key] and params[key] in script: |
| return True |
| return False |
|
|