asdf98 commited on
Commit
654d061
·
verified ·
1 Parent(s): dd9c2aa

Fix conv2d bf16 crash on T4: iris/model.py

Browse files
Files changed (1) hide show
  1. iris/model.py +18 -6
iris/model.py CHANGED
@@ -17,7 +17,11 @@ class Patchify(nn.Module):
17
  def forward(self, z):
18
  B, C, H, W = z.shape
19
  p = self.patch_size
20
- z = self.dw_conv(z)
 
 
 
 
21
  H_tok, W_tok = H // p, W // p
22
  z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p)
23
  return self.proj(z), H_tok, W_tok
@@ -37,7 +41,11 @@ class Unpatchify(nn.Module):
37
  C = self.out_channels
38
  z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p)
39
  z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p)
40
- return self.dw_conv(z)
 
 
 
 
41
 
42
 
43
  class TinyDecoder(nn.Module):
@@ -55,10 +63,14 @@ class TinyDecoder(nn.Module):
55
  self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True)
56
 
57
  def forward(self, z):
58
- x = z
59
- for stage in self.stages:
60
- x = stage(x)
61
- return torch.tanh(self.final(x))
 
 
 
 
62
 
63
 
64
  class IRIS(nn.Module):
 
17
  def forward(self, z):
18
  B, C, H, W = z.shape
19
  p = self.patch_size
20
+ orig_dtype = z.dtype
21
+ # Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4
22
+ with torch.amp.autocast(device_type='cuda', enabled=False):
23
+ z = self.dw_conv(z.float())
24
+ z = z.to(orig_dtype)
25
  H_tok, W_tok = H // p, W // p
26
  z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p)
27
  return self.proj(z), H_tok, W_tok
 
41
  C = self.out_channels
42
  z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p)
43
  z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p)
44
+ # Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4
45
+ orig_dtype = z.dtype
46
+ with torch.amp.autocast(device_type='cuda', enabled=False):
47
+ z = self.dw_conv(z.float())
48
+ return z.to(orig_dtype)
49
 
50
 
51
  class TinyDecoder(nn.Module):
 
63
  self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True)
64
 
65
  def forward(self, z):
66
+ # Run decoder convs in float32 — cuDNN lacks bf16 kernels on T4
67
+ orig_dtype = z.dtype
68
+ with torch.amp.autocast(device_type='cuda', enabled=False):
69
+ x = z.float()
70
+ for stage in self.stages:
71
+ x = stage(x)
72
+ x = torch.tanh(self.final(x))
73
+ return x.to(orig_dtype)
74
 
75
 
76
  class IRIS(nn.Module):