asdf98 commited on
Commit
dd9c2aa
·
verified ·
1 Parent(s): 8c17293

Fix conv2d bf16 crash on T4: iris/blocks.py

Browse files
Files changed (1) hide show
  1. iris/blocks.py +4 -1
iris/blocks.py CHANGED
@@ -134,7 +134,10 @@ class UIBFFN(nn.Module):
134
  h = self.pw_up(x)
135
  g = F.silu(self.gate(x))
136
  h_2d = h.view(B, H, W, -1).permute(0, 3, 1, 2)
137
- h = self.dw_conv(h_2d).permute(0, 2, 3, 1).reshape(B, N, -1)
 
 
 
138
  return residual + self.pw_down(h * g)
139
 
140
 
 
134
  h = self.pw_up(x)
135
  g = F.silu(self.gate(x))
136
  h_2d = h.view(B, H, W, -1).permute(0, 3, 1, 2)
137
+ # Run depthwise conv in float32 grouped convs lack bf16 cuDNN kernels on T4
138
+ with torch.amp.autocast(device_type='cuda', enabled=False):
139
+ h = self.dw_conv(h_2d.float()).permute(0, 2, 3, 1).reshape(B, N, -1)
140
+ h = h.to(g.dtype)
141
  return residual + self.pw_down(h * g)
142
 
143