"""Repair primitives — direct inverses of the 8 breakage primitives. Used during warm-start data generation: for every (script, breakage) pair we know the canonical repair, so we can write SFT pairs. These are also useful for unit-testing the breakage primitives: apply(breakage) then apply(repair) should be (close to) the identity. """ from __future__ import annotations import re from abc import ABC, abstractmethod from dataclasses import dataclass, field @dataclass class RepairPrimitive(ABC): category: str = field(default="generic", init=False) name: str = field(default="RepairPrimitive", init=False) description: str = field(default="", init=False) @abstractmethod def apply(self, script: str) -> str: """Transform `script` to undo the corresponding breakage.""" def to_spec(self) -> dict: return { "primitive_type": self.__class__.__name__, "category": self.category, "params": self._get_params(), } @abstractmethod def _get_params(self) -> dict: """Return JSON-serializable constructor parameters.""" @dataclass class RestoreApiCall(RepairPrimitive): new_name: str = "" old_name: str = "" def __post_init__(self) -> None: self.category = "api_drift" self.name = "RestoreApiCall" self.description = f"Rename {self.new_name} -> {self.old_name}" def apply(self, script: str) -> str: if not self.new_name: return script pattern = re.compile(rf"(? dict: return {"new_name": self.new_name, "old_name": self.old_name} @dataclass class RestoreImport(RepairPrimitive): new_module: str = "" old_module: str = "" def __post_init__(self) -> None: self.category = "import_drift" self.name = "RestoreImport" self.description = f"Restore import {self.new_module} -> {self.old_module}" def apply(self, script: str) -> str: return script.replace(self.new_module, self.old_module) def _get_params(self) -> dict: return {"new_module": self.new_module, "old_module": self.old_module} @dataclass class RestoreArgument(RepairPrimitive): """Re-add a removed argument to a function call.""" function_name: str = "" arg_name: str = "" arg_value: str = "" def __post_init__(self) -> None: self.category = "api_drift" self.name = "RestoreArgument" self.description = ( f"Add {self.arg_name}={self.arg_value} to {self.function_name}()" ) def apply(self, script: str) -> str: if not self.function_name: return script # Insert the kwarg right after the function-name's opening paren. pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)" replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>" return re.sub(pattern, replacement, script, count=1) def _get_params(self) -> dict: return { "function_name": self.function_name, "arg_name": self.arg_name, "arg_value": self.arg_value, } @dataclass class RestoreConfigField(RepairPrimitive): field_name: str = "" old_value: str = "" def __post_init__(self) -> None: self.category = "config_drift" self.name = "RestoreConfigField" self.description = f"Restore {self.field_name}={self.old_value}" def apply(self, script: str) -> str: if not self.field_name: return script pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)" return re.sub(pattern, rf"\g<1>{self.old_value}", script) def _get_params(self) -> dict: return {"field_name": self.field_name, "old_value": self.old_value} @dataclass class RestoreColumn(RepairPrimitive): new_column: str = "" old_column: str = "" def __post_init__(self) -> None: self.category = "dataset_drift" self.name = "RestoreColumn" self.description = f"Rename column {self.new_column} -> {self.old_column}" def apply(self, script: str) -> str: return script.replace( f'"{self.new_column}"', f'"{self.old_column}"' ).replace( f"'{self.new_column}'", f"'{self.old_column}'" ) def _get_params(self) -> dict: return {"new_column": self.new_column, "old_column": self.old_column} @dataclass class RestoreTokenizerKwarg(RepairPrimitive): new_kwarg: str = "" new_value: str = "" old_kwarg: str = "" old_value: str = "" def __post_init__(self) -> None: self.category = "tokenizer_drift" self.name = "RestoreTokenizerKwarg" self.description = ( f"Restore tokenizer {self.new_kwarg}={self.new_value} -> " f"{self.old_kwarg}={self.old_value}" ) def apply(self, script: str) -> str: if not self.new_kwarg: return script pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}" replacement = f"{self.old_kwarg}={self.old_value}" return re.sub(pattern, replacement, script) def _get_params(self) -> dict: return { "new_kwarg": self.new_kwarg, "new_value": self.new_value, "old_kwarg": self.old_kwarg, "old_value": self.old_value, } @dataclass class RestoreMethod(RepairPrimitive): method_name: str = "" def __post_init__(self) -> None: self.category = "api_drift" self.name = "RestoreMethod" self.description = f"Un-deprecate .{self.method_name}()" def apply(self, script: str) -> str: if not self.method_name: return script return script.replace( f".{self.method_name}_DEPRECATED(", f".{self.method_name}(" ) def _get_params(self) -> dict: return {"method_name": self.method_name} @dataclass class RestoreReturnAccess(RepairPrimitive): new_access: str = "" old_access: str = "" def __post_init__(self) -> None: self.category = "api_drift" self.name = "RestoreReturnAccess" self.description = f"Restore return-access {self.new_access} -> {self.old_access}" def apply(self, script: str) -> str: if not self.new_access: return script return script.replace(self.new_access, self.old_access) def _get_params(self) -> dict: return {"new_access": self.new_access, "old_access": self.old_access} REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = { "RestoreApiCall": RestoreApiCall, "RestoreImport": RestoreImport, "RestoreArgument": RestoreArgument, "RestoreConfigField": RestoreConfigField, "RestoreColumn": RestoreColumn, "RestoreTokenizerKwarg": RestoreTokenizerKwarg, "RestoreMethod": RestoreMethod, "RestoreReturnAccess": RestoreReturnAccess, } # Map a breakage primitive's class name to the repair-primitive class that # inverts it. Used by the warm-start pair generator and by the demo / repair # library curator. BREAKAGE_TO_REPAIR: dict[str, str] = { "RenameApiCall": "RestoreApiCall", "DeprecateImport": "RestoreImport", "ChangeArgumentSignature": "RestoreArgument", "ModifyConfigField": "RestoreConfigField", "RestructureDatasetSchema": "RestoreColumn", "ChangeTokenizerBehavior": "RestoreTokenizerKwarg", "RemoveDeprecatedMethod": "RestoreMethod", "ChangeReturnType": "RestoreReturnAccess", }