Spaces:
Sleeping
Sleeping
| """GRPO training loop using Unsloth + TRL. | |
| Requires the ``[train]`` optional dependency group (heavy ML deps). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Literal, Optional | |
| import torch | |
| from datasets import Dataset | |
| from pydantic import BaseModel, ConfigDict | |
| from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState | |
| from transformers import TrainingArguments as HFTrainingArguments | |
| from physix.systems import SUPPORTED_SYSTEMS | |
| from physix.training.dataset import ( | |
| DatasetSpec, | |
| build_training_dataset, | |
| ) | |
| from physix.training.reward_fns import make_reward_funcs | |
| from physix.training.scorer import Scorer | |
| # Unsloth patches must be applied before importing GRPOTrainer — order matters. | |
| # Requires trl<=0.24.0; newer versions break PatchFastRL. | |
| from unsloth import FastLanguageModel, PatchFastRL # noqa: E402 | |
| PatchFastRL("GRPO", FastLanguageModel) | |
| from trl import GRPOConfig, GRPOTrainer # noqa: E402 (must come after PatchFastRL) | |
| _log = logging.getLogger(__name__) | |
| Ablation = Literal["no_progress", "no_simplicity", "no_format"] | |
| SaveMethod = Literal["lora", "merged_16bit", "merged_4bit"] | |
| class TrainingConfig(BaseModel): | |
| """All hyperparameters in one place; the CLI populates this.""" | |
| model_config = ConfigDict(frozen=True) | |
| model_name: str = "Qwen/Qwen2.5-1.5B-Instruct" | |
| #: Path to merged SFT model to warm-start GRPO from. | |
| sft_checkpoint: Optional[str] = None | |
| #: Hub repo id or local path of an existing LoRA adapter to resume from. | |
| lora_adapter_repo: Optional[str] = None | |
| output_dir: str = "runs/physix-1.5b-rl" | |
| max_seq_length: int = 2048 | |
| lora_r: int = 16 | |
| lora_alpha: int = 32 | |
| learning_rate: float = 5.0e-6 | |
| temperature: float = 0.9 | |
| max_completion_length: int = 256 | |
| beta: float = 0.04 | |
| num_generations: int = 4 | |
| per_device_train_batch_size: int = 1 | |
| gradient_accumulation_steps: int = 8 | |
| num_steps: int = 300 | |
| #: Set to 0 to disable early stopping. | |
| early_stop_patience: int = 50 | |
| seed: int = 0 | |
| instances_per_system: int = 32 | |
| system_ids: tuple[str, ...] = SUPPORTED_SYSTEMS | |
| ablation: Optional[Ablation] = None | |
| wandb_project: str = "physix-live" | |
| wandb_run_name: Optional[str] = None | |
| push_to_hub: bool = False | |
| hub_repo_id: Optional[str] = None | |
| #: HF repo to push LoRA checkpoints to every save_steps. | |
| hub_checkpoint_repo_id: Optional[str] = None | |
| resume_from_checkpoint: Optional[str] = None | |
| save_method: SaveMethod = "merged_16bit" | |
| def train(config: TrainingConfig) -> None: | |
| """Run a full GRPO training loop with the given configuration.""" | |
| _configure_logging() | |
| import wandb | |
| run_name = config.wandb_run_name or f"physix-grpo-{config.num_steps}steps" | |
| wandb.init( | |
| project=config.wandb_project, | |
| name=run_name, | |
| config=config.model_dump(), | |
| tags=["grpo", "physix", config.model_name.split("/")[-1]], | |
| resume="allow", | |
| ) | |
| if config.hub_checkpoint_repo_id: | |
| ckpt_url = f"https://huggingface.co/{config.hub_checkpoint_repo_id}" | |
| wandb.run.summary["checkpoint/repo"] = config.hub_checkpoint_repo_id | |
| wandb.run.summary["checkpoint/repo_url"] = ckpt_url | |
| if config.hub_repo_id: | |
| wandb.run.summary["model/final_repo"] = config.hub_repo_id | |
| wandb.run.summary["model/final_url"] = ( | |
| f"https://huggingface.co/{config.hub_repo_id}" | |
| ) | |
| if config.lora_adapter_repo: | |
| wandb.run.summary["resume/from_adapter"] = config.lora_adapter_repo | |
| wandb.run.summary["resume/from_url"] = ( | |
| f"https://huggingface.co/{config.lora_adapter_repo}" | |
| ) | |
| parent_run = os.environ.get("WANDB_RESUMED_FROM") | |
| if parent_run: | |
| wandb.run.summary["resume/parent_wandb_run"] = parent_run | |
| wandb.run.summary["resume/parent_wandb_url"] = ( | |
| f"https://wandb.ai/{wandb.run.entity}/{wandb.run.project}/runs/{parent_run}" | |
| ) | |
| print( | |
| f"\n[wandb] WARM-STARTED run — adapter " | |
| f"https://huggingface.co/{config.lora_adapter_repo}\n", | |
| flush=True, | |
| ) | |
| _log.info("Loading model %s with Unsloth (4-bit, LoRA-%d)", config.model_name, config.lora_r) | |
| model, tokenizer = _load_model_and_tokenizer(config) | |
| train_dataset = _build_and_format_dataset(config, tokenizer) | |
| reward_funcs = _select_reward_funcs(config.ablation) | |
| grpo_config = _build_grpo_config(config) | |
| callbacks = [] | |
| if config.early_stop_patience > 0: | |
| callbacks.append(_RewardConvergenceCallback(patience=config.early_stop_patience)) | |
| _log.info( | |
| "Early stopping enabled: will stop if reward_std < 0.05 for %d consecutive steps", | |
| config.early_stop_patience, | |
| ) | |
| if config.hub_checkpoint_repo_id: | |
| callbacks.append(_WandbCheckpointCallback(config.hub_checkpoint_repo_id)) | |
| _log.info( | |
| "Checkpoint hub push enabled → %s (every %d steps)", | |
| config.hub_checkpoint_repo_id, | |
| grpo_config.save_steps, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| args=grpo_config, | |
| train_dataset=train_dataset, | |
| reward_funcs=reward_funcs, | |
| callbacks=callbacks or None, | |
| ) | |
| if config.resume_from_checkpoint: | |
| _log.info("Resuming from checkpoint: %s", config.resume_from_checkpoint) | |
| _log.info("Starting GRPO training for %d steps", config.num_steps) | |
| trainer.train(resume_from_checkpoint=config.resume_from_checkpoint) | |
| _log_reward_summary(trainer) | |
| _render_training_curves(trainer, config) | |
| _log.info("Saving adapter (%s) to %s", config.save_method, config.output_dir) | |
| _save_artifacts(model, tokenizer, config) | |
| wandb.finish() | |
| def _log_reward_summary(trainer: "GRPOTrainer") -> None: | |
| """Log first→last reward delta for every component. Raises if no rewards were logged.""" | |
| history = getattr(trainer.state, "log_history", []) or [] | |
| reward_entries = [ | |
| entry for entry in history | |
| if any(k.startswith("rewards/") or k == "reward" for k in entry) | |
| ] | |
| if not reward_entries: | |
| _log.error( | |
| "No reward metrics logged during training. This usually means " | |
| "every rollout failed to parse. Check `train/reward` in W&B and " | |
| "the most recent completion samples." | |
| ) | |
| raise RuntimeError( | |
| "GRPO produced no reward metrics — training silently failed." | |
| ) | |
| last = reward_entries[-1] | |
| first = reward_entries[0] | |
| _log.info("=" * 60) | |
| _log.info("GRPO reward summary (first → last logged step):") | |
| for key in sorted(last): | |
| if key.startswith("rewards/") or key == "reward": | |
| v0 = first.get(key) | |
| v1 = last.get(key) | |
| if isinstance(v0, (int, float)) and isinstance(v1, (int, float)): | |
| _log.info(" %-40s %.4f → %.4f (Δ=%+.4f)", key, v0, v1, v1 - v0) | |
| _log.info("=" * 60) | |
| def _render_training_curves( | |
| trainer: "GRPOTrainer", | |
| config: TrainingConfig, | |
| ) -> None: | |
| """Render loss/reward/component PNGs from log_history and push to Hub.""" | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") # headless / no display server in HF Jobs | |
| import matplotlib.pyplot as plt | |
| except Exception as exc: # noqa: BLE001 | |
| _log.warning("matplotlib unavailable, skipping curve PNGs: %s", exc) | |
| return | |
| history = list(getattr(trainer.state, "log_history", []) or []) | |
| if not history: | |
| _log.warning("No log_history found — cannot render curves.") | |
| return | |
| plots_dir = Path(config.output_dir) / "plots" | |
| plots_dir.mkdir(parents=True, exist_ok=True) | |
| def _series(metric: str) -> tuple[list[int], list[float]]: | |
| xs: list[int] = [] | |
| ys: list[float] = [] | |
| for entry in history: | |
| if metric in entry and "step" in entry: | |
| value = entry[metric] | |
| if isinstance(value, (int, float)): | |
| xs.append(int(entry["step"])) | |
| ys.append(float(value)) | |
| return xs, ys | |
| rendered: list[Path] = [] | |
| steps_l, losses = _series("loss") | |
| if steps_l: | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| ax.plot(steps_l, losses, color="#d62728", linewidth=1.8) | |
| ax.set_xlabel("training step") | |
| ax.set_ylabel("GRPO surrogate loss") | |
| ax.set_title("PhysiX GRPO — train/loss (lower is better)") | |
| ax.grid(alpha=0.3) | |
| path = plots_dir / "loss.png" | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=140) | |
| plt.close(fig) | |
| rendered.append(path) | |
| else: | |
| _log.warning("No 'loss' entries in log_history.") | |
| steps_r, rewards = _series("reward") | |
| _, reward_std = _series("reward_std") | |
| if steps_r: | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| ax.plot(steps_r, rewards, color="#2ca02c", linewidth=2.0, label="mean reward") | |
| if reward_std and len(reward_std) == len(rewards): | |
| import numpy as np | |
| r = np.asarray(rewards) | |
| s = np.asarray(reward_std) | |
| ax.fill_between(steps_r, r - s, r + s, color="#2ca02c", alpha=0.18, | |
| label="±1σ across rollouts") | |
| ax.set_xlabel("training step") | |
| ax.set_ylabel("mean reward (sum of components)") | |
| ax.set_title("PhysiX GRPO — train/reward (higher is better)") | |
| ax.legend(loc="best") | |
| ax.grid(alpha=0.3) | |
| path = plots_dir / "reward.png" | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=140) | |
| plt.close(fig) | |
| rendered.append(path) | |
| else: | |
| _log.warning("No 'reward' entries in log_history.") | |
| component_keys = sorted({ | |
| k for entry in history for k in entry | |
| if k.startswith("rewards/") and k.endswith("/mean") | |
| }) | |
| if component_keys: | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| for k in component_keys: | |
| xs, ys = _series(k) | |
| if xs: | |
| label = k.removeprefix("rewards/").removesuffix("/mean") | |
| ax.plot(xs, ys, linewidth=1.6, label=label) | |
| ax.set_xlabel("training step") | |
| ax.set_ylabel("component mean reward") | |
| ax.set_title("PhysiX GRPO — per-component reward (rewards/*/mean)") | |
| ax.legend(loc="best", fontsize=8) | |
| ax.grid(alpha=0.3) | |
| path = plots_dir / "reward_components.png" | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=140) | |
| plt.close(fig) | |
| rendered.append(path) | |
| if not rendered: | |
| _log.warning("No PNGs rendered — log_history had no recognised metrics.") | |
| return | |
| _log.info("Rendered %d curve PNG(s) to %s", len(rendered), plots_dir) | |
| try: | |
| import wandb | |
| if wandb.run is not None: | |
| wandb.log({ | |
| f"plots/{p.stem}": wandb.Image(str(p)) for p in rendered | |
| }) | |
| _log.info("Logged %d plot(s) to wandb.Media", len(rendered)) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.warning("Could not log plots to wandb: %s", exc) | |
| if config.push_to_hub and config.hub_repo_id: | |
| try: | |
| from huggingface_hub import HfApi, create_repo | |
| api = HfApi(token=os.environ.get("HUGGINGFACE_HUB_TOKEN")) | |
| create_repo( | |
| repo_id=config.hub_repo_id, | |
| repo_type="model", | |
| exist_ok=True, | |
| token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), | |
| ) | |
| for p in rendered: | |
| api.upload_file( | |
| path_or_fileobj=str(p), | |
| path_in_repo=f"plots/{p.name}", | |
| repo_id=config.hub_repo_id, | |
| repo_type="model", | |
| commit_message=f"plots: {p.name}", | |
| ) | |
| _log.info( | |
| "Pushed %d plot(s) to https://huggingface.co/%s/tree/main/plots", | |
| len(rendered), | |
| config.hub_repo_id, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.warning("Could not push plots to Hub: %s", exc) | |
| def _load_model_and_tokenizer( | |
| config: TrainingConfig, | |
| ) -> tuple[FastLanguageModel, AutoTokenizer]: | |
| """Load model via Unsloth in 4-bit and attach a LoRA adapter.""" | |
| if config.lora_adapter_repo: | |
| _log.info( | |
| "Resuming from existing LoRA adapter %s on top of %s", | |
| config.lora_adapter_repo, | |
| config.model_name, | |
| ) | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=config.model_name, | |
| max_seq_length=config.max_seq_length, | |
| load_in_4bit=True, | |
| dtype=None, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=config.lora_r, | |
| lora_alpha=config.lora_alpha, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=config.seed, | |
| ) | |
| model.load_adapter( | |
| config.lora_adapter_repo, | |
| adapter_name="default", | |
| is_trainable=True, | |
| ) | |
| _log.info("Adapter loaded; LoRA is trainable and ready for GRPO.") | |
| return model, tokenizer | |
| if config.sft_checkpoint: | |
| _log.info( | |
| "Loading SFT-warmed model from %s (GRPO will refine from here)", | |
| config.sft_checkpoint, | |
| ) | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=config.sft_checkpoint, | |
| max_seq_length=config.max_seq_length, | |
| load_in_4bit=True, | |
| dtype=None, | |
| ) | |
| else: | |
| _log.warning( | |
| "No --sft-checkpoint supplied. Starting GRPO from cold base model. " | |
| "Early reward signal will be near-zero; consider running sft.py first." | |
| ) | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=config.model_name, | |
| max_seq_length=config.max_seq_length, | |
| load_in_4bit=True, | |
| dtype=None, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=config.lora_r, | |
| lora_alpha=config.lora_alpha, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=config.seed, | |
| ) | |
| return model, tokenizer | |
| def _build_and_format_dataset( | |
| config: TrainingConfig, | |
| tokenizer: AutoTokenizer, | |
| ) -> Dataset: | |
| spec = DatasetSpec( | |
| system_ids=config.system_ids, | |
| instances_per_system=config.instances_per_system, | |
| seed=config.seed, | |
| ) | |
| dataset = build_training_dataset(spec) | |
| _log.info( | |
| "Built training dataset: %d rows across %d systems (%s)", | |
| len(dataset), | |
| len(config.system_ids), | |
| ", ".join(config.system_ids), | |
| ) | |
| def _apply_chat_template(example: dict[str, object]) -> dict[str, object]: | |
| formatted = tokenizer.apply_chat_template( | |
| example["prompt"], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| return {"prompt": formatted} | |
| return dataset.map(_apply_chat_template) | |
| def _select_reward_funcs(ablation: Optional[Ablation]) -> list[object]: | |
| """Return the active reward function list, optionally with one signal ablated.""" | |
| scorer = Scorer() | |
| funcs = make_reward_funcs(scorer) | |
| full = [ | |
| funcs["match"], | |
| funcs["match_dense"], | |
| funcs["correctness"], | |
| funcs["simplicity"], | |
| funcs["format"], | |
| ] | |
| if ablation is None: | |
| return full | |
| if ablation == "no_simplicity": | |
| return [funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["format"]] | |
| if ablation == "no_format": | |
| return [funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["simplicity"]] | |
| if ablation == "no_progress": | |
| return full # progress was removed; treat as full set for backward compat | |
| raise ValueError( | |
| f"Unknown ablation {ablation!r}. Choose from " | |
| "no_progress | no_simplicity | no_format | None." | |
| ) | |
| class _RewardConvergenceCallback(TrainerCallback): | |
| """Stop early when reward_std stays below min_std for `patience` consecutive steps.""" | |
| def __init__(self, patience: int = 50, min_std: float = 0.05) -> None: | |
| self._patience = patience | |
| self._min_std = min_std | |
| self._flat_steps: int = 0 | |
| def on_log( | |
| self, | |
| args: HFTrainingArguments, | |
| state: TrainerState, | |
| control: TrainerControl, | |
| logs: dict | None = None, | |
| **kwargs, | |
| ) -> None: | |
| if not logs: | |
| return | |
| reward_std = logs.get("reward_std") | |
| if reward_std is None: | |
| return | |
| if reward_std < self._min_std: | |
| self._flat_steps += 1 | |
| else: | |
| self._flat_steps = 0 | |
| if self._flat_steps >= self._patience: | |
| step = state.global_step | |
| msg = ( | |
| f"[early-stop] reward_std < {self._min_std} for " | |
| f"{self._flat_steps} consecutive steps at step {step}. " | |
| "Stopping training — policy has converged." | |
| ) | |
| print(f"\n{msg}\n", flush=True) | |
| _log.info(msg) | |
| try: | |
| import wandb | |
| if wandb.run is not None: | |
| wandb.run.summary["early_stop/step"] = step | |
| wandb.run.summary["early_stop/reason"] = ( | |
| f"reward_std < {self._min_std} for {self._flat_steps} steps" | |
| ) | |
| wandb.log({"early_stop/triggered": 1}, step=step) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.debug("Could not log early-stop event to W&B: %s", exc) | |
| control.should_training_stop = True | |
| class _WandbCheckpointCallback(TrainerCallback): | |
| """Log checkpoint metadata to W&B summary and stdout after each Trainer save.""" | |
| def __init__(self, hub_checkpoint_repo_id: str) -> None: | |
| self._repo = hub_checkpoint_repo_id | |
| self._repo_url = f"https://huggingface.co/{hub_checkpoint_repo_id}" | |
| self._table = None | |
| def on_train_begin( | |
| self, | |
| args: HFTrainingArguments, | |
| state: TrainerState, | |
| control: TrainerControl, | |
| **kwargs, | |
| ) -> None: | |
| try: | |
| import wandb | |
| if wandb.run is None: | |
| return | |
| wandb.run.summary["checkpoint/repo_url"] = self._repo_url | |
| wandb.run.summary["checkpoint/repo"] = self._repo | |
| wandb.config.update( | |
| {"checkpoint_repo_url": self._repo_url, "checkpoint_repo": self._repo}, | |
| allow_val_change=True, | |
| ) | |
| print( | |
| f"\n[wandb] Checkpoint repo pinned in run summary: {self._repo_url}\n", | |
| flush=True, | |
| ) | |
| self._publish_wandb_run_id(wandb.run.id) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.warning("Could not pin checkpoint repo to W&B summary: %s", exc) | |
| def _publish_wandb_run_id(self, run_id: str) -> None: | |
| try: | |
| import tempfile | |
| from huggingface_hub import HfApi, create_repo | |
| token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") | |
| api = HfApi(token=token) | |
| create_repo(self._repo, exist_ok=True, repo_type="model", token=token) | |
| with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as tmp: | |
| tmp.write(run_id) | |
| tmp_path = tmp.name | |
| api.upload_file( | |
| path_or_fileobj=tmp_path, | |
| path_in_repo="wandb_run_id.txt", | |
| repo_id=self._repo, | |
| repo_type="model", | |
| commit_message=f"Pin W&B run id {run_id}", | |
| token=token, | |
| ) | |
| print(f"[wandb] Published run_id={run_id} to {self._repo_url}/wandb_run_id.txt", flush=True) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.warning("Could not publish wandb run id (non-fatal): %s", exc) | |
| def on_save( | |
| self, | |
| args: HFTrainingArguments, | |
| state: TrainerState, | |
| control: TrainerControl, | |
| **kwargs, | |
| ) -> None: | |
| try: | |
| import wandb | |
| if wandb.run is None: | |
| return | |
| step = state.global_step | |
| commit_sha = self._latest_commit_sha() | |
| short = (commit_sha or "pending")[:8] | |
| tree_url = ( | |
| f"{self._repo_url}/tree/{commit_sha}" | |
| if commit_sha | |
| else f"{self._repo_url}/tree/main" | |
| ) | |
| wandb.run.summary["checkpoint/last_step"] = step | |
| wandb.run.summary["checkpoint/last_commit"] = commit_sha or "pending" | |
| wandb.run.summary["checkpoint/last_url"] = tree_url | |
| wandb.log({"checkpoint/step": step}, step=step) | |
| if self._table is None: | |
| self._table = wandb.Table( | |
| columns=["step", "commit", "url", "repo"] | |
| ) | |
| self._table.add_data(step, commit_sha or "pending", tree_url, self._repo) | |
| wandb.log({"checkpoint_history": self._table}, step=step) | |
| if commit_sha: | |
| from physix.training.checkpoints import ( | |
| CheckpointHandle, | |
| log_link_artifact_to_wandb, | |
| ) | |
| handle = CheckpointHandle( | |
| repo_id=self._repo, | |
| subfolder=f"checkpoint-{step}", | |
| revision=commit_sha, | |
| step=step, | |
| ) | |
| log_link_artifact_to_wandb( | |
| handle, | |
| artifact_name="physix-grpo-checkpoint", | |
| ) | |
| print( | |
| "\n" | |
| "================ CHECKPOINT SAVED ================\n" | |
| f" step : {step}\n" | |
| f" commit: {short}\n" | |
| f" url : {tree_url}\n" | |
| f" repo : {self._repo_url}\n" | |
| "==================================================\n", | |
| flush=True, | |
| ) | |
| _log.info( | |
| "W&B checkpoint metadata logged: step=%d commit=%s", | |
| step, | |
| short, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.warning( | |
| "W&B checkpoint callback skipped at step %d: %s. " | |
| "Training continues; the actual checkpoint is still pushed " | |
| "to the HF Hub by the trainer's PushToHubCallback.", | |
| state.global_step, | |
| exc, | |
| ) | |
| def _latest_commit_sha(self) -> Optional[str]: | |
| """Best-effort fetch of the latest commit SHA; returns None on failure.""" | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=os.environ.get("HUGGINGFACE_HUB_TOKEN")) | |
| commits = api.list_repo_commits(repo_id=self._repo, repo_type="model") | |
| if commits: | |
| return commits[0].commit_id | |
| except Exception as exc: # noqa: BLE001 | |
| _log.debug("Could not fetch latest commit sha: %s", exc) | |
| return None | |
| def _build_grpo_config(config: TrainingConfig) -> GRPOConfig: | |
| effective_batch = ( | |
| config.per_device_train_batch_size * config.gradient_accumulation_steps | |
| ) | |
| if effective_batch % config.num_generations != 0: | |
| raise ValueError( | |
| f"effective_batch_size ({effective_batch}) must be divisible by " | |
| f"num_generations ({config.num_generations}). Adjust " | |
| "per_device_train_batch_size, gradient_accumulation_steps, or " | |
| "num_generations." | |
| ) | |
| hub_kwargs: dict = {} | |
| if config.hub_checkpoint_repo_id: | |
| hub_kwargs = dict( | |
| push_to_hub=True, | |
| hub_model_id=config.hub_checkpoint_repo_id, | |
| hub_strategy="checkpoint", | |
| hub_token=os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN"), | |
| ) | |
| return GRPOConfig( | |
| output_dir=config.output_dir, | |
| learning_rate=config.learning_rate, | |
| per_device_train_batch_size=config.per_device_train_batch_size, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| num_train_epochs=1, | |
| max_steps=config.num_steps, | |
| num_generations=config.num_generations, | |
| max_completion_length=config.max_completion_length, | |
| max_prompt_length=config.max_seq_length - config.max_completion_length, | |
| temperature=config.temperature, | |
| beta=config.beta, | |
| logging_steps=1, | |
| save_strategy="steps", | |
| save_steps=max(50, config.num_steps // 6), | |
| report_to=["wandb"], | |
| run_name=config.wandb_run_name, | |
| seed=config.seed, | |
| bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False, | |
| fp16=not torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False, | |
| **hub_kwargs, | |
| ) | |
| def _save_artifacts( | |
| model: FastLanguageModel, | |
| tokenizer: AutoTokenizer, | |
| config: TrainingConfig, | |
| ) -> None: | |
| """Save model locally and optionally push to Hub.""" | |
| out_path = Path(config.output_dir) | |
| out_path.mkdir(parents=True, exist_ok=True) | |
| save_dir = out_path / config.save_method | |
| model.save_pretrained_merged( | |
| save_directory=str(save_dir), | |
| tokenizer=tokenizer, | |
| save_method=config.save_method, | |
| ) | |
| if config.push_to_hub and config.hub_repo_id: | |
| _log.info("Pushing %s artifact to Hugging Face Hub: %s", config.save_method, config.hub_repo_id) | |
| model.push_to_hub_merged( | |
| config.hub_repo_id, | |
| tokenizer, | |
| save_method=config.save_method, | |
| token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), | |
| ) | |
| def _configure_logging() -> None: | |
| logging.basicConfig( | |
| level=os.environ.get("PHYSIX_LOG_LEVEL", "INFO"), | |
| format="[%(asctime)s] %(levelname)s %(name)s | %(message)s", | |
| ) | |
| def _parse_args() -> TrainingConfig: | |
| parser = argparse.ArgumentParser(description="Train PhysiX-Live with GRPO.") | |
| parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") | |
| parser.add_argument("--output-dir", default="runs/physix-1.5b-rl") | |
| parser.add_argument("--num-steps", type=int, default=300) | |
| parser.add_argument("--learning-rate", type=float, default=5.0e-6) | |
| parser.add_argument("--num-generations", type=int, default=4) | |
| parser.add_argument("--max-completion-length", type=int, default=256, | |
| help="Max tokens per rollout completion. Shorter = faster generation.") | |
| parser.add_argument("--lora-r", type=int, default=16) | |
| parser.add_argument("--instances-per-system", type=int, default=32) | |
| parser.add_argument( | |
| "--system-ids", | |
| default=None, | |
| help=( | |
| "Comma-separated list of system IDs to train on " | |
| "(e.g. 'damped_spring' or 'free_fall,simple_pendulum'). " | |
| "Defaults to all SUPPORTED_SYSTEMS when omitted." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--ablation", | |
| choices=("no_progress", "no_simplicity", "no_format"), | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--save-method", | |
| choices=("lora", "merged_16bit", "merged_4bit"), | |
| default="merged_16bit", | |
| help="How to persist the final adapter (merged_16bit is deployable).", | |
| ) | |
| parser.add_argument("--sft-checkpoint", default=None, | |
| help="Path to a merged SFT model from sft.py to warm-start from.") | |
| parser.add_argument( | |
| "--lora-adapter-repo", | |
| default=None, | |
| help=( | |
| "Hub repo id (or local path) of an existing LoRA adapter to warm-start " | |
| "GRPO from — e.g. a previous run's checkpoint at " | |
| "user/physix-1.5b-rl-ckpt. Mutually exclusive with --sft-checkpoint." | |
| ), | |
| ) | |
| parser.add_argument("--wandb-project", default="physix-live") | |
| parser.add_argument("--wandb-run-name", default=None) | |
| parser.add_argument("--push-to-hub", action="store_true") | |
| parser.add_argument("--hub-repo-id", default=None) | |
| parser.add_argument( | |
| "--hub-checkpoint-repo-id", | |
| default=None, | |
| help="HF repo to push LoRA checkpoints to every save_steps (e.g. user/physix-ckpt).", | |
| ) | |
| parser.add_argument( | |
| "--resume-from-checkpoint", | |
| default=None, | |
| help="Path to a Trainer checkpoint directory to resume GRPO from.", | |
| ) | |
| parser.add_argument( | |
| "--early-stop-patience", | |
| type=int, | |
| default=50, | |
| help=( | |
| "Stop training early if reward_std stays below 0.05 for this many " | |
| "consecutive steps (policy converged, GRPO advantage ≈ 0). " | |
| "Set to 0 to disable." | |
| ), | |
| ) | |
| parser.add_argument("--seed", type=int, default=0) | |
| args = parser.parse_args() | |
| if args.sft_checkpoint and args.lora_adapter_repo: | |
| parser.error( | |
| "--sft-checkpoint and --lora-adapter-repo are mutually exclusive. " | |
| "Use --lora-adapter-repo to resume from a prior GRPO run, or " | |
| "--sft-checkpoint for a fresh GRPO from a merged SFT model." | |
| ) | |
| system_ids = ( | |
| tuple(s.strip() for s in args.system_ids.split(",") if s.strip()) | |
| if args.system_ids | |
| else SUPPORTED_SYSTEMS | |
| ) | |
| return TrainingConfig( | |
| model_name=args.model, | |
| sft_checkpoint=args.sft_checkpoint, | |
| lora_adapter_repo=args.lora_adapter_repo, | |
| output_dir=args.output_dir, | |
| num_steps=args.num_steps, | |
| learning_rate=args.learning_rate, | |
| num_generations=args.num_generations, | |
| max_completion_length=args.max_completion_length, | |
| lora_r=args.lora_r, | |
| instances_per_system=args.instances_per_system, | |
| system_ids=system_ids, | |
| ablation=args.ablation, | |
| save_method=args.save_method, | |
| wandb_project=args.wandb_project, | |
| wandb_run_name=args.wandb_run_name, | |
| push_to_hub=args.push_to_hub, | |
| hub_repo_id=args.hub_repo_id, | |
| hub_checkpoint_repo_id=args.hub_checkpoint_repo_id, | |
| resume_from_checkpoint=args.resume_from_checkpoint, | |
| early_stop_patience=args.early_stop_patience, | |
| seed=args.seed, | |
| ) | |
| def main() -> None: | |
| config = _parse_args() | |
| os.environ.setdefault("WANDB_PROJECT", config.wandb_project) | |
| train(config) | |
| if __name__ == "__main__": | |
| main() | |