| """8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift.
|
|
|
| Each primitive transforms a working script to simulate a library upgrade
|
| breakage. They double as the Drift Generator's structured action space.
|
| """
|
| from __future__ import annotations
|
|
|
| import re
|
| from abc import ABC, abstractmethod
|
| from dataclasses import dataclass, field
|
|
|
|
|
| @dataclass
|
| class BreakagePrimitive(ABC):
|
| """Abstract base class for all breakage types."""
|
|
|
| category: str = field(default="generic", init=False)
|
| name: str = field(default="BreakagePrimitive", init=False)
|
| description: str = field(default="", init=False)
|
|
|
| @abstractmethod
|
| def apply(self, script: str) -> str:
|
| """Transform `script` to introduce the breakage."""
|
|
|
| def to_spec(self) -> dict:
|
| """Serialize to JSON-compatible spec for the LLM action space."""
|
| return {
|
| "primitive_type": self.__class__.__name__,
|
| "category": self.category,
|
| "params": self._get_params(),
|
| }
|
|
|
| @abstractmethod
|
| def _get_params(self) -> dict:
|
| """Return a JSON-serializable dict of constructor parameters."""
|
|
|
|
|
| @dataclass
|
| class RenameApiCall(BreakagePrimitive):
|
| """Rename a function/method call to simulate API deprecation."""
|
|
|
| old_name: str = ""
|
| new_name: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "api_drift"
|
| self.name = "RenameApiCall"
|
| self.description = f"Rename {self.old_name} -> {self.new_name}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.old_name:
|
| return script
|
|
|
| pattern = re.compile(rf"(?<!\w){re.escape(self.old_name)}(?!\w)")
|
| return pattern.sub(self.new_name, script)
|
|
|
| def _get_params(self) -> dict:
|
| return {"old_name": self.old_name, "new_name": self.new_name}
|
|
|
|
|
| @dataclass
|
| class DeprecateImport(BreakagePrimitive):
|
| """Change an import path to simulate module restructuring."""
|
|
|
| old_module: str = ""
|
| new_module: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "import_drift"
|
| self.name = "DeprecateImport"
|
| self.description = f"Move {self.old_module} -> {self.new_module}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.old_module:
|
| return script
|
| return script.replace(self.old_module, self.new_module)
|
|
|
| def _get_params(self) -> dict:
|
| return {"old_module": self.old_module, "new_module": self.new_module}
|
|
|
|
|
| @dataclass
|
| class ChangeArgumentSignature(BreakagePrimitive):
|
| """Remove an expected kwarg (and document a new required one)."""
|
|
|
| function_name: str = ""
|
| removed_arg: str = ""
|
| added_arg: str = ""
|
| added_value: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "api_drift"
|
| self.name = "ChangeArgumentSignature"
|
| self.description = (
|
| f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}"
|
| )
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.removed_arg:
|
| return script
|
| pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)"
|
| return re.sub(pattern, "", script)
|
|
|
| def _get_params(self) -> dict:
|
| return {
|
| "function_name": self.function_name,
|
| "removed_arg": self.removed_arg,
|
| "added_arg": self.added_arg,
|
| "added_value": self.added_value,
|
| }
|
|
|
|
|
| @dataclass
|
| class ModifyConfigField(BreakagePrimitive):
|
| """Change a config-class default value to simulate behaviour drift."""
|
|
|
| config_class: str = ""
|
| field_name: str = ""
|
| new_value: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "config_drift"
|
| self.name = "ModifyConfigField"
|
| self.description = f"Change {self.config_class}.{self.field_name}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.field_name:
|
| return script
|
| pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
|
| return re.sub(pattern, rf"\g<1>{self.new_value}", script)
|
|
|
| def _get_params(self) -> dict:
|
| return {
|
| "config_class": self.config_class,
|
| "field_name": self.field_name,
|
| "new_value": self.new_value,
|
| }
|
|
|
|
|
| @dataclass
|
| class RestructureDatasetSchema(BreakagePrimitive):
|
| """Rename a dataset column reference to simulate schema drift."""
|
|
|
| old_column: str = ""
|
| new_column: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "dataset_drift"
|
| self.name = "RestructureDatasetSchema"
|
| self.description = f"Rename column {self.old_column} -> {self.new_column}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.old_column:
|
| return script
|
| return script.replace(
|
| f'"{self.old_column}"', f'"{self.new_column}"'
|
| ).replace(
|
| f"'{self.old_column}'", f"'{self.new_column}'"
|
| )
|
|
|
| def _get_params(self) -> dict:
|
| return {"old_column": self.old_column, "new_column": self.new_column}
|
|
|
|
|
| @dataclass
|
| class ChangeTokenizerBehavior(BreakagePrimitive):
|
| """Change tokenizer call arguments."""
|
|
|
| old_kwarg: str = ""
|
| old_value: str = ""
|
| new_kwarg: str = ""
|
| new_value: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "tokenizer_drift"
|
| self.name = "ChangeTokenizerBehavior"
|
| self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.old_kwarg:
|
| return script
|
| pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}"
|
| replacement = f"{self.new_kwarg}={self.new_value}"
|
| return re.sub(pattern, replacement, script)
|
|
|
| def _get_params(self) -> dict:
|
| return {
|
| "old_kwarg": self.old_kwarg,
|
| "old_value": self.old_value,
|
| "new_kwarg": self.new_kwarg,
|
| "new_value": self.new_value,
|
| }
|
|
|
|
|
| @dataclass
|
| class RemoveDeprecatedMethod(BreakagePrimitive):
|
| """Remove a method that has been deprecated, leaving a sentinel that
|
| raises AttributeError-style errors when the script runs."""
|
|
|
| class_name: str = ""
|
| method_name: str = ""
|
| replacement: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "api_drift"
|
| self.name = "RemoveDeprecatedMethod"
|
| self.description = f"Remove {self.class_name}.{self.method_name}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.method_name:
|
| return script
|
| return script.replace(
|
| f".{self.method_name}(", f".{self.method_name}_DEPRECATED("
|
| )
|
|
|
| def _get_params(self) -> dict:
|
| return {
|
| "class_name": self.class_name,
|
| "method_name": self.method_name,
|
| "replacement": self.replacement,
|
| }
|
|
|
|
|
| @dataclass
|
| class ChangeReturnType(BreakagePrimitive):
|
| """A function now returns a different structure (e.g. tuple -> object)."""
|
|
|
| function_name: str = ""
|
| old_access: str = ""
|
| new_access: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "api_drift"
|
| self.name = "ChangeReturnType"
|
| self.description = f"Change return type of {self.function_name}"
|
|
|
| def apply(self, script: str) -> str:
|
| if self.old_access and self.new_access:
|
| return script.replace(self.old_access, self.new_access)
|
| return script
|
|
|
| def _get_params(self) -> dict:
|
| return {
|
| "function_name": self.function_name,
|
| "old_access": self.old_access,
|
| "new_access": self.new_access,
|
| }
|
|
|
|
|
| PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = {
|
| "RenameApiCall": RenameApiCall,
|
| "DeprecateImport": DeprecateImport,
|
| "ChangeArgumentSignature": ChangeArgumentSignature,
|
| "ModifyConfigField": ModifyConfigField,
|
| "RestructureDatasetSchema": RestructureDatasetSchema,
|
| "ChangeTokenizerBehavior": ChangeTokenizerBehavior,
|
| "RemoveDeprecatedMethod": RemoveDeprecatedMethod,
|
| "ChangeReturnType": ChangeReturnType,
|
| }
|
|
|
|
|
| def parse_breakage_spec(spec: dict) -> BreakagePrimitive:
|
| """Parse a JSON breakage spec into a BreakagePrimitive object.
|
|
|
| Tolerates extra keys; ignores unknown params (LLMs hallucinate these).
|
| """
|
| ptype = spec.get("primitive_type", "")
|
| params = spec.get("params", {}) or {}
|
|
|
| if ptype not in PRIMITIVE_REGISTRY:
|
| raise ValueError(
|
| f"Unknown primitive type: {ptype!r}. "
|
| f"Valid types: {list(PRIMITIVE_REGISTRY)}"
|
| )
|
|
|
| cls = PRIMITIVE_REGISTRY[ptype]
|
|
|
| valid_fields = {
|
| f.name for f in cls.__dataclass_fields__.values() if f.init
|
| }
|
| filtered = {k: v for k, v in params.items() if k in valid_fields}
|
| return cls(**filtered)
|
|
|