File size: 6,563 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """Tests for breakage primitives, repair primitives, and task sampler."""
import pytest
from forgeenv.primitives.breakage_primitives import (
ChangeArgumentSignature,
ChangeReturnType,
ChangeTokenizerBehavior,
DeprecateImport,
ModifyConfigField,
PRIMITIVE_REGISTRY,
RemoveDeprecatedMethod,
RenameApiCall,
RestructureDatasetSchema,
parse_breakage_spec,
)
from forgeenv.primitives.repair_primitives import (
BREAKAGE_TO_REPAIR,
REPAIR_REGISTRY,
RestoreApiCall,
RestoreColumn,
RestoreImport,
RestoreMethod,
)
from forgeenv.tasks.task_sampler import TaskSampler
SAMPLE_SCRIPT = """
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
dataset = load_dataset("glue", "sst2")
dataset = dataset.rename_column("label", "labels")
args = TrainingArguments(num_train_epochs=3, report_to="none")
trainer = Trainer(model=model, args=args, train_dataset=dataset)
trainer.train()
result = trainer.evaluate()
"""
def test_rename_api_call_word_boundary():
"""Renaming should not break out-of-context substrings."""
b = RenameApiCall(old_name="evaluate", new_name="eval_model")
broken = b.apply(SAMPLE_SCRIPT)
assert "eval_model" in broken
assert "trainer.evaluate" not in broken
# Inverse should restore
r = RestoreApiCall(new_name="eval_model", old_name="evaluate")
restored = r.apply(broken)
assert restored.strip() == SAMPLE_SCRIPT.strip()
def test_deprecate_import():
b = DeprecateImport(
old_module="from transformers import",
new_module="from transformers.legacy import",
)
broken = b.apply(SAMPLE_SCRIPT)
assert "transformers.legacy" in broken
r = RestoreImport(
new_module="from transformers.legacy import",
old_module="from transformers import",
)
restored = r.apply(broken)
assert restored == SAMPLE_SCRIPT
def test_restructure_dataset_string_replacement():
b = RestructureDatasetSchema(old_column="label", new_column="sentiment_label")
broken = b.apply(SAMPLE_SCRIPT)
assert '"sentiment_label"' in broken
assert '"label"' not in broken
r = RestoreColumn(new_column="sentiment_label", old_column="label")
restored = r.apply(broken)
assert restored == SAMPLE_SCRIPT
def test_modify_config_field_changes_value():
b = ModifyConfigField(
config_class="TrainingArguments",
field_name="num_train_epochs",
new_value="999",
)
broken = b.apply(SAMPLE_SCRIPT)
assert "num_train_epochs=999" in broken
def test_change_tokenizer_behavior_replaces_kwarg():
script = "tok = tokenizer(text, padding=True, truncation=True)"
b = ChangeTokenizerBehavior(
old_kwarg="padding",
old_value="True",
new_kwarg="padding",
new_value='"max_length"',
)
broken = b.apply(script)
assert 'padding="max_length"' in broken
def test_remove_deprecated_method_marks_call():
b = RemoveDeprecatedMethod(
class_name="Trainer", method_name="evaluate", replacement="evaluate_legacy"
)
broken = b.apply(SAMPLE_SCRIPT)
assert ".evaluate_DEPRECATED(" in broken
r = RestoreMethod(method_name="evaluate")
restored = r.apply(broken)
assert restored == SAMPLE_SCRIPT
def test_change_argument_signature_removes_kwarg():
b = ChangeArgumentSignature(
function_name="TrainingArguments",
removed_arg="report_to",
added_arg="report_to",
added_value='"none"',
)
broken = b.apply(SAMPLE_SCRIPT)
assert 'report_to="none"' not in broken
def test_change_return_type_swaps_access():
b = ChangeReturnType(
function_name="evaluate",
old_access="trainer.evaluate()",
new_access="trainer.evaluate().metrics",
)
broken = b.apply(SAMPLE_SCRIPT)
assert "trainer.evaluate().metrics" in broken
def test_parse_spec_round_trip():
spec = {
"primitive_type": "RenameApiCall",
"params": {"old_name": "evaluate", "new_name": "eval_model"},
}
primitive = parse_breakage_spec(spec)
assert isinstance(primitive, RenameApiCall)
assert primitive.old_name == "evaluate"
assert primitive.to_spec()["primitive_type"] == "RenameApiCall"
def test_parse_spec_unknown_raises():
with pytest.raises(ValueError):
parse_breakage_spec({"primitive_type": "Bogus"})
def test_parse_spec_ignores_extra_kwargs():
"""LLMs hallucinate kwargs; we should silently filter them."""
spec = {
"primitive_type": "RenameApiCall",
"params": {
"old_name": "evaluate",
"new_name": "eval_model",
"hallucinated_kwarg": "ignore_me",
},
}
primitive = parse_breakage_spec(spec)
assert isinstance(primitive, RenameApiCall)
def test_breakage_creates_actual_difference():
b = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training")
broken = b.apply(SAMPLE_SCRIPT)
assert broken != SAMPLE_SCRIPT
def test_all_8_primitives_registered():
expected = {
"RenameApiCall",
"DeprecateImport",
"ChangeArgumentSignature",
"ModifyConfigField",
"RestructureDatasetSchema",
"ChangeTokenizerBehavior",
"RemoveDeprecatedMethod",
"ChangeReturnType",
}
assert set(PRIMITIVE_REGISTRY) == expected
def test_breakage_repair_registry_alignment():
"""Every breakage class should have a registered inverse."""
for breakage_name, repair_name in BREAKAGE_TO_REPAIR.items():
assert breakage_name in PRIMITIVE_REGISTRY
assert repair_name in REPAIR_REGISTRY
def test_seed_corpus_has_at_least_10_scripts():
sampler = TaskSampler()
assert len(sampler.tasks) >= 10
assert all(t.script_content for t in sampler.tasks)
assert all(t.task_id for t in sampler.tasks)
def test_task_sampler_categories_are_diverse():
sampler = TaskSampler()
categories = sampler.get_all_categories()
assert len(categories) >= 3, f"Expected at least 3 distinct categories, got {categories}"
def test_task_sampler_difficulty_filter():
sampler = TaskSampler()
# Should not crash even when an unknown difficulty is requested.
task = sampler.sample(difficulty="easy")
if task is not None:
assert task.difficulty == "easy"
def test_task_sampler_get_by_id():
sampler = TaskSampler()
if not sampler.tasks:
pytest.skip("No tasks loaded")
first = sampler.tasks[0]
fetched = sampler.get_by_id(first.task_id)
assert fetched is first
|