Spaces:
Sleeping
Sleeping
| """Build the prompt+context dataset for GRPO training.""" | |
| from __future__ import annotations | |
| from collections.abc import Iterable | |
| import numpy as np | |
| from datasets import Dataset | |
| from pydantic import BaseModel, ConfigDict | |
| from physix.models import DEFAULT_MAX_TURNS, PhysiXObservation | |
| from physix.systems import ( | |
| SYSTEM_REGISTRY, | |
| SUPPORTED_SYSTEMS, | |
| SystemTier, | |
| get_system, | |
| list_systems_by_tier, | |
| ) | |
| from physix.systems.base import PhysicalSystem, TrajectoryData | |
| from physix.training.prompt import build_prompt | |
| class DatasetSpec(BaseModel): | |
| """Configuration for :func:`build_training_dataset`.""" | |
| model_config = ConfigDict(frozen=True) | |
| system_ids: tuple[str, ...] = SUPPORTED_SYSTEMS | |
| instances_per_system: int = 32 | |
| seed: int = 0 | |
| class EvalDatasetSpec(BaseModel): | |
| """Held-out evaluation set, drawn separately so seeds do not overlap.""" | |
| model_config = ConfigDict(frozen=True) | |
| train_tiers: tuple[SystemTier, ...] = (SystemTier.TIER_1, SystemTier.TIER_2) | |
| held_out_tiers: tuple[SystemTier, ...] = (SystemTier.TIER_3,) | |
| instances_per_system: int = 8 | |
| seed: int = 1_000_000 # large to avoid overlap with training seeds | |
| def build_training_dataset(spec: DatasetSpec | None = None) -> Dataset: | |
| """Build the GRPO training dataset. | |
| Each row contains one (system, instance) prompt at turn 0. | |
| """ | |
| spec = spec or DatasetSpec() | |
| _validate_system_ids(spec.system_ids) | |
| rng = np.random.default_rng(spec.seed) | |
| rows: list[dict[str, object]] = [] | |
| for system_id in spec.system_ids: | |
| for _ in range(spec.instances_per_system): | |
| rows.append(_build_row(system_id, rng)) | |
| return Dataset.from_list(rows) | |
| def _validate_system_ids(system_ids: tuple[str, ...]) -> None: | |
| """Fail fast if the spec references an unregistered system.""" | |
| if not system_ids: | |
| raise ValueError( | |
| "DatasetSpec.system_ids must be non-empty. " | |
| f"Available: {sorted(SYSTEM_REGISTRY)!r}." | |
| ) | |
| unknown = [sid for sid in system_ids if sid not in SYSTEM_REGISTRY] | |
| if unknown: | |
| raise ValueError( | |
| f"Unknown system_ids in DatasetSpec: {unknown!r}. " | |
| f"Registered: {sorted(SYSTEM_REGISTRY)!r}." | |
| ) | |
| def build_eval_dataset(spec: EvalDatasetSpec | None = None) -> Dataset: | |
| """Build a held-out evaluation dataset spanning held-out tiers too.""" | |
| spec = spec or EvalDatasetSpec() | |
| rng = np.random.default_rng(spec.seed) | |
| rows: list[dict[str, object]] = [] | |
| for system_id in _list_systems(spec.train_tiers + spec.held_out_tiers): | |
| for _ in range(spec.instances_per_system): | |
| row = _build_row(system_id, rng) | |
| row["is_held_out"] = system_id in _list_systems(spec.held_out_tiers) | |
| rows.append(row) | |
| return Dataset.from_list(rows) | |
| def _list_systems(tiers: Iterable[SystemTier]) -> list[str]: | |
| out: list[str] = [] | |
| for tier in tiers: | |
| out.extend(list_systems_by_tier(tier)) | |
| return out | |
| def _build_row(system_id: str, rng: np.random.Generator) -> dict[str, object]: | |
| """Generate one (prompt + system context) row for a given system.""" | |
| system = get_system(system_id) | |
| trajectory = system.simulate(rng) | |
| obs = _build_observation(system, trajectory) | |
| prompt = build_prompt(obs) | |
| return { | |
| "prompt": prompt, # chat list of {"role", "content"} dicts | |
| "system_id": system_id, | |
| "state_variables": list(system.state_variables), | |
| "parameters": dict(system.parameters), | |
| "initial_conditions": dict(system.initial_conditions), | |
| "timestamps": trajectory.timestamps.tolist(), | |
| "observed": {var: trajectory.states[var].tolist() for var in system.state_variables}, | |
| "previous_r_match": 0.0, | |
| } | |
| def _build_observation( | |
| system: PhysicalSystem, | |
| trajectory: TrajectoryData, | |
| ) -> PhysiXObservation: | |
| """Construct a turn-0 :class:`PhysiXObservation` for a fresh system. | |
| We bypass :class:`PhysiXEnvironment` here because its lifecycle (history, | |
| convergence flag, episode budget) is irrelevant for dataset construction. | |
| """ | |
| return PhysiXObservation( | |
| done=False, | |
| reward=None, | |
| trajectory=trajectory.to_observation_samples(), | |
| state_variables=list(system.state_variables), | |
| hint=system.hint(system.parameters), | |
| history=[], | |
| mismatch_summary="", | |
| turn=0, | |
| turn_remaining=DEFAULT_MAX_TURNS, | |
| system_id=system.system_id, | |
| stats=trajectory.stats(), | |
| reward_breakdown={}, | |
| ) | |