test-rl-hackathon-budget / train_sft.py
Akshay Babbar
chore: HF Space export (size filter)
98a5a8c
#!/usr/bin/env python3
# /// script
# dependencies = [
# "torch",
# "transformers>=4.45.0",
# "trl>=0.25.0",
# "peft>=0.13.0",
# "datasets>=2.20.0",
# "accelerate>=0.34.0",
# "huggingface_hub>=0.24.0",
# ]
# ///
"""Train a LoRA SFT Budget Router model on HF Jobs and push merged weights."""
from __future__ import annotations
import argparse
import os
DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_DATASET_REPO = "akshay4/budget-router-sft-data"
DEFAULT_OUTPUT_REPO = "akshay4/budget-router-sft-qwen1.5b"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train Budget Router SFT model.")
parser.add_argument("--base-model", default=os.getenv("BASE_MODEL", DEFAULT_BASE_MODEL))
parser.add_argument("--dataset-repo", default=os.getenv("DATASET_REPO", DEFAULT_DATASET_REPO))
parser.add_argument("--output-repo", default=os.getenv("OUTPUT_REPO", DEFAULT_OUTPUT_REPO))
parser.add_argument("--num-epochs", type=float, default=float(os.getenv("NUM_EPOCHS", "3")))
parser.add_argument("--learning-rate", type=float, default=float(os.getenv("LEARNING_RATE", "2e-4")))
parser.add_argument("--lora-r", type=int, default=int(os.getenv("LORA_R", "16")))
parser.add_argument("--lora-alpha", type=int, default=int(os.getenv("LORA_ALPHA", "32")))
parser.add_argument("--max-length", type=int, default=int(os.getenv("MAX_SEQ_LENGTH", "4096")))
parser.add_argument("--batch-size", type=int, default=int(os.getenv("PER_DEVICE_BATCH_SIZE", "2")))
parser.add_argument("--grad-accum", type=int, default=int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "4")))
return parser.parse_args()
def main() -> None:
args = parse_args()
token = os.environ.get("HF_TOKEN")
if not token:
raise RuntimeError("HF_TOKEN must be set as a secret in the HF Job.")
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
device_supports_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if device_supports_bf16 else torch.float16
print(f"[train-sft] loading dataset={args.dataset_repo}", flush=True)
dataset = load_dataset(args.dataset_repo, split="train", token=token)
print(f"[train-sft] rows={len(dataset)}", flush=True)
print(f"[train-sft] loading base_model={args.base_model}", flush=True)
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype=dtype,
token=token,
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, token=token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
sft_args = SFTConfig(
output_dir="./sft_output",
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.learning_rate,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
bf16=device_supports_bf16,
fp16=not device_supports_bf16,
optim="adamw_torch",
logging_steps=5,
save_strategy="epoch",
report_to="none",
max_length=args.max_length,
packing=False,
assistant_only_loss=True,
push_to_hub=False,
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=sft_args,
train_dataset=dataset,
peft_config=peft_config,
)
trainable = getattr(trainer.model, "print_trainable_parameters", None)
if callable(trainable):
trainable()
print("[train-sft] starting training", flush=True)
train_result = trainer.train()
print(f"[train-sft] train_metrics={train_result.metrics}", flush=True)
final_loss = train_result.metrics.get("train_loss")
if final_loss is not None and float(final_loss) > 0.5:
print("[train-sft] WARNING: train_loss > 0.5; inspect data and consider more epochs.", flush=True)
print("[train-sft] merging LoRA and pushing model", flush=True)
merged = trainer.model.merge_and_unload() if hasattr(trainer.model, "merge_and_unload") else trainer.model
merged.push_to_hub(args.output_repo, token=token)
tokenizer.push_to_hub(args.output_repo, token=token)
print(f"[train-sft] Model pushed to {args.output_repo}. Training complete.", flush=True)
if __name__ == "__main__":
main()