| """ |
| Run before grpo_train.py for SFT→GRPO pipeline. Pass checkpoint path as |
| BASE_MODEL env var to grpo_train.py. |
| """ |
| import argparse |
| import inspect |
| import json |
| import logging |
| from pathlib import Path |
|
|
| logger = logging.getLogger(__name__) |
|
|
| DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" |
| DEFAULT_OUTPUT = "checkpoints/sft_1.5b/" |
|
|
|
|
| def _sft_seq_len_kw(max_tokens: int = 2048) -> dict[str, int]: |
| """TRL 1.0+ uses max_length; older TRL used max_seq_length on SFTConfig.""" |
| from trl import SFTConfig |
|
|
| p = set(inspect.signature(SFTConfig.__init__).parameters) |
| if "max_length" in p: |
| return {"max_length": max_tokens} |
| if "max_seq_length" in p: |
| return {"max_seq_length": max_tokens} |
| return {} |
|
|
|
|
| def _extract_completions(rec: dict) -> list[str]: |
| """Return candidate completion texts from a record.""" |
| completion = rec.get("completion") |
| if isinstance(completion, str) and completion.strip(): |
| return [completion.strip()] |
|
|
| conversation = rec.get("conversation", []) |
| candidates: list[str] = [] |
| if isinstance(conversation, list): |
| for turn in conversation: |
| if not isinstance(turn, dict): |
| continue |
| role = str(turn.get("role", "")).lower() |
| content = str(turn.get("content", "")).strip() |
| if role == "negotiator" and content: |
| candidates.append(content) |
| return candidates |
|
|
|
|
| def _row_total_reward(rec: dict) -> float | None: |
| v = rec.get("reward") |
| if v is not None: |
| return float(v) |
| v2 = rec.get("cumulative_reward") |
| if v2 is not None: |
| return float(v2) |
| return None |
|
|
|
|
| def load_sft_dataset( |
| data_path: Path, min_reward: float = -50.0, model_id: str | None = None |
| ): |
| """Build a text dataset from JSONL: Qwen2.5 chat (system + first user + assistant = negotiator).""" |
| try: |
| from datasets import Dataset |
| except ImportError as exc: |
| raise ImportError("Install datasets: pip install datasets") from exc |
|
|
| from training.prompts_qwen import format_sft_text, load_tokenizer_for_chat |
|
|
| mid = (model_id or DEFAULT_MODEL).strip() or DEFAULT_MODEL |
| _tok = load_tokenizer_for_chat(mid) |
|
|
| rows: list[dict[str, str]] = [] |
| skipped = 0 |
| reward_filtered = 0 |
| remaining_records = 0 |
| with data_path.open("r", encoding="utf-8") as f: |
| for line_no, line in enumerate(f, start=1): |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| rec = json.loads(line) |
| except json.JSONDecodeError: |
| logger.warning("Skipping malformed JSONL row %d", line_no) |
| skipped += 1 |
| continue |
|
|
| r = _row_total_reward(rec) |
| if r is not None and r < min_reward: |
| reward_filtered += 1 |
| continue |
|
|
| prompt = str(rec.get("prompt", "")).strip() |
| if not prompt: |
| logger.warning("Skipping row %d: missing prompt", line_no) |
| skipped += 1 |
| continue |
|
|
| completions = _extract_completions(rec) |
| if not completions: |
| logger.warning("Skipping row %d: missing completion and negotiator turns", line_no) |
| skipped += 1 |
| continue |
|
|
| remaining_records += 1 |
| for completion in completions: |
| rows.append( |
| { |
| "text": format_sft_text(rec, completion, tokenizer=_tok), |
| } |
| ) |
|
|
| print( |
| f"Filtered {reward_filtered} records below min_reward={min_reward}, " |
| f"{remaining_records} remaining for SFT" |
| ) |
| if skipped: |
| logger.info("Also skipped %d malformed/empty JSONL rows; expanded to %d text rows", skipped, len(rows)) |
| if not rows: |
| raise RuntimeError("No valid SFT examples found in dataset.") |
| return Dataset.from_list(rows) |
|
|
|
|
| def train_sft( |
| data_path: Path, |
| model_id: str, |
| output_dir: Path, |
| min_reward: float = -50.0, |
| *, |
| per_device_train_batch_size: int = 2, |
| gradient_accumulation_steps: int = 8, |
| ) -> None: |
| """Fine-tune a base model with LoRA via TRL SFTTrainer. |
| |
| Default batch/accum (2×8) keeps effective batch 16 and fits Colab T4 (16GB VRAM) better than 4×4; |
| set higher batch if you have headroom. gradient_checkpointing reduces VRAM at some speed cost. |
| """ |
| import torch |
| from peft import LoraConfig |
| from trl import SFTConfig, SFTTrainer |
|
|
| dataset = load_sft_dataset(data_path, min_reward=min_reward, model_id=model_id) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| lora_config = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| target_modules=["q_proj", "v_proj"], |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
|
|
| training_args = SFTConfig( |
| output_dir=str(output_dir), |
| num_train_epochs=3, |
| per_device_train_batch_size=per_device_train_batch_size, |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| learning_rate=2e-4, |
| logging_steps=10, |
| save_strategy="epoch", |
| fp16=True, |
| report_to="none", |
| gradient_checkpointing=True, |
| **_sft_seq_len_kw(2048), |
| ) |
|
|
| if not torch.cuda.is_available(): |
| logger.warning("No CUDA GPU detected; training may be very slow.") |
|
|
| trainer = SFTTrainer( |
| model=model_id, |
| args=training_args, |
| train_dataset=dataset, |
| peft_config=lora_config, |
| ) |
|
|
| logger.info("Starting SFT: model=%s, examples=%d", model_id, len(dataset)) |
| trainer.train() |
| trainer.save_model(str(output_dir)) |
| logger.info("Saved SFT checkpoint to %s", output_dir) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Parlay SFT training") |
| parser.add_argument("--data", default="data/episodes.jsonl") |
| parser.add_argument("--model", default=DEFAULT_MODEL) |
| parser.add_argument("--output", default=DEFAULT_OUTPUT) |
| parser.add_argument( |
| "--min-reward", |
| type=float, |
| default=-50.0, |
| help="Skip JSONL records with total reward below this (default: -50.0)", |
| ) |
| parser.add_argument( |
| "--per-device-train-batch-size", |
| type=int, |
| default=2, |
| help="Lower if GPU OOM (default 2, effective batch = this × grad accum)", |
| ) |
| parser.add_argument( |
| "--gradient-accumulation-steps", |
| type=int, |
| default=8, |
| help="Default 8 with batch 2 for effective batch 16; raise batch and lower this on large GPUs", |
| ) |
| args = parser.parse_args() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") |
| train_sft( |
| Path(args.data), |
| args.model, |
| Path(args.output), |
| min_reward=args.min_reward, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|