krystv commited on
Commit
992d967
·
verified ·
1 Parent(s): 8589a61

Upload liquid_diffusion/trainer.py

Browse files
Files changed (1) hide show
  1. liquid_diffusion/trainer.py +88 -20
liquid_diffusion/trainer.py CHANGED
@@ -29,20 +29,30 @@ from torchvision.utils import save_image, make_grid
29
 
30
 
31
  class RectifiedFlowTrainer:
32
- """Trainer for LiquidDiffusion using Rectified Flow objective."""
 
 
 
 
 
 
 
 
33
 
34
  def __init__(self, model, optimizer=None, lr=1e-4, weight_decay=0.01,
35
  ema_decay=0.9999, grad_clip=1.0, time_sampling="logit_normal",
36
  logit_normal_mean=0.0, logit_normal_std=1.0, device="cuda",
37
  use_amp=True, amp_dtype="float16"):
38
- self.model = model.to(device)
39
  self.device = device
 
40
  self.ema_decay = ema_decay
41
  self.grad_clip = grad_clip
42
  self.time_sampling = time_sampling
43
  self.logit_normal_mean = logit_normal_mean
44
  self.logit_normal_std = logit_normal_std
45
- self.use_amp = use_amp and device == "cuda"
 
 
46
  self.amp_dtype = getattr(torch, amp_dtype) if self.use_amp else torch.float32
47
 
48
  if optimizer is None:
@@ -51,12 +61,18 @@ class RectifiedFlowTrainer:
51
  else:
52
  self.optimizer = optimizer
53
 
54
- self.scaler = torch.amp.GradScaler("cuda", enabled=(self.use_amp and amp_dtype == "float16"))
 
 
 
 
 
55
  self.ema_model = self._build_ema()
56
  self.step = 0
57
  self.losses = []
58
 
59
  def _build_ema(self):
 
60
  ema = copy.deepcopy(self.model)
61
  ema.eval()
62
  for p in ema.parameters():
@@ -65,10 +81,12 @@ class RectifiedFlowTrainer:
65
 
66
  @torch.no_grad()
67
  def _update_ema(self):
 
68
  for ema_p, model_p in zip(self.ema_model.parameters(), self.model.parameters()):
69
  ema_p.data.mul_(self.ema_decay).add_(model_p.data, alpha=1 - self.ema_decay)
70
 
71
  def _sample_time(self, batch_size):
 
72
  eps = 1e-5
73
  if self.time_sampling == "uniform":
74
  return torch.rand(batch_size, device=self.device) * (1 - 2*eps) + eps
@@ -78,68 +96,113 @@ class RectifiedFlowTrainer:
78
  raise ValueError(f"Unknown time_sampling: {self.time_sampling}")
79
 
80
  def train_step(self, x0):
 
81
  self.model.train()
 
82
  x1 = torch.randn_like(x0)
83
  t = self._sample_time(x0.shape[0])
84
  t_expand = t[:, None, None, None]
85
  x_t = (1 - t_expand) * x0 + t_expand * x1
86
  v_target = x1 - x0
87
 
88
- with torch.amp.autocast(self.device, dtype=self.amp_dtype, enabled=self.use_amp):
 
 
 
 
 
89
  v_pred = self.model(x_t, t)
90
  loss = F.mse_loss(v_pred, v_target)
91
 
 
92
  self.optimizer.zero_grad(set_to_none=True)
93
  self.scaler.scale(loss).backward()
 
94
  if self.grad_clip > 0:
95
  self.scaler.unscale_(self.optimizer)
96
  grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
97
  else:
98
  grad_norm = torch.tensor(0.0)
 
99
  self.scaler.step(self.optimizer)
100
  self.scaler.update()
101
  self._update_ema()
 
102
  self.step += 1
103
  loss_val = loss.item()
104
  self.losses.append(loss_val)
105
- return {'loss': loss_val, 'grad_norm': grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm, 'step': self.step}
 
 
 
 
106
 
107
  @torch.no_grad()
108
  def sample(self, batch_size=4, image_size=256, channels=3, num_steps=50, use_ema=True):
 
109
  model = self.ema_model if use_ema else self.model
110
  model.eval()
111
  z = torch.randn(batch_size, channels, image_size, image_size, device=self.device)
112
  dt = 1.0 / num_steps
 
113
  for i in range(num_steps, 0, -1):
114
  t = torch.full((batch_size,), i / num_steps, device=self.device)
115
- with torch.amp.autocast(self.device, dtype=self.amp_dtype, enabled=self.use_amp):
116
- v = model(z, t)
117
- if self.use_amp and self.amp_dtype == torch.float16:
118
  v = v.float()
 
 
119
  z = z - v * dt
 
120
  return z.clamp(-1, 1)
121
 
122
  def save_checkpoint(self, path, extra=None):
123
- ckpt = {'model': self.model.state_dict(), 'ema_model': self.ema_model.state_dict(),
124
- 'optimizer': self.optimizer.state_dict(), 'scaler': self.scaler.state_dict(),
125
- 'step': self.step, 'losses': self.losses[-1000:]}
126
- if extra: ckpt.update(extra)
127
- os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
 
 
 
 
 
 
 
 
 
128
  torch.save(ckpt, path)
129
 
130
  def load_checkpoint(self, path):
 
131
  ckpt = torch.load(path, map_location=self.device, weights_only=False)
132
  self.model.load_state_dict(ckpt['model'])
133
  self.ema_model.load_state_dict(ckpt['ema_model'])
134
  self.optimizer.load_state_dict(ckpt['optimizer'])
135
- if 'scaler' in ckpt: self.scaler.load_state_dict(ckpt['scaler'])
 
136
  self.step = ckpt.get('step', 0)
137
  self.losses = ckpt.get('losses', [])
138
  print(f"Resumed from step {self.step}")
139
 
140
 
141
  class ImageDataset(Dataset):
142
- """Image dataset from local folder or HuggingFace Hub."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def __init__(self, source, image_size=256, split="train",
144
  image_column="image", max_samples=None, hf_dataset=None):
145
  self.image_size = image_size
@@ -149,8 +212,9 @@ class ImageDataset(Dataset):
149
  transforms.CenterCrop(image_size),
150
  transforms.RandomHorizontalFlip(),
151
  transforms.ToTensor(),
152
- transforms.Normalize([0.5], [0.5]),
153
  ])
 
154
  if hf_dataset is not None:
155
  self.data = hf_dataset
156
  self.mode = "hf"
@@ -160,12 +224,14 @@ class ImageDataset(Dataset):
160
  for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:
161
  self.files.extend(glob(os.path.join(source, '**', ext), recursive=True))
162
  self.files.sort()
163
- if max_samples: self.files = self.files[:max_samples]
 
164
  self.mode = "folder"
165
  else:
166
  from datasets import load_dataset
167
  self.data = load_dataset(source, split=split)
168
- if max_samples: self.data = self.data.select(range(min(max_samples, len(self.data))))
 
169
  self.mode = "hf"
170
 
171
  def __len__(self):
@@ -185,9 +251,11 @@ class ImageDataset(Dataset):
185
 
186
 
187
  def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
 
188
  def lr_lambda(step):
189
  if step < num_warmup_steps:
190
  return float(step) / float(max(1, num_warmup_steps))
191
- progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
 
192
  return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
193
  return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
 
29
 
30
 
31
  class RectifiedFlowTrainer:
32
+ """Trainer for LiquidDiffusion using Rectified Flow objective.
33
+
34
+ Features:
35
+ - Simple MSE velocity loss (no noise schedule to tune)
36
+ - Optional logit-normal time sampling (from SD3)
37
+ - EMA model for stable sampling
38
+ - Gradient clipping, mixed precision
39
+ - Checkpoint save/load with resume support
40
+ """
41
 
42
  def __init__(self, model, optimizer=None, lr=1e-4, weight_decay=0.01,
43
  ema_decay=0.9999, grad_clip=1.0, time_sampling="logit_normal",
44
  logit_normal_mean=0.0, logit_normal_std=1.0, device="cuda",
45
  use_amp=True, amp_dtype="float16"):
 
46
  self.device = device
47
+ self.model = model.to(device)
48
  self.ema_decay = ema_decay
49
  self.grad_clip = grad_clip
50
  self.time_sampling = time_sampling
51
  self.logit_normal_mean = logit_normal_mean
52
  self.logit_normal_std = logit_normal_std
53
+
54
+ # AMP only on CUDA
55
+ self.use_amp = use_amp and (device == "cuda")
56
  self.amp_dtype = getattr(torch, amp_dtype) if self.use_amp else torch.float32
57
 
58
  if optimizer is None:
 
61
  else:
62
  self.optimizer = optimizer
63
 
64
+ # GradScaler only when AMP is active
65
+ if self.use_amp and amp_dtype == "float16":
66
+ self.scaler = torch.amp.GradScaler("cuda", enabled=True)
67
+ else:
68
+ self.scaler = torch.amp.GradScaler("cuda", enabled=False)
69
+
70
  self.ema_model = self._build_ema()
71
  self.step = 0
72
  self.losses = []
73
 
74
  def _build_ema(self):
75
+ """Create EMA copy of model."""
76
  ema = copy.deepcopy(self.model)
77
  ema.eval()
78
  for p in ema.parameters():
 
81
 
82
  @torch.no_grad()
83
  def _update_ema(self):
84
+ """Update EMA weights: ema = decay * ema + (1-decay) * model"""
85
  for ema_p, model_p in zip(self.ema_model.parameters(), self.model.parameters()):
86
  ema_p.data.mul_(self.ema_decay).add_(model_p.data, alpha=1 - self.ema_decay)
87
 
88
  def _sample_time(self, batch_size):
89
+ """Sample timesteps. logit_normal puts more weight near t=0.5."""
90
  eps = 1e-5
91
  if self.time_sampling == "uniform":
92
  return torch.rand(batch_size, device=self.device) * (1 - 2*eps) + eps
 
96
  raise ValueError(f"Unknown time_sampling: {self.time_sampling}")
97
 
98
  def train_step(self, x0):
99
+ """Single training step. x0: [B,C,H,W] images in [-1,1]."""
100
  self.model.train()
101
+ x0 = x0.to(self.device)
102
  x1 = torch.randn_like(x0)
103
  t = self._sample_time(x0.shape[0])
104
  t_expand = t[:, None, None, None]
105
  x_t = (1 - t_expand) * x0 + t_expand * x1
106
  v_target = x1 - x0
107
 
108
+ # Forward with optional AMP
109
+ if self.use_amp:
110
+ with torch.amp.autocast("cuda", dtype=self.amp_dtype):
111
+ v_pred = self.model(x_t, t)
112
+ loss = F.mse_loss(v_pred, v_target)
113
+ else:
114
  v_pred = self.model(x_t, t)
115
  loss = F.mse_loss(v_pred, v_target)
116
 
117
+ # Backward
118
  self.optimizer.zero_grad(set_to_none=True)
119
  self.scaler.scale(loss).backward()
120
+
121
  if self.grad_clip > 0:
122
  self.scaler.unscale_(self.optimizer)
123
  grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
124
  else:
125
  grad_norm = torch.tensor(0.0)
126
+
127
  self.scaler.step(self.optimizer)
128
  self.scaler.update()
129
  self._update_ema()
130
+
131
  self.step += 1
132
  loss_val = loss.item()
133
  self.losses.append(loss_val)
134
+ return {
135
+ 'loss': loss_val,
136
+ 'grad_norm': grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
137
+ 'step': self.step,
138
+ }
139
 
140
  @torch.no_grad()
141
  def sample(self, batch_size=4, image_size=256, channels=3, num_steps=50, use_ema=True):
142
+ """Generate images via Euler ODE integration from noise → data."""
143
  model = self.ema_model if use_ema else self.model
144
  model.eval()
145
  z = torch.randn(batch_size, channels, image_size, image_size, device=self.device)
146
  dt = 1.0 / num_steps
147
+
148
  for i in range(num_steps, 0, -1):
149
  t = torch.full((batch_size,), i / num_steps, device=self.device)
150
+ if self.use_amp:
151
+ with torch.amp.autocast("cuda", dtype=self.amp_dtype):
152
+ v = model(z, t)
153
  v = v.float()
154
+ else:
155
+ v = model(z, t)
156
  z = z - v * dt
157
+
158
  return z.clamp(-1, 1)
159
 
160
  def save_checkpoint(self, path, extra=None):
161
+ """Save model, EMA, optimizer, scaler, and training state."""
162
+ ckpt = {
163
+ 'model': self.model.state_dict(),
164
+ 'ema_model': self.ema_model.state_dict(),
165
+ 'optimizer': self.optimizer.state_dict(),
166
+ 'scaler': self.scaler.state_dict(),
167
+ 'step': self.step,
168
+ 'losses': self.losses[-1000:],
169
+ }
170
+ if extra:
171
+ ckpt.update(extra)
172
+ dir_path = os.path.dirname(path)
173
+ if dir_path:
174
+ os.makedirs(dir_path, exist_ok=True)
175
  torch.save(ckpt, path)
176
 
177
  def load_checkpoint(self, path):
178
+ """Load checkpoint and resume training."""
179
  ckpt = torch.load(path, map_location=self.device, weights_only=False)
180
  self.model.load_state_dict(ckpt['model'])
181
  self.ema_model.load_state_dict(ckpt['ema_model'])
182
  self.optimizer.load_state_dict(ckpt['optimizer'])
183
+ if 'scaler' in ckpt:
184
+ self.scaler.load_state_dict(ckpt['scaler'])
185
  self.step = ckpt.get('step', 0)
186
  self.losses = ckpt.get('losses', [])
187
  print(f"Resumed from step {self.step}")
188
 
189
 
190
  class ImageDataset(Dataset):
191
+ """Image dataset from local folder or HuggingFace Hub.
192
+
193
+ Usage:
194
+ # From HuggingFace
195
+ ds = ImageDataset("huggan/CelebA-HQ", image_size=256)
196
+
197
+ # From local folder
198
+ ds = ImageDataset("/path/to/images", image_size=256)
199
+
200
+ # With pre-loaded HF dataset
201
+ from datasets import load_dataset
202
+ hf_ds = load_dataset("huggan/CelebA-HQ", split="train")
203
+ ds = ImageDataset(None, image_size=256, hf_dataset=hf_ds)
204
+ """
205
+
206
  def __init__(self, source, image_size=256, split="train",
207
  image_column="image", max_samples=None, hf_dataset=None):
208
  self.image_size = image_size
 
212
  transforms.CenterCrop(image_size),
213
  transforms.RandomHorizontalFlip(),
214
  transforms.ToTensor(),
215
+ transforms.Normalize([0.5], [0.5]), # → [-1, 1]
216
  ])
217
+
218
  if hf_dataset is not None:
219
  self.data = hf_dataset
220
  self.mode = "hf"
 
224
  for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:
225
  self.files.extend(glob(os.path.join(source, '**', ext), recursive=True))
226
  self.files.sort()
227
+ if max_samples:
228
+ self.files = self.files[:max_samples]
229
  self.mode = "folder"
230
  else:
231
  from datasets import load_dataset
232
  self.data = load_dataset(source, split=split)
233
+ if max_samples:
234
+ self.data = self.data.select(range(min(max_samples, len(self.data))))
235
  self.mode = "hf"
236
 
237
  def __len__(self):
 
251
 
252
 
253
  def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
254
+ """Cosine annealing with linear warmup — standard for diffusion training."""
255
  def lr_lambda(step):
256
  if step < num_warmup_steps:
257
  return float(step) / float(max(1, num_warmup_steps))
258
+ progress = float(step - num_warmup_steps) / float(
259
+ max(1, num_training_steps - num_warmup_steps))
260
  return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
261
  return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)