asdf98 commited on
Commit
193fbf7
·
verified ·
1 Parent(s): 49c59f1

Add gradient checkpointing + zigzag index caching

Browse files
Files changed (1) hide show
  1. model.py +31 -17
model.py CHANGED
@@ -45,6 +45,7 @@ References:
45
  import torch
46
  import torch.nn as nn
47
  import torch.nn.functional as F
 
48
  import math
49
  from typing import Optional, Tuple
50
 
@@ -168,26 +169,26 @@ class ZigzagScan1D(nn.Module):
168
  padding=kernel_size // 2, groups=channels, bias=False)
169
  self.pw = nn.Conv1d(channels, channels, 1, bias=True)
170
  self.act = nn.GELU()
 
171
 
172
- def _zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:
173
- indices = []
174
- for i in range(H):
175
- row = list(range(i * W, (i + 1) * W))
176
- if i % 2 == 1:
177
- row = row[::-1]
178
- indices.extend(row)
179
- return torch.tensor(indices, device=device, dtype=torch.long)
180
-
181
- def _inverse_zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:
182
- fwd = self._zigzag_indices(H, W, device)
183
- inv = torch.empty_like(fwd)
184
- inv[fwd] = torch.arange(H * W, device=device)
185
- return inv
186
 
187
  def forward(self, x: torch.Tensor) -> torch.Tensor:
188
  B, C, H, W = x.shape
189
- zz_idx = self._zigzag_indices(H, W, x.device)
190
- inv_idx = self._inverse_zigzag_indices(H, W, x.device)
191
  x_flat = x.reshape(B, C, H * W)
192
  x_zz = x_flat[:, :, zz_idx]
193
  x_mixed = self.pw(self.act(self.conv1d(x_zz)))
@@ -361,6 +362,16 @@ class LiquidGen(nn.Module):
361
  nn.init.zeros_(self.unpatch.bias)
362
 
363
  self.apply(self._init_weights)
 
 
 
 
 
 
 
 
 
 
364
 
365
  def _init_weights(self, m):
366
  if isinstance(m, nn.Conv2d):
@@ -410,7 +421,10 @@ class LiquidGen(nn.Module):
410
  elif i >= mid and len(skip_connections) > 0:
411
  skip = skip_connections.pop()
412
  h = h + skip
413
- h = block(h, cond)
 
 
 
414
 
415
  h = self.final_norm(h)
416
  h = self.final_proj(h)
 
45
  import torch
46
  import torch.nn as nn
47
  import torch.nn.functional as F
48
+ from torch.utils.checkpoint import checkpoint
49
  import math
50
  from typing import Optional, Tuple
51
 
 
169
  padding=kernel_size // 2, groups=channels, bias=False)
170
  self.pw = nn.Conv1d(channels, channels, 1, bias=True)
171
  self.act = nn.GELU()
172
+ self._idx_cache = {}
173
 
174
+ def _get_indices(self, H: int, W: int, device: torch.device):
175
+ key = (H, W, device)
176
+ if key not in self._idx_cache:
177
+ indices = []
178
+ for i in range(H):
179
+ row = list(range(i * W, (i + 1) * W))
180
+ if i % 2 == 1:
181
+ row = row[::-1]
182
+ indices.extend(row)
183
+ fwd = torch.tensor(indices, device=device, dtype=torch.long)
184
+ inv = torch.empty_like(fwd)
185
+ inv[fwd] = torch.arange(H * W, device=device)
186
+ self._idx_cache[key] = (fwd, inv)
187
+ return self._idx_cache[key]
188
 
189
  def forward(self, x: torch.Tensor) -> torch.Tensor:
190
  B, C, H, W = x.shape
191
+ zz_idx, inv_idx = self._get_indices(H, W, x.device)
 
192
  x_flat = x.reshape(B, C, H * W)
193
  x_zz = x_flat[:, :, zz_idx]
194
  x_mixed = self.pw(self.act(self.conv1d(x_zz)))
 
362
  nn.init.zeros_(self.unpatch.bias)
363
 
364
  self.apply(self._init_weights)
365
+ self._gradient_checkpointing = False
366
+
367
+ def enable_gradient_checkpointing(self):
368
+ """Enable gradient checkpointing to reduce VRAM by ~40-60%.
369
+ Recomputes block activations during backward instead of storing them.
370
+ Slower training (~30%) but allows much larger batch sizes or models."""
371
+ self._gradient_checkpointing = True
372
+
373
+ def disable_gradient_checkpointing(self):
374
+ self._gradient_checkpointing = False
375
 
376
  def _init_weights(self, m):
377
  if isinstance(m, nn.Conv2d):
 
421
  elif i >= mid and len(skip_connections) > 0:
422
  skip = skip_connections.pop()
423
  h = h + skip
424
+ if self._gradient_checkpointing and self.training:
425
+ h = checkpoint(block, h, cond, use_reentrant=False)
426
+ else:
427
+ h = block(h, cond)
428
 
429
  h = self.final_norm(h)
430
  h = self.final_proj(h)