"""GRPO training loop using Unsloth + TRL. Requires the ``[train]`` optional dependency group (heavy ML deps). """ from __future__ import annotations import argparse import logging import os from pathlib import Path from typing import Literal, Optional import torch from datasets import Dataset from pydantic import BaseModel, ConfigDict from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState from transformers import TrainingArguments as HFTrainingArguments from physix.systems import SUPPORTED_SYSTEMS from physix.training.dataset import ( DatasetSpec, build_training_dataset, ) from physix.training.reward_fns import make_reward_funcs from physix.training.scorer import Scorer # Unsloth patches must be applied before importing GRPOTrainer — order matters. # Requires trl<=0.24.0; newer versions break PatchFastRL. from unsloth import FastLanguageModel, PatchFastRL # noqa: E402 PatchFastRL("GRPO", FastLanguageModel) from trl import GRPOConfig, GRPOTrainer # noqa: E402 (must come after PatchFastRL) _log = logging.getLogger(__name__) Ablation = Literal["no_progress", "no_simplicity", "no_format"] SaveMethod = Literal["lora", "merged_16bit", "merged_4bit"] class TrainingConfig(BaseModel): """All hyperparameters in one place; the CLI populates this.""" model_config = ConfigDict(frozen=True) model_name: str = "Qwen/Qwen2.5-1.5B-Instruct" #: Path to merged SFT model to warm-start GRPO from. sft_checkpoint: Optional[str] = None #: Hub repo id or local path of an existing LoRA adapter to resume from. lora_adapter_repo: Optional[str] = None output_dir: str = "runs/physix-1.5b-rl" max_seq_length: int = 2048 lora_r: int = 16 lora_alpha: int = 32 learning_rate: float = 5.0e-6 temperature: float = 0.9 max_completion_length: int = 256 beta: float = 0.04 num_generations: int = 4 per_device_train_batch_size: int = 1 gradient_accumulation_steps: int = 8 num_steps: int = 300 #: Set to 0 to disable early stopping. early_stop_patience: int = 50 seed: int = 0 instances_per_system: int = 32 system_ids: tuple[str, ...] = SUPPORTED_SYSTEMS ablation: Optional[Ablation] = None wandb_project: str = "physix-live" wandb_run_name: Optional[str] = None push_to_hub: bool = False hub_repo_id: Optional[str] = None #: HF repo to push LoRA checkpoints to every save_steps. hub_checkpoint_repo_id: Optional[str] = None resume_from_checkpoint: Optional[str] = None save_method: SaveMethod = "merged_16bit" def train(config: TrainingConfig) -> None: """Run a full GRPO training loop with the given configuration.""" _configure_logging() import wandb run_name = config.wandb_run_name or f"physix-grpo-{config.num_steps}steps" wandb.init( project=config.wandb_project, name=run_name, config=config.model_dump(), tags=["grpo", "physix", config.model_name.split("/")[-1]], resume="allow", ) if config.hub_checkpoint_repo_id: ckpt_url = f"https://huggingface.co/{config.hub_checkpoint_repo_id}" wandb.run.summary["checkpoint/repo"] = config.hub_checkpoint_repo_id wandb.run.summary["checkpoint/repo_url"] = ckpt_url if config.hub_repo_id: wandb.run.summary["model/final_repo"] = config.hub_repo_id wandb.run.summary["model/final_url"] = ( f"https://huggingface.co/{config.hub_repo_id}" ) if config.lora_adapter_repo: wandb.run.summary["resume/from_adapter"] = config.lora_adapter_repo wandb.run.summary["resume/from_url"] = ( f"https://huggingface.co/{config.lora_adapter_repo}" ) parent_run = os.environ.get("WANDB_RESUMED_FROM") if parent_run: wandb.run.summary["resume/parent_wandb_run"] = parent_run wandb.run.summary["resume/parent_wandb_url"] = ( f"https://wandb.ai/{wandb.run.entity}/{wandb.run.project}/runs/{parent_run}" ) print( f"\n[wandb] WARM-STARTED run — adapter " f"https://huggingface.co/{config.lora_adapter_repo}\n", flush=True, ) _log.info("Loading model %s with Unsloth (4-bit, LoRA-%d)", config.model_name, config.lora_r) model, tokenizer = _load_model_and_tokenizer(config) train_dataset = _build_and_format_dataset(config, tokenizer) reward_funcs = _select_reward_funcs(config.ablation) grpo_config = _build_grpo_config(config) callbacks = [] if config.early_stop_patience > 0: callbacks.append(_RewardConvergenceCallback(patience=config.early_stop_patience)) _log.info( "Early stopping enabled: will stop if reward_std < 0.05 for %d consecutive steps", config.early_stop_patience, ) if config.hub_checkpoint_repo_id: callbacks.append(_WandbCheckpointCallback(config.hub_checkpoint_repo_id)) _log.info( "Checkpoint hub push enabled → %s (every %d steps)", config.hub_checkpoint_repo_id, grpo_config.save_steps, ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, args=grpo_config, train_dataset=train_dataset, reward_funcs=reward_funcs, callbacks=callbacks or None, ) if config.resume_from_checkpoint: _log.info("Resuming from checkpoint: %s", config.resume_from_checkpoint) _log.info("Starting GRPO training for %d steps", config.num_steps) trainer.train(resume_from_checkpoint=config.resume_from_checkpoint) _log_reward_summary(trainer) _render_training_curves(trainer, config) _log.info("Saving adapter (%s) to %s", config.save_method, config.output_dir) _save_artifacts(model, tokenizer, config) wandb.finish() def _log_reward_summary(trainer: "GRPOTrainer") -> None: """Log first→last reward delta for every component. Raises if no rewards were logged.""" history = getattr(trainer.state, "log_history", []) or [] reward_entries = [ entry for entry in history if any(k.startswith("rewards/") or k == "reward" for k in entry) ] if not reward_entries: _log.error( "No reward metrics logged during training. This usually means " "every rollout failed to parse. Check `train/reward` in W&B and " "the most recent completion samples." ) raise RuntimeError( "GRPO produced no reward metrics — training silently failed." ) last = reward_entries[-1] first = reward_entries[0] _log.info("=" * 60) _log.info("GRPO reward summary (first → last logged step):") for key in sorted(last): if key.startswith("rewards/") or key == "reward": v0 = first.get(key) v1 = last.get(key) if isinstance(v0, (int, float)) and isinstance(v1, (int, float)): _log.info(" %-40s %.4f → %.4f (Δ=%+.4f)", key, v0, v1, v1 - v0) _log.info("=" * 60) def _render_training_curves( trainer: "GRPOTrainer", config: TrainingConfig, ) -> None: """Render loss/reward/component PNGs from log_history and push to Hub.""" try: import matplotlib matplotlib.use("Agg") # headless / no display server in HF Jobs import matplotlib.pyplot as plt except Exception as exc: # noqa: BLE001 _log.warning("matplotlib unavailable, skipping curve PNGs: %s", exc) return history = list(getattr(trainer.state, "log_history", []) or []) if not history: _log.warning("No log_history found — cannot render curves.") return plots_dir = Path(config.output_dir) / "plots" plots_dir.mkdir(parents=True, exist_ok=True) def _series(metric: str) -> tuple[list[int], list[float]]: xs: list[int] = [] ys: list[float] = [] for entry in history: if metric in entry and "step" in entry: value = entry[metric] if isinstance(value, (int, float)): xs.append(int(entry["step"])) ys.append(float(value)) return xs, ys rendered: list[Path] = [] steps_l, losses = _series("loss") if steps_l: fig, ax = plt.subplots(figsize=(8, 4.5)) ax.plot(steps_l, losses, color="#d62728", linewidth=1.8) ax.set_xlabel("training step") ax.set_ylabel("GRPO surrogate loss") ax.set_title("PhysiX GRPO — train/loss (lower is better)") ax.grid(alpha=0.3) path = plots_dir / "loss.png" fig.tight_layout() fig.savefig(path, dpi=140) plt.close(fig) rendered.append(path) else: _log.warning("No 'loss' entries in log_history.") steps_r, rewards = _series("reward") _, reward_std = _series("reward_std") if steps_r: fig, ax = plt.subplots(figsize=(8, 4.5)) ax.plot(steps_r, rewards, color="#2ca02c", linewidth=2.0, label="mean reward") if reward_std and len(reward_std) == len(rewards): import numpy as np r = np.asarray(rewards) s = np.asarray(reward_std) ax.fill_between(steps_r, r - s, r + s, color="#2ca02c", alpha=0.18, label="±1σ across rollouts") ax.set_xlabel("training step") ax.set_ylabel("mean reward (sum of components)") ax.set_title("PhysiX GRPO — train/reward (higher is better)") ax.legend(loc="best") ax.grid(alpha=0.3) path = plots_dir / "reward.png" fig.tight_layout() fig.savefig(path, dpi=140) plt.close(fig) rendered.append(path) else: _log.warning("No 'reward' entries in log_history.") component_keys = sorted({ k for entry in history for k in entry if k.startswith("rewards/") and k.endswith("/mean") }) if component_keys: fig, ax = plt.subplots(figsize=(8, 4.5)) for k in component_keys: xs, ys = _series(k) if xs: label = k.removeprefix("rewards/").removesuffix("/mean") ax.plot(xs, ys, linewidth=1.6, label=label) ax.set_xlabel("training step") ax.set_ylabel("component mean reward") ax.set_title("PhysiX GRPO — per-component reward (rewards/*/mean)") ax.legend(loc="best", fontsize=8) ax.grid(alpha=0.3) path = plots_dir / "reward_components.png" fig.tight_layout() fig.savefig(path, dpi=140) plt.close(fig) rendered.append(path) if not rendered: _log.warning("No PNGs rendered — log_history had no recognised metrics.") return _log.info("Rendered %d curve PNG(s) to %s", len(rendered), plots_dir) try: import wandb if wandb.run is not None: wandb.log({ f"plots/{p.stem}": wandb.Image(str(p)) for p in rendered }) _log.info("Logged %d plot(s) to wandb.Media", len(rendered)) except Exception as exc: # noqa: BLE001 _log.warning("Could not log plots to wandb: %s", exc) if config.push_to_hub and config.hub_repo_id: try: from huggingface_hub import HfApi, create_repo api = HfApi(token=os.environ.get("HUGGINGFACE_HUB_TOKEN")) create_repo( repo_id=config.hub_repo_id, repo_type="model", exist_ok=True, token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), ) for p in rendered: api.upload_file( path_or_fileobj=str(p), path_in_repo=f"plots/{p.name}", repo_id=config.hub_repo_id, repo_type="model", commit_message=f"plots: {p.name}", ) _log.info( "Pushed %d plot(s) to https://huggingface.co/%s/tree/main/plots", len(rendered), config.hub_repo_id, ) except Exception as exc: # noqa: BLE001 _log.warning("Could not push plots to Hub: %s", exc) def _load_model_and_tokenizer( config: TrainingConfig, ) -> tuple[FastLanguageModel, AutoTokenizer]: """Load model via Unsloth in 4-bit and attach a LoRA adapter.""" if config.lora_adapter_repo: _log.info( "Resuming from existing LoRA adapter %s on top of %s", config.lora_adapter_repo, config.model_name, ) model, tokenizer = FastLanguageModel.from_pretrained( model_name=config.model_name, max_seq_length=config.max_seq_length, load_in_4bit=True, dtype=None, ) model = FastLanguageModel.get_peft_model( model, r=config.lora_r, lora_alpha=config.lora_alpha, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], bias="none", use_gradient_checkpointing="unsloth", random_state=config.seed, ) model.load_adapter( config.lora_adapter_repo, adapter_name="default", is_trainable=True, ) _log.info("Adapter loaded; LoRA is trainable and ready for GRPO.") return model, tokenizer if config.sft_checkpoint: _log.info( "Loading SFT-warmed model from %s (GRPO will refine from here)", config.sft_checkpoint, ) model, tokenizer = FastLanguageModel.from_pretrained( model_name=config.sft_checkpoint, max_seq_length=config.max_seq_length, load_in_4bit=True, dtype=None, ) else: _log.warning( "No --sft-checkpoint supplied. Starting GRPO from cold base model. " "Early reward signal will be near-zero; consider running sft.py first." ) model, tokenizer = FastLanguageModel.from_pretrained( model_name=config.model_name, max_seq_length=config.max_seq_length, load_in_4bit=True, dtype=None, ) model = FastLanguageModel.get_peft_model( model, r=config.lora_r, lora_alpha=config.lora_alpha, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], bias="none", use_gradient_checkpointing="unsloth", random_state=config.seed, ) return model, tokenizer def _build_and_format_dataset( config: TrainingConfig, tokenizer: AutoTokenizer, ) -> Dataset: spec = DatasetSpec( system_ids=config.system_ids, instances_per_system=config.instances_per_system, seed=config.seed, ) dataset = build_training_dataset(spec) _log.info( "Built training dataset: %d rows across %d systems (%s)", len(dataset), len(config.system_ids), ", ".join(config.system_ids), ) def _apply_chat_template(example: dict[str, object]) -> dict[str, object]: formatted = tokenizer.apply_chat_template( example["prompt"], tokenize=False, add_generation_prompt=True, ) return {"prompt": formatted} return dataset.map(_apply_chat_template) def _select_reward_funcs(ablation: Optional[Ablation]) -> list[object]: """Return the active reward function list, optionally with one signal ablated.""" scorer = Scorer() funcs = make_reward_funcs(scorer) full = [ funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["simplicity"], funcs["format"], ] if ablation is None: return full if ablation == "no_simplicity": return [funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["format"]] if ablation == "no_format": return [funcs["match"], funcs["match_dense"], funcs["correctness"], funcs["simplicity"]] if ablation == "no_progress": return full # progress was removed; treat as full set for backward compat raise ValueError( f"Unknown ablation {ablation!r}. Choose from " "no_progress | no_simplicity | no_format | None." ) class _RewardConvergenceCallback(TrainerCallback): """Stop early when reward_std stays below min_std for `patience` consecutive steps.""" def __init__(self, patience: int = 50, min_std: float = 0.05) -> None: self._patience = patience self._min_std = min_std self._flat_steps: int = 0 def on_log( self, args: HFTrainingArguments, state: TrainerState, control: TrainerControl, logs: dict | None = None, **kwargs, ) -> None: if not logs: return reward_std = logs.get("reward_std") if reward_std is None: return if reward_std < self._min_std: self._flat_steps += 1 else: self._flat_steps = 0 if self._flat_steps >= self._patience: step = state.global_step msg = ( f"[early-stop] reward_std < {self._min_std} for " f"{self._flat_steps} consecutive steps at step {step}. " "Stopping training — policy has converged." ) print(f"\n{msg}\n", flush=True) _log.info(msg) try: import wandb if wandb.run is not None: wandb.run.summary["early_stop/step"] = step wandb.run.summary["early_stop/reason"] = ( f"reward_std < {self._min_std} for {self._flat_steps} steps" ) wandb.log({"early_stop/triggered": 1}, step=step) except Exception as exc: # noqa: BLE001 _log.debug("Could not log early-stop event to W&B: %s", exc) control.should_training_stop = True class _WandbCheckpointCallback(TrainerCallback): """Log checkpoint metadata to W&B summary and stdout after each Trainer save.""" def __init__(self, hub_checkpoint_repo_id: str) -> None: self._repo = hub_checkpoint_repo_id self._repo_url = f"https://huggingface.co/{hub_checkpoint_repo_id}" self._table = None def on_train_begin( self, args: HFTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ) -> None: try: import wandb if wandb.run is None: return wandb.run.summary["checkpoint/repo_url"] = self._repo_url wandb.run.summary["checkpoint/repo"] = self._repo wandb.config.update( {"checkpoint_repo_url": self._repo_url, "checkpoint_repo": self._repo}, allow_val_change=True, ) print( f"\n[wandb] Checkpoint repo pinned in run summary: {self._repo_url}\n", flush=True, ) self._publish_wandb_run_id(wandb.run.id) except Exception as exc: # noqa: BLE001 _log.warning("Could not pin checkpoint repo to W&B summary: %s", exc) def _publish_wandb_run_id(self, run_id: str) -> None: try: import tempfile from huggingface_hub import HfApi, create_repo token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") api = HfApi(token=token) create_repo(self._repo, exist_ok=True, repo_type="model", token=token) with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as tmp: tmp.write(run_id) tmp_path = tmp.name api.upload_file( path_or_fileobj=tmp_path, path_in_repo="wandb_run_id.txt", repo_id=self._repo, repo_type="model", commit_message=f"Pin W&B run id {run_id}", token=token, ) print(f"[wandb] Published run_id={run_id} to {self._repo_url}/wandb_run_id.txt", flush=True) except Exception as exc: # noqa: BLE001 _log.warning("Could not publish wandb run id (non-fatal): %s", exc) def on_save( self, args: HFTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ) -> None: try: import wandb if wandb.run is None: return step = state.global_step commit_sha = self._latest_commit_sha() short = (commit_sha or "pending")[:8] tree_url = ( f"{self._repo_url}/tree/{commit_sha}" if commit_sha else f"{self._repo_url}/tree/main" ) wandb.run.summary["checkpoint/last_step"] = step wandb.run.summary["checkpoint/last_commit"] = commit_sha or "pending" wandb.run.summary["checkpoint/last_url"] = tree_url wandb.log({"checkpoint/step": step}, step=step) if self._table is None: self._table = wandb.Table( columns=["step", "commit", "url", "repo"] ) self._table.add_data(step, commit_sha or "pending", tree_url, self._repo) wandb.log({"checkpoint_history": self._table}, step=step) if commit_sha: from physix.training.checkpoints import ( CheckpointHandle, log_link_artifact_to_wandb, ) handle = CheckpointHandle( repo_id=self._repo, subfolder=f"checkpoint-{step}", revision=commit_sha, step=step, ) log_link_artifact_to_wandb( handle, artifact_name="physix-grpo-checkpoint", ) print( "\n" "================ CHECKPOINT SAVED ================\n" f" step : {step}\n" f" commit: {short}\n" f" url : {tree_url}\n" f" repo : {self._repo_url}\n" "==================================================\n", flush=True, ) _log.info( "W&B checkpoint metadata logged: step=%d commit=%s", step, short, ) except Exception as exc: # noqa: BLE001 _log.warning( "W&B checkpoint callback skipped at step %d: %s. " "Training continues; the actual checkpoint is still pushed " "to the HF Hub by the trainer's PushToHubCallback.", state.global_step, exc, ) def _latest_commit_sha(self) -> Optional[str]: """Best-effort fetch of the latest commit SHA; returns None on failure.""" try: from huggingface_hub import HfApi api = HfApi(token=os.environ.get("HUGGINGFACE_HUB_TOKEN")) commits = api.list_repo_commits(repo_id=self._repo, repo_type="model") if commits: return commits[0].commit_id except Exception as exc: # noqa: BLE001 _log.debug("Could not fetch latest commit sha: %s", exc) return None def _build_grpo_config(config: TrainingConfig) -> GRPOConfig: effective_batch = ( config.per_device_train_batch_size * config.gradient_accumulation_steps ) if effective_batch % config.num_generations != 0: raise ValueError( f"effective_batch_size ({effective_batch}) must be divisible by " f"num_generations ({config.num_generations}). Adjust " "per_device_train_batch_size, gradient_accumulation_steps, or " "num_generations." ) hub_kwargs: dict = {} if config.hub_checkpoint_repo_id: hub_kwargs = dict( push_to_hub=True, hub_model_id=config.hub_checkpoint_repo_id, hub_strategy="checkpoint", hub_token=os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN"), ) return GRPOConfig( output_dir=config.output_dir, learning_rate=config.learning_rate, per_device_train_batch_size=config.per_device_train_batch_size, gradient_accumulation_steps=config.gradient_accumulation_steps, num_train_epochs=1, max_steps=config.num_steps, num_generations=config.num_generations, max_completion_length=config.max_completion_length, max_prompt_length=config.max_seq_length - config.max_completion_length, temperature=config.temperature, beta=config.beta, logging_steps=1, save_strategy="steps", save_steps=max(50, config.num_steps // 6), report_to=["wandb"], run_name=config.wandb_run_name, seed=config.seed, bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False, fp16=not torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False, **hub_kwargs, ) def _save_artifacts( model: FastLanguageModel, tokenizer: AutoTokenizer, config: TrainingConfig, ) -> None: """Save model locally and optionally push to Hub.""" out_path = Path(config.output_dir) out_path.mkdir(parents=True, exist_ok=True) save_dir = out_path / config.save_method model.save_pretrained_merged( save_directory=str(save_dir), tokenizer=tokenizer, save_method=config.save_method, ) if config.push_to_hub and config.hub_repo_id: _log.info("Pushing %s artifact to Hugging Face Hub: %s", config.save_method, config.hub_repo_id) model.push_to_hub_merged( config.hub_repo_id, tokenizer, save_method=config.save_method, token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), ) def _configure_logging() -> None: logging.basicConfig( level=os.environ.get("PHYSIX_LOG_LEVEL", "INFO"), format="[%(asctime)s] %(levelname)s %(name)s | %(message)s", ) def _parse_args() -> TrainingConfig: parser = argparse.ArgumentParser(description="Train PhysiX-Live with GRPO.") parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") parser.add_argument("--output-dir", default="runs/physix-1.5b-rl") parser.add_argument("--num-steps", type=int, default=300) parser.add_argument("--learning-rate", type=float, default=5.0e-6) parser.add_argument("--num-generations", type=int, default=4) parser.add_argument("--max-completion-length", type=int, default=256, help="Max tokens per rollout completion. Shorter = faster generation.") parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--instances-per-system", type=int, default=32) parser.add_argument( "--system-ids", default=None, help=( "Comma-separated list of system IDs to train on " "(e.g. 'damped_spring' or 'free_fall,simple_pendulum'). " "Defaults to all SUPPORTED_SYSTEMS when omitted." ), ) parser.add_argument( "--ablation", choices=("no_progress", "no_simplicity", "no_format"), default=None, ) parser.add_argument( "--save-method", choices=("lora", "merged_16bit", "merged_4bit"), default="merged_16bit", help="How to persist the final adapter (merged_16bit is deployable).", ) parser.add_argument("--sft-checkpoint", default=None, help="Path to a merged SFT model from sft.py to warm-start from.") parser.add_argument( "--lora-adapter-repo", default=None, help=( "Hub repo id (or local path) of an existing LoRA adapter to warm-start " "GRPO from — e.g. a previous run's checkpoint at " "user/physix-1.5b-rl-ckpt. Mutually exclusive with --sft-checkpoint." ), ) parser.add_argument("--wandb-project", default="physix-live") parser.add_argument("--wandb-run-name", default=None) parser.add_argument("--push-to-hub", action="store_true") parser.add_argument("--hub-repo-id", default=None) parser.add_argument( "--hub-checkpoint-repo-id", default=None, help="HF repo to push LoRA checkpoints to every save_steps (e.g. user/physix-ckpt).", ) parser.add_argument( "--resume-from-checkpoint", default=None, help="Path to a Trainer checkpoint directory to resume GRPO from.", ) parser.add_argument( "--early-stop-patience", type=int, default=50, help=( "Stop training early if reward_std stays below 0.05 for this many " "consecutive steps (policy converged, GRPO advantage ≈ 0). " "Set to 0 to disable." ), ) parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() if args.sft_checkpoint and args.lora_adapter_repo: parser.error( "--sft-checkpoint and --lora-adapter-repo are mutually exclusive. " "Use --lora-adapter-repo to resume from a prior GRPO run, or " "--sft-checkpoint for a fresh GRPO from a merged SFT model." ) system_ids = ( tuple(s.strip() for s in args.system_ids.split(",") if s.strip()) if args.system_ids else SUPPORTED_SYSTEMS ) return TrainingConfig( model_name=args.model, sft_checkpoint=args.sft_checkpoint, lora_adapter_repo=args.lora_adapter_repo, output_dir=args.output_dir, num_steps=args.num_steps, learning_rate=args.learning_rate, num_generations=args.num_generations, max_completion_length=args.max_completion_length, lora_r=args.lora_r, instances_per_system=args.instances_per_system, system_ids=system_ids, ablation=args.ablation, save_method=args.save_method, wandb_project=args.wandb_project, wandb_run_name=args.wandb_run_name, push_to_hub=args.push_to_hub, hub_repo_id=args.hub_repo_id, hub_checkpoint_repo_id=args.hub_checkpoint_repo_id, resume_from_checkpoint=args.resume_from_checkpoint, early_stop_patience=args.early_stop_patience, seed=args.seed, ) def main() -> None: config = _parse_args() os.environ.setdefault("WANDB_PROJECT", config.wandb_project) train(config) if __name__ == "__main__": main()