asdf98 commited on
Commit
8c17293
·
verified ·
1 Parent(s): 70ea268

Fix conv2d bf16 crash on T4: iris/pde_ssm.py

Browse files
Files changed (1) hide show
  1. iris/pde_ssm.py +5 -3
iris/pde_ssm.py CHANGED
@@ -183,9 +183,11 @@ class PDESSMBlock(nn.Module):
183
  x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W)
184
 
185
  # Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing)
186
- spectral_out = self.spectral(x_2d)
187
- local_out = self.local_conv(x_2d)
188
- mixed = spectral_out + local_out
 
 
189
  mixed = self.token_diff(mixed)
190
 
191
  # Back to sequence format
 
183
  x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W)
184
 
185
  # Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing)
186
+ # Run conv in float32 — grouped/1x1 convs lack bf16 cuDNN kernels on T4
187
+ with torch.amp.autocast(device_type='cuda', enabled=False):
188
+ spectral_out = self.spectral(x_2d.float())
189
+ local_out = self.local_conv(x_2d.float())
190
+ mixed = spectral_out.to(x.dtype) + local_out.to(x.dtype)
191
  mixed = self.token_diff(mixed)
192
 
193
  # Back to sequence format