PERF: force fp32 for FFT, JIT scan, fix RMSNorm, cache Manhattan
Browse files- iris_model.py +42 -45
iris_model.py
CHANGED
|
@@ -372,35 +372,34 @@ class FourierMixingPathway(nn.Module):
|
|
| 372 |
out_imag = torch.einsum('...ki,kij->...kj', x.real, w_imag) + torch.einsum('...ki,kij->...kj', x.imag, w_real)
|
| 373 |
return torch.complex(out_real, out_imag)
|
| 374 |
|
|
|
|
| 375 |
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
|
|
| 376 |
B, N, D = x.shape
|
| 377 |
x_2d = x.reshape(B, H, W, D)
|
| 378 |
|
| 379 |
-
# 2D Real FFT on spatial dimensions
|
| 380 |
x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho') # [B, H, W//2+1, D]
|
| 381 |
|
| 382 |
# Reshape channel dim for block-diagonal MLP: D → (num_blocks, block_size)
|
| 383 |
Hf, Wf = x_freq.shape[1], x_freq.shape[2]
|
| 384 |
x_freq = x_freq.reshape(B, Hf, Wf, self.num_blocks, self.block_size)
|
| 385 |
|
| 386 |
-
# Block MLP Layer 1
|
| 387 |
-
# x_freq: [B, Hf, Wf, num_blocks, block_size]
|
| 388 |
-
# w1: [num_blocks, block_size, block_size]
|
| 389 |
x_freq = self.complex_matmul(x_freq, self.w1_real, self.w1_imag)
|
| 390 |
-
x_freq = x_freq + self.b1
|
| 391 |
x_freq = torch.complex(F.relu(x_freq.real), F.relu(x_freq.imag))
|
| 392 |
|
| 393 |
# Block MLP Layer 2
|
| 394 |
x_freq = self.complex_matmul(x_freq, self.w2_real, self.w2_imag)
|
| 395 |
x_freq = x_freq + self.b2
|
| 396 |
|
| 397 |
-
# Reshape back
|
| 398 |
x_freq = x_freq.reshape(B, Hf, Wf, D)
|
| 399 |
|
| 400 |
# Soft-shrinkage (sparsity in Fourier domain)
|
| 401 |
magnitude = x_freq.abs()
|
| 402 |
shrunk_mag = F.relu(magnitude - self.sparsity_threshold)
|
| 403 |
-
# Preserve phase, shrink magnitude
|
| 404 |
x_freq = x_freq * (shrunk_mag / (magnitude + 1e-8))
|
| 405 |
|
| 406 |
# Inverse FFT
|
|
@@ -432,37 +431,33 @@ class GatedLinearRecurrence(nn.Module):
|
|
| 432 |
# Output projection
|
| 433 |
self.proj_out = nn.Linear(recurrence_dim * 2, dim)
|
| 434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
def _scan(self, x: torch.Tensor) -> torch.Tensor:
|
| 436 |
-
"""Gated linear recurrence scan. x: [B, N, rec_dim]
|
| 437 |
-
|
| 438 |
-
Uses chunked computation to reduce Python loop overhead.
|
| 439 |
-
For production, replace with a CUDA parallel scan kernel.
|
| 440 |
-
"""
|
| 441 |
B, N, D = x.shape
|
| 442 |
|
| 443 |
# Compute all gates in one shot (parallelized)
|
| 444 |
-
a_base = torch.sigmoid(self.Lambda)
|
| 445 |
-
r = torch.sigmoid(self.W_a(x))
|
| 446 |
-
i = torch.sigmoid(self.W_x(x))
|
| 447 |
|
| 448 |
-
|
| 449 |
-
a = a_base.pow(self.c * r) # [B, N, D]
|
| 450 |
-
|
| 451 |
-
# Normalized input: sqrt(1 - a^2) for variance preservation
|
| 452 |
input_scale = torch.sqrt(1.0 - a * a + 1e-8)
|
| 453 |
-
u = input_scale * (i * x)
|
| 454 |
-
|
| 455 |
-
# Sequential recurrence — use contiguous tensors for speed
|
| 456 |
-
a = a.contiguous()
|
| 457 |
-
u = u.contiguous()
|
| 458 |
-
h = torch.zeros(B, D, device=x.device, dtype=x.dtype)
|
| 459 |
-
outputs = torch.empty_like(u) # Pre-allocate output
|
| 460 |
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
outputs[:, t] = h
|
| 464 |
-
|
| 465 |
-
return outputs
|
| 466 |
|
| 467 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 468 |
B, N, D = x.shape
|
|
@@ -526,29 +521,30 @@ class ManhattanSpatialGate(nn.Module):
|
|
| 526 |
|
| 527 |
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 528 |
B, N, D = x.shape
|
|
|
|
| 529 |
|
| 530 |
-
# Compute spatial decay (
|
| 531 |
-
gamma = torch.sigmoid(self.gamma_logit) # [num_heads]
|
| 532 |
-
manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N]
|
| 533 |
|
| 534 |
-
# Window mask
|
| 535 |
-
decay_mask = (manhattan_dist <= self.window)
|
| 536 |
|
| 537 |
-
# Per-head decay: gamma_h^dist
|
| 538 |
-
# Only compute pow for positions within window (sparse)
|
| 539 |
decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
|
| 540 |
decay = decay * decay_mask.unsqueeze(0).float()
|
| 541 |
|
| 542 |
-
# Value and gate
|
| 543 |
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
|
| 544 |
g = torch.sigmoid(self.g_proj(x))
|
| 545 |
|
| 546 |
-
#
|
| 547 |
v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
|
| 548 |
-
|
|
|
|
| 549 |
|
| 550 |
-
# Normalize
|
| 551 |
-
decay_sum =
|
| 552 |
out = out / decay_sum
|
| 553 |
|
| 554 |
out = out.permute(0, 2, 1, 3).reshape(B, N, D)
|
|
@@ -623,8 +619,9 @@ class CrossAttention(nn.Module):
|
|
| 623 |
self.o_proj = nn.Linear(num_heads * head_dim, dim)
|
| 624 |
|
| 625 |
# QK normalization for stability (from SANA-Sprint)
|
| 626 |
-
|
| 627 |
-
self.
|
|
|
|
| 628 |
|
| 629 |
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
| 630 |
B, N, _ = x.shape
|
|
|
|
| 372 |
out_imag = torch.einsum('...ki,kij->...kj', x.real, w_imag) + torch.einsum('...ki,kij->...kj', x.imag, w_real)
|
| 373 |
return torch.complex(out_real, out_imag)
|
| 374 |
|
| 375 |
+
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
| 376 |
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 377 |
+
"""Forward pass — forced to fp32 because FFT + ComplexHalf is broken/slow."""
|
| 378 |
B, N, D = x.shape
|
| 379 |
x_2d = x.reshape(B, H, W, D)
|
| 380 |
|
| 381 |
+
# 2D Real FFT on spatial dimensions (MUST be fp32 — ComplexHalf is broken)
|
| 382 |
x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho') # [B, H, W//2+1, D]
|
| 383 |
|
| 384 |
# Reshape channel dim for block-diagonal MLP: D → (num_blocks, block_size)
|
| 385 |
Hf, Wf = x_freq.shape[1], x_freq.shape[2]
|
| 386 |
x_freq = x_freq.reshape(B, Hf, Wf, self.num_blocks, self.block_size)
|
| 387 |
|
| 388 |
+
# Block MLP Layer 1
|
|
|
|
|
|
|
| 389 |
x_freq = self.complex_matmul(x_freq, self.w1_real, self.w1_imag)
|
| 390 |
+
x_freq = x_freq + self.b1
|
| 391 |
x_freq = torch.complex(F.relu(x_freq.real), F.relu(x_freq.imag))
|
| 392 |
|
| 393 |
# Block MLP Layer 2
|
| 394 |
x_freq = self.complex_matmul(x_freq, self.w2_real, self.w2_imag)
|
| 395 |
x_freq = x_freq + self.b2
|
| 396 |
|
| 397 |
+
# Reshape back
|
| 398 |
x_freq = x_freq.reshape(B, Hf, Wf, D)
|
| 399 |
|
| 400 |
# Soft-shrinkage (sparsity in Fourier domain)
|
| 401 |
magnitude = x_freq.abs()
|
| 402 |
shrunk_mag = F.relu(magnitude - self.sparsity_threshold)
|
|
|
|
| 403 |
x_freq = x_freq * (shrunk_mag / (magnitude + 1e-8))
|
| 404 |
|
| 405 |
# Inverse FFT
|
|
|
|
| 431 |
# Output projection
|
| 432 |
self.proj_out = nn.Linear(recurrence_dim * 2, dim)
|
| 433 |
|
| 434 |
+
@staticmethod
|
| 435 |
+
@torch.jit.script
|
| 436 |
+
def _scan_kernel(a: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
|
| 437 |
+
"""JIT-compiled sequential scan — avoids Python loop overhead on GPU."""
|
| 438 |
+
B, N, D = a.shape
|
| 439 |
+
h = torch.zeros(B, D, device=a.device, dtype=a.dtype)
|
| 440 |
+
outputs = torch.empty_like(u)
|
| 441 |
+
for t in range(N):
|
| 442 |
+
h = a[:, t] * h + u[:, t]
|
| 443 |
+
outputs[:, t] = h
|
| 444 |
+
return outputs
|
| 445 |
+
|
| 446 |
def _scan(self, x: torch.Tensor) -> torch.Tensor:
|
| 447 |
+
"""Gated linear recurrence scan. x: [B, N, rec_dim]"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
B, N, D = x.shape
|
| 449 |
|
| 450 |
# Compute all gates in one shot (parallelized)
|
| 451 |
+
a_base = torch.sigmoid(self.Lambda)
|
| 452 |
+
r = torch.sigmoid(self.W_a(x))
|
| 453 |
+
i = torch.sigmoid(self.W_x(x))
|
| 454 |
|
| 455 |
+
a = a_base.pow(self.c * r)
|
|
|
|
|
|
|
|
|
|
| 456 |
input_scale = torch.sqrt(1.0 - a * a + 1e-8)
|
| 457 |
+
u = input_scale * (i * x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
+
# JIT-compiled scan (much faster than Python loop on GPU)
|
| 460 |
+
return self._scan_kernel(a.contiguous(), u.contiguous())
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 463 |
B, N, D = x.shape
|
|
|
|
| 521 |
|
| 522 |
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 523 |
B, N, D = x.shape
|
| 524 |
+
input_dtype = x.dtype
|
| 525 |
|
| 526 |
+
# Compute spatial decay in fp32 (pow in fp16 loses precision badly)
|
| 527 |
+
gamma = torch.sigmoid(self.gamma_logit).float() # [num_heads] fp32
|
| 528 |
+
manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N] fp32
|
| 529 |
|
| 530 |
+
# Window mask
|
| 531 |
+
decay_mask = (manhattan_dist <= self.window)
|
| 532 |
|
| 533 |
+
# Per-head decay: gamma_h^dist (fp32 for precision)
|
|
|
|
| 534 |
decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
|
| 535 |
decay = decay * decay_mask.unsqueeze(0).float()
|
| 536 |
|
| 537 |
+
# Value and gate (stay in input dtype for speed)
|
| 538 |
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
|
| 539 |
g = torch.sigmoid(self.g_proj(x))
|
| 540 |
|
| 541 |
+
# Matmul in input dtype (fp16 ok for matmul)
|
| 542 |
v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
|
| 543 |
+
decay_cast = decay.unsqueeze(0).to(input_dtype)
|
| 544 |
+
out = torch.matmul(decay_cast, v) # [B, heads, N, head_dim]
|
| 545 |
|
| 546 |
+
# Normalize
|
| 547 |
+
decay_sum = decay_cast.sum(dim=-1, keepdim=True) + 1e-8
|
| 548 |
out = out / decay_sum
|
| 549 |
|
| 550 |
out = out.permute(0, 2, 1, 3).reshape(B, N, D)
|
|
|
|
| 619 |
self.o_proj = nn.Linear(num_heads * head_dim, dim)
|
| 620 |
|
| 621 |
# QK normalization for stability (from SANA-Sprint)
|
| 622 |
+
# Use LayerNorm instead of RMSNorm — RMSNorm has fp16 weight mismatch issues
|
| 623 |
+
self.q_norm = nn.LayerNorm(head_dim, elementwise_affine=False)
|
| 624 |
+
self.k_norm = nn.LayerNorm(head_dim, elementwise_affine=False)
|
| 625 |
|
| 626 |
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
| 627 |
B, N, _ = x.shape
|