nraptisss's picture
Upload train.py
751b8a2 verified
"""
QLoRA Fine-Tuning Script for Telecom Intent-to-Config Translation
Optimized for Kaggle T4x2 (2x T4 GPUs, ~30h/week free)
Dataset: nraptisss/TMF921-intent-to-config-augmented (or any dataset with 'messages' column)
Model: Qwen/Qwen2.5-7B-Instruct (or meta-llama/Llama-3.1-8B-Instruct)
Output: LoRA adapters saved locally, then merge_and_push.py merges and pushes
"""
import os
import sys
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
# ============================================================================
# CONFIGURATION — EDIT THESE
# ============================================================================
# Model
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" # or "meta-llama/Llama-3.1-8B-Instruct"
# Dataset
DATASET_NAME = "nraptisss/TMF921-intent-to-config-augmented"
DATASET_CONFIG = "default"
TRAIN_SPLIT = "train"
TEST_SPLIT = "test"
# Output
OUTPUT_DIR = "./qwen2.5-7b-telecom-intent-lora"
# Training hyperparameters (optimized for T4 16GB)
NUM_EPOCHS = 3
BATCH_SIZE = 1
GRAD_ACCUMULATION = 4 # effective batch = 4
LEARNING_RATE = 2.0e-4
MAX_LENGTH = 512
LORA_R = 64
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
# ============================================================================
# SETUP
# ============================================================================
def setup():
"""Verify GPU and set environment."""
if not torch.cuda.is_available():
print("WARNING: No GPU detected. This will be extremely slow on CPU.")
sys.exit(1)
gpu_count = torch.cuda.device_count()
print(f"Detected {gpu_count} GPU(s):")
for i in range(gpu_count):
props = torch.cuda.get_device_properties(i)
print(f" GPU {i}: {props.name} ({props.total_memory / 1e9:.1f} GB)")
# T4-specific: use fp16, not bf16
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
return gpu_count
def load_model_and_tokenizer(model_name: str):
"""Load 4-bit quantized model and tokenizer."""
print(f"\nLoading model: {model_name}")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16, # T4: fp16, not bf16
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
padding_side="right",
)
# Qwen2.5 already has a pad_token; only set if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
dtype=torch.float16,
)
# NOTE: Do NOT manually enable gradient_checkpointing here.
# SFTTrainer handles it automatically when gradient_checkpointing=True in args.
# Manual enable + liger_kernel causes Triton crashes on T4.
print(f"Model loaded. VRAM used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
return model, tokenizer
def load_and_inspect_dataset(dataset_name: str, config_name: str, split: str):
"""Load dataset and verify messages column."""
print(f"\nLoading dataset: {dataset_name} (config={config_name}, split={split})")
ds = load_dataset(dataset_name, config_name, split=split)
print(f"Dataset size: {len(ds)} examples")
# Verify format
sample = ds[0]
if "messages" not in sample:
raise ValueError(
f"Dataset must have 'messages' column. Got: {list(sample.keys())}"
)
msgs = sample["messages"]
print(f"Sample messages structure: {len(msgs)} messages")
for m in msgs:
print(f" role={m.get('role')}, content_len={len(m.get('content', ''))}")
# Print a sample intent text
for m in msgs:
if m.get("role") == "user":
print(f"\nSample user intent:\n{m['content'][:200]}...")
break
return ds
def get_lora_config():
"""Return LoRA config optimized for intent-to-config task."""
return LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules="all-linear",
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
def get_training_args(output_dir: str, num_gpus: int):
"""Return SFTConfig optimized for Kaggle T4x2."""
return SFTConfig(
output_dir=output_dir,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUMULATION,
learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
logging_steps=10,
save_strategy="epoch",
eval_strategy="epoch" if TEST_SPLIT else "no",
fp16=True,
bf16=False,
max_length=MAX_LENGTH,
gradient_checkpointing=True,
# NOTE: liger_kernel disabled for T4 compatibility.
# Enable only on A100/L40/H100: use_liger_kernel=True
use_liger_kernel=False,
report_to="none",
load_best_model_at_end=False,
dataloader_num_workers=0,
remove_unused_columns=False,
)
def train(model, tokenizer, train_ds, eval_ds=None):
"""Run SFT training with QLoRA."""
print("\n" + "=" * 60)
print("STARTING TRAINING")
print("=" * 60)
peft_config = get_lora_config()
training_args = get_training_args(OUTPUT_DIR, torch.cuda.device_count())
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
peft_config=peft_config,
)
trainer.train()
# Save adapters
print(f"\nSaving LoRA adapters to {OUTPUT_DIR}")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Training complete!")
return trainer
def main():
num_gpus = setup()
# Load everything
model, tokenizer = load_model_and_tokenizer(MODEL_NAME)
train_ds = load_and_inspect_dataset(DATASET_NAME, DATASET_CONFIG, TRAIN_SPLIT)
eval_ds = None
if TEST_SPLIT:
try:
eval_ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
print(f"Eval dataset: {len(eval_ds)} examples")
except Exception as e:
print(f"No eval split available: {e}")
# Train
trainer = train(model, tokenizer, train_ds, eval_ds)
print("\n" + "=" * 60)
print("NEXT STEPS:")
print("=" * 60)
print("1. Run inference.py to test the model")
print("2. Run merge_and_push.py to merge adapters and push to hub")
print("3. Run benchmark.py to evaluate on the test set")
if __name__ == "__main__":
main()