Spaces:
Paused
Paused
| """SFT (Supervised Fine-Tuning) warm-start for CERNenv. | |
| Generates ``--num_episodes`` oracle trajectories against the local | |
| ``CERNCollisionEnvironment`` (no HTTP), turns each successful step into | |
| a (prompt, completion) pair using the *same* chat templating the GRPO | |
| loop uses (see ``training.training_script.build_dataset`` / | |
| ``training.llm_agent.build_chat``), and runs ``trl.SFTTrainer`` with | |
| LoRA so the resulting checkpoint can be used as the starting weights | |
| for GRPO. | |
| This addresses the v1 reward hack head-on. v1 | |
| (``anugrahhu/cernenv-grpo-smollm2-360m``) never saw a positive-reward | |
| trajectory during early training because SmolLM2-360M-Instruct cannot | |
| solve the LHC discovery pipeline zero-shot, so GRPO had no positive | |
| gradient to follow and the policy collapsed to "spam request_systematics | |
| forever". A short SFT on oracle traces gives the policy a non-zero | |
| prior over the *correct* action sequence, which RL can then refine. | |
| Usage | |
| ----- | |
| python -m training.sft_warmstart \\ | |
| --out_dir runs/sft-warmstart \\ | |
| --num_episodes 200 --max_steps 8 --epochs 1 --lr 1e-5 \\ | |
| --base_model HuggingFaceTB/SmolLM2-360M-Instruct \\ | |
| --difficulty mixed --evidence_dir evidence | |
| Smoke test: | |
| python -m training.sft_warmstart --num_episodes 4 --max_steps 4 \\ | |
| --epochs 1 --base_model sshleifer/tiny-gpt2 --out_dir /tmp/sft_smoke | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import time | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from models import ActionType, ExperimentAction | |
| from server.environment import CERNCollisionEnvironment | |
| from training.llm_agent import LLMAgentConfig, build_chat | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Mirrors training/training_script.py so the pre-train SFT and the | |
| # downstream GRPO use the *same* prompt templating. If you change one, | |
| # change both — divergence here is the most insidious source of | |
| # warm-start ineffectiveness. | |
| LORA_R = 16 | |
| LORA_ALPHA = 32 | |
| LORA_DROPOUT = 0.05 | |
| LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"] | |
| # ── Oracle trajectory collection ───────────────────────────────────────── | |
| _DIFFICULTY_CYCLE = ("easy", "medium", "hard") | |
| def _serialise_action(action: ExperimentAction) -> str: | |
| """Return a JSON string the GRPO parser would accept. | |
| Uses ``models.ExperimentAction.model_dump`` so enum values are | |
| converted to their string representation and parameters survive a | |
| round-trip through ``training.llm_agent.parse_action``. | |
| """ | |
| payload: Dict[str, Any] = { | |
| "action_type": action.action_type.value, | |
| "parameters": dict(action.parameters or {}), | |
| } | |
| if action.method: | |
| payload["method"] = action.method | |
| if action.justification: | |
| payload["justification"] = action.justification | |
| if action.confidence is not None: | |
| payload["confidence"] = float(action.confidence) | |
| return json.dumps(payload, ensure_ascii=False) | |
| def _difficulty_for(idx: int, difficulty: str) -> str: | |
| if difficulty == "mixed": | |
| return _DIFFICULTY_CYCLE[idx % len(_DIFFICULTY_CYCLE)] | |
| return difficulty | |
| def _collect_oracle_trajectories( | |
| *, | |
| tokenizer, | |
| num_episodes: int, | |
| max_steps: int, | |
| difficulty: str, | |
| seed_base: int = 4242, | |
| ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
| """Run the OracleAgent against the env and harvest (prompt, completion) pairs. | |
| Returns (filtered_pairs, stats). ``stats`` reports oracle success | |
| rate even when no successful trajectories were obtained, so callers | |
| can log it into ``evidence/sft_summary.json``. | |
| """ | |
| from scripts.baseline_agents import OracleAgent | |
| config = LLMAgentConfig() | |
| pairs_all: List[Dict[str, Any]] = [] | |
| pairs_successful: List[Dict[str, Any]] = [] | |
| successes = 0 | |
| rewards: List[float] = [] | |
| for i in range(num_episodes): | |
| env = CERNCollisionEnvironment(max_steps=max_steps) | |
| difficulty_i = _difficulty_for(i, difficulty) | |
| obs = env.reset(seed=seed_base + i, difficulty=difficulty_i) | |
| truth = env.hidden_truth() | |
| agent = OracleAgent(truth=truth) | |
| agent.reset() | |
| episode_pairs: List[Dict[str, Any]] = [] | |
| cumulative = 0.0 | |
| while not obs.done and len(episode_pairs) < max_steps: | |
| chat = build_chat(obs, config) | |
| try: | |
| prompt = tokenizer.apply_chat_template( | |
| chat, add_generation_prompt=True, tokenize=False, | |
| ) | |
| except Exception as exc: # pragma: no cover - tiny-gpt2 etc. | |
| logger.warning( | |
| "tokenizer has no chat template (%s); installing a " | |
| "minimal fallback so SFT can proceed", exc, | |
| ) | |
| tokenizer.chat_template = ( | |
| "{% for m in messages %}{{ m['role'] }}: {{ m['content'] }}\n" | |
| "{% endfor %}" | |
| "{% if add_generation_prompt %}assistant: {% endif %}" | |
| ) | |
| prompt = tokenizer.apply_chat_template( | |
| chat, add_generation_prompt=True, tokenize=False, | |
| ) | |
| action = agent.act(obs) | |
| completion = _serialise_action(action) | |
| episode_pairs.append({ | |
| "prompt": prompt, | |
| "completion": completion, | |
| "step": obs.step_index, | |
| "difficulty": difficulty_i, | |
| "seed": seed_base + i, | |
| }) | |
| obs = env.step(action) | |
| cumulative += float(obs.reward or 0.0) | |
| rewards.append(cumulative) | |
| st = env.state | |
| ok = bool(st.correct_mass) and bool(st.correct_channel) | |
| if ok: | |
| successes += 1 | |
| pairs_successful.extend(episode_pairs) | |
| pairs_all.extend(episode_pairs) | |
| success_rate = successes / max(num_episodes, 1) | |
| mean_reward = sum(rewards) / max(len(rewards), 1) if rewards else 0.0 | |
| stats = { | |
| "num_episodes": num_episodes, | |
| "num_successful_episodes": successes, | |
| "oracle_success_rate": round(success_rate, 4), | |
| "mean_oracle_reward": round(mean_reward, 4), | |
| "num_transitions_total": len(pairs_all), | |
| "num_transitions_successful": len(pairs_successful), | |
| } | |
| # If we filtered out *everything* (e.g. smoke test with max_steps too | |
| # small for the oracle to ever finish), fall back to the unfiltered | |
| # set with a warning. Better to teach the model the prefix of the | |
| # correct pipeline than to give up entirely. | |
| if pairs_successful: | |
| return pairs_successful, stats | |
| logger.warning( | |
| "no fully-successful oracle trajectories (max_steps=%d may be " | |
| "too small); using %d unfiltered transitions instead", | |
| max_steps, len(pairs_all), | |
| ) | |
| stats["fallback_used"] = True | |
| return pairs_all, stats | |
| # ── Dataset assembly ───────────────────────────────────────────────────── | |
| def _build_dataset(pairs: List[Dict[str, Any]], tokenizer): | |
| """Build a HF Dataset whose ``text`` column is ``prompt + completion``. | |
| SFTTrainer will train next-token-prediction on the *whole* text, so | |
| we append an ``eos_token`` at the end of the completion to terminate | |
| generation cleanly. ``dataset_text_field='text'`` on SFTConfig then | |
| consumes this column directly. | |
| """ | |
| from datasets import Dataset | |
| eos = tokenizer.eos_token or "" | |
| rows = [] | |
| for p in pairs: | |
| rows.append({ | |
| "text": p["prompt"] + p["completion"] + eos, | |
| "prompt": p["prompt"], | |
| "completion": p["completion"], | |
| }) | |
| return Dataset.from_list(rows) | |
| # ── LoRA helper ────────────────────────────────────────────────────────── | |
| def _build_peft_config(model): | |
| """Build a LoRA config matching the GRPO setup. | |
| For tiny stub models like ``sshleifer/tiny-gpt2`` the q_proj/k_proj/ | |
| v_proj/o_proj target modules don't exist (GPT-2 uses a fused | |
| ``c_attn``); fall back to ``all-linear`` so the smoke test still | |
| exercises the LoRA path. | |
| """ | |
| from peft import LoraConfig | |
| target_modules: Any = LORA_TARGET_MODULES | |
| available = {n for n, _ in model.named_modules()} | |
| has_target = any(t in mod_name for t in LORA_TARGET_MODULES for mod_name in available) | |
| if not has_target: | |
| logger.warning( | |
| "model %s exposes none of %s; falling back to target_modules='all-linear'", | |
| type(model).__name__, LORA_TARGET_MODULES, | |
| ) | |
| target_modules = "all-linear" | |
| return LoraConfig( | |
| r=LORA_R, | |
| lora_alpha=LORA_ALPHA, | |
| lora_dropout=LORA_DROPOUT, | |
| target_modules=target_modules, | |
| task_type="CAUSAL_LM", | |
| bias="none", | |
| ) | |
| # ── Main ───────────────────────────────────────────────────────────────── | |
| def main() -> None: # pragma: no cover - heavy ML path | |
| parser = argparse.ArgumentParser(description="SFT warm-start for CERNenv GRPO") | |
| parser.add_argument("--out_dir", default="runs/sft-warmstart") | |
| parser.add_argument("--num_episodes", type=int, default=200) | |
| parser.add_argument("--max_steps", type=int, default=8) | |
| parser.add_argument("--epochs", type=int, default=1) | |
| parser.add_argument("--lr", type=float, default=1e-5) | |
| parser.add_argument("--base_model", default="HuggingFaceTB/SmolLM2-360M-Instruct") | |
| parser.add_argument( | |
| "--difficulty", | |
| default="mixed", | |
| choices=["easy", "medium", "hard", "mixed"], | |
| help="'mixed' cycles easy/medium/hard across episodes.", | |
| ) | |
| parser.add_argument("--evidence_dir", default="evidence") | |
| parser.add_argument("--per_device_batch_size", type=int, default=4) | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1) | |
| parser.add_argument("--max_seq_length", type=int, default=1280) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument( | |
| "--no_lora", | |
| action="store_true", | |
| help="Train the full model without LoRA (used by some smoke tests).", | |
| ) | |
| args = parser.parse_args() | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| evidence_dir = Path(args.evidence_dir) | |
| evidence_dir.mkdir(parents=True, exist_ok=True) | |
| t_start = time.time() | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import SFTConfig, SFTTrainer | |
| except ImportError as exc: # pragma: no cover | |
| raise SystemExit( | |
| "Heavy ML deps missing — install -r space/training/requirements.txt" | |
| ) from exc | |
| logger.info("Loading tokenizer + base model: %s", args.base_model) | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| use_bf16 = torch.cuda.is_available() | |
| dtype = torch.bfloat16 if use_bf16 else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=dtype) | |
| if model.config.pad_token_id is None: | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| logger.info( | |
| "Collecting %d oracle trajectories (max_steps=%d, difficulty=%s)", | |
| args.num_episodes, args.max_steps, args.difficulty, | |
| ) | |
| pairs, stats = _collect_oracle_trajectories( | |
| tokenizer=tokenizer, | |
| num_episodes=args.num_episodes, | |
| max_steps=args.max_steps, | |
| difficulty=args.difficulty, | |
| ) | |
| logger.info( | |
| "oracle stats: success_rate=%.2f total_transitions=%d kept=%d", | |
| stats["oracle_success_rate"], | |
| stats["num_transitions_total"], | |
| len(pairs), | |
| ) | |
| if not pairs: | |
| raise SystemExit( | |
| "No transitions collected — check OracleAgent / env wiring." | |
| ) | |
| dataset = _build_dataset(pairs, tokenizer) | |
| logger.info("Built SFT dataset: %d rows", len(dataset)) | |
| peft_config = None if args.no_lora else _build_peft_config(model) | |
| sft_cfg = SFTConfig( | |
| output_dir=str(out_dir), | |
| per_device_train_batch_size=args.per_device_batch_size, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| num_train_epochs=float(args.epochs), | |
| learning_rate=args.lr, | |
| logging_steps=5, | |
| save_strategy="no", # we save manually at the end so checkpoints | |
| # don't double the disk footprint. | |
| bf16=use_bf16, | |
| fp16=False, | |
| seed=args.seed, | |
| report_to=[], | |
| dataset_text_field="text", | |
| max_seq_length=args.max_seq_length, | |
| packing=False, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=sft_cfg, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| ) | |
| logger.info("Starting SFT training (epochs=%d, lr=%.1e)", args.epochs, args.lr) | |
| train_result = trainer.train() | |
| final_loss = float(train_result.training_loss) | |
| logger.info("SFT done; final training_loss=%.4f", final_loss) | |
| # If we're using LoRA, merge the adapters back into the base model | |
| # and save a *full* causal-LM checkpoint to ``out_dir``. GRPO | |
| # downstream just calls ``AutoModelForCausalLM.from_pretrained(out_dir)``, | |
| # which is much simpler than asking it to load a base model + adapters | |
| # separately and keeps the warm-start path one ``--base_model`` flag. | |
| if peft_config is not None: | |
| try: | |
| merged = trainer.model.merge_and_unload() | |
| merged.save_pretrained(str(out_dir)) | |
| logger.info("Merged LoRA adapters into base model and saved to %s", out_dir) | |
| except Exception as exc: | |
| logger.warning( | |
| "merge_and_unload failed (%s); saving adapters alongside base " | |
| "and pointing GRPO at this directory will still work via PEFT.", | |
| exc, | |
| ) | |
| trainer.save_model(str(out_dir)) | |
| else: | |
| trainer.save_model(str(out_dir)) | |
| tokenizer.save_pretrained(str(out_dir)) | |
| duration_s = round(time.time() - t_start, 2) | |
| summary = dict(stats) | |
| summary.update({ | |
| "final_loss": round(final_loss, 6), | |
| "duration_s": duration_s, | |
| "epochs": args.epochs, | |
| "learning_rate": args.lr, | |
| "base_model": args.base_model, | |
| "out_dir": str(out_dir), | |
| "lora": peft_config is not None, | |
| "num_train_rows": len(dataset), | |
| }) | |
| summary_path = evidence_dir / "sft_summary.json" | |
| summary_path.write_text(json.dumps(summary, indent=2)) | |
| logger.info( | |
| "Wrote SFT summary to %s (final_loss=%.4f, duration=%.1fs)", | |
| summary_path, final_loss, duration_s, | |
| ) | |
| if __name__ == "__main__": # pragma: no cover | |
| main() | |