Parlay / training /sft_train.py
sh4shv4t's picture
fix: add chat template to GRPO prompts
79d9923
"""
Run before grpo_train.py for SFT→GRPO pipeline. Pass checkpoint path as
BASE_MODEL env var to grpo_train.py.
"""
import argparse
import inspect
import json
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_OUTPUT = "checkpoints/sft_1.5b/"
def _sft_seq_len_kw(max_tokens: int = 2048) -> dict[str, int]:
"""TRL 1.0+ uses max_length; older TRL used max_seq_length on SFTConfig."""
from trl import SFTConfig
p = set(inspect.signature(SFTConfig.__init__).parameters)
if "max_length" in p:
return {"max_length": max_tokens}
if "max_seq_length" in p:
return {"max_seq_length": max_tokens}
return {}
def _extract_completions(rec: dict) -> list[str]:
"""Return candidate completion texts from a record."""
completion = rec.get("completion")
if isinstance(completion, str) and completion.strip():
return [completion.strip()]
conversation = rec.get("conversation", [])
candidates: list[str] = []
if isinstance(conversation, list):
for turn in conversation:
if not isinstance(turn, dict):
continue
role = str(turn.get("role", "")).lower()
content = str(turn.get("content", "")).strip()
if role == "negotiator" and content:
candidates.append(content)
return candidates
def _row_total_reward(rec: dict) -> float | None:
v = rec.get("reward")
if v is not None:
return float(v)
v2 = rec.get("cumulative_reward")
if v2 is not None:
return float(v2)
return None
def load_sft_dataset(
data_path: Path, min_reward: float = -50.0, model_id: str | None = None
):
"""Build a text dataset from JSONL: Qwen2.5 chat (system + first user + assistant = negotiator)."""
try:
from datasets import Dataset
except ImportError as exc:
raise ImportError("Install datasets: pip install datasets") from exc
from training.prompts_qwen import format_sft_text, load_tokenizer_for_chat
mid = (model_id or DEFAULT_MODEL).strip() or DEFAULT_MODEL
_tok = load_tokenizer_for_chat(mid)
rows: list[dict[str, str]] = []
skipped = 0
reward_filtered = 0
remaining_records = 0
with data_path.open("r", encoding="utf-8") as f:
for line_no, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
rec = json.loads(line)
except json.JSONDecodeError:
logger.warning("Skipping malformed JSONL row %d", line_no)
skipped += 1
continue
r = _row_total_reward(rec)
if r is not None and r < min_reward:
reward_filtered += 1
continue
prompt = str(rec.get("prompt", "")).strip()
if not prompt:
logger.warning("Skipping row %d: missing prompt", line_no)
skipped += 1
continue
completions = _extract_completions(rec)
if not completions:
logger.warning("Skipping row %d: missing completion and negotiator turns", line_no)
skipped += 1
continue
remaining_records += 1
for completion in completions:
rows.append(
{
"text": format_sft_text(rec, completion, tokenizer=_tok),
}
)
print(
f"Filtered {reward_filtered} records below min_reward={min_reward}, "
f"{remaining_records} remaining for SFT"
)
if skipped:
logger.info("Also skipped %d malformed/empty JSONL rows; expanded to %d text rows", skipped, len(rows))
if not rows:
raise RuntimeError("No valid SFT examples found in dataset.")
return Dataset.from_list(rows)
def train_sft(
data_path: Path,
model_id: str,
output_dir: Path,
min_reward: float = -50.0,
*,
per_device_train_batch_size: int = 2,
gradient_accumulation_steps: int = 8,
) -> None:
"""Fine-tune a base model with LoRA via TRL SFTTrainer.
Default batch/accum (2×8) keeps effective batch 16 and fits Colab T4 (16GB VRAM) better than 4×4;
set higher batch if you have headroom. gradient_checkpointing reduces VRAM at some speed cost.
"""
import torch
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
dataset = load_sft_dataset(data_path, min_reward=min_reward, model_id=model_id)
output_dir.mkdir(parents=True, exist_ok=True)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
training_args = SFTConfig(
output_dir=str(output_dir),
num_train_epochs=3,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=2e-4,
logging_steps=10,
save_strategy="epoch",
fp16=True,
report_to="none",
gradient_checkpointing=True,
**_sft_seq_len_kw(2048),
)
if not torch.cuda.is_available():
logger.warning("No CUDA GPU detected; training may be very slow.")
trainer = SFTTrainer(
model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=lora_config,
)
logger.info("Starting SFT: model=%s, examples=%d", model_id, len(dataset))
trainer.train()
trainer.save_model(str(output_dir))
logger.info("Saved SFT checkpoint to %s", output_dir)
def main() -> None:
parser = argparse.ArgumentParser(description="Parlay SFT training")
parser.add_argument("--data", default="data/episodes.jsonl")
parser.add_argument("--model", default=DEFAULT_MODEL)
parser.add_argument("--output", default=DEFAULT_OUTPUT)
parser.add_argument(
"--min-reward",
type=float,
default=-50.0,
help="Skip JSONL records with total reward below this (default: -50.0)",
)
parser.add_argument(
"--per-device-train-batch-size",
type=int,
default=2,
help="Lower if GPU OOM (default 2, effective batch = this × grad accum)",
)
parser.add_argument(
"--gradient-accumulation-steps",
type=int,
default=8,
help="Default 8 with batch 2 for effective batch 16; raise batch and lower this on large GPUs",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
train_sft(
Path(args.data),
args.model,
Path(args.output),
min_reward=args.min_reward,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
)
if __name__ == "__main__":
main()