asdf98 commited on
Commit
eb07d9d
·
verified ·
1 Parent(s): 213365a

PERF: force fp32 for FFT, JIT scan, fix RMSNorm, cache Manhattan

Browse files
Files changed (1) hide show
  1. iris_model.py +42 -45
iris_model.py CHANGED
@@ -372,35 +372,34 @@ class FourierMixingPathway(nn.Module):
372
  out_imag = torch.einsum('...ki,kij->...kj', x.real, w_imag) + torch.einsum('...ki,kij->...kj', x.imag, w_real)
373
  return torch.complex(out_real, out_imag)
374
 
 
375
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
 
376
  B, N, D = x.shape
377
  x_2d = x.reshape(B, H, W, D)
378
 
379
- # 2D Real FFT on spatial dimensions
380
  x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho') # [B, H, W//2+1, D]
381
 
382
  # Reshape channel dim for block-diagonal MLP: D → (num_blocks, block_size)
383
  Hf, Wf = x_freq.shape[1], x_freq.shape[2]
384
  x_freq = x_freq.reshape(B, Hf, Wf, self.num_blocks, self.block_size)
385
 
386
- # Block MLP Layer 1: operates on last dim (block_size)
387
- # x_freq: [B, Hf, Wf, num_blocks, block_size]
388
- # w1: [num_blocks, block_size, block_size]
389
  x_freq = self.complex_matmul(x_freq, self.w1_real, self.w1_imag)
390
- x_freq = x_freq + self.b1 # Broadcast bias (real only)
391
  x_freq = torch.complex(F.relu(x_freq.real), F.relu(x_freq.imag))
392
 
393
  # Block MLP Layer 2
394
  x_freq = self.complex_matmul(x_freq, self.w2_real, self.w2_imag)
395
  x_freq = x_freq + self.b2
396
 
397
- # Reshape back to [B, Hf, Wf, D]
398
  x_freq = x_freq.reshape(B, Hf, Wf, D)
399
 
400
  # Soft-shrinkage (sparsity in Fourier domain)
401
  magnitude = x_freq.abs()
402
  shrunk_mag = F.relu(magnitude - self.sparsity_threshold)
403
- # Preserve phase, shrink magnitude
404
  x_freq = x_freq * (shrunk_mag / (magnitude + 1e-8))
405
 
406
  # Inverse FFT
@@ -432,37 +431,33 @@ class GatedLinearRecurrence(nn.Module):
432
  # Output projection
433
  self.proj_out = nn.Linear(recurrence_dim * 2, dim)
434
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  def _scan(self, x: torch.Tensor) -> torch.Tensor:
436
- """Gated linear recurrence scan. x: [B, N, rec_dim]
437
-
438
- Uses chunked computation to reduce Python loop overhead.
439
- For production, replace with a CUDA parallel scan kernel.
440
- """
441
  B, N, D = x.shape
442
 
443
  # Compute all gates in one shot (parallelized)
444
- a_base = torch.sigmoid(self.Lambda) # [D]
445
- r = torch.sigmoid(self.W_a(x)) # [B, N, D]
446
- i = torch.sigmoid(self.W_x(x)) # [B, N, D]
447
 
448
- # a_t = a_base^(c * r_t) — data-dependent decay
449
- a = a_base.pow(self.c * r) # [B, N, D]
450
-
451
- # Normalized input: sqrt(1 - a^2) for variance preservation
452
  input_scale = torch.sqrt(1.0 - a * a + 1e-8)
453
- u = input_scale * (i * x) # [B, N, D]
454
-
455
- # Sequential recurrence — use contiguous tensors for speed
456
- a = a.contiguous()
457
- u = u.contiguous()
458
- h = torch.zeros(B, D, device=x.device, dtype=x.dtype)
459
- outputs = torch.empty_like(u) # Pre-allocate output
460
 
461
- for t in range(N):
462
- h = a[:, t] * h + u[:, t]
463
- outputs[:, t] = h
464
-
465
- return outputs
466
 
467
  def forward(self, x: torch.Tensor) -> torch.Tensor:
468
  B, N, D = x.shape
@@ -526,29 +521,30 @@ class ManhattanSpatialGate(nn.Module):
526
 
527
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
528
  B, N, D = x.shape
 
529
 
530
- # Compute spatial decay (distance matrix is cached)
531
- gamma = torch.sigmoid(self.gamma_logit) # [num_heads]
532
- manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N]
533
 
534
- # Window mask — only positions within window distance contribute
535
- decay_mask = (manhattan_dist <= self.window) # bool [N, N]
536
 
537
- # Per-head decay: gamma_h^dist, masked to window
538
- # Only compute pow for positions within window (sparse)
539
  decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
540
  decay = decay * decay_mask.unsqueeze(0).float()
541
 
542
- # Value and gate
543
  v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
544
  g = torch.sigmoid(self.g_proj(x))
545
 
546
- # Apply spatial decay to values
547
  v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
548
- out = torch.matmul(decay.unsqueeze(0), v) # [B, heads, N, head_dim]
 
549
 
550
- # Normalize by decay sum
551
- decay_sum = decay.sum(dim=-1, keepdim=True).unsqueeze(0) + 1e-8
552
  out = out / decay_sum
553
 
554
  out = out.permute(0, 2, 1, 3).reshape(B, N, D)
@@ -623,8 +619,9 @@ class CrossAttention(nn.Module):
623
  self.o_proj = nn.Linear(num_heads * head_dim, dim)
624
 
625
  # QK normalization for stability (from SANA-Sprint)
626
- self.q_norm = nn.RMSNorm(head_dim)
627
- self.k_norm = nn.RMSNorm(head_dim)
 
628
 
629
  def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
630
  B, N, _ = x.shape
 
372
  out_imag = torch.einsum('...ki,kij->...kj', x.real, w_imag) + torch.einsum('...ki,kij->...kj', x.imag, w_real)
373
  return torch.complex(out_real, out_imag)
374
 
375
+ @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
376
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
377
+ """Forward pass — forced to fp32 because FFT + ComplexHalf is broken/slow."""
378
  B, N, D = x.shape
379
  x_2d = x.reshape(B, H, W, D)
380
 
381
+ # 2D Real FFT on spatial dimensions (MUST be fp32 — ComplexHalf is broken)
382
  x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho') # [B, H, W//2+1, D]
383
 
384
  # Reshape channel dim for block-diagonal MLP: D → (num_blocks, block_size)
385
  Hf, Wf = x_freq.shape[1], x_freq.shape[2]
386
  x_freq = x_freq.reshape(B, Hf, Wf, self.num_blocks, self.block_size)
387
 
388
+ # Block MLP Layer 1
 
 
389
  x_freq = self.complex_matmul(x_freq, self.w1_real, self.w1_imag)
390
+ x_freq = x_freq + self.b1
391
  x_freq = torch.complex(F.relu(x_freq.real), F.relu(x_freq.imag))
392
 
393
  # Block MLP Layer 2
394
  x_freq = self.complex_matmul(x_freq, self.w2_real, self.w2_imag)
395
  x_freq = x_freq + self.b2
396
 
397
+ # Reshape back
398
  x_freq = x_freq.reshape(B, Hf, Wf, D)
399
 
400
  # Soft-shrinkage (sparsity in Fourier domain)
401
  magnitude = x_freq.abs()
402
  shrunk_mag = F.relu(magnitude - self.sparsity_threshold)
 
403
  x_freq = x_freq * (shrunk_mag / (magnitude + 1e-8))
404
 
405
  # Inverse FFT
 
431
  # Output projection
432
  self.proj_out = nn.Linear(recurrence_dim * 2, dim)
433
 
434
+ @staticmethod
435
+ @torch.jit.script
436
+ def _scan_kernel(a: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
437
+ """JIT-compiled sequential scan — avoids Python loop overhead on GPU."""
438
+ B, N, D = a.shape
439
+ h = torch.zeros(B, D, device=a.device, dtype=a.dtype)
440
+ outputs = torch.empty_like(u)
441
+ for t in range(N):
442
+ h = a[:, t] * h + u[:, t]
443
+ outputs[:, t] = h
444
+ return outputs
445
+
446
  def _scan(self, x: torch.Tensor) -> torch.Tensor:
447
+ """Gated linear recurrence scan. x: [B, N, rec_dim]"""
 
 
 
 
448
  B, N, D = x.shape
449
 
450
  # Compute all gates in one shot (parallelized)
451
+ a_base = torch.sigmoid(self.Lambda)
452
+ r = torch.sigmoid(self.W_a(x))
453
+ i = torch.sigmoid(self.W_x(x))
454
 
455
+ a = a_base.pow(self.c * r)
 
 
 
456
  input_scale = torch.sqrt(1.0 - a * a + 1e-8)
457
+ u = input_scale * (i * x)
 
 
 
 
 
 
458
 
459
+ # JIT-compiled scan (much faster than Python loop on GPU)
460
+ return self._scan_kernel(a.contiguous(), u.contiguous())
 
 
 
461
 
462
  def forward(self, x: torch.Tensor) -> torch.Tensor:
463
  B, N, D = x.shape
 
521
 
522
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
523
  B, N, D = x.shape
524
+ input_dtype = x.dtype
525
 
526
+ # Compute spatial decay in fp32 (pow in fp16 loses precision badly)
527
+ gamma = torch.sigmoid(self.gamma_logit).float() # [num_heads] fp32
528
+ manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N] fp32
529
 
530
+ # Window mask
531
+ decay_mask = (manhattan_dist <= self.window)
532
 
533
+ # Per-head decay: gamma_h^dist (fp32 for precision)
 
534
  decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
535
  decay = decay * decay_mask.unsqueeze(0).float()
536
 
537
+ # Value and gate (stay in input dtype for speed)
538
  v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
539
  g = torch.sigmoid(self.g_proj(x))
540
 
541
+ # Matmul in input dtype (fp16 ok for matmul)
542
  v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
543
+ decay_cast = decay.unsqueeze(0).to(input_dtype)
544
+ out = torch.matmul(decay_cast, v) # [B, heads, N, head_dim]
545
 
546
+ # Normalize
547
+ decay_sum = decay_cast.sum(dim=-1, keepdim=True) + 1e-8
548
  out = out / decay_sum
549
 
550
  out = out.permute(0, 2, 1, 3).reshape(B, N, D)
 
619
  self.o_proj = nn.Linear(num_heads * head_dim, dim)
620
 
621
  # QK normalization for stability (from SANA-Sprint)
622
+ # Use LayerNorm instead of RMSNorm — RMSNorm has fp16 weight mismatch issues
623
+ self.q_norm = nn.LayerNorm(head_dim, elementwise_affine=False)
624
+ self.k_norm = nn.LayerNorm(head_dim, elementwise_affine=False)
625
 
626
  def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
627
  B, N, _ = x.shape