| """ |
| Physics-Informed Regularization for LiquidFlow. |
| |
| From: "Physics-Informed Diffusion Models" (Bastek & Sun, ICLR 2025) |
| and "PID: Physics-Informed Diffusion for IR Image Generation" (Mao et al., 2024) |
| |
| Physics losses act as TRAINING-ONLY regularizers — they don't affect |
| inference speed. The pattern: |
| |
| 1. During training: denoise to get x̂₀, compute physics residual, add to loss |
| 2. During inference: no change at all |
| |
| Implemented physics constraints for image generation: |
| |
| A. Total Variation (TV) — penalizes non-smooth outputs |
| L_TV = ||∇_x x̂₀||₁ + ||∇_y x̂₀||₁ |
| → Enforces spatial smoothness, reduces artifacts |
| |
| B. Conservation of Intensity — mass conservation across image |
| L_cons = ||mean(x̂₀) - E[mean(x_ref)]||² |
| → Prevents intensity drift |
| |
| C. Spectral Regularizer — penalizes high-frequency noise |
| L_spec = ||FFT_high(x̂₀)||² |
| → Reduces checkerboard artifacts |
| |
| D. Gradient Magnitude Balance — prevents exploding gradients in dark regions |
| L_grad = ||∇x̂₀||² (Sobolev regularization) |
| → Stabilizes training in low-signal regions |
| |
| Pattern: L_total = L_diffusion + λ_TV * L_TV + λ_cons * L_cons + λ_spec * L_spec |
| |
| The virtual-observable paradigm (from PAD-Hand, 2026): |
| Physics constraints are SOFT — they guide without requiring perfect satisfaction. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class PhysicsRegularizer(nn.Module): |
| """ |
| Physics-informed regularizer for image generation training. |
| |
| All losses are computed on the estimated clean sample x̂₀ during training. |
| They are ADDITIVE regularizers — just add to the diffusion loss. |
| |
| Args: |
| tv_weight: Total Variation weight (default 0.01) |
| cons_weight: Conservation of intensity weight (default 0.001) |
| spec_weight: Spectral regularizer weight (default 0.01) |
| grad_weight: Gradient magnitude penalty weight (default 0.001) |
| """ |
| |
| def __init__( |
| self, |
| tv_weight=0.01, |
| cons_weight=0.001, |
| spec_weight=0.01, |
| grad_weight=0.001, |
| ): |
| super().__init__() |
| self.tv_weight = tv_weight |
| self.cons_weight = cons_weight |
| self.spec_weight = spec_weight |
| self.grad_weight = grad_weight |
| |
| |
| self.register_buffer('intensity_mean', torch.tensor(0.0)) |
| self.register_buffer('intensity_count', torch.tensor(0)) |
| self.intensity_alpha = 0.99 |
| |
| def total_variation(self, x): |
| """ |
| Total Variation loss on image batch x. |
| |
| L_TV = mean(|x_{i+1,j} - x_{i,j}| + |x_{i,j+1} - x_{i,j}|) |
| |
| Args: |
| x: [B, C, H, W] images |
| Returns: |
| tv_loss: scalar |
| """ |
| diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) |
| diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) |
| return diff_h.mean() + diff_w.mean() |
| |
| def conservation_intensity(self, x): |
| """ |
| Conservation of image intensity (mass). |
| |
| L_cons = (mean(x) - running_mean)^2 |
| |
| This prevents the generator from drifting into producing |
| images that are too dark or too bright. |
| |
| Args: |
| x: [B, C, H, W] images |
| Returns: |
| cons_loss: scalar |
| """ |
| batch_mean = x.mean() |
| |
| |
| if self.training: |
| with torch.no_grad(): |
| self.intensity_mean = ( |
| self.intensity_alpha * self.intensity_mean + |
| (1 - self.intensity_alpha) * batch_mean.detach() |
| ) |
| |
| |
| if self.intensity_count > 100: |
| return ((batch_mean - self.intensity_mean) ** 2).mean() |
| return torch.tensor(0.0, device=x.device) |
| |
| def spectral_regularizer(self, x): |
| """ |
| Spectral regularizer: penalize high-frequency content. |
| |
| Uses FFT and penalizes high-frequency components. |
| This prevents high-frequency artifacts (checkerboard patterns). |
| |
| Args: |
| x: [B, C, H, W] images |
| Returns: |
| spec_loss: scalar |
| """ |
| |
| x_fft = torch.fft.fft2(x) |
| x_fft_shift = torch.fft.fftshift(x_fft) |
| |
| |
| B, C, H, W = x.shape |
| h_center, w_center = H // 2, W // 2 |
| |
| y, x_coord = torch.meshgrid( |
| torch.arange(H, device=x.device), |
| torch.arange(W, device=x.device), |
| indexing='ij' |
| ) |
| dist = torch.sqrt((y - h_center) ** 2 + (x_coord - w_center) ** 2) |
| |
| |
| high_freq_mask = (dist > min(H, W) / 4).float() |
| |
| |
| spec_mag = torch.abs(x_fft_shift) |
| high_freq_energy = (spec_mag * high_freq_mask.unsqueeze(0).unsqueeze(0)).mean() |
| |
| return high_freq_energy |
| |
| def gradient_penalty(self, x): |
| """ |
| Sobolev gradient penalty. |
| |
| L_grad = ||∇x||² (mean squared gradient magnitude) |
| |
| This prevents the generator from creating regions where |
| gradients explode (common in GAN-like training). |
| For diffusion, this helps stabilize the noise prediction. |
| |
| Args: |
| x: [B, C, H, W] images |
| Returns: |
| grad_loss: scalar |
| """ |
| grad_h = x[:, :, 1:, :] - x[:, :, :-1, :] |
| grad_w = x[:, :, :, 1:] - x[:, :, :, :-1] |
| |
| grad_mag = (grad_h ** 2).mean() + (grad_w ** 2).mean() |
| return grad_mag |
| |
| def forward(self, x0_hat, x_ref=None): |
| """ |
| Compute total physics loss. |
| |
| Args: |
| x0_hat: Estimated clean image [B, C, H, W] |
| x_ref: Optional ground truth reference (for intensity tracking) |
| |
| Returns: |
| total_loss: Combined physics regularizer (scalar) |
| loss_dict: Dict of individual losses |
| """ |
| losses = {} |
| |
| |
| if self.tv_weight > 0: |
| losses['tv'] = self.total_variation(x0_hat) |
| |
| |
| if self.cons_weight > 0: |
| losses['cons'] = self.conservation_intensity(x0_hat) |
| |
| |
| if self.spec_weight > 0: |
| losses['spec'] = self.spectral_regularizer(x0_hat) |
| |
| |
| if self.grad_weight > 0: |
| losses['grad'] = self.gradient_penalty(x0_hat) |
| |
| |
| total = ( |
| self.tv_weight * losses.get('tv', 0.0) + |
| self.cons_weight * losses.get('cons', 0.0) + |
| self.spec_weight * losses.get('spec', 0.0) + |
| self.grad_weight * losses.get('grad', 0.0) |
| ) |
| |
| return total, losses |
|
|
|
|
| class DDIMEstimator: |
| """ |
| DDIM clean-sample estimator for physics loss computation. |
| |
| From the Bastek & Sun (ICLR 2025) pattern: |
| x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t) |
| |
| This provides an estimate of the clean sample at training time |
| without requiring full reverse diffusion. |
| """ |
| |
| @staticmethod |
| def estimate_x0(x_t, eps_pred, alpha_bar_t): |
| """ |
| Estimate clean sample from noisy sample and predicted noise. |
| |
| Args: |
| x_t: Noisy sample [B, C, H, W] |
| eps_pred: Predicted noise [B, C, H, W] |
| alpha_bar_t: Cumulative product of alphas at timestep t [B] |
| |
| Returns: |
| x0_hat: Estimated clean sample [B, C, H, W] |
| """ |
| alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1) |
| x0_hat = (x_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t) |
| return x0_hat |
| |
| @staticmethod |
| def estimate_noise(x_t, x0_hat, alpha_bar_t): |
| """Reverse: estimate noise from clean sample.""" |
| alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1) |
| eps_pred = (x_t - torch.sqrt(alpha_bar_t) * x0_hat) / torch.sqrt(1 - alpha_bar_t) |
| return eps_pred |
|
|