Spaces:
Sleeping
Sleeping
| """Tests for the SFT dataset builder. | |
| Why a dedicated file: :func:`physix.training.sft.build_sft_dataset` is | |
| imported through a module that *also* imports heavy ML packages | |
| (``unsloth``, ``trl``, ``wandb``) lazily inside :func:`train_sft`. Keeping | |
| the dataset-only tests here documents the boundary — these tests | |
| exercise the pure-Python data path and must not require GPU deps. | |
| The whole point of changes captured by these tests: SFT must warm the | |
| model on *the same systems GRPO will refine*. Any divergence between | |
| SFT and GRPO curricula wastes the warm-start budget on shapes that | |
| never see a reinforcement signal. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import pytest | |
| from physix.systems import SUPPORTED_SYSTEMS | |
| from physix.training.sft import build_sft_dataset | |
| def test_build_sft_dataset_default_curriculum_size() -> None: | |
| ds = build_sft_dataset(instances_per_system=1) | |
| assert len(ds) == len(SUPPORTED_SYSTEMS) | |
| def test_build_sft_dataset_default_scales_with_instances_per_system() -> None: | |
| instances = 4 | |
| ds = build_sft_dataset(instances_per_system=instances) | |
| assert len(ds) == len(SUPPORTED_SYSTEMS) * instances | |
| def test_build_sft_dataset_completions_are_valid_json_action() -> None: | |
| ds = build_sft_dataset(instances_per_system=1) | |
| for row in ds: | |
| action = json.loads(row["completion"]) | |
| assert "equation" in action | |
| assert "params" in action | |
| assert "rationale" in action | |
| def test_build_sft_dataset_explicit_system_ids_override_default() -> None: | |
| ds = build_sft_dataset( | |
| system_ids=("free_fall",), | |
| instances_per_system=3, | |
| ) | |
| assert len(ds) == 3 | |
| def test_build_sft_dataset_rejects_unknown_system_ids() -> None: | |
| with pytest.raises(ValueError, match="Unknown system_ids"): | |
| build_sft_dataset(system_ids=("definitely_not_a_system",)) | |
| def test_build_sft_dataset_rejects_empty_system_ids() -> None: | |
| with pytest.raises(ValueError, match="non-empty"): | |
| build_sft_dataset(system_ids=()) | |