| |
| """ |
| TMF921 Intent-to-Configuration Translation β QLoRA Fine-Tuning |
| ============================================================= |
| Fine-tunes Qwen3-8B on the TMF921-intent-to-config-augmented dataset |
| using 4-bit QLoRA. Designed for a single RTX 6000 Ada (50 GB VRAM). |
| |
| Usage: |
| python train.py # defaults |
| python train.py --base_model Qwen/Qwen3-8B --epochs 3 --lr 1e-4 |
| python train.py --base_model Qwen/Qwen2.5-7B-Instruct --lora_r 64 |
| """ |
|
|
| import argparse, os, json, torch, math |
| from datetime import datetime |
|
|
| from datasets import load_dataset |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| ) |
| from peft import LoraConfig |
| from trl import SFTConfig, SFTTrainer |
|
|
|
|
| |
| def parse_args(): |
| p = argparse.ArgumentParser() |
| |
| p.add_argument("--base_model", type=str, default="Qwen/Qwen3-8B", |
| help="HuggingFace model id or local path") |
| |
| p.add_argument("--dataset", type=str, |
| default="nraptisss/TMF921-intent-to-config-augmented") |
| |
| p.add_argument("--lora_r", type=int, default=32) |
| p.add_argument("--lora_alpha", type=int, default=64) |
| p.add_argument("--lora_dropout", type=float, default=0.05) |
| |
| p.add_argument("--epochs", type=int, default=3) |
| p.add_argument("--lr", type=float, default=1e-4) |
| p.add_argument("--batch_size", type=int, default=4) |
| p.add_argument("--grad_accum", type=int, default=8) |
| p.add_argument("--max_length", type=int, default=4096) |
| p.add_argument("--warmup_steps", type=int, default=100) |
| p.add_argument("--weight_decay", type=float, default=0.01) |
| |
| p.add_argument("--output_dir", type=str, default="./output") |
| p.add_argument("--hub_model_id", type=str, default=None, |
| help="Push to this HF model id (e.g. nraptisss/Qwen3-8B-TMF921)") |
| p.add_argument("--push_to_hub", action="store_true", default=False) |
| |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--flash_attn", action="store_true", default=True) |
| p.add_argument("--no_flash_attn", dest="flash_attn", action="store_false") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| print("=" * 70) |
| print("TMF921 Intent Translation β QLoRA Training") |
| print("=" * 70) |
| print(f"Base model : {args.base_model}") |
| print(f"Dataset : {args.dataset}") |
| print(f"LoRA r/alpha : {args.lora_r}/{args.lora_alpha}") |
| print(f"Epochs : {args.epochs}") |
| print(f"LR : {args.lr}") |
| print(f"Batch size : {args.batch_size} Γ {args.grad_accum} grad_accum = " |
| f"{args.batch_size * args.grad_accum} effective") |
| print(f"Max length : {args.max_length}") |
| print(f"Flash attn : {args.flash_attn}") |
| print(f"Output : {args.output_dir}") |
| print(f"Push to Hub : {args.push_to_hub} β {args.hub_model_id}") |
| print("=" * 70) |
|
|
| |
| print("\n[1/4] Loading dataset β¦") |
| dataset = load_dataset(args.dataset) |
| train_ds = dataset["train"] |
| eval_ds = dataset["test"] |
| print(f" Train: {len(train_ds):,} | Eval: {len(eval_ds):,}") |
|
|
| |
| print("\n[2/4] Loading model in 4-bit NF4 β¦") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| model_kwargs = { |
| "quantization_config": bnb_config, |
| "device_map": "auto", |
| "trust_remote_code": True, |
| } |
| if args.flash_attn: |
| model_kwargs["attn_implementation"] = "flash_attention_2" |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, **model_kwargs |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| args.base_model, trust_remote_code=True |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| print(f" Model loaded: {model.config._name_or_path}") |
| print(f" Tokenizer pad_token: '{tokenizer.pad_token}'") |
|
|
| |
| print("\n[3/4] Configuring LoRA β¦") |
| peft_config = LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules="all-linear", |
| ) |
|
|
| |
| print("\n[4/4] Starting SFT training β¦") |
|
|
| |
| steps_per_epoch = math.ceil(len(train_ds) / (args.batch_size * args.grad_accum)) |
| eval_steps = max(steps_per_epoch // 4, 50) |
|
|
| run_name = ( |
| f"tmf921-{args.base_model.split('/')[-1]}" |
| f"-r{args.lora_r}-lr{args.lr}-ep{args.epochs}" |
| ) |
|
|
| sft_config = SFTConfig( |
| output_dir=args.output_dir, |
| run_name=run_name, |
| |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.grad_accum, |
| |
| num_train_epochs=args.epochs, |
| learning_rate=args.lr, |
| lr_scheduler_type="cosine", |
| warmup_steps=args.warmup_steps, |
| weight_decay=args.weight_decay, |
| |
| bf16=True, |
| gradient_checkpointing=True, |
| gradient_checkpointing_kwargs={"use_reentrant": False}, |
| |
| max_length=args.max_length, |
| |
| assistant_only_loss=True, |
| |
| logging_strategy="steps", |
| logging_steps=10, |
| logging_first_step=True, |
| disable_tqdm=False, |
| |
| eval_strategy="steps", |
| eval_steps=eval_steps, |
| |
| save_strategy="steps", |
| save_steps=eval_steps, |
| save_total_limit=3, |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| |
| push_to_hub=args.push_to_hub, |
| hub_model_id=args.hub_model_id, |
| |
| seed=args.seed, |
| report_to="none", |
| dataloader_num_workers=4, |
| dataloader_pin_memory=True, |
| ) |
|
|
| print(f" Steps/epoch: {steps_per_epoch}") |
| print(f" Eval every: {eval_steps} steps") |
| print(f" Total steps: ~{steps_per_epoch * args.epochs}") |
|
|
| trainer = SFTTrainer( |
| model=model, |
| args=sft_config, |
| train_dataset=train_ds, |
| eval_dataset=eval_ds, |
| processing_class=tokenizer, |
| peft_config=peft_config, |
| ) |
|
|
| |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in model.parameters()) |
| print(f" Trainable params: {trainable:,} / {total:,} " |
| f"({100 * trainable / total:.2f}%)") |
|
|
| |
| train_result = trainer.train() |
|
|
| |
| print("\nSaving final model β¦") |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
| |
| metrics = train_result.metrics |
| metrics["train_samples"] = len(train_ds) |
| metrics["eval_samples"] = len(eval_ds) |
| metrics["base_model"] = args.base_model |
| metrics["lora_r"] = args.lora_r |
| metrics["lora_alpha"] = args.lora_alpha |
| metrics["learning_rate"] = args.lr |
| metrics["epochs"] = args.epochs |
| metrics["effective_batch_size"] = args.batch_size * args.grad_accum |
|
|
| with open(os.path.join(args.output_dir, "train_metrics.json"), "w") as f: |
| json.dump(metrics, f, indent=2) |
| print(f" Metrics saved to {args.output_dir}/train_metrics.json") |
|
|
| |
| if args.push_to_hub and args.hub_model_id: |
| print(f"\nPushing to Hub: {args.hub_model_id}") |
| trainer.push_to_hub() |
|
|
| print("\nβ
Training complete!") |
| print(f" Model saved to: {args.output_dir}") |
| if args.push_to_hub and args.hub_model_id: |
| print(f" Pushed to: https://huggingface.co/{args.hub_model_id}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|