File size: 4,895 Bytes
98a5a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
# /// script
# dependencies = [
#   "torch",
#   "transformers>=4.45.0",
#   "trl>=0.25.0",
#   "peft>=0.13.0",
#   "datasets>=2.20.0",
#   "accelerate>=0.34.0",
#   "huggingface_hub>=0.24.0",
# ]
# ///
"""Train a LoRA SFT Budget Router model on HF Jobs and push merged weights."""

from __future__ import annotations

import argparse
import os


DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_DATASET_REPO = "akshay4/budget-router-sft-data"
DEFAULT_OUTPUT_REPO = "akshay4/budget-router-sft-qwen1.5b"


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train Budget Router SFT model.")
    parser.add_argument("--base-model", default=os.getenv("BASE_MODEL", DEFAULT_BASE_MODEL))
    parser.add_argument("--dataset-repo", default=os.getenv("DATASET_REPO", DEFAULT_DATASET_REPO))
    parser.add_argument("--output-repo", default=os.getenv("OUTPUT_REPO", DEFAULT_OUTPUT_REPO))
    parser.add_argument("--num-epochs", type=float, default=float(os.getenv("NUM_EPOCHS", "3")))
    parser.add_argument("--learning-rate", type=float, default=float(os.getenv("LEARNING_RATE", "2e-4")))
    parser.add_argument("--lora-r", type=int, default=int(os.getenv("LORA_R", "16")))
    parser.add_argument("--lora-alpha", type=int, default=int(os.getenv("LORA_ALPHA", "32")))
    parser.add_argument("--max-length", type=int, default=int(os.getenv("MAX_SEQ_LENGTH", "4096")))
    parser.add_argument("--batch-size", type=int, default=int(os.getenv("PER_DEVICE_BATCH_SIZE", "2")))
    parser.add_argument("--grad-accum", type=int, default=int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "4")))
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    token = os.environ.get("HF_TOKEN")
    if not token:
        raise RuntimeError("HF_TOKEN must be set as a secret in the HF Job.")

    import torch
    from datasets import load_dataset
    from peft import LoraConfig
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from trl import SFTConfig, SFTTrainer

    device_supports_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    dtype = torch.bfloat16 if device_supports_bf16 else torch.float16

    print(f"[train-sft] loading dataset={args.dataset_repo}", flush=True)
    dataset = load_dataset(args.dataset_repo, split="train", token=token)
    print(f"[train-sft] rows={len(dataset)}", flush=True)

    print(f"[train-sft] loading base_model={args.base_model}", flush=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        torch_dtype=dtype,
        token=token,
    )
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, token=token)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    peft_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    sft_args = SFTConfig(
        output_dir="./sft_output",
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        learning_rate=args.learning_rate,
        lr_scheduler_type="cosine",
        warmup_ratio=0.1,
        bf16=device_supports_bf16,
        fp16=not device_supports_bf16,
        optim="adamw_torch",
        logging_steps=5,
        save_strategy="epoch",
        report_to="none",
        max_length=args.max_length,
        packing=False,
        assistant_only_loss=True,
        push_to_hub=False,
    )

    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        args=sft_args,
        train_dataset=dataset,
        peft_config=peft_config,
    )
    trainable = getattr(trainer.model, "print_trainable_parameters", None)
    if callable(trainable):
        trainable()

    print("[train-sft] starting training", flush=True)
    train_result = trainer.train()
    print(f"[train-sft] train_metrics={train_result.metrics}", flush=True)

    final_loss = train_result.metrics.get("train_loss")
    if final_loss is not None and float(final_loss) > 0.5:
        print("[train-sft] WARNING: train_loss > 0.5; inspect data and consider more epochs.", flush=True)

    print("[train-sft] merging LoRA and pushing model", flush=True)
    merged = trainer.model.merge_and_unload() if hasattr(trainer.model, "merge_and_unload") else trainer.model
    merged.push_to_hub(args.output_repo, token=token)
    tokenizer.push_to_hub(args.output_repo, token=token)
    print(f"[train-sft] Model pushed to {args.output_repo}. Training complete.", flush=True)


if __name__ == "__main__":
    main()