nraptisss's picture
fix: max_seq_length β†’ max_length, warmup_ratio β†’ warmup_steps (TRL 1.3 compat)
2fdbc71 verified
#!/usr/bin/env python3
"""
TMF921 Intent-to-Configuration Translation β€” QLoRA Fine-Tuning
=============================================================
Fine-tunes Qwen3-8B on the TMF921-intent-to-config-augmented dataset
using 4-bit QLoRA. Designed for a single RTX 6000 Ada (50 GB VRAM).
Usage:
python train.py # defaults
python train.py --base_model Qwen/Qwen3-8B --epochs 3 --lr 1e-4
python train.py --base_model Qwen/Qwen2.5-7B-Instruct --lora_r 64
"""
import argparse, os, json, torch, math
from datetime import datetime
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
# ── CLI ──────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser()
# Model
p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B",
help="HuggingFace model id or local path")
# Dataset
p.add_argument("--dataset", type=str,
default="nraptisss/TMF921-intent-to-config-augmented")
# LoRA
p.add_argument("--lora_r", type=int, default=32)
p.add_argument("--lora_alpha", type=int, default=64)
p.add_argument("--lora_dropout", type=float, default=0.05)
# Training
p.add_argument("--epochs", type=int, default=3)
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--batch_size", type=int, default=4)
p.add_argument("--grad_accum", type=int, default=8)
p.add_argument("--max_length", type=int, default=4096)
p.add_argument("--warmup_steps", type=int, default=100)
p.add_argument("--weight_decay", type=float, default=0.01)
# Output
p.add_argument("--output_dir", type=str, default="./output")
p.add_argument("--hub_model_id", type=str, default=None,
help="Push to this HF model id (e.g. nraptisss/Qwen3-8B-TMF921)")
p.add_argument("--push_to_hub", action="store_true", default=False)
# Misc
p.add_argument("--seed", type=int, default=42)
p.add_argument("--flash_attn", action="store_true", default=True)
p.add_argument("--no_flash_attn", dest="flash_attn", action="store_false")
return p.parse_args()
def main():
args = parse_args()
print("=" * 70)
print("TMF921 Intent Translation β€” QLoRA Training")
print("=" * 70)
print(f"Base model : {args.base_model}")
print(f"Dataset : {args.dataset}")
print(f"LoRA r/alpha : {args.lora_r}/{args.lora_alpha}")
print(f"Epochs : {args.epochs}")
print(f"LR : {args.lr}")
print(f"Batch size : {args.batch_size} Γ— {args.grad_accum} grad_accum = "
f"{args.batch_size * args.grad_accum} effective")
print(f"Max length : {args.max_length}")
print(f"Flash attn : {args.flash_attn}")
print(f"Output : {args.output_dir}")
print(f"Push to Hub : {args.push_to_hub} β†’ {args.hub_model_id}")
print("=" * 70)
# ── 1. Load dataset ──────────────────────────────────────────────
print("\n[1/4] Loading dataset …")
dataset = load_dataset(args.dataset)
train_ds = dataset["train"]
eval_ds = dataset["test"]
print(f" Train: {len(train_ds):,} | Eval: {len(eval_ds):,}")
# ── 2. Load model in 4-bit ───────────────────────────────────────
print("\n[2/4] Loading model in 4-bit NF4 …")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model_kwargs = {
"quantization_config": bnb_config,
"device_map": "auto",
"trust_remote_code": True,
}
if args.flash_attn:
model_kwargs["attn_implementation"] = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
args.base_model, **model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(
args.base_model, trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f" Model loaded: {model.config._name_or_path}")
print(f" Tokenizer pad_token: '{tokenizer.pad_token}'")
# ── 3. LoRA config ───────────────────────────────────────────────
print("\n[3/4] Configuring LoRA …")
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear", # QLoRA best practice
)
# ── 4. SFT Training ─────────────────────────────────────────────
print("\n[4/4] Starting SFT training …")
# Compute eval steps: ~4 evals per epoch
steps_per_epoch = math.ceil(len(train_ds) / (args.batch_size * args.grad_accum))
eval_steps = max(steps_per_epoch // 4, 50)
run_name = (
f"tmf921-{args.base_model.split('/')[-1]}"
f"-r{args.lora_r}-lr{args.lr}-ep{args.epochs}"
)
sft_config = SFTConfig(
output_dir=args.output_dir,
run_name=run_name,
# Batch
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
# Schedule
num_train_epochs=args.epochs,
learning_rate=args.lr,
lr_scheduler_type="cosine",
warmup_steps=args.warmup_steps,
weight_decay=args.weight_decay,
# Precision & memory
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
# Sequence
max_length=args.max_length,
# Loss: train only on assistant outputs
assistant_only_loss=True,
# Logging
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=False,
# Eval
eval_strategy="steps",
eval_steps=eval_steps,
# Save
save_strategy="steps",
save_steps=eval_steps,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
# Hub
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
# Misc
seed=args.seed,
report_to="none",
dataloader_num_workers=4,
dataloader_pin_memory=True,
)
print(f" Steps/epoch: {steps_per_epoch}")
print(f" Eval every: {eval_steps} steps")
print(f" Total steps: ~{steps_per_epoch * args.epochs}")
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
peft_config=peft_config,
)
# Print trainable params
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f" Trainable params: {trainable:,} / {total:,} "
f"({100 * trainable / total:.2f}%)")
# Train
train_result = trainer.train()
# Save
print("\nSaving final model …")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Save training metrics
metrics = train_result.metrics
metrics["train_samples"] = len(train_ds)
metrics["eval_samples"] = len(eval_ds)
metrics["base_model"] = args.base_model
metrics["lora_r"] = args.lora_r
metrics["lora_alpha"] = args.lora_alpha
metrics["learning_rate"] = args.lr
metrics["epochs"] = args.epochs
metrics["effective_batch_size"] = args.batch_size * args.grad_accum
with open(os.path.join(args.output_dir, "train_metrics.json"), "w") as f:
json.dump(metrics, f, indent=2)
print(f" Metrics saved to {args.output_dir}/train_metrics.json")
# Push to Hub
if args.push_to_hub and args.hub_model_id:
print(f"\nPushing to Hub: {args.hub_model_id}")
trainer.push_to_hub()
print("\nβœ… Training complete!")
print(f" Model saved to: {args.output_dir}")
if args.push_to_hub and args.hub_model_id:
print(f" Pushed to: https://huggingface.co/{args.hub_model_id}")
if __name__ == "__main__":
main()