Perf fix: cache Manhattan dist, optimize scan, pre-cache CLIP, fix deprecated AMP API
Browse files- iris_model.py +47 -30
iris_model.py
CHANGED
|
@@ -433,29 +433,36 @@ class GatedLinearRecurrence(nn.Module):
|
|
| 433 |
self.proj_out = nn.Linear(recurrence_dim * 2, dim)
|
| 434 |
|
| 435 |
def _scan(self, x: torch.Tensor) -> torch.Tensor:
|
| 436 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
B, N, D = x.shape
|
| 438 |
|
| 439 |
-
# Compute gates
|
| 440 |
a_base = torch.sigmoid(self.Lambda) # [D]
|
| 441 |
-
r = torch.sigmoid(self.W_a(x)) # [B, N, D]
|
| 442 |
-
i = torch.sigmoid(self.W_x(x)) # [B, N, D]
|
| 443 |
|
| 444 |
# a_t = a_base^(c * r_t) — data-dependent decay
|
| 445 |
a = a_base.pow(self.c * r) # [B, N, D]
|
| 446 |
|
| 447 |
# Normalized input: sqrt(1 - a^2) for variance preservation
|
| 448 |
input_scale = torch.sqrt(1.0 - a * a + 1e-8)
|
| 449 |
-
|
| 450 |
|
| 451 |
-
# Sequential recurrence
|
| 452 |
-
|
|
|
|
| 453 |
h = torch.zeros(B, D, device=x.device, dtype=x.dtype)
|
|
|
|
|
|
|
| 454 |
for t in range(N):
|
| 455 |
-
h = a[:, t] * h +
|
| 456 |
-
outputs
|
| 457 |
|
| 458 |
-
return
|
| 459 |
|
| 460 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 461 |
B, N, D = x.shape
|
|
@@ -476,7 +483,7 @@ class GatedLinearRecurrence(nn.Module):
|
|
| 476 |
class ManhattanSpatialGate(nn.Module):
|
| 477 |
"""Pathway 3: Manhattan distance spatial decay gating.
|
| 478 |
Provides learned 2D spatial inductive bias with per-head multi-scale receptive fields.
|
| 479 |
-
Uses windowed computation for efficiency.
|
| 480 |
"""
|
| 481 |
def __init__(self, dim: int, num_heads: int, window: int = 16):
|
| 482 |
super().__init__()
|
|
@@ -493,49 +500,59 @@ class ManhattanSpatialGate(nn.Module):
|
|
| 493 |
self.v_proj = nn.Linear(dim, dim)
|
| 494 |
self.g_proj = nn.Linear(dim, dim)
|
| 495 |
self.o_proj = nn.Linear(dim, dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
def _get_manhattan_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor:
|
| 498 |
-
"""Compute Manhattan distance matrix
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
return dist
|
| 508 |
|
| 509 |
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 510 |
B, N, D = x.shape
|
| 511 |
|
| 512 |
-
# Compute spatial decay
|
| 513 |
gamma = torch.sigmoid(self.gamma_logit) # [num_heads]
|
| 514 |
manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N]
|
| 515 |
|
| 516 |
-
# Window
|
| 517 |
-
|
| 518 |
-
decay_mask = (manhattan_dist <= self.window).float()
|
| 519 |
|
| 520 |
-
# Per-head decay: gamma_h^dist
|
|
|
|
| 521 |
decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
|
| 522 |
-
decay = decay * decay_mask
|
| 523 |
|
| 524 |
# Value and gate
|
| 525 |
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
|
| 526 |
g = torch.sigmoid(self.g_proj(x))
|
| 527 |
|
| 528 |
# Apply spatial decay to values
|
| 529 |
-
# [B, heads, N, head_dim] = [heads, N, N] @ [B, heads, N, head_dim]
|
| 530 |
v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
|
| 531 |
out = torch.matmul(decay.unsqueeze(0), v) # [B, heads, N, head_dim]
|
| 532 |
|
| 533 |
# Normalize by decay sum
|
| 534 |
-
decay_sum = decay.sum(dim=-1, keepdim=True).unsqueeze(0) + 1e-8
|
| 535 |
out = out / decay_sum
|
| 536 |
|
| 537 |
-
out = out.permute(0, 2, 1, 3).reshape(B, N, D)
|
| 538 |
-
out = out * g
|
| 539 |
return self.o_proj(out)
|
| 540 |
|
| 541 |
|
|
|
|
| 433 |
self.proj_out = nn.Linear(recurrence_dim * 2, dim)
|
| 434 |
|
| 435 |
def _scan(self, x: torch.Tensor) -> torch.Tensor:
|
| 436 |
+
"""Gated linear recurrence scan. x: [B, N, rec_dim]
|
| 437 |
+
|
| 438 |
+
Uses chunked computation to reduce Python loop overhead.
|
| 439 |
+
For production, replace with a CUDA parallel scan kernel.
|
| 440 |
+
"""
|
| 441 |
B, N, D = x.shape
|
| 442 |
|
| 443 |
+
# Compute all gates in one shot (parallelized)
|
| 444 |
a_base = torch.sigmoid(self.Lambda) # [D]
|
| 445 |
+
r = torch.sigmoid(self.W_a(x)) # [B, N, D]
|
| 446 |
+
i = torch.sigmoid(self.W_x(x)) # [B, N, D]
|
| 447 |
|
| 448 |
# a_t = a_base^(c * r_t) — data-dependent decay
|
| 449 |
a = a_base.pow(self.c * r) # [B, N, D]
|
| 450 |
|
| 451 |
# Normalized input: sqrt(1 - a^2) for variance preservation
|
| 452 |
input_scale = torch.sqrt(1.0 - a * a + 1e-8)
|
| 453 |
+
u = input_scale * (i * x) # [B, N, D]
|
| 454 |
|
| 455 |
+
# Sequential recurrence — use contiguous tensors for speed
|
| 456 |
+
a = a.contiguous()
|
| 457 |
+
u = u.contiguous()
|
| 458 |
h = torch.zeros(B, D, device=x.device, dtype=x.dtype)
|
| 459 |
+
outputs = torch.empty_like(u) # Pre-allocate output
|
| 460 |
+
|
| 461 |
for t in range(N):
|
| 462 |
+
h = a[:, t] * h + u[:, t]
|
| 463 |
+
outputs[:, t] = h
|
| 464 |
|
| 465 |
+
return outputs
|
| 466 |
|
| 467 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 468 |
B, N, D = x.shape
|
|
|
|
| 483 |
class ManhattanSpatialGate(nn.Module):
|
| 484 |
"""Pathway 3: Manhattan distance spatial decay gating.
|
| 485 |
Provides learned 2D spatial inductive bias with per-head multi-scale receptive fields.
|
| 486 |
+
Uses CACHED distance matrix and sparse windowed computation for efficiency.
|
| 487 |
"""
|
| 488 |
def __init__(self, dim: int, num_heads: int, window: int = 16):
|
| 489 |
super().__init__()
|
|
|
|
| 500 |
self.v_proj = nn.Linear(dim, dim)
|
| 501 |
self.g_proj = nn.Linear(dim, dim)
|
| 502 |
self.o_proj = nn.Linear(dim, dim)
|
| 503 |
+
|
| 504 |
+
# Cache for distance matrix (computed once, reused)
|
| 505 |
+
self._cached_dist = None
|
| 506 |
+
self._cached_shape = None
|
| 507 |
|
| 508 |
def _get_manhattan_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor:
|
| 509 |
+
"""Compute Manhattan distance matrix — CACHED after first call."""
|
| 510 |
+
shape_key = (H, W, device)
|
| 511 |
+
if self._cached_dist is not None and self._cached_shape == shape_key:
|
| 512 |
+
return self._cached_dist
|
| 513 |
+
|
| 514 |
+
# Build coordinate grid
|
| 515 |
+
rows = torch.arange(H, device=device)
|
| 516 |
+
cols = torch.arange(W, device=device)
|
| 517 |
+
grid_r, grid_c = torch.meshgrid(rows, cols, indexing='ij')
|
| 518 |
+
coords = torch.stack([grid_r.reshape(-1), grid_c.reshape(-1)], dim=-1).float() # [N, 2]
|
| 519 |
+
|
| 520 |
+
# Manhattan distance via broadcasting (faster than cdist)
|
| 521 |
+
dist = (coords[:, None, :] - coords[None, :, :]).abs().sum(dim=-1) # [N, N]
|
| 522 |
+
|
| 523 |
+
self._cached_dist = dist
|
| 524 |
+
self._cached_shape = shape_key
|
| 525 |
return dist
|
| 526 |
|
| 527 |
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 528 |
B, N, D = x.shape
|
| 529 |
|
| 530 |
+
# Compute spatial decay (distance matrix is cached)
|
| 531 |
gamma = torch.sigmoid(self.gamma_logit) # [num_heads]
|
| 532 |
manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N]
|
| 533 |
|
| 534 |
+
# Window mask — only positions within window distance contribute
|
| 535 |
+
decay_mask = (manhattan_dist <= self.window) # bool [N, N]
|
|
|
|
| 536 |
|
| 537 |
+
# Per-head decay: gamma_h^dist, masked to window
|
| 538 |
+
# Only compute pow for positions within window (sparse)
|
| 539 |
decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
|
| 540 |
+
decay = decay * decay_mask.unsqueeze(0).float()
|
| 541 |
|
| 542 |
# Value and gate
|
| 543 |
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
|
| 544 |
g = torch.sigmoid(self.g_proj(x))
|
| 545 |
|
| 546 |
# Apply spatial decay to values
|
|
|
|
| 547 |
v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
|
| 548 |
out = torch.matmul(decay.unsqueeze(0), v) # [B, heads, N, head_dim]
|
| 549 |
|
| 550 |
# Normalize by decay sum
|
| 551 |
+
decay_sum = decay.sum(dim=-1, keepdim=True).unsqueeze(0) + 1e-8
|
| 552 |
out = out / decay_sum
|
| 553 |
|
| 554 |
+
out = out.permute(0, 2, 1, 3).reshape(B, N, D)
|
| 555 |
+
out = out * g
|
| 556 |
return self.o_proj(out)
|
| 557 |
|
| 558 |
|