Spaces:
Sleeping
Sleeping
| """Unit tests for :mod:`physix.training.dataset`.""" | |
| from __future__ import annotations | |
| import pytest | |
| from physix.systems import SUPPORTED_SYSTEMS | |
| from physix.training.dataset import ( | |
| DatasetSpec, | |
| EvalDatasetSpec, | |
| build_eval_dataset, | |
| build_training_dataset, | |
| ) | |
| _EXPECTED_COLUMNS = { | |
| "prompt", | |
| "system_id", | |
| "state_variables", | |
| "parameters", | |
| "initial_conditions", | |
| "timestamps", | |
| "observed", | |
| "previous_r_match", | |
| } | |
| def test_training_dataset_has_expected_schema() -> None: | |
| ds = build_training_dataset(DatasetSpec(instances_per_system=2)) | |
| assert _EXPECTED_COLUMNS.issubset(set(ds.column_names)) | |
| # Default curriculum is the 3 demo systems × 2 instances = 6 rows. | |
| assert len(ds) == len(SUPPORTED_SYSTEMS) * 2 | |
| def test_training_dataset_default_curriculum_is_demo_systems() -> None: | |
| """Default DatasetSpec must train on exactly the same systems the | |
| live demo exposes. Mismatches would mean GRPO improves systems we | |
| never benchmark — the failure mode that produced flat headline | |
| metrics in the v1 training run. | |
| """ | |
| ds = build_training_dataset(DatasetSpec(instances_per_system=1)) | |
| assert set(ds["system_id"]) == set(SUPPORTED_SYSTEMS) | |
| def test_training_dataset_explicit_system_ids_override_default() -> None: | |
| ds = build_training_dataset( | |
| DatasetSpec(system_ids=("free_fall", "simple_pendulum"), instances_per_system=1) | |
| ) | |
| assert set(ds["system_id"]) == {"free_fall", "simple_pendulum"} | |
| def test_training_dataset_rejects_unknown_system_ids() -> None: | |
| with pytest.raises(ValueError, match="Unknown system_ids"): | |
| build_training_dataset(DatasetSpec(system_ids=("not_a_real_system",))) | |
| def test_training_dataset_rejects_empty_system_ids() -> None: | |
| with pytest.raises(ValueError, match="non-empty"): | |
| build_training_dataset(DatasetSpec(system_ids=())) | |
| def test_training_dataset_prompts_are_chat_lists() -> None: | |
| ds = build_training_dataset(DatasetSpec(instances_per_system=1)) | |
| prompt = ds[0]["prompt"] | |
| assert isinstance(prompt, list) | |
| assert prompt[0]["role"] == "system" | |
| assert prompt[1]["role"] == "user" | |
| assert "TRAJECTORY" in prompt[1]["content"] | |
| def test_eval_dataset_marks_held_out_rows() -> None: | |
| ds = build_eval_dataset(EvalDatasetSpec(instances_per_system=1)) | |
| held_out_rows = [r for r in ds if r["is_held_out"]] | |
| train_rows = [r for r in ds if not r["is_held_out"]] | |
| assert len(held_out_rows) >= 2 # 2 Tier 3 systems | |
| assert len(train_rows) >= 6 # 6 Tier 1+2 systems | |
| held_out_ids = {row["system_id"] for row in held_out_rows} | |
| assert held_out_ids == {"projectile_drag", "charged_b_field"} | |
| def test_dataset_observed_arrays_match_state_variables() -> None: | |
| ds = build_training_dataset(DatasetSpec(instances_per_system=1)) | |
| for row in ds: | |
| observed = row["observed"] | |
| for var in row["state_variables"]: | |
| assert var in observed | |
| assert len(observed[var]) > 0 | |