| """8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift. |
| |
| Each primitive transforms a working script to simulate a library upgrade |
| breakage. They double as the Drift Generator's structured action space. |
| """ |
| from __future__ import annotations |
|
|
| import re |
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class BreakagePrimitive(ABC): |
| """Abstract base class for all breakage types.""" |
|
|
| category: str = field(default="generic", init=False) |
| name: str = field(default="BreakagePrimitive", init=False) |
| description: str = field(default="", init=False) |
|
|
| @abstractmethod |
| def apply(self, script: str) -> str: |
| """Transform `script` to introduce the breakage.""" |
|
|
| def to_spec(self) -> dict: |
| """Serialize to JSON-compatible spec for the LLM action space.""" |
| return { |
| "primitive_type": self.__class__.__name__, |
| "category": self.category, |
| "params": self._get_params(), |
| } |
|
|
| @abstractmethod |
| def _get_params(self) -> dict: |
| """Return a JSON-serializable dict of constructor parameters.""" |
|
|
|
|
| @dataclass |
| class RenameApiCall(BreakagePrimitive): |
| """Rename a function/method call to simulate API deprecation.""" |
|
|
| old_name: str = "" |
| new_name: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "api_drift" |
| self.name = "RenameApiCall" |
| self.description = f"Rename {self.old_name} -> {self.new_name}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.old_name: |
| return script |
| |
| pattern = re.compile(rf"(?<!\w){re.escape(self.old_name)}(?!\w)") |
| return pattern.sub(self.new_name, script) |
|
|
| def _get_params(self) -> dict: |
| return {"old_name": self.old_name, "new_name": self.new_name} |
|
|
|
|
| @dataclass |
| class DeprecateImport(BreakagePrimitive): |
| """Change an import path to simulate module restructuring.""" |
|
|
| old_module: str = "" |
| new_module: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "import_drift" |
| self.name = "DeprecateImport" |
| self.description = f"Move {self.old_module} -> {self.new_module}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.old_module: |
| return script |
| return script.replace(self.old_module, self.new_module) |
|
|
| def _get_params(self) -> dict: |
| return {"old_module": self.old_module, "new_module": self.new_module} |
|
|
|
|
| @dataclass |
| class ChangeArgumentSignature(BreakagePrimitive): |
| """Remove an expected kwarg (and document a new required one).""" |
|
|
| function_name: str = "" |
| removed_arg: str = "" |
| added_arg: str = "" |
| added_value: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "api_drift" |
| self.name = "ChangeArgumentSignature" |
| self.description = ( |
| f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}" |
| ) |
|
|
| def apply(self, script: str) -> str: |
| if not self.removed_arg: |
| return script |
| pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)" |
| return re.sub(pattern, "", script) |
|
|
| def _get_params(self) -> dict: |
| return { |
| "function_name": self.function_name, |
| "removed_arg": self.removed_arg, |
| "added_arg": self.added_arg, |
| "added_value": self.added_value, |
| } |
|
|
|
|
| @dataclass |
| class ModifyConfigField(BreakagePrimitive): |
| """Change a config-class default value to simulate behaviour drift.""" |
|
|
| config_class: str = "" |
| field_name: str = "" |
| new_value: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "config_drift" |
| self.name = "ModifyConfigField" |
| self.description = f"Change {self.config_class}.{self.field_name}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.field_name: |
| return script |
| pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)" |
| return re.sub(pattern, rf"\g<1>{self.new_value}", script) |
|
|
| def _get_params(self) -> dict: |
| return { |
| "config_class": self.config_class, |
| "field_name": self.field_name, |
| "new_value": self.new_value, |
| } |
|
|
|
|
| @dataclass |
| class RestructureDatasetSchema(BreakagePrimitive): |
| """Rename a dataset column reference to simulate schema drift.""" |
|
|
| old_column: str = "" |
| new_column: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "dataset_drift" |
| self.name = "RestructureDatasetSchema" |
| self.description = f"Rename column {self.old_column} -> {self.new_column}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.old_column: |
| return script |
| return script.replace( |
| f'"{self.old_column}"', f'"{self.new_column}"' |
| ).replace( |
| f"'{self.old_column}'", f"'{self.new_column}'" |
| ) |
|
|
| def _get_params(self) -> dict: |
| return {"old_column": self.old_column, "new_column": self.new_column} |
|
|
|
|
| @dataclass |
| class ChangeTokenizerBehavior(BreakagePrimitive): |
| """Change tokenizer call arguments.""" |
|
|
| old_kwarg: str = "" |
| old_value: str = "" |
| new_kwarg: str = "" |
| new_value: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "tokenizer_drift" |
| self.name = "ChangeTokenizerBehavior" |
| self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.old_kwarg: |
| return script |
| pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}" |
| replacement = f"{self.new_kwarg}={self.new_value}" |
| return re.sub(pattern, replacement, script) |
|
|
| def _get_params(self) -> dict: |
| return { |
| "old_kwarg": self.old_kwarg, |
| "old_value": self.old_value, |
| "new_kwarg": self.new_kwarg, |
| "new_value": self.new_value, |
| } |
|
|
|
|
| @dataclass |
| class RemoveDeprecatedMethod(BreakagePrimitive): |
| """Remove a method that has been deprecated, leaving a sentinel that |
| raises AttributeError-style errors when the script runs.""" |
|
|
| class_name: str = "" |
| method_name: str = "" |
| replacement: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "api_drift" |
| self.name = "RemoveDeprecatedMethod" |
| self.description = f"Remove {self.class_name}.{self.method_name}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.method_name: |
| return script |
| return script.replace( |
| f".{self.method_name}(", f".{self.method_name}_DEPRECATED(" |
| ) |
|
|
| def _get_params(self) -> dict: |
| return { |
| "class_name": self.class_name, |
| "method_name": self.method_name, |
| "replacement": self.replacement, |
| } |
|
|
|
|
| @dataclass |
| class ChangeReturnType(BreakagePrimitive): |
| """A function now returns a different structure (e.g. tuple -> object).""" |
|
|
| function_name: str = "" |
| old_access: str = "" |
| new_access: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "api_drift" |
| self.name = "ChangeReturnType" |
| self.description = f"Change return type of {self.function_name}" |
|
|
| def apply(self, script: str) -> str: |
| if self.old_access and self.new_access: |
| return script.replace(self.old_access, self.new_access) |
| return script |
|
|
| def _get_params(self) -> dict: |
| return { |
| "function_name": self.function_name, |
| "old_access": self.old_access, |
| "new_access": self.new_access, |
| } |
|
|
|
|
| PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = { |
| "RenameApiCall": RenameApiCall, |
| "DeprecateImport": DeprecateImport, |
| "ChangeArgumentSignature": ChangeArgumentSignature, |
| "ModifyConfigField": ModifyConfigField, |
| "RestructureDatasetSchema": RestructureDatasetSchema, |
| "ChangeTokenizerBehavior": ChangeTokenizerBehavior, |
| "RemoveDeprecatedMethod": RemoveDeprecatedMethod, |
| "ChangeReturnType": ChangeReturnType, |
| } |
|
|
|
|
| def parse_breakage_spec(spec: dict) -> BreakagePrimitive: |
| """Parse a JSON breakage spec into a BreakagePrimitive object. |
| |
| Tolerates extra keys; ignores unknown params (LLMs hallucinate these). |
| """ |
| ptype = spec.get("primitive_type", "") |
| params = spec.get("params", {}) or {} |
|
|
| if ptype not in PRIMITIVE_REGISTRY: |
| raise ValueError( |
| f"Unknown primitive type: {ptype!r}. " |
| f"Valid types: {list(PRIMITIVE_REGISTRY)}" |
| ) |
|
|
| cls = PRIMITIVE_REGISTRY[ptype] |
| |
| valid_fields = { |
| f.name for f in cls.__dataclass_fields__.values() if f.init |
| } |
| filtered = {k: v for k, v in params.items() if k in valid_fields} |
| return cls(**filtered) |
|
|