data-archetype commited on
Commit
6a12ad8
·
verified ·
1 Parent(s): e12172b

Fix bf16 standalone RMSNorm precision

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. fcdm_diffae/norms.py +4 -4
README.md CHANGED
@@ -17,6 +17,7 @@ library_name: fcdm_diffae
17
 
18
  | Date | Change |
19
  |------|--------|
 
20
  | 2026-04-08 | Initial release |
21
 
22
  **Experimental patch-32 version** of
 
17
 
18
  | Date | Change |
19
  |------|--------|
20
+ | 2026-04-10 | Refresh standalone package: fix bf16 RMSNorm precision path in both encoder and decoder to match training code; local export tooling now preserves fp32 EMA weights for future re-exports |
21
  | 2026-04-08 | Initial release |
22
 
23
  **Experimental patch-32 version** of
fcdm_diffae/norms.py CHANGED
@@ -30,10 +30,10 @@ class ChannelWiseRMSNorm(nn.Module):
30
  # Float32 accumulation for stability
31
  ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
  inv_rms = torch.rsqrt(ms + self._eps)
33
- y = x * inv_rms
34
  if self.weight is not None:
35
  shape = (1, -1) + (1,) * (x.dim() - 2)
36
- y = y * self.weight.view(shape).to(dtype=y.dtype)
37
  if self.bias is not None:
38
- y = y + self.bias.view(shape).to(dtype=y.dtype)
39
- return y.to(dtype=x.dtype)
 
30
  # Float32 accumulation for stability
31
  ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
  inv_rms = torch.rsqrt(ms + self._eps)
33
+ y = x * inv_rms.to(dtype=x.dtype)
34
  if self.weight is not None:
35
  shape = (1, -1) + (1,) * (x.dim() - 2)
36
+ y = y * self.weight.view(shape).to(dtype=x.dtype)
37
  if self.bias is not None:
38
+ y = y + self.bias.view(shape).to(dtype=x.dtype)
39
+ return y