""" LiRA Training Script - Ready for Colab/Kaggle This script trains LiRA from scratch on any text-image dataset. Designed to be Colab-friendly: works on a single GPU with 16GB VRAM. Usage: # Quick test (CIFAR-like, no text) python train.py --test_mode # Train on a real dataset python train.py --dataset_name "lambdalabs/naruto-blip-captions" \ --model_config tiny --resolution 256 --batch_size 8 """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset import math import os import sys import argparse import time import json from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from lira.model import LiRAModel, LiRAPipeline, estimate_memory_mb from lira.training import ( FlowMatchingScheduler, EMAModel, compute_loss, LiRATrainingConfig, FlowDPMSolver, get_lr_scheduler ) class SyntheticDataset(Dataset): """Synthetic dataset for architecture testing - generates random latents + text""" def __init__(self, num_samples=1000, latent_channels=4, latent_size=32, text_dim=768, text_len=77): self.num_samples = num_samples self.latent_channels = latent_channels self.latent_size = latent_size self.text_dim = text_dim self.text_len = text_len def __len__(self): return self.num_samples def __getitem__(self, idx): # Generate structured patterns (not just noise) for meaningful learning torch.manual_seed(idx) # Create latent with spatial structure z = torch.randn(self.latent_channels, self.latent_size, self.latent_size) # Add some structure: low-frequency patterns freq = torch.randn(self.latent_channels, 4, 4) z = z + F.interpolate(freq.unsqueeze(0), size=self.latent_size, mode='bilinear', align_corners=False).squeeze(0) * 2 # Text features (random but consistent per sample) text_features = torch.randn(self.text_len, self.text_dim) * 0.1 text_mask = torch.ones(self.text_len, dtype=torch.bool) return { 'latent': z, 'text_features': text_features, 'text_mask': text_mask, } def train(config: LiRATrainingConfig): """Main training loop""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"šŸ”§ Device: {device}") # Create model model = LiRAModel( config_name=config.model_config, in_channels=config.latent_channels, d_text=config.d_text, patch_size=config.patch_size, ).to(device) counts = model.count_parameters() print(f"\nšŸ—ļø Model: LiRA-{config.model_config.capitalize()}") print(f" Parameters: {counts['total']/1e6:.1f}M") print(f" Model size (fp16): {counts['total'] * 2 / (1024**2):.0f}MB") # Optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, betas=(0.9, 0.999), ) # LR scheduler lr_scheduler = get_lr_scheduler(optimizer, config) # EMA ema = EMAModel(model, decay=config.ema_decay) # Flow matching scheduler noise_scheduler = FlowMatchingScheduler(schedule=config.noise_schedule) # Dataset latent_size = config.progressive_stages[0]['resolution'] // config.spatial_compression if config.patch_size > 1: latent_size = latent_size # Patchification happens inside model dataset = SyntheticDataset( num_samples=min(10000, config.max_steps * config.batch_size), latent_channels=config.latent_channels, latent_size=latent_size, text_dim=config.d_text, ) dataloader = DataLoader( dataset, batch_size=config.batch_size, shuffle=True, num_workers=0, # 0 for Colab compatibility drop_last=True, ) # Mixed precision use_amp = config.mixed_precision != 'no' and device.type == 'cuda' scaler = torch.amp.GradScaler(enabled=use_amp and config.mixed_precision == 'fp16') amp_dtype = torch.bfloat16 if config.mixed_precision == 'bf16' else torch.float16 # Training loop print(f"\nšŸš€ Starting training...") print(f" Steps: {config.max_steps}") print(f" Batch size: {config.batch_size}") print(f" Learning rate: {config.learning_rate}") print(f" Noise schedule: {config.noise_schedule}") print(f" Mixed precision: {config.mixed_precision}") os.makedirs(config.output_dir, exist_ok=True) global_step = 0 epoch = 0 losses = [] start_time = time.time() model.train() while global_step < config.max_steps: epoch += 1 for batch in dataloader: if global_step >= config.max_steps: break z_0 = batch['latent'].to(device) text_features = batch['text_features'].to(device) text_mask = batch['text_mask'].to(device) # Forward + backward with mixed precision optimizer.zero_grad(set_to_none=True) if use_amp: with torch.amp.autocast(device_type=device.type, dtype=amp_dtype): loss, info = compute_loss( model, z_0, text_features, noise_scheduler, config, global_step=global_step, text_mask=text_mask, ) scaler.scale(loss).backward() scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) scaler.step(optimizer) scaler.update() else: loss, info = compute_loss( model, z_0, text_features, noise_scheduler, config, global_step=global_step, text_mask=text_mask, ) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) optimizer.step() lr_scheduler.step() ema.update(model) losses.append(info['loss']) global_step += 1 # Logging if global_step % config.log_every == 0 or global_step == 1: avg_loss = sum(losses[-100:]) / len(losses[-100:]) elapsed = time.time() - start_time steps_per_sec = global_step / elapsed lr = optimizer.param_groups[0]['lr'] print(f" Step {global_step}/{config.max_steps} | " f"loss={avg_loss:.4f} | " f"mse={info['mse_loss']:.4f} | " f"reason_steps={info['reason_steps']} | " f"grad={grad_norm:.3f} | " f"lr={lr:.2e} | " f"speed={steps_per_sec:.1f} steps/s") # Save checkpoint if global_step % config.save_every == 0: save_path = os.path.join(config.output_dir, f'checkpoint-{global_step}.pt') torch.save({ 'step': global_step, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'ema_state_dict': ema.state_dict(), 'config': vars(config), 'losses': losses[-1000:], }, save_path) print(f" šŸ’¾ Saved checkpoint: {save_path}") # Final save save_path = os.path.join(config.output_dir, 'final_model.pt') torch.save({ 'step': global_step, 'model_state_dict': model.state_dict(), 'ema_state_dict': ema.state_dict(), 'config': vars(config), }, save_path) elapsed = time.time() - start_time print(f"\nāœ… Training complete!") print(f" Total steps: {global_step}") print(f" Final loss: {sum(losses[-100:])/len(losses[-100:]):.4f}") print(f" Total time: {elapsed:.0f}s ({elapsed/60:.1f}min)") print(f" Saved to: {save_path}") return model, ema def main(): parser = argparse.ArgumentParser(description='Train LiRA') parser.add_argument('--test_mode', action='store_true', help='Quick test with synthetic data') parser.add_argument('--model_config', type=str, default='tiny') parser.add_argument('--resolution', type=int, default=256) parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--max_steps', type=int, default=1000) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--output_dir', type=str, default='./lira_output') parser.add_argument('--dataset_name', type=str, default='') args = parser.parse_args() if args.test_mode: config = LiRATrainingConfig( model_config='tiny', latent_channels=4, spatial_compression=8, d_text=768, patch_size=2, batch_size=2, learning_rate=1e-4, max_steps=50, warmup_steps=5, log_every=10, save_every=25, noise_schedule='laplace', use_curriculum=True, curriculum_warmup=20, output_dir=args.output_dir, ) else: spatial_compression = 8 # Default f8 VAE config = LiRATrainingConfig( model_config=args.model_config, latent_channels=4, spatial_compression=spatial_compression, d_text=768, patch_size=2, batch_size=args.batch_size, learning_rate=args.learning_rate, max_steps=args.max_steps, output_dir=args.output_dir, dataset_name=args.dataset_name, ) train(config) if __name__ == '__main__': main()