Spaces:
Sleeping
Sleeping
File size: 4,895 Bytes
98a5a8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | #!/usr/bin/env python3
# /// script
# dependencies = [
# "torch",
# "transformers>=4.45.0",
# "trl>=0.25.0",
# "peft>=0.13.0",
# "datasets>=2.20.0",
# "accelerate>=0.34.0",
# "huggingface_hub>=0.24.0",
# ]
# ///
"""Train a LoRA SFT Budget Router model on HF Jobs and push merged weights."""
from __future__ import annotations
import argparse
import os
DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_DATASET_REPO = "akshay4/budget-router-sft-data"
DEFAULT_OUTPUT_REPO = "akshay4/budget-router-sft-qwen1.5b"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train Budget Router SFT model.")
parser.add_argument("--base-model", default=os.getenv("BASE_MODEL", DEFAULT_BASE_MODEL))
parser.add_argument("--dataset-repo", default=os.getenv("DATASET_REPO", DEFAULT_DATASET_REPO))
parser.add_argument("--output-repo", default=os.getenv("OUTPUT_REPO", DEFAULT_OUTPUT_REPO))
parser.add_argument("--num-epochs", type=float, default=float(os.getenv("NUM_EPOCHS", "3")))
parser.add_argument("--learning-rate", type=float, default=float(os.getenv("LEARNING_RATE", "2e-4")))
parser.add_argument("--lora-r", type=int, default=int(os.getenv("LORA_R", "16")))
parser.add_argument("--lora-alpha", type=int, default=int(os.getenv("LORA_ALPHA", "32")))
parser.add_argument("--max-length", type=int, default=int(os.getenv("MAX_SEQ_LENGTH", "4096")))
parser.add_argument("--batch-size", type=int, default=int(os.getenv("PER_DEVICE_BATCH_SIZE", "2")))
parser.add_argument("--grad-accum", type=int, default=int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "4")))
return parser.parse_args()
def main() -> None:
args = parse_args()
token = os.environ.get("HF_TOKEN")
if not token:
raise RuntimeError("HF_TOKEN must be set as a secret in the HF Job.")
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
device_supports_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if device_supports_bf16 else torch.float16
print(f"[train-sft] loading dataset={args.dataset_repo}", flush=True)
dataset = load_dataset(args.dataset_repo, split="train", token=token)
print(f"[train-sft] rows={len(dataset)}", flush=True)
print(f"[train-sft] loading base_model={args.base_model}", flush=True)
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype=dtype,
token=token,
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, token=token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
sft_args = SFTConfig(
output_dir="./sft_output",
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.learning_rate,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
bf16=device_supports_bf16,
fp16=not device_supports_bf16,
optim="adamw_torch",
logging_steps=5,
save_strategy="epoch",
report_to="none",
max_length=args.max_length,
packing=False,
assistant_only_loss=True,
push_to_hub=False,
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=sft_args,
train_dataset=dataset,
peft_config=peft_config,
)
trainable = getattr(trainer.model, "print_trainable_parameters", None)
if callable(trainable):
trainable()
print("[train-sft] starting training", flush=True)
train_result = trainer.train()
print(f"[train-sft] train_metrics={train_result.metrics}", flush=True)
final_loss = train_result.metrics.get("train_loss")
if final_loss is not None and float(final_loss) > 0.5:
print("[train-sft] WARNING: train_loss > 0.5; inspect data and consider more epochs.", flush=True)
print("[train-sft] merging LoRA and pushing model", flush=True)
merged = trainer.model.merge_and_unload() if hasattr(trainer.model, "merge_and_unload") else trainer.model
merged.push_to_hub(args.output_repo, token=token)
tokenizer.push_to_hub(args.output_repo, token=token)
print(f"[train-sft] Model pushed to {args.output_repo}. Training complete.", flush=True)
if __name__ == "__main__":
main()
|