training-scripts / production_training_llama_qlora.py
panikos's picture
Upload production_training_llama_qlora.py with huggingface_hub
52e3179 verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "transformers>=4.40.0", "datasets>=2.18.0", "accelerate>=0.28.0", "bitsandbytes>=0.41.0"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
import torch
import trackio
print("=" * 80)
print("PRODUCTION: Biomedical Llama Fine-Tuning with QLoRA (Full Dataset)")
print("=" * 80)
print("\n[1/5] Loading dataset...")
dataset = load_dataset("panikos/biomedical-llama-training")
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]
print(f" Train: {len(train_dataset)} examples")
print(f" Eval: {len(eval_dataset)} examples")
print("\n[2/5] Configuring 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print(" Quantization: 4-bit NF4")
print(" Compute dtype: bfloat16")
print(" Double quantization: enabled")
print("\n[3/5] Configuring LoRA...")
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
print(" LoRA rank: 16, alpha: 32")
print("\n[4/5] Loading quantized model...")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
quantization_config=bnb_config,
device_map="auto"
)
print("\n[5/5] Initializing trainer...")
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
args=SFTConfig(
output_dir="llama-biomedical-production-qlora",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
logging_steps=50,
eval_strategy="steps",
eval_steps=200,
save_strategy="epoch",
save_total_limit=2,
push_to_hub=True,
hub_model_id="panikos/llama-biomedical-production-qlora",
hub_private_repo=True,
bf16=True,
gradient_checkpointing=True,
report_to="trackio",
project="biomedical-llama-training",
run_name="production-full-dataset-qlora-v1"
)
)
print("\n[6/6] Starting training...")
print(" Model: meta-llama/Llama-3.1-8B-Instruct")
print(" Method: QLoRA (4-bit) with LoRA adapters")
print(" Epochs: 3")
print(" Training examples: 17,008")
print(" Validation examples: 896")
print(" Batch size: 2 x 4 = 8 (effective)")
print(" Estimated steps: ~6,378 (2,126 per epoch)")
print(" Gradient checkpointing: ENABLED")
print(" Memory: ~5-6GB (optimized with QLoRA)")
print()
trainer.train()
print("\n" + "=" * 80)
print("Pushing model to Hub...")
print("=" * 80)
trainer.push_to_hub()
print("\n" + "=" * 80)
print("PRODUCTION TRAINING COMPLETE!")
print("=" * 80)
print("\nModel: https://huggingface.co/panikos/llama-biomedical-production-qlora")
print("Dashboard: https://panikos-trackio.hf.space/")
print("\nYour biomedical Llama model is ready!")