Add gradient checkpointing + zigzag index caching
Browse files
model.py
CHANGED
|
@@ -45,6 +45,7 @@ References:
|
|
| 45 |
import torch
|
| 46 |
import torch.nn as nn
|
| 47 |
import torch.nn.functional as F
|
|
|
|
| 48 |
import math
|
| 49 |
from typing import Optional, Tuple
|
| 50 |
|
|
@@ -168,26 +169,26 @@ class ZigzagScan1D(nn.Module):
|
|
| 168 |
padding=kernel_size // 2, groups=channels, bias=False)
|
| 169 |
self.pw = nn.Conv1d(channels, channels, 1, bias=True)
|
| 170 |
self.act = nn.GELU()
|
|
|
|
| 171 |
|
| 172 |
-
def
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
row =
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
return
|
| 186 |
|
| 187 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 188 |
B, C, H, W = x.shape
|
| 189 |
-
zz_idx = self.
|
| 190 |
-
inv_idx = self._inverse_zigzag_indices(H, W, x.device)
|
| 191 |
x_flat = x.reshape(B, C, H * W)
|
| 192 |
x_zz = x_flat[:, :, zz_idx]
|
| 193 |
x_mixed = self.pw(self.act(self.conv1d(x_zz)))
|
|
@@ -361,6 +362,16 @@ class LiquidGen(nn.Module):
|
|
| 361 |
nn.init.zeros_(self.unpatch.bias)
|
| 362 |
|
| 363 |
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
def _init_weights(self, m):
|
| 366 |
if isinstance(m, nn.Conv2d):
|
|
@@ -410,7 +421,10 @@ class LiquidGen(nn.Module):
|
|
| 410 |
elif i >= mid and len(skip_connections) > 0:
|
| 411 |
skip = skip_connections.pop()
|
| 412 |
h = h + skip
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
h = self.final_norm(h)
|
| 416 |
h = self.final_proj(h)
|
|
|
|
| 45 |
import torch
|
| 46 |
import torch.nn as nn
|
| 47 |
import torch.nn.functional as F
|
| 48 |
+
from torch.utils.checkpoint import checkpoint
|
| 49 |
import math
|
| 50 |
from typing import Optional, Tuple
|
| 51 |
|
|
|
|
| 169 |
padding=kernel_size // 2, groups=channels, bias=False)
|
| 170 |
self.pw = nn.Conv1d(channels, channels, 1, bias=True)
|
| 171 |
self.act = nn.GELU()
|
| 172 |
+
self._idx_cache = {}
|
| 173 |
|
| 174 |
+
def _get_indices(self, H: int, W: int, device: torch.device):
|
| 175 |
+
key = (H, W, device)
|
| 176 |
+
if key not in self._idx_cache:
|
| 177 |
+
indices = []
|
| 178 |
+
for i in range(H):
|
| 179 |
+
row = list(range(i * W, (i + 1) * W))
|
| 180 |
+
if i % 2 == 1:
|
| 181 |
+
row = row[::-1]
|
| 182 |
+
indices.extend(row)
|
| 183 |
+
fwd = torch.tensor(indices, device=device, dtype=torch.long)
|
| 184 |
+
inv = torch.empty_like(fwd)
|
| 185 |
+
inv[fwd] = torch.arange(H * W, device=device)
|
| 186 |
+
self._idx_cache[key] = (fwd, inv)
|
| 187 |
+
return self._idx_cache[key]
|
| 188 |
|
| 189 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 190 |
B, C, H, W = x.shape
|
| 191 |
+
zz_idx, inv_idx = self._get_indices(H, W, x.device)
|
|
|
|
| 192 |
x_flat = x.reshape(B, C, H * W)
|
| 193 |
x_zz = x_flat[:, :, zz_idx]
|
| 194 |
x_mixed = self.pw(self.act(self.conv1d(x_zz)))
|
|
|
|
| 362 |
nn.init.zeros_(self.unpatch.bias)
|
| 363 |
|
| 364 |
self.apply(self._init_weights)
|
| 365 |
+
self._gradient_checkpointing = False
|
| 366 |
+
|
| 367 |
+
def enable_gradient_checkpointing(self):
|
| 368 |
+
"""Enable gradient checkpointing to reduce VRAM by ~40-60%.
|
| 369 |
+
Recomputes block activations during backward instead of storing them.
|
| 370 |
+
Slower training (~30%) but allows much larger batch sizes or models."""
|
| 371 |
+
self._gradient_checkpointing = True
|
| 372 |
+
|
| 373 |
+
def disable_gradient_checkpointing(self):
|
| 374 |
+
self._gradient_checkpointing = False
|
| 375 |
|
| 376 |
def _init_weights(self, m):
|
| 377 |
if isinstance(m, nn.Conv2d):
|
|
|
|
| 421 |
elif i >= mid and len(skip_connections) > 0:
|
| 422 |
skip = skip_connections.pop()
|
| 423 |
h = h + skip
|
| 424 |
+
if self._gradient_checkpointing and self.training:
|
| 425 |
+
h = checkpoint(block, h, cond, use_reentrant=False)
|
| 426 |
+
else:
|
| 427 |
+
h = block(h, cond)
|
| 428 |
|
| 429 |
h = self.final_norm(h)
|
| 430 |
h = self.final_proj(h)
|