| """Tests for breakage primitives, repair primitives, and task sampler."""
|
| import pytest
|
|
|
| from forgeenv.primitives.breakage_primitives import (
|
| ChangeArgumentSignature,
|
| ChangeReturnType,
|
| ChangeTokenizerBehavior,
|
| DeprecateImport,
|
| ModifyConfigField,
|
| PRIMITIVE_REGISTRY,
|
| RemoveDeprecatedMethod,
|
| RenameApiCall,
|
| RestructureDatasetSchema,
|
| parse_breakage_spec,
|
| )
|
| from forgeenv.primitives.repair_primitives import (
|
| BREAKAGE_TO_REPAIR,
|
| REPAIR_REGISTRY,
|
| RestoreApiCall,
|
| RestoreColumn,
|
| RestoreImport,
|
| RestoreMethod,
|
| )
|
| from forgeenv.tasks.task_sampler import TaskSampler
|
|
|
| SAMPLE_SCRIPT = """
|
| from transformers import Trainer, TrainingArguments
|
| from datasets import load_dataset
|
|
|
| dataset = load_dataset("glue", "sst2")
|
| dataset = dataset.rename_column("label", "labels")
|
|
|
| args = TrainingArguments(num_train_epochs=3, report_to="none")
|
| trainer = Trainer(model=model, args=args, train_dataset=dataset)
|
| trainer.train()
|
| result = trainer.evaluate()
|
| """
|
|
|
|
|
| def test_rename_api_call_word_boundary():
|
| """Renaming should not break out-of-context substrings."""
|
| b = RenameApiCall(old_name="evaluate", new_name="eval_model")
|
| broken = b.apply(SAMPLE_SCRIPT)
|
| assert "eval_model" in broken
|
| assert "trainer.evaluate" not in broken
|
|
|
| r = RestoreApiCall(new_name="eval_model", old_name="evaluate")
|
| restored = r.apply(broken)
|
| assert restored.strip() == SAMPLE_SCRIPT.strip()
|
|
|
|
|
| def test_deprecate_import():
|
| b = DeprecateImport(
|
| old_module="from transformers import",
|
| new_module="from transformers.legacy import",
|
| )
|
| broken = b.apply(SAMPLE_SCRIPT)
|
| assert "transformers.legacy" in broken
|
| r = RestoreImport(
|
| new_module="from transformers.legacy import",
|
| old_module="from transformers import",
|
| )
|
| restored = r.apply(broken)
|
| assert restored == SAMPLE_SCRIPT
|
|
|
|
|
| def test_restructure_dataset_string_replacement():
|
| b = RestructureDatasetSchema(old_column="label", new_column="sentiment_label")
|
| broken = b.apply(SAMPLE_SCRIPT)
|
| assert '"sentiment_label"' in broken
|
| assert '"label"' not in broken
|
| r = RestoreColumn(new_column="sentiment_label", old_column="label")
|
| restored = r.apply(broken)
|
| assert restored == SAMPLE_SCRIPT
|
|
|
|
|
| def test_modify_config_field_changes_value():
|
| b = ModifyConfigField(
|
| config_class="TrainingArguments",
|
| field_name="num_train_epochs",
|
| new_value="999",
|
| )
|
| broken = b.apply(SAMPLE_SCRIPT)
|
| assert "num_train_epochs=999" in broken
|
|
|
|
|
| def test_change_tokenizer_behavior_replaces_kwarg():
|
| script = "tok = tokenizer(text, padding=True, truncation=True)"
|
| b = ChangeTokenizerBehavior(
|
| old_kwarg="padding",
|
| old_value="True",
|
| new_kwarg="padding",
|
| new_value='"max_length"',
|
| )
|
| broken = b.apply(script)
|
| assert 'padding="max_length"' in broken
|
|
|
|
|
| def test_remove_deprecated_method_marks_call():
|
| b = RemoveDeprecatedMethod(
|
| class_name="Trainer", method_name="evaluate", replacement="evaluate_legacy"
|
| )
|
| broken = b.apply(SAMPLE_SCRIPT)
|
| assert ".evaluate_DEPRECATED(" in broken
|
| r = RestoreMethod(method_name="evaluate")
|
| restored = r.apply(broken)
|
| assert restored == SAMPLE_SCRIPT
|
|
|
|
|
| def test_change_argument_signature_removes_kwarg():
|
| b = ChangeArgumentSignature(
|
| function_name="TrainingArguments",
|
| removed_arg="report_to",
|
| added_arg="report_to",
|
| added_value='"none"',
|
| )
|
| broken = b.apply(SAMPLE_SCRIPT)
|
| assert 'report_to="none"' not in broken
|
|
|
|
|
| def test_change_return_type_swaps_access():
|
| b = ChangeReturnType(
|
| function_name="evaluate",
|
| old_access="trainer.evaluate()",
|
| new_access="trainer.evaluate().metrics",
|
| )
|
| broken = b.apply(SAMPLE_SCRIPT)
|
| assert "trainer.evaluate().metrics" in broken
|
|
|
|
|
| def test_parse_spec_round_trip():
|
| spec = {
|
| "primitive_type": "RenameApiCall",
|
| "params": {"old_name": "evaluate", "new_name": "eval_model"},
|
| }
|
| primitive = parse_breakage_spec(spec)
|
| assert isinstance(primitive, RenameApiCall)
|
| assert primitive.old_name == "evaluate"
|
| assert primitive.to_spec()["primitive_type"] == "RenameApiCall"
|
|
|
|
|
| def test_parse_spec_unknown_raises():
|
| with pytest.raises(ValueError):
|
| parse_breakage_spec({"primitive_type": "Bogus"})
|
|
|
|
|
| def test_parse_spec_ignores_extra_kwargs():
|
| """LLMs hallucinate kwargs; we should silently filter them."""
|
| spec = {
|
| "primitive_type": "RenameApiCall",
|
| "params": {
|
| "old_name": "evaluate",
|
| "new_name": "eval_model",
|
| "hallucinated_kwarg": "ignore_me",
|
| },
|
| }
|
| primitive = parse_breakage_spec(spec)
|
| assert isinstance(primitive, RenameApiCall)
|
|
|
|
|
| def test_breakage_creates_actual_difference():
|
| b = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training")
|
| broken = b.apply(SAMPLE_SCRIPT)
|
| assert broken != SAMPLE_SCRIPT
|
|
|
|
|
| def test_all_8_primitives_registered():
|
| expected = {
|
| "RenameApiCall",
|
| "DeprecateImport",
|
| "ChangeArgumentSignature",
|
| "ModifyConfigField",
|
| "RestructureDatasetSchema",
|
| "ChangeTokenizerBehavior",
|
| "RemoveDeprecatedMethod",
|
| "ChangeReturnType",
|
| }
|
| assert set(PRIMITIVE_REGISTRY) == expected
|
|
|
|
|
| def test_breakage_repair_registry_alignment():
|
| """Every breakage class should have a registered inverse."""
|
| for breakage_name, repair_name in BREAKAGE_TO_REPAIR.items():
|
| assert breakage_name in PRIMITIVE_REGISTRY
|
| assert repair_name in REPAIR_REGISTRY
|
|
|
|
|
| def test_seed_corpus_has_at_least_10_scripts():
|
| sampler = TaskSampler()
|
| assert len(sampler.tasks) >= 10
|
| assert all(t.script_content for t in sampler.tasks)
|
| assert all(t.task_id for t in sampler.tasks)
|
|
|
|
|
| def test_task_sampler_categories_are_diverse():
|
| sampler = TaskSampler()
|
| categories = sampler.get_all_categories()
|
| assert len(categories) >= 3, f"Expected at least 3 distinct categories, got {categories}"
|
|
|
|
|
| def test_task_sampler_difficulty_filter():
|
| sampler = TaskSampler()
|
|
|
| task = sampler.sample(difficulty="easy")
|
| if task is not None:
|
| assert task.difficulty == "easy"
|
|
|
|
|
| def test_task_sampler_get_by_id():
|
| sampler = TaskSampler()
|
| if not sampler.tasks:
|
| pytest.skip("No tasks loaded")
|
| first = sampler.tasks[0]
|
| fetched = sampler.get_by_id(first.task_id)
|
| assert fetched is first
|
|
|