TinmanLabSL's picture
Upload train.py
cebf3b5 verified
"""
Tinman-SmolOmni-MLA Training Script
Stage 1: MLA initialization + KL distillation from SmolVLM teacher
Stage 2: Joint AR + flow-matching training on image-text pairs
Based on:
- X-EcoMLA: SVD init + KD fine-tuning (3.6B tokens for SmolLM family)
- Show-o2: Dual AR + flow-matching loss
- JanusFlow: Representation alignment (REPA)
Usage:
python train.py --stage 1 --model_variant 256M
python train.py --stage 2 --model_variant 256M --checkpoint stage1_output
"""
import os
import sys
import math
import argparse
import json
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
from accelerate import Accelerator
from accelerate.utils import set_seed
from transformers import (
AutoModelForImageTextToText,
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer,
get_cosine_schedule_with_warmup,
)
# Add smolomni to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from smolomni.config import SmolOmniConfig
from smolomni.model import SmolOmniModel
from smolomni.svd_init import initialize_mla_from_pretrained
import trackio
# Safe trackio wrapper
def safe_trackio_log(metrics):
try:
trackio.log(metrics)
except Exception:
pass
# ===== Stage 1: KL Distillation Dataset =====
class TextDistillationDataset(IterableDataset):
"""Streams text from FineWeb-Edu for KL distillation."""
def __init__(self, tokenizer, max_length=512, max_samples=None):
from datasets import load_dataset
self.dataset = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="CC-MAIN-2024-10", # Use one recent crawl
split="train",
streaming=True,
)
self.tokenizer = tokenizer
self.max_length = max_length
self.max_samples = max_samples
def __iter__(self):
count = 0
for example in self.dataset:
if self.max_samples and count >= self.max_samples:
break
text = example.get("text", "")
if len(text) < 50:
continue
tokens = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
return_tensors="pt",
padding="max_length",
)
yield {
"input_ids": tokens["input_ids"].squeeze(0),
"attention_mask": tokens["attention_mask"].squeeze(0),
}
count += 1
# ===== Stage 2: Image-Text Dataset =====
class ImageTextDataset(IterableDataset):
"""Streams image-text pairs for joint AR + flow-matching training."""
def __init__(self, tokenizer, vae, max_length=256, image_size=256, max_samples=None):
from datasets import load_dataset
self.dataset = load_dataset(
"HuggingFaceM4/the_cauldron",
name="chartqa", # Start with a manageable subset
split="train",
streaming=True,
)
self.tokenizer = tokenizer
self.vae = vae
self.max_length = max_length
self.image_size = image_size
self.max_samples = max_samples
from torchvision import transforms
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def __iter__(self):
count = 0
for example in self.dataset:
if self.max_samples and count >= self.max_samples:
break
try:
# Get text
texts = example.get("texts", [])
if not texts:
continue
text = texts[0].get("user", "") + " " + texts[0].get("assistant", "")
if len(text) < 10:
continue
# Tokenize
tokens = self.tokenizer(
text, max_length=self.max_length, truncation=True,
return_tensors="pt", padding="max_length",
)
# Get image (use dummy latents if image processing fails)
images = example.get("images", [])
if images and images[0] is not None:
try:
from PIL import Image
img = images[0]
if not isinstance(img, Image.Image):
img = Image.open(img).convert("RGB")
else:
img = img.convert("RGB")
img_tensor = self.transform(img).unsqueeze(0)
# Encode with VAE
with torch.no_grad():
latents = self.vae.encode(img_tensor.to(self.vae.device, dtype=self.vae.dtype)).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
except Exception:
latents = torch.randn(1, 4, self.image_size // 8, self.image_size // 8)
else:
latents = torch.randn(1, 4, self.image_size // 8, self.image_size // 8)
yield {
"input_ids": tokens["input_ids"].squeeze(0),
"attention_mask": tokens["attention_mask"].squeeze(0),
"latents": latents.squeeze(0).cpu(),
}
count += 1
except Exception as e:
continue
def train_stage1(args, config):
"""Stage 1: SVD init + KL distillation from teacher model."""
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision="bf16",
)
if accelerator.is_main_process:
try:
trackio.init(
project="SmolOmni-MLA",
name="Stage1-KD",
config=vars(args),
)
except Exception as e:
print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.")
set_seed(args.seed)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Create student model with SVD initialization
print("Creating student model with SVD initialization...")
student = SmolOmniModel(config)
student = initialize_mla_from_pretrained(student, config.base_model, config)
# Load teacher model (frozen)
print("Loading teacher model...")
# SmolVLM-256M uses SmolLM2-135M as backbone
base_lm_map = {
"256M": "HuggingFaceTB/SmolLM2-135M-Instruct",
"500M": "HuggingFaceTB/SmolLM2-360M-Instruct",
}
teacher_name = base_lm_map.get(config.model_variant, "HuggingFaceTB/SmolLM2-135M-Instruct")
try:
teacher = AutoModelForCausalLM.from_pretrained(teacher_name, torch_dtype=torch.bfloat16)
except Exception:
print(f"Warning: Could not load teacher {teacher_name}, using student as teacher (self-distillation)")
teacher = None
if teacher is not None:
teacher.eval()
for p in teacher.parameters():
p.requires_grad = False
# Dataset
dataset = TextDistillationDataset(
tokenizer,
max_length=args.max_length,
max_samples=args.max_train_samples,
)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
# Optimizer
optimizer = torch.optim.AdamW(
student.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay,
betas=(0.9, 0.95),
)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=args.max_steps,
)
# Prepare
student, optimizer, dataloader, scheduler = accelerator.prepare(
student, optimizer, dataloader, scheduler
)
if teacher is not None:
teacher = accelerator.prepare(teacher)
# Training loop
student.train()
global_step = 0
total_loss = 0.0
start_time = time.time()
print(f"\n{'='*60}")
print(f"Stage 1: KL Distillation Training")
print(f"Model: {config.model_variant}, Steps: {args.max_steps}")
print(f"Batch size: {args.batch_size} x {args.gradient_accumulation_steps} = {args.batch_size * args.gradient_accumulation_steps}")
print(f"Learning rate: {args.learning_rate}")
print(f"{'='*60}\n")
for batch in dataloader:
if global_step >= args.max_steps:
break
with accelerator.accumulate(student):
input_ids = batch["input_ids"]
# Student forward
student_output = student.forward_understanding(input_ids, labels=input_ids)
student_logits = student_output["logits"]
# Teacher forward
if teacher is not None:
with torch.no_grad():
teacher_output = teacher(input_ids)
teacher_logits = teacher_output.logits
# KL divergence loss (student learns to match teacher distribution)
T = args.temperature
student_probs = F.log_softmax(student_logits / T, dim=-1)
teacher_probs = F.softmax(teacher_logits / T, dim=-1)
# Need to handle vocab size mismatch
min_vocab = min(student_logits.shape[-1], teacher_logits.shape[-1])
kd_loss = F.kl_div(
student_probs[..., :min_vocab],
teacher_probs[..., :min_vocab],
reduction="batchmean",
) * (T * T)
# Combined loss
alpha = args.kd_alpha
loss = alpha * kd_loss + (1 - alpha) * student_output["loss"]
else:
loss = student_output["loss"]
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(student.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_loss += loss.item()
global_step += 1
if global_step % args.log_every == 0:
avg_loss = total_loss / args.log_every
elapsed = time.time() - start_time
steps_per_sec = global_step / elapsed
metrics = {
"loss": avg_loss,
"lr": scheduler.get_last_lr()[0],
"steps_per_sec": steps_per_sec,
"step": global_step,
}
if accelerator.is_main_process:
print(f"Step {global_step}/{args.max_steps} | Loss: {avg_loss:.4f} | "
f"LR: {scheduler.get_last_lr()[0]:.2e} | "
f"Speed: {steps_per_sec:.1f} steps/s")
safe_trackio_log(metrics)
total_loss = 0.0
if global_step % args.save_every == 0 and accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
os.makedirs(save_path, exist_ok=True)
unwrapped = accelerator.unwrap_model(student)
torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt"))
config.save(os.path.join(save_path, "config.json"))
print(f"Saved checkpoint to {save_path}")
# Save final
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, "stage1_final")
os.makedirs(save_path, exist_ok=True)
unwrapped = accelerator.unwrap_model(student)
torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt"))
config.save(os.path.join(save_path, "config.json"))
print(f"\nStage 1 complete! Model saved to {save_path}")
# Push to Hub
from huggingface_hub import HfApi
api = HfApi()
api.upload_folder(
folder_path=save_path,
repo_id=f"TinmanLabSL/SmolOmni-MLA-{config.model_variant}",
commit_message="Stage 1: SVD init + KL distillation",
)
print(f"Pushed to TinmanLabSL/SmolOmni-MLA-{config.model_variant}")
def train_stage2(args, config):
"""Stage 2: Joint AR + flow-matching training."""
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision="bf16",
)
if accelerator.is_main_process:
try:
trackio.init(
project="SmolOmni-MLA",
name="Stage2-Joint",
config=vars(args),
)
except Exception as e:
print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.")
set_seed(args.seed)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load VAE for image encoding
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(
config.flow_head.vae_model,
torch_dtype=torch.bfloat16
)
vae.eval()
for p in vae.parameters():
p.requires_grad = False
# Load model from Stage 1 checkpoint
model = SmolOmniModel(config)
if args.checkpoint:
ckpt_path = os.path.join(args.checkpoint, "model.pt")
if os.path.exists(ckpt_path):
state = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state, strict=False)
print(f"Loaded Stage 1 checkpoint from {ckpt_path}")
else:
print("No Stage 1 checkpoint found, training from scratch")
model = initialize_mla_from_pretrained(model, config.base_model, config)
else:
model = initialize_mla_from_pretrained(model, config.base_model, config)
# Cast to bf16 AFTER loading checkpoint (ckpt weights may be fp32)
model = model.to(torch.bfloat16)
print("Model cast to bfloat16")
# Dataset
dataset = ImageTextDataset(
tokenizer, vae,
max_length=args.max_length,
image_size=config.flow_head.gen_resolution,
max_samples=args.max_train_samples,
)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
# Optimizer (separate LR for flow head)
backbone_params = []
flow_params = []
for name, param in model.named_parameters():
if "flow_head" in name or "gen_image_encoder" in name:
flow_params.append(param)
else:
backbone_params.append(param)
optimizer = torch.optim.AdamW([
{"params": backbone_params, "lr": args.learning_rate},
{"params": flow_params, "lr": args.learning_rate * 3}, # Higher LR for new flow head
], weight_decay=args.weight_decay, betas=(0.9, 0.95))
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps,
)
model, vae, optimizer, dataloader, scheduler = accelerator.prepare(
model, vae, optimizer, dataloader, scheduler
)
model.train()
global_step = 0
total_loss = 0.0
total_ar_loss = 0.0
total_flow_loss = 0.0
start_time = time.time()
print(f"\n{'='*60}")
print(f"Stage 2: Joint AR + Flow-Matching Training")
print(f"Model: {config.model_variant}, Steps: {args.max_steps}")
print(f"{'='*60}\n")
for batch in dataloader:
if global_step >= args.max_steps:
break
with accelerator.accumulate(model):
input_ids = batch["input_ids"]
latents = batch["latents"].to(accelerator.device, dtype=torch.bfloat16)
# Forward
output = model.forward_generation(
input_ids,
clean_latents=latents,
labels=input_ids,
)
loss = output["loss"]
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_loss += loss.item()
if output["ar_loss"] is not None:
total_ar_loss += output["ar_loss"].item()
total_flow_loss += output["flow_loss"].item()
global_step += 1
if global_step % args.log_every == 0:
n = args.log_every
metrics = {
"loss": total_loss / n,
"ar_loss": total_ar_loss / n,
"flow_loss": total_flow_loss / n,
"lr": scheduler.get_last_lr()[0],
"step": global_step,
}
if accelerator.is_main_process:
print(f"Step {global_step}/{args.max_steps} | "
f"Loss: {total_loss/n:.4f} | "
f"AR: {total_ar_loss/n:.4f} | "
f"Flow: {total_flow_loss/n:.4f}")
safe_trackio_log(metrics)
total_loss = total_ar_loss = total_flow_loss = 0.0
if global_step % args.save_every == 0 and accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
os.makedirs(save_path, exist_ok=True)
unwrapped = accelerator.unwrap_model(model)
torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt"))
config.save(os.path.join(save_path, "config.json"))
# Final save + push
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, "stage2_final")
os.makedirs(save_path, exist_ok=True)
unwrapped = accelerator.unwrap_model(model)
torch.save(unwrapped.state_dict(), os.path.join(save_path, "model.pt"))
config.save(os.path.join(save_path, "config.json"))
from huggingface_hub import HfApi
api = HfApi()
api.upload_folder(
folder_path=save_path,
repo_id=f"TinmanLabSL/SmolOmni-MLA-{config.model_variant}",
commit_message="Stage 2: Joint AR + flow-matching training",
)
print(f"\nStage 2 complete! Pushed to TinmanLabSL/SmolOmni-MLA-{config.model_variant}")
def main():
parser = argparse.ArgumentParser(description="Tinman-SmolOmni-MLA Training")
parser.add_argument("--stage", type=int, default=1, choices=[1, 2])
parser.add_argument("--model_variant", type=str, default="256M", choices=["256M", "500M", "1B"])
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--output_dir", type=str, default="./output")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--warmup_steps", type=int, default=200)
parser.add_argument("--max_steps", type=int, default=5000)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--max_train_samples", type=int, default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--log_every", type=int, default=10)
parser.add_argument("--save_every", type=int, default=1000)
parser.add_argument("--temperature", type=float, default=2.0)
parser.add_argument("--kd_alpha", type=float, default=0.7)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Build config
config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{args.model_variant}")
if args.stage == 1:
train_stage1(args, config)
else:
train_stage2(args, config)
if __name__ == "__main__":
main()