"""Tests for breakage primitives, repair primitives, and task sampler.""" import pytest from forgeenv.primitives.breakage_primitives import ( ChangeArgumentSignature, ChangeReturnType, ChangeTokenizerBehavior, DeprecateImport, ModifyConfigField, PRIMITIVE_REGISTRY, RemoveDeprecatedMethod, RenameApiCall, RestructureDatasetSchema, parse_breakage_spec, ) from forgeenv.primitives.repair_primitives import ( BREAKAGE_TO_REPAIR, REPAIR_REGISTRY, RestoreApiCall, RestoreColumn, RestoreImport, RestoreMethod, ) from forgeenv.tasks.task_sampler import TaskSampler SAMPLE_SCRIPT = """ from transformers import Trainer, TrainingArguments from datasets import load_dataset dataset = load_dataset("glue", "sst2") dataset = dataset.rename_column("label", "labels") args = TrainingArguments(num_train_epochs=3, report_to="none") trainer = Trainer(model=model, args=args, train_dataset=dataset) trainer.train() result = trainer.evaluate() """ def test_rename_api_call_word_boundary(): """Renaming should not break out-of-context substrings.""" b = RenameApiCall(old_name="evaluate", new_name="eval_model") broken = b.apply(SAMPLE_SCRIPT) assert "eval_model" in broken assert "trainer.evaluate" not in broken # Inverse should restore r = RestoreApiCall(new_name="eval_model", old_name="evaluate") restored = r.apply(broken) assert restored.strip() == SAMPLE_SCRIPT.strip() def test_deprecate_import(): b = DeprecateImport( old_module="from transformers import", new_module="from transformers.legacy import", ) broken = b.apply(SAMPLE_SCRIPT) assert "transformers.legacy" in broken r = RestoreImport( new_module="from transformers.legacy import", old_module="from transformers import", ) restored = r.apply(broken) assert restored == SAMPLE_SCRIPT def test_restructure_dataset_string_replacement(): b = RestructureDatasetSchema(old_column="label", new_column="sentiment_label") broken = b.apply(SAMPLE_SCRIPT) assert '"sentiment_label"' in broken assert '"label"' not in broken r = RestoreColumn(new_column="sentiment_label", old_column="label") restored = r.apply(broken) assert restored == SAMPLE_SCRIPT def test_modify_config_field_changes_value(): b = ModifyConfigField( config_class="TrainingArguments", field_name="num_train_epochs", new_value="999", ) broken = b.apply(SAMPLE_SCRIPT) assert "num_train_epochs=999" in broken def test_change_tokenizer_behavior_replaces_kwarg(): script = "tok = tokenizer(text, padding=True, truncation=True)" b = ChangeTokenizerBehavior( old_kwarg="padding", old_value="True", new_kwarg="padding", new_value='"max_length"', ) broken = b.apply(script) assert 'padding="max_length"' in broken def test_remove_deprecated_method_marks_call(): b = RemoveDeprecatedMethod( class_name="Trainer", method_name="evaluate", replacement="evaluate_legacy" ) broken = b.apply(SAMPLE_SCRIPT) assert ".evaluate_DEPRECATED(" in broken r = RestoreMethod(method_name="evaluate") restored = r.apply(broken) assert restored == SAMPLE_SCRIPT def test_change_argument_signature_removes_kwarg(): b = ChangeArgumentSignature( function_name="TrainingArguments", removed_arg="report_to", added_arg="report_to", added_value='"none"', ) broken = b.apply(SAMPLE_SCRIPT) assert 'report_to="none"' not in broken def test_change_return_type_swaps_access(): b = ChangeReturnType( function_name="evaluate", old_access="trainer.evaluate()", new_access="trainer.evaluate().metrics", ) broken = b.apply(SAMPLE_SCRIPT) assert "trainer.evaluate().metrics" in broken def test_parse_spec_round_trip(): spec = { "primitive_type": "RenameApiCall", "params": {"old_name": "evaluate", "new_name": "eval_model"}, } primitive = parse_breakage_spec(spec) assert isinstance(primitive, RenameApiCall) assert primitive.old_name == "evaluate" assert primitive.to_spec()["primitive_type"] == "RenameApiCall" def test_parse_spec_unknown_raises(): with pytest.raises(ValueError): parse_breakage_spec({"primitive_type": "Bogus"}) def test_parse_spec_ignores_extra_kwargs(): """LLMs hallucinate kwargs; we should silently filter them.""" spec = { "primitive_type": "RenameApiCall", "params": { "old_name": "evaluate", "new_name": "eval_model", "hallucinated_kwarg": "ignore_me", }, } primitive = parse_breakage_spec(spec) assert isinstance(primitive, RenameApiCall) def test_breakage_creates_actual_difference(): b = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training") broken = b.apply(SAMPLE_SCRIPT) assert broken != SAMPLE_SCRIPT def test_all_8_primitives_registered(): expected = { "RenameApiCall", "DeprecateImport", "ChangeArgumentSignature", "ModifyConfigField", "RestructureDatasetSchema", "ChangeTokenizerBehavior", "RemoveDeprecatedMethod", "ChangeReturnType", } assert set(PRIMITIVE_REGISTRY) == expected def test_breakage_repair_registry_alignment(): """Every breakage class should have a registered inverse.""" for breakage_name, repair_name in BREAKAGE_TO_REPAIR.items(): assert breakage_name in PRIMITIVE_REGISTRY assert repair_name in REPAIR_REGISTRY def test_seed_corpus_has_at_least_10_scripts(): sampler = TaskSampler() assert len(sampler.tasks) >= 10 assert all(t.script_content for t in sampler.tasks) assert all(t.task_id for t in sampler.tasks) def test_task_sampler_categories_are_diverse(): sampler = TaskSampler() categories = sampler.get_all_categories() assert len(categories) >= 3, f"Expected at least 3 distinct categories, got {categories}" def test_task_sampler_difficulty_filter(): sampler = TaskSampler() # Should not crash even when an unknown difficulty is requested. task = sampler.sample(difficulty="easy") if task is not None: assert task.difficulty == "easy" def test_task_sampler_get_by_id(): sampler = TaskSampler() if not sampler.tasks: pytest.skip("No tasks loaded") first = sampler.tasks[0] fetched = sampler.get_by_id(first.task_id) assert fetched is first