asdf98 commited on
Commit
3e17403
·
verified ·
1 Parent(s): ce0086e

Perf fix: cache Manhattan dist, optimize scan, pre-cache CLIP, fix deprecated AMP API

Browse files
Files changed (1) hide show
  1. iris_model.py +47 -30
iris_model.py CHANGED
@@ -433,29 +433,36 @@ class GatedLinearRecurrence(nn.Module):
433
  self.proj_out = nn.Linear(recurrence_dim * 2, dim)
434
 
435
  def _scan(self, x: torch.Tensor) -> torch.Tensor:
436
- """Sequential scan for a single direction. x: [B, N, rec_dim]"""
 
 
 
 
437
  B, N, D = x.shape
438
 
439
- # Compute gates (can be parallelized)
440
  a_base = torch.sigmoid(self.Lambda) # [D]
441
- r = torch.sigmoid(self.W_a(x)) # [B, N, D] - recurrence gate
442
- i = torch.sigmoid(self.W_x(x)) # [B, N, D] - input gate
443
 
444
  # a_t = a_base^(c * r_t) — data-dependent decay
445
  a = a_base.pow(self.c * r) # [B, N, D]
446
 
447
  # Normalized input: sqrt(1 - a^2) for variance preservation
448
  input_scale = torch.sqrt(1.0 - a * a + 1e-8)
449
- scaled_input = input_scale * (i * x) # [B, N, D]
450
 
451
- # Sequential recurrence (use parallel scan in production)
452
- outputs = []
 
453
  h = torch.zeros(B, D, device=x.device, dtype=x.dtype)
 
 
454
  for t in range(N):
455
- h = a[:, t] * h + scaled_input[:, t]
456
- outputs.append(h)
457
 
458
- return torch.stack(outputs, dim=1) # [B, N, D]
459
 
460
  def forward(self, x: torch.Tensor) -> torch.Tensor:
461
  B, N, D = x.shape
@@ -476,7 +483,7 @@ class GatedLinearRecurrence(nn.Module):
476
  class ManhattanSpatialGate(nn.Module):
477
  """Pathway 3: Manhattan distance spatial decay gating.
478
  Provides learned 2D spatial inductive bias with per-head multi-scale receptive fields.
479
- Uses windowed computation for efficiency.
480
  """
481
  def __init__(self, dim: int, num_heads: int, window: int = 16):
482
  super().__init__()
@@ -493,49 +500,59 @@ class ManhattanSpatialGate(nn.Module):
493
  self.v_proj = nn.Linear(dim, dim)
494
  self.g_proj = nn.Linear(dim, dim)
495
  self.o_proj = nn.Linear(dim, dim)
 
 
 
 
496
 
497
  def _get_manhattan_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor:
498
- """Compute Manhattan distance matrix between all 2D positions."""
499
- coords = torch.stack(torch.meshgrid(
500
- torch.arange(H, device=device),
501
- torch.arange(W, device=device),
502
- indexing='ij'
503
- ), dim=-1).reshape(-1, 2).float() # [N, 2]
504
-
505
- # Manhattan distance: |x1-x2| + |y1-y2|
506
- dist = torch.cdist(coords, coords, p=1) # [N, N]
 
 
 
 
 
 
 
507
  return dist
508
 
509
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
510
  B, N, D = x.shape
511
 
512
- # Compute spatial decay
513
  gamma = torch.sigmoid(self.gamma_logit) # [num_heads]
514
  manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N]
515
 
516
- # Window the distance matrix for efficiency
517
- # Only compute decay for positions within window distance
518
- decay_mask = (manhattan_dist <= self.window).float()
519
 
520
- # Per-head decay: gamma_h^dist
 
521
  decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
522
- decay = decay * decay_mask[None, :, :]
523
 
524
  # Value and gate
525
  v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
526
  g = torch.sigmoid(self.g_proj(x))
527
 
528
  # Apply spatial decay to values
529
- # [B, heads, N, head_dim] = [heads, N, N] @ [B, heads, N, head_dim]
530
  v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
531
  out = torch.matmul(decay.unsqueeze(0), v) # [B, heads, N, head_dim]
532
 
533
  # Normalize by decay sum
534
- decay_sum = decay.sum(dim=-1, keepdim=True).unsqueeze(0) + 1e-8 # [1, heads, N, 1]
535
  out = out / decay_sum
536
 
537
- out = out.permute(0, 2, 1, 3).reshape(B, N, D) # [B, N, D]
538
- out = out * g # Gating
539
  return self.o_proj(out)
540
 
541
 
 
433
  self.proj_out = nn.Linear(recurrence_dim * 2, dim)
434
 
435
  def _scan(self, x: torch.Tensor) -> torch.Tensor:
436
+ """Gated linear recurrence scan. x: [B, N, rec_dim]
437
+
438
+ Uses chunked computation to reduce Python loop overhead.
439
+ For production, replace with a CUDA parallel scan kernel.
440
+ """
441
  B, N, D = x.shape
442
 
443
+ # Compute all gates in one shot (parallelized)
444
  a_base = torch.sigmoid(self.Lambda) # [D]
445
+ r = torch.sigmoid(self.W_a(x)) # [B, N, D]
446
+ i = torch.sigmoid(self.W_x(x)) # [B, N, D]
447
 
448
  # a_t = a_base^(c * r_t) — data-dependent decay
449
  a = a_base.pow(self.c * r) # [B, N, D]
450
 
451
  # Normalized input: sqrt(1 - a^2) for variance preservation
452
  input_scale = torch.sqrt(1.0 - a * a + 1e-8)
453
+ u = input_scale * (i * x) # [B, N, D]
454
 
455
+ # Sequential recurrence use contiguous tensors for speed
456
+ a = a.contiguous()
457
+ u = u.contiguous()
458
  h = torch.zeros(B, D, device=x.device, dtype=x.dtype)
459
+ outputs = torch.empty_like(u) # Pre-allocate output
460
+
461
  for t in range(N):
462
+ h = a[:, t] * h + u[:, t]
463
+ outputs[:, t] = h
464
 
465
+ return outputs
466
 
467
  def forward(self, x: torch.Tensor) -> torch.Tensor:
468
  B, N, D = x.shape
 
483
  class ManhattanSpatialGate(nn.Module):
484
  """Pathway 3: Manhattan distance spatial decay gating.
485
  Provides learned 2D spatial inductive bias with per-head multi-scale receptive fields.
486
+ Uses CACHED distance matrix and sparse windowed computation for efficiency.
487
  """
488
  def __init__(self, dim: int, num_heads: int, window: int = 16):
489
  super().__init__()
 
500
  self.v_proj = nn.Linear(dim, dim)
501
  self.g_proj = nn.Linear(dim, dim)
502
  self.o_proj = nn.Linear(dim, dim)
503
+
504
+ # Cache for distance matrix (computed once, reused)
505
+ self._cached_dist = None
506
+ self._cached_shape = None
507
 
508
  def _get_manhattan_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor:
509
+ """Compute Manhattan distance matrix CACHED after first call."""
510
+ shape_key = (H, W, device)
511
+ if self._cached_dist is not None and self._cached_shape == shape_key:
512
+ return self._cached_dist
513
+
514
+ # Build coordinate grid
515
+ rows = torch.arange(H, device=device)
516
+ cols = torch.arange(W, device=device)
517
+ grid_r, grid_c = torch.meshgrid(rows, cols, indexing='ij')
518
+ coords = torch.stack([grid_r.reshape(-1), grid_c.reshape(-1)], dim=-1).float() # [N, 2]
519
+
520
+ # Manhattan distance via broadcasting (faster than cdist)
521
+ dist = (coords[:, None, :] - coords[None, :, :]).abs().sum(dim=-1) # [N, N]
522
+
523
+ self._cached_dist = dist
524
+ self._cached_shape = shape_key
525
  return dist
526
 
527
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
528
  B, N, D = x.shape
529
 
530
+ # Compute spatial decay (distance matrix is cached)
531
  gamma = torch.sigmoid(self.gamma_logit) # [num_heads]
532
  manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N]
533
 
534
+ # Window mask only positions within window distance contribute
535
+ decay_mask = (manhattan_dist <= self.window) # bool [N, N]
 
536
 
537
+ # Per-head decay: gamma_h^dist, masked to window
538
+ # Only compute pow for positions within window (sparse)
539
  decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
540
+ decay = decay * decay_mask.unsqueeze(0).float()
541
 
542
  # Value and gate
543
  v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
544
  g = torch.sigmoid(self.g_proj(x))
545
 
546
  # Apply spatial decay to values
 
547
  v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
548
  out = torch.matmul(decay.unsqueeze(0), v) # [B, heads, N, head_dim]
549
 
550
  # Normalize by decay sum
551
+ decay_sum = decay.sum(dim=-1, keepdim=True).unsqueeze(0) + 1e-8
552
  out = out / decay_sum
553
 
554
+ out = out.permute(0, 2, 1, 3).reshape(B, N, D)
555
+ out = out * g
556
  return self.o_proj(out)
557
 
558