File size: 8,835 Bytes
aee8025
 
 
 
 
 
 
 
 
 
 
 
 
2fdbc71
aee8025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fdbc71
 
aee8025
 
 
 
 
 
 
 
 
2fdbc71
aee8025
 
 
 
 
 
 
 
 
2fdbc71
 
 
 
 
 
aee8025
2fdbc71
 
 
 
aee8025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fdbc71
 
 
 
aee8025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fdbc71
aee8025
 
2fdbc71
aee8025
 
 
2fdbc71
aee8025
 
 
 
 
 
 
 
 
2fdbc71
aee8025
 
2fdbc71
aee8025
 
 
 
 
 
 
 
 
 
 
 
 
 
2fdbc71
 
 
 
aee8025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fdbc71
 
 
 
 
 
 
aee8025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python3
"""
TMF921 Intent-to-Configuration Translation β€” QLoRA Fine-Tuning
=============================================================
Fine-tunes Qwen3-8B on the TMF921-intent-to-config-augmented dataset
using 4-bit QLoRA. Designed for a single RTX 6000 Ada (50 GB VRAM).

Usage:
    python train.py                          # defaults
    python train.py --base_model Qwen/Qwen3-8B --epochs 3 --lr 1e-4
    python train.py --base_model Qwen/Qwen2.5-7B-Instruct --lora_r 64
"""

import argparse, os, json, torch, math
from datetime import datetime

from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer


# ── CLI ──────────────────────────────────────────────────────────────
def parse_args():
    p = argparse.ArgumentParser()
    # Model
    p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B",
                   help="HuggingFace model id or local path")
    # Dataset
    p.add_argument("--dataset", type=str,
                   default="nraptisss/TMF921-intent-to-config-augmented")
    # LoRA
    p.add_argument("--lora_r", type=int, default=32)
    p.add_argument("--lora_alpha", type=int, default=64)
    p.add_argument("--lora_dropout", type=float, default=0.05)
    # Training
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--grad_accum", type=int, default=8)
    p.add_argument("--max_length", type=int, default=4096)
    p.add_argument("--warmup_steps", type=int, default=100)
    p.add_argument("--weight_decay", type=float, default=0.01)
    # Output
    p.add_argument("--output_dir", type=str, default="./output")
    p.add_argument("--hub_model_id", type=str, default=None,
                   help="Push to this HF model id (e.g. nraptisss/Qwen3-8B-TMF921)")
    p.add_argument("--push_to_hub", action="store_true", default=False)
    # Misc
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--flash_attn", action="store_true", default=True)
    p.add_argument("--no_flash_attn", dest="flash_attn", action="store_false")
    return p.parse_args()


def main():
    args = parse_args()

    print("=" * 70)
    print("TMF921 Intent Translation β€” QLoRA Training")
    print("=" * 70)
    print(f"Base model    : {args.base_model}")
    print(f"Dataset       : {args.dataset}")
    print(f"LoRA r/alpha  : {args.lora_r}/{args.lora_alpha}")
    print(f"Epochs        : {args.epochs}")
    print(f"LR            : {args.lr}")
    print(f"Batch size    : {args.batch_size} Γ— {args.grad_accum} grad_accum = "
          f"{args.batch_size * args.grad_accum} effective")
    print(f"Max length    : {args.max_length}")
    print(f"Flash attn    : {args.flash_attn}")
    print(f"Output        : {args.output_dir}")
    print(f"Push to Hub   : {args.push_to_hub} β†’ {args.hub_model_id}")
    print("=" * 70)

    # ── 1. Load dataset ──────────────────────────────────────────────
    print("\n[1/4] Loading dataset …")
    dataset = load_dataset(args.dataset)
    train_ds = dataset["train"]
    eval_ds  = dataset["test"]
    print(f"  Train: {len(train_ds):,}  |  Eval: {len(eval_ds):,}")

    # ── 2. Load model in 4-bit ───────────────────────────────────────
    print("\n[2/4] Loading model in 4-bit NF4 …")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

    model_kwargs = {
        "quantization_config": bnb_config,
        "device_map": "auto",
        "trust_remote_code": True,
    }
    if args.flash_attn:
        model_kwargs["attn_implementation"] = "flash_attention_2"

    model = AutoModelForCausalLM.from_pretrained(
        args.base_model, **model_kwargs
    )

    tokenizer = AutoTokenizer.from_pretrained(
        args.base_model, trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(f"  Model loaded: {model.config._name_or_path}")
    print(f"  Tokenizer pad_token: '{tokenizer.pad_token}'")

    # ── 3. LoRA config ───────────────────────────────────────────────
    print("\n[3/4] Configuring LoRA …")
    peft_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules="all-linear",   # QLoRA best practice
    )

    # ── 4. SFT Training ─────────────────────────────────────────────
    print("\n[4/4] Starting SFT training …")

    # Compute eval steps: ~4 evals per epoch
    steps_per_epoch = math.ceil(len(train_ds) / (args.batch_size * args.grad_accum))
    eval_steps = max(steps_per_epoch // 4, 50)

    run_name = (
        f"tmf921-{args.base_model.split('/')[-1]}"
        f"-r{args.lora_r}-lr{args.lr}-ep{args.epochs}"
    )

    sft_config = SFTConfig(
        output_dir=args.output_dir,
        run_name=run_name,
        # Batch
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        # Schedule
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        lr_scheduler_type="cosine",
        warmup_steps=args.warmup_steps,
        weight_decay=args.weight_decay,
        # Precision & memory
        bf16=True,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        # Sequence
        max_length=args.max_length,
        # Loss: train only on assistant outputs
        assistant_only_loss=True,
        # Logging
        logging_strategy="steps",
        logging_steps=10,
        logging_first_step=True,
        disable_tqdm=False,
        # Eval
        eval_strategy="steps",
        eval_steps=eval_steps,
        # Save
        save_strategy="steps",
        save_steps=eval_steps,
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        # Hub
        push_to_hub=args.push_to_hub,
        hub_model_id=args.hub_model_id,
        # Misc
        seed=args.seed,
        report_to="none",
        dataloader_num_workers=4,
        dataloader_pin_memory=True,
    )

    print(f"  Steps/epoch: {steps_per_epoch}")
    print(f"  Eval every:  {eval_steps} steps")
    print(f"  Total steps: ~{steps_per_epoch * args.epochs}")

    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    # Print trainable params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"  Trainable params: {trainable:,} / {total:,} "
          f"({100 * trainable / total:.2f}%)")

    # Train
    train_result = trainer.train()

    # Save
    print("\nSaving final model …")
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    # Save training metrics
    metrics = train_result.metrics
    metrics["train_samples"] = len(train_ds)
    metrics["eval_samples"] = len(eval_ds)
    metrics["base_model"] = args.base_model
    metrics["lora_r"] = args.lora_r
    metrics["lora_alpha"] = args.lora_alpha
    metrics["learning_rate"] = args.lr
    metrics["epochs"] = args.epochs
    metrics["effective_batch_size"] = args.batch_size * args.grad_accum

    with open(os.path.join(args.output_dir, "train_metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)
    print(f"  Metrics saved to {args.output_dir}/train_metrics.json")

    # Push to Hub
    if args.push_to_hub and args.hub_model_id:
        print(f"\nPushing to Hub: {args.hub_model_id}")
        trainer.push_to_hub()

    print("\nβœ… Training complete!")
    print(f"  Model saved to: {args.output_dir}")
    if args.push_to_hub and args.hub_model_id:
        print(f"  Pushed to: https://huggingface.co/{args.hub_model_id}")


if __name__ == "__main__":
    main()