forgeenv-source / forgeenv /training /sft_warmstart.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""SFT warm-start trainer for both roles.
Run on a Colab T4/A100 GPU. Reads `warmstart/data/repair_pairs.jsonl` (or
`drift_pairs.jsonl`), wraps in TRL SFTTrainer with Unsloth's 4-bit Qwen2.5
loader, and saves a LoRA adapter.
Usage:
python -m forgeenv.training.sft_warmstart \\
--role repair_agent \\
--data warmstart/data/repair_pairs.jsonl \\
--output_dir artifacts/checkpoints/repair_agent_sft \\
--base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\
--max_steps 200
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Optional
def _load_jsonl(path: str) -> list[dict]:
rows: list[dict] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def _format_chat(rows: list[dict]) -> list[dict]:
"""Flatten messages -> a single `text` field for SFT."""
out: list[dict] = []
for row in rows:
msgs = row["messages"]
text_parts = []
for m in msgs:
text_parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>")
out.append({"text": "\n".join(text_parts)})
return out
def run_sft(
role: str,
data_path: str,
output_dir: str,
base_model: str = "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
max_steps: int = 200,
batch_size: int = 2,
learning_rate: float = 2e-4,
lora_r: int = 16,
seed: int = 0,
use_unsloth: Optional[bool] = None,
) -> None:
"""Run SFT. Imports unsloth/trl lazily so this module is importable on
machines without a GPU."""
rows = _load_jsonl(data_path)
formatted = _format_chat(rows)
print(f"[forgeenv.sft] Loaded {len(formatted)} rows for role={role}")
if use_unsloth is None:
use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1"
if use_unsloth:
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import SFTConfig, SFTTrainer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model,
max_seq_length=4096,
dtype=None,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_r,
lora_alpha=lora_r * 2,
lora_dropout=0.0,
bias="none",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth",
random_state=seed,
)
dataset = Dataset.from_list(formatted)
sft_config = SFTConfig(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
warmup_steps=10,
max_steps=max_steps,
learning_rate=learning_rate,
logging_steps=10,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=seed,
save_steps=max(50, max_steps // 4),
save_total_limit=2,
report_to="none",
dataset_text_field="text",
max_seq_length=4096,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
args=sft_config,
)
trainer.train()
Path(output_dir).mkdir(parents=True, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"[forgeenv.sft] Saved adapter to {output_dir}")
return
# CPU/dry-run fallback: just dump the formatted dataset to disk so we
# can verify the pipeline shape locally.
Path(output_dir).mkdir(parents=True, exist_ok=True)
out_file = Path(output_dir) / "formatted_dataset.jsonl"
with out_file.open("w", encoding="utf-8") as f:
for row in formatted:
f.write(json.dumps(row) + "\n")
print(f"[forgeenv.sft] (dry run) wrote {len(formatted)} rows to {out_file}")
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--role", choices=["repair_agent", "drift_generator"], required=True
)
parser.add_argument("--data", required=True, help="Path to JSONL warm-start file")
parser.add_argument("--output_dir", required=True)
parser.add_argument(
"--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit"
)
parser.add_argument("--max_steps", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--dry_run", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
run_sft(
role=args.role,
data_path=args.data,
output_dir=args.output_dir,
base_model=args.base_model,
max_steps=args.max_steps,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
lora_r=args.lora_r,
seed=args.seed,
use_unsloth=not args.dry_run,
)