forgeenv-source / forgeenv /primitives /breakage_primitives.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift.
Each primitive transforms a working script to simulate a library upgrade
breakage. They double as the Drift Generator's structured action space.
"""
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@dataclass
class BreakagePrimitive(ABC):
"""Abstract base class for all breakage types."""
category: str = field(default="generic", init=False)
name: str = field(default="BreakagePrimitive", init=False)
description: str = field(default="", init=False)
@abstractmethod
def apply(self, script: str) -> str:
"""Transform `script` to introduce the breakage."""
def to_spec(self) -> dict:
"""Serialize to JSON-compatible spec for the LLM action space."""
return {
"primitive_type": self.__class__.__name__,
"category": self.category,
"params": self._get_params(),
}
@abstractmethod
def _get_params(self) -> dict:
"""Return a JSON-serializable dict of constructor parameters."""
@dataclass
class RenameApiCall(BreakagePrimitive):
"""Rename a function/method call to simulate API deprecation."""
old_name: str = ""
new_name: str = ""
def __post_init__(self) -> None:
self.category = "api_drift"
self.name = "RenameApiCall"
self.description = f"Rename {self.old_name} -> {self.new_name}"
def apply(self, script: str) -> str:
if not self.old_name:
return script
# Use word-boundary replacement so we don't substring-match identifiers.
pattern = re.compile(rf"(?<!\w){re.escape(self.old_name)}(?!\w)")
return pattern.sub(self.new_name, script)
def _get_params(self) -> dict:
return {"old_name": self.old_name, "new_name": self.new_name}
@dataclass
class DeprecateImport(BreakagePrimitive):
"""Change an import path to simulate module restructuring."""
old_module: str = ""
new_module: str = ""
def __post_init__(self) -> None:
self.category = "import_drift"
self.name = "DeprecateImport"
self.description = f"Move {self.old_module} -> {self.new_module}"
def apply(self, script: str) -> str:
if not self.old_module:
return script
return script.replace(self.old_module, self.new_module)
def _get_params(self) -> dict:
return {"old_module": self.old_module, "new_module": self.new_module}
@dataclass
class ChangeArgumentSignature(BreakagePrimitive):
"""Remove an expected kwarg (and document a new required one)."""
function_name: str = ""
removed_arg: str = ""
added_arg: str = ""
added_value: str = ""
def __post_init__(self) -> None:
self.category = "api_drift"
self.name = "ChangeArgumentSignature"
self.description = (
f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}"
)
def apply(self, script: str) -> str:
if not self.removed_arg:
return script
pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)"
return re.sub(pattern, "", script)
def _get_params(self) -> dict:
return {
"function_name": self.function_name,
"removed_arg": self.removed_arg,
"added_arg": self.added_arg,
"added_value": self.added_value,
}
@dataclass
class ModifyConfigField(BreakagePrimitive):
"""Change a config-class default value to simulate behaviour drift."""
config_class: str = ""
field_name: str = ""
new_value: str = ""
def __post_init__(self) -> None:
self.category = "config_drift"
self.name = "ModifyConfigField"
self.description = f"Change {self.config_class}.{self.field_name}"
def apply(self, script: str) -> str:
if not self.field_name:
return script
pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
return re.sub(pattern, rf"\g<1>{self.new_value}", script)
def _get_params(self) -> dict:
return {
"config_class": self.config_class,
"field_name": self.field_name,
"new_value": self.new_value,
}
@dataclass
class RestructureDatasetSchema(BreakagePrimitive):
"""Rename a dataset column reference to simulate schema drift."""
old_column: str = ""
new_column: str = ""
def __post_init__(self) -> None:
self.category = "dataset_drift"
self.name = "RestructureDatasetSchema"
self.description = f"Rename column {self.old_column} -> {self.new_column}"
def apply(self, script: str) -> str:
if not self.old_column:
return script
return script.replace(
f'"{self.old_column}"', f'"{self.new_column}"'
).replace(
f"'{self.old_column}'", f"'{self.new_column}'"
)
def _get_params(self) -> dict:
return {"old_column": self.old_column, "new_column": self.new_column}
@dataclass
class ChangeTokenizerBehavior(BreakagePrimitive):
"""Change tokenizer call arguments."""
old_kwarg: str = ""
old_value: str = ""
new_kwarg: str = ""
new_value: str = ""
def __post_init__(self) -> None:
self.category = "tokenizer_drift"
self.name = "ChangeTokenizerBehavior"
self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}"
def apply(self, script: str) -> str:
if not self.old_kwarg:
return script
pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}"
replacement = f"{self.new_kwarg}={self.new_value}"
return re.sub(pattern, replacement, script)
def _get_params(self) -> dict:
return {
"old_kwarg": self.old_kwarg,
"old_value": self.old_value,
"new_kwarg": self.new_kwarg,
"new_value": self.new_value,
}
@dataclass
class RemoveDeprecatedMethod(BreakagePrimitive):
"""Remove a method that has been deprecated, leaving a sentinel that
raises AttributeError-style errors when the script runs."""
class_name: str = ""
method_name: str = ""
replacement: str = ""
def __post_init__(self) -> None:
self.category = "api_drift"
self.name = "RemoveDeprecatedMethod"
self.description = f"Remove {self.class_name}.{self.method_name}"
def apply(self, script: str) -> str:
if not self.method_name:
return script
return script.replace(
f".{self.method_name}(", f".{self.method_name}_DEPRECATED("
)
def _get_params(self) -> dict:
return {
"class_name": self.class_name,
"method_name": self.method_name,
"replacement": self.replacement,
}
@dataclass
class ChangeReturnType(BreakagePrimitive):
"""A function now returns a different structure (e.g. tuple -> object)."""
function_name: str = ""
old_access: str = ""
new_access: str = ""
def __post_init__(self) -> None:
self.category = "api_drift"
self.name = "ChangeReturnType"
self.description = f"Change return type of {self.function_name}"
def apply(self, script: str) -> str:
if self.old_access and self.new_access:
return script.replace(self.old_access, self.new_access)
return script
def _get_params(self) -> dict:
return {
"function_name": self.function_name,
"old_access": self.old_access,
"new_access": self.new_access,
}
PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = {
"RenameApiCall": RenameApiCall,
"DeprecateImport": DeprecateImport,
"ChangeArgumentSignature": ChangeArgumentSignature,
"ModifyConfigField": ModifyConfigField,
"RestructureDatasetSchema": RestructureDatasetSchema,
"ChangeTokenizerBehavior": ChangeTokenizerBehavior,
"RemoveDeprecatedMethod": RemoveDeprecatedMethod,
"ChangeReturnType": ChangeReturnType,
}
def parse_breakage_spec(spec: dict) -> BreakagePrimitive:
"""Parse a JSON breakage spec into a BreakagePrimitive object.
Tolerates extra keys; ignores unknown params (LLMs hallucinate these).
"""
ptype = spec.get("primitive_type", "")
params = spec.get("params", {}) or {}
if ptype not in PRIMITIVE_REGISTRY:
raise ValueError(
f"Unknown primitive type: {ptype!r}. "
f"Valid types: {list(PRIMITIVE_REGISTRY)}"
)
cls = PRIMITIVE_REGISTRY[ptype]
# Filter to known fields only so a hallucinated kwarg can't crash us.
valid_fields = {
f.name for f in cls.__dataclass_fields__.values() if f.init # type: ignore[attr-defined]
}
filtered = {k: v for k, v in params.items() if k in valid_fields}
return cls(**filtered)