| """Tests for breakage primitives, repair primitives, and task sampler.""" |
| import pytest |
|
|
| from forgeenv.primitives.breakage_primitives import ( |
| ChangeArgumentSignature, |
| ChangeReturnType, |
| ChangeTokenizerBehavior, |
| DeprecateImport, |
| ModifyConfigField, |
| PRIMITIVE_REGISTRY, |
| RemoveDeprecatedMethod, |
| RenameApiCall, |
| RestructureDatasetSchema, |
| parse_breakage_spec, |
| ) |
| from forgeenv.primitives.repair_primitives import ( |
| BREAKAGE_TO_REPAIR, |
| REPAIR_REGISTRY, |
| RestoreApiCall, |
| RestoreColumn, |
| RestoreImport, |
| RestoreMethod, |
| ) |
| from forgeenv.tasks.task_sampler import TaskSampler |
|
|
| SAMPLE_SCRIPT = """ |
| from transformers import Trainer, TrainingArguments |
| from datasets import load_dataset |
| |
| dataset = load_dataset("glue", "sst2") |
| dataset = dataset.rename_column("label", "labels") |
| |
| args = TrainingArguments(num_train_epochs=3, report_to="none") |
| trainer = Trainer(model=model, args=args, train_dataset=dataset) |
| trainer.train() |
| result = trainer.evaluate() |
| """ |
|
|
|
|
| def test_rename_api_call_word_boundary(): |
| """Renaming should not break out-of-context substrings.""" |
| b = RenameApiCall(old_name="evaluate", new_name="eval_model") |
| broken = b.apply(SAMPLE_SCRIPT) |
| assert "eval_model" in broken |
| assert "trainer.evaluate" not in broken |
| |
| r = RestoreApiCall(new_name="eval_model", old_name="evaluate") |
| restored = r.apply(broken) |
| assert restored.strip() == SAMPLE_SCRIPT.strip() |
|
|
|
|
| def test_deprecate_import(): |
| b = DeprecateImport( |
| old_module="from transformers import", |
| new_module="from transformers.legacy import", |
| ) |
| broken = b.apply(SAMPLE_SCRIPT) |
| assert "transformers.legacy" in broken |
| r = RestoreImport( |
| new_module="from transformers.legacy import", |
| old_module="from transformers import", |
| ) |
| restored = r.apply(broken) |
| assert restored == SAMPLE_SCRIPT |
|
|
|
|
| def test_restructure_dataset_string_replacement(): |
| b = RestructureDatasetSchema(old_column="label", new_column="sentiment_label") |
| broken = b.apply(SAMPLE_SCRIPT) |
| assert '"sentiment_label"' in broken |
| assert '"label"' not in broken |
| r = RestoreColumn(new_column="sentiment_label", old_column="label") |
| restored = r.apply(broken) |
| assert restored == SAMPLE_SCRIPT |
|
|
|
|
| def test_modify_config_field_changes_value(): |
| b = ModifyConfigField( |
| config_class="TrainingArguments", |
| field_name="num_train_epochs", |
| new_value="999", |
| ) |
| broken = b.apply(SAMPLE_SCRIPT) |
| assert "num_train_epochs=999" in broken |
|
|
|
|
| def test_change_tokenizer_behavior_replaces_kwarg(): |
| script = "tok = tokenizer(text, padding=True, truncation=True)" |
| b = ChangeTokenizerBehavior( |
| old_kwarg="padding", |
| old_value="True", |
| new_kwarg="padding", |
| new_value='"max_length"', |
| ) |
| broken = b.apply(script) |
| assert 'padding="max_length"' in broken |
|
|
|
|
| def test_remove_deprecated_method_marks_call(): |
| b = RemoveDeprecatedMethod( |
| class_name="Trainer", method_name="evaluate", replacement="evaluate_legacy" |
| ) |
| broken = b.apply(SAMPLE_SCRIPT) |
| assert ".evaluate_DEPRECATED(" in broken |
| r = RestoreMethod(method_name="evaluate") |
| restored = r.apply(broken) |
| assert restored == SAMPLE_SCRIPT |
|
|
|
|
| def test_change_argument_signature_removes_kwarg(): |
| b = ChangeArgumentSignature( |
| function_name="TrainingArguments", |
| removed_arg="report_to", |
| added_arg="report_to", |
| added_value='"none"', |
| ) |
| broken = b.apply(SAMPLE_SCRIPT) |
| assert 'report_to="none"' not in broken |
|
|
|
|
| def test_change_return_type_swaps_access(): |
| b = ChangeReturnType( |
| function_name="evaluate", |
| old_access="trainer.evaluate()", |
| new_access="trainer.evaluate().metrics", |
| ) |
| broken = b.apply(SAMPLE_SCRIPT) |
| assert "trainer.evaluate().metrics" in broken |
|
|
|
|
| def test_parse_spec_round_trip(): |
| spec = { |
| "primitive_type": "RenameApiCall", |
| "params": {"old_name": "evaluate", "new_name": "eval_model"}, |
| } |
| primitive = parse_breakage_spec(spec) |
| assert isinstance(primitive, RenameApiCall) |
| assert primitive.old_name == "evaluate" |
| assert primitive.to_spec()["primitive_type"] == "RenameApiCall" |
|
|
|
|
| def test_parse_spec_unknown_raises(): |
| with pytest.raises(ValueError): |
| parse_breakage_spec({"primitive_type": "Bogus"}) |
|
|
|
|
| def test_parse_spec_ignores_extra_kwargs(): |
| """LLMs hallucinate kwargs; we should silently filter them.""" |
| spec = { |
| "primitive_type": "RenameApiCall", |
| "params": { |
| "old_name": "evaluate", |
| "new_name": "eval_model", |
| "hallucinated_kwarg": "ignore_me", |
| }, |
| } |
| primitive = parse_breakage_spec(spec) |
| assert isinstance(primitive, RenameApiCall) |
|
|
|
|
| def test_breakage_creates_actual_difference(): |
| b = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training") |
| broken = b.apply(SAMPLE_SCRIPT) |
| assert broken != SAMPLE_SCRIPT |
|
|
|
|
| def test_all_8_primitives_registered(): |
| expected = { |
| "RenameApiCall", |
| "DeprecateImport", |
| "ChangeArgumentSignature", |
| "ModifyConfigField", |
| "RestructureDatasetSchema", |
| "ChangeTokenizerBehavior", |
| "RemoveDeprecatedMethod", |
| "ChangeReturnType", |
| } |
| assert set(PRIMITIVE_REGISTRY) == expected |
|
|
|
|
| def test_breakage_repair_registry_alignment(): |
| """Every breakage class should have a registered inverse.""" |
| for breakage_name, repair_name in BREAKAGE_TO_REPAIR.items(): |
| assert breakage_name in PRIMITIVE_REGISTRY |
| assert repair_name in REPAIR_REGISTRY |
|
|
|
|
| def test_seed_corpus_has_at_least_10_scripts(): |
| sampler = TaskSampler() |
| assert len(sampler.tasks) >= 10 |
| assert all(t.script_content for t in sampler.tasks) |
| assert all(t.task_id for t in sampler.tasks) |
|
|
|
|
| def test_task_sampler_categories_are_diverse(): |
| sampler = TaskSampler() |
| categories = sampler.get_all_categories() |
| assert len(categories) >= 3, f"Expected at least 3 distinct categories, got {categories}" |
|
|
|
|
| def test_task_sampler_difficulty_filter(): |
| sampler = TaskSampler() |
| |
| task = sampler.sample(difficulty="easy") |
| if task is not None: |
| assert task.difficulty == "easy" |
|
|
|
|
| def test_task_sampler_get_by_id(): |
| sampler = TaskSampler() |
| if not sampler.tasks: |
| pytest.skip("No tasks loaded") |
| first = sampler.tasks[0] |
| fetched = sampler.get_by_id(first.task_id) |
| assert fetched is first |
|
|