""" SSD (Simple Self-Distillation) training for google/gemma-4-E2B-it — QLoRA edition Following the recipe from "Embarrassingly Simple Self-Distillation Improves Code Generation" (arXiv:2604.01193). This script performs SFT on pre-generated completions using QLoRA. Key fixes from sandbox testing: - Use AutoModelForImageTextToText (Gemma-4 is multimodal, not CausalLM) - Explicitly target only language_model layers for LoRA (skip vision/audio towers) - Use dtype= instead of deprecated torch_dtype= - Removed prepare_model_for_kbit_training (causes OOM, not needed with modern TRL) - Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True for better memory management Hardware: L4 24GB (tested), also works on A10G 24GB or A100 80GB - T4 16GB: use max_length=2048 (OOM at 4096 due to 262K vocab logit tensor) - L4/A10G 24GB: max_length=4096 fits comfortably - A100 80GB: can try max_length=8192 or higher Estimated cost on L4 ($0.80/hr): ~3h training = ~$2.40 Requirements: pip install trl transformers datasets trackio accelerate torch peft bitsandbytes """ import os os.environ["TRL_EXPERIMENTAL_SILENCE"] = "1" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch from datasets import load_dataset from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig from peft import LoraConfig from trl import SFTTrainer, SFTConfig # ────────────────────────────────────────────────────────────────────── # 1. Load pre-generated SSD dataset # ────────────────────────────────────────────────────────────────────── print("Loading pre-generated SSD dataset (wrmedford/Gemma-4-E2B-it-SSD)...") ds = load_dataset( "wrmedford/Gemma-4-E2B-it-SSD", data_files={"train": "ssd_dataset.jsonl"}, split="train", ) print(f"Dataset loaded: {len(ds)} examples") print(f"Columns: {ds.column_names}") print(f"Sample messages[0]: {str(ds[0]['messages'][0])[:200]}...") print(f"Sample messages[1]: {str(ds[0]['messages'][1])[:200]}...") # ────────────────────────────────────────────────────────────────────── # 2. Load model with 4-bit quantization (QLoRA) # ────────────────────────────────────────────────────────────────────── MODEL_ID = "google/gemma-4-E2B-it" HUB_MODEL_ID = "ludsvick/gemma-4-E2B-it-SSD" print(f"Loading {MODEL_ID} in 4-bit quantization...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) # IMPORTANT: Gemma-4 is multimodal — must use AutoModelForImageTextToText # Using AutoModelForCausalLM will load Gemma4ForCausalLM with random weights! model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, quantization_config=bnb_config, device_map="auto", dtype=torch.bfloat16, attn_implementation="eager", # L4/T4 don't support flash-attn; use eager ) model.config.use_cache = False # required for gradient checkpointing tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Model loaded. Memory: {torch.cuda.memory_allocated() / 1024**3:.1f} GB") # ────────────────────────────────────────────────────────────────────── # 3. LoRA configuration — target ONLY language_model layers # ────────────────────────────────────────────────────────────────────── # Gemma-4 has vision_tower, audio_tower, and language_model sub-modules. # We only want to adapt the language_model for code generation SSD. # Using generic target_modules=["q_proj", ...] would also LoRA the vision # and audio towers, wasting parameters and potentially hurting quality. target_names = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] lang_modules = [ name for name, _ in model.named_modules() if "language_model" in name and any(name.endswith(t) for t in target_names) ] print(f"Targeting {len(lang_modules)} language model modules for LoRA") peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=lang_modules, ) # ────────────────────────────────────────────────────────────────────── # 4. SFT training config (following paper's hyperparameters) # ────────────────────────────────────────────────────────────────────── training_args = SFTConfig( output_dir="ssd-gemma-4-e2b-it-qlora", hub_model_id=HUB_MODEL_ID, push_to_hub=True, # SFT-specific max_length=4096, # L4 24GB tested; paper uses 65536 completion_only_loss=True, # SSD: loss only on completions, not prompts # Training duration: 1 epoch ≈ paper's ~300 iters for thinking models num_train_epochs=1, # Batch: effective = 1 * 32 = 32 (matches paper) per_device_train_batch_size=1, gradient_accumulation_steps=32, # Optimizer: paper uses AdamW (beta1=0.9, beta2=0.95, wd=0.1) # LR bumped to 2e-4 for LoRA (10x base rate per TRL PEFT docs) learning_rate=2e-4, lr_scheduler_type="cosine", adam_beta1=0.9, adam_beta2=0.95, weight_decay=0.1, warmup_steps=50, max_grad_norm=1.0, optim="paged_adamw_8bit", # paged optimizer saves GPU memory # Memory optimizations bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, # Logging logging_strategy="steps", logging_steps=1, logging_first_step=True, disable_tqdm=True, report_to="trackio", # Saving save_strategy="steps", save_steps=100, save_total_limit=2, ) # ────────────────────────────────────────────────────────────────────── # 5. Create trainer and train # ────────────────────────────────────────────────────────────────────── print("Initializing SFTTrainer...") trainer = SFTTrainer( model=model, args=training_args, train_dataset=ds, processing_class=tokenizer, peft_config=peft_config, ) print("Starting SSD (SFT step) training...") trainable = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad) total = sum(p.numel() for p in trainer.model.parameters()) print(f" Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") trainer.train() # ────────────────────────────────────────────────────────────────────── # 6. Save & push to Hub # ────────────────────────────────────────────────────────────────────── print("Pushing model to Hub...") trainer.push_to_hub() print(f"✅ LoRA adapter saved to https://huggingface.co/{HUB_MODEL_ID}")