Pratyush-01's picture
cleanup: strip verbose comments from physix/training/loop.py
b4bd6d8 verified
"""GRPO training loop using Unsloth + TRL.
Requires the ``[train]`` optional dependency group (heavy ML deps).
"""
from __future__ import annotations
import argparse
import logging
import os
from pathlib import Path
from typing import Literal, Optional
import torch
from datasets import Dataset
from pydantic import BaseModel, ConfigDict
from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState
from transformers import TrainingArguments as HFTrainingArguments
from physix.systems import SUPPORTED_SYSTEMS
from physix.training.dataset import (
DatasetSpec,
build_training_dataset,
)
from physix.training.reward_fns import make_reward_funcs
from physix.training.scorer import Scorer
# Unsloth patches must be applied before importing GRPOTrainer — order matters.
# Requires trl<=0.24.0; newer versions break PatchFastRL.
from unsloth import FastLanguageModel, PatchFastRL # noqa: E402
PatchFastRL("GRPO", FastLanguageModel)
from trl import GRPOConfig, GRPOTrainer # noqa: E402 (must come after PatchFastRL)
_log = logging.getLogger(__name__)
Ablation = Literal["no_progress", "no_simplicity", "no_format"]
SaveMethod = Literal["lora", "merged_16bit", "merged_4bit"]
class TrainingConfig(BaseModel):
"""All hyperparameters in one place; the CLI populates this."""
model_config = ConfigDict(frozen=True)
model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
#: Path to merged SFT model to warm-start GRPO from.
sft_checkpoint: Optional[str] = None
#: Hub repo id or local path of an existing LoRA adapter to resume from.
lora_adapter_repo: Optional[str] = None
output_dir: str = "runs/physix-1.5b-rl"
max_seq_length: int = 2048
lora_r: int = 16
lora_alpha: int = 32
learning_rate: float = 5.0e-6
temperature: float = 0.9
max_completion_length: int = 256
beta: float = 0.04
num_generations: int = 4
per_device_train_batch_size: int = 1
gradient_accumulation_steps: int = 8
num_steps: int = 300
#: Set to 0 to disable early stopping.
early_stop_patience: int = 50
seed: int = 0
instances_per_system: int = 32
system_ids: tuple[str, ...] = SUPPORTED_SYSTEMS
ablation: Optional[Ablation] = None
wandb_project: str = "physix-live"
wandb_run_name: Optional[str] = None
push_to_hub: bool = False
hub_repo_id: Optional[str] = None
#: HF repo to push LoRA checkpoints to every save_steps.
hub_checkpoint_repo_id: Optional[str] = None
resume_from_checkpoint: Optional[str] = None
save_method: SaveMethod = "merged_16bit"
def train(config: TrainingConfig) -> None:
"""Run a full GRPO training loop with the given configuration."""
_configure_logging()
import wandb
run_name = config.wandb_run_name or f"physix-grpo-{config.num_steps}steps"
wandb.init(
project=config.wandb_project,
name=run_name,
config=config.model_dump(),
tags=["grpo", "physix", config.model_name.split("/")[-1]],
resume="allow",
)
if config.hub_checkpoint_repo_id:
ckpt_url = f"https://huggingface.co/{config.hub_checkpoint_repo_id}"
wandb.run.summary["checkpoint/repo"] = config.hub_checkpoint_repo_id
wandb.run.summary["checkpoint/repo_url"] = ckpt_url
if config.hub_repo_id:
wandb.run.summary["model/final_repo"] = config.hub_repo_id
wandb.run.summary["model/final_url"] = (
f"https://huggingface.co/{config.hub_repo_id}"
)
if config.lora_adapter_repo:
wandb.run.summary["resume/from_adapter"] = config.lora_adapter_repo
wandb.run.summary["resume/from_url"] = (
f"https://huggingface.co/{config.lora_adapter_repo}"
)
parent_run = os.environ.get("WANDB_RESUMED_FROM")
if parent_run:
wandb.run.summary["resume/parent_wandb_run"] = parent_run
wandb.run.summary["resume/parent_wandb_url"] = (
f"https://wandb.ai/{wandb.run.entity}/{wandb.run.project}/runs/{parent_run}"
)
print(
f"\n[wandb] WARM-STARTED run — adapter "
f"https://huggingface.co/{config.lora_adapter_repo}\n",
flush=True,
)
_log.info("Loading model %s with Unsloth (4-bit, LoRA-%d)", config.model_name, config.lora_r)
model, tokenizer = _load_model_and_tokenizer(config)
train_dataset = _build_and_format_dataset(config, tokenizer)
reward_funcs = _select_reward_funcs(config.ablation)
grpo_config = _build_grpo_config(config)
callbacks = []
if config.early_stop_patience > 0:
callbacks.append(_RewardConvergenceCallback(patience=config.early_stop_patience))
_log.info(
"Early stopping enabled: will stop if reward_std < 0.05 for %d consecutive steps",
config.early_stop_patience,
)
if config.hub_checkpoint_repo_id:
callbacks.append(_WandbCheckpointCallback(config.hub_checkpoint_repo_id))
_log.info(
"Checkpoint hub push enabled → %s (every %d steps)",
config.hub_checkpoint_repo_id,
grpo_config.save_steps,
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
args=grpo_config,
train_dataset=train_dataset,
reward_funcs=reward_funcs,
callbacks=callbacks or None,
)
if config.resume_from_checkpoint:
_log.info("Resuming from checkpoint: %s", config.resume_from_checkpoint)
_log.info("Starting GRPO training for %d steps", config.num_steps)
trainer.train(resume_from_checkpoint=config.resume_from_checkpoint)
_log_reward_summary(trainer)
_render_training_curves(trainer, config)
_log.info("Saving adapter (%s) to %s", config.save_method, config.output_dir)
_save_artifacts(model, tokenizer, config)
wandb.finish()
def _log_reward_summary(trainer: "GRPOTrainer") -> None:
"""Log first→last reward delta for every component. Raises if no rewards were logged."""
history = getattr(trainer.state, "log_history", []) or []
reward_entries = [
entry for entry in history
if any(k.startswith("rewards/") or k == "reward" for k in entry)
]
if not reward_entries:
_log.error(
"No reward metrics logged during training. This usually means "
"every rollout failed to parse. Check `train/reward` in W&B and "
"the most recent completion samples."
)
raise RuntimeError(
"GRPO produced no reward metrics — training silently failed."
)
last = reward_entries[-1]
first = reward_entries[0]
_log.info("=" * 60)
_log.info("GRPO reward summary (first → last logged step):")
for key in sorted(last):
if key.startswith("rewards/") or key == "reward":
v0 = first.get(key)
v1 = last.get(key)
if isinstance(v0, (int, float)) and isinstance(v1, (int, float)):
_log.info(" %-40s %.4f → %.4f (Δ=%+.4f)", key, v0, v1, v1 - v0)
_log.info("=" * 60)
def _render_training_curves(
trainer: "GRPOTrainer",
config: TrainingConfig,
) -> None:
"""Render loss/reward/component PNGs from log_history and push to Hub."""
try:
import matplotlib
matplotlib.use("Agg") # headless / no display server in HF Jobs
import matplotlib.pyplot as plt
except Exception as exc: # noqa: BLE001
_log.warning("matplotlib unavailable, skipping curve PNGs: %s", exc)
return
history = list(getattr(trainer.state, "log_history", []) or [])
if not history:
_log.warning("No log_history found — cannot render curves.")
return
plots_dir = Path(config.output_dir) / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)
def _series(metric: str) -> tuple[list[int], list[float]]:
xs: list[int] = []
ys: list[float] = []
for entry in history:
if metric in entry and "step" in entry:
value = entry[metric]
if isinstance(value, (int, float)):
xs.append(int(entry["step"]))
ys.append(float(value))
return xs, ys
rendered: list[Path] = []
steps_l, losses = _series("loss")
if steps_l:
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(steps_l, losses, color="#d62728", linewidth=1.8)
ax.set_xlabel("training step")
ax.set_ylabel("GRPO surrogate loss")
ax.set_title("PhysiX GRPO — train/loss (lower is better)")
ax.grid(alpha=0.3)
path = plots_dir / "loss.png"
fig.tight_layout()
fig.savefig(path, dpi=140)
plt.close(fig)
rendered.append(path)
else:
_log.warning("No 'loss' entries in log_history.")
steps_r, rewards = _series("reward")
_, reward_std = _series("reward_std")
if steps_r:
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(steps_r, rewards, color="#2ca02c", linewidth=2.0, label="mean reward")
if reward_std and len(reward_std) == len(rewards):
import numpy as np
r = np.asarray(rewards)
s = np.asarray(reward_std)
ax.fill_between(steps_r, r - s, r + s, color="#2ca02c", alpha=0.18,
label="±1σ across rollouts")
ax.set_xlabel("training step")
ax.set_ylabel("mean reward (sum of components)")
ax.set_title("PhysiX GRPO — train/reward (higher is better)")
ax.legend(loc="best")
ax.grid(alpha=0.3)
path = plots_dir / "reward.png"
fig.tight_layout()
fig.savefig(path, dpi=140)
plt.close(fig)
rendered.append(path)
else:
_log.warning("No 'reward' entries in log_history.")
component_keys = sorted({
k for entry in history for k in entry
if k.startswith("rewards/") and k.endswith("/mean")
})
if component_keys:
fig, ax = plt.subplots(figsize=(8, 4.5))
for k in component_keys:
xs, ys = _series(k)
if xs:
label = k.removeprefix("rewards/").removesuffix("/mean")
ax.plot(xs, ys, linewidth=1.6, label=label)
ax.set_xlabel("training step")
ax.set_ylabel("component mean reward")
ax.set_title("PhysiX GRPO — per-component reward (rewards/*/mean)")
ax.legend(loc="best", fontsize=8)
ax.grid(alpha=0.3)
path = plots_dir / "reward_components.png"
fig.tight_layout()
fig.savefig(path, dpi=140)
plt.close(fig)
rendered.append(path)
if not rendered:
_log.warning("No PNGs rendered — log_history had no recognised metrics.")
return
_log.info("Rendered %d curve PNG(s) to %s", len(rendered), plots_dir)
try:
import wandb
if wandb.run is not None:
wandb.log({
f"plots/{p.stem}": wandb.Image(str(p)) for p in rendered
})
_log.info("Logged %d plot(s) to wandb.Media", len(rendered))
except Exception as exc: # noqa: BLE001
_log.warning("Could not log plots to wandb: %s", exc)
if config.push_to_hub and config.hub_repo_id:
try:
from huggingface_hub import HfApi, create_repo
api = HfApi(token=os.environ.get("HUGGINGFACE_HUB_TOKEN"))
create_repo(
repo_id=config.hub_repo_id,
repo_type="model",
exist_ok=True,
token=os.environ.get("HUGGINGFACE_HUB_TOKEN"),
)
for p in rendered:
api.upload_file(
path_or_fileobj=str(p),
path_in_repo=f"plots/{p.name}",
repo_id=config.hub_repo_id,
repo_type="model",
commit_message=f"plots: {p.name}",
)
_log.info(
"Pushed %d plot(s) to https://huggingface.co/%s/tree/main/plots",
len(rendered),
config.hub_repo_id,
)
except Exception as exc: # noqa: BLE001
_log.warning("Could not push plots to Hub: %s", exc)
def _load_model_and_tokenizer(
config: TrainingConfig,
) -> tuple[FastLanguageModel, AutoTokenizer]:
"""Load model via Unsloth in 4-bit and attach a LoRA adapter."""
if config.lora_adapter_repo:
_log.info(
"Resuming from existing LoRA adapter %s on top of %s",
config.lora_adapter_repo,
config.model_name,
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config.model_name,
max_seq_length=config.max_seq_length,
load_in_4bit=True,
dtype=None,
)
model = FastLanguageModel.get_peft_model(
model,
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
bias="none",
use_gradient_checkpointing="unsloth",
random_state=config.seed,
)
model.load_adapter(
config.lora_adapter_repo,
adapter_name="default",
is_trainable=True,
)
_log.info("Adapter loaded; LoRA is trainable and ready for GRPO.")
return model, tokenizer
if config.sft_checkpoint:
_log.info(
"Loading SFT-warmed model from %s (GRPO will refine from here)",
config.sft_checkpoint,
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config.sft_checkpoint,
max_seq_length=config.max_seq_length,
load_in_4bit=True,
dtype=None,
)
else:
_log.warning(
"No --sft-checkpoint supplied. Starting GRPO from cold base model. "
"Early reward signal will be near-zero; consider running sft.py first."
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config.model_name,
max_seq_length=config.max_seq_length,
load_in_4bit=True,
dtype=None,
)
model = FastLanguageModel.get_peft_model(
model,
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
bias="none",
use_gradient_checkpointing="unsloth",
random_state=config.seed,
)
return model, tokenizer
def _build_and_format_dataset(
config: TrainingConfig,
tokenizer: AutoTokenizer,
) -> Dataset:
spec = DatasetSpec(
system_ids=config.system_ids,
instances_per_system=config.instances_per_system,
seed=config.seed,
)
dataset = build_training_dataset(spec)
_log.info(
"Built training dataset: %d rows across %d systems (%s)",
len(dataset),
len(config.system_ids),
", ".join(config.system_ids),
)
def _apply_chat_template(example: dict[str, object]) -> dict[str, object]:
formatted = tokenizer.apply_chat_template(
example["prompt"],
tokenize=False,
add_generation_prompt=True,
)
return {"prompt": formatted}
return dataset.map(_apply_chat_template)
def _select_reward_funcs(ablation: Optional[Ablation]) -> list[object]:
"""Return the active reward function list, optionally with one signal ablated."""
scorer = Scorer()
funcs = make_reward_funcs(scorer)
full = [
funcs["match"],
funcs["match_dense"],
funcs["correctness"],
funcs["simplicity"],
funcs["format"],
]
if ablation is None:
return full
if ablation == "no_simplicity":
return [funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["format"]]
if ablation == "no_format":
return [funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["simplicity"]]
if ablation == "no_progress":
return full # progress was removed; treat as full set for backward compat
raise ValueError(
f"Unknown ablation {ablation!r}. Choose from "
"no_progress | no_simplicity | no_format | None."
)
class _RewardConvergenceCallback(TrainerCallback):
"""Stop early when reward_std stays below min_std for `patience` consecutive steps."""
def __init__(self, patience: int = 50, min_std: float = 0.05) -> None:
self._patience = patience
self._min_std = min_std
self._flat_steps: int = 0
def on_log(
self,
args: HFTrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: dict | None = None,
**kwargs,
) -> None:
if not logs:
return
reward_std = logs.get("reward_std")
if reward_std is None:
return
if reward_std < self._min_std:
self._flat_steps += 1
else:
self._flat_steps = 0
if self._flat_steps >= self._patience:
step = state.global_step
msg = (
f"[early-stop] reward_std < {self._min_std} for "
f"{self._flat_steps} consecutive steps at step {step}. "
"Stopping training — policy has converged."
)
print(f"\n{msg}\n", flush=True)
_log.info(msg)
try:
import wandb
if wandb.run is not None:
wandb.run.summary["early_stop/step"] = step
wandb.run.summary["early_stop/reason"] = (
f"reward_std < {self._min_std} for {self._flat_steps} steps"
)
wandb.log({"early_stop/triggered": 1}, step=step)
except Exception as exc: # noqa: BLE001
_log.debug("Could not log early-stop event to W&B: %s", exc)
control.should_training_stop = True
class _WandbCheckpointCallback(TrainerCallback):
"""Log checkpoint metadata to W&B summary and stdout after each Trainer save."""
def __init__(self, hub_checkpoint_repo_id: str) -> None:
self._repo = hub_checkpoint_repo_id
self._repo_url = f"https://huggingface.co/{hub_checkpoint_repo_id}"
self._table = None
def on_train_begin(
self,
args: HFTrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
try:
import wandb
if wandb.run is None:
return
wandb.run.summary["checkpoint/repo_url"] = self._repo_url
wandb.run.summary["checkpoint/repo"] = self._repo
wandb.config.update(
{"checkpoint_repo_url": self._repo_url, "checkpoint_repo": self._repo},
allow_val_change=True,
)
print(
f"\n[wandb] Checkpoint repo pinned in run summary: {self._repo_url}\n",
flush=True,
)
self._publish_wandb_run_id(wandb.run.id)
except Exception as exc: # noqa: BLE001
_log.warning("Could not pin checkpoint repo to W&B summary: %s", exc)
def _publish_wandb_run_id(self, run_id: str) -> None:
try:
import tempfile
from huggingface_hub import HfApi, create_repo
token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
api = HfApi(token=token)
create_repo(self._repo, exist_ok=True, repo_type="model", token=token)
with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as tmp:
tmp.write(run_id)
tmp_path = tmp.name
api.upload_file(
path_or_fileobj=tmp_path,
path_in_repo="wandb_run_id.txt",
repo_id=self._repo,
repo_type="model",
commit_message=f"Pin W&B run id {run_id}",
token=token,
)
print(f"[wandb] Published run_id={run_id} to {self._repo_url}/wandb_run_id.txt", flush=True)
except Exception as exc: # noqa: BLE001
_log.warning("Could not publish wandb run id (non-fatal): %s", exc)
def on_save(
self,
args: HFTrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
try:
import wandb
if wandb.run is None:
return
step = state.global_step
commit_sha = self._latest_commit_sha()
short = (commit_sha or "pending")[:8]
tree_url = (
f"{self._repo_url}/tree/{commit_sha}"
if commit_sha
else f"{self._repo_url}/tree/main"
)
wandb.run.summary["checkpoint/last_step"] = step
wandb.run.summary["checkpoint/last_commit"] = commit_sha or "pending"
wandb.run.summary["checkpoint/last_url"] = tree_url
wandb.log({"checkpoint/step": step}, step=step)
if self._table is None:
self._table = wandb.Table(
columns=["step", "commit", "url", "repo"]
)
self._table.add_data(step, commit_sha or "pending", tree_url, self._repo)
wandb.log({"checkpoint_history": self._table}, step=step)
if commit_sha:
from physix.training.checkpoints import (
CheckpointHandle,
log_link_artifact_to_wandb,
)
handle = CheckpointHandle(
repo_id=self._repo,
subfolder=f"checkpoint-{step}",
revision=commit_sha,
step=step,
)
log_link_artifact_to_wandb(
handle,
artifact_name="physix-grpo-checkpoint",
)
print(
"\n"
"================ CHECKPOINT SAVED ================\n"
f" step : {step}\n"
f" commit: {short}\n"
f" url : {tree_url}\n"
f" repo : {self._repo_url}\n"
"==================================================\n",
flush=True,
)
_log.info(
"W&B checkpoint metadata logged: step=%d commit=%s",
step,
short,
)
except Exception as exc: # noqa: BLE001
_log.warning(
"W&B checkpoint callback skipped at step %d: %s. "
"Training continues; the actual checkpoint is still pushed "
"to the HF Hub by the trainer's PushToHubCallback.",
state.global_step,
exc,
)
def _latest_commit_sha(self) -> Optional[str]:
"""Best-effort fetch of the latest commit SHA; returns None on failure."""
try:
from huggingface_hub import HfApi
api = HfApi(token=os.environ.get("HUGGINGFACE_HUB_TOKEN"))
commits = api.list_repo_commits(repo_id=self._repo, repo_type="model")
if commits:
return commits[0].commit_id
except Exception as exc: # noqa: BLE001
_log.debug("Could not fetch latest commit sha: %s", exc)
return None
def _build_grpo_config(config: TrainingConfig) -> GRPOConfig:
effective_batch = (
config.per_device_train_batch_size * config.gradient_accumulation_steps
)
if effective_batch % config.num_generations != 0:
raise ValueError(
f"effective_batch_size ({effective_batch}) must be divisible by "
f"num_generations ({config.num_generations}). Adjust "
"per_device_train_batch_size, gradient_accumulation_steps, or "
"num_generations."
)
hub_kwargs: dict = {}
if config.hub_checkpoint_repo_id:
hub_kwargs = dict(
push_to_hub=True,
hub_model_id=config.hub_checkpoint_repo_id,
hub_strategy="checkpoint",
hub_token=os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN"),
)
return GRPOConfig(
output_dir=config.output_dir,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.per_device_train_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
num_train_epochs=1,
max_steps=config.num_steps,
num_generations=config.num_generations,
max_completion_length=config.max_completion_length,
max_prompt_length=config.max_seq_length - config.max_completion_length,
temperature=config.temperature,
beta=config.beta,
logging_steps=1,
save_strategy="steps",
save_steps=max(50, config.num_steps // 6),
report_to=["wandb"],
run_name=config.wandb_run_name,
seed=config.seed,
bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
fp16=not torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
**hub_kwargs,
)
def _save_artifacts(
model: FastLanguageModel,
tokenizer: AutoTokenizer,
config: TrainingConfig,
) -> None:
"""Save model locally and optionally push to Hub."""
out_path = Path(config.output_dir)
out_path.mkdir(parents=True, exist_ok=True)
save_dir = out_path / config.save_method
model.save_pretrained_merged(
save_directory=str(save_dir),
tokenizer=tokenizer,
save_method=config.save_method,
)
if config.push_to_hub and config.hub_repo_id:
_log.info("Pushing %s artifact to Hugging Face Hub: %s", config.save_method, config.hub_repo_id)
model.push_to_hub_merged(
config.hub_repo_id,
tokenizer,
save_method=config.save_method,
token=os.environ.get("HUGGINGFACE_HUB_TOKEN"),
)
def _configure_logging() -> None:
logging.basicConfig(
level=os.environ.get("PHYSIX_LOG_LEVEL", "INFO"),
format="[%(asctime)s] %(levelname)s %(name)s | %(message)s",
)
def _parse_args() -> TrainingConfig:
parser = argparse.ArgumentParser(description="Train PhysiX-Live with GRPO.")
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct")
parser.add_argument("--output-dir", default="runs/physix-1.5b-rl")
parser.add_argument("--num-steps", type=int, default=300)
parser.add_argument("--learning-rate", type=float, default=5.0e-6)
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--max-completion-length", type=int, default=256,
help="Max tokens per rollout completion. Shorter = faster generation.")
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--instances-per-system", type=int, default=32)
parser.add_argument(
"--system-ids",
default=None,
help=(
"Comma-separated list of system IDs to train on "
"(e.g. 'damped_spring' or 'free_fall,simple_pendulum'). "
"Defaults to all SUPPORTED_SYSTEMS when omitted."
),
)
parser.add_argument(
"--ablation",
choices=("no_progress", "no_simplicity", "no_format"),
default=None,
)
parser.add_argument(
"--save-method",
choices=("lora", "merged_16bit", "merged_4bit"),
default="merged_16bit",
help="How to persist the final adapter (merged_16bit is deployable).",
)
parser.add_argument("--sft-checkpoint", default=None,
help="Path to a merged SFT model from sft.py to warm-start from.")
parser.add_argument(
"--lora-adapter-repo",
default=None,
help=(
"Hub repo id (or local path) of an existing LoRA adapter to warm-start "
"GRPO from — e.g. a previous run's checkpoint at "
"user/physix-1.5b-rl-ckpt. Mutually exclusive with --sft-checkpoint."
),
)
parser.add_argument("--wandb-project", default="physix-live")
parser.add_argument("--wandb-run-name", default=None)
parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--hub-repo-id", default=None)
parser.add_argument(
"--hub-checkpoint-repo-id",
default=None,
help="HF repo to push LoRA checkpoints to every save_steps (e.g. user/physix-ckpt).",
)
parser.add_argument(
"--resume-from-checkpoint",
default=None,
help="Path to a Trainer checkpoint directory to resume GRPO from.",
)
parser.add_argument(
"--early-stop-patience",
type=int,
default=50,
help=(
"Stop training early if reward_std stays below 0.05 for this many "
"consecutive steps (policy converged, GRPO advantage ≈ 0). "
"Set to 0 to disable."
),
)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
if args.sft_checkpoint and args.lora_adapter_repo:
parser.error(
"--sft-checkpoint and --lora-adapter-repo are mutually exclusive. "
"Use --lora-adapter-repo to resume from a prior GRPO run, or "
"--sft-checkpoint for a fresh GRPO from a merged SFT model."
)
system_ids = (
tuple(s.strip() for s in args.system_ids.split(",") if s.strip())
if args.system_ids
else SUPPORTED_SYSTEMS
)
return TrainingConfig(
model_name=args.model,
sft_checkpoint=args.sft_checkpoint,
lora_adapter_repo=args.lora_adapter_repo,
output_dir=args.output_dir,
num_steps=args.num_steps,
learning_rate=args.learning_rate,
num_generations=args.num_generations,
max_completion_length=args.max_completion_length,
lora_r=args.lora_r,
instances_per_system=args.instances_per_system,
system_ids=system_ids,
ablation=args.ablation,
save_method=args.save_method,
wandb_project=args.wandb_project,
wandb_run_name=args.wandb_run_name,
push_to_hub=args.push_to_hub,
hub_repo_id=args.hub_repo_id,
hub_checkpoint_repo_id=args.hub_checkpoint_repo_id,
resume_from_checkpoint=args.resume_from_checkpoint,
early_stop_patience=args.early_stop_patience,
seed=args.seed,
)
def main() -> None:
config = _parse_args()
os.environ.setdefault("WANDB_PROJECT", config.wandb_project)
train(config)
if __name__ == "__main__":
main()