Spaces:
Sleeping
Sleeping
| """scripts/train_grpo.py - GRPO RL phase (2026-04 spec rewrite). | |
| Loads the SFT-warm-started LoRA adapter at | |
| ``checkpoints/sft_warmup/checkpoint-50`` on top of the 4-bit NF4 quantised | |
| ``unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit`` base, connects to the | |
| OpenEnv server (local or remote via ``QUBIT_MEDIC_URL``), and runs TRL's | |
| :class:`GRPOTrainer` for 1,500 steps with diversity-focused rollout | |
| sampling (temperature=1.2, top_p=0.95, top_k=50, repetition_penalty=1.1) | |
| and a weighted 5-component reward bounded to [0, 1]. | |
| Why diversity-focused sampling | |
| ------------------------------ | |
| The first GRPO attempt (temperature=0.7) collapsed inside 100 steps to a | |
| constant ``X_ERRORS=[] Z_ERRORS=[]`` policy: every group of 4 generations | |
| was byte-identical, so within-group reward variance was zero and the | |
| GRPO advantage was exactly zero - no gradient. The new sampler defaults | |
| broaden per-token entropy enough to keep within-group variance positive, | |
| which is what GRPO needs to learn anything. | |
| Major spec features wired up here | |
| --------------------------------- | |
| * ``_diversity_preflight`` - 5 prompts x 8 completions at T=1.2; abort if | |
| fewer than 3 prompts hit >=3 unique completions. The model is too | |
| collapsed for GRPO to recover. | |
| * Frozen 200-syndrome eval set seeded ``4284`` (matches SFT validation | |
| seed). Cached to ``data/grpo_validation.jsonl`` so reruns and offline | |
| inspection see the same prompts. | |
| * Tier-1 training metrics (every 5 steps): total_reward_mean, | |
| reward_std_within_group, completion_uniqueness, advantage_mean_abs, | |
| kl_divergence, grad_norm, policy_loss, learning_rate. | |
| * Tier-2 eval metrics (every 100 steps, greedy at T=0): logical | |
| correction rate, pymatching beat rate, format compliance, exact-match | |
| pymatching, hard-syndrome (>=2 errors) LCR, syndrome consistency, | |
| avg_completion_length, output_diversity at T=1.0. | |
| * Tier-3 (every eval): per-round logical-error rate at d=3 p=0.001 plus | |
| log10 transform. | |
| * Sample-completion table every 50 steps: 5 random eval prompts, the 4 | |
| rollouts each, per-component rewards, parsed action. | |
| * Anti-hacking: 30s per-episode timeout (server-side), reward bounds | |
| enforced both pre-multiply and post-sum, mode-collapse inspection | |
| every 100 steps that auto-raises temperature by 0.2 if >7 of the | |
| last 10 prompts produced 4 byte-identical generations. | |
| * Wall-clock cap: 13h. Saves+exits cleanly if exceeded. | |
| * Best-checkpoint tracking: writes ``output/best/`` whenever a new best | |
| ``eval/total_reward_mean`` is observed. Final state always saves to | |
| ``output/final/`` regardless of rank. | |
| * Decision rules (warnings only, no auto-fix): step-50 reward variance | |
| floor, step-500 pymatching-beat sanity, format-compliance floor, and | |
| 3-consecutive-log grad-norm runaway alarm. | |
| Usage:: | |
| python -m scripts.train_grpo \ | |
| --sft-checkpoint checkpoints/sft_warmup/checkpoint-50 \ | |
| --output checkpoints/grpo \ | |
| --report-to wandb | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import inspect | |
| import json | |
| import os | |
| import random | |
| # torch._dynamo recompile-limit guard. Unsloth's GRPO trainer wraps the | |
| # loss/generation graph in torch.compile(fullgraph=True). Two things blow | |
| # past Dynamo's default cache_size_limit (8) over a long GRPO run: | |
| # 1. The mode-collapse hook mutates trainer.args.temperature in flight | |
| # (e.g. 1.2 -> 1.4 -> 1.6 -> 1.8); each mutation re-specializes the | |
| # compiled generation path. | |
| # 2. Variable prompt/completion shapes specialize over hundreds of steps. | |
| # When the limit is hit, fullgraph=True turns it into a fatal | |
| # FailOnRecompileLimitHit (we lost a run at step 400/1500 to this). Set the | |
| # limits high before torch is imported so they take effect everywhere. | |
| os.environ.setdefault("TORCHDYNAMO_CACHE_SIZE_LIMIT", "256") | |
| os.environ.setdefault("TORCHDYNAMO_RECOMPILE_LIMIT", "256") | |
| import shutil | |
| import sys | |
| import threading | |
| import time | |
| from collections import deque | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| # --------------------------------------------------------------------------- # | |
| # Pre-flight: detect the unsloth / unsloth_zoo signature skew that crashes # | |
| # GRPO at step 0 with a misleading TypeError. # | |
| # --------------------------------------------------------------------------- # | |
| # | |
| # unsloth==2025.11.1's GRPO trainer template calls | |
| # grpo_accumulated_loss(..., old_hidden_states=..., ref_hidden_states=...) | |
| # but unsloth_zoo>=2026.x renamed those positional args to old_logps / ref_logps | |
| # with no compat shim. Pip's resolver (with the unpinned `unsloth` line in | |
| # requirements-train.txt) silently couples the two: it picks | |
| # unsloth==2025.11.1 + unsloth_zoo==2026.4.9 | |
| # and that pair crashes at the first training step with: | |
| # TypeError: grpo_accumulated_loss() missing 2 required positional | |
| # arguments: 'old_logps' and 'ref_logps' | |
| # | |
| # SFT does not exercise this code path, so SFT finishes cleanly first, the | |
| # checkpoint gets saved, and only then GRPO blows up - wasting the whole SFT | |
| # run. This guard runs in well under a second, before any GPU work, and | |
| # prints the exact pip command to fix it instead of the cryptic TypeError. | |
| # --------------------------------------------------------------------------- # | |
| def _assert_grpo_signature_compatible() -> None: | |
| """Abort early if the installed unsloth_zoo signature does not match | |
| the call pattern baked into the installed unsloth. | |
| """ | |
| try: | |
| import unsloth # noqa: F401 (force the patches to apply first) | |
| import unsloth_zoo | |
| from unsloth_zoo.rl_replacements import grpo_accumulated_loss | |
| except Exception as exc: | |
| print(f"[grpo-guard] WARNING: could not introspect unsloth_zoo " | |
| f"({exc!r}); skipping signature check.", file=sys.stderr) | |
| return | |
| params = list(inspect.signature(grpo_accumulated_loss).parameters.keys()) | |
| has_hidden = "old_hidden_states" in params and "ref_hidden_states" in params | |
| has_logps = "old_logps" in params and "ref_logps" in params | |
| # The unsloth in this repo is pinned to the 2025.11.x lineage (matches | |
| # what SFT just used). That lineage calls with old_hidden_states= / | |
| # ref_hidden_states=. If unsloth_zoo has those names, we are fine. | |
| if has_hidden: | |
| return | |
| unsloth_ver = getattr(unsloth, "__version__", "?") | |
| zoo_ver = getattr(unsloth_zoo, "__version__", "?") | |
| have_logps_only = has_logps and not has_hidden | |
| msg = [ | |
| "", | |
| "=" * 78, | |
| "[grpo-guard] FATAL: unsloth / unsloth_zoo signature mismatch detected.", | |
| "=" * 78, | |
| f" unsloth == {unsloth_ver}", | |
| f" unsloth_zoo == {zoo_ver}", | |
| f" grpo_accumulated_loss parameters: {params}", | |
| "", | |
| " unsloth (this version) calls grpo_accumulated_loss with", | |
| " old_hidden_states=... , ref_hidden_states=...", | |
| " but the installed unsloth_zoo expects", | |
| " old_logps=... , ref_logps=...", | |
| " as required positional arguments." if have_logps_only else | |
| " but the installed unsloth_zoo signature does not contain those names.", | |
| "", | |
| " Without this fix, GRPO will crash at step 0 with:", | |
| " TypeError: grpo_accumulated_loss() missing 2 required positional", | |
| " arguments: 'old_logps' and 'ref_logps'", | |
| "", | |
| " Fix on Colab (one-liner):", | |
| " pip install --no-deps --force-reinstall unsloth_zoo==2025.11.1 \\", | |
| " && rm -rf unsloth_compiled_cache", | |
| "", | |
| " Then re-run:", | |
| " python -m scripts.train_grpo --sft-checkpoint " | |
| "checkpoints/sft_warmup/checkpoint-50 \\", | |
| " --output checkpoints/grpo", | |
| "=" * 78, | |
| "", | |
| ] | |
| raise SystemExit("\n".join(msg)) | |
| def _wipe_stale_grpo_cache() -> None: | |
| """Remove unsloth_compiled_cache/UnslothGRPOTrainer.py if present. | |
| The cache file is regenerated automatically by unsloth on the next | |
| GRPO import using the *currently installed* unsloth_zoo source, so | |
| deleting it is safe and is the only way to recover after fixing | |
| a previously-mismatched install. | |
| """ | |
| cache_file = Path("unsloth_compiled_cache") / "UnslothGRPOTrainer.py" | |
| if cache_file.exists(): | |
| print(f"[grpo-guard] removing stale {cache_file} so it regenerates " | |
| f"against the current unsloth_zoo install") | |
| try: | |
| cache_file.unlink() | |
| except OSError as exc: | |
| print(f"[grpo-guard] WARNING: failed to remove {cache_file}: " | |
| f"{exc!r}", file=sys.stderr) | |
| # --------------------------------------------------------------------------- # | |
| # Per-batch scoring cache + reward bounds enforcement # | |
| # --------------------------------------------------------------------------- # | |
| # | |
| # The original implementation called the env 5 times per (prompt, completion) | |
| # - once per reward function. We fix that with a single (prompt, completion) | |
| # -> breakdown cache keyed inside one GRPO step, AND we apply the spec's | |
| # weighted-sum + [0, 1] clip in one place so every reward function returns | |
| # a number that's already correctly weighted. | |
| # --------------------------------------------------------------------------- # | |
| class _ScoredCompletion: | |
| """One scored (prompt, completion) pair, keyed by the env episode.""" | |
| rewards: dict # raw per-component rewards from the env (in [0, 1]) | |
| weighted_total: float # weighted sum, clipped to [0, 1] | |
| parse_success: bool | |
| parse_partial: bool | |
| x_pred: list | |
| z_pred: list | |
| actual_flip: int | |
| pm_flip: int | |
| elapsed: float | |
| timed_out: bool | |
| curriculum_level: str | |
| bounds_violations: int # >0 if env returned a component outside [0, 1] | |
| class _BatchScoringCache: | |
| """Caches per-(prompt, completion) scores within one GRPO step.""" | |
| env_client: object | |
| reward_weights: dict | |
| _cache: dict = field(default_factory=dict) | |
| _step_keys: list = field(default_factory=list) | |
| _lock: threading.Lock = field(default_factory=threading.Lock) | |
| _all_curriculum_stats: dict = field(default_factory=dict) | |
| _episodes: int = 0 | |
| _timeouts: int = 0 | |
| _bounds_violations: int = 0 | |
| def _enforce_bounds(self, name: str, val: float) -> tuple[float, bool]: | |
| """Clip a reward component to [0, 1]; flag if it was outside.""" | |
| v = float(val) | |
| if v < 0.0 or v > 1.0: | |
| return max(0.0, min(1.0, v)), True | |
| return v, False | |
| def score(self, prompt: str, completion: str) -> _ScoredCompletion: | |
| key = (prompt, completion) | |
| with self._lock: | |
| entry = self._cache.get(key) | |
| if entry is not None: | |
| return entry | |
| # Env work is independent across (p, c) so it's safe to release the | |
| # lock during the network round-trip. | |
| obs = self.env_client.reset() | |
| result = self.env_client.step(raw_response=completion, | |
| episode_id=obs.episode_id) | |
| info = result.info | |
| action = info.get("parsed_action", {}) | |
| # Apply spec weights + [0, 1] bounds enforcement. | |
| raw = info.get("rewards", {}) or {} | |
| violations = 0 | |
| weighted_sum = 0.0 | |
| bounded_components: dict = {} | |
| for name, weight in self.reward_weights.items(): | |
| v, was_oob = self._enforce_bounds(name, raw.get(name, 0.0)) | |
| bounded_components[name] = v | |
| weighted_sum += weight * v | |
| if was_oob: | |
| violations += 1 | |
| # Clip weighted sum to [0, 1] (already in range when components | |
| # are; defensive against weights that don't sum to 1.0). | |
| weighted_total = max(0.0, min(1.0, weighted_sum)) | |
| # Preserve env's "total" alongside our weighted total so downstream | |
| # wandb log_reward_breakdown still works. | |
| bounded_components["total"] = weighted_total | |
| scored = _ScoredCompletion( | |
| rewards=bounded_components, | |
| weighted_total=weighted_total, | |
| parse_success=bool(action.get("parse_success", False)), | |
| parse_partial=False, | |
| x_pred=list(action.get("x_error_qubits", []) or []), | |
| z_pred=list(action.get("z_error_qubits", []) or []), | |
| actual_flip=int(info.get("actual_observable_flip", 0)), | |
| pm_flip=int(info.get("pymatching_observable_pred", 0)), | |
| elapsed=float(info.get("elapsed_seconds", 0.0)), | |
| timed_out=bool(info.get("timed_out", False)), | |
| curriculum_level=str(getattr(obs, "curriculum_level", "")), | |
| bounds_violations=violations, | |
| ) | |
| with self._lock: | |
| self._cache[key] = scored | |
| self._step_keys.append(key) | |
| self._all_curriculum_stats = info.get("curriculum_stats", {}) or {} | |
| self._episodes += 1 | |
| if scored.timed_out: | |
| self._timeouts += 1 | |
| if violations: | |
| self._bounds_violations += violations | |
| return scored | |
| def drain_step(self): | |
| """Pop everything cached since the last drain_step() call.""" | |
| with self._lock: | |
| entries = [self._cache[k] for k in self._step_keys] | |
| keys = list(self._step_keys) | |
| self._step_keys.clear() | |
| # Bound memory use - long runs with unique strings. | |
| if len(self._cache) > 4096: | |
| self._cache.clear() | |
| return entries, keys | |
| def _seed_everything(seed: int) -> None: | |
| import numpy as np | |
| import torch | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| # --------------------------------------------------------------------------- # | |
| # Reward function factory # | |
| # --------------------------------------------------------------------------- # | |
| # | |
| # Spec: total reward = sum of weighted components, clipped to [0, 1]. | |
| # Implementation: the cache returns a per-completion weighted_total in | |
| # [0, 1]. We expose ONE TRL reward function that returns that bounded | |
| # total, plus zero-weight per-component observers so wandb gets per- | |
| # component traces without altering the policy gradient. | |
| # --------------------------------------------------------------------------- # | |
| _REWARD_COMPONENTS = ( | |
| "logical_correction", | |
| "hamming_overlap", | |
| "syndrome_consistency", | |
| "format_compliance", | |
| "pymatching_beat", | |
| ) | |
| def _make_reward_fns(cache: _BatchScoringCache): | |
| def total_fn(prompts, completions, **_unused): | |
| scored = [cache.score(p, c) for p, c in zip(prompts, completions)] | |
| return [s.weighted_total for s in scored] | |
| total_fn.__name__ = "reward_total" | |
| observers: list = [] | |
| for name in _REWARD_COMPONENTS: | |
| def _factory(component_name: str): | |
| def fn(prompts, completions, **_unused): | |
| scored = [cache.score(p, c) for p, c in zip(prompts, completions)] | |
| return [s.rewards.get(component_name, 0.0) for s in scored] | |
| fn.__name__ = f"reward_obs_{component_name}" | |
| return fn | |
| observers.append(_factory(name)) | |
| return [total_fn] + observers | |
| # --------------------------------------------------------------------------- # | |
| # Frozen eval set: 200 syndromes seeded GRPO_VAL_SEED. # | |
| # --------------------------------------------------------------------------- # | |
| # | |
| # We snapshot the 200 prompts to data/grpo_validation.jsonl on first run so | |
| # reruns hit byte-identical syndromes, and so the file can be inspected / | |
| # diffed offline. If the file already exists with >= n rows, we trust it. | |
| # --------------------------------------------------------------------------- # | |
| def _load_or_build_eval_set(env_client, *, seed: int, n: int, path: str) -> list[dict]: | |
| p = Path(path) | |
| if p.exists(): | |
| rows: list[dict] = [] | |
| with p.open("r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| rows.append(json.loads(line)) | |
| if len(rows) >= n: | |
| print(f"[grpo-eval] reusing cached eval set: {p} ({len(rows)} rows)") | |
| return rows[:n] | |
| print(f"[grpo-eval] cached eval set at {p} has {len(rows)} < {n} rows; " | |
| f"regenerating") | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| rows = [] | |
| print(f"[grpo-eval] building frozen eval set seed={seed} n={n} -> {p}") | |
| cur_seed = seed | |
| for _ in range(n): | |
| obs = env_client.reset(seed=cur_seed) | |
| rows.append({ | |
| "prompt": obs.prompt, | |
| "episode_id": int(obs.episode_id), | |
| "curriculum_level": str(getattr(obs, "curriculum_level", "")), | |
| "distance": int(getattr(obs, "distance", 0)), | |
| "rounds": int(getattr(obs, "rounds", 0)), | |
| "p": float(getattr(obs, "p", 0.0)), | |
| "syndrome_bits": list(getattr(obs, "syndrome_bits", []) or []), | |
| "seed": cur_seed, | |
| }) | |
| cur_seed += 1 # deterministic, reproducible | |
| with p.open("w") as f: | |
| for r in rows: | |
| f.write(json.dumps(r) + "\n") | |
| print(f"[grpo-eval] wrote {len(rows)} eval rows to {p}") | |
| return rows | |
| # --------------------------------------------------------------------------- # | |
| # Diversity preflight # | |
| # --------------------------------------------------------------------------- # | |
| def _diversity_preflight(model, tokenizer, *, val_path: str, n_prompts: int = 5, | |
| n_samples_per_prompt: int = 8, temperature: float = 1.2, | |
| min_unique: int = 3, min_passing: int = 3, | |
| max_new_tokens: int = 50) -> bool: | |
| """Generate ``n_samples_per_prompt`` completions per prompt at high temp. | |
| Returns True iff at least ``min_passing`` of the prompts produced | |
| >= ``min_unique`` unique completions (byte-equal under skip-special-tokens | |
| decoding). False -> the model is collapsed past the point where GRPO | |
| can recover, so we should refuse to start training. | |
| """ | |
| import torch | |
| src = Path(val_path) | |
| if not src.exists(): | |
| print(f"[grpo-preflight] WARNING: {val_path} not found; " | |
| f"skipping diversity preflight") | |
| return True # don't block on missing file | |
| rows: list[dict] = [] | |
| with src.open("r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| rows.append(json.loads(line)) | |
| if len(rows) < n_prompts: | |
| print(f"[grpo-preflight] WARNING: only {len(rows)} validation rows " | |
| f"available, need {n_prompts}; using all") | |
| n_prompts = len(rows) | |
| # Mix of trivial (no errors) and non-trivial (errors present), so the | |
| # diversity probe sees both regimes the model has to handle. | |
| rng = random.Random(0) | |
| trivial = [r for r in rows if not r.get("had_errors")] | |
| non_trivial = [r for r in rows if r.get("had_errors")] | |
| rng.shuffle(trivial) | |
| rng.shuffle(non_trivial) | |
| half = max(1, n_prompts // 2) | |
| chosen = (non_trivial[:half] + trivial[:n_prompts - half])[:n_prompts] | |
| if not chosen: | |
| chosen = rows[:n_prompts] | |
| print(f"[grpo-preflight] probing diversity at T={temperature} on " | |
| f"{len(chosen)} prompts x {n_samples_per_prompt} samples each") | |
| try: | |
| from unsloth import FastLanguageModel | |
| FastLanguageModel.for_inference(model) | |
| except Exception: | |
| model.eval() | |
| passing = 0 | |
| per_prompt_unique: list[int] = [] | |
| pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id | |
| for i, row in enumerate(chosen): | |
| prompt = row["prompt"] | |
| try: | |
| chat = [{"role": "user", "content": prompt}] | |
| text = tokenizer.apply_chat_template( | |
| chat, tokenize=False, add_generation_prompt=True, | |
| ) | |
| except Exception: | |
| text = ("<|im_start|>user\n" + prompt | |
| + "\n<|im_end|>\n<|im_start|>assistant\n") | |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
| completions: list[str] = [] | |
| for _ in range(n_samples_per_prompt): | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.95, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=pad_id, | |
| ) | |
| gen_ids = out[0][inputs["input_ids"].shape[1]:] | |
| txt = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| completions.append(txt) | |
| unique = len(set(completions)) | |
| per_prompt_unique.append(unique) | |
| verdict = "PASS" if unique >= min_unique else "FAIL" | |
| print(f"[grpo-preflight] prompt {i}: {unique}/{n_samples_per_prompt} " | |
| f"unique [{verdict}] examples={completions[:2]!r}") | |
| if unique >= min_unique: | |
| passing += 1 | |
| overall = passing >= min_passing | |
| print(f"[grpo-preflight] {passing}/{len(chosen)} prompts passed " | |
| f"(threshold: >= {min_passing}). per_prompt_unique={per_prompt_unique}") | |
| if not overall: | |
| print("=" * 78) | |
| print("[grpo-preflight] PRE-FLIGHT FAILED - model is too collapsed; " | |
| "redo SFT with regularization before launching GRPO") | |
| print("=" * 78) | |
| return overall | |
| # --------------------------------------------------------------------------- # | |
| # In-loop W&B callback (tier-1 + tier-2 + tier-3 + sample table + safeguards) # | |
| # --------------------------------------------------------------------------- # | |
| def _build_wandb_callback(cache, model, tokenizer, env_client, eval_rows, | |
| *, sample_every: int, sample_n: int, | |
| inloop_every: int, | |
| inloop_max_new_tokens: int, | |
| kl_alarm: float, | |
| inspection_every: int, inspection_sample_n: int, | |
| inspection_collapse_threshold: int, | |
| temp_bump_on_collapse: float, | |
| best_dir: Path, output_dir: Path, | |
| wall_seconds: float, | |
| decision_thresholds: dict): | |
| from transformers import TrainerCallback | |
| from qubit_medic import wandb_utils | |
| if not wandb_utils.is_available(): | |
| return None | |
| started_at = time.time() | |
| # Rolling cache for the inspection hook: we record (group_unique_count) | |
| # per prompt as it streams in, and at every inspection_every-step | |
| # boundary look at the most recent inspection_sample_n entries. | |
| recent_uniques = deque(maxlen=max(inspection_sample_n, 16)) | |
| grad_norm_run = deque(maxlen=decision_thresholds["grad_norm_run_len"]) | |
| state = { | |
| "best_total_reward": float("-inf"), | |
| "best_step": -1, | |
| "wall_exceeded": False, | |
| "step50_warned": False, | |
| "step500_warned": False, | |
| "format_warned_at": -1, | |
| "grad_norm_warned_at": -1, | |
| "beat_rate_history": [], | |
| } | |
| class _RolloutCallback(TrainerCallback): | |
| # ------------------------------------------------------------------ # | |
| # Per-step instrumentation # | |
| # ------------------------------------------------------------------ # | |
| def on_step_end(self, args, state_, control, **kwargs): # noqa: D401 | |
| entries, keys = cache.drain_step() | |
| if not entries: | |
| return | |
| step = state_.global_step | |
| # Group entries by prompt so we can compute within-group stats. | |
| groups: list[list[_ScoredCompletion]] = [] | |
| current_prompt = None | |
| current: list[_ScoredCompletion] = [] | |
| for (p, _), e in zip(keys, entries): | |
| if p != current_prompt and current: | |
| groups.append(current) | |
| current = [] | |
| current_prompt = p | |
| current.append(e) | |
| if current: | |
| groups.append(current) | |
| # ----- Tier-1 metrics ----- # | |
| totals = [e.weighted_total for e in entries] | |
| if not totals: | |
| return | |
| mean_total = sum(totals) / len(totals) | |
| within_stds: list[float] = [] | |
| uniques: list[int] = [] | |
| for grp in groups: | |
| if len(grp) < 2: | |
| within_stds.append(0.0) | |
| uniques.append(1) | |
| continue | |
| vals = [e.weighted_total for e in grp] | |
| mu = sum(vals) / len(vals) | |
| var = sum((v - mu) ** 2 for v in vals) / len(vals) | |
| within_stds.append(var ** 0.5) | |
| key_set = {(tuple(e.x_pred), tuple(e.z_pred)) for e in grp} | |
| uniques.append(len(key_set)) | |
| mean_within_std = sum(within_stds) / max(1, len(within_stds)) | |
| mean_unique = sum(uniques) / max(1, len(uniques)) | |
| # GRPO advantage (recomputed locally for the log only). | |
| adv_abs: list[float] = [] | |
| for grp in groups: | |
| if len(grp) < 2: | |
| continue | |
| vals = [e.weighted_total for e in grp] | |
| mu = sum(vals) / len(vals) | |
| var = sum((v - mu) ** 2 for v in vals) / len(vals) | |
| std = max((var ** 0.5), 1e-4) | |
| adv_abs.extend(abs((v - mu) / std) for v in vals) | |
| mean_adv_abs = sum(adv_abs) / max(1, len(adv_abs)) | |
| wandb_utils.log({ | |
| "train/total_reward_mean": mean_total, | |
| "train/reward_std_within_group": mean_within_std, | |
| "train/completion_uniqueness": mean_unique, | |
| "train/advantage_mean_abs": mean_adv_abs, | |
| "train/global_step": step, | |
| }, step=step) | |
| wandb_utils.log_reward_breakdown( | |
| [e.rewards for e in entries], step=step, prefix="train", | |
| ) | |
| wandb_utils.log({ | |
| "train/reward_bounds_violations_total": cache._bounds_violations, | |
| "train/env_episodes_total": cache._episodes, | |
| "train/env_timeouts_total": cache._timeouts, | |
| }, step=step) | |
| # ----- Decision rule: step 50 within-group variance ----- # | |
| if (not state["step50_warned"] | |
| and step >= decision_thresholds["reward_std_check_step"]): | |
| if mean_within_std < decision_thresholds["reward_std_floor"]: | |
| print(f"\n[grpo-decision] CRITICAL @ step {step}: " | |
| f"train/reward_std_within_group={mean_within_std:.4f} " | |
| f"< {decision_thresholds['reward_std_floor']}. The " | |
| f"within-group reward std has collapsed; GRPO has " | |
| f"effectively zero advantage signal. Pausing for " | |
| f"manual review (warning only - no auto-action).") | |
| wandb_utils.log({ | |
| "alarms/reward_std_collapse": 1.0, | |
| "alarms/reward_std_value": mean_within_std, | |
| }, step=step) | |
| state["step50_warned"] = True | |
| # Compliance Section 8 (audit, 2026-04): continuous warning | |
| # for reward_std < 0.02 at ANY step, not only step 50. We | |
| # throttle to once per 100 steps so the message doesn't | |
| # spam every 5-step log line. The existing step-50 gate | |
| # above stays as the harder "pause for review" check at | |
| # the higher 0.03 threshold; this continuous one fires | |
| # earlier the moment within-group variance crosses the | |
| # spec floor and tells the operator to look at the run. | |
| CONT_REWARD_STD_FLOOR = 0.02 | |
| if mean_within_std < CONT_REWARD_STD_FLOOR: | |
| last_warn = state.get("reward_std_warned_at", -1) | |
| if step - last_warn >= 100: | |
| print(f"\n[grpo-warn] @ step {step}: " | |
| f"train/reward_std_within_group=" | |
| f"{mean_within_std:.4f} < {CONT_REWARD_STD_FLOOR} " | |
| f"(continuous alarm). GRPO advantage signal is " | |
| f"vanishing - inspect generations / temperature.") | |
| wandb_utils.log({ | |
| "alarms/reward_std_continuous_low": 1.0, | |
| "alarms/reward_std_value": mean_within_std, | |
| }, step=step) | |
| state["reward_std_warned_at"] = step | |
| # ----- Mode-collapse inspection hook ----- # | |
| for u in uniques: | |
| recent_uniques.append(u) | |
| if (inspection_every and step > 0 | |
| and step % inspection_every == 0 | |
| and len(recent_uniques) >= inspection_sample_n): | |
| last = list(recent_uniques)[-inspection_sample_n:] | |
| collapsed_count = sum(1 for u in last if u == 1) | |
| if collapsed_count > inspection_collapse_threshold: | |
| cur_temp = float(getattr(args, "temperature", 1.2)) | |
| # Cap the bump at 2.0 - going higher does not actually | |
| # produce more diversity (sampler is already at top-k=50 | |
| # / top-p=0.95) and every distinct value re-specializes | |
| # the torch.compile cache, eventually tripping | |
| # FailOnRecompileLimitHit even with raised limits. | |
| new_temp = min(2.0, cur_temp + temp_bump_on_collapse) | |
| if new_temp <= cur_temp + 1e-6: | |
| print(f"\n[grpo-inspection] WARN @ step {step}: " | |
| f"{collapsed_count}/{inspection_sample_n} prompts " | |
| f"collapsed but temperature already at cap " | |
| f"({cur_temp:.2f}); leaving unchanged.") | |
| else: | |
| print(f"\n[grpo-inspection] WARN @ step {step}: " | |
| f"{collapsed_count}/{inspection_sample_n} of the " | |
| f"most recent prompts had ALL 4 generations " | |
| f"identical. Bumping rollout temperature " | |
| f"{cur_temp:.2f} -> {new_temp:.2f}.") | |
| try: | |
| args.temperature = new_temp | |
| except Exception as exc: | |
| print(f"[grpo-inspection] could not patch temperature " | |
| f"on TRL args: {exc!r}") | |
| wandb_utils.log({ | |
| "alarms/mode_collapse_count": collapsed_count, | |
| "train/temperature_after_bump": new_temp, | |
| }, step=step) | |
| # ----- Sample-completion table every sample_every steps ----- # | |
| if sample_every and step > 0 and step % sample_every == 0: | |
| rows_out = [] | |
| # First sample_n unique prompts in this batch; emit a row per | |
| # generation (so the W&B table has gen_idx as a column). | |
| chosen_groups: list[tuple[str, list[_ScoredCompletion]]] = [] | |
| seen_prompts: set = set() | |
| for (p, _), e in zip(keys, entries): | |
| if p in seen_prompts: | |
| for q, grp in chosen_groups: | |
| if q == p: | |
| grp.append(e) | |
| break | |
| continue | |
| if len(chosen_groups) >= sample_n: | |
| continue | |
| chosen_groups.append((p, [e])) | |
| seen_prompts.add(p) | |
| for prompt, grp in chosen_groups: | |
| for gi, e in enumerate(grp[:4]): | |
| rows_out.append({ | |
| "step": step, | |
| "prompt": prompt[:600], | |
| "gen_idx": gi, | |
| "x_pred": ",".join(map(str, e.x_pred)), | |
| "z_pred": ",".join(map(str, e.z_pred)), | |
| "logical_correction": | |
| e.rewards.get("logical_correction", 0.0), | |
| "syndrome_consistency": | |
| e.rewards.get("syndrome_consistency", 0.0), | |
| "hamming_overlap": | |
| e.rewards.get("hamming_overlap", 0.0), | |
| "format_compliance": | |
| e.rewards.get("format_compliance", 0.0), | |
| "pymatching_beat": | |
| e.rewards.get("pymatching_beat", 0.0), | |
| "weighted_total": e.weighted_total, | |
| "parse_success": e.parse_success, | |
| "actual_obs_flip": e.actual_flip, | |
| "pm_obs_flip": e.pm_flip, | |
| "curriculum_level": e.curriculum_level, | |
| }) | |
| if rows_out: | |
| wandb_utils.log_generation_table( | |
| rows_out, step=step, table_name="rl/generations", | |
| columns=[ | |
| "step", "prompt", "gen_idx", "x_pred", "z_pred", | |
| "logical_correction", "syndrome_consistency", | |
| "hamming_overlap", "format_compliance", | |
| "pymatching_beat", "weighted_total", | |
| "parse_success", "actual_obs_flip", "pm_obs_flip", | |
| "curriculum_level", | |
| ], | |
| ) | |
| # ----- Wall-clock cap ----- # | |
| elapsed = time.time() - started_at | |
| if elapsed > wall_seconds and not state["wall_exceeded"]: | |
| state["wall_exceeded"] = True | |
| print(f"\n[grpo-walltime] wall-clock cap hit at step {step} " | |
| f"({elapsed:.0f}s > {wall_seconds:.0f}s). " | |
| f"Saving and exiting.") | |
| try: | |
| control.should_save = True | |
| control.should_training_stop = True | |
| except Exception: | |
| pass | |
| wandb_utils.log({ | |
| "alarms/wall_exceeded": 1.0, | |
| "alarms/wall_seconds_at_cap": elapsed, | |
| }, step=step) | |
| # ----- Tier-2 + tier-3 eval ----- # | |
| if inloop_every and step > 0 and step % inloop_every == 0: | |
| self._run_inloop_eval(step) | |
| def on_log(self, args, state_, control, logs=None, **kwargs): # noqa: D401 | |
| if not logs: | |
| return | |
| step = state_.global_step | |
| # Tier-1: surface train/* metrics that TRL itself produces. | |
| extra: dict = {} | |
| for src_key, dst_key in [ | |
| ("kl", "train/kl_divergence"), | |
| ("train/kl_divergence", "train/kl_divergence"), | |
| ("grad_norm", "train/grad_norm"), | |
| ("loss", "train/policy_loss"), | |
| ("learning_rate", "train/learning_rate"), | |
| ]: | |
| if src_key in logs: | |
| try: | |
| extra[dst_key] = float(logs[src_key]) | |
| except (TypeError, ValueError): | |
| pass | |
| if extra: | |
| wandb_utils.log(extra, step=step) | |
| # KL alarm. | |
| kl = logs.get("kl") or logs.get("train/kl_divergence") | |
| if kl is not None: | |
| try: | |
| kl_v = float(kl) | |
| except (TypeError, ValueError): | |
| kl_v = None | |
| if kl_v is not None and kl_v > kl_alarm: | |
| wandb_utils.log({ | |
| "alarms/kl_alarm": 1.0, | |
| "alarms/kl_alarm_value": kl_v, | |
| }, step=step) | |
| print(f"[grpo][step {step}] KL ALARM: {kl_v:.3f} " | |
| f"> {kl_alarm:.3f} - inspect generations.") | |
| # Decision rule: grad_norm > ceil for N consecutive logs. | |
| gn = logs.get("grad_norm") | |
| if gn is not None: | |
| try: | |
| gn_v = float(gn) | |
| except (TypeError, ValueError): | |
| gn_v = None | |
| if gn_v is not None: | |
| grad_norm_run.append(gn_v) | |
| ceil = decision_thresholds["grad_norm_ceil"] | |
| run_len = decision_thresholds["grad_norm_run_len"] | |
| if (len(grad_norm_run) >= run_len | |
| and all(x > ceil for x in grad_norm_run) | |
| and step != state["grad_norm_warned_at"]): | |
| print(f"\n[grpo-decision] CRITICAL @ step {step}: " | |
| f"train/grad_norm > {ceil} for {run_len} " | |
| f"consecutive logs ({list(grad_norm_run)}). " | |
| f"Recommend reducing LR (warning only - no " | |
| f"auto-action).") | |
| wandb_utils.log({ | |
| "alarms/grad_norm_runaway": 1.0, | |
| "alarms/grad_norm_value": gn_v, | |
| }, step=step) | |
| state["grad_norm_warned_at"] = step | |
| def on_train_end(self, args, state_, control, **kwargs): # noqa: D401 | |
| self._run_inloop_eval(state_.global_step, table_name="rl/final_eval") | |
| # ------------------------------------------------------------------ # | |
| # Tier-2 / tier-3 eval (greedy, T=0) # | |
| # ------------------------------------------------------------------ # | |
| def _run_inloop_eval(self, step: int, table_name: str = "rl/inloop_eval"): | |
| try: | |
| from unsloth import FastLanguageModel | |
| FastLanguageModel.for_inference(model) | |
| except Exception: | |
| model.eval() # type: ignore[attr-defined] | |
| n = len(eval_rows) | |
| stats = { | |
| "logical_correction": 0, | |
| "format_success": 0, | |
| "format_partial": 0, | |
| "pymatching_beat": 0, | |
| "syndrome_consistency_pass": 0, | |
| "exact_match_pymatching": 0, | |
| "total_reward_sum": 0.0, | |
| "completion_len_sum": 0, | |
| "hard_lcr_num": 0, | |
| "hard_lcr_den": 0, | |
| "ler_d3_p001_logical_errors": 0, | |
| "ler_d3_p001_total": 0, | |
| "ler_d3_p001_rounds": 0, | |
| } | |
| preview_rows = [] | |
| pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id | |
| import torch | |
| from qubit_medic.config import REWARD_WEIGHTS | |
| for ep_idx, row in enumerate(eval_rows): | |
| prompt = row["prompt"] | |
| episode_id = int(row.get("episode_id", -1)) | |
| try: | |
| chat = [{"role": "user", "content": prompt}] | |
| text = tokenizer.apply_chat_template( | |
| chat, tokenize=False, add_generation_prompt=True, | |
| ) | |
| except Exception: | |
| text = ("<|im_start|>user\n" + prompt | |
| + "\n<|im_end|>\n<|im_start|>assistant\n") | |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
| try: | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=inloop_max_new_tokens, | |
| do_sample=False, # greedy at T=0 per spec | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=pad_id, | |
| ) | |
| gen_ids = out[0][inputs["input_ids"].shape[1]:] | |
| completion = tokenizer.decode(gen_ids, skip_special_tokens=True) | |
| n_tokens = int(gen_ids.shape[0]) | |
| except Exception as exc: # pragma: no cover | |
| completion = f"<gen-error: {exc}>" | |
| n_tokens = 0 | |
| # Score against the env. If episode_id has TTL'd we fall back | |
| # to a fresh reset so the run continues, but log nothing | |
| # special - the metric arithmetic is still correct. | |
| try: | |
| result = env_client.step(raw_response=completion, | |
| episode_id=episode_id) | |
| except Exception: | |
| obs2 = env_client.reset(seed=row.get("seed")) | |
| result = env_client.step(raw_response=completion, | |
| episode_id=obs2.episode_id) | |
| rwd = result.info.get("rewards", {}) or {} | |
| action = result.info.get("parsed_action", {}) or {} | |
| actual = int(result.info.get("actual_observable_flip", 0)) | |
| pm_pred = int(result.info.get("pymatching_observable_pred", 0)) | |
| we_correct = float(rwd.get("logical_correction", 0.0)) >= 0.5 | |
| pm_correct = (pm_pred == actual) | |
| stats["logical_correction"] += int(we_correct) | |
| stats["format_success"] += int(action.get("parse_success", False)) | |
| stats["format_partial"] += int( | |
| float(rwd.get("format_compliance", 0.0)) >= 0.5 | |
| and not action.get("parse_success", False) | |
| ) | |
| stats["pymatching_beat"] += int( | |
| float(rwd.get("pymatching_beat", 0.0)) >= 0.5) | |
| stats["syndrome_consistency_pass"] += int( | |
| float(rwd.get("syndrome_consistency", 0.0)) >= 0.999) | |
| weighted = sum( | |
| weight * max(0.0, min(1.0, float(rwd.get(name, 0.0)))) | |
| for name, weight in REWARD_WEIGHTS.items() | |
| ) | |
| stats["total_reward_sum"] += max(0.0, min(1.0, weighted)) | |
| stats["completion_len_sum"] += n_tokens | |
| pm_x = sorted(set(map(int, | |
| result.info.get("pymatching_x_errors", []) or []))) | |
| pm_z = sorted(set(map(int, | |
| result.info.get("pymatching_z_errors", []) or []))) | |
| our_x = sorted(set(map(int, | |
| action.get("x_error_qubits", []) or []))) | |
| our_z = sorted(set(map(int, | |
| action.get("z_error_qubits", []) or []))) | |
| if (action.get("parse_success", False) | |
| and pm_x == our_x and pm_z == our_z): | |
| stats["exact_match_pymatching"] += 1 | |
| # Hard syndrome: >=2 stabilizers fired (anti-hacking spec | |
| # forbids exposing true_x/true_z, so we use the syndrome | |
| # bit count from the cached eval row as the proxy). | |
| n_active = sum(1 for b in row.get("syndrome_bits", []) if int(b)) | |
| if n_active >= 2: | |
| stats["hard_lcr_den"] += 1 | |
| stats["hard_lcr_num"] += int(we_correct) | |
| # tier-3: per-round LER for d=3 / p=0.001 only. | |
| d = int(row.get("distance", 0)) | |
| rnds = max(1, int(row.get("rounds", 0))) | |
| if d == 3 and abs(float(row.get("p", 0.0)) - 0.001) < 1e-6: | |
| stats["ler_d3_p001_total"] += 1 | |
| stats["ler_d3_p001_rounds"] = rnds | |
| if not we_correct: | |
| stats["ler_d3_p001_logical_errors"] += 1 | |
| if ep_idx < 4: | |
| preview_rows.append({ | |
| "step": step, | |
| "episode": ep_idx, | |
| "completion": completion[:300], | |
| "logical_correction": rwd.get("logical_correction", 0.0), | |
| "syndrome_consistency": rwd.get("syndrome_consistency", 0.0), | |
| "format_compliance": rwd.get("format_compliance", 0.0), | |
| "pymatching_beat": rwd.get("pymatching_beat", 0.0), | |
| "weighted_total": weighted, | |
| }) | |
| denom = max(1, n) | |
| lcr = stats["logical_correction"] / denom | |
| beat_rate = stats["pymatching_beat"] / denom | |
| fmt_compliance = stats["format_success"] / denom | |
| hard_lcr = (stats["hard_lcr_num"] / max(1, stats["hard_lcr_den"]) | |
| if stats["hard_lcr_den"] else 0.0) | |
| sync_consistency_rate = stats["syndrome_consistency_pass"] / denom | |
| avg_completion_len = stats["completion_len_sum"] / denom | |
| mean_total_reward = stats["total_reward_sum"] / denom | |
| exact_match = stats["exact_match_pymatching"] / denom | |
| # Tier-3 LER per round, log10. | |
| ler_per_round = None | |
| ler_log10 = None | |
| if stats["ler_d3_p001_total"] > 0: | |
| p_logical = (stats["ler_d3_p001_logical_errors"] | |
| / stats["ler_d3_p001_total"]) | |
| rounds = max(1, stats["ler_d3_p001_rounds"]) | |
| # Per-round LER: 1 - (1 - p_logical)^(1/rounds). | |
| ler_per_round = 1.0 - (1.0 - max(0.0, min(1.0, p_logical))) ** (1.0 / rounds) | |
| if ler_per_round > 0: | |
| import math | |
| ler_log10 = math.log10(max(ler_per_round, 1e-12)) | |
| # Tier-2 output diversity probe at T=1.0 (8 samples per prompt | |
| # on a small subset to keep eval fast). | |
| div_probe_n = min(8, len(eval_rows)) | |
| div_samples = 8 | |
| unique_counts: list[int] = [] | |
| for row in eval_rows[:div_probe_n]: | |
| prompt = row["prompt"] | |
| try: | |
| chat = [{"role": "user", "content": prompt}] | |
| text = tokenizer.apply_chat_template( | |
| chat, tokenize=False, add_generation_prompt=True, | |
| ) | |
| except Exception: | |
| text = ("<|im_start|>user\n" + prompt | |
| + "\n<|im_end|>\n<|im_start|>assistant\n") | |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
| outs = [] | |
| for _ in range(div_samples): | |
| try: | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=inloop_max_new_tokens, | |
| do_sample=True, | |
| temperature=1.0, | |
| top_p=0.95, | |
| top_k=50, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=pad_id, | |
| ) | |
| gen = tokenizer.decode( | |
| out[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True, | |
| ).strip() | |
| except Exception: | |
| gen = "" | |
| outs.append(gen) | |
| unique_counts.append(len(set(outs))) | |
| output_diversity_t1 = (sum(unique_counts) / max(1, len(unique_counts)) | |
| if unique_counts else 0.0) | |
| eval_metrics = { | |
| "eval/logical_correction_rate": lcr, | |
| "eval/pymatching_beat_rate": beat_rate, | |
| "eval/format_compliance": fmt_compliance, | |
| "eval/exact_match_pymatching": exact_match, | |
| "eval/hard_syndrome_lcr": hard_lcr, | |
| "eval/syndrome_consistency_rate": sync_consistency_rate, | |
| "eval/avg_completion_length": avg_completion_len, | |
| "eval/output_diversity_temp_1": output_diversity_t1, | |
| "eval/total_reward_mean": mean_total_reward, | |
| "eval/episodes": denom, | |
| } | |
| if ler_per_round is not None: | |
| eval_metrics["eval/ler_per_round_d3_p001"] = ler_per_round | |
| if ler_log10 is not None: | |
| eval_metrics["eval/ler_per_round_log10"] = ler_log10 | |
| print(f"[grpo][eval@{step}] " + ", ".join( | |
| f"{k.split('/')[-1]}={v:.4f}" if isinstance(v, float) | |
| else f"{k.split('/')[-1]}={v}" for k, v in eval_metrics.items() | |
| )) | |
| wandb_utils.log(eval_metrics, step=step) | |
| if preview_rows: | |
| wandb_utils.log_generation_table( | |
| preview_rows, step=step, table_name=table_name, | |
| ) | |
| # Decision rule: step-500 pymatching_beat sanity. | |
| state["beat_rate_history"].append(beat_rate) | |
| if len(state["beat_rate_history"]) > 5: | |
| state["beat_rate_history"] = state["beat_rate_history"][-5:] | |
| if (not state["step500_warned"] | |
| and step >= decision_thresholds["beat_rate_check_step"] | |
| and len(state["beat_rate_history"]) >= 5 | |
| and all(b == 0 for b in state["beat_rate_history"])): | |
| print(f"\n[grpo-decision] WARN @ step {step}: " | |
| f"eval/pymatching_beat_rate has been 0.0 across the last " | |
| f"5 evals. The model is never finding syndromes where " | |
| f"PyMatching fails - consider increasing the " | |
| f"pymatching_beat reward weight (warning only).") | |
| wandb_utils.log({"alarms/zero_beat_rate": 1.0}, step=step) | |
| state["step500_warned"] = True | |
| # Decision rule: format_compliance < floor. | |
| if (fmt_compliance < decision_thresholds["format_floor"] | |
| and step != state["format_warned_at"]): | |
| print(f"\n[grpo-decision] WARN @ step {step}: " | |
| f"eval/format_compliance={fmt_compliance:.3f} < " | |
| f"{decision_thresholds['format_floor']}. Consider " | |
| f"increasing format_compliance weight (warning only).") | |
| wandb_utils.log({ | |
| "alarms/format_below_floor": 1.0, | |
| "alarms/format_value": fmt_compliance, | |
| }, step=step) | |
| state["format_warned_at"] = step | |
| # ----- Best-checkpoint tracking ----- # | |
| if mean_total_reward > state["best_total_reward"]: | |
| old = state["best_total_reward"] | |
| state["best_total_reward"] = mean_total_reward | |
| state["best_step"] = step | |
| print(f"[grpo][eval@{step}] new best total_reward_mean=" | |
| f"{mean_total_reward:.4f} (prev {old:.4f}); " | |
| f"saving to {best_dir}") | |
| try: | |
| if best_dir.exists(): | |
| shutil.rmtree(best_dir) | |
| best_dir.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(str(best_dir)) | |
| tokenizer.save_pretrained(str(best_dir)) | |
| wandb_utils.update_summary({ | |
| "best/total_reward_mean": mean_total_reward, | |
| "best/step": step, | |
| }) | |
| except Exception as exc: | |
| print(f"[grpo] WARN: failed to save best checkpoint: " | |
| f"{exc!r}", file=sys.stderr) | |
| # Switch back to training mode. | |
| try: | |
| from unsloth import FastLanguageModel | |
| FastLanguageModel.for_training(model) | |
| except Exception: | |
| model.train() # type: ignore[attr-defined] | |
| return _RolloutCallback() | |
| # --------------------------------------------------------------------------- # | |
| # Dataset of prompts # | |
| # --------------------------------------------------------------------------- # | |
| def _build_prompt_pool(env_client, n: int): | |
| prompts = [] | |
| for _ in range(n): | |
| obs = env_client.reset() | |
| prompts.append({"prompt": obs.prompt, "episode_id": obs.episode_id}) | |
| return prompts | |
| # --------------------------------------------------------------------------- # | |
| # Main # | |
| # --------------------------------------------------------------------------- # | |
| def main(argv: Iterable[str] = ()) -> int: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--sft-checkpoint", type=str, default=None, | |
| help="LoRA adapter directory to start GRPO from. " | |
| "Defaults to config.SFT_CHECKPOINT_PATH_FOR_GRPO " | |
| "(checkpoints/sft_warmup/checkpoint-50).") | |
| parser.add_argument("--output", type=str, default="checkpoints/grpo") | |
| parser.add_argument("--model", type=str, | |
| default=os.getenv( | |
| "QUBIT_MEDIC_MODEL", | |
| "unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit"), | |
| help="Base model. Defaults to the 4-bit unsloth bundle " | |
| "matching the SFT base.") | |
| parser.add_argument("--steps", type=int, default=None) | |
| parser.add_argument("--gen-per-prompt", type=int, default=None) | |
| parser.add_argument("--lr", type=float, default=None) | |
| parser.add_argument("--kl-coef", type=float, default=None) | |
| parser.add_argument("--max-prompt-len", type=int, default=None) | |
| parser.add_argument("--max-completion-len", type=int, default=None) | |
| parser.add_argument("--seed", type=int, default=None) | |
| parser.add_argument("--report-to", type=str, default="wandb") | |
| parser.add_argument("--prompt-pool", type=int, default=512) | |
| parser.add_argument("--wandb-run-name", type=str, default=None) | |
| parser.add_argument("--wandb-group", type=str, default=None) | |
| parser.add_argument("--wandb-tags", type=str, nargs="*", default=("grpo",)) | |
| parser.add_argument("--wandb-notes", type=str, default=None) | |
| parser.add_argument("--sample-every", type=int, default=None) | |
| parser.add_argument("--sample-n", type=int, default=None) | |
| parser.add_argument("--inloop-eval-every", type=int, default=None) | |
| parser.add_argument("--inloop-eval-episodes", type=int, default=None) | |
| parser.add_argument("--kl-alarm", type=float, default=None) | |
| parser.add_argument("--no-artifact", action="store_true") | |
| parser.add_argument("--skip-preflight", action="store_true", | |
| help="Skip the diversity preflight (DEBUG ONLY)") | |
| args = parser.parse_args(list(argv)) | |
| # Lazy heavy imports. | |
| try: | |
| from unsloth import FastLanguageModel | |
| except ImportError: | |
| print("ERROR: unsloth not installed. " | |
| "Run `pip install -r requirements-train.txt`", file=sys.stderr) | |
| return 1 | |
| import torch | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| # Belt-and-suspenders for the dynamo recompile-limit crash that killed a | |
| # previous run at step 400. Env vars at the top of the file cover the | |
| # case where torch hasn't been imported yet; this block covers the case | |
| # where unsloth/torch were already imported (env vars no-op then) and | |
| # also flips suppress_errors so any future overflow falls back to eager | |
| # instead of raising FailOnRecompileLimitHit. | |
| try: | |
| import torch._dynamo as _dynamo | |
| for _attr in ("cache_size_limit", "recompile_limit", | |
| "accumulated_cache_size_limit"): | |
| if hasattr(_dynamo.config, _attr): | |
| setattr(_dynamo.config, _attr, | |
| max(256, getattr(_dynamo.config, _attr))) | |
| _dynamo.config.suppress_errors = True | |
| except Exception as _exc: # pragma: no cover - defensive | |
| print(f"[grpo-guard] WARNING: could not raise dynamo limits: " | |
| f"{_exc!r}", file=sys.stderr) | |
| # Pre-flight signature check + stale-cache wipe. | |
| _wipe_stale_grpo_cache() | |
| _assert_grpo_signature_compatible() | |
| from qubit_medic import wandb_utils | |
| from qubit_medic.client.client import make_default_client | |
| from qubit_medic.config import ( | |
| GRPO_BATCH_SIZE, GRPO_CHECKPOINT_EVERY, GRPO_DECISION_BEAT_RATE_CHECK_STEP, | |
| GRPO_DECISION_FORMAT_FLOOR, GRPO_DECISION_GRAD_NORM_CEIL, | |
| GRPO_DECISION_GRAD_NORM_RUN_LEN, GRPO_DECISION_REWARD_STD_CHECK_STEP, | |
| GRPO_DECISION_REWARD_STD_FLOOR, GRPO_DO_SAMPLE, GRPO_GEN_PER_PROMPT, | |
| GRPO_GRAD_ACCUM, GRPO_INSPECTION_COLLAPSE_THRESHOLD, | |
| GRPO_INSPECTION_HOOK_EVERY, GRPO_INSPECTION_SAMPLE_N, GRPO_KL_ALARM, | |
| GRPO_KL_COEF, GRPO_LOG_EVERY, GRPO_LR, GRPO_LR_SCHEDULER, | |
| GRPO_MAX_COMPLETION_LEN, GRPO_MAX_PROMPT_LEN, GRPO_OPTIMIZER, | |
| GRPO_REPETITION_PENALTY, GRPO_SAMPLE_LOG_EVERY, GRPO_SAMPLE_LOG_N, | |
| GRPO_SAVE_TOTAL_LIMIT, GRPO_STEPS, GRPO_TEMP_BUMP_ON_COLLAPSE, | |
| GRPO_TEMPERATURE, GRPO_TOP_K, GRPO_TOP_P, GRPO_VAL_EPISODES, | |
| GRPO_VAL_PATH, GRPO_VAL_SEED, GRPO_WALL_SECONDS, LORA_ALPHA, LORA_DROPOUT, | |
| LORA_R, LORA_TARGET_MODULES, MODEL_ID, PRIMARY_SEED, REWARD_WEIGHTS, | |
| SFT_CHECKPOINT_PATH_FOR_GRPO, WANDB_INLOOP_EVAL_EPISODES, | |
| WANDB_INLOOP_EVAL_EVERY, | |
| ) | |
| sft_ckpt = args.sft_checkpoint or SFT_CHECKPOINT_PATH_FOR_GRPO | |
| steps = args.steps if args.steps is not None else GRPO_STEPS | |
| gen_per_prompt = args.gen_per_prompt if args.gen_per_prompt is not None else GRPO_GEN_PER_PROMPT | |
| lr = args.lr if args.lr is not None else GRPO_LR | |
| kl_coef = args.kl_coef if args.kl_coef is not None else GRPO_KL_COEF | |
| max_p = args.max_prompt_len if args.max_prompt_len is not None else GRPO_MAX_PROMPT_LEN | |
| max_c = args.max_completion_len if args.max_completion_len is not None else GRPO_MAX_COMPLETION_LEN | |
| seed = args.seed if args.seed is not None else PRIMARY_SEED | |
| sample_every = args.sample_every if args.sample_every is not None else GRPO_SAMPLE_LOG_EVERY | |
| sample_n = args.sample_n if args.sample_n is not None else GRPO_SAMPLE_LOG_N | |
| inloop_every = args.inloop_eval_every if args.inloop_eval_every is not None else WANDB_INLOOP_EVAL_EVERY | |
| inloop_episodes = args.inloop_eval_episodes if args.inloop_eval_episodes is not None else WANDB_INLOOP_EVAL_EPISODES | |
| kl_alarm = args.kl_alarm if args.kl_alarm is not None else GRPO_KL_ALARM | |
| _seed_everything(seed) | |
| # ---- Env client --------------------------------------------------- # | |
| env_client = make_default_client() | |
| print(f"using env client: {type(env_client).__name__}; " | |
| f"health = {env_client.health()}") | |
| # ---- W&B init ----------------------------------------------------- # | |
| report_to = wandb_utils.derive_report_to(args.report_to) | |
| run_name = args.wandb_run_name or wandb_utils.make_run_name("grpo") | |
| wandb_utils.init_run( | |
| run_name=run_name, | |
| job_type="grpo", | |
| tags=args.wandb_tags, | |
| notes=args.wandb_notes, | |
| group=args.wandb_group, | |
| extra_config={ | |
| "cli": { | |
| "steps": steps, | |
| "gen_per_prompt": gen_per_prompt, | |
| "lr": lr, | |
| "kl_coef": kl_coef, | |
| "max_prompt_len": max_p, | |
| "max_completion_len": max_c, | |
| "prompt_pool": args.prompt_pool, | |
| "sample_every": sample_every, | |
| "sample_n": sample_n, | |
| "inloop_eval_every": inloop_every, | |
| "inloop_eval_episodes": inloop_episodes, | |
| "kl_alarm": kl_alarm, | |
| "temperature": GRPO_TEMPERATURE, | |
| "top_p": GRPO_TOP_P, | |
| "top_k": GRPO_TOP_K, | |
| "repetition_penalty": GRPO_REPETITION_PENALTY, | |
| "do_sample": GRPO_DO_SAMPLE, | |
| "lr_scheduler": GRPO_LR_SCHEDULER, | |
| "optimizer": GRPO_OPTIMIZER, | |
| "grad_accum": GRPO_GRAD_ACCUM, | |
| "effective_batch": GRPO_BATCH_SIZE * GRPO_GRAD_ACCUM, | |
| "sft_checkpoint": sft_ckpt, | |
| "model": args.model, | |
| "seed": seed, | |
| "report_to": report_to, | |
| "wall_seconds": GRPO_WALL_SECONDS, | |
| "reward_weights": dict(REWARD_WEIGHTS), | |
| "val_seed": GRPO_VAL_SEED, | |
| "val_episodes": GRPO_VAL_EPISODES, | |
| }, | |
| }, | |
| ) | |
| # Use train/global_step as default x-axis for everything we log. | |
| try: | |
| run = wandb_utils.get_run() | |
| if run is not None: | |
| run.define_metric("train/global_step") | |
| run.define_metric("train/*", step_metric="train/global_step") | |
| run.define_metric("eval/*", step_metric="train/global_step") | |
| run.define_metric("alarms/*", step_metric="train/global_step") | |
| run.define_metric("rl/*", step_metric="train/global_step") | |
| run.define_metric("best/*", step_metric="train/global_step") | |
| except Exception as exc: | |
| print(f"[wandb] could not define x-axis metric: {exc!r}", file=sys.stderr) | |
| # ---- Build prompt pool -------------------------------------------- # | |
| print(f"pre-generating {args.prompt_pool} prompts ...") | |
| prompts = _build_prompt_pool(env_client, args.prompt_pool) | |
| dataset = Dataset.from_list(prompts) | |
| print(f" built dataset with {len(dataset)} prompts") | |
| # ---- Frozen eval set --------------------------------------------- # | |
| eval_rows = _load_or_build_eval_set( | |
| env_client, seed=GRPO_VAL_SEED, n=inloop_episodes, path=GRPO_VAL_PATH, | |
| ) | |
| # ---- Load model --------------------------------------------------- # | |
| print(f"loading base={args.model}, sft adapter={sft_ckpt}") | |
| base_for_load = sft_ckpt if Path(sft_ckpt).exists() else args.model | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=base_for_load, | |
| max_seq_length=max_p + max_c, | |
| load_in_4bit=True, | |
| dtype=None, | |
| ) | |
| if not Path(sft_ckpt).exists(): | |
| print(f"[grpo] WARN: SFT checkpoint {sft_ckpt} not found; " | |
| f"attaching fresh LoRA on the base model") | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=LORA_R, | |
| lora_alpha=LORA_ALPHA, | |
| target_modules=list(LORA_TARGET_MODULES), | |
| lora_dropout=LORA_DROPOUT, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=seed, | |
| ) | |
| # ---- Diversity preflight ----------------------------------------- # | |
| if not args.skip_preflight: | |
| ok = _diversity_preflight( | |
| model, tokenizer, | |
| val_path="data/sft_validation.jsonl", | |
| n_prompts=5, n_samples_per_prompt=8, | |
| temperature=GRPO_TEMPERATURE, | |
| min_unique=3, min_passing=3, | |
| max_new_tokens=max_c, | |
| ) | |
| if not ok: | |
| wandb_utils.update_summary({"preflight/passed": False}) | |
| wandb_utils.finish_run() | |
| return 2 | |
| wandb_utils.update_summary({"preflight/passed": True}) | |
| else: | |
| print("[grpo] --skip-preflight given; bypassing diversity preflight " | |
| "(DEBUG ONLY)") | |
| # ---- TRL GRPOConfig ---------------------------------------------- # | |
| output_dir = Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| best_dir = output_dir / "best" | |
| final_dir = output_dir / "final" | |
| bf16_supported = ( | |
| torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
| ) | |
| grpo_kwargs: dict = { | |
| "output_dir": str(output_dir), | |
| "max_steps": steps, | |
| "per_device_train_batch_size": GRPO_BATCH_SIZE, | |
| "gradient_accumulation_steps": GRPO_GRAD_ACCUM, | |
| "num_generations": gen_per_prompt, | |
| "max_prompt_length": max_p, | |
| "max_completion_length": max_c, | |
| "learning_rate": lr, | |
| "beta": kl_coef, | |
| "lr_scheduler_type": GRPO_LR_SCHEDULER, | |
| "optim": GRPO_OPTIMIZER, | |
| "logging_steps": GRPO_LOG_EVERY, | |
| "save_steps": GRPO_CHECKPOINT_EVERY, | |
| "save_total_limit": GRPO_SAVE_TOTAL_LIMIT, | |
| "save_only_model": False, | |
| "seed": seed, | |
| "bf16": bf16_supported, | |
| "fp16": torch.cuda.is_available() and not bf16_supported, | |
| "report_to": report_to, | |
| "run_name": run_name, | |
| # Diversity-focused rollout sampling. | |
| "temperature": GRPO_TEMPERATURE, | |
| "top_p": GRPO_TOP_P, | |
| "top_k": GRPO_TOP_K, | |
| "repetition_penalty": GRPO_REPETITION_PENALTY, | |
| } | |
| # Some TRL versions don't accept every sampling kwarg on GRPOConfig; | |
| # fall back gracefully so the script still runs. | |
| config = None | |
| dropped: list[str] = [] | |
| while config is None: | |
| try: | |
| config = GRPOConfig(**grpo_kwargs) | |
| except TypeError as exc: | |
| msg = str(exc) | |
| removed = False | |
| for k in ("repetition_penalty", "top_k", "top_p", "temperature", | |
| "save_only_model"): | |
| if k in msg and k in grpo_kwargs: | |
| grpo_kwargs.pop(k) | |
| dropped.append(k) | |
| removed = True | |
| break | |
| if not removed: | |
| raise | |
| if dropped: | |
| print(f"[grpo] WARN: TRL did not accept these GRPOConfig kwargs and " | |
| f"they were dropped: {dropped}. Using TRL defaults for them.") | |
| # ---- Reward functions + scoring cache ----------------------------- # | |
| cache = _BatchScoringCache(env_client=env_client, | |
| reward_weights=dict(REWARD_WEIGHTS)) | |
| reward_fns = _make_reward_fns(cache) | |
| # The first reward is the bounded weighted-total used for the gradient; | |
| # the rest are zero-weight observers used only for per-component traces. | |
| reward_weights = [1.0] + [0.0] * len(_REWARD_COMPONENTS) | |
| callbacks = [] | |
| cb = _build_wandb_callback( | |
| cache, model, tokenizer, env_client, eval_rows, | |
| sample_every=sample_every, sample_n=sample_n, | |
| inloop_every=inloop_every, | |
| inloop_max_new_tokens=max_c, | |
| kl_alarm=kl_alarm, | |
| inspection_every=GRPO_INSPECTION_HOOK_EVERY, | |
| inspection_sample_n=GRPO_INSPECTION_SAMPLE_N, | |
| inspection_collapse_threshold=GRPO_INSPECTION_COLLAPSE_THRESHOLD, | |
| temp_bump_on_collapse=GRPO_TEMP_BUMP_ON_COLLAPSE, | |
| best_dir=best_dir, output_dir=output_dir, | |
| wall_seconds=GRPO_WALL_SECONDS, | |
| decision_thresholds={ | |
| "reward_std_floor": GRPO_DECISION_REWARD_STD_FLOOR, | |
| "reward_std_check_step": GRPO_DECISION_REWARD_STD_CHECK_STEP, | |
| "beat_rate_check_step": GRPO_DECISION_BEAT_RATE_CHECK_STEP, | |
| "format_floor": GRPO_DECISION_FORMAT_FLOOR, | |
| "grad_norm_ceil": GRPO_DECISION_GRAD_NORM_CEIL, | |
| "grad_norm_run_len": GRPO_DECISION_GRAD_NORM_RUN_LEN, | |
| }, | |
| ) | |
| if cb is not None: | |
| callbacks.append(cb) | |
| # Older TRL versions: GRPOTrainer may not accept reward_weights kw. | |
| trainer_kwargs = dict( | |
| model=model, | |
| processing_class=tokenizer, | |
| args=config, | |
| train_dataset=dataset, | |
| reward_funcs=reward_fns, | |
| reward_weights=reward_weights, | |
| callbacks=callbacks, | |
| ) | |
| try: | |
| trainer = GRPOTrainer(**trainer_kwargs) | |
| except TypeError as exc: | |
| if "reward_weights" in str(exc): | |
| print("[grpo] WARN: this TRL does not accept reward_weights= " | |
| "on GRPOTrainer; falling back to using only the bounded " | |
| "weighted-total reward (and no observers).") | |
| trainer_kwargs.pop("reward_weights") | |
| trainer_kwargs["reward_funcs"] = [reward_fns[0]] | |
| trainer = GRPOTrainer(**trainer_kwargs) | |
| else: | |
| raise | |
| print(f"running GRPO for {steps} steps " | |
| f"(temperature={GRPO_TEMPERATURE}, top_p={GRPO_TOP_P}, " | |
| f"top_k={GRPO_TOP_K}, repetition_penalty={GRPO_REPETITION_PENALTY}, " | |
| f"beta={kl_coef}, lr={lr}) ...") | |
| started = time.time() | |
| train_result = trainer.train() | |
| elapsed = time.time() - started | |
| print(f"finished in {elapsed:.1f}s") | |
| metrics = getattr(train_result, "metrics", {}) or {} | |
| wandb_utils.update_summary({ | |
| "grpo/wall_seconds": elapsed, | |
| "grpo/total_episodes": cache._episodes, | |
| "grpo/total_timeouts": cache._timeouts, | |
| "grpo/reward_bounds_violations": cache._bounds_violations, | |
| **{f"grpo/final/{k}": v for k, v in metrics.items() | |
| if isinstance(v, (int, float))}, | |
| }) | |
| # ---- Final + rolling adapter saves ------------------------------- # | |
| print(f"saving rolling adapter snapshot to {output_dir}") | |
| model.save_pretrained(str(output_dir)) | |
| tokenizer.save_pretrained(str(output_dir)) | |
| print(f"saving final adapter snapshot to {final_dir}") | |
| final_dir.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(str(final_dir)) | |
| tokenizer.save_pretrained(str(final_dir)) | |
| if not args.no_artifact: | |
| wandb_utils.log_artifact( | |
| str(final_dir), | |
| name=f"grpo-final-{run_name}", | |
| artifact_type="model", | |
| description="GRPO final LoRA adapter (Qubit-Medic).", | |
| ) | |
| if best_dir.exists(): | |
| wandb_utils.log_artifact( | |
| str(best_dir), | |
| name=f"grpo-best-{run_name}", | |
| artifact_type="model", | |
| description="GRPO best-eval LoRA adapter (Qubit-Medic).", | |
| ) | |
| wandb_utils.finish_run() | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main(sys.argv[1:])) | |