| """ |
| SSD (Simple Self-Distillation) training for google/gemma-4-E2B-it β QLoRA edition |
| |
| Following the recipe from "Embarrassingly Simple Self-Distillation Improves Code Generation" |
| (arXiv:2604.01193). |
| |
| This script performs SFT on pre-generated completions using QLoRA. |
| |
| Key fixes from sandbox testing: |
| - Use AutoModelForImageTextToText (Gemma-4 is multimodal, not CausalLM) |
| - Explicitly target only language_model layers for LoRA (skip vision/audio towers) |
| - Use dtype= instead of deprecated torch_dtype= |
| - Removed prepare_model_for_kbit_training (causes OOM, not needed with modern TRL) |
| - Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True for better memory management |
| |
| Hardware: L4 24GB (tested), also works on A10G 24GB or A100 80GB |
| - T4 16GB: use max_length=2048 (OOM at 4096 due to 262K vocab logit tensor) |
| - L4/A10G 24GB: max_length=4096 fits comfortably |
| - A100 80GB: can try max_length=8192 or higher |
| |
| Estimated cost on L4 ($0.80/hr): ~3h training = ~$2.40 |
| |
| Requirements: |
| pip install trl transformers datasets trackio accelerate torch peft bitsandbytes |
| """ |
|
|
| import os |
| os.environ["TRL_EXPERIMENTAL_SILENCE"] = "1" |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
| import torch |
| from datasets import load_dataset |
| from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig |
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig |
|
|
| |
| |
| |
| print("Loading pre-generated SSD dataset (wrmedford/Gemma-4-E2B-it-SSD)...") |
| ds = load_dataset( |
| "wrmedford/Gemma-4-E2B-it-SSD", |
| data_files={"train": "ssd_dataset.jsonl"}, |
| split="train", |
| ) |
| print(f"Dataset loaded: {len(ds)} examples") |
| print(f"Columns: {ds.column_names}") |
| print(f"Sample messages[0]: {str(ds[0]['messages'][0])[:200]}...") |
| print(f"Sample messages[1]: {str(ds[0]['messages'][1])[:200]}...") |
|
|
| |
| |
| |
| MODEL_ID = "google/gemma-4-E2B-it" |
| HUB_MODEL_ID = "ludsvick/gemma-4-E2B-it-SSD" |
|
|
| print(f"Loading {MODEL_ID} in 4-bit quantization...") |
|
|
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| |
| |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_ID, |
| quantization_config=bnb_config, |
| device_map="auto", |
| dtype=torch.bfloat16, |
| attn_implementation="eager", |
| ) |
| model.config.use_cache = False |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| print(f"Model loaded. Memory: {torch.cuda.memory_allocated() / 1024**3:.1f} GB") |
|
|
| |
| |
| |
| |
| |
| |
| |
| target_names = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
| lang_modules = [ |
| name for name, _ in model.named_modules() |
| if "language_model" in name and any(name.endswith(t) for t in target_names) |
| ] |
| print(f"Targeting {len(lang_modules)} language model modules for LoRA") |
|
|
| peft_config = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules=lang_modules, |
| ) |
|
|
| |
| |
| |
| training_args = SFTConfig( |
| output_dir="ssd-gemma-4-e2b-it-qlora", |
| hub_model_id=HUB_MODEL_ID, |
| push_to_hub=True, |
|
|
| |
| max_length=4096, |
| completion_only_loss=True, |
|
|
| |
| num_train_epochs=1, |
|
|
| |
| per_device_train_batch_size=1, |
| gradient_accumulation_steps=32, |
|
|
| |
| |
| learning_rate=2e-4, |
| lr_scheduler_type="cosine", |
| adam_beta1=0.9, |
| adam_beta2=0.95, |
| weight_decay=0.1, |
| warmup_steps=50, |
| max_grad_norm=1.0, |
| optim="paged_adamw_8bit", |
|
|
| |
| bf16=True, |
| gradient_checkpointing=True, |
| gradient_checkpointing_kwargs={"use_reentrant": False}, |
|
|
| |
| logging_strategy="steps", |
| logging_steps=1, |
| logging_first_step=True, |
| disable_tqdm=True, |
| report_to="trackio", |
|
|
| |
| save_strategy="steps", |
| save_steps=100, |
| save_total_limit=2, |
| ) |
|
|
| |
| |
| |
| print("Initializing SFTTrainer...") |
| trainer = SFTTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=ds, |
| processing_class=tokenizer, |
| peft_config=peft_config, |
| ) |
|
|
| print("Starting SSD (SFT step) training...") |
| trainable = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in trainer.model.parameters()) |
| print(f" Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") |
| trainer.train() |
|
|
| |
| |
| |
| print("Pushing model to Hub...") |
| trainer.push_to_hub() |
| print(f"β
LoRA adapter saved to https://huggingface.co/{HUB_MODEL_ID}") |
|
|