""" Tinman-SmolOmni-MLA Training Script Stage 1: MLA initialization + KL distillation from SmolVLM teacher Stage 2: Joint AR + flow-matching training on image-text pairs Based on: - X-EcoMLA: SVD init + KD fine-tuning (3.6B tokens for SmolLM family) - Show-o2: Dual AR + flow-matching loss - JanusFlow: Representation alignment (REPA) Usage: python train.py --stage 1 --model_variant 256M python train.py --stage 2 --model_variant 256M --checkpoint stage1_output """ import os import sys import math import argparse import json import time from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, IterableDataset from accelerate import Accelerator from accelerate.utils import set_seed from transformers import ( AutoModelForImageTextToText, AutoProcessor, AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup, ) # Add smolomni to path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from smolomni.config import SmolOmniConfig from smolomni.model import SmolOmniModel from smolomni.svd_init import initialize_mla_from_pretrained import trackio # Safe trackio wrapper def safe_trackio_log(metrics): try: trackio.log(metrics) except Exception: pass # ===== Stage 1: KL Distillation Dataset ===== class TextDistillationDataset(IterableDataset): """Streams text from FineWeb-Edu for KL distillation.""" def __init__(self, tokenizer, max_length=512, max_samples=None): from datasets import load_dataset self.dataset = load_dataset( "HuggingFaceFW/fineweb-edu", name="CC-MAIN-2024-10", # Use one recent crawl split="train", streaming=True, ) self.tokenizer = tokenizer self.max_length = max_length self.max_samples = max_samples def __iter__(self): count = 0 for example in self.dataset: if self.max_samples and count >= self.max_samples: break text = example.get("text", "") if len(text) < 50: continue tokens = self.tokenizer( text, max_length=self.max_length, truncation=True, return_tensors="pt", padding="max_length", ) yield { "input_ids": tokens["input_ids"].squeeze(0), "attention_mask": tokens["attention_mask"].squeeze(0), } count += 1 # ===== Stage 2: Image-Text Dataset ===== class ImageTextDataset(IterableDataset): """Streams image-text pairs for joint AR + flow-matching training.""" def __init__(self, tokenizer, vae, max_length=256, image_size=256, max_samples=None): from datasets import load_dataset self.dataset = load_dataset( "HuggingFaceM4/the_cauldron", name="chartqa", # Start with a manageable subset split="train", streaming=True, ) self.tokenizer = tokenizer self.vae = vae self.max_length = max_length self.image_size = image_size self.max_samples = max_samples from torchvision import transforms self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def __iter__(self): count = 0 for example in self.dataset: if self.max_samples and count >= self.max_samples: break try: # Get text texts = example.get("texts", []) if not texts: continue text = texts[0].get("user", "") + " " + texts[0].get("assistant", "") if len(text) < 10: continue # Tokenize tokens = self.tokenizer( text, max_length=self.max_length, truncation=True, return_tensors="pt", padding="max_length", ) # Get image (use dummy latents if image processing fails) images = example.get("images", []) if images and images[0] is not None: try: from PIL import Image img = images[0] if not isinstance(img, Image.Image): img = Image.open(img).convert("RGB") else: img = img.convert("RGB") img_tensor = self.transform(img).unsqueeze(0) # Encode with VAE with torch.no_grad(): latents = self.vae.encode(img_tensor.to(self.vae.device, dtype=self.vae.dtype)).latent_dist.sample() latents = latents * self.vae.config.scaling_factor except Exception: latents = torch.randn(1, 4, self.image_size // 8, self.image_size // 8) else: latents = torch.randn(1, 4, self.image_size // 8, self.image_size // 8) yield { "input_ids": tokens["input_ids"].squeeze(0), "attention_mask": tokens["attention_mask"].squeeze(0), "latents": latents.squeeze(0).cpu(), } count += 1 except Exception as e: continue def train_stage1(args, config): """Stage 1: SVD init + KL distillation from teacher model.""" accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16", ) if accelerator.is_main_process: try: trackio.init( project="SmolOmni-MLA", name="Stage1-KD", config=vars(args), ) except Exception as e: print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.") set_seed(args.seed) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Create student model with SVD initialization print("Creating student model with SVD initialization...") student = SmolOmniModel(config) student = initialize_mla_from_pretrained(student, config.base_model, config) # Load teacher model (frozen) print("Loading teacher model...") # SmolVLM-256M uses SmolLM2-135M as backbone base_lm_map = { "256M": "HuggingFaceTB/SmolLM2-135M-Instruct", "500M": "HuggingFaceTB/SmolLM2-360M-Instruct", } teacher_name = base_lm_map.get(config.model_variant, "HuggingFaceTB/SmolLM2-135M-Instruct") try: teacher = AutoModelForCausalLM.from_pretrained(teacher_name, torch_dtype=torch.bfloat16) except Exception: print(f"Warning: Could not load teacher {teacher_name}, using student as teacher (self-distillation)") teacher = None if teacher is not None: teacher.eval() for p in teacher.parameters(): p.requires_grad = False # Dataset dataset = TextDistillationDataset( tokenizer, max_length=args.max_length, max_samples=args.max_train_samples, ) dataloader = DataLoader(dataset, batch_size=args.batch_size) # Optimizer optimizer = torch.optim.AdamW( student.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(0.9, 0.95), ) scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps, ) # Prepare student, optimizer, dataloader, scheduler = accelerator.prepare( student, optimizer, dataloader, scheduler ) if teacher is not None: teacher = accelerator.prepare(teacher) # Training loop student.train() global_step = 0 total_loss = 0.0 start_time = time.time() print(f"\n{'='*60}") print(f"Stage 1: KL Distillation Training") print(f"Model: {config.model_variant}, Steps: {args.max_steps}") print(f"Batch size: {args.batch_size} x {args.gradient_accumulation_steps} = {args.batch_size * args.gradient_accumulation_steps}") print(f"Learning rate: {args.learning_rate}") print(f"{'='*60}\n") for batch in dataloader: if global_step >= args.max_steps: break with accelerator.accumulate(student): input_ids = batch["input_ids"] # Student forward student_output = student.forward_understanding(input_ids, labels=input_ids) student_logits = student_output["logits"] # Teacher forward if teacher is not None: with torch.no_grad(): teacher_output = teacher(input_ids) teacher_logits = teacher_output.logits # KL divergence loss (student learns to match teacher distribution) T = args.temperature student_probs = F.log_softmax(student_logits / T, dim=-1) teacher_probs = F.softmax(teacher_logits / T, dim=-1) # Need to handle vocab size mismatch min_vocab = min(student_logits.shape[-1], teacher_logits.shape[-1]) kd_loss = F.kl_div( student_probs[..., :min_vocab], teacher_probs[..., :min_vocab], reduction="batchmean", ) * (T * T) # Combined loss alpha = args.kd_alpha loss = alpha * kd_loss + (1 - alpha) * student_output["loss"] else: loss = student_output["loss"] accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(student.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() total_loss += loss.item() global_step += 1 if global_step % args.log_every == 0: avg_loss = total_loss / args.log_every elapsed = time.time() - start_time steps_per_sec = global_step / elapsed metrics = { "loss": avg_loss, "lr": scheduler.get_last_lr()[0], "steps_per_sec": steps_per_sec, "step": global_step, } if accelerator.is_main_process: print(f"Step {global_step}/{args.max_steps} | Loss: {avg_loss:.4f} | " f"LR: {scheduler.get_last_lr()[0]:.2e} | " f"Speed: {steps_per_sec:.1f} steps/s") safe_trackio_log(metrics) total_loss = 0.0 if global_step % args.save_every == 0 and accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") os.makedirs(save_path, exist_ok=True) unwrapped = accelerator.unwrap_model(student) torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt")) config.save(os.path.join(save_path, "config.json")) print(f"Saved checkpoint to {save_path}") # Save final if accelerator.is_main_process: save_path = os.path.join(args.output_dir, "stage1_final") os.makedirs(save_path, exist_ok=True) unwrapped = accelerator.unwrap_model(student) torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt")) config.save(os.path.join(save_path, "config.json")) print(f"\nStage 1 complete! Model saved to {save_path}") # Push to Hub from huggingface_hub import HfApi api = HfApi() api.upload_folder( folder_path=save_path, repo_id=f"TinmanLabSL/SmolOmni-MLA-{config.model_variant}", commit_message="Stage 1: SVD init + KL distillation", ) print(f"Pushed to TinmanLabSL/SmolOmni-MLA-{config.model_variant}") def train_stage2(args, config): """Stage 2: Joint AR + flow-matching training.""" accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16", ) if accelerator.is_main_process: try: trackio.init( project="SmolOmni-MLA", name="Stage2-Joint", config=vars(args), ) except Exception as e: print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.") set_seed(args.seed) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load VAE for image encoding from diffusers import AutoencoderKL vae = AutoencoderKL.from_pretrained( config.flow_head.vae_model, torch_dtype=torch.bfloat16 ) vae.eval() for p in vae.parameters(): p.requires_grad = False # Load model from Stage 1 checkpoint model = SmolOmniModel(config) if args.checkpoint: ckpt_path = os.path.join(args.checkpoint, "model.pt") if os.path.exists(ckpt_path): state = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(state, strict=False) print(f"Loaded Stage 1 checkpoint from {ckpt_path}") else: print("No Stage 1 checkpoint found, training from scratch") model = initialize_mla_from_pretrained(model, config.base_model, config) else: model = initialize_mla_from_pretrained(model, config.base_model, config) # Cast to bf16 AFTER loading checkpoint (ckpt weights may be fp32) model = model.to(torch.bfloat16) print("Model cast to bfloat16") # Dataset dataset = ImageTextDataset( tokenizer, vae, max_length=args.max_length, image_size=config.flow_head.gen_resolution, max_samples=args.max_train_samples, ) dataloader = DataLoader(dataset, batch_size=args.batch_size) # Optimizer (separate LR for flow head) backbone_params = [] flow_params = [] for name, param in model.named_parameters(): if "flow_head" in name or "gen_image_encoder" in name: flow_params.append(param) else: backbone_params.append(param) optimizer = torch.optim.AdamW([ {"params": backbone_params, "lr": args.learning_rate}, {"params": flow_params, "lr": args.learning_rate * 3}, # Higher LR for new flow head ], weight_decay=args.weight_decay, betas=(0.9, 0.95)) scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps, ) model, vae, optimizer, dataloader, scheduler = accelerator.prepare( model, vae, optimizer, dataloader, scheduler ) model.train() global_step = 0 total_loss = 0.0 total_ar_loss = 0.0 total_flow_loss = 0.0 start_time = time.time() print(f"\n{'='*60}") print(f"Stage 2: Joint AR + Flow-Matching Training") print(f"Model: {config.model_variant}, Steps: {args.max_steps}") print(f"{'='*60}\n") for batch in dataloader: if global_step >= args.max_steps: break with accelerator.accumulate(model): input_ids = batch["input_ids"] latents = batch["latents"].to(accelerator.device, dtype=torch.bfloat16) # Forward output = model.forward_generation( input_ids, clean_latents=latents, labels=input_ids, ) loss = output["loss"] accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() total_loss += loss.item() if output["ar_loss"] is not None: total_ar_loss += output["ar_loss"].item() total_flow_loss += output["flow_loss"].item() global_step += 1 if global_step % args.log_every == 0: n = args.log_every metrics = { "loss": total_loss / n, "ar_loss": total_ar_loss / n, "flow_loss": total_flow_loss / n, "lr": scheduler.get_last_lr()[0], "step": global_step, } if accelerator.is_main_process: print(f"Step {global_step}/{args.max_steps} | " f"Loss: {total_loss/n:.4f} | " f"AR: {total_ar_loss/n:.4f} | " f"Flow: {total_flow_loss/n:.4f}") safe_trackio_log(metrics) total_loss = total_ar_loss = total_flow_loss = 0.0 if global_step % args.save_every == 0 and accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") os.makedirs(save_path, exist_ok=True) unwrapped = accelerator.unwrap_model(model) torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt")) config.save(os.path.join(save_path, "config.json")) # Final save + push if accelerator.is_main_process: save_path = os.path.join(args.output_dir, "stage2_final") os.makedirs(save_path, exist_ok=True) unwrapped = accelerator.unwrap_model(model) torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt")) config.save(os.path.join(save_path, "config.json")) from huggingface_hub import HfApi api = HfApi() api.upload_folder( folder_path=save_path, repo_id=f"TinmanLabSL/SmolOmni-MLA-{config.model_variant}", commit_message="Stage 2: Joint AR + flow-matching training", ) print(f"\nStage 2 complete! Pushed to TinmanLabSL/SmolOmni-MLA-{config.model_variant}") def main(): parser = argparse.ArgumentParser(description="Tinman-SmolOmni-MLA Training") parser.add_argument("--stage", type=int, default=1, choices=[1, 2]) parser.add_argument("--model_variant", type=str, default="256M", choices=["256M", "500M", "1B"]) parser.add_argument("--checkpoint", type=str, default=None) parser.add_argument("--output_dir", type=str, default="./output") parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=3e-4) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--warmup_steps", type=int, default=200) parser.add_argument("--max_steps", type=int, default=5000) parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--max_train_samples", type=int, default=None) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--log_every", type=int, default=10) parser.add_argument("--save_every", type=int, default=1000) parser.add_argument("--temperature", type=float, default=2.0) parser.add_argument("--kd_alpha", type=float, default=0.7) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) # Build config config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{args.model_variant}") if args.stage == 1: train_stage1(args, config) else: train_stage2(args, config) if __name__ == "__main__": main()