krystv commited on
Commit
1a3345c
·
verified ·
1 Parent(s): b9e8cb3

Upload liquid_flow/physics_loss.py

Browse files
Files changed (1) hide show
  1. liquid_flow/physics_loss.py +249 -0
liquid_flow/physics_loss.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Physics-Informed Regularization for LiquidFlow.
3
+
4
+ From: "Physics-Informed Diffusion Models" (Bastek & Sun, ICLR 2025)
5
+ and "PID: Physics-Informed Diffusion for IR Image Generation" (Mao et al., 2024)
6
+
7
+ Physics losses act as TRAINING-ONLY regularizers — they don't affect
8
+ inference speed. The pattern:
9
+
10
+ 1. During training: denoise to get x̂₀, compute physics residual, add to loss
11
+ 2. During inference: no change at all
12
+
13
+ Implemented physics constraints for image generation:
14
+
15
+ A. Total Variation (TV) — penalizes non-smooth outputs
16
+ L_TV = ||∇_x x̂₀||₁ + ||∇_y x̂₀||₁
17
+ → Enforces spatial smoothness, reduces artifacts
18
+
19
+ B. Conservation of Intensity — mass conservation across image
20
+ L_cons = ||mean(x̂₀) - E[mean(x_ref)]||²
21
+ → Prevents intensity drift
22
+
23
+ C. Spectral Regularizer — penalizes high-frequency noise
24
+ L_spec = ||FFT_high(x̂₀)||²
25
+ → Reduces checkerboard artifacts
26
+
27
+ D. Gradient Magnitude Balance — prevents exploding gradients in dark regions
28
+ L_grad = ||∇x̂₀||² (Sobolev regularization)
29
+ → Stabilizes training in low-signal regions
30
+
31
+ Pattern: L_total = L_diffusion + λ_TV * L_TV + λ_cons * L_cons + λ_spec * L_spec
32
+
33
+ The virtual-observable paradigm (from PAD-Hand, 2026):
34
+ Physics constraints are SOFT — they guide without requiring perfect satisfaction.
35
+ """
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+
42
+ class PhysicsRegularizer(nn.Module):
43
+ """
44
+ Physics-informed regularizer for image generation training.
45
+
46
+ All losses are computed on the estimated clean sample x̂₀ during training.
47
+ They are ADDITIVE regularizers — just add to the diffusion loss.
48
+
49
+ Args:
50
+ tv_weight: Total Variation weight (default 0.01)
51
+ cons_weight: Conservation of intensity weight (default 0.001)
52
+ spec_weight: Spectral regularizer weight (default 0.01)
53
+ grad_weight: Gradient magnitude penalty weight (default 0.001)
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ tv_weight=0.01,
59
+ cons_weight=0.001,
60
+ spec_weight=0.01,
61
+ grad_weight=0.001,
62
+ ):
63
+ super().__init__()
64
+ self.tv_weight = tv_weight
65
+ self.cons_weight = cons_weight
66
+ self.spec_weight = spec_weight
67
+ self.grad_weight = grad_weight
68
+
69
+ # Running mean for intensity conservation
70
+ self.register_buffer('intensity_mean', torch.tensor(0.0))
71
+ self.register_buffer('intensity_count', torch.tensor(0))
72
+ self.intensity_alpha = 0.99 # EMA decay
73
+
74
+ def total_variation(self, x):
75
+ """
76
+ Total Variation loss on image batch x.
77
+
78
+ L_TV = mean(|x_{i+1,j} - x_{i,j}| + |x_{i,j+1} - x_{i,j}|)
79
+
80
+ Args:
81
+ x: [B, C, H, W] images
82
+ Returns:
83
+ tv_loss: scalar
84
+ """
85
+ diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
86
+ diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
87
+ return diff_h.mean() + diff_w.mean()
88
+
89
+ def conservation_intensity(self, x):
90
+ """
91
+ Conservation of image intensity (mass).
92
+
93
+ L_cons = (mean(x) - running_mean)^2
94
+
95
+ This prevents the generator from drifting into producing
96
+ images that are too dark or too bright.
97
+
98
+ Args:
99
+ x: [B, C, H, W] images
100
+ Returns:
101
+ cons_loss: scalar
102
+ """
103
+ batch_mean = x.mean()
104
+
105
+ # Update running statistics
106
+ if self.training:
107
+ with torch.no_grad():
108
+ self.intensity_mean = (
109
+ self.intensity_alpha * self.intensity_mean +
110
+ (1 - self.intensity_alpha) * batch_mean.detach()
111
+ )
112
+
113
+ # Conservation loss: penalize deviation from running mean
114
+ if self.intensity_count > 100: # Only after some warmup
115
+ return ((batch_mean - self.intensity_mean) ** 2).mean()
116
+ return torch.tensor(0.0, device=x.device)
117
+
118
+ def spectral_regularizer(self, x):
119
+ """
120
+ Spectral regularizer: penalize high-frequency content.
121
+
122
+ Uses FFT and penalizes high-frequency components.
123
+ This prevents high-frequency artifacts (checkerboard patterns).
124
+
125
+ Args:
126
+ x: [B, C, H, W] images
127
+ Returns:
128
+ spec_loss: scalar
129
+ """
130
+ # 2D FFT
131
+ x_fft = torch.fft.fft2(x)
132
+ x_fft_shift = torch.fft.fftshift(x_fft)
133
+
134
+ # Create high-frequency mask (center is low frequency)
135
+ B, C, H, W = x.shape
136
+ h_center, w_center = H // 2, W // 2
137
+
138
+ y, x_coord = torch.meshgrid(
139
+ torch.arange(H, device=x.device),
140
+ torch.arange(W, device=x.device),
141
+ indexing='ij'
142
+ )
143
+ dist = torch.sqrt((y - h_center) ** 2 + (x_coord - w_center) ** 2)
144
+
145
+ # High frequency: distance > quarter of image size
146
+ high_freq_mask = (dist > min(H, W) / 4).float()
147
+
148
+ # Penalize high-frequency magnitude
149
+ spec_mag = torch.abs(x_fft_shift)
150
+ high_freq_energy = (spec_mag * high_freq_mask.unsqueeze(0).unsqueeze(0)).mean()
151
+
152
+ return high_freq_energy
153
+
154
+ def gradient_penalty(self, x):
155
+ """
156
+ Sobolev gradient penalty.
157
+
158
+ L_grad = ||∇x||² (mean squared gradient magnitude)
159
+
160
+ This prevents the generator from creating regions where
161
+ gradients explode (common in GAN-like training).
162
+ For diffusion, this helps stabilize the noise prediction.
163
+
164
+ Args:
165
+ x: [B, C, H, W] images
166
+ Returns:
167
+ grad_loss: scalar
168
+ """
169
+ grad_h = x[:, :, 1:, :] - x[:, :, :-1, :]
170
+ grad_w = x[:, :, :, 1:] - x[:, :, :, :-1]
171
+
172
+ grad_mag = (grad_h ** 2).mean() + (grad_w ** 2).mean()
173
+ return grad_mag
174
+
175
+ def forward(self, x0_hat, x_ref=None):
176
+ """
177
+ Compute total physics loss.
178
+
179
+ Args:
180
+ x0_hat: Estimated clean image [B, C, H, W]
181
+ x_ref: Optional ground truth reference (for intensity tracking)
182
+
183
+ Returns:
184
+ total_loss: Combined physics regularizer (scalar)
185
+ loss_dict: Dict of individual losses
186
+ """
187
+ losses = {}
188
+
189
+ # Total Variation
190
+ if self.tv_weight > 0:
191
+ losses['tv'] = self.total_variation(x0_hat)
192
+
193
+ # Conservation of Intensity
194
+ if self.cons_weight > 0:
195
+ losses['cons'] = self.conservation_intensity(x0_hat)
196
+
197
+ # Spectral Regularizer
198
+ if self.spec_weight > 0:
199
+ losses['spec'] = self.spectral_regularizer(x0_hat)
200
+
201
+ # Gradient Penalty
202
+ if self.grad_weight > 0:
203
+ losses['grad'] = self.gradient_penalty(x0_hat)
204
+
205
+ # Weighted sum
206
+ total = (
207
+ self.tv_weight * losses.get('tv', 0.0) +
208
+ self.cons_weight * losses.get('cons', 0.0) +
209
+ self.spec_weight * losses.get('spec', 0.0) +
210
+ self.grad_weight * losses.get('grad', 0.0)
211
+ )
212
+
213
+ return total, losses
214
+
215
+
216
+ class DDIMEstimator:
217
+ """
218
+ DDIM clean-sample estimator for physics loss computation.
219
+
220
+ From the Bastek & Sun (ICLR 2025) pattern:
221
+ x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t)
222
+
223
+ This provides an estimate of the clean sample at training time
224
+ without requiring full reverse diffusion.
225
+ """
226
+
227
+ @staticmethod
228
+ def estimate_x0(x_t, eps_pred, alpha_bar_t):
229
+ """
230
+ Estimate clean sample from noisy sample and predicted noise.
231
+
232
+ Args:
233
+ x_t: Noisy sample [B, C, H, W]
234
+ eps_pred: Predicted noise [B, C, H, W]
235
+ alpha_bar_t: Cumulative product of alphas at timestep t [B]
236
+
237
+ Returns:
238
+ x0_hat: Estimated clean sample [B, C, H, W]
239
+ """
240
+ alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1)
241
+ x0_hat = (x_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)
242
+ return x0_hat
243
+
244
+ @staticmethod
245
+ def estimate_noise(x_t, x0_hat, alpha_bar_t):
246
+ """Reverse: estimate noise from clean sample."""
247
+ alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1)
248
+ eps_pred = (x_t - torch.sqrt(alpha_bar_t) * x0_hat) / torch.sqrt(1 - alpha_bar_t)
249
+ return eps_pred