Fix bf16 standalone RMSNorm precision
Browse files- README.md +1 -0
- fcdm_diffae/norms.py +4 -4
README.md
CHANGED
|
@@ -17,6 +17,7 @@ library_name: fcdm_diffae
|
|
| 17 |
|
| 18 |
| Date | Change |
|
| 19 |
|------|--------|
|
|
|
|
| 20 |
| 2026-04-08 | Initial release |
|
| 21 |
|
| 22 |
**Experimental patch-32 version** of
|
|
|
|
| 17 |
|
| 18 |
| Date | Change |
|
| 19 |
|------|--------|
|
| 20 |
+
| 2026-04-10 | Refresh standalone package: fix bf16 RMSNorm precision path in both encoder and decoder to match training code; local export tooling now preserves fp32 EMA weights for future re-exports |
|
| 21 |
| 2026-04-08 | Initial release |
|
| 22 |
|
| 23 |
**Experimental patch-32 version** of
|
fcdm_diffae/norms.py
CHANGED
|
@@ -30,10 +30,10 @@ class ChannelWiseRMSNorm(nn.Module):
|
|
| 30 |
# Float32 accumulation for stability
|
| 31 |
ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
|
| 32 |
inv_rms = torch.rsqrt(ms + self._eps)
|
| 33 |
-
y = x * inv_rms
|
| 34 |
if self.weight is not None:
|
| 35 |
shape = (1, -1) + (1,) * (x.dim() - 2)
|
| 36 |
-
y = y * self.weight.view(shape).to(dtype=
|
| 37 |
if self.bias is not None:
|
| 38 |
-
y = y + self.bias.view(shape).to(dtype=
|
| 39 |
-
return y
|
|
|
|
| 30 |
# Float32 accumulation for stability
|
| 31 |
ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
|
| 32 |
inv_rms = torch.rsqrt(ms + self._eps)
|
| 33 |
+
y = x * inv_rms.to(dtype=x.dtype)
|
| 34 |
if self.weight is not None:
|
| 35 |
shape = (1, -1) + (1,) * (x.dim() - 2)
|
| 36 |
+
y = y * self.weight.view(shape).to(dtype=x.dtype)
|
| 37 |
if self.bias is not None:
|
| 38 |
+
y = y + self.bias.view(shape).to(dtype=x.dtype)
|
| 39 |
+
return y
|