| """ |
| Tinman-SmolOmni-MLA Training Script |
| Stage 1: MLA initialization + KL distillation from SmolVLM teacher |
| Stage 2: Joint AR + flow-matching training on image-text pairs |
| |
| Based on: |
| - X-EcoMLA: SVD init + KD fine-tuning (3.6B tokens for SmolLM family) |
| - Show-o2: Dual AR + flow-matching loss |
| - JanusFlow: Representation alignment (REPA) |
| |
| Usage: |
| python train.py --stage 1 --model_variant 256M |
| python train.py --stage 2 --model_variant 256M --checkpoint stage1_output |
| """ |
| import os |
| import sys |
| import math |
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset, IterableDataset |
|
|
| from accelerate import Accelerator |
| from accelerate.utils import set_seed |
| from transformers import ( |
| AutoModelForImageTextToText, |
| AutoProcessor, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| get_cosine_schedule_with_warmup, |
| ) |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from smolomni.config import SmolOmniConfig |
| from smolomni.model import SmolOmniModel |
| from smolomni.svd_init import initialize_mla_from_pretrained |
|
|
| import trackio |
|
|
| |
| def safe_trackio_log(metrics): |
| try: |
| trackio.log(metrics) |
| except Exception: |
| pass |
|
|
|
|
| |
| class TextDistillationDataset(IterableDataset): |
| """Streams text from FineWeb-Edu for KL distillation.""" |
| def __init__(self, tokenizer, max_length=512, max_samples=None): |
| from datasets import load_dataset |
| self.dataset = load_dataset( |
| "HuggingFaceFW/fineweb-edu", |
| name="CC-MAIN-2024-10", |
| split="train", |
| streaming=True, |
| ) |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.max_samples = max_samples |
| |
| def __iter__(self): |
| count = 0 |
| for example in self.dataset: |
| if self.max_samples and count >= self.max_samples: |
| break |
| text = example.get("text", "") |
| if len(text) < 50: |
| continue |
| tokens = self.tokenizer( |
| text, |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors="pt", |
| padding="max_length", |
| ) |
| yield { |
| "input_ids": tokens["input_ids"].squeeze(0), |
| "attention_mask": tokens["attention_mask"].squeeze(0), |
| } |
| count += 1 |
|
|
|
|
| |
| class ImageTextDataset(IterableDataset): |
| """Streams image-text pairs for joint AR + flow-matching training.""" |
| def __init__(self, tokenizer, vae, max_length=256, image_size=256, max_samples=None): |
| from datasets import load_dataset |
| self.dataset = load_dataset( |
| "HuggingFaceM4/the_cauldron", |
| name="chartqa", |
| split="train", |
| streaming=True, |
| ) |
| self.tokenizer = tokenizer |
| self.vae = vae |
| self.max_length = max_length |
| self.image_size = image_size |
| self.max_samples = max_samples |
| |
| from torchvision import transforms |
| self.transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ]) |
| |
| def __iter__(self): |
| count = 0 |
| for example in self.dataset: |
| if self.max_samples and count >= self.max_samples: |
| break |
| try: |
| |
| texts = example.get("texts", []) |
| if not texts: |
| continue |
| text = texts[0].get("user", "") + " " + texts[0].get("assistant", "") |
| if len(text) < 10: |
| continue |
| |
| |
| tokens = self.tokenizer( |
| text, max_length=self.max_length, truncation=True, |
| return_tensors="pt", padding="max_length", |
| ) |
| |
| |
| images = example.get("images", []) |
| if images and images[0] is not None: |
| try: |
| from PIL import Image |
| img = images[0] |
| if not isinstance(img, Image.Image): |
| img = Image.open(img).convert("RGB") |
| else: |
| img = img.convert("RGB") |
| img_tensor = self.transform(img).unsqueeze(0) |
| |
| with torch.no_grad(): |
| latents = self.vae.encode(img_tensor.to(self.vae.device, dtype=self.vae.dtype)).latent_dist.sample() |
| latents = latents * self.vae.config.scaling_factor |
| except Exception: |
| latents = torch.randn(1, 4, self.image_size // 8, self.image_size // 8) |
| else: |
| latents = torch.randn(1, 4, self.image_size // 8, self.image_size // 8) |
| |
| yield { |
| "input_ids": tokens["input_ids"].squeeze(0), |
| "attention_mask": tokens["attention_mask"].squeeze(0), |
| "latents": latents.squeeze(0).cpu(), |
| } |
| count += 1 |
| except Exception as e: |
| continue |
|
|
|
|
| def train_stage1(args, config): |
| """Stage 1: SVD init + KL distillation from teacher model.""" |
| accelerator = Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| mixed_precision="bf16", |
| ) |
| |
| if accelerator.is_main_process: |
| try: |
| trackio.init( |
| project="SmolOmni-MLA", |
| name="Stage1-KD", |
| config=vars(args), |
| ) |
| except Exception as e: |
| print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.") |
| |
| set_seed(args.seed) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| print("Creating student model with SVD initialization...") |
| student = SmolOmniModel(config) |
| student = initialize_mla_from_pretrained(student, config.base_model, config) |
| |
| |
| print("Loading teacher model...") |
| |
| base_lm_map = { |
| "256M": "HuggingFaceTB/SmolLM2-135M-Instruct", |
| "500M": "HuggingFaceTB/SmolLM2-360M-Instruct", |
| } |
| teacher_name = base_lm_map.get(config.model_variant, "HuggingFaceTB/SmolLM2-135M-Instruct") |
| try: |
| teacher = AutoModelForCausalLM.from_pretrained(teacher_name, torch_dtype=torch.bfloat16) |
| except Exception: |
| print(f"Warning: Could not load teacher {teacher_name}, using student as teacher (self-distillation)") |
| teacher = None |
| |
| if teacher is not None: |
| teacher.eval() |
| for p in teacher.parameters(): |
| p.requires_grad = False |
| |
| |
| dataset = TextDistillationDataset( |
| tokenizer, |
| max_length=args.max_length, |
| max_samples=args.max_train_samples, |
| ) |
| dataloader = DataLoader(dataset, batch_size=args.batch_size) |
| |
| |
| optimizer = torch.optim.AdamW( |
| student.parameters(), |
| lr=args.learning_rate, |
| weight_decay=args.weight_decay, |
| betas=(0.9, 0.95), |
| ) |
| |
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=args.warmup_steps, |
| num_training_steps=args.max_steps, |
| ) |
| |
| |
| student, optimizer, dataloader, scheduler = accelerator.prepare( |
| student, optimizer, dataloader, scheduler |
| ) |
| if teacher is not None: |
| teacher = accelerator.prepare(teacher) |
| |
| |
| student.train() |
| global_step = 0 |
| total_loss = 0.0 |
| start_time = time.time() |
| |
| print(f"\n{'='*60}") |
| print(f"Stage 1: KL Distillation Training") |
| print(f"Model: {config.model_variant}, Steps: {args.max_steps}") |
| print(f"Batch size: {args.batch_size} x {args.gradient_accumulation_steps} = {args.batch_size * args.gradient_accumulation_steps}") |
| print(f"Learning rate: {args.learning_rate}") |
| print(f"{'='*60}\n") |
| |
| for batch in dataloader: |
| if global_step >= args.max_steps: |
| break |
| |
| with accelerator.accumulate(student): |
| input_ids = batch["input_ids"] |
| |
| |
| student_output = student.forward_understanding(input_ids, labels=input_ids) |
| student_logits = student_output["logits"] |
| |
| |
| if teacher is not None: |
| with torch.no_grad(): |
| teacher_output = teacher(input_ids) |
| teacher_logits = teacher_output.logits |
| |
| |
| T = args.temperature |
| student_probs = F.log_softmax(student_logits / T, dim=-1) |
| teacher_probs = F.softmax(teacher_logits / T, dim=-1) |
| |
| |
| min_vocab = min(student_logits.shape[-1], teacher_logits.shape[-1]) |
| kd_loss = F.kl_div( |
| student_probs[..., :min_vocab], |
| teacher_probs[..., :min_vocab], |
| reduction="batchmean", |
| ) * (T * T) |
| |
| |
| alpha = args.kd_alpha |
| loss = alpha * kd_loss + (1 - alpha) * student_output["loss"] |
| else: |
| loss = student_output["loss"] |
| |
| accelerator.backward(loss) |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(student.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| total_loss += loss.item() |
| global_step += 1 |
| |
| if global_step % args.log_every == 0: |
| avg_loss = total_loss / args.log_every |
| elapsed = time.time() - start_time |
| steps_per_sec = global_step / elapsed |
| |
| metrics = { |
| "loss": avg_loss, |
| "lr": scheduler.get_last_lr()[0], |
| "steps_per_sec": steps_per_sec, |
| "step": global_step, |
| } |
| |
| if accelerator.is_main_process: |
| print(f"Step {global_step}/{args.max_steps} | Loss: {avg_loss:.4f} | " |
| f"LR: {scheduler.get_last_lr()[0]:.2e} | " |
| f"Speed: {steps_per_sec:.1f} steps/s") |
| safe_trackio_log(metrics) |
| |
| total_loss = 0.0 |
| |
| if global_step % args.save_every == 0 and accelerator.is_main_process: |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| os.makedirs(save_path, exist_ok=True) |
| unwrapped = accelerator.unwrap_model(student) |
| torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt")) |
| config.save(os.path.join(save_path, "config.json")) |
| print(f"Saved checkpoint to {save_path}") |
| |
| |
| if accelerator.is_main_process: |
| save_path = os.path.join(args.output_dir, "stage1_final") |
| os.makedirs(save_path, exist_ok=True) |
| unwrapped = accelerator.unwrap_model(student) |
| torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt")) |
| config.save(os.path.join(save_path, "config.json")) |
| print(f"\nStage 1 complete! Model saved to {save_path}") |
| |
| |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_folder( |
| folder_path=save_path, |
| repo_id=f"TinmanLabSL/SmolOmni-MLA-{config.model_variant}", |
| commit_message="Stage 1: SVD init + KL distillation", |
| ) |
| print(f"Pushed to TinmanLabSL/SmolOmni-MLA-{config.model_variant}") |
|
|
|
|
| def train_stage2(args, config): |
| """Stage 2: Joint AR + flow-matching training.""" |
| accelerator = Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| mixed_precision="bf16", |
| ) |
| |
| if accelerator.is_main_process: |
| try: |
| trackio.init( |
| project="SmolOmni-MLA", |
| name="Stage2-Joint", |
| config=vars(args), |
| ) |
| except Exception as e: |
| print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.") |
| |
| set_seed(args.seed) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| from diffusers import AutoencoderKL |
| vae = AutoencoderKL.from_pretrained( |
| config.flow_head.vae_model, |
| torch_dtype=torch.bfloat16 |
| ) |
| vae.eval() |
| for p in vae.parameters(): |
| p.requires_grad = False |
| |
| |
| model = SmolOmniModel(config) |
| if args.checkpoint: |
| ckpt_path = os.path.join(args.checkpoint, "model.pt") |
| if os.path.exists(ckpt_path): |
| state = torch.load(ckpt_path, map_location="cpu") |
| model.load_state_dict(state, strict=False) |
| print(f"Loaded Stage 1 checkpoint from {ckpt_path}") |
| else: |
| print("No Stage 1 checkpoint found, training from scratch") |
| model = initialize_mla_from_pretrained(model, config.base_model, config) |
| else: |
| model = initialize_mla_from_pretrained(model, config.base_model, config) |
| |
| |
| model = model.to(torch.bfloat16) |
| print("Model cast to bfloat16") |
| |
| |
| dataset = ImageTextDataset( |
| tokenizer, vae, |
| max_length=args.max_length, |
| image_size=config.flow_head.gen_resolution, |
| max_samples=args.max_train_samples, |
| ) |
| dataloader = DataLoader(dataset, batch_size=args.batch_size) |
| |
| |
| backbone_params = [] |
| flow_params = [] |
| for name, param in model.named_parameters(): |
| if "flow_head" in name or "gen_image_encoder" in name: |
| flow_params.append(param) |
| else: |
| backbone_params.append(param) |
| |
| optimizer = torch.optim.AdamW([ |
| {"params": backbone_params, "lr": args.learning_rate}, |
| {"params": flow_params, "lr": args.learning_rate * 3}, |
| ], weight_decay=args.weight_decay, betas=(0.9, 0.95)) |
| |
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps, |
| ) |
| |
| model, vae, optimizer, dataloader, scheduler = accelerator.prepare( |
| model, vae, optimizer, dataloader, scheduler |
| ) |
| |
| model.train() |
| global_step = 0 |
| total_loss = 0.0 |
| total_ar_loss = 0.0 |
| total_flow_loss = 0.0 |
| start_time = time.time() |
| |
| print(f"\n{'='*60}") |
| print(f"Stage 2: Joint AR + Flow-Matching Training") |
| print(f"Model: {config.model_variant}, Steps: {args.max_steps}") |
| print(f"{'='*60}\n") |
| |
| for batch in dataloader: |
| if global_step >= args.max_steps: |
| break |
| |
| with accelerator.accumulate(model): |
| input_ids = batch["input_ids"] |
| latents = batch["latents"].to(accelerator.device, dtype=torch.bfloat16) |
| |
| |
| output = model.forward_generation( |
| input_ids, |
| clean_latents=latents, |
| labels=input_ids, |
| ) |
| |
| loss = output["loss"] |
| accelerator.backward(loss) |
| |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| total_loss += loss.item() |
| if output["ar_loss"] is not None: |
| total_ar_loss += output["ar_loss"].item() |
| total_flow_loss += output["flow_loss"].item() |
| global_step += 1 |
| |
| if global_step % args.log_every == 0: |
| n = args.log_every |
| metrics = { |
| "loss": total_loss / n, |
| "ar_loss": total_ar_loss / n, |
| "flow_loss": total_flow_loss / n, |
| "lr": scheduler.get_last_lr()[0], |
| "step": global_step, |
| } |
| if accelerator.is_main_process: |
| print(f"Step {global_step}/{args.max_steps} | " |
| f"Loss: {total_loss/n:.4f} | " |
| f"AR: {total_ar_loss/n:.4f} | " |
| f"Flow: {total_flow_loss/n:.4f}") |
| safe_trackio_log(metrics) |
| total_loss = total_ar_loss = total_flow_loss = 0.0 |
| |
| if global_step % args.save_every == 0 and accelerator.is_main_process: |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| os.makedirs(save_path, exist_ok=True) |
| unwrapped = accelerator.unwrap_model(model) |
| torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt")) |
| config.save(os.path.join(save_path, "config.json")) |
| |
| |
| if accelerator.is_main_process: |
| save_path = os.path.join(args.output_dir, "stage2_final") |
| os.makedirs(save_path, exist_ok=True) |
| unwrapped = accelerator.unwrap_model(model) |
| torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt")) |
| config.save(os.path.join(save_path, "config.json")) |
| |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_folder( |
| folder_path=save_path, |
| repo_id=f"TinmanLabSL/SmolOmni-MLA-{config.model_variant}", |
| commit_message="Stage 2: Joint AR + flow-matching training", |
| ) |
| print(f"\nStage 2 complete! Pushed to TinmanLabSL/SmolOmni-MLA-{config.model_variant}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Tinman-SmolOmni-MLA Training") |
| parser.add_argument("--stage", type=int, default=1, choices=[1, 2]) |
| parser.add_argument("--model_variant", type=str, default="256M", choices=["256M", "500M", "1B"]) |
| parser.add_argument("--checkpoint", type=str, default=None) |
| parser.add_argument("--output_dir", type=str, default="./output") |
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=3e-4) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--warmup_steps", type=int, default=200) |
| parser.add_argument("--max_steps", type=int, default=5000) |
| parser.add_argument("--max_length", type=int, default=512) |
| parser.add_argument("--max_train_samples", type=int, default=None) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--log_every", type=int, default=10) |
| parser.add_argument("--save_every", type=int, default=1000) |
| parser.add_argument("--temperature", type=float, default=2.0) |
| parser.add_argument("--kd_alpha", type=float, default=0.7) |
| args = parser.parse_args() |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{args.model_variant}") |
| |
| if args.stage == 1: |
| train_stage1(args, config) |
| else: |
| train_stage2(args, config) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|