asdf98 commited on
Commit
3afddfb
·
verified ·
1 Parent(s): 359afd9

Delete luminars/ssm.py

Browse files
Files changed (1) hide show
  1. luminars/ssm.py +0 -113
luminars/ssm.py DELETED
@@ -1,113 +0,0 @@
1
- """
2
- Spatial Recurrent Block (SRB) -- inspired by RWKV + VMamba-UNet.
3
- Uses depthwise conv for spatial token-shift and channel-wise decay mixing.
4
- Pure PyTorch, no heavy deps.
5
- """
6
- import math
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
-
12
- def rmsnorm(x, eps=1e-6):
13
- return x * torch.rsqrt(x.mean(dim=-1, keepdim=True) ** 2 + eps)
14
-
15
-
16
- class RMSNorm(nn.Module):
17
- def __init__(self, dim, eps=1e-6):
18
- super().__init__()
19
- self.eps = eps
20
- self.gamma = nn.Parameter(torch.ones(dim))
21
- def forward(self, x):
22
- # x: (..., dim)
23
- norm = x.norm(2, dim=-1, keepdim=True) / math.sqrt(x.shape[-1])
24
- return self.gamma * x / (norm + self.eps)
25
-
26
-
27
- class SpatialRecurrentBlock(nn.Module):
28
- """
29
- A block that:
30
- 1. Token-shifts spatially with a 3x3 depthwise conv (spatial mixing)
31
- 2. Applies channel-wise decay-mixing (RWKV time-mix equivalent)
32
- 3. Returns residual output
33
-
34
- Channels always treated as sequence dim for the SSM part.
35
- Spatial dims are folded into batch.
36
- """
37
- def __init__(self, dim, d_state=64, drop_path=0.0):
38
- super().__init__()
39
- self.dim = dim
40
- self.d_state = d_state
41
-
42
- # Spatial token shift (depthwise 3x3 conv)
43
- self.spatial_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
44
- self.spatial_norm = RMSNorm(dim)
45
-
46
- # Input-dependent selective projections
47
- self.x_proj_in = nn.Linear(dim, d_state * 2 + 1, bias=False) # [B, C, decay]
48
- self.x_proj_A = nn.Parameter(torch.arange(d_state).float() * -math.log(10000) / d_state) # S4D init
49
-
50
- # State-to-output
51
- self.state_out = nn.Linear(d_state, dim, bias=False)
52
- self.D = nn.Parameter(torch.ones(dim) * 1.0) # skip
53
-
54
- # Post-MLP
55
- self.mlp = nn.Sequential(
56
- RMSNorm(dim),
57
- nn.Linear(dim, dim * 2),
58
- nn.GELU(),
59
- nn.Linear(dim * 2, dim),
60
- )
61
-
62
- # Drop path (stochastic depth)
63
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
64
-
65
- def forward(self, x):
66
- """
67
- x: (B, C, H, W)
68
- Returns: (B, C, H, W)
69
- """
70
- B, C, H, W = x.shape
71
- shortcut = x
72
-
73
- # --- SPATIAL TOKEN SHIFT ---
74
- x_shift = self.spatial_conv(x) # (B, C, H, W)
75
-
76
- # Flatten to sequence for SSM: (B*H*W, C)
77
- x_flat = rearrange_for_ssm(x_shift) # (BHW, C)
78
-
79
- # --- SELECTIVE STATE SPACE (MAMBA-style) ---
80
- # Per-token selectivity
81
- params = self.x_proj_in(x_flat) # (BHW, d_state*2 + 1)
82
- B_param, C_param, delta_log = params.split([self.d_state, self.d_state, 1], dim=-1)
83
-
84
- delta = F.softplus(delta_log.squeeze(-1)) # (BHW,)
85
-
86
- # Discretize A
87
- A = -torch.exp(self.x_proj_A) # negative for stability
88
- A_bar = torch.exp(delta.unsqueeze(-1) * A) # (BHW, d_state)
89
-
90
- # Input-to-state
91
- Bx = B_param * x_flat # (BHW, d_state)
92
-
93
- # RECURRENT SCAN (vectorized over batch)
94
- state = torch.zeros(B * H * W, self.d_state, device=x.device, dtype=x.dtype)
95
- states = []
96
- for t in range(C): # scan along channel dim (like token dim)
97
- state = A_bar * state + Bx.unsqueeze(1) # broadcasting issue
98
- # NO -- need to redesign. This is wrong.
99
- pass
100
-
101
- # Actually, the canonical approach for vision: treat spatial positions as tokens.
102
- # Each pixel = one token. Scan in raster order, or better: bidirectional scan.
103
- # BUT for a 32x32 image that's 1024 tokens. Scanning in PyTorch sequentially is SLOW.
104
-
105
- # SOLUTION: Use a DIFFERENT architecture altogether.
106
- # Instead of token-scanning SSM, use RWKV's time-mixing formula generalized to 2D:
107
- # y_i = sigmoid(gate_i) * (decay_i * prev_i + (1-decay_i) * x_i)
108
- # where prev_i is previous token mixed spatially via depthwise conv.
109
- #
110
- # This avoids seq scan: all operations are parallel.
111
-
112
- # REWRITE:
113
- pass