"""SFT warm-start trainer for both roles. Run on a Colab T4/A100 GPU. Reads `warmstart/data/repair_pairs.jsonl` (or `drift_pairs.jsonl`), wraps in TRL SFTTrainer with Unsloth's 4-bit Qwen2.5 loader, and saves a LoRA adapter. Usage: python -m forgeenv.training.sft_warmstart \\ --role repair_agent \\ --data warmstart/data/repair_pairs.jsonl \\ --output_dir artifacts/checkpoints/repair_agent_sft \\ --base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\ --max_steps 200 """ from __future__ import annotations import argparse import json import os from pathlib import Path from typing import Optional def _load_jsonl(path: str) -> list[dict]: rows: list[dict] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: rows.append(json.loads(line)) return rows def _format_chat(rows: list[dict]) -> list[dict]: """Flatten messages -> a single `text` field for SFT.""" out: list[dict] = [] for row in rows: msgs = row["messages"] text_parts = [] for m in msgs: text_parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>") out.append({"text": "\n".join(text_parts)}) return out def run_sft( role: str, data_path: str, output_dir: str, base_model: str = "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit", max_steps: int = 200, batch_size: int = 2, learning_rate: float = 2e-4, lora_r: int = 16, seed: int = 0, use_unsloth: Optional[bool] = None, ) -> None: """Run SFT. Imports unsloth/trl lazily so this module is importable on machines without a GPU.""" rows = _load_jsonl(data_path) formatted = _format_chat(rows) print(f"[forgeenv.sft] Loaded {len(formatted)} rows for role={role}") if use_unsloth is None: use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1" if use_unsloth: from unsloth import FastLanguageModel from datasets import Dataset from trl import SFTConfig, SFTTrainer model, tokenizer = FastLanguageModel.from_pretrained( model_name=base_model, max_seq_length=4096, dtype=None, load_in_4bit=True, ) model = FastLanguageModel.get_peft_model( model, r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.0, bias="none", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], use_gradient_checkpointing="unsloth", random_state=seed, ) dataset = Dataset.from_list(formatted) sft_config = SFTConfig( output_dir=output_dir, per_device_train_batch_size=batch_size, gradient_accumulation_steps=4, warmup_steps=10, max_steps=max_steps, learning_rate=learning_rate, logging_steps=10, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=seed, save_steps=max(50, max_steps // 4), save_total_limit=2, report_to="none", dataset_text_field="text", max_seq_length=4096, ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, args=sft_config, ) trainer.train() Path(output_dir).mkdir(parents=True, exist_ok=True) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) print(f"[forgeenv.sft] Saved adapter to {output_dir}") return # CPU/dry-run fallback: just dump the formatted dataset to disk so we # can verify the pipeline shape locally. Path(output_dir).mkdir(parents=True, exist_ok=True) out_file = Path(output_dir) / "formatted_dataset.jsonl" with out_file.open("w", encoding="utf-8") as f: for row in formatted: f.write(json.dumps(row) + "\n") print(f"[forgeenv.sft] (dry run) wrote {len(formatted)} rows to {out_file}") def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--role", choices=["repair_agent", "drift_generator"], required=True ) parser.add_argument("--data", required=True, help="Path to JSONL warm-start file") parser.add_argument("--output_dir", required=True) parser.add_argument( "--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" ) parser.add_argument("--max_steps", type=int, default=200) parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--lora_r", type=int, default=16) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--dry_run", action="store_true") return parser.parse_args() if __name__ == "__main__": args = _parse_args() run_sft( role=args.role, data_path=args.data, output_dir=args.output_dir, base_model=args.base_model, max_steps=args.max_steps, batch_size=args.batch_size, learning_rate=args.learning_rate, lora_r=args.lora_r, seed=args.seed, use_unsloth=not args.dry_run, )