forgeenv-source / tests /test_primitives.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Tests for breakage primitives, repair primitives, and task sampler."""
import pytest
from forgeenv.primitives.breakage_primitives import (
ChangeArgumentSignature,
ChangeReturnType,
ChangeTokenizerBehavior,
DeprecateImport,
ModifyConfigField,
PRIMITIVE_REGISTRY,
RemoveDeprecatedMethod,
RenameApiCall,
RestructureDatasetSchema,
parse_breakage_spec,
)
from forgeenv.primitives.repair_primitives import (
BREAKAGE_TO_REPAIR,
REPAIR_REGISTRY,
RestoreApiCall,
RestoreColumn,
RestoreImport,
RestoreMethod,
)
from forgeenv.tasks.task_sampler import TaskSampler
SAMPLE_SCRIPT = """
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
dataset = load_dataset("glue", "sst2")
dataset = dataset.rename_column("label", "labels")
args = TrainingArguments(num_train_epochs=3, report_to="none")
trainer = Trainer(model=model, args=args, train_dataset=dataset)
trainer.train()
result = trainer.evaluate()
"""
def test_rename_api_call_word_boundary():
"""Renaming should not break out-of-context substrings."""
b = RenameApiCall(old_name="evaluate", new_name="eval_model")
broken = b.apply(SAMPLE_SCRIPT)
assert "eval_model" in broken
assert "trainer.evaluate" not in broken
# Inverse should restore
r = RestoreApiCall(new_name="eval_model", old_name="evaluate")
restored = r.apply(broken)
assert restored.strip() == SAMPLE_SCRIPT.strip()
def test_deprecate_import():
b = DeprecateImport(
old_module="from transformers import",
new_module="from transformers.legacy import",
)
broken = b.apply(SAMPLE_SCRIPT)
assert "transformers.legacy" in broken
r = RestoreImport(
new_module="from transformers.legacy import",
old_module="from transformers import",
)
restored = r.apply(broken)
assert restored == SAMPLE_SCRIPT
def test_restructure_dataset_string_replacement():
b = RestructureDatasetSchema(old_column="label", new_column="sentiment_label")
broken = b.apply(SAMPLE_SCRIPT)
assert '"sentiment_label"' in broken
assert '"label"' not in broken
r = RestoreColumn(new_column="sentiment_label", old_column="label")
restored = r.apply(broken)
assert restored == SAMPLE_SCRIPT
def test_modify_config_field_changes_value():
b = ModifyConfigField(
config_class="TrainingArguments",
field_name="num_train_epochs",
new_value="999",
)
broken = b.apply(SAMPLE_SCRIPT)
assert "num_train_epochs=999" in broken
def test_change_tokenizer_behavior_replaces_kwarg():
script = "tok = tokenizer(text, padding=True, truncation=True)"
b = ChangeTokenizerBehavior(
old_kwarg="padding",
old_value="True",
new_kwarg="padding",
new_value='"max_length"',
)
broken = b.apply(script)
assert 'padding="max_length"' in broken
def test_remove_deprecated_method_marks_call():
b = RemoveDeprecatedMethod(
class_name="Trainer", method_name="evaluate", replacement="evaluate_legacy"
)
broken = b.apply(SAMPLE_SCRIPT)
assert ".evaluate_DEPRECATED(" in broken
r = RestoreMethod(method_name="evaluate")
restored = r.apply(broken)
assert restored == SAMPLE_SCRIPT
def test_change_argument_signature_removes_kwarg():
b = ChangeArgumentSignature(
function_name="TrainingArguments",
removed_arg="report_to",
added_arg="report_to",
added_value='"none"',
)
broken = b.apply(SAMPLE_SCRIPT)
assert 'report_to="none"' not in broken
def test_change_return_type_swaps_access():
b = ChangeReturnType(
function_name="evaluate",
old_access="trainer.evaluate()",
new_access="trainer.evaluate().metrics",
)
broken = b.apply(SAMPLE_SCRIPT)
assert "trainer.evaluate().metrics" in broken
def test_parse_spec_round_trip():
spec = {
"primitive_type": "RenameApiCall",
"params": {"old_name": "evaluate", "new_name": "eval_model"},
}
primitive = parse_breakage_spec(spec)
assert isinstance(primitive, RenameApiCall)
assert primitive.old_name == "evaluate"
assert primitive.to_spec()["primitive_type"] == "RenameApiCall"
def test_parse_spec_unknown_raises():
with pytest.raises(ValueError):
parse_breakage_spec({"primitive_type": "Bogus"})
def test_parse_spec_ignores_extra_kwargs():
"""LLMs hallucinate kwargs; we should silently filter them."""
spec = {
"primitive_type": "RenameApiCall",
"params": {
"old_name": "evaluate",
"new_name": "eval_model",
"hallucinated_kwarg": "ignore_me",
},
}
primitive = parse_breakage_spec(spec)
assert isinstance(primitive, RenameApiCall)
def test_breakage_creates_actual_difference():
b = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training")
broken = b.apply(SAMPLE_SCRIPT)
assert broken != SAMPLE_SCRIPT
def test_all_8_primitives_registered():
expected = {
"RenameApiCall",
"DeprecateImport",
"ChangeArgumentSignature",
"ModifyConfigField",
"RestructureDatasetSchema",
"ChangeTokenizerBehavior",
"RemoveDeprecatedMethod",
"ChangeReturnType",
}
assert set(PRIMITIVE_REGISTRY) == expected
def test_breakage_repair_registry_alignment():
"""Every breakage class should have a registered inverse."""
for breakage_name, repair_name in BREAKAGE_TO_REPAIR.items():
assert breakage_name in PRIMITIVE_REGISTRY
assert repair_name in REPAIR_REGISTRY
def test_seed_corpus_has_at_least_10_scripts():
sampler = TaskSampler()
assert len(sampler.tasks) >= 10
assert all(t.script_content for t in sampler.tasks)
assert all(t.task_id for t in sampler.tasks)
def test_task_sampler_categories_are_diverse():
sampler = TaskSampler()
categories = sampler.get_all_categories()
assert len(categories) >= 3, f"Expected at least 3 distinct categories, got {categories}"
def test_task_sampler_difficulty_filter():
sampler = TaskSampler()
# Should not crash even when an unknown difficulty is requested.
task = sampler.sample(difficulty="easy")
if task is not None:
assert task.difficulty == "easy"
def test_task_sampler_get_by_id():
sampler = TaskSampler()
if not sampler.tasks:
pytest.skip("No tasks loaded")
first = sampler.tasks[0]
fetched = sampler.get_by_id(first.task_id)
assert fetched is first