| """SFT warm-start trainer for both roles. |
| |
| Run on a Colab T4/A100 GPU. Reads `warmstart/data/repair_pairs.jsonl` (or |
| `drift_pairs.jsonl`), wraps in TRL SFTTrainer with Unsloth's 4-bit Qwen2.5 |
| loader, and saves a LoRA adapter. |
| |
| Usage: |
| python -m forgeenv.training.sft_warmstart \\ |
| --role repair_agent \\ |
| --data warmstart/data/repair_pairs.jsonl \\ |
| --output_dir artifacts/checkpoints/repair_agent_sft \\ |
| --base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\ |
| --max_steps 200 |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from pathlib import Path |
| from typing import Optional |
|
|
|
|
| def _load_jsonl(path: str) -> list[dict]: |
| rows: list[dict] = [] |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| rows.append(json.loads(line)) |
| return rows |
|
|
|
|
| def _format_chat(rows: list[dict]) -> list[dict]: |
| """Flatten messages -> a single `text` field for SFT.""" |
| out: list[dict] = [] |
| for row in rows: |
| msgs = row["messages"] |
| text_parts = [] |
| for m in msgs: |
| text_parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>") |
| out.append({"text": "\n".join(text_parts)}) |
| return out |
|
|
|
|
| def run_sft( |
| role: str, |
| data_path: str, |
| output_dir: str, |
| base_model: str = "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit", |
| max_steps: int = 200, |
| batch_size: int = 2, |
| learning_rate: float = 2e-4, |
| lora_r: int = 16, |
| seed: int = 0, |
| use_unsloth: Optional[bool] = None, |
| ) -> None: |
| """Run SFT. Imports unsloth/trl lazily so this module is importable on |
| machines without a GPU.""" |
| rows = _load_jsonl(data_path) |
| formatted = _format_chat(rows) |
| print(f"[forgeenv.sft] Loaded {len(formatted)} rows for role={role}") |
|
|
| if use_unsloth is None: |
| use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1" |
|
|
| if use_unsloth: |
| from unsloth import FastLanguageModel |
| from datasets import Dataset |
| from trl import SFTConfig, SFTTrainer |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=base_model, |
| max_seq_length=4096, |
| dtype=None, |
| load_in_4bit=True, |
| ) |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=lora_r, |
| lora_alpha=lora_r * 2, |
| lora_dropout=0.0, |
| bias="none", |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| use_gradient_checkpointing="unsloth", |
| random_state=seed, |
| ) |
|
|
| dataset = Dataset.from_list(formatted) |
| sft_config = SFTConfig( |
| output_dir=output_dir, |
| per_device_train_batch_size=batch_size, |
| gradient_accumulation_steps=4, |
| warmup_steps=10, |
| max_steps=max_steps, |
| learning_rate=learning_rate, |
| logging_steps=10, |
| optim="adamw_8bit", |
| weight_decay=0.01, |
| lr_scheduler_type="linear", |
| seed=seed, |
| save_steps=max(50, max_steps // 4), |
| save_total_limit=2, |
| report_to="none", |
| dataset_text_field="text", |
| max_seq_length=4096, |
| ) |
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=dataset, |
| args=sft_config, |
| ) |
| trainer.train() |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
| print(f"[forgeenv.sft] Saved adapter to {output_dir}") |
| return |
|
|
| |
| |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| out_file = Path(output_dir) / "formatted_dataset.jsonl" |
| with out_file.open("w", encoding="utf-8") as f: |
| for row in formatted: |
| f.write(json.dumps(row) + "\n") |
| print(f"[forgeenv.sft] (dry run) wrote {len(formatted)} rows to {out_file}") |
|
|
|
|
| def _parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--role", choices=["repair_agent", "drift_generator"], required=True |
| ) |
| parser.add_argument("--data", required=True, help="Path to JSONL warm-start file") |
| parser.add_argument("--output_dir", required=True) |
| parser.add_argument( |
| "--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" |
| ) |
| parser.add_argument("--max_steps", type=int, default=200) |
| parser.add_argument("--batch_size", type=int, default=2) |
| parser.add_argument("--learning_rate", type=float, default=2e-4) |
| parser.add_argument("--lora_r", type=int, default=16) |
| parser.add_argument("--seed", type=int, default=0) |
| parser.add_argument("--dry_run", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = _parse_args() |
| run_sft( |
| role=args.role, |
| data_path=args.data, |
| output_dir=args.output_dir, |
| base_model=args.base_model, |
| max_steps=args.max_steps, |
| batch_size=args.batch_size, |
| learning_rate=args.learning_rate, |
| lora_r=args.lora_r, |
| seed=args.seed, |
| use_unsloth=not args.dry_run, |
| ) |
|
|