forgeenv-source / forgeenv-space /forgeenv /roles /drift_generator.py
akhiilll's picture
forgeenv source snapshot for training job
b0fbec3 verified
"""Drift Generator parser + a deterministic baseline policy.
In training the LLM produces a JSON breakage spec; we parse it. In rollouts
where we want a baseline (or a fallback when the LLM emits malformed JSON)
we use `BaselineDriftGenerator`, which samples from the per-category set of
known good primitive parameterisations.
"""
from __future__ import annotations
import json
import random
import re
from dataclasses import dataclass
from typing import Optional
from forgeenv.primitives.breakage_primitives import (
PRIMITIVE_REGISTRY,
parse_breakage_spec,
BreakagePrimitive,
)
_JSON_RE = re.compile(r"\{[\s\S]*\}")
def parse_drift_output(text: str) -> Optional[dict]:
"""Extract a JSON object from possibly-noisy LLM output.
Handles markdown fences, prose preamble, trailing commas (best-effort).
Returns None on failure.
"""
if not text:
return None
text = text.strip()
if text.startswith("```"):
text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
text = re.sub(r"\n?```$", "", text)
match = _JSON_RE.search(text)
if not match:
return None
blob = match.group(0)
try:
return json.loads(blob)
except json.JSONDecodeError:
cleaned = re.sub(r",\s*([}\]])", r"\1", blob)
try:
return json.loads(cleaned)
except json.JSONDecodeError:
return None
def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]:
"""End-to-end: LLM text -> validated BreakagePrimitive (or None)."""
data = parse_drift_output(text)
if not isinstance(data, dict):
return None
try:
return parse_breakage_spec(data)
except (ValueError, TypeError):
return None
# ---------------------------------------------------------------- baselines
_DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = {
"RenameApiCall": [
{"old_name": "trainer.train", "new_name": "trainer.start_training"},
{"old_name": "save_pretrained", "new_name": "save_to_hub"},
{"old_name": "from_pretrained", "new_name": "load_from_hub"},
],
"DeprecateImport": [
{
"old_module": "from transformers import Trainer",
"new_module": "from transformers.legacy import Trainer",
},
{
"old_module": "from transformers import TrainingArguments",
"new_module": "from transformers.training import TrainingArguments",
},
],
"ChangeArgumentSignature": [
{
"function_name": "TrainingArguments",
"removed_arg": "num_train_epochs",
"added_arg": "max_steps",
"added_value": "1000",
},
{
"function_name": "TrainingArguments",
"removed_arg": "evaluation_strategy",
"added_arg": "eval_strategy",
"added_value": '"steps"',
},
],
"ModifyConfigField": [
{"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"},
{"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"},
],
"RestructureDatasetSchema": [
{"old_column": "text", "new_column": "input_text"},
{"old_column": "label", "new_column": "labels"},
{"old_column": "tokens", "new_column": "words"},
],
"ChangeTokenizerBehavior": [
{"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"},
{"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"},
],
"RemoveDeprecatedMethod": [
{"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"},
{"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"},
],
"ChangeReturnType": [
{"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"},
{"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"},
],
}
@dataclass
class BaselineDriftGenerator:
"""Deterministic stand-in for the LLM Drift Generator.
Used for warm-start data, baseline rollouts, and unit tests.
"""
seed: Optional[int] = None
def __post_init__(self) -> None:
self._rng = random.Random(self.seed) if self.seed is not None else random
def propose(
self, target_category: str = "", script: str = ""
) -> dict:
"""Produce a JSON-serializable breakage spec for `target_category`.
Order of preference:
1. A primitive of `target_category` whose default params apply to `script`.
2. A primitive of any type whose default params apply to `script`.
3. A primitive of `target_category` (no-op fallback).
"""
preferred_types = (
[target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else []
)
all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys())
for type_set in (preferred_types, all_types):
shuffled = self._rng.sample(type_set, len(type_set)) if type_set else []
for ptype in shuffled:
for params in self._rng.sample(
_DEFAULT_PARAMS_BY_TYPE[ptype],
len(_DEFAULT_PARAMS_BY_TYPE[ptype]),
):
if self._params_apply_to_script(ptype, params, script):
return {"primitive_type": ptype, "params": dict(params)}
ptype = preferred_types[0] if preferred_types else all_types[0]
return {
"primitive_type": ptype,
"params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]),
}
@staticmethod
def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool:
"""Heuristic: would this primitive actually mutate `script`?"""
if not script:
return True
for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"):
if key in params and params[key] and params[key] in script:
return True
return False