Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # /// script | |
| # dependencies = [ | |
| # "torch", | |
| # "transformers>=4.45.0", | |
| # "trl>=0.25.0", | |
| # "peft>=0.13.0", | |
| # "datasets>=2.20.0", | |
| # "accelerate>=0.34.0", | |
| # "huggingface_hub>=0.24.0", | |
| # ] | |
| # /// | |
| """Train a LoRA SFT Budget Router model on HF Jobs and push merged weights.""" | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" | |
| DEFAULT_DATASET_REPO = "akshay4/budget-router-sft-data" | |
| DEFAULT_OUTPUT_REPO = "akshay4/budget-router-sft-qwen1.5b" | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Train Budget Router SFT model.") | |
| parser.add_argument("--base-model", default=os.getenv("BASE_MODEL", DEFAULT_BASE_MODEL)) | |
| parser.add_argument("--dataset-repo", default=os.getenv("DATASET_REPO", DEFAULT_DATASET_REPO)) | |
| parser.add_argument("--output-repo", default=os.getenv("OUTPUT_REPO", DEFAULT_OUTPUT_REPO)) | |
| parser.add_argument("--num-epochs", type=float, default=float(os.getenv("NUM_EPOCHS", "3"))) | |
| parser.add_argument("--learning-rate", type=float, default=float(os.getenv("LEARNING_RATE", "2e-4"))) | |
| parser.add_argument("--lora-r", type=int, default=int(os.getenv("LORA_R", "16"))) | |
| parser.add_argument("--lora-alpha", type=int, default=int(os.getenv("LORA_ALPHA", "32"))) | |
| parser.add_argument("--max-length", type=int, default=int(os.getenv("MAX_SEQ_LENGTH", "4096"))) | |
| parser.add_argument("--batch-size", type=int, default=int(os.getenv("PER_DEVICE_BATCH_SIZE", "2"))) | |
| parser.add_argument("--grad-accum", type=int, default=int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "4"))) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| raise RuntimeError("HF_TOKEN must be set as a secret in the HF Job.") | |
| import torch | |
| from datasets import load_dataset | |
| from peft import LoraConfig | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import SFTConfig, SFTTrainer | |
| device_supports_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
| dtype = torch.bfloat16 if device_supports_bf16 else torch.float16 | |
| print(f"[train-sft] loading dataset={args.dataset_repo}", flush=True) | |
| dataset = load_dataset(args.dataset_repo, split="train", token=token) | |
| print(f"[train-sft] rows={len(dataset)}", flush=True) | |
| print(f"[train-sft] loading base_model={args.base_model}", flush=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.base_model, | |
| torch_dtype=dtype, | |
| token=token, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, token=token) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| peft_config = LoraConfig( | |
| r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| sft_args = SFTConfig( | |
| output_dir="./sft_output", | |
| num_train_epochs=args.num_epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.learning_rate, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.1, | |
| bf16=device_supports_bf16, | |
| fp16=not device_supports_bf16, | |
| optim="adamw_torch", | |
| logging_steps=5, | |
| save_strategy="epoch", | |
| report_to="none", | |
| max_length=args.max_length, | |
| packing=False, | |
| assistant_only_loss=True, | |
| push_to_hub=False, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| args=sft_args, | |
| train_dataset=dataset, | |
| peft_config=peft_config, | |
| ) | |
| trainable = getattr(trainer.model, "print_trainable_parameters", None) | |
| if callable(trainable): | |
| trainable() | |
| print("[train-sft] starting training", flush=True) | |
| train_result = trainer.train() | |
| print(f"[train-sft] train_metrics={train_result.metrics}", flush=True) | |
| final_loss = train_result.metrics.get("train_loss") | |
| if final_loss is not None and float(final_loss) > 0.5: | |
| print("[train-sft] WARNING: train_loss > 0.5; inspect data and consider more epochs.", flush=True) | |
| print("[train-sft] merging LoRA and pushing model", flush=True) | |
| merged = trainer.model.merge_and_unload() if hasattr(trainer.model, "merge_and_unload") else trainer.model | |
| merged.push_to_hub(args.output_repo, token=token) | |
| tokenizer.push_to_hub(args.output_repo, token=token) | |
| print(f"[train-sft] Model pushed to {args.output_repo}. Training complete.", flush=True) | |
| if __name__ == "__main__": | |
| main() | |