Spaces:
No application file
No application file
| # /// script | |
| # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "transformers>=4.40.0", "datasets>=2.18.0", "accelerate>=0.28.0", "bitsandbytes>=0.41.0"] | |
| # /// | |
| from datasets import load_dataset | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, SFTConfig | |
| from transformers import BitsAndBytesConfig, AutoModelForCausalLM | |
| import torch | |
| import trackio | |
| print("=" * 80) | |
| print("PRODUCTION: Biomedical Llama Fine-Tuning with QLoRA (Full Dataset)") | |
| print("=" * 80) | |
| print("\n[1/5] Loading dataset...") | |
| dataset = load_dataset("panikos/biomedical-llama-training") | |
| train_dataset = dataset["train"] | |
| eval_dataset = dataset["validation"] | |
| print(f" Train: {len(train_dataset)} examples") | |
| print(f" Eval: {len(eval_dataset)} examples") | |
| print("\n[2/5] Configuring 4-bit quantization...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| print(" Quantization: 4-bit NF4") | |
| print(" Compute dtype: bfloat16") | |
| print(" Double quantization: enabled") | |
| print("\n[3/5] Configuring LoRA...") | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| print(" LoRA rank: 16, alpha: 32") | |
| print("\n[4/5] Loading quantized model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "meta-llama/Llama-3.1-8B-Instruct", | |
| quantization_config=bnb_config, | |
| device_map="auto" | |
| ) | |
| print("\n[5/5] Initializing trainer...") | |
| trainer = SFTTrainer( | |
| model=model, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| peft_config=lora_config, | |
| args=SFTConfig( | |
| output_dir="llama-biomedical-production-qlora", | |
| num_train_epochs=3, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| learning_rate=2e-4, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.1, | |
| logging_steps=50, | |
| eval_strategy="steps", | |
| eval_steps=200, | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| push_to_hub=True, | |
| hub_model_id="panikos/llama-biomedical-production-qlora", | |
| hub_private_repo=True, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| report_to="trackio", | |
| project="biomedical-llama-training", | |
| run_name="production-full-dataset-qlora-v1" | |
| ) | |
| ) | |
| print("\n[6/6] Starting training...") | |
| print(" Model: meta-llama/Llama-3.1-8B-Instruct") | |
| print(" Method: QLoRA (4-bit) with LoRA adapters") | |
| print(" Epochs: 3") | |
| print(" Training examples: 17,008") | |
| print(" Validation examples: 896") | |
| print(" Batch size: 2 x 4 = 8 (effective)") | |
| print(" Estimated steps: ~6,378 (2,126 per epoch)") | |
| print(" Gradient checkpointing: ENABLED") | |
| print(" Memory: ~5-6GB (optimized with QLoRA)") | |
| print() | |
| trainer.train() | |
| print("\n" + "=" * 80) | |
| print("Pushing model to Hub...") | |
| print("=" * 80) | |
| trainer.push_to_hub() | |
| print("\n" + "=" * 80) | |
| print("PRODUCTION TRAINING COMPLETE!") | |
| print("=" * 80) | |
| print("\nModel: https://huggingface.co/panikos/llama-biomedical-production-qlora") | |
| print("Dashboard: https://panikos-trackio.hf.space/") | |
| print("\nYour biomedical Llama model is ready!") | |