physix / tests /test_scorer.py
Pratyush-01's picture
Upload folder using huggingface_hub
0e24aff verified
"""Unit tests for :mod:`physix.training.scorer`."""
from __future__ import annotations
import numpy as np
import pytest
from physix.systems import get_system
from physix.training.scorer import Scorer, SystemContext
def _build_context_for(system_id: str, seed: int = 5) -> SystemContext:
system = get_system(system_id)
rng = np.random.default_rng(seed)
traj = system.simulate(rng)
return SystemContext(
system_id=system_id,
state_variables=system.state_variables,
parameters=system.parameters,
initial_conditions=system.initial_conditions,
timestamps=traj.timestamps,
observed=traj.states,
previous_total=0.0,
)
def test_scorer_returns_high_match_for_ground_truth() -> None:
scorer = Scorer()
ctx = _build_context_for("free_fall")
breakdown = scorer.score('{"equation": "d2y/dt2 = -9.81"}', ctx)
assert breakdown.format == 1.0
assert breakdown.match >= 0.9
def test_scorer_returns_zero_format_on_garbage() -> None:
scorer = Scorer()
ctx = _build_context_for("free_fall")
breakdown = scorer.score("this is not an equation", ctx)
assert breakdown.format == 0.0
assert breakdown.match == 0.0
def test_scorer_caches_by_key() -> None:
scorer = Scorer()
ctx = _build_context_for("simple_pendulum")
completion = '{"equation": "d2theta/dt2 = -9.81 * sin(theta)"}'
first = scorer.score(completion, ctx, cache_key=42)
second = scorer.score(completion, ctx, cache_key=42)
assert first is second
def test_scorer_progress_shrinks_when_previous_total_is_high() -> None:
scorer = Scorer()
ctx_low = _build_context_for("free_fall_drag", seed=3)
completion = '{"equation": "d2y/dt2 = -9.81 + 0.05 * vy**2"}'
no_prev = scorer.score(completion, ctx_low)
if no_prev.match < 0.5:
pytest.skip("rng selection produced an awkward instance")
ctx_high = ctx_low.model_copy(update={"previous_total": no_prev.match * 0.9})
with_prev = scorer.score(completion, ctx_high)
# progress = max(0, match - prev_total). Higher prev_total => smaller progress.
assert with_prev.progress <= no_prev.progress + 1e-9