| """Repair primitives — direct inverses of the 8 breakage primitives. |
| |
| Used during warm-start data generation: for every (script, breakage) |
| pair we know the canonical repair, so we can write SFT pairs. |
| |
| These are also useful for unit-testing the breakage primitives: |
| apply(breakage) then apply(repair) should be (close to) the identity. |
| """ |
| from __future__ import annotations |
|
|
| import re |
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class RepairPrimitive(ABC): |
| category: str = field(default="generic", init=False) |
| name: str = field(default="RepairPrimitive", init=False) |
| description: str = field(default="", init=False) |
|
|
| @abstractmethod |
| def apply(self, script: str) -> str: |
| """Transform `script` to undo the corresponding breakage.""" |
|
|
| def to_spec(self) -> dict: |
| return { |
| "primitive_type": self.__class__.__name__, |
| "category": self.category, |
| "params": self._get_params(), |
| } |
|
|
| @abstractmethod |
| def _get_params(self) -> dict: |
| """Return JSON-serializable constructor parameters.""" |
|
|
|
|
| @dataclass |
| class RestoreApiCall(RepairPrimitive): |
| new_name: str = "" |
| old_name: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "api_drift" |
| self.name = "RestoreApiCall" |
| self.description = f"Rename {self.new_name} -> {self.old_name}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.new_name: |
| return script |
| pattern = re.compile(rf"(?<!\w){re.escape(self.new_name)}(?!\w)") |
| return pattern.sub(self.old_name, script) |
|
|
| def _get_params(self) -> dict: |
| return {"new_name": self.new_name, "old_name": self.old_name} |
|
|
|
|
| @dataclass |
| class RestoreImport(RepairPrimitive): |
| new_module: str = "" |
| old_module: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "import_drift" |
| self.name = "RestoreImport" |
| self.description = f"Restore import {self.new_module} -> {self.old_module}" |
|
|
| def apply(self, script: str) -> str: |
| return script.replace(self.new_module, self.old_module) |
|
|
| def _get_params(self) -> dict: |
| return {"new_module": self.new_module, "old_module": self.old_module} |
|
|
|
|
| @dataclass |
| class RestoreArgument(RepairPrimitive): |
| """Re-add a removed argument to a function call.""" |
|
|
| function_name: str = "" |
| arg_name: str = "" |
| arg_value: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "api_drift" |
| self.name = "RestoreArgument" |
| self.description = ( |
| f"Add {self.arg_name}={self.arg_value} to {self.function_name}()" |
| ) |
|
|
| def apply(self, script: str) -> str: |
| if not self.function_name: |
| return script |
| |
| pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)" |
| replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>" |
| return re.sub(pattern, replacement, script, count=1) |
|
|
| def _get_params(self) -> dict: |
| return { |
| "function_name": self.function_name, |
| "arg_name": self.arg_name, |
| "arg_value": self.arg_value, |
| } |
|
|
|
|
| @dataclass |
| class RestoreConfigField(RepairPrimitive): |
| field_name: str = "" |
| old_value: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "config_drift" |
| self.name = "RestoreConfigField" |
| self.description = f"Restore {self.field_name}={self.old_value}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.field_name: |
| return script |
| pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)" |
| return re.sub(pattern, rf"\g<1>{self.old_value}", script) |
|
|
| def _get_params(self) -> dict: |
| return {"field_name": self.field_name, "old_value": self.old_value} |
|
|
|
|
| @dataclass |
| class RestoreColumn(RepairPrimitive): |
| new_column: str = "" |
| old_column: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "dataset_drift" |
| self.name = "RestoreColumn" |
| self.description = f"Rename column {self.new_column} -> {self.old_column}" |
|
|
| def apply(self, script: str) -> str: |
| return script.replace( |
| f'"{self.new_column}"', f'"{self.old_column}"' |
| ).replace( |
| f"'{self.new_column}'", f"'{self.old_column}'" |
| ) |
|
|
| def _get_params(self) -> dict: |
| return {"new_column": self.new_column, "old_column": self.old_column} |
|
|
|
|
| @dataclass |
| class RestoreTokenizerKwarg(RepairPrimitive): |
| new_kwarg: str = "" |
| new_value: str = "" |
| old_kwarg: str = "" |
| old_value: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "tokenizer_drift" |
| self.name = "RestoreTokenizerKwarg" |
| self.description = ( |
| f"Restore tokenizer {self.new_kwarg}={self.new_value} -> " |
| f"{self.old_kwarg}={self.old_value}" |
| ) |
|
|
| def apply(self, script: str) -> str: |
| if not self.new_kwarg: |
| return script |
| pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}" |
| replacement = f"{self.old_kwarg}={self.old_value}" |
| return re.sub(pattern, replacement, script) |
|
|
| def _get_params(self) -> dict: |
| return { |
| "new_kwarg": self.new_kwarg, |
| "new_value": self.new_value, |
| "old_kwarg": self.old_kwarg, |
| "old_value": self.old_value, |
| } |
|
|
|
|
| @dataclass |
| class RestoreMethod(RepairPrimitive): |
| method_name: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "api_drift" |
| self.name = "RestoreMethod" |
| self.description = f"Un-deprecate .{self.method_name}()" |
|
|
| def apply(self, script: str) -> str: |
| if not self.method_name: |
| return script |
| return script.replace( |
| f".{self.method_name}_DEPRECATED(", f".{self.method_name}(" |
| ) |
|
|
| def _get_params(self) -> dict: |
| return {"method_name": self.method_name} |
|
|
|
|
| @dataclass |
| class RestoreReturnAccess(RepairPrimitive): |
| new_access: str = "" |
| old_access: str = "" |
|
|
| def __post_init__(self) -> None: |
| self.category = "api_drift" |
| self.name = "RestoreReturnAccess" |
| self.description = f"Restore return-access {self.new_access} -> {self.old_access}" |
|
|
| def apply(self, script: str) -> str: |
| if not self.new_access: |
| return script |
| return script.replace(self.new_access, self.old_access) |
|
|
| def _get_params(self) -> dict: |
| return {"new_access": self.new_access, "old_access": self.old_access} |
|
|
|
|
| REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = { |
| "RestoreApiCall": RestoreApiCall, |
| "RestoreImport": RestoreImport, |
| "RestoreArgument": RestoreArgument, |
| "RestoreConfigField": RestoreConfigField, |
| "RestoreColumn": RestoreColumn, |
| "RestoreTokenizerKwarg": RestoreTokenizerKwarg, |
| "RestoreMethod": RestoreMethod, |
| "RestoreReturnAccess": RestoreReturnAccess, |
| } |
|
|
|
|
| |
| |
| |
| BREAKAGE_TO_REPAIR: dict[str, str] = { |
| "RenameApiCall": "RestoreApiCall", |
| "DeprecateImport": "RestoreImport", |
| "ChangeArgumentSignature": "RestoreArgument", |
| "ModifyConfigField": "RestoreConfigField", |
| "RestructureDatasetSchema": "RestoreColumn", |
| "ChangeTokenizerBehavior": "RestoreTokenizerKwarg", |
| "RemoveDeprecatedMethod": "RestoreMethod", |
| "ChangeReturnType": "RestoreReturnAccess", |
| } |
|
|