Fix conv2d bf16 crash on T4: iris/blocks.py
Browse files- iris/blocks.py +4 -1
iris/blocks.py
CHANGED
|
@@ -134,7 +134,10 @@ class UIBFFN(nn.Module):
|
|
| 134 |
h = self.pw_up(x)
|
| 135 |
g = F.silu(self.gate(x))
|
| 136 |
h_2d = h.view(B, H, W, -1).permute(0, 3, 1, 2)
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
| 138 |
return residual + self.pw_down(h * g)
|
| 139 |
|
| 140 |
|
|
|
|
| 134 |
h = self.pw_up(x)
|
| 135 |
g = F.silu(self.gate(x))
|
| 136 |
h_2d = h.view(B, H, W, -1).permute(0, 3, 1, 2)
|
| 137 |
+
# Run depthwise conv in float32 — grouped convs lack bf16 cuDNN kernels on T4
|
| 138 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
| 139 |
+
h = self.dw_conv(h_2d.float()).permute(0, 2, 3, 1).reshape(B, N, -1)
|
| 140 |
+
h = h.to(g.dtype)
|
| 141 |
return residual + self.pw_down(h * g)
|
| 142 |
|
| 143 |
|