"""8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift. Each primitive transforms a working script to simulate a library upgrade breakage. They double as the Drift Generator's structured action space. """ from __future__ import annotations import re from abc import ABC, abstractmethod from dataclasses import dataclass, field @dataclass class BreakagePrimitive(ABC): """Abstract base class for all breakage types.""" category: str = field(default="generic", init=False) name: str = field(default="BreakagePrimitive", init=False) description: str = field(default="", init=False) @abstractmethod def apply(self, script: str) -> str: """Transform `script` to introduce the breakage.""" def to_spec(self) -> dict: """Serialize to JSON-compatible spec for the LLM action space.""" return { "primitive_type": self.__class__.__name__, "category": self.category, "params": self._get_params(), } @abstractmethod def _get_params(self) -> dict: """Return a JSON-serializable dict of constructor parameters.""" @dataclass class RenameApiCall(BreakagePrimitive): """Rename a function/method call to simulate API deprecation.""" old_name: str = "" new_name: str = "" def __post_init__(self) -> None: self.category = "api_drift" self.name = "RenameApiCall" self.description = f"Rename {self.old_name} -> {self.new_name}" def apply(self, script: str) -> str: if not self.old_name: return script # Use word-boundary replacement so we don't substring-match identifiers. pattern = re.compile(rf"(? dict: return {"old_name": self.old_name, "new_name": self.new_name} @dataclass class DeprecateImport(BreakagePrimitive): """Change an import path to simulate module restructuring.""" old_module: str = "" new_module: str = "" def __post_init__(self) -> None: self.category = "import_drift" self.name = "DeprecateImport" self.description = f"Move {self.old_module} -> {self.new_module}" def apply(self, script: str) -> str: if not self.old_module: return script return script.replace(self.old_module, self.new_module) def _get_params(self) -> dict: return {"old_module": self.old_module, "new_module": self.new_module} @dataclass class ChangeArgumentSignature(BreakagePrimitive): """Remove an expected kwarg (and document a new required one).""" function_name: str = "" removed_arg: str = "" added_arg: str = "" added_value: str = "" def __post_init__(self) -> None: self.category = "api_drift" self.name = "ChangeArgumentSignature" self.description = ( f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}" ) def apply(self, script: str) -> str: if not self.removed_arg: return script pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)" return re.sub(pattern, "", script) def _get_params(self) -> dict: return { "function_name": self.function_name, "removed_arg": self.removed_arg, "added_arg": self.added_arg, "added_value": self.added_value, } @dataclass class ModifyConfigField(BreakagePrimitive): """Change a config-class default value to simulate behaviour drift.""" config_class: str = "" field_name: str = "" new_value: str = "" def __post_init__(self) -> None: self.category = "config_drift" self.name = "ModifyConfigField" self.description = f"Change {self.config_class}.{self.field_name}" def apply(self, script: str) -> str: if not self.field_name: return script pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)" return re.sub(pattern, rf"\g<1>{self.new_value}", script) def _get_params(self) -> dict: return { "config_class": self.config_class, "field_name": self.field_name, "new_value": self.new_value, } @dataclass class RestructureDatasetSchema(BreakagePrimitive): """Rename a dataset column reference to simulate schema drift.""" old_column: str = "" new_column: str = "" def __post_init__(self) -> None: self.category = "dataset_drift" self.name = "RestructureDatasetSchema" self.description = f"Rename column {self.old_column} -> {self.new_column}" def apply(self, script: str) -> str: if not self.old_column: return script return script.replace( f'"{self.old_column}"', f'"{self.new_column}"' ).replace( f"'{self.old_column}'", f"'{self.new_column}'" ) def _get_params(self) -> dict: return {"old_column": self.old_column, "new_column": self.new_column} @dataclass class ChangeTokenizerBehavior(BreakagePrimitive): """Change tokenizer call arguments.""" old_kwarg: str = "" old_value: str = "" new_kwarg: str = "" new_value: str = "" def __post_init__(self) -> None: self.category = "tokenizer_drift" self.name = "ChangeTokenizerBehavior" self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}" def apply(self, script: str) -> str: if not self.old_kwarg: return script pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}" replacement = f"{self.new_kwarg}={self.new_value}" return re.sub(pattern, replacement, script) def _get_params(self) -> dict: return { "old_kwarg": self.old_kwarg, "old_value": self.old_value, "new_kwarg": self.new_kwarg, "new_value": self.new_value, } @dataclass class RemoveDeprecatedMethod(BreakagePrimitive): """Remove a method that has been deprecated, leaving a sentinel that raises AttributeError-style errors when the script runs.""" class_name: str = "" method_name: str = "" replacement: str = "" def __post_init__(self) -> None: self.category = "api_drift" self.name = "RemoveDeprecatedMethod" self.description = f"Remove {self.class_name}.{self.method_name}" def apply(self, script: str) -> str: if not self.method_name: return script return script.replace( f".{self.method_name}(", f".{self.method_name}_DEPRECATED(" ) def _get_params(self) -> dict: return { "class_name": self.class_name, "method_name": self.method_name, "replacement": self.replacement, } @dataclass class ChangeReturnType(BreakagePrimitive): """A function now returns a different structure (e.g. tuple -> object).""" function_name: str = "" old_access: str = "" new_access: str = "" def __post_init__(self) -> None: self.category = "api_drift" self.name = "ChangeReturnType" self.description = f"Change return type of {self.function_name}" def apply(self, script: str) -> str: if self.old_access and self.new_access: return script.replace(self.old_access, self.new_access) return script def _get_params(self) -> dict: return { "function_name": self.function_name, "old_access": self.old_access, "new_access": self.new_access, } PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = { "RenameApiCall": RenameApiCall, "DeprecateImport": DeprecateImport, "ChangeArgumentSignature": ChangeArgumentSignature, "ModifyConfigField": ModifyConfigField, "RestructureDatasetSchema": RestructureDatasetSchema, "ChangeTokenizerBehavior": ChangeTokenizerBehavior, "RemoveDeprecatedMethod": RemoveDeprecatedMethod, "ChangeReturnType": ChangeReturnType, } def parse_breakage_spec(spec: dict) -> BreakagePrimitive: """Parse a JSON breakage spec into a BreakagePrimitive object. Tolerates extra keys; ignores unknown params (LLMs hallucinate these). """ ptype = spec.get("primitive_type", "") params = spec.get("params", {}) or {} if ptype not in PRIMITIVE_REGISTRY: raise ValueError( f"Unknown primitive type: {ptype!r}. " f"Valid types: {list(PRIMITIVE_REGISTRY)}" ) cls = PRIMITIVE_REGISTRY[ptype] # Filter to known fields only so a hallucinated kwarg can't crash us. valid_fields = { f.name for f in cls.__dataclass_fields__.values() if f.init # type: ignore[attr-defined] } filtered = {k: v for k, v in params.items() if k in valid_fields} return cls(**filtered)