Fix conv2d bf16 crash on T4: iris/pde_ssm.py
Browse files- iris/pde_ssm.py +5 -3
iris/pde_ssm.py
CHANGED
|
@@ -183,9 +183,11 @@ class PDESSMBlock(nn.Module):
|
|
| 183 |
x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W)
|
| 184 |
|
| 185 |
# Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing)
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
| 189 |
mixed = self.token_diff(mixed)
|
| 190 |
|
| 191 |
# Back to sequence format
|
|
|
|
| 183 |
x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W)
|
| 184 |
|
| 185 |
# Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing)
|
| 186 |
+
# Run conv in float32 — grouped/1x1 convs lack bf16 cuDNN kernels on T4
|
| 187 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
| 188 |
+
spectral_out = self.spectral(x_2d.float())
|
| 189 |
+
local_out = self.local_conv(x_2d.float())
|
| 190 |
+
mixed = spectral_out.to(x.dtype) + local_out.to(x.dtype)
|
| 191 |
mixed = self.token_diff(mixed)
|
| 192 |
|
| 193 |
# Back to sequence format
|