| """ |
| LiRA Training Script - Ready for Colab/Kaggle |
| |
| This script trains LiRA from scratch on any text-image dataset. |
| Designed to be Colab-friendly: works on a single GPU with 16GB VRAM. |
| |
| Usage: |
| # Quick test (CIFAR-like, no text) |
| python train.py --test_mode |
| |
| # Train on a real dataset |
| python train.py --dataset_name "lambdalabs/naruto-blip-captions" \ |
| --model_config tiny --resolution 256 --batch_size 8 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| import math |
| import os |
| import sys |
| import argparse |
| import time |
| import json |
| from pathlib import Path |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from lira.model import LiRAModel, LiRAPipeline, estimate_memory_mb |
| from lira.training import ( |
| FlowMatchingScheduler, EMAModel, compute_loss, |
| LiRATrainingConfig, FlowDPMSolver, get_lr_scheduler |
| ) |
|
|
|
|
| class SyntheticDataset(Dataset): |
| """Synthetic dataset for architecture testing - generates random latents + text""" |
| |
| def __init__(self, num_samples=1000, latent_channels=4, latent_size=32, |
| text_dim=768, text_len=77): |
| self.num_samples = num_samples |
| self.latent_channels = latent_channels |
| self.latent_size = latent_size |
| self.text_dim = text_dim |
| self.text_len = text_len |
| |
| def __len__(self): |
| return self.num_samples |
| |
| def __getitem__(self, idx): |
| |
| torch.manual_seed(idx) |
| |
| |
| z = torch.randn(self.latent_channels, self.latent_size, self.latent_size) |
| |
| freq = torch.randn(self.latent_channels, 4, 4) |
| z = z + F.interpolate(freq.unsqueeze(0), size=self.latent_size, |
| mode='bilinear', align_corners=False).squeeze(0) * 2 |
| |
| |
| text_features = torch.randn(self.text_len, self.text_dim) * 0.1 |
| text_mask = torch.ones(self.text_len, dtype=torch.bool) |
| |
| return { |
| 'latent': z, |
| 'text_features': text_features, |
| 'text_mask': text_mask, |
| } |
|
|
|
|
| def train(config: LiRATrainingConfig): |
| """Main training loop""" |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"🔧 Device: {device}") |
| |
| |
| model = LiRAModel( |
| config_name=config.model_config, |
| in_channels=config.latent_channels, |
| d_text=config.d_text, |
| patch_size=config.patch_size, |
| ).to(device) |
| |
| counts = model.count_parameters() |
| print(f"\n🏗️ Model: LiRA-{config.model_config.capitalize()}") |
| print(f" Parameters: {counts['total']/1e6:.1f}M") |
| print(f" Model size (fp16): {counts['total'] * 2 / (1024**2):.0f}MB") |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=config.learning_rate, |
| weight_decay=config.weight_decay, |
| betas=(0.9, 0.999), |
| ) |
| |
| |
| lr_scheduler = get_lr_scheduler(optimizer, config) |
| |
| |
| ema = EMAModel(model, decay=config.ema_decay) |
| |
| |
| noise_scheduler = FlowMatchingScheduler(schedule=config.noise_schedule) |
| |
| |
| latent_size = config.progressive_stages[0]['resolution'] // config.spatial_compression |
| if config.patch_size > 1: |
| latent_size = latent_size |
| |
| dataset = SyntheticDataset( |
| num_samples=min(10000, config.max_steps * config.batch_size), |
| latent_channels=config.latent_channels, |
| latent_size=latent_size, |
| text_dim=config.d_text, |
| ) |
| |
| dataloader = DataLoader( |
| dataset, |
| batch_size=config.batch_size, |
| shuffle=True, |
| num_workers=0, |
| drop_last=True, |
| ) |
| |
| |
| use_amp = config.mixed_precision != 'no' and device.type == 'cuda' |
| scaler = torch.amp.GradScaler(enabled=use_amp and config.mixed_precision == 'fp16') |
| amp_dtype = torch.bfloat16 if config.mixed_precision == 'bf16' else torch.float16 |
| |
| |
| print(f"\n🚀 Starting training...") |
| print(f" Steps: {config.max_steps}") |
| print(f" Batch size: {config.batch_size}") |
| print(f" Learning rate: {config.learning_rate}") |
| print(f" Noise schedule: {config.noise_schedule}") |
| print(f" Mixed precision: {config.mixed_precision}") |
| |
| os.makedirs(config.output_dir, exist_ok=True) |
| |
| global_step = 0 |
| epoch = 0 |
| losses = [] |
| start_time = time.time() |
| |
| model.train() |
| |
| while global_step < config.max_steps: |
| epoch += 1 |
| for batch in dataloader: |
| if global_step >= config.max_steps: |
| break |
| |
| z_0 = batch['latent'].to(device) |
| text_features = batch['text_features'].to(device) |
| text_mask = batch['text_mask'].to(device) |
| |
| |
| optimizer.zero_grad(set_to_none=True) |
| |
| if use_amp: |
| with torch.amp.autocast(device_type=device.type, dtype=amp_dtype): |
| loss, info = compute_loss( |
| model, z_0, text_features, noise_scheduler, config, |
| global_step=global_step, text_mask=text_mask, |
| ) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss, info = compute_loss( |
| model, z_0, text_features, noise_scheduler, config, |
| global_step=global_step, text_mask=text_mask, |
| ) |
| loss.backward() |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) |
| optimizer.step() |
| |
| lr_scheduler.step() |
| ema.update(model) |
| |
| losses.append(info['loss']) |
| global_step += 1 |
| |
| |
| if global_step % config.log_every == 0 or global_step == 1: |
| avg_loss = sum(losses[-100:]) / len(losses[-100:]) |
| elapsed = time.time() - start_time |
| steps_per_sec = global_step / elapsed |
| lr = optimizer.param_groups[0]['lr'] |
| |
| print(f" Step {global_step}/{config.max_steps} | " |
| f"loss={avg_loss:.4f} | " |
| f"mse={info['mse_loss']:.4f} | " |
| f"reason_steps={info['reason_steps']} | " |
| f"grad={grad_norm:.3f} | " |
| f"lr={lr:.2e} | " |
| f"speed={steps_per_sec:.1f} steps/s") |
| |
| |
| if global_step % config.save_every == 0: |
| save_path = os.path.join(config.output_dir, f'checkpoint-{global_step}.pt') |
| torch.save({ |
| 'step': global_step, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'ema_state_dict': ema.state_dict(), |
| 'config': vars(config), |
| 'losses': losses[-1000:], |
| }, save_path) |
| print(f" 💾 Saved checkpoint: {save_path}") |
| |
| |
| save_path = os.path.join(config.output_dir, 'final_model.pt') |
| torch.save({ |
| 'step': global_step, |
| 'model_state_dict': model.state_dict(), |
| 'ema_state_dict': ema.state_dict(), |
| 'config': vars(config), |
| }, save_path) |
| |
| elapsed = time.time() - start_time |
| print(f"\n✅ Training complete!") |
| print(f" Total steps: {global_step}") |
| print(f" Final loss: {sum(losses[-100:])/len(losses[-100:]):.4f}") |
| print(f" Total time: {elapsed:.0f}s ({elapsed/60:.1f}min)") |
| print(f" Saved to: {save_path}") |
| |
| return model, ema |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Train LiRA') |
| parser.add_argument('--test_mode', action='store_true', help='Quick test with synthetic data') |
| parser.add_argument('--model_config', type=str, default='tiny') |
| parser.add_argument('--resolution', type=int, default=256) |
| parser.add_argument('--batch_size', type=int, default=4) |
| parser.add_argument('--max_steps', type=int, default=1000) |
| parser.add_argument('--learning_rate', type=float, default=1e-4) |
| parser.add_argument('--output_dir', type=str, default='./lira_output') |
| parser.add_argument('--dataset_name', type=str, default='') |
| args = parser.parse_args() |
| |
| if args.test_mode: |
| config = LiRATrainingConfig( |
| model_config='tiny', |
| latent_channels=4, |
| spatial_compression=8, |
| d_text=768, |
| patch_size=2, |
| batch_size=2, |
| learning_rate=1e-4, |
| max_steps=50, |
| warmup_steps=5, |
| log_every=10, |
| save_every=25, |
| noise_schedule='laplace', |
| use_curriculum=True, |
| curriculum_warmup=20, |
| output_dir=args.output_dir, |
| ) |
| else: |
| spatial_compression = 8 |
| config = LiRATrainingConfig( |
| model_config=args.model_config, |
| latent_channels=4, |
| spatial_compression=spatial_compression, |
| d_text=768, |
| patch_size=2, |
| batch_size=args.batch_size, |
| learning_rate=args.learning_rate, |
| max_steps=args.max_steps, |
| output_dir=args.output_dir, |
| dataset_name=args.dataset_name, |
| ) |
| |
| train(config) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|