Spaces:
Sleeping
Sleeping
| """SFT warm-start before GRPO training. | |
| Trains Qwen2.5-1.5B-Instruct for 2 epochs on supervised (prompt, completion) | |
| pairs where the completion is the ground-truth equation in the action JSON | |
| format the env expects. This is the essential bootstrap step: without it a | |
| cold 1.5B model outputs LaTeX / incoherent text on ~80% of turns, yielding | |
| near-zero GRPO advantages and a flat loss curve that wastes GPU credits. | |
| After SFT the model: | |
| - Emits valid JSON with ``equation``, ``params``, ``rationale`` on >90% turns. | |
| - Writes equations in the ASCII grammar (``d2y/dt2 = ...``), not LaTeX. | |
| - Knows the per-system equation family (gravity, drag, pendulum, spring). | |
| Then GRPO refines physics accuracy via the verifiable R² reward. | |
| Run:: | |
| python -m physix.training.sft \ | |
| --model Qwen/Qwen2.5-1.5B-Instruct \ | |
| --output-dir runs/physix-1.5b-sft \ | |
| --epochs 2 \ | |
| --instances-per-system 32 | |
| Typical runtime: 5-8 min on an A10G, 3-4 min on an A100. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| from datasets import Dataset | |
| from physix.systems import ( | |
| SUPPORTED_SYSTEMS, | |
| SYSTEM_REGISTRY, | |
| get_system, | |
| ) | |
| from physix.systems.base import PhysicalSystem, TrajectoryData | |
| from physix.training.prompt import build_prompt | |
| from physix.models import DEFAULT_MAX_TURNS, PhysiXObservation | |
| _log = logging.getLogger(__name__) | |
| def _gt_completion(system: PhysicalSystem) -> str: | |
| """Return the ground-truth completion JSON for one system.""" | |
| import re as _re | |
| eq = system.ground_truth_equation() | |
| reserved = set(system.state_variables) | {"dt", "d", "t", "sin", "cos", | |
| "tan", "exp", "log", "sqrt", "abs"} | |
| eq_tokens = set(_re.findall(r'\b([A-Za-z_][A-Za-z0-9_]*)\b', eq)) | |
| relevant_keys = eq_tokens & set(system.parameters) - reserved | |
| relevant = {k: round(system.parameters[k], 4) for k in sorted(relevant_keys)} | |
| return json.dumps({ | |
| "equation": eq, | |
| "params": relevant, | |
| "rationale": ( | |
| f"Ground-truth equation for {system.system_id.replace('_', ' ')}." | |
| ), | |
| }) | |
| def build_sft_dataset( | |
| system_ids: tuple[str, ...] = SUPPORTED_SYSTEMS, | |
| instances_per_system: int = 32, | |
| seed: int = 0, | |
| ) -> Dataset: | |
| if not system_ids: | |
| raise ValueError("system_ids must be non-empty.") | |
| unknown = [sid for sid in system_ids if sid not in SYSTEM_REGISTRY] | |
| if unknown: | |
| raise ValueError( | |
| f"Unknown system_ids in build_sft_dataset: {unknown!r}. " | |
| f"Registered: {sorted(SYSTEM_REGISTRY)!r}." | |
| ) | |
| rng = np.random.default_rng(seed) | |
| rows: list[dict] = [] | |
| for system_id in system_ids: | |
| system = get_system(system_id) | |
| for _ in range(instances_per_system): | |
| trajectory = system.simulate(rng) | |
| obs = _build_obs(system, trajectory) | |
| prompt_messages = build_prompt(obs) | |
| completion = _gt_completion(system) | |
| rows.append({"prompt": prompt_messages, "completion": completion}) | |
| _log.info( | |
| "Built SFT dataset: %d rows across %d systems (%s)", | |
| len(rows), | |
| len(system_ids), | |
| ", ".join(system_ids), | |
| ) | |
| return Dataset.from_list(rows) | |
| def _build_obs(system: PhysicalSystem, trajectory: TrajectoryData) -> PhysiXObservation: | |
| return PhysiXObservation( | |
| done=False, | |
| reward=None, | |
| trajectory=trajectory.to_observation_samples(), | |
| state_variables=list(system.state_variables), | |
| hint=system.hint(system.parameters), | |
| history=[], | |
| mismatch_summary="", | |
| turn=0, | |
| turn_remaining=DEFAULT_MAX_TURNS, | |
| system_id=system.system_id, | |
| stats=trajectory.stats(), | |
| reward_breakdown={}, | |
| ) | |
| def train_sft( | |
| model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", | |
| output_dir: str = "runs/physix-1.5b-sft", | |
| epochs: int = 2, | |
| max_seq_length: int = 2048, | |
| lora_r: int = 16, | |
| lora_alpha: int = 32, | |
| per_device_batch_size: int = 2, | |
| gradient_accumulation_steps: int = 4, | |
| learning_rate: float = 2e-5, | |
| instances_per_system: int = 32, | |
| system_ids: tuple[str, ...] = SUPPORTED_SYSTEMS, | |
| seed: int = 0, | |
| wandb_run_name: str | None = None, | |
| hub_checkpoint_repo_id: str | None = None, | |
| hub_token: str | None = None, | |
| ) -> None: | |
| _configure_logging() | |
| import wandb | |
| from unsloth import FastLanguageModel | |
| from trl import SFTTrainer, SFTConfig | |
| # Clear stale resume vars so SFT starts a fresh W&B run. | |
| for stale in ("WANDB_RUN_ID", "WANDB_RESUME"): | |
| os.environ.pop(stale, None) | |
| wandb.init( | |
| project=os.environ.get("WANDB_PROJECT", "physix-live"), | |
| name=wandb_run_name or f"physix-sft-{epochs}ep", | |
| config={ | |
| "stage": "sft", | |
| "model_name": model_name, | |
| "epochs": epochs, | |
| "lora_r": lora_r, | |
| "lora_alpha": lora_alpha, | |
| "learning_rate": learning_rate, | |
| "per_device_batch_size": per_device_batch_size, | |
| "gradient_accumulation_steps": gradient_accumulation_steps, | |
| "instances_per_system": instances_per_system, | |
| "seed": seed, | |
| }, | |
| tags=["sft", "physix", model_name.split("/")[-1]], | |
| ) | |
| _log.info("Loading model %s (4-bit, LoRA-%d)", model_name, lora_r) | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| load_in_4bit=True, | |
| dtype=None, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=seed, | |
| ) | |
| dataset = build_sft_dataset(system_ids=system_ids, instances_per_system=instances_per_system, seed=seed) | |
| def _format_row(row: dict) -> dict: | |
| """Combine prompt + completion into a single training string.""" | |
| messages = row["prompt"] + [{"role": "assistant", "content": row["completion"]}] | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=False | |
| ) | |
| return {"text": text} | |
| formatted = dataset.map(_format_row, remove_columns=["prompt", "completion"]) | |
| _log.info("SFT dataset ready: %d rows", len(formatted)) | |
| import torch | |
| sft_config = SFTConfig( | |
| output_dir=output_dir, | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=per_device_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| learning_rate=learning_rate, | |
| max_seq_length=max_seq_length, | |
| dataset_text_field="text", | |
| packing=True, | |
| logging_steps=1, | |
| save_strategy="epoch", | |
| report_to=["wandb"], | |
| seed=seed, | |
| bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False, | |
| fp16=not torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| args=sft_config, | |
| train_dataset=formatted, | |
| ) | |
| _log.info("Starting SFT for %d epochs on %d examples", epochs, len(formatted)) | |
| trainer.train() | |
| # We save as merged_16bit (full model + config + tokenizer) rather than | |
| # "lora" (adapter weights only). GRPO's downstream | |
| # ``FastLanguageModel.from_pretrained(sft_checkpoint)`` needs a complete | |
| # model directory — config.json + tokenizer + weights — to load. A bare | |
| # adapter shard makes Unsloth raise "No config file found". The merged | |
| # checkpoint is ~3 GB (1.5B params × 2 bytes) which is fine on /tmp. | |
| out_path = Path(output_dir) / "merged" | |
| out_path.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained_merged( | |
| save_directory=str(out_path), | |
| tokenizer=tokenizer, | |
| save_method="merged_16bit", | |
| ) | |
| _log.info("SFT model (merged 16-bit) saved → %s", out_path) | |
| if hub_checkpoint_repo_id: | |
| # Push the merged SFT model to the same checkpoint repo GRPO uses, | |
| # under a fixed `sft/` subfolder. Re-runs overwrite the subfolder | |
| # but produce a new commit, so the revision SHA still uniquely | |
| # identifies *this* SFT result. | |
| from physix.training.checkpoints import ( | |
| SFT_SUBFOLDER, | |
| log_link_artifact_to_wandb, | |
| push_checkpoint_to_hub, | |
| ) | |
| try: | |
| handle = push_checkpoint_to_hub( | |
| local_dir=out_path, | |
| repo_id=hub_checkpoint_repo_id, | |
| subfolder=SFT_SUBFOLDER, | |
| commit_message=( | |
| f"SFT merged_16bit: {model_name} | " | |
| f"epochs={epochs} lora_r={lora_r}" | |
| ), | |
| token=hub_token, | |
| ) | |
| _log.info("SFT checkpoint pushed to Hub: %s", handle.hub_url) | |
| wandb.run.summary["sft/hub_repo"] = handle.repo_id | |
| wandb.run.summary["sft/hub_url"] = handle.hub_url | |
| wandb.run.summary["sft/hub_revision"] = handle.revision | |
| log_link_artifact_to_wandb( | |
| handle, | |
| artifact_name="physix-sft-checkpoint", | |
| extra={"model_name": model_name, "epochs": epochs, "lora_r": lora_r}, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| # Don't kill SFT just because the hub push failed; the GRPO step | |
| # downstream can fall back to the local /tmp checkpoint. | |
| _log.error("SFT hub push failed (non-fatal): %s", exc) | |
| wandb.finish() | |
| # ─── CLI ────────────────────────────────────────────────────────────────────── | |
| def _configure_logging() -> None: | |
| logging.basicConfig( | |
| level=os.environ.get("PHYSIX_LOG_LEVEL", "INFO"), | |
| format="[%(asctime)s] %(levelname)s %(name)s | %(message)s", | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="SFT warm-start for PhysiX RLVR.") | |
| parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") | |
| parser.add_argument("--output-dir", default="runs/physix-1.5b-sft") | |
| parser.add_argument("--epochs", type=int, default=2) | |
| parser.add_argument("--instances-per-system", type=int, default=32) | |
| parser.add_argument( | |
| "--system-ids", | |
| default=None, | |
| help=( | |
| "Comma-separated list of system IDs to include in the SFT dataset " | |
| "(e.g. 'damped_spring'). Defaults to all SUPPORTED_SYSTEMS." | |
| ), | |
| ) | |
| parser.add_argument("--lora-r", type=int, default=32) | |
| parser.add_argument("--learning-rate", type=float, default=2e-5) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument("--wandb-run-name", default=None, | |
| help="Override W&B run name. Defaults to physix-sft-{epochs}ep.") | |
| parser.add_argument( | |
| "--hub-checkpoint-repo-id", | |
| default=None, | |
| help=( | |
| "If set, push the merged SFT model to <repo>/sft on the Hub " | |
| "and log a pointer-only artifact to W&B." | |
| ), | |
| ) | |
| args = parser.parse_args() | |
| system_ids = ( | |
| tuple(s.strip() for s in args.system_ids.split(",") if s.strip()) | |
| if args.system_ids | |
| else SUPPORTED_SYSTEMS | |
| ) | |
| os.environ.setdefault("WANDB_PROJECT", "physix-live") | |
| train_sft( | |
| model_name=args.model, | |
| output_dir=args.output_dir, | |
| epochs=args.epochs, | |
| lora_r=args.lora_r, | |
| learning_rate=args.learning_rate, | |
| instances_per_system=args.instances_per_system, | |
| system_ids=system_ids, | |
| seed=args.seed, | |
| wandb_run_name=args.wandb_run_name, | |
| hub_checkpoint_repo_id=args.hub_checkpoint_repo_id, | |
| hub_token=os.environ.get("HF_TOKEN"), | |
| ) | |
| if __name__ == "__main__": | |
| main() | |