gemma-4-E2B-it-SSD / train_ssd.py
ludsvick's picture
Fix train_ssd.py: use correct model class (AutoModelForImageTextToText), target language_model only for LoRA, fix deprecated torch_dtype, remove OOM-causing prepare_model_for_kbit_training
923df00 verified
"""
SSD (Simple Self-Distillation) training for google/gemma-4-E2B-it β€” QLoRA edition
Following the recipe from "Embarrassingly Simple Self-Distillation Improves Code Generation"
(arXiv:2604.01193).
This script performs SFT on pre-generated completions using QLoRA.
Key fixes from sandbox testing:
- Use AutoModelForImageTextToText (Gemma-4 is multimodal, not CausalLM)
- Explicitly target only language_model layers for LoRA (skip vision/audio towers)
- Use dtype= instead of deprecated torch_dtype=
- Removed prepare_model_for_kbit_training (causes OOM, not needed with modern TRL)
- Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True for better memory management
Hardware: L4 24GB (tested), also works on A10G 24GB or A100 80GB
- T4 16GB: use max_length=2048 (OOM at 4096 due to 262K vocab logit tensor)
- L4/A10G 24GB: max_length=4096 fits comfortably
- A100 80GB: can try max_length=8192 or higher
Estimated cost on L4 ($0.80/hr): ~3h training = ~$2.40
Requirements:
pip install trl transformers datasets trackio accelerate torch peft bitsandbytes
"""
import os
os.environ["TRL_EXPERIMENTAL_SILENCE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
from datasets import load_dataset
from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
# ──────────────────────────────────────────────────────────────────────
# 1. Load pre-generated SSD dataset
# ──────────────────────────────────────────────────────────────────────
print("Loading pre-generated SSD dataset (wrmedford/Gemma-4-E2B-it-SSD)...")
ds = load_dataset(
"wrmedford/Gemma-4-E2B-it-SSD",
data_files={"train": "ssd_dataset.jsonl"},
split="train",
)
print(f"Dataset loaded: {len(ds)} examples")
print(f"Columns: {ds.column_names}")
print(f"Sample messages[0]: {str(ds[0]['messages'][0])[:200]}...")
print(f"Sample messages[1]: {str(ds[0]['messages'][1])[:200]}...")
# ──────────────────────────────────────────────────────────────────────
# 2. Load model with 4-bit quantization (QLoRA)
# ──────────────────────────────────────────────────────────────────────
MODEL_ID = "google/gemma-4-E2B-it"
HUB_MODEL_ID = "ludsvick/gemma-4-E2B-it-SSD"
print(f"Loading {MODEL_ID} in 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# IMPORTANT: Gemma-4 is multimodal β€” must use AutoModelForImageTextToText
# Using AutoModelForCausalLM will load Gemma4ForCausalLM with random weights!
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
dtype=torch.bfloat16,
attn_implementation="eager", # L4/T4 don't support flash-attn; use eager
)
model.config.use_cache = False # required for gradient checkpointing
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Model loaded. Memory: {torch.cuda.memory_allocated() / 1024**3:.1f} GB")
# ──────────────────────────────────────────────────────────────────────
# 3. LoRA configuration β€” target ONLY language_model layers
# ──────────────────────────────────────────────────────────────────────
# Gemma-4 has vision_tower, audio_tower, and language_model sub-modules.
# We only want to adapt the language_model for code generation SSD.
# Using generic target_modules=["q_proj", ...] would also LoRA the vision
# and audio towers, wasting parameters and potentially hurting quality.
target_names = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lang_modules = [
name for name, _ in model.named_modules()
if "language_model" in name and any(name.endswith(t) for t in target_names)
]
print(f"Targeting {len(lang_modules)} language model modules for LoRA")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=lang_modules,
)
# ──────────────────────────────────────────────────────────────────────
# 4. SFT training config (following paper's hyperparameters)
# ──────────────────────────────────────────────────────────────────────
training_args = SFTConfig(
output_dir="ssd-gemma-4-e2b-it-qlora",
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
# SFT-specific
max_length=4096, # L4 24GB tested; paper uses 65536
completion_only_loss=True, # SSD: loss only on completions, not prompts
# Training duration: 1 epoch β‰ˆ paper's ~300 iters for thinking models
num_train_epochs=1,
# Batch: effective = 1 * 32 = 32 (matches paper)
per_device_train_batch_size=1,
gradient_accumulation_steps=32,
# Optimizer: paper uses AdamW (beta1=0.9, beta2=0.95, wd=0.1)
# LR bumped to 2e-4 for LoRA (10x base rate per TRL PEFT docs)
learning_rate=2e-4,
lr_scheduler_type="cosine",
adam_beta1=0.9,
adam_beta2=0.95,
weight_decay=0.1,
warmup_steps=50,
max_grad_norm=1.0,
optim="paged_adamw_8bit", # paged optimizer saves GPU memory
# Memory optimizations
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
# Logging
logging_strategy="steps",
logging_steps=1,
logging_first_step=True,
disable_tqdm=True,
report_to="trackio",
# Saving
save_strategy="steps",
save_steps=100,
save_total_limit=2,
)
# ──────────────────────────────────────────────────────────────────────
# 5. Create trainer and train
# ──────────────────────────────────────────────────────────────────────
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=ds,
processing_class=tokenizer,
peft_config=peft_config,
)
print("Starting SSD (SFT step) training...")
trainable = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
total = sum(p.numel() for p in trainer.model.parameters())
print(f" Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
trainer.train()
# ──────────────────────────────────────────────────────────────────────
# 6. Save & push to Hub
# ──────────────────────────────────────────────────────────────────────
print("Pushing model to Hub...")
trainer.push_to_hub()
print(f"βœ… LoRA adapter saved to https://huggingface.co/{HUB_MODEL_ID}")