| """Repair primitives — direct inverses of the 8 breakage primitives.
|
|
|
| Used during warm-start data generation: for every (script, breakage)
|
| pair we know the canonical repair, so we can write SFT pairs.
|
|
|
| These are also useful for unit-testing the breakage primitives:
|
| apply(breakage) then apply(repair) should be (close to) the identity.
|
| """
|
| from __future__ import annotations
|
|
|
| import re
|
| from abc import ABC, abstractmethod
|
| from dataclasses import dataclass, field
|
|
|
|
|
| @dataclass
|
| class RepairPrimitive(ABC):
|
| category: str = field(default="generic", init=False)
|
| name: str = field(default="RepairPrimitive", init=False)
|
| description: str = field(default="", init=False)
|
|
|
| @abstractmethod
|
| def apply(self, script: str) -> str:
|
| """Transform `script` to undo the corresponding breakage."""
|
|
|
| def to_spec(self) -> dict:
|
| return {
|
| "primitive_type": self.__class__.__name__,
|
| "category": self.category,
|
| "params": self._get_params(),
|
| }
|
|
|
| @abstractmethod
|
| def _get_params(self) -> dict:
|
| """Return JSON-serializable constructor parameters."""
|
|
|
|
|
| @dataclass
|
| class RestoreApiCall(RepairPrimitive):
|
| new_name: str = ""
|
| old_name: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "api_drift"
|
| self.name = "RestoreApiCall"
|
| self.description = f"Rename {self.new_name} -> {self.old_name}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.new_name:
|
| return script
|
| pattern = re.compile(rf"(?<!\w){re.escape(self.new_name)}(?!\w)")
|
| return pattern.sub(self.old_name, script)
|
|
|
| def _get_params(self) -> dict:
|
| return {"new_name": self.new_name, "old_name": self.old_name}
|
|
|
|
|
| @dataclass
|
| class RestoreImport(RepairPrimitive):
|
| new_module: str = ""
|
| old_module: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "import_drift"
|
| self.name = "RestoreImport"
|
| self.description = f"Restore import {self.new_module} -> {self.old_module}"
|
|
|
| def apply(self, script: str) -> str:
|
| return script.replace(self.new_module, self.old_module)
|
|
|
| def _get_params(self) -> dict:
|
| return {"new_module": self.new_module, "old_module": self.old_module}
|
|
|
|
|
| @dataclass
|
| class RestoreArgument(RepairPrimitive):
|
| """Re-add a removed argument to a function call."""
|
|
|
| function_name: str = ""
|
| arg_name: str = ""
|
| arg_value: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "api_drift"
|
| self.name = "RestoreArgument"
|
| self.description = (
|
| f"Add {self.arg_name}={self.arg_value} to {self.function_name}()"
|
| )
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.function_name:
|
| return script
|
|
|
| pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)"
|
| replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>"
|
| return re.sub(pattern, replacement, script, count=1)
|
|
|
| def _get_params(self) -> dict:
|
| return {
|
| "function_name": self.function_name,
|
| "arg_name": self.arg_name,
|
| "arg_value": self.arg_value,
|
| }
|
|
|
|
|
| @dataclass
|
| class RestoreConfigField(RepairPrimitive):
|
| field_name: str = ""
|
| old_value: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "config_drift"
|
| self.name = "RestoreConfigField"
|
| self.description = f"Restore {self.field_name}={self.old_value}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.field_name:
|
| return script
|
| pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
|
| return re.sub(pattern, rf"\g<1>{self.old_value}", script)
|
|
|
| def _get_params(self) -> dict:
|
| return {"field_name": self.field_name, "old_value": self.old_value}
|
|
|
|
|
| @dataclass
|
| class RestoreColumn(RepairPrimitive):
|
| new_column: str = ""
|
| old_column: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "dataset_drift"
|
| self.name = "RestoreColumn"
|
| self.description = f"Rename column {self.new_column} -> {self.old_column}"
|
|
|
| def apply(self, script: str) -> str:
|
| return script.replace(
|
| f'"{self.new_column}"', f'"{self.old_column}"'
|
| ).replace(
|
| f"'{self.new_column}'", f"'{self.old_column}'"
|
| )
|
|
|
| def _get_params(self) -> dict:
|
| return {"new_column": self.new_column, "old_column": self.old_column}
|
|
|
|
|
| @dataclass
|
| class RestoreTokenizerKwarg(RepairPrimitive):
|
| new_kwarg: str = ""
|
| new_value: str = ""
|
| old_kwarg: str = ""
|
| old_value: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "tokenizer_drift"
|
| self.name = "RestoreTokenizerKwarg"
|
| self.description = (
|
| f"Restore tokenizer {self.new_kwarg}={self.new_value} -> "
|
| f"{self.old_kwarg}={self.old_value}"
|
| )
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.new_kwarg:
|
| return script
|
| pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}"
|
| replacement = f"{self.old_kwarg}={self.old_value}"
|
| return re.sub(pattern, replacement, script)
|
|
|
| def _get_params(self) -> dict:
|
| return {
|
| "new_kwarg": self.new_kwarg,
|
| "new_value": self.new_value,
|
| "old_kwarg": self.old_kwarg,
|
| "old_value": self.old_value,
|
| }
|
|
|
|
|
| @dataclass
|
| class RestoreMethod(RepairPrimitive):
|
| method_name: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "api_drift"
|
| self.name = "RestoreMethod"
|
| self.description = f"Un-deprecate .{self.method_name}()"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.method_name:
|
| return script
|
| return script.replace(
|
| f".{self.method_name}_DEPRECATED(", f".{self.method_name}("
|
| )
|
|
|
| def _get_params(self) -> dict:
|
| return {"method_name": self.method_name}
|
|
|
|
|
| @dataclass
|
| class RestoreReturnAccess(RepairPrimitive):
|
| new_access: str = ""
|
| old_access: str = ""
|
|
|
| def __post_init__(self) -> None:
|
| self.category = "api_drift"
|
| self.name = "RestoreReturnAccess"
|
| self.description = f"Restore return-access {self.new_access} -> {self.old_access}"
|
|
|
| def apply(self, script: str) -> str:
|
| if not self.new_access:
|
| return script
|
| return script.replace(self.new_access, self.old_access)
|
|
|
| def _get_params(self) -> dict:
|
| return {"new_access": self.new_access, "old_access": self.old_access}
|
|
|
|
|
| REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = {
|
| "RestoreApiCall": RestoreApiCall,
|
| "RestoreImport": RestoreImport,
|
| "RestoreArgument": RestoreArgument,
|
| "RestoreConfigField": RestoreConfigField,
|
| "RestoreColumn": RestoreColumn,
|
| "RestoreTokenizerKwarg": RestoreTokenizerKwarg,
|
| "RestoreMethod": RestoreMethod,
|
| "RestoreReturnAccess": RestoreReturnAccess,
|
| }
|
|
|
|
|
|
|
|
|
|
|
| BREAKAGE_TO_REPAIR: dict[str, str] = {
|
| "RenameApiCall": "RestoreApiCall",
|
| "DeprecateImport": "RestoreImport",
|
| "ChangeArgumentSignature": "RestoreArgument",
|
| "ModifyConfigField": "RestoreConfigField",
|
| "RestructureDatasetSchema": "RestoreColumn",
|
| "ChangeTokenizerBehavior": "RestoreTokenizerKwarg",
|
| "RemoveDeprecatedMethod": "RestoreMethod",
|
| "ChangeReturnType": "RestoreReturnAccess",
|
| }
|
|
|