Spaces:
Sleeping
Sleeping
cleanup: strip verbose comments from physix/training/sft.py
Browse files- physix/training/sft.py +2 -20
physix/training/sft.py
CHANGED
|
@@ -48,22 +48,10 @@ from physix.models import DEFAULT_MAX_TURNS, PhysiXObservation
|
|
| 48 |
_log = logging.getLogger(__name__)
|
| 49 |
|
| 50 |
|
| 51 |
-
# ─── Dataset ──────────────────────────────────────────────────────────────────
|
| 52 |
-
|
| 53 |
def _gt_completion(system: PhysicalSystem) -> str:
|
| 54 |
-
"""
|
| 55 |
-
|
| 56 |
-
We include the system's sampled parameters so the model learns that the
|
| 57 |
-
``params`` field must contain the symbols it references in the equation.
|
| 58 |
-
The SFT target is the *exact* JSON string the env's verifier accepts;
|
| 59 |
-
GRPO will later teach the model to refine parameter values per trajectory.
|
| 60 |
-
"""
|
| 61 |
import re as _re
|
| 62 |
eq = system.ground_truth_equation()
|
| 63 |
-
# Extract all identifier tokens that appear in the equation, then keep
|
| 64 |
-
# only those that are declared as system parameters. We use a proper
|
| 65 |
-
# identifier regex (not split-on-whitespace) so symbols inside function
|
| 66 |
-
# calls like sin(theta) and fractions like -(g/L) are caught.
|
| 67 |
reserved = set(system.state_variables) | {"dt", "d", "t", "sin", "cos",
|
| 68 |
"tan", "exp", "log", "sqrt", "abs"}
|
| 69 |
eq_tokens = set(_re.findall(r'\b([A-Za-z_][A-Za-z0-9_]*)\b', eq))
|
|
@@ -130,8 +118,6 @@ def _build_obs(system: PhysicalSystem, trajectory: TrajectoryData) -> PhysiXObse
|
|
| 130 |
)
|
| 131 |
|
| 132 |
|
| 133 |
-
# ─── Training ─────────────────────────────────────────────────────────────────
|
| 134 |
-
|
| 135 |
def train_sft(
|
| 136 |
model_name: str = "Qwen/Qwen2.5-1.5B-Instruct",
|
| 137 |
output_dir: str = "runs/physix-1.5b-sft",
|
|
@@ -151,15 +137,11 @@ def train_sft(
|
|
| 151 |
) -> None:
|
| 152 |
_configure_logging()
|
| 153 |
|
| 154 |
-
# Heavy imports: only available in [train] env.
|
| 155 |
import wandb
|
| 156 |
from unsloth import FastLanguageModel
|
| 157 |
from trl import SFTTrainer, SFTConfig
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
# / WANDB_RESUME env vars (those are intended for the GRPO stage). If we
|
| 161 |
-
# let wandb.init() try to resume a foreign run id it will block for ~90s
|
| 162 |
-
# fetching that run's history before giving up.
|
| 163 |
for stale in ("WANDB_RUN_ID", "WANDB_RESUME"):
|
| 164 |
os.environ.pop(stale, None)
|
| 165 |
|
|
|
|
| 48 |
_log = logging.getLogger(__name__)
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
| 51 |
def _gt_completion(system: PhysicalSystem) -> str:
|
| 52 |
+
"""Return the ground-truth completion JSON for one system."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
import re as _re
|
| 54 |
eq = system.ground_truth_equation()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
reserved = set(system.state_variables) | {"dt", "d", "t", "sin", "cos",
|
| 56 |
"tan", "exp", "log", "sqrt", "abs"}
|
| 57 |
eq_tokens = set(_re.findall(r'\b([A-Za-z_][A-Za-z0-9_]*)\b', eq))
|
|
|
|
| 118 |
)
|
| 119 |
|
| 120 |
|
|
|
|
|
|
|
| 121 |
def train_sft(
|
| 122 |
model_name: str = "Qwen/Qwen2.5-1.5B-Instruct",
|
| 123 |
output_dir: str = "runs/physix-1.5b-sft",
|
|
|
|
| 137 |
) -> None:
|
| 138 |
_configure_logging()
|
| 139 |
|
|
|
|
| 140 |
import wandb
|
| 141 |
from unsloth import FastLanguageModel
|
| 142 |
from trl import SFTTrainer, SFTConfig
|
| 143 |
|
| 144 |
+
# Clear stale resume vars so SFT starts a fresh W&B run.
|
|
|
|
|
|
|
|
|
|
| 145 |
for stale in ("WANDB_RUN_ID", "WANDB_RESUME"):
|
| 146 |
os.environ.pop(stale, None)
|
| 147 |
|