Pratyush-01 commited on
Commit
0b8f87b
·
verified ·
1 Parent(s): c59b8f5

cleanup: strip verbose comments from physix/training/sft.py

Browse files
Files changed (1) hide show
  1. physix/training/sft.py +2 -20
physix/training/sft.py CHANGED
@@ -48,22 +48,10 @@ from physix.models import DEFAULT_MAX_TURNS, PhysiXObservation
48
  _log = logging.getLogger(__name__)
49
 
50
 
51
- # ─── Dataset ──────────────────────────────────────────────────────────────────
52
-
53
  def _gt_completion(system: PhysicalSystem) -> str:
54
- """Build the ground-truth completion JSON for one system.
55
-
56
- We include the system's sampled parameters so the model learns that the
57
- ``params`` field must contain the symbols it references in the equation.
58
- The SFT target is the *exact* JSON string the env's verifier accepts;
59
- GRPO will later teach the model to refine parameter values per trajectory.
60
- """
61
  import re as _re
62
  eq = system.ground_truth_equation()
63
- # Extract all identifier tokens that appear in the equation, then keep
64
- # only those that are declared as system parameters. We use a proper
65
- # identifier regex (not split-on-whitespace) so symbols inside function
66
- # calls like sin(theta) and fractions like -(g/L) are caught.
67
  reserved = set(system.state_variables) | {"dt", "d", "t", "sin", "cos",
68
  "tan", "exp", "log", "sqrt", "abs"}
69
  eq_tokens = set(_re.findall(r'\b([A-Za-z_][A-Za-z0-9_]*)\b', eq))
@@ -130,8 +118,6 @@ def _build_obs(system: PhysicalSystem, trajectory: TrajectoryData) -> PhysiXObse
130
  )
131
 
132
 
133
- # ─── Training ─────────────────────────────────────────────────────────────────
134
-
135
  def train_sft(
136
  model_name: str = "Qwen/Qwen2.5-1.5B-Instruct",
137
  output_dir: str = "runs/physix-1.5b-sft",
@@ -151,15 +137,11 @@ def train_sft(
151
  ) -> None:
152
  _configure_logging()
153
 
154
- # Heavy imports: only available in [train] env.
155
  import wandb
156
  from unsloth import FastLanguageModel
157
  from trl import SFTTrainer, SFTConfig
158
 
159
- # Force a fresh W&B run for SFT regardless of any inherited WANDB_RUN_ID
160
- # / WANDB_RESUME env vars (those are intended for the GRPO stage). If we
161
- # let wandb.init() try to resume a foreign run id it will block for ~90s
162
- # fetching that run's history before giving up.
163
  for stale in ("WANDB_RUN_ID", "WANDB_RESUME"):
164
  os.environ.pop(stale, None)
165
 
 
48
  _log = logging.getLogger(__name__)
49
 
50
 
 
 
51
  def _gt_completion(system: PhysicalSystem) -> str:
52
+ """Return the ground-truth completion JSON for one system."""
 
 
 
 
 
 
53
  import re as _re
54
  eq = system.ground_truth_equation()
 
 
 
 
55
  reserved = set(system.state_variables) | {"dt", "d", "t", "sin", "cos",
56
  "tan", "exp", "log", "sqrt", "abs"}
57
  eq_tokens = set(_re.findall(r'\b([A-Za-z_][A-Za-z0-9_]*)\b', eq))
 
118
  )
119
 
120
 
 
 
121
  def train_sft(
122
  model_name: str = "Qwen/Qwen2.5-1.5B-Instruct",
123
  output_dir: str = "runs/physix-1.5b-sft",
 
137
  ) -> None:
138
  _configure_logging()
139
 
 
140
  import wandb
141
  from unsloth import FastLanguageModel
142
  from trl import SFTTrainer, SFTConfig
143
 
144
+ # Clear stale resume vars so SFT starts a fresh W&B run.
 
 
 
145
  for stale in ("WANDB_RUN_ID", "WANDB_RESUME"):
146
  os.environ.pop(stale, None)
147