forgeenv-source / forgeenv /primitives /repair_primitives.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Repair primitives — direct inverses of the 8 breakage primitives.
Used during warm-start data generation: for every (script, breakage)
pair we know the canonical repair, so we can write SFT pairs.
These are also useful for unit-testing the breakage primitives:
apply(breakage) then apply(repair) should be (close to) the identity.
"""
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@dataclass
class RepairPrimitive(ABC):
category: str = field(default="generic", init=False)
name: str = field(default="RepairPrimitive", init=False)
description: str = field(default="", init=False)
@abstractmethod
def apply(self, script: str) -> str:
"""Transform `script` to undo the corresponding breakage."""
def to_spec(self) -> dict:
return {
"primitive_type": self.__class__.__name__,
"category": self.category,
"params": self._get_params(),
}
@abstractmethod
def _get_params(self) -> dict:
"""Return JSON-serializable constructor parameters."""
@dataclass
class RestoreApiCall(RepairPrimitive):
new_name: str = ""
old_name: str = ""
def __post_init__(self) -> None:
self.category = "api_drift"
self.name = "RestoreApiCall"
self.description = f"Rename {self.new_name} -> {self.old_name}"
def apply(self, script: str) -> str:
if not self.new_name:
return script
pattern = re.compile(rf"(?<!\w){re.escape(self.new_name)}(?!\w)")
return pattern.sub(self.old_name, script)
def _get_params(self) -> dict:
return {"new_name": self.new_name, "old_name": self.old_name}
@dataclass
class RestoreImport(RepairPrimitive):
new_module: str = ""
old_module: str = ""
def __post_init__(self) -> None:
self.category = "import_drift"
self.name = "RestoreImport"
self.description = f"Restore import {self.new_module} -> {self.old_module}"
def apply(self, script: str) -> str:
return script.replace(self.new_module, self.old_module)
def _get_params(self) -> dict:
return {"new_module": self.new_module, "old_module": self.old_module}
@dataclass
class RestoreArgument(RepairPrimitive):
"""Re-add a removed argument to a function call."""
function_name: str = ""
arg_name: str = ""
arg_value: str = ""
def __post_init__(self) -> None:
self.category = "api_drift"
self.name = "RestoreArgument"
self.description = (
f"Add {self.arg_name}={self.arg_value} to {self.function_name}()"
)
def apply(self, script: str) -> str:
if not self.function_name:
return script
# Insert the kwarg right after the function-name's opening paren.
pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)"
replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>"
return re.sub(pattern, replacement, script, count=1)
def _get_params(self) -> dict:
return {
"function_name": self.function_name,
"arg_name": self.arg_name,
"arg_value": self.arg_value,
}
@dataclass
class RestoreConfigField(RepairPrimitive):
field_name: str = ""
old_value: str = ""
def __post_init__(self) -> None:
self.category = "config_drift"
self.name = "RestoreConfigField"
self.description = f"Restore {self.field_name}={self.old_value}"
def apply(self, script: str) -> str:
if not self.field_name:
return script
pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
return re.sub(pattern, rf"\g<1>{self.old_value}", script)
def _get_params(self) -> dict:
return {"field_name": self.field_name, "old_value": self.old_value}
@dataclass
class RestoreColumn(RepairPrimitive):
new_column: str = ""
old_column: str = ""
def __post_init__(self) -> None:
self.category = "dataset_drift"
self.name = "RestoreColumn"
self.description = f"Rename column {self.new_column} -> {self.old_column}"
def apply(self, script: str) -> str:
return script.replace(
f'"{self.new_column}"', f'"{self.old_column}"'
).replace(
f"'{self.new_column}'", f"'{self.old_column}'"
)
def _get_params(self) -> dict:
return {"new_column": self.new_column, "old_column": self.old_column}
@dataclass
class RestoreTokenizerKwarg(RepairPrimitive):
new_kwarg: str = ""
new_value: str = ""
old_kwarg: str = ""
old_value: str = ""
def __post_init__(self) -> None:
self.category = "tokenizer_drift"
self.name = "RestoreTokenizerKwarg"
self.description = (
f"Restore tokenizer {self.new_kwarg}={self.new_value} -> "
f"{self.old_kwarg}={self.old_value}"
)
def apply(self, script: str) -> str:
if not self.new_kwarg:
return script
pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}"
replacement = f"{self.old_kwarg}={self.old_value}"
return re.sub(pattern, replacement, script)
def _get_params(self) -> dict:
return {
"new_kwarg": self.new_kwarg,
"new_value": self.new_value,
"old_kwarg": self.old_kwarg,
"old_value": self.old_value,
}
@dataclass
class RestoreMethod(RepairPrimitive):
method_name: str = ""
def __post_init__(self) -> None:
self.category = "api_drift"
self.name = "RestoreMethod"
self.description = f"Un-deprecate .{self.method_name}()"
def apply(self, script: str) -> str:
if not self.method_name:
return script
return script.replace(
f".{self.method_name}_DEPRECATED(", f".{self.method_name}("
)
def _get_params(self) -> dict:
return {"method_name": self.method_name}
@dataclass
class RestoreReturnAccess(RepairPrimitive):
new_access: str = ""
old_access: str = ""
def __post_init__(self) -> None:
self.category = "api_drift"
self.name = "RestoreReturnAccess"
self.description = f"Restore return-access {self.new_access} -> {self.old_access}"
def apply(self, script: str) -> str:
if not self.new_access:
return script
return script.replace(self.new_access, self.old_access)
def _get_params(self) -> dict:
return {"new_access": self.new_access, "old_access": self.old_access}
REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = {
"RestoreApiCall": RestoreApiCall,
"RestoreImport": RestoreImport,
"RestoreArgument": RestoreArgument,
"RestoreConfigField": RestoreConfigField,
"RestoreColumn": RestoreColumn,
"RestoreTokenizerKwarg": RestoreTokenizerKwarg,
"RestoreMethod": RestoreMethod,
"RestoreReturnAccess": RestoreReturnAccess,
}
# Map a breakage primitive's class name to the repair-primitive class that
# inverts it. Used by the warm-start pair generator and by the demo / repair
# library curator.
BREAKAGE_TO_REPAIR: dict[str, str] = {
"RenameApiCall": "RestoreApiCall",
"DeprecateImport": "RestoreImport",
"ChangeArgumentSignature": "RestoreArgument",
"ModifyConfigField": "RestoreConfigField",
"RestructureDatasetSchema": "RestoreColumn",
"ChangeTokenizerBehavior": "RestoreTokenizerKwarg",
"RemoveDeprecatedMethod": "RestoreMethod",
"ChangeReturnType": "RestoreReturnAccess",
}