asdf98 commited on
Commit
1c7e629
·
verified ·
1 Parent(s): 85accf4

Add train.py

Browse files
Files changed (1) hide show
  1. train.py +286 -0
train.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiRA Training Script - Ready for Colab/Kaggle
3
+
4
+ This script trains LiRA from scratch on any text-image dataset.
5
+ Designed to be Colab-friendly: works on a single GPU with 16GB VRAM.
6
+
7
+ Usage:
8
+ # Quick test (CIFAR-like, no text)
9
+ python train.py --test_mode
10
+
11
+ # Train on a real dataset
12
+ python train.py --dataset_name "lambdalabs/naruto-blip-captions" \
13
+ --model_config tiny --resolution 256 --batch_size 8
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.utils.data import DataLoader, Dataset
20
+ import math
21
+ import os
22
+ import sys
23
+ import argparse
24
+ import time
25
+ import json
26
+ from pathlib import Path
27
+
28
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
29
+
30
+ from lira.model import LiRAModel, LiRAPipeline, estimate_memory_mb
31
+ from lira.training import (
32
+ FlowMatchingScheduler, EMAModel, compute_loss,
33
+ LiRATrainingConfig, FlowDPMSolver, get_lr_scheduler
34
+ )
35
+
36
+
37
+ class SyntheticDataset(Dataset):
38
+ """Synthetic dataset for architecture testing - generates random latents + text"""
39
+
40
+ def __init__(self, num_samples=1000, latent_channels=4, latent_size=32,
41
+ text_dim=768, text_len=77):
42
+ self.num_samples = num_samples
43
+ self.latent_channels = latent_channels
44
+ self.latent_size = latent_size
45
+ self.text_dim = text_dim
46
+ self.text_len = text_len
47
+
48
+ def __len__(self):
49
+ return self.num_samples
50
+
51
+ def __getitem__(self, idx):
52
+ # Generate structured patterns (not just noise) for meaningful learning
53
+ torch.manual_seed(idx)
54
+
55
+ # Create latent with spatial structure
56
+ z = torch.randn(self.latent_channels, self.latent_size, self.latent_size)
57
+ # Add some structure: low-frequency patterns
58
+ freq = torch.randn(self.latent_channels, 4, 4)
59
+ z = z + F.interpolate(freq.unsqueeze(0), size=self.latent_size,
60
+ mode='bilinear', align_corners=False).squeeze(0) * 2
61
+
62
+ # Text features (random but consistent per sample)
63
+ text_features = torch.randn(self.text_len, self.text_dim) * 0.1
64
+ text_mask = torch.ones(self.text_len, dtype=torch.bool)
65
+
66
+ return {
67
+ 'latent': z,
68
+ 'text_features': text_features,
69
+ 'text_mask': text_mask,
70
+ }
71
+
72
+
73
+ def train(config: LiRATrainingConfig):
74
+ """Main training loop"""
75
+
76
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
77
+ print(f"🔧 Device: {device}")
78
+
79
+ # Create model
80
+ model = LiRAModel(
81
+ config_name=config.model_config,
82
+ in_channels=config.latent_channels,
83
+ d_text=config.d_text,
84
+ patch_size=config.patch_size,
85
+ ).to(device)
86
+
87
+ counts = model.count_parameters()
88
+ print(f"\n🏗️ Model: LiRA-{config.model_config.capitalize()}")
89
+ print(f" Parameters: {counts['total']/1e6:.1f}M")
90
+ print(f" Model size (fp16): {counts['total'] * 2 / (1024**2):.0f}MB")
91
+
92
+ # Optimizer
93
+ optimizer = torch.optim.AdamW(
94
+ model.parameters(),
95
+ lr=config.learning_rate,
96
+ weight_decay=config.weight_decay,
97
+ betas=(0.9, 0.999),
98
+ )
99
+
100
+ # LR scheduler
101
+ lr_scheduler = get_lr_scheduler(optimizer, config)
102
+
103
+ # EMA
104
+ ema = EMAModel(model, decay=config.ema_decay)
105
+
106
+ # Flow matching scheduler
107
+ noise_scheduler = FlowMatchingScheduler(schedule=config.noise_schedule)
108
+
109
+ # Dataset
110
+ latent_size = config.progressive_stages[0]['resolution'] // config.spatial_compression
111
+ if config.patch_size > 1:
112
+ latent_size = latent_size # Patchification happens inside model
113
+
114
+ dataset = SyntheticDataset(
115
+ num_samples=min(10000, config.max_steps * config.batch_size),
116
+ latent_channels=config.latent_channels,
117
+ latent_size=latent_size,
118
+ text_dim=config.d_text,
119
+ )
120
+
121
+ dataloader = DataLoader(
122
+ dataset,
123
+ batch_size=config.batch_size,
124
+ shuffle=True,
125
+ num_workers=0, # 0 for Colab compatibility
126
+ drop_last=True,
127
+ )
128
+
129
+ # Mixed precision
130
+ use_amp = config.mixed_precision != 'no' and device.type == 'cuda'
131
+ scaler = torch.amp.GradScaler(enabled=use_amp and config.mixed_precision == 'fp16')
132
+ amp_dtype = torch.bfloat16 if config.mixed_precision == 'bf16' else torch.float16
133
+
134
+ # Training loop
135
+ print(f"\n🚀 Starting training...")
136
+ print(f" Steps: {config.max_steps}")
137
+ print(f" Batch size: {config.batch_size}")
138
+ print(f" Learning rate: {config.learning_rate}")
139
+ print(f" Noise schedule: {config.noise_schedule}")
140
+ print(f" Mixed precision: {config.mixed_precision}")
141
+
142
+ os.makedirs(config.output_dir, exist_ok=True)
143
+
144
+ global_step = 0
145
+ epoch = 0
146
+ losses = []
147
+ start_time = time.time()
148
+
149
+ model.train()
150
+
151
+ while global_step < config.max_steps:
152
+ epoch += 1
153
+ for batch in dataloader:
154
+ if global_step >= config.max_steps:
155
+ break
156
+
157
+ z_0 = batch['latent'].to(device)
158
+ text_features = batch['text_features'].to(device)
159
+ text_mask = batch['text_mask'].to(device)
160
+
161
+ # Forward + backward with mixed precision
162
+ optimizer.zero_grad(set_to_none=True)
163
+
164
+ if use_amp:
165
+ with torch.amp.autocast(device_type=device.type, dtype=amp_dtype):
166
+ loss, info = compute_loss(
167
+ model, z_0, text_features, noise_scheduler, config,
168
+ global_step=global_step, text_mask=text_mask,
169
+ )
170
+ scaler.scale(loss).backward()
171
+ scaler.unscale_(optimizer)
172
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
173
+ scaler.step(optimizer)
174
+ scaler.update()
175
+ else:
176
+ loss, info = compute_loss(
177
+ model, z_0, text_features, noise_scheduler, config,
178
+ global_step=global_step, text_mask=text_mask,
179
+ )
180
+ loss.backward()
181
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
182
+ optimizer.step()
183
+
184
+ lr_scheduler.step()
185
+ ema.update(model)
186
+
187
+ losses.append(info['loss'])
188
+ global_step += 1
189
+
190
+ # Logging
191
+ if global_step % config.log_every == 0 or global_step == 1:
192
+ avg_loss = sum(losses[-100:]) / len(losses[-100:])
193
+ elapsed = time.time() - start_time
194
+ steps_per_sec = global_step / elapsed
195
+ lr = optimizer.param_groups[0]['lr']
196
+
197
+ print(f" Step {global_step}/{config.max_steps} | "
198
+ f"loss={avg_loss:.4f} | "
199
+ f"mse={info['mse_loss']:.4f} | "
200
+ f"reason_steps={info['reason_steps']} | "
201
+ f"grad={grad_norm:.3f} | "
202
+ f"lr={lr:.2e} | "
203
+ f"speed={steps_per_sec:.1f} steps/s")
204
+
205
+ # Save checkpoint
206
+ if global_step % config.save_every == 0:
207
+ save_path = os.path.join(config.output_dir, f'checkpoint-{global_step}.pt')
208
+ torch.save({
209
+ 'step': global_step,
210
+ 'model_state_dict': model.state_dict(),
211
+ 'optimizer_state_dict': optimizer.state_dict(),
212
+ 'ema_state_dict': ema.state_dict(),
213
+ 'config': vars(config),
214
+ 'losses': losses[-1000:],
215
+ }, save_path)
216
+ print(f" 💾 Saved checkpoint: {save_path}")
217
+
218
+ # Final save
219
+ save_path = os.path.join(config.output_dir, 'final_model.pt')
220
+ torch.save({
221
+ 'step': global_step,
222
+ 'model_state_dict': model.state_dict(),
223
+ 'ema_state_dict': ema.state_dict(),
224
+ 'config': vars(config),
225
+ }, save_path)
226
+
227
+ elapsed = time.time() - start_time
228
+ print(f"\n✅ Training complete!")
229
+ print(f" Total steps: {global_step}")
230
+ print(f" Final loss: {sum(losses[-100:])/len(losses[-100:]):.4f}")
231
+ print(f" Total time: {elapsed:.0f}s ({elapsed/60:.1f}min)")
232
+ print(f" Saved to: {save_path}")
233
+
234
+ return model, ema
235
+
236
+
237
+ def main():
238
+ parser = argparse.ArgumentParser(description='Train LiRA')
239
+ parser.add_argument('--test_mode', action='store_true', help='Quick test with synthetic data')
240
+ parser.add_argument('--model_config', type=str, default='tiny')
241
+ parser.add_argument('--resolution', type=int, default=256)
242
+ parser.add_argument('--batch_size', type=int, default=4)
243
+ parser.add_argument('--max_steps', type=int, default=1000)
244
+ parser.add_argument('--learning_rate', type=float, default=1e-4)
245
+ parser.add_argument('--output_dir', type=str, default='./lira_output')
246
+ parser.add_argument('--dataset_name', type=str, default='')
247
+ args = parser.parse_args()
248
+
249
+ if args.test_mode:
250
+ config = LiRATrainingConfig(
251
+ model_config='tiny',
252
+ latent_channels=4,
253
+ spatial_compression=8,
254
+ d_text=768,
255
+ patch_size=2,
256
+ batch_size=2,
257
+ learning_rate=1e-4,
258
+ max_steps=50,
259
+ warmup_steps=5,
260
+ log_every=10,
261
+ save_every=25,
262
+ noise_schedule='laplace',
263
+ use_curriculum=True,
264
+ curriculum_warmup=20,
265
+ output_dir=args.output_dir,
266
+ )
267
+ else:
268
+ spatial_compression = 8 # Default f8 VAE
269
+ config = LiRATrainingConfig(
270
+ model_config=args.model_config,
271
+ latent_channels=4,
272
+ spatial_compression=spatial_compression,
273
+ d_text=768,
274
+ patch_size=2,
275
+ batch_size=args.batch_size,
276
+ learning_rate=args.learning_rate,
277
+ max_steps=args.max_steps,
278
+ output_dir=args.output_dir,
279
+ dataset_name=args.dataset_name,
280
+ )
281
+
282
+ train(config)
283
+
284
+
285
+ if __name__ == '__main__':
286
+ main()