SouravNath's picture
Initial commit
dc71cad
"""
fine_tuning/train.py
──────────────────────
QLoRA fine-tuning entry point for DeepSeek-Coder-7B.
Usage:
# Standard training
python -m fine_tuning.train
# Specific variant for ablation
python -m fine_tuning.train --variant large_r
# Dry run (dataset check, no GPU needed)
python -m fine_tuning.train --dry-run
# Custom config
python -m fine_tuning.train --model deepseek-ai/deepseek-coder-7b-instruct-v1.5 \
--epochs 3 --lr 2e-4 --batch 4
The script performs:
1. Dataset validation (token count, format check)
2. Model loading with 4-bit quantisation
3. LoRA adapter injection
4. SFT training with HuggingFace TRL's SFTTrainer
5. Checkpoint saving + adapter merging
6. MLflow logging of training metrics + config
IMPORTANT: Requires GPU with >= 14GB VRAM.
For development/testing, use --dry-run to validate without GPU.
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
from fine_tuning.qlora_config import TrainingConfig, get_config
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="QLoRA fine-tuning for DeepSeek-Coder")
p.add_argument("--variant", default="default", help="Config variant (default/small_r/large_r/qwen)")
p.add_argument("--model", default=None, help="Override model name")
p.add_argument("--epochs", type=int, default=None)
p.add_argument("--lr", type=float, default=None)
p.add_argument("--batch", type=int, default=None)
p.add_argument("--output", default=None, help="Override output directory")
p.add_argument("--dry-run", action="store_true", help="Validate dataset only, no training")
p.add_argument("--resume", action="store_true", help="Resume from latest checkpoint")
p.add_argument("--merge", action="store_true", help="Merge LoRA into base model after training")
return p.parse_args()
def validate_dataset(config: TrainingConfig) -> dict:
"""Validate dataset files exist and have correct format. No GPU needed."""
from fine_tuning.dataset_builder import estimate_token_counts
results = {}
for split, path_str in [("train", config.train_file), ("val", config.val_file)]:
path = Path(path_str)
if not path.exists():
logger.warning("Dataset file not found: %s", path)
results[split] = {"error": "file not found", "path": str(path)}
continue
n_lines = sum(1 for _ in open(path))
token_stats = estimate_token_counts(path)
# Check format of first 3 lines
format_ok = True
format_errors = []
with path.open() as f:
for i, line in enumerate(f):
if i >= 3:
break
try:
obj = json.loads(line)
if "text" not in obj and "conversations" not in obj and "messages" not in obj:
format_errors.append(f"Line {i+1}: missing 'text' or 'conversations' or 'messages'")
format_ok = False
except json.JSONDecodeError as e:
format_errors.append(f"Line {i+1}: JSON error: {e}")
format_ok = False
results[split] = {
"n_examples": n_lines,
"format_ok": format_ok,
"format_errors": format_errors[:3],
**token_stats,
}
logger.info(
"%s: %d examples | ~%s tokens | format_ok=%s",
split, n_lines,
f"{token_stats.get('estimated_tokens', 0):,}",
format_ok,
)
return results
def train(config: TrainingConfig, resume: bool = False, merge_after: bool = False) -> None:
"""
Run the QLoRA fine-tuning loop.
Requires: transformers, peft, trl, bitsandbytes, torch.
"""
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig as BnBConfig,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import load_dataset
except ImportError as e:
logger.error(
"Missing dependency: %s\n"
"Install with: pip install transformers peft trl bitsandbytes datasets torch\n"
"Or run with --dry-run to validate without GPU.",
e
)
sys.exit(1)
logger.info("Loading model: %s", config.model_name)
logger.info("Estimated VRAM: %.1f GB", config.estimate_vram_gb())
# ── Quantisation ───────────────────────────────────────────────────────
bnb_config = BnBConfig(
load_in_4bit=config.bnb.load_in_4bit,
bnb_4bit_quant_type=config.bnb.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=getattr(torch, config.bnb.bnb_4bit_compute_dtype),
bnb_4bit_use_double_quant=config.bnb.bnb_4bit_use_double_quant,
)
# ── Model + tokenizer ─────────────────────────────────────────────────
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model = prepare_model_for_kbit_training(model)
tokenizer = AutoTokenizer.from_pretrained(
config.model_name, trust_remote_code=True, padding_side="right"
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ── LoRA ──────────────────────────────────────────────────────────────
lora_config = LoraConfig(
r=config.lora.r,
lora_alpha=config.lora.lora_alpha,
lora_dropout=config.lora.lora_dropout,
bias=config.lora.bias,
task_type=config.lora.task_type,
target_modules=config.lora.target_modules,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# ── Dataset ───────────────────────────────────────────────────────────
dataset = load_dataset(
"json",
data_files={"train": config.train_file, "validation": config.val_file},
)
# ── Training args ─────────────────────────────────────────────────────
training_args = TrainingArguments(
output_dir=config.output_dir,
run_name=config.run_name,
num_train_epochs=config.num_train_epochs,
per_device_train_batch_size=config.per_device_train_batch_size,
per_device_eval_batch_size=config.per_device_eval_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
lr_scheduler_type=config.lr_scheduler_type,
warmup_ratio=config.warmup_ratio,
weight_decay=config.weight_decay,
max_grad_norm=config.max_grad_norm,
optim=config.optim,
bf16=config.bf16,
fp16=config.fp16,
save_strategy=config.save_strategy,
save_steps=config.save_steps,
save_total_limit=config.save_total_limit,
logging_steps=config.logging_steps,
eval_strategy=config.eval_strategy,
eval_steps=config.eval_steps,
load_best_model_at_end=config.load_best_model_at_end,
metric_for_best_model=config.metric_for_best_model,
report_to=config.report_to,
)
# ── SFT Trainer ───────────────────────────────────────────────────────
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
dataset_text_field=config.dataset_text_field,
max_seq_length=config.max_seq_length,
packing=config.packing,
)
resume_checkpoint = None
if resume:
ckpts = sorted(Path(config.output_dir).glob("checkpoint-*"))
if ckpts:
resume_checkpoint = str(ckpts[-1])
logger.info("Resuming from checkpoint: %s", resume_checkpoint)
# ── Train ─────────────────────────────────────────────────────────────
logger.info("Starting training: %d epochs, effective batch=%d, lr=%.2e",
config.num_train_epochs, config.effective_batch_size, config.learning_rate)
trainer.train(resume_from_checkpoint=resume_checkpoint)
# ── Save ──────────────────────────────────────────────────────────────
adapter_path = Path(config.output_dir) / "lora_adapter"
trainer.model.save_pretrained(adapter_path)
tokenizer.save_pretrained(adapter_path)
logger.info("LoRA adapter saved to %s", adapter_path)
# ── Merge ─────────────────────────────────────────────────────────────
if merge_after:
merge_adapter(config.model_name, adapter_path, Path(config.output_dir) / "merged")
def merge_adapter(base_model_name: str, adapter_path: Path, output_path: Path) -> None:
"""Merge LoRA weights into base model for fast inference (no PEFT at inference time)."""
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
logger.info("Merging LoRA adapter into base model...")
model = AutoModelForCausalLM.from_pretrained(
base_model_name, torch_dtype=torch.bfloat16, device_map="cpu"
)
model = PeftModel.from_pretrained(model, str(adapter_path))
merged = model.merge_and_unload()
merged.save_pretrained(str(output_path))
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.save_pretrained(str(output_path))
logger.info("Merged model saved to %s", output_path)
except Exception as e:
logger.error("Merge failed: %s", e)
def main():
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
args = parse_args()
# Build config
config = get_config(args.variant)
if args.model: config.model_name = args.model
if args.epochs: config.num_train_epochs = args.epochs
if args.lr: config.learning_rate = args.lr
if args.batch: config.per_device_train_batch_size = args.batch
if args.output: config.output_dir = args.output
logger.info("Training config: model=%s, variant=%s", config.model_name, args.variant)
logger.info("LoRA: r=%d, alpha=%d, modules=%s",
config.lora.r, config.lora.lora_alpha, config.lora.target_modules)
# Validate dataset
dataset_stats = validate_dataset(config)
logger.info("Dataset validation: %s", dataset_stats)
if args.dry_run:
logger.info("Dry run complete β€” dataset valid. Run without --dry-run to start training.")
return
# Train
train(config, resume=args.resume, merge_after=args.merge)
if __name__ == "__main__":
main()