Spaces:
Sleeping
Sleeping
| """Unit tests for :mod:`physix.training.scorer`.""" | |
| from __future__ import annotations | |
| import numpy as np | |
| import pytest | |
| from physix.systems import get_system | |
| from physix.training.scorer import Scorer, SystemContext | |
| def _build_context_for(system_id: str, seed: int = 5) -> SystemContext: | |
| system = get_system(system_id) | |
| rng = np.random.default_rng(seed) | |
| traj = system.simulate(rng) | |
| return SystemContext( | |
| system_id=system_id, | |
| state_variables=system.state_variables, | |
| parameters=system.parameters, | |
| initial_conditions=system.initial_conditions, | |
| timestamps=traj.timestamps, | |
| observed=traj.states, | |
| previous_total=0.0, | |
| ) | |
| def test_scorer_returns_high_match_for_ground_truth() -> None: | |
| scorer = Scorer() | |
| ctx = _build_context_for("free_fall") | |
| breakdown = scorer.score('{"equation": "d2y/dt2 = -9.81"}', ctx) | |
| assert breakdown.format == 1.0 | |
| assert breakdown.match >= 0.9 | |
| def test_scorer_returns_zero_format_on_garbage() -> None: | |
| scorer = Scorer() | |
| ctx = _build_context_for("free_fall") | |
| breakdown = scorer.score("this is not an equation", ctx) | |
| assert breakdown.format == 0.0 | |
| assert breakdown.match == 0.0 | |
| def test_scorer_caches_by_key() -> None: | |
| scorer = Scorer() | |
| ctx = _build_context_for("simple_pendulum") | |
| completion = '{"equation": "d2theta/dt2 = -9.81 * sin(theta)"}' | |
| first = scorer.score(completion, ctx, cache_key=42) | |
| second = scorer.score(completion, ctx, cache_key=42) | |
| assert first is second | |
| def test_scorer_progress_shrinks_when_previous_total_is_high() -> None: | |
| scorer = Scorer() | |
| ctx_low = _build_context_for("free_fall_drag", seed=3) | |
| completion = '{"equation": "d2y/dt2 = -9.81 + 0.05 * vy**2"}' | |
| no_prev = scorer.score(completion, ctx_low) | |
| if no_prev.match < 0.5: | |
| pytest.skip("rng selection produced an awkward instance") | |
| ctx_high = ctx_low.model_copy(update={"previous_total": no_prev.match * 0.9}) | |
| with_prev = scorer.score(completion, ctx_high) | |
| # progress = max(0, match - prev_total). Higher prev_total => smaller progress. | |
| assert with_prev.progress <= no_prev.progress + 1e-9 | |