""" Run before grpo_train.py for SFT→GRPO pipeline. Pass checkpoint path as BASE_MODEL env var to grpo_train.py. """ import argparse import inspect import json import logging from pathlib import Path logger = logging.getLogger(__name__) DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" DEFAULT_OUTPUT = "checkpoints/sft_1.5b/" def _sft_seq_len_kw(max_tokens: int = 2048) -> dict[str, int]: """TRL 1.0+ uses max_length; older TRL used max_seq_length on SFTConfig.""" from trl import SFTConfig p = set(inspect.signature(SFTConfig.__init__).parameters) if "max_length" in p: return {"max_length": max_tokens} if "max_seq_length" in p: return {"max_seq_length": max_tokens} return {} def _extract_completions(rec: dict) -> list[str]: """Return candidate completion texts from a record.""" completion = rec.get("completion") if isinstance(completion, str) and completion.strip(): return [completion.strip()] conversation = rec.get("conversation", []) candidates: list[str] = [] if isinstance(conversation, list): for turn in conversation: if not isinstance(turn, dict): continue role = str(turn.get("role", "")).lower() content = str(turn.get("content", "")).strip() if role == "negotiator" and content: candidates.append(content) return candidates def _row_total_reward(rec: dict) -> float | None: v = rec.get("reward") if v is not None: return float(v) v2 = rec.get("cumulative_reward") if v2 is not None: return float(v2) return None def load_sft_dataset( data_path: Path, min_reward: float = -50.0, model_id: str | None = None ): """Build a text dataset from JSONL: Qwen2.5 chat (system + first user + assistant = negotiator).""" try: from datasets import Dataset except ImportError as exc: raise ImportError("Install datasets: pip install datasets") from exc from training.prompts_qwen import format_sft_text, load_tokenizer_for_chat mid = (model_id or DEFAULT_MODEL).strip() or DEFAULT_MODEL _tok = load_tokenizer_for_chat(mid) rows: list[dict[str, str]] = [] skipped = 0 reward_filtered = 0 remaining_records = 0 with data_path.open("r", encoding="utf-8") as f: for line_no, line in enumerate(f, start=1): line = line.strip() if not line: continue try: rec = json.loads(line) except json.JSONDecodeError: logger.warning("Skipping malformed JSONL row %d", line_no) skipped += 1 continue r = _row_total_reward(rec) if r is not None and r < min_reward: reward_filtered += 1 continue prompt = str(rec.get("prompt", "")).strip() if not prompt: logger.warning("Skipping row %d: missing prompt", line_no) skipped += 1 continue completions = _extract_completions(rec) if not completions: logger.warning("Skipping row %d: missing completion and negotiator turns", line_no) skipped += 1 continue remaining_records += 1 for completion in completions: rows.append( { "text": format_sft_text(rec, completion, tokenizer=_tok), } ) print( f"Filtered {reward_filtered} records below min_reward={min_reward}, " f"{remaining_records} remaining for SFT" ) if skipped: logger.info("Also skipped %d malformed/empty JSONL rows; expanded to %d text rows", skipped, len(rows)) if not rows: raise RuntimeError("No valid SFT examples found in dataset.") return Dataset.from_list(rows) def train_sft( data_path: Path, model_id: str, output_dir: Path, min_reward: float = -50.0, *, per_device_train_batch_size: int = 2, gradient_accumulation_steps: int = 8, ) -> None: """Fine-tune a base model with LoRA via TRL SFTTrainer. Default batch/accum (2×8) keeps effective batch 16 and fits Colab T4 (16GB VRAM) better than 4×4; set higher batch if you have headroom. gradient_checkpointing reduces VRAM at some speed cost. """ import torch from peft import LoraConfig from trl import SFTConfig, SFTTrainer dataset = load_sft_dataset(data_path, min_reward=min_reward, model_id=model_id) output_dir.mkdir(parents=True, exist_ok=True) lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) training_args = SFTConfig( output_dir=str(output_dir), num_train_epochs=3, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=2e-4, logging_steps=10, save_strategy="epoch", fp16=True, report_to="none", gradient_checkpointing=True, **_sft_seq_len_kw(2048), ) if not torch.cuda.is_available(): logger.warning("No CUDA GPU detected; training may be very slow.") trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, peft_config=lora_config, ) logger.info("Starting SFT: model=%s, examples=%d", model_id, len(dataset)) trainer.train() trainer.save_model(str(output_dir)) logger.info("Saved SFT checkpoint to %s", output_dir) def main() -> None: parser = argparse.ArgumentParser(description="Parlay SFT training") parser.add_argument("--data", default="data/episodes.jsonl") parser.add_argument("--model", default=DEFAULT_MODEL) parser.add_argument("--output", default=DEFAULT_OUTPUT) parser.add_argument( "--min-reward", type=float, default=-50.0, help="Skip JSONL records with total reward below this (default: -50.0)", ) parser.add_argument( "--per-device-train-batch-size", type=int, default=2, help="Lower if GPU OOM (default 2, effective batch = this × grad accum)", ) parser.add_argument( "--gradient-accumulation-steps", type=int, default=8, help="Default 8 with batch 2 for effective batch 16; raise batch and lower this on large GPUs", ) args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") train_sft( Path(args.data), args.model, Path(args.output), min_reward=args.min_reward, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, ) if __name__ == "__main__": main()