"""scripts/train_grpo.py - GRPO RL phase (2026-04 spec rewrite). Loads the SFT-warm-started LoRA adapter at ``checkpoints/sft_warmup/checkpoint-50`` on top of the 4-bit NF4 quantised ``unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit`` base, connects to the OpenEnv server (local or remote via ``QUBIT_MEDIC_URL``), and runs TRL's :class:`GRPOTrainer` for 1,500 steps with diversity-focused rollout sampling (temperature=1.2, top_p=0.95, top_k=50, repetition_penalty=1.1) and a weighted 5-component reward bounded to [0, 1]. Why diversity-focused sampling ------------------------------ The first GRPO attempt (temperature=0.7) collapsed inside 100 steps to a constant ``X_ERRORS=[] Z_ERRORS=[]`` policy: every group of 4 generations was byte-identical, so within-group reward variance was zero and the GRPO advantage was exactly zero - no gradient. The new sampler defaults broaden per-token entropy enough to keep within-group variance positive, which is what GRPO needs to learn anything. Major spec features wired up here --------------------------------- * ``_diversity_preflight`` - 5 prompts x 8 completions at T=1.2; abort if fewer than 3 prompts hit >=3 unique completions. The model is too collapsed for GRPO to recover. * Frozen 200-syndrome eval set seeded ``4284`` (matches SFT validation seed). Cached to ``data/grpo_validation.jsonl`` so reruns and offline inspection see the same prompts. * Tier-1 training metrics (every 5 steps): total_reward_mean, reward_std_within_group, completion_uniqueness, advantage_mean_abs, kl_divergence, grad_norm, policy_loss, learning_rate. * Tier-2 eval metrics (every 100 steps, greedy at T=0): logical correction rate, pymatching beat rate, format compliance, exact-match pymatching, hard-syndrome (>=2 errors) LCR, syndrome consistency, avg_completion_length, output_diversity at T=1.0. * Tier-3 (every eval): per-round logical-error rate at d=3 p=0.001 plus log10 transform. * Sample-completion table every 50 steps: 5 random eval prompts, the 4 rollouts each, per-component rewards, parsed action. * Anti-hacking: 30s per-episode timeout (server-side), reward bounds enforced both pre-multiply and post-sum, mode-collapse inspection every 100 steps that auto-raises temperature by 0.2 if >7 of the last 10 prompts produced 4 byte-identical generations. * Wall-clock cap: 13h. Saves+exits cleanly if exceeded. * Best-checkpoint tracking: writes ``output/best/`` whenever a new best ``eval/total_reward_mean`` is observed. Final state always saves to ``output/final/`` regardless of rank. * Decision rules (warnings only, no auto-fix): step-50 reward variance floor, step-500 pymatching-beat sanity, format-compliance floor, and 3-consecutive-log grad-norm runaway alarm. Usage:: python -m scripts.train_grpo \ --sft-checkpoint checkpoints/sft_warmup/checkpoint-50 \ --output checkpoints/grpo \ --report-to wandb """ from __future__ import annotations import argparse import inspect import json import os import random # torch._dynamo recompile-limit guard. Unsloth's GRPO trainer wraps the # loss/generation graph in torch.compile(fullgraph=True). Two things blow # past Dynamo's default cache_size_limit (8) over a long GRPO run: # 1. The mode-collapse hook mutates trainer.args.temperature in flight # (e.g. 1.2 -> 1.4 -> 1.6 -> 1.8); each mutation re-specializes the # compiled generation path. # 2. Variable prompt/completion shapes specialize over hundreds of steps. # When the limit is hit, fullgraph=True turns it into a fatal # FailOnRecompileLimitHit (we lost a run at step 400/1500 to this). Set the # limits high before torch is imported so they take effect everywhere. os.environ.setdefault("TORCHDYNAMO_CACHE_SIZE_LIMIT", "256") os.environ.setdefault("TORCHDYNAMO_RECOMPILE_LIMIT", "256") import shutil import sys import threading import time from collections import deque from dataclasses import dataclass, field from pathlib import Path from typing import Iterable, Optional # --------------------------------------------------------------------------- # # Pre-flight: detect the unsloth / unsloth_zoo signature skew that crashes # # GRPO at step 0 with a misleading TypeError. # # --------------------------------------------------------------------------- # # # unsloth==2025.11.1's GRPO trainer template calls # grpo_accumulated_loss(..., old_hidden_states=..., ref_hidden_states=...) # but unsloth_zoo>=2026.x renamed those positional args to old_logps / ref_logps # with no compat shim. Pip's resolver (with the unpinned `unsloth` line in # requirements-train.txt) silently couples the two: it picks # unsloth==2025.11.1 + unsloth_zoo==2026.4.9 # and that pair crashes at the first training step with: # TypeError: grpo_accumulated_loss() missing 2 required positional # arguments: 'old_logps' and 'ref_logps' # # SFT does not exercise this code path, so SFT finishes cleanly first, the # checkpoint gets saved, and only then GRPO blows up - wasting the whole SFT # run. This guard runs in well under a second, before any GPU work, and # prints the exact pip command to fix it instead of the cryptic TypeError. # --------------------------------------------------------------------------- # def _assert_grpo_signature_compatible() -> None: """Abort early if the installed unsloth_zoo signature does not match the call pattern baked into the installed unsloth. """ try: import unsloth # noqa: F401 (force the patches to apply first) import unsloth_zoo from unsloth_zoo.rl_replacements import grpo_accumulated_loss except Exception as exc: print(f"[grpo-guard] WARNING: could not introspect unsloth_zoo " f"({exc!r}); skipping signature check.", file=sys.stderr) return params = list(inspect.signature(grpo_accumulated_loss).parameters.keys()) has_hidden = "old_hidden_states" in params and "ref_hidden_states" in params has_logps = "old_logps" in params and "ref_logps" in params # The unsloth in this repo is pinned to the 2025.11.x lineage (matches # what SFT just used). That lineage calls with old_hidden_states= / # ref_hidden_states=. If unsloth_zoo has those names, we are fine. if has_hidden: return unsloth_ver = getattr(unsloth, "__version__", "?") zoo_ver = getattr(unsloth_zoo, "__version__", "?") have_logps_only = has_logps and not has_hidden msg = [ "", "=" * 78, "[grpo-guard] FATAL: unsloth / unsloth_zoo signature mismatch detected.", "=" * 78, f" unsloth == {unsloth_ver}", f" unsloth_zoo == {zoo_ver}", f" grpo_accumulated_loss parameters: {params}", "", " unsloth (this version) calls grpo_accumulated_loss with", " old_hidden_states=... , ref_hidden_states=...", " but the installed unsloth_zoo expects", " old_logps=... , ref_logps=...", " as required positional arguments." if have_logps_only else " but the installed unsloth_zoo signature does not contain those names.", "", " Without this fix, GRPO will crash at step 0 with:", " TypeError: grpo_accumulated_loss() missing 2 required positional", " arguments: 'old_logps' and 'ref_logps'", "", " Fix on Colab (one-liner):", " pip install --no-deps --force-reinstall unsloth_zoo==2025.11.1 \\", " && rm -rf unsloth_compiled_cache", "", " Then re-run:", " python -m scripts.train_grpo --sft-checkpoint " "checkpoints/sft_warmup/checkpoint-50 \\", " --output checkpoints/grpo", "=" * 78, "", ] raise SystemExit("\n".join(msg)) def _wipe_stale_grpo_cache() -> None: """Remove unsloth_compiled_cache/UnslothGRPOTrainer.py if present. The cache file is regenerated automatically by unsloth on the next GRPO import using the *currently installed* unsloth_zoo source, so deleting it is safe and is the only way to recover after fixing a previously-mismatched install. """ cache_file = Path("unsloth_compiled_cache") / "UnslothGRPOTrainer.py" if cache_file.exists(): print(f"[grpo-guard] removing stale {cache_file} so it regenerates " f"against the current unsloth_zoo install") try: cache_file.unlink() except OSError as exc: print(f"[grpo-guard] WARNING: failed to remove {cache_file}: " f"{exc!r}", file=sys.stderr) # --------------------------------------------------------------------------- # # Per-batch scoring cache + reward bounds enforcement # # --------------------------------------------------------------------------- # # # The original implementation called the env 5 times per (prompt, completion) # - once per reward function. We fix that with a single (prompt, completion) # -> breakdown cache keyed inside one GRPO step, AND we apply the spec's # weighted-sum + [0, 1] clip in one place so every reward function returns # a number that's already correctly weighted. # --------------------------------------------------------------------------- # @dataclass class _ScoredCompletion: """One scored (prompt, completion) pair, keyed by the env episode.""" rewards: dict # raw per-component rewards from the env (in [0, 1]) weighted_total: float # weighted sum, clipped to [0, 1] parse_success: bool parse_partial: bool x_pred: list z_pred: list actual_flip: int pm_flip: int elapsed: float timed_out: bool curriculum_level: str bounds_violations: int # >0 if env returned a component outside [0, 1] @dataclass class _BatchScoringCache: """Caches per-(prompt, completion) scores within one GRPO step.""" env_client: object reward_weights: dict _cache: dict = field(default_factory=dict) _step_keys: list = field(default_factory=list) _lock: threading.Lock = field(default_factory=threading.Lock) _all_curriculum_stats: dict = field(default_factory=dict) _episodes: int = 0 _timeouts: int = 0 _bounds_violations: int = 0 def _enforce_bounds(self, name: str, val: float) -> tuple[float, bool]: """Clip a reward component to [0, 1]; flag if it was outside.""" v = float(val) if v < 0.0 or v > 1.0: return max(0.0, min(1.0, v)), True return v, False def score(self, prompt: str, completion: str) -> _ScoredCompletion: key = (prompt, completion) with self._lock: entry = self._cache.get(key) if entry is not None: return entry # Env work is independent across (p, c) so it's safe to release the # lock during the network round-trip. obs = self.env_client.reset() result = self.env_client.step(raw_response=completion, episode_id=obs.episode_id) info = result.info action = info.get("parsed_action", {}) # Apply spec weights + [0, 1] bounds enforcement. raw = info.get("rewards", {}) or {} violations = 0 weighted_sum = 0.0 bounded_components: dict = {} for name, weight in self.reward_weights.items(): v, was_oob = self._enforce_bounds(name, raw.get(name, 0.0)) bounded_components[name] = v weighted_sum += weight * v if was_oob: violations += 1 # Clip weighted sum to [0, 1] (already in range when components # are; defensive against weights that don't sum to 1.0). weighted_total = max(0.0, min(1.0, weighted_sum)) # Preserve env's "total" alongside our weighted total so downstream # wandb log_reward_breakdown still works. bounded_components["total"] = weighted_total scored = _ScoredCompletion( rewards=bounded_components, weighted_total=weighted_total, parse_success=bool(action.get("parse_success", False)), parse_partial=False, x_pred=list(action.get("x_error_qubits", []) or []), z_pred=list(action.get("z_error_qubits", []) or []), actual_flip=int(info.get("actual_observable_flip", 0)), pm_flip=int(info.get("pymatching_observable_pred", 0)), elapsed=float(info.get("elapsed_seconds", 0.0)), timed_out=bool(info.get("timed_out", False)), curriculum_level=str(getattr(obs, "curriculum_level", "")), bounds_violations=violations, ) with self._lock: self._cache[key] = scored self._step_keys.append(key) self._all_curriculum_stats = info.get("curriculum_stats", {}) or {} self._episodes += 1 if scored.timed_out: self._timeouts += 1 if violations: self._bounds_violations += violations return scored def drain_step(self): """Pop everything cached since the last drain_step() call.""" with self._lock: entries = [self._cache[k] for k in self._step_keys] keys = list(self._step_keys) self._step_keys.clear() # Bound memory use - long runs with unique strings. if len(self._cache) > 4096: self._cache.clear() return entries, keys def _seed_everything(seed: int) -> None: import numpy as np import torch random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # --------------------------------------------------------------------------- # # Reward function factory # # --------------------------------------------------------------------------- # # # Spec: total reward = sum of weighted components, clipped to [0, 1]. # Implementation: the cache returns a per-completion weighted_total in # [0, 1]. We expose ONE TRL reward function that returns that bounded # total, plus zero-weight per-component observers so wandb gets per- # component traces without altering the policy gradient. # --------------------------------------------------------------------------- # _REWARD_COMPONENTS = ( "logical_correction", "hamming_overlap", "syndrome_consistency", "format_compliance", "pymatching_beat", ) def _make_reward_fns(cache: _BatchScoringCache): def total_fn(prompts, completions, **_unused): scored = [cache.score(p, c) for p, c in zip(prompts, completions)] return [s.weighted_total for s in scored] total_fn.__name__ = "reward_total" observers: list = [] for name in _REWARD_COMPONENTS: def _factory(component_name: str): def fn(prompts, completions, **_unused): scored = [cache.score(p, c) for p, c in zip(prompts, completions)] return [s.rewards.get(component_name, 0.0) for s in scored] fn.__name__ = f"reward_obs_{component_name}" return fn observers.append(_factory(name)) return [total_fn] + observers # --------------------------------------------------------------------------- # # Frozen eval set: 200 syndromes seeded GRPO_VAL_SEED. # # --------------------------------------------------------------------------- # # # We snapshot the 200 prompts to data/grpo_validation.jsonl on first run so # reruns hit byte-identical syndromes, and so the file can be inspected / # diffed offline. If the file already exists with >= n rows, we trust it. # --------------------------------------------------------------------------- # def _load_or_build_eval_set(env_client, *, seed: int, n: int, path: str) -> list[dict]: p = Path(path) if p.exists(): rows: list[dict] = [] with p.open("r") as f: for line in f: line = line.strip() if not line: continue rows.append(json.loads(line)) if len(rows) >= n: print(f"[grpo-eval] reusing cached eval set: {p} ({len(rows)} rows)") return rows[:n] print(f"[grpo-eval] cached eval set at {p} has {len(rows)} < {n} rows; " f"regenerating") p.parent.mkdir(parents=True, exist_ok=True) rows = [] print(f"[grpo-eval] building frozen eval set seed={seed} n={n} -> {p}") cur_seed = seed for _ in range(n): obs = env_client.reset(seed=cur_seed) rows.append({ "prompt": obs.prompt, "episode_id": int(obs.episode_id), "curriculum_level": str(getattr(obs, "curriculum_level", "")), "distance": int(getattr(obs, "distance", 0)), "rounds": int(getattr(obs, "rounds", 0)), "p": float(getattr(obs, "p", 0.0)), "syndrome_bits": list(getattr(obs, "syndrome_bits", []) or []), "seed": cur_seed, }) cur_seed += 1 # deterministic, reproducible with p.open("w") as f: for r in rows: f.write(json.dumps(r) + "\n") print(f"[grpo-eval] wrote {len(rows)} eval rows to {p}") return rows # --------------------------------------------------------------------------- # # Diversity preflight # # --------------------------------------------------------------------------- # def _diversity_preflight(model, tokenizer, *, val_path: str, n_prompts: int = 5, n_samples_per_prompt: int = 8, temperature: float = 1.2, min_unique: int = 3, min_passing: int = 3, max_new_tokens: int = 50) -> bool: """Generate ``n_samples_per_prompt`` completions per prompt at high temp. Returns True iff at least ``min_passing`` of the prompts produced >= ``min_unique`` unique completions (byte-equal under skip-special-tokens decoding). False -> the model is collapsed past the point where GRPO can recover, so we should refuse to start training. """ import torch src = Path(val_path) if not src.exists(): print(f"[grpo-preflight] WARNING: {val_path} not found; " f"skipping diversity preflight") return True # don't block on missing file rows: list[dict] = [] with src.open("r") as f: for line in f: line = line.strip() if not line: continue rows.append(json.loads(line)) if len(rows) < n_prompts: print(f"[grpo-preflight] WARNING: only {len(rows)} validation rows " f"available, need {n_prompts}; using all") n_prompts = len(rows) # Mix of trivial (no errors) and non-trivial (errors present), so the # diversity probe sees both regimes the model has to handle. rng = random.Random(0) trivial = [r for r in rows if not r.get("had_errors")] non_trivial = [r for r in rows if r.get("had_errors")] rng.shuffle(trivial) rng.shuffle(non_trivial) half = max(1, n_prompts // 2) chosen = (non_trivial[:half] + trivial[:n_prompts - half])[:n_prompts] if not chosen: chosen = rows[:n_prompts] print(f"[grpo-preflight] probing diversity at T={temperature} on " f"{len(chosen)} prompts x {n_samples_per_prompt} samples each") try: from unsloth import FastLanguageModel FastLanguageModel.for_inference(model) except Exception: model.eval() passing = 0 per_prompt_unique: list[int] = [] pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id for i, row in enumerate(chosen): prompt = row["prompt"] try: chat = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) except Exception: text = ("<|im_start|>user\n" + prompt + "\n<|im_end|>\n<|im_start|>assistant\n") inputs = tokenizer(text, return_tensors="pt").to(model.device) completions: list[str] = [] for _ in range(n_samples_per_prompt): with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=0.95, top_k=50, repetition_penalty=1.1, eos_token_id=tokenizer.eos_token_id, pad_token_id=pad_id, ) gen_ids = out[0][inputs["input_ids"].shape[1]:] txt = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() completions.append(txt) unique = len(set(completions)) per_prompt_unique.append(unique) verdict = "PASS" if unique >= min_unique else "FAIL" print(f"[grpo-preflight] prompt {i}: {unique}/{n_samples_per_prompt} " f"unique [{verdict}] examples={completions[:2]!r}") if unique >= min_unique: passing += 1 overall = passing >= min_passing print(f"[grpo-preflight] {passing}/{len(chosen)} prompts passed " f"(threshold: >= {min_passing}). per_prompt_unique={per_prompt_unique}") if not overall: print("=" * 78) print("[grpo-preflight] PRE-FLIGHT FAILED - model is too collapsed; " "redo SFT with regularization before launching GRPO") print("=" * 78) return overall # --------------------------------------------------------------------------- # # In-loop W&B callback (tier-1 + tier-2 + tier-3 + sample table + safeguards) # # --------------------------------------------------------------------------- # def _build_wandb_callback(cache, model, tokenizer, env_client, eval_rows, *, sample_every: int, sample_n: int, inloop_every: int, inloop_max_new_tokens: int, kl_alarm: float, inspection_every: int, inspection_sample_n: int, inspection_collapse_threshold: int, temp_bump_on_collapse: float, best_dir: Path, output_dir: Path, wall_seconds: float, decision_thresholds: dict): from transformers import TrainerCallback from qubit_medic import wandb_utils if not wandb_utils.is_available(): return None started_at = time.time() # Rolling cache for the inspection hook: we record (group_unique_count) # per prompt as it streams in, and at every inspection_every-step # boundary look at the most recent inspection_sample_n entries. recent_uniques = deque(maxlen=max(inspection_sample_n, 16)) grad_norm_run = deque(maxlen=decision_thresholds["grad_norm_run_len"]) state = { "best_total_reward": float("-inf"), "best_step": -1, "wall_exceeded": False, "step50_warned": False, "step500_warned": False, "format_warned_at": -1, "grad_norm_warned_at": -1, "beat_rate_history": [], } class _RolloutCallback(TrainerCallback): # ------------------------------------------------------------------ # # Per-step instrumentation # # ------------------------------------------------------------------ # def on_step_end(self, args, state_, control, **kwargs): # noqa: D401 entries, keys = cache.drain_step() if not entries: return step = state_.global_step # Group entries by prompt so we can compute within-group stats. groups: list[list[_ScoredCompletion]] = [] current_prompt = None current: list[_ScoredCompletion] = [] for (p, _), e in zip(keys, entries): if p != current_prompt and current: groups.append(current) current = [] current_prompt = p current.append(e) if current: groups.append(current) # ----- Tier-1 metrics ----- # totals = [e.weighted_total for e in entries] if not totals: return mean_total = sum(totals) / len(totals) within_stds: list[float] = [] uniques: list[int] = [] for grp in groups: if len(grp) < 2: within_stds.append(0.0) uniques.append(1) continue vals = [e.weighted_total for e in grp] mu = sum(vals) / len(vals) var = sum((v - mu) ** 2 for v in vals) / len(vals) within_stds.append(var ** 0.5) key_set = {(tuple(e.x_pred), tuple(e.z_pred)) for e in grp} uniques.append(len(key_set)) mean_within_std = sum(within_stds) / max(1, len(within_stds)) mean_unique = sum(uniques) / max(1, len(uniques)) # GRPO advantage (recomputed locally for the log only). adv_abs: list[float] = [] for grp in groups: if len(grp) < 2: continue vals = [e.weighted_total for e in grp] mu = sum(vals) / len(vals) var = sum((v - mu) ** 2 for v in vals) / len(vals) std = max((var ** 0.5), 1e-4) adv_abs.extend(abs((v - mu) / std) for v in vals) mean_adv_abs = sum(adv_abs) / max(1, len(adv_abs)) wandb_utils.log({ "train/total_reward_mean": mean_total, "train/reward_std_within_group": mean_within_std, "train/completion_uniqueness": mean_unique, "train/advantage_mean_abs": mean_adv_abs, "train/global_step": step, }, step=step) wandb_utils.log_reward_breakdown( [e.rewards for e in entries], step=step, prefix="train", ) wandb_utils.log({ "train/reward_bounds_violations_total": cache._bounds_violations, "train/env_episodes_total": cache._episodes, "train/env_timeouts_total": cache._timeouts, }, step=step) # ----- Decision rule: step 50 within-group variance ----- # if (not state["step50_warned"] and step >= decision_thresholds["reward_std_check_step"]): if mean_within_std < decision_thresholds["reward_std_floor"]: print(f"\n[grpo-decision] CRITICAL @ step {step}: " f"train/reward_std_within_group={mean_within_std:.4f} " f"< {decision_thresholds['reward_std_floor']}. The " f"within-group reward std has collapsed; GRPO has " f"effectively zero advantage signal. Pausing for " f"manual review (warning only - no auto-action).") wandb_utils.log({ "alarms/reward_std_collapse": 1.0, "alarms/reward_std_value": mean_within_std, }, step=step) state["step50_warned"] = True # Compliance Section 8 (audit, 2026-04): continuous warning # for reward_std < 0.02 at ANY step, not only step 50. We # throttle to once per 100 steps so the message doesn't # spam every 5-step log line. The existing step-50 gate # above stays as the harder "pause for review" check at # the higher 0.03 threshold; this continuous one fires # earlier the moment within-group variance crosses the # spec floor and tells the operator to look at the run. CONT_REWARD_STD_FLOOR = 0.02 if mean_within_std < CONT_REWARD_STD_FLOOR: last_warn = state.get("reward_std_warned_at", -1) if step - last_warn >= 100: print(f"\n[grpo-warn] @ step {step}: " f"train/reward_std_within_group=" f"{mean_within_std:.4f} < {CONT_REWARD_STD_FLOOR} " f"(continuous alarm). GRPO advantage signal is " f"vanishing - inspect generations / temperature.") wandb_utils.log({ "alarms/reward_std_continuous_low": 1.0, "alarms/reward_std_value": mean_within_std, }, step=step) state["reward_std_warned_at"] = step # ----- Mode-collapse inspection hook ----- # for u in uniques: recent_uniques.append(u) if (inspection_every and step > 0 and step % inspection_every == 0 and len(recent_uniques) >= inspection_sample_n): last = list(recent_uniques)[-inspection_sample_n:] collapsed_count = sum(1 for u in last if u == 1) if collapsed_count > inspection_collapse_threshold: cur_temp = float(getattr(args, "temperature", 1.2)) # Cap the bump at 2.0 - going higher does not actually # produce more diversity (sampler is already at top-k=50 # / top-p=0.95) and every distinct value re-specializes # the torch.compile cache, eventually tripping # FailOnRecompileLimitHit even with raised limits. new_temp = min(2.0, cur_temp + temp_bump_on_collapse) if new_temp <= cur_temp + 1e-6: print(f"\n[grpo-inspection] WARN @ step {step}: " f"{collapsed_count}/{inspection_sample_n} prompts " f"collapsed but temperature already at cap " f"({cur_temp:.2f}); leaving unchanged.") else: print(f"\n[grpo-inspection] WARN @ step {step}: " f"{collapsed_count}/{inspection_sample_n} of the " f"most recent prompts had ALL 4 generations " f"identical. Bumping rollout temperature " f"{cur_temp:.2f} -> {new_temp:.2f}.") try: args.temperature = new_temp except Exception as exc: print(f"[grpo-inspection] could not patch temperature " f"on TRL args: {exc!r}") wandb_utils.log({ "alarms/mode_collapse_count": collapsed_count, "train/temperature_after_bump": new_temp, }, step=step) # ----- Sample-completion table every sample_every steps ----- # if sample_every and step > 0 and step % sample_every == 0: rows_out = [] # First sample_n unique prompts in this batch; emit a row per # generation (so the W&B table has gen_idx as a column). chosen_groups: list[tuple[str, list[_ScoredCompletion]]] = [] seen_prompts: set = set() for (p, _), e in zip(keys, entries): if p in seen_prompts: for q, grp in chosen_groups: if q == p: grp.append(e) break continue if len(chosen_groups) >= sample_n: continue chosen_groups.append((p, [e])) seen_prompts.add(p) for prompt, grp in chosen_groups: for gi, e in enumerate(grp[:4]): rows_out.append({ "step": step, "prompt": prompt[:600], "gen_idx": gi, "x_pred": ",".join(map(str, e.x_pred)), "z_pred": ",".join(map(str, e.z_pred)), "logical_correction": e.rewards.get("logical_correction", 0.0), "syndrome_consistency": e.rewards.get("syndrome_consistency", 0.0), "hamming_overlap": e.rewards.get("hamming_overlap", 0.0), "format_compliance": e.rewards.get("format_compliance", 0.0), "pymatching_beat": e.rewards.get("pymatching_beat", 0.0), "weighted_total": e.weighted_total, "parse_success": e.parse_success, "actual_obs_flip": e.actual_flip, "pm_obs_flip": e.pm_flip, "curriculum_level": e.curriculum_level, }) if rows_out: wandb_utils.log_generation_table( rows_out, step=step, table_name="rl/generations", columns=[ "step", "prompt", "gen_idx", "x_pred", "z_pred", "logical_correction", "syndrome_consistency", "hamming_overlap", "format_compliance", "pymatching_beat", "weighted_total", "parse_success", "actual_obs_flip", "pm_obs_flip", "curriculum_level", ], ) # ----- Wall-clock cap ----- # elapsed = time.time() - started_at if elapsed > wall_seconds and not state["wall_exceeded"]: state["wall_exceeded"] = True print(f"\n[grpo-walltime] wall-clock cap hit at step {step} " f"({elapsed:.0f}s > {wall_seconds:.0f}s). " f"Saving and exiting.") try: control.should_save = True control.should_training_stop = True except Exception: pass wandb_utils.log({ "alarms/wall_exceeded": 1.0, "alarms/wall_seconds_at_cap": elapsed, }, step=step) # ----- Tier-2 + tier-3 eval ----- # if inloop_every and step > 0 and step % inloop_every == 0: self._run_inloop_eval(step) def on_log(self, args, state_, control, logs=None, **kwargs): # noqa: D401 if not logs: return step = state_.global_step # Tier-1: surface train/* metrics that TRL itself produces. extra: dict = {} for src_key, dst_key in [ ("kl", "train/kl_divergence"), ("train/kl_divergence", "train/kl_divergence"), ("grad_norm", "train/grad_norm"), ("loss", "train/policy_loss"), ("learning_rate", "train/learning_rate"), ]: if src_key in logs: try: extra[dst_key] = float(logs[src_key]) except (TypeError, ValueError): pass if extra: wandb_utils.log(extra, step=step) # KL alarm. kl = logs.get("kl") or logs.get("train/kl_divergence") if kl is not None: try: kl_v = float(kl) except (TypeError, ValueError): kl_v = None if kl_v is not None and kl_v > kl_alarm: wandb_utils.log({ "alarms/kl_alarm": 1.0, "alarms/kl_alarm_value": kl_v, }, step=step) print(f"[grpo][step {step}] KL ALARM: {kl_v:.3f} " f"> {kl_alarm:.3f} - inspect generations.") # Decision rule: grad_norm > ceil for N consecutive logs. gn = logs.get("grad_norm") if gn is not None: try: gn_v = float(gn) except (TypeError, ValueError): gn_v = None if gn_v is not None: grad_norm_run.append(gn_v) ceil = decision_thresholds["grad_norm_ceil"] run_len = decision_thresholds["grad_norm_run_len"] if (len(grad_norm_run) >= run_len and all(x > ceil for x in grad_norm_run) and step != state["grad_norm_warned_at"]): print(f"\n[grpo-decision] CRITICAL @ step {step}: " f"train/grad_norm > {ceil} for {run_len} " f"consecutive logs ({list(grad_norm_run)}). " f"Recommend reducing LR (warning only - no " f"auto-action).") wandb_utils.log({ "alarms/grad_norm_runaway": 1.0, "alarms/grad_norm_value": gn_v, }, step=step) state["grad_norm_warned_at"] = step def on_train_end(self, args, state_, control, **kwargs): # noqa: D401 self._run_inloop_eval(state_.global_step, table_name="rl/final_eval") # ------------------------------------------------------------------ # # Tier-2 / tier-3 eval (greedy, T=0) # # ------------------------------------------------------------------ # def _run_inloop_eval(self, step: int, table_name: str = "rl/inloop_eval"): try: from unsloth import FastLanguageModel FastLanguageModel.for_inference(model) except Exception: model.eval() # type: ignore[attr-defined] n = len(eval_rows) stats = { "logical_correction": 0, "format_success": 0, "format_partial": 0, "pymatching_beat": 0, "syndrome_consistency_pass": 0, "exact_match_pymatching": 0, "total_reward_sum": 0.0, "completion_len_sum": 0, "hard_lcr_num": 0, "hard_lcr_den": 0, "ler_d3_p001_logical_errors": 0, "ler_d3_p001_total": 0, "ler_d3_p001_rounds": 0, } preview_rows = [] pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id import torch from qubit_medic.config import REWARD_WEIGHTS for ep_idx, row in enumerate(eval_rows): prompt = row["prompt"] episode_id = int(row.get("episode_id", -1)) try: chat = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) except Exception: text = ("<|im_start|>user\n" + prompt + "\n<|im_end|>\n<|im_start|>assistant\n") inputs = tokenizer(text, return_tensors="pt").to(model.device) try: with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=inloop_max_new_tokens, do_sample=False, # greedy at T=0 per spec eos_token_id=tokenizer.eos_token_id, pad_token_id=pad_id, ) gen_ids = out[0][inputs["input_ids"].shape[1]:] completion = tokenizer.decode(gen_ids, skip_special_tokens=True) n_tokens = int(gen_ids.shape[0]) except Exception as exc: # pragma: no cover completion = f"" n_tokens = 0 # Score against the env. If episode_id has TTL'd we fall back # to a fresh reset so the run continues, but log nothing # special - the metric arithmetic is still correct. try: result = env_client.step(raw_response=completion, episode_id=episode_id) except Exception: obs2 = env_client.reset(seed=row.get("seed")) result = env_client.step(raw_response=completion, episode_id=obs2.episode_id) rwd = result.info.get("rewards", {}) or {} action = result.info.get("parsed_action", {}) or {} actual = int(result.info.get("actual_observable_flip", 0)) pm_pred = int(result.info.get("pymatching_observable_pred", 0)) we_correct = float(rwd.get("logical_correction", 0.0)) >= 0.5 pm_correct = (pm_pred == actual) stats["logical_correction"] += int(we_correct) stats["format_success"] += int(action.get("parse_success", False)) stats["format_partial"] += int( float(rwd.get("format_compliance", 0.0)) >= 0.5 and not action.get("parse_success", False) ) stats["pymatching_beat"] += int( float(rwd.get("pymatching_beat", 0.0)) >= 0.5) stats["syndrome_consistency_pass"] += int( float(rwd.get("syndrome_consistency", 0.0)) >= 0.999) weighted = sum( weight * max(0.0, min(1.0, float(rwd.get(name, 0.0)))) for name, weight in REWARD_WEIGHTS.items() ) stats["total_reward_sum"] += max(0.0, min(1.0, weighted)) stats["completion_len_sum"] += n_tokens pm_x = sorted(set(map(int, result.info.get("pymatching_x_errors", []) or []))) pm_z = sorted(set(map(int, result.info.get("pymatching_z_errors", []) or []))) our_x = sorted(set(map(int, action.get("x_error_qubits", []) or []))) our_z = sorted(set(map(int, action.get("z_error_qubits", []) or []))) if (action.get("parse_success", False) and pm_x == our_x and pm_z == our_z): stats["exact_match_pymatching"] += 1 # Hard syndrome: >=2 stabilizers fired (anti-hacking spec # forbids exposing true_x/true_z, so we use the syndrome # bit count from the cached eval row as the proxy). n_active = sum(1 for b in row.get("syndrome_bits", []) if int(b)) if n_active >= 2: stats["hard_lcr_den"] += 1 stats["hard_lcr_num"] += int(we_correct) # tier-3: per-round LER for d=3 / p=0.001 only. d = int(row.get("distance", 0)) rnds = max(1, int(row.get("rounds", 0))) if d == 3 and abs(float(row.get("p", 0.0)) - 0.001) < 1e-6: stats["ler_d3_p001_total"] += 1 stats["ler_d3_p001_rounds"] = rnds if not we_correct: stats["ler_d3_p001_logical_errors"] += 1 if ep_idx < 4: preview_rows.append({ "step": step, "episode": ep_idx, "completion": completion[:300], "logical_correction": rwd.get("logical_correction", 0.0), "syndrome_consistency": rwd.get("syndrome_consistency", 0.0), "format_compliance": rwd.get("format_compliance", 0.0), "pymatching_beat": rwd.get("pymatching_beat", 0.0), "weighted_total": weighted, }) denom = max(1, n) lcr = stats["logical_correction"] / denom beat_rate = stats["pymatching_beat"] / denom fmt_compliance = stats["format_success"] / denom hard_lcr = (stats["hard_lcr_num"] / max(1, stats["hard_lcr_den"]) if stats["hard_lcr_den"] else 0.0) sync_consistency_rate = stats["syndrome_consistency_pass"] / denom avg_completion_len = stats["completion_len_sum"] / denom mean_total_reward = stats["total_reward_sum"] / denom exact_match = stats["exact_match_pymatching"] / denom # Tier-3 LER per round, log10. ler_per_round = None ler_log10 = None if stats["ler_d3_p001_total"] > 0: p_logical = (stats["ler_d3_p001_logical_errors"] / stats["ler_d3_p001_total"]) rounds = max(1, stats["ler_d3_p001_rounds"]) # Per-round LER: 1 - (1 - p_logical)^(1/rounds). ler_per_round = 1.0 - (1.0 - max(0.0, min(1.0, p_logical))) ** (1.0 / rounds) if ler_per_round > 0: import math ler_log10 = math.log10(max(ler_per_round, 1e-12)) # Tier-2 output diversity probe at T=1.0 (8 samples per prompt # on a small subset to keep eval fast). div_probe_n = min(8, len(eval_rows)) div_samples = 8 unique_counts: list[int] = [] for row in eval_rows[:div_probe_n]: prompt = row["prompt"] try: chat = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) except Exception: text = ("<|im_start|>user\n" + prompt + "\n<|im_end|>\n<|im_start|>assistant\n") inputs = tokenizer(text, return_tensors="pt").to(model.device) outs = [] for _ in range(div_samples): try: with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=inloop_max_new_tokens, do_sample=True, temperature=1.0, top_p=0.95, top_k=50, eos_token_id=tokenizer.eos_token_id, pad_token_id=pad_id, ) gen = tokenizer.decode( out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, ).strip() except Exception: gen = "" outs.append(gen) unique_counts.append(len(set(outs))) output_diversity_t1 = (sum(unique_counts) / max(1, len(unique_counts)) if unique_counts else 0.0) eval_metrics = { "eval/logical_correction_rate": lcr, "eval/pymatching_beat_rate": beat_rate, "eval/format_compliance": fmt_compliance, "eval/exact_match_pymatching": exact_match, "eval/hard_syndrome_lcr": hard_lcr, "eval/syndrome_consistency_rate": sync_consistency_rate, "eval/avg_completion_length": avg_completion_len, "eval/output_diversity_temp_1": output_diversity_t1, "eval/total_reward_mean": mean_total_reward, "eval/episodes": denom, } if ler_per_round is not None: eval_metrics["eval/ler_per_round_d3_p001"] = ler_per_round if ler_log10 is not None: eval_metrics["eval/ler_per_round_log10"] = ler_log10 print(f"[grpo][eval@{step}] " + ", ".join( f"{k.split('/')[-1]}={v:.4f}" if isinstance(v, float) else f"{k.split('/')[-1]}={v}" for k, v in eval_metrics.items() )) wandb_utils.log(eval_metrics, step=step) if preview_rows: wandb_utils.log_generation_table( preview_rows, step=step, table_name=table_name, ) # Decision rule: step-500 pymatching_beat sanity. state["beat_rate_history"].append(beat_rate) if len(state["beat_rate_history"]) > 5: state["beat_rate_history"] = state["beat_rate_history"][-5:] if (not state["step500_warned"] and step >= decision_thresholds["beat_rate_check_step"] and len(state["beat_rate_history"]) >= 5 and all(b == 0 for b in state["beat_rate_history"])): print(f"\n[grpo-decision] WARN @ step {step}: " f"eval/pymatching_beat_rate has been 0.0 across the last " f"5 evals. The model is never finding syndromes where " f"PyMatching fails - consider increasing the " f"pymatching_beat reward weight (warning only).") wandb_utils.log({"alarms/zero_beat_rate": 1.0}, step=step) state["step500_warned"] = True # Decision rule: format_compliance < floor. if (fmt_compliance < decision_thresholds["format_floor"] and step != state["format_warned_at"]): print(f"\n[grpo-decision] WARN @ step {step}: " f"eval/format_compliance={fmt_compliance:.3f} < " f"{decision_thresholds['format_floor']}. Consider " f"increasing format_compliance weight (warning only).") wandb_utils.log({ "alarms/format_below_floor": 1.0, "alarms/format_value": fmt_compliance, }, step=step) state["format_warned_at"] = step # ----- Best-checkpoint tracking ----- # if mean_total_reward > state["best_total_reward"]: old = state["best_total_reward"] state["best_total_reward"] = mean_total_reward state["best_step"] = step print(f"[grpo][eval@{step}] new best total_reward_mean=" f"{mean_total_reward:.4f} (prev {old:.4f}); " f"saving to {best_dir}") try: if best_dir.exists(): shutil.rmtree(best_dir) best_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(best_dir)) tokenizer.save_pretrained(str(best_dir)) wandb_utils.update_summary({ "best/total_reward_mean": mean_total_reward, "best/step": step, }) except Exception as exc: print(f"[grpo] WARN: failed to save best checkpoint: " f"{exc!r}", file=sys.stderr) # Switch back to training mode. try: from unsloth import FastLanguageModel FastLanguageModel.for_training(model) except Exception: model.train() # type: ignore[attr-defined] return _RolloutCallback() # --------------------------------------------------------------------------- # # Dataset of prompts # # --------------------------------------------------------------------------- # def _build_prompt_pool(env_client, n: int): prompts = [] for _ in range(n): obs = env_client.reset() prompts.append({"prompt": obs.prompt, "episode_id": obs.episode_id}) return prompts # --------------------------------------------------------------------------- # # Main # # --------------------------------------------------------------------------- # def main(argv: Iterable[str] = ()) -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--sft-checkpoint", type=str, default=None, help="LoRA adapter directory to start GRPO from. " "Defaults to config.SFT_CHECKPOINT_PATH_FOR_GRPO " "(checkpoints/sft_warmup/checkpoint-50).") parser.add_argument("--output", type=str, default="checkpoints/grpo") parser.add_argument("--model", type=str, default=os.getenv( "QUBIT_MEDIC_MODEL", "unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit"), help="Base model. Defaults to the 4-bit unsloth bundle " "matching the SFT base.") parser.add_argument("--steps", type=int, default=None) parser.add_argument("--gen-per-prompt", type=int, default=None) parser.add_argument("--lr", type=float, default=None) parser.add_argument("--kl-coef", type=float, default=None) parser.add_argument("--max-prompt-len", type=int, default=None) parser.add_argument("--max-completion-len", type=int, default=None) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--report-to", type=str, default="wandb") parser.add_argument("--prompt-pool", type=int, default=512) parser.add_argument("--wandb-run-name", type=str, default=None) parser.add_argument("--wandb-group", type=str, default=None) parser.add_argument("--wandb-tags", type=str, nargs="*", default=("grpo",)) parser.add_argument("--wandb-notes", type=str, default=None) parser.add_argument("--sample-every", type=int, default=None) parser.add_argument("--sample-n", type=int, default=None) parser.add_argument("--inloop-eval-every", type=int, default=None) parser.add_argument("--inloop-eval-episodes", type=int, default=None) parser.add_argument("--kl-alarm", type=float, default=None) parser.add_argument("--no-artifact", action="store_true") parser.add_argument("--skip-preflight", action="store_true", help="Skip the diversity preflight (DEBUG ONLY)") args = parser.parse_args(list(argv)) # Lazy heavy imports. try: from unsloth import FastLanguageModel except ImportError: print("ERROR: unsloth not installed. " "Run `pip install -r requirements-train.txt`", file=sys.stderr) return 1 import torch from datasets import Dataset from trl import GRPOConfig, GRPOTrainer # Belt-and-suspenders for the dynamo recompile-limit crash that killed a # previous run at step 400. Env vars at the top of the file cover the # case where torch hasn't been imported yet; this block covers the case # where unsloth/torch were already imported (env vars no-op then) and # also flips suppress_errors so any future overflow falls back to eager # instead of raising FailOnRecompileLimitHit. try: import torch._dynamo as _dynamo for _attr in ("cache_size_limit", "recompile_limit", "accumulated_cache_size_limit"): if hasattr(_dynamo.config, _attr): setattr(_dynamo.config, _attr, max(256, getattr(_dynamo.config, _attr))) _dynamo.config.suppress_errors = True except Exception as _exc: # pragma: no cover - defensive print(f"[grpo-guard] WARNING: could not raise dynamo limits: " f"{_exc!r}", file=sys.stderr) # Pre-flight signature check + stale-cache wipe. _wipe_stale_grpo_cache() _assert_grpo_signature_compatible() from qubit_medic import wandb_utils from qubit_medic.client.client import make_default_client from qubit_medic.config import ( GRPO_BATCH_SIZE, GRPO_CHECKPOINT_EVERY, GRPO_DECISION_BEAT_RATE_CHECK_STEP, GRPO_DECISION_FORMAT_FLOOR, GRPO_DECISION_GRAD_NORM_CEIL, GRPO_DECISION_GRAD_NORM_RUN_LEN, GRPO_DECISION_REWARD_STD_CHECK_STEP, GRPO_DECISION_REWARD_STD_FLOOR, GRPO_DO_SAMPLE, GRPO_GEN_PER_PROMPT, GRPO_GRAD_ACCUM, GRPO_INSPECTION_COLLAPSE_THRESHOLD, GRPO_INSPECTION_HOOK_EVERY, GRPO_INSPECTION_SAMPLE_N, GRPO_KL_ALARM, GRPO_KL_COEF, GRPO_LOG_EVERY, GRPO_LR, GRPO_LR_SCHEDULER, GRPO_MAX_COMPLETION_LEN, GRPO_MAX_PROMPT_LEN, GRPO_OPTIMIZER, GRPO_REPETITION_PENALTY, GRPO_SAMPLE_LOG_EVERY, GRPO_SAMPLE_LOG_N, GRPO_SAVE_TOTAL_LIMIT, GRPO_STEPS, GRPO_TEMP_BUMP_ON_COLLAPSE, GRPO_TEMPERATURE, GRPO_TOP_K, GRPO_TOP_P, GRPO_VAL_EPISODES, GRPO_VAL_PATH, GRPO_VAL_SEED, GRPO_WALL_SECONDS, LORA_ALPHA, LORA_DROPOUT, LORA_R, LORA_TARGET_MODULES, MODEL_ID, PRIMARY_SEED, REWARD_WEIGHTS, SFT_CHECKPOINT_PATH_FOR_GRPO, WANDB_INLOOP_EVAL_EPISODES, WANDB_INLOOP_EVAL_EVERY, ) sft_ckpt = args.sft_checkpoint or SFT_CHECKPOINT_PATH_FOR_GRPO steps = args.steps if args.steps is not None else GRPO_STEPS gen_per_prompt = args.gen_per_prompt if args.gen_per_prompt is not None else GRPO_GEN_PER_PROMPT lr = args.lr if args.lr is not None else GRPO_LR kl_coef = args.kl_coef if args.kl_coef is not None else GRPO_KL_COEF max_p = args.max_prompt_len if args.max_prompt_len is not None else GRPO_MAX_PROMPT_LEN max_c = args.max_completion_len if args.max_completion_len is not None else GRPO_MAX_COMPLETION_LEN seed = args.seed if args.seed is not None else PRIMARY_SEED sample_every = args.sample_every if args.sample_every is not None else GRPO_SAMPLE_LOG_EVERY sample_n = args.sample_n if args.sample_n is not None else GRPO_SAMPLE_LOG_N inloop_every = args.inloop_eval_every if args.inloop_eval_every is not None else WANDB_INLOOP_EVAL_EVERY inloop_episodes = args.inloop_eval_episodes if args.inloop_eval_episodes is not None else WANDB_INLOOP_EVAL_EPISODES kl_alarm = args.kl_alarm if args.kl_alarm is not None else GRPO_KL_ALARM _seed_everything(seed) # ---- Env client --------------------------------------------------- # env_client = make_default_client() print(f"using env client: {type(env_client).__name__}; " f"health = {env_client.health()}") # ---- W&B init ----------------------------------------------------- # report_to = wandb_utils.derive_report_to(args.report_to) run_name = args.wandb_run_name or wandb_utils.make_run_name("grpo") wandb_utils.init_run( run_name=run_name, job_type="grpo", tags=args.wandb_tags, notes=args.wandb_notes, group=args.wandb_group, extra_config={ "cli": { "steps": steps, "gen_per_prompt": gen_per_prompt, "lr": lr, "kl_coef": kl_coef, "max_prompt_len": max_p, "max_completion_len": max_c, "prompt_pool": args.prompt_pool, "sample_every": sample_every, "sample_n": sample_n, "inloop_eval_every": inloop_every, "inloop_eval_episodes": inloop_episodes, "kl_alarm": kl_alarm, "temperature": GRPO_TEMPERATURE, "top_p": GRPO_TOP_P, "top_k": GRPO_TOP_K, "repetition_penalty": GRPO_REPETITION_PENALTY, "do_sample": GRPO_DO_SAMPLE, "lr_scheduler": GRPO_LR_SCHEDULER, "optimizer": GRPO_OPTIMIZER, "grad_accum": GRPO_GRAD_ACCUM, "effective_batch": GRPO_BATCH_SIZE * GRPO_GRAD_ACCUM, "sft_checkpoint": sft_ckpt, "model": args.model, "seed": seed, "report_to": report_to, "wall_seconds": GRPO_WALL_SECONDS, "reward_weights": dict(REWARD_WEIGHTS), "val_seed": GRPO_VAL_SEED, "val_episodes": GRPO_VAL_EPISODES, }, }, ) # Use train/global_step as default x-axis for everything we log. try: run = wandb_utils.get_run() if run is not None: run.define_metric("train/global_step") run.define_metric("train/*", step_metric="train/global_step") run.define_metric("eval/*", step_metric="train/global_step") run.define_metric("alarms/*", step_metric="train/global_step") run.define_metric("rl/*", step_metric="train/global_step") run.define_metric("best/*", step_metric="train/global_step") except Exception as exc: print(f"[wandb] could not define x-axis metric: {exc!r}", file=sys.stderr) # ---- Build prompt pool -------------------------------------------- # print(f"pre-generating {args.prompt_pool} prompts ...") prompts = _build_prompt_pool(env_client, args.prompt_pool) dataset = Dataset.from_list(prompts) print(f" built dataset with {len(dataset)} prompts") # ---- Frozen eval set --------------------------------------------- # eval_rows = _load_or_build_eval_set( env_client, seed=GRPO_VAL_SEED, n=inloop_episodes, path=GRPO_VAL_PATH, ) # ---- Load model --------------------------------------------------- # print(f"loading base={args.model}, sft adapter={sft_ckpt}") base_for_load = sft_ckpt if Path(sft_ckpt).exists() else args.model model, tokenizer = FastLanguageModel.from_pretrained( model_name=base_for_load, max_seq_length=max_p + max_c, load_in_4bit=True, dtype=None, ) if not Path(sft_ckpt).exists(): print(f"[grpo] WARN: SFT checkpoint {sft_ckpt} not found; " f"attaching fresh LoRA on the base model") model = FastLanguageModel.get_peft_model( model, r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=list(LORA_TARGET_MODULES), lora_dropout=LORA_DROPOUT, bias="none", use_gradient_checkpointing="unsloth", random_state=seed, ) # ---- Diversity preflight ----------------------------------------- # if not args.skip_preflight: ok = _diversity_preflight( model, tokenizer, val_path="data/sft_validation.jsonl", n_prompts=5, n_samples_per_prompt=8, temperature=GRPO_TEMPERATURE, min_unique=3, min_passing=3, max_new_tokens=max_c, ) if not ok: wandb_utils.update_summary({"preflight/passed": False}) wandb_utils.finish_run() return 2 wandb_utils.update_summary({"preflight/passed": True}) else: print("[grpo] --skip-preflight given; bypassing diversity preflight " "(DEBUG ONLY)") # ---- TRL GRPOConfig ---------------------------------------------- # output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) best_dir = output_dir / "best" final_dir = output_dir / "final" bf16_supported = ( torch.cuda.is_available() and torch.cuda.is_bf16_supported() ) grpo_kwargs: dict = { "output_dir": str(output_dir), "max_steps": steps, "per_device_train_batch_size": GRPO_BATCH_SIZE, "gradient_accumulation_steps": GRPO_GRAD_ACCUM, "num_generations": gen_per_prompt, "max_prompt_length": max_p, "max_completion_length": max_c, "learning_rate": lr, "beta": kl_coef, "lr_scheduler_type": GRPO_LR_SCHEDULER, "optim": GRPO_OPTIMIZER, "logging_steps": GRPO_LOG_EVERY, "save_steps": GRPO_CHECKPOINT_EVERY, "save_total_limit": GRPO_SAVE_TOTAL_LIMIT, "save_only_model": False, "seed": seed, "bf16": bf16_supported, "fp16": torch.cuda.is_available() and not bf16_supported, "report_to": report_to, "run_name": run_name, # Diversity-focused rollout sampling. "temperature": GRPO_TEMPERATURE, "top_p": GRPO_TOP_P, "top_k": GRPO_TOP_K, "repetition_penalty": GRPO_REPETITION_PENALTY, } # Some TRL versions don't accept every sampling kwarg on GRPOConfig; # fall back gracefully so the script still runs. config = None dropped: list[str] = [] while config is None: try: config = GRPOConfig(**grpo_kwargs) except TypeError as exc: msg = str(exc) removed = False for k in ("repetition_penalty", "top_k", "top_p", "temperature", "save_only_model"): if k in msg and k in grpo_kwargs: grpo_kwargs.pop(k) dropped.append(k) removed = True break if not removed: raise if dropped: print(f"[grpo] WARN: TRL did not accept these GRPOConfig kwargs and " f"they were dropped: {dropped}. Using TRL defaults for them.") # ---- Reward functions + scoring cache ----------------------------- # cache = _BatchScoringCache(env_client=env_client, reward_weights=dict(REWARD_WEIGHTS)) reward_fns = _make_reward_fns(cache) # The first reward is the bounded weighted-total used for the gradient; # the rest are zero-weight observers used only for per-component traces. reward_weights = [1.0] + [0.0] * len(_REWARD_COMPONENTS) callbacks = [] cb = _build_wandb_callback( cache, model, tokenizer, env_client, eval_rows, sample_every=sample_every, sample_n=sample_n, inloop_every=inloop_every, inloop_max_new_tokens=max_c, kl_alarm=kl_alarm, inspection_every=GRPO_INSPECTION_HOOK_EVERY, inspection_sample_n=GRPO_INSPECTION_SAMPLE_N, inspection_collapse_threshold=GRPO_INSPECTION_COLLAPSE_THRESHOLD, temp_bump_on_collapse=GRPO_TEMP_BUMP_ON_COLLAPSE, best_dir=best_dir, output_dir=output_dir, wall_seconds=GRPO_WALL_SECONDS, decision_thresholds={ "reward_std_floor": GRPO_DECISION_REWARD_STD_FLOOR, "reward_std_check_step": GRPO_DECISION_REWARD_STD_CHECK_STEP, "beat_rate_check_step": GRPO_DECISION_BEAT_RATE_CHECK_STEP, "format_floor": GRPO_DECISION_FORMAT_FLOOR, "grad_norm_ceil": GRPO_DECISION_GRAD_NORM_CEIL, "grad_norm_run_len": GRPO_DECISION_GRAD_NORM_RUN_LEN, }, ) if cb is not None: callbacks.append(cb) # Older TRL versions: GRPOTrainer may not accept reward_weights kw. trainer_kwargs = dict( model=model, processing_class=tokenizer, args=config, train_dataset=dataset, reward_funcs=reward_fns, reward_weights=reward_weights, callbacks=callbacks, ) try: trainer = GRPOTrainer(**trainer_kwargs) except TypeError as exc: if "reward_weights" in str(exc): print("[grpo] WARN: this TRL does not accept reward_weights= " "on GRPOTrainer; falling back to using only the bounded " "weighted-total reward (and no observers).") trainer_kwargs.pop("reward_weights") trainer_kwargs["reward_funcs"] = [reward_fns[0]] trainer = GRPOTrainer(**trainer_kwargs) else: raise print(f"running GRPO for {steps} steps " f"(temperature={GRPO_TEMPERATURE}, top_p={GRPO_TOP_P}, " f"top_k={GRPO_TOP_K}, repetition_penalty={GRPO_REPETITION_PENALTY}, " f"beta={kl_coef}, lr={lr}) ...") started = time.time() train_result = trainer.train() elapsed = time.time() - started print(f"finished in {elapsed:.1f}s") metrics = getattr(train_result, "metrics", {}) or {} wandb_utils.update_summary({ "grpo/wall_seconds": elapsed, "grpo/total_episodes": cache._episodes, "grpo/total_timeouts": cache._timeouts, "grpo/reward_bounds_violations": cache._bounds_violations, **{f"grpo/final/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}, }) # ---- Final + rolling adapter saves ------------------------------- # print(f"saving rolling adapter snapshot to {output_dir}") model.save_pretrained(str(output_dir)) tokenizer.save_pretrained(str(output_dir)) print(f"saving final adapter snapshot to {final_dir}") final_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(final_dir)) tokenizer.save_pretrained(str(final_dir)) if not args.no_artifact: wandb_utils.log_artifact( str(final_dir), name=f"grpo-final-{run_name}", artifact_type="model", description="GRPO final LoRA adapter (Qubit-Medic).", ) if best_dir.exists(): wandb_utils.log_artifact( str(best_dir), name=f"grpo-best-{run_name}", artifact_type="model", description="GRPO best-eval LoRA adapter (Qubit-Medic).", ) wandb_utils.finish_run() return 0 if __name__ == "__main__": sys.exit(main(sys.argv[1:]))