| """Drift Generator parser + a deterministic baseline policy.
|
|
|
| In training the LLM produces a JSON breakage spec; we parse it. In rollouts
|
| where we want a baseline (or a fallback when the LLM emits malformed JSON)
|
| we use `BaselineDriftGenerator`, which samples from the per-category set of
|
| known good primitive parameterisations.
|
| """
|
| from __future__ import annotations
|
|
|
| import json
|
| import random
|
| import re
|
| from dataclasses import dataclass
|
| from typing import Optional
|
|
|
| from forgeenv.primitives.breakage_primitives import (
|
| PRIMITIVE_REGISTRY,
|
| parse_breakage_spec,
|
| BreakagePrimitive,
|
| )
|
|
|
|
|
| _JSON_RE = re.compile(r"\{[\s\S]*\}")
|
|
|
|
|
| def parse_drift_output(text: str) -> Optional[dict]:
|
| """Extract a JSON object from possibly-noisy LLM output.
|
|
|
| Handles markdown fences, prose preamble, trailing commas (best-effort).
|
| Returns None on failure.
|
| """
|
| if not text:
|
| return None
|
| text = text.strip()
|
| if text.startswith("```"):
|
| text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
|
| text = re.sub(r"\n?```$", "", text)
|
| match = _JSON_RE.search(text)
|
| if not match:
|
| return None
|
| blob = match.group(0)
|
| try:
|
| return json.loads(blob)
|
| except json.JSONDecodeError:
|
| cleaned = re.sub(r",\s*([}\]])", r"\1", blob)
|
| try:
|
| return json.loads(cleaned)
|
| except json.JSONDecodeError:
|
| return None
|
|
|
|
|
| def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]:
|
| """End-to-end: LLM text -> validated BreakagePrimitive (or None)."""
|
| data = parse_drift_output(text)
|
| if not isinstance(data, dict):
|
| return None
|
| try:
|
| return parse_breakage_spec(data)
|
| except (ValueError, TypeError):
|
| return None
|
|
|
|
|
|
|
| _DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = {
|
| "RenameApiCall": [
|
| {"old_name": "trainer.train", "new_name": "trainer.start_training"},
|
| {"old_name": "save_pretrained", "new_name": "save_to_hub"},
|
| {"old_name": "from_pretrained", "new_name": "load_from_hub"},
|
| ],
|
| "DeprecateImport": [
|
| {
|
| "old_module": "from transformers import Trainer",
|
| "new_module": "from transformers.legacy import Trainer",
|
| },
|
| {
|
| "old_module": "from transformers import TrainingArguments",
|
| "new_module": "from transformers.training import TrainingArguments",
|
| },
|
| ],
|
| "ChangeArgumentSignature": [
|
| {
|
| "function_name": "TrainingArguments",
|
| "removed_arg": "num_train_epochs",
|
| "added_arg": "max_steps",
|
| "added_value": "1000",
|
| },
|
| {
|
| "function_name": "TrainingArguments",
|
| "removed_arg": "evaluation_strategy",
|
| "added_arg": "eval_strategy",
|
| "added_value": '"steps"',
|
| },
|
| ],
|
| "ModifyConfigField": [
|
| {"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"},
|
| {"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"},
|
| ],
|
| "RestructureDatasetSchema": [
|
| {"old_column": "text", "new_column": "input_text"},
|
| {"old_column": "label", "new_column": "labels"},
|
| {"old_column": "tokens", "new_column": "words"},
|
| ],
|
| "ChangeTokenizerBehavior": [
|
| {"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"},
|
| {"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"},
|
| ],
|
| "RemoveDeprecatedMethod": [
|
| {"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"},
|
| {"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"},
|
| ],
|
| "ChangeReturnType": [
|
| {"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"},
|
| {"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"},
|
| ],
|
| }
|
|
|
|
|
| @dataclass
|
| class BaselineDriftGenerator:
|
| """Deterministic stand-in for the LLM Drift Generator.
|
|
|
| Used for warm-start data, baseline rollouts, and unit tests.
|
| """
|
|
|
| seed: Optional[int] = None
|
|
|
| def __post_init__(self) -> None:
|
| self._rng = random.Random(self.seed) if self.seed is not None else random
|
|
|
| def propose(
|
| self, target_category: str = "", script: str = ""
|
| ) -> dict:
|
| """Produce a JSON-serializable breakage spec for `target_category`.
|
|
|
| Order of preference:
|
| 1. A primitive of `target_category` whose default params apply to `script`.
|
| 2. A primitive of any type whose default params apply to `script`.
|
| 3. A primitive of `target_category` (no-op fallback).
|
| """
|
|
|
| preferred_types = (
|
| [target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else []
|
| )
|
| all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys())
|
|
|
| for type_set in (preferred_types, all_types):
|
| shuffled = self._rng.sample(type_set, len(type_set)) if type_set else []
|
| for ptype in shuffled:
|
| for params in self._rng.sample(
|
| _DEFAULT_PARAMS_BY_TYPE[ptype],
|
| len(_DEFAULT_PARAMS_BY_TYPE[ptype]),
|
| ):
|
| if self._params_apply_to_script(ptype, params, script):
|
| return {"primitive_type": ptype, "params": dict(params)}
|
|
|
| ptype = preferred_types[0] if preferred_types else all_types[0]
|
| return {
|
| "primitive_type": ptype,
|
| "params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]),
|
| }
|
|
|
| @staticmethod
|
| def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool:
|
| """Heuristic: would this primitive actually mutate `script`?"""
|
| if not script:
|
| return True
|
| for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"):
|
| if key in params and params[key] and params[key] in script:
|
| return True
|
| return False
|
|
|