"""Drift Generator parser + a deterministic baseline policy. In training the LLM produces a JSON breakage spec; we parse it. In rollouts where we want a baseline (or a fallback when the LLM emits malformed JSON) we use `BaselineDriftGenerator`, which samples from the per-category set of known good primitive parameterisations. """ from __future__ import annotations import json import random import re from dataclasses import dataclass from typing import Optional from forgeenv.primitives.breakage_primitives import ( PRIMITIVE_REGISTRY, parse_breakage_spec, BreakagePrimitive, ) _JSON_RE = re.compile(r"\{[\s\S]*\}") def parse_drift_output(text: str) -> Optional[dict]: """Extract a JSON object from possibly-noisy LLM output. Handles markdown fences, prose preamble, trailing commas (best-effort). Returns None on failure. """ if not text: return None text = text.strip() if text.startswith("```"): text = re.sub(r"^```[a-zA-Z]*\n?", "", text) text = re.sub(r"\n?```$", "", text) match = _JSON_RE.search(text) if not match: return None blob = match.group(0) try: return json.loads(blob) except json.JSONDecodeError: cleaned = re.sub(r",\s*([}\]])", r"\1", blob) try: return json.loads(cleaned) except json.JSONDecodeError: return None def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]: """End-to-end: LLM text -> validated BreakagePrimitive (or None).""" data = parse_drift_output(text) if not isinstance(data, dict): return None try: return parse_breakage_spec(data) except (ValueError, TypeError): return None # ---------------------------------------------------------------- baselines _DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = { "RenameApiCall": [ {"old_name": "trainer.train", "new_name": "trainer.start_training"}, {"old_name": "save_pretrained", "new_name": "save_to_hub"}, {"old_name": "from_pretrained", "new_name": "load_from_hub"}, ], "DeprecateImport": [ { "old_module": "from transformers import Trainer", "new_module": "from transformers.legacy import Trainer", }, { "old_module": "from transformers import TrainingArguments", "new_module": "from transformers.training import TrainingArguments", }, ], "ChangeArgumentSignature": [ { "function_name": "TrainingArguments", "removed_arg": "num_train_epochs", "added_arg": "max_steps", "added_value": "1000", }, { "function_name": "TrainingArguments", "removed_arg": "evaluation_strategy", "added_arg": "eval_strategy", "added_value": '"steps"', }, ], "ModifyConfigField": [ {"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"}, {"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"}, ], "RestructureDatasetSchema": [ {"old_column": "text", "new_column": "input_text"}, {"old_column": "label", "new_column": "labels"}, {"old_column": "tokens", "new_column": "words"}, ], "ChangeTokenizerBehavior": [ {"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"}, {"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"}, ], "RemoveDeprecatedMethod": [ {"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"}, {"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"}, ], "ChangeReturnType": [ {"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"}, {"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"}, ], } @dataclass class BaselineDriftGenerator: """Deterministic stand-in for the LLM Drift Generator. Used for warm-start data, baseline rollouts, and unit tests. """ seed: Optional[int] = None def __post_init__(self) -> None: self._rng = random.Random(self.seed) if self.seed is not None else random def propose( self, target_category: str = "", script: str = "" ) -> dict: """Produce a JSON-serializable breakage spec for `target_category`. Order of preference: 1. A primitive of `target_category` whose default params apply to `script`. 2. A primitive of any type whose default params apply to `script`. 3. A primitive of `target_category` (no-op fallback). """ preferred_types = ( [target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else [] ) all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys()) for type_set in (preferred_types, all_types): shuffled = self._rng.sample(type_set, len(type_set)) if type_set else [] for ptype in shuffled: for params in self._rng.sample( _DEFAULT_PARAMS_BY_TYPE[ptype], len(_DEFAULT_PARAMS_BY_TYPE[ptype]), ): if self._params_apply_to_script(ptype, params, script): return {"primitive_type": ptype, "params": dict(params)} ptype = preferred_types[0] if preferred_types else all_types[0] return { "primitive_type": ptype, "params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]), } @staticmethod def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool: """Heuristic: would this primitive actually mutate `script`?""" if not script: return True for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"): if key in params and params[key] and params[key] in script: return True return False