""" Physics-Informed Regularization for LiquidFlow. From: "Physics-Informed Diffusion Models" (Bastek & Sun, ICLR 2025) and "PID: Physics-Informed Diffusion for IR Image Generation" (Mao et al., 2024) Physics losses act as TRAINING-ONLY regularizers — they don't affect inference speed. The pattern: 1. During training: denoise to get x̂₀, compute physics residual, add to loss 2. During inference: no change at all Implemented physics constraints for image generation: A. Total Variation (TV) — penalizes non-smooth outputs L_TV = ||∇_x x̂₀||₁ + ||∇_y x̂₀||₁ → Enforces spatial smoothness, reduces artifacts B. Conservation of Intensity — mass conservation across image L_cons = ||mean(x̂₀) - E[mean(x_ref)]||² → Prevents intensity drift C. Spectral Regularizer — penalizes high-frequency noise L_spec = ||FFT_high(x̂₀)||² → Reduces checkerboard artifacts D. Gradient Magnitude Balance — prevents exploding gradients in dark regions L_grad = ||∇x̂₀||² (Sobolev regularization) → Stabilizes training in low-signal regions Pattern: L_total = L_diffusion + λ_TV * L_TV + λ_cons * L_cons + λ_spec * L_spec The virtual-observable paradigm (from PAD-Hand, 2026): Physics constraints are SOFT — they guide without requiring perfect satisfaction. """ import torch import torch.nn as nn import torch.nn.functional as F class PhysicsRegularizer(nn.Module): """ Physics-informed regularizer for image generation training. All losses are computed on the estimated clean sample x̂₀ during training. They are ADDITIVE regularizers — just add to the diffusion loss. Args: tv_weight: Total Variation weight (default 0.01) cons_weight: Conservation of intensity weight (default 0.001) spec_weight: Spectral regularizer weight (default 0.01) grad_weight: Gradient magnitude penalty weight (default 0.001) """ def __init__( self, tv_weight=0.01, cons_weight=0.001, spec_weight=0.01, grad_weight=0.001, ): super().__init__() self.tv_weight = tv_weight self.cons_weight = cons_weight self.spec_weight = spec_weight self.grad_weight = grad_weight # Running mean for intensity conservation self.register_buffer('intensity_mean', torch.tensor(0.0)) self.register_buffer('intensity_count', torch.tensor(0)) self.intensity_alpha = 0.99 # EMA decay def total_variation(self, x): """ Total Variation loss on image batch x. L_TV = mean(|x_{i+1,j} - x_{i,j}| + |x_{i,j+1} - x_{i,j}|) Args: x: [B, C, H, W] images Returns: tv_loss: scalar """ diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) return diff_h.mean() + diff_w.mean() def conservation_intensity(self, x): """ Conservation of image intensity (mass). L_cons = (mean(x) - running_mean)^2 This prevents the generator from drifting into producing images that are too dark or too bright. Args: x: [B, C, H, W] images Returns: cons_loss: scalar """ batch_mean = x.mean() # Update running statistics if self.training: with torch.no_grad(): self.intensity_mean = ( self.intensity_alpha * self.intensity_mean + (1 - self.intensity_alpha) * batch_mean.detach() ) # Conservation loss: penalize deviation from running mean if self.intensity_count > 100: # Only after some warmup return ((batch_mean - self.intensity_mean) ** 2).mean() return torch.tensor(0.0, device=x.device) def spectral_regularizer(self, x): """ Spectral regularizer: penalize high-frequency content. Uses FFT and penalizes high-frequency components. This prevents high-frequency artifacts (checkerboard patterns). Args: x: [B, C, H, W] images Returns: spec_loss: scalar """ # 2D FFT x_fft = torch.fft.fft2(x) x_fft_shift = torch.fft.fftshift(x_fft) # Create high-frequency mask (center is low frequency) B, C, H, W = x.shape h_center, w_center = H // 2, W // 2 y, x_coord = torch.meshgrid( torch.arange(H, device=x.device), torch.arange(W, device=x.device), indexing='ij' ) dist = torch.sqrt((y - h_center) ** 2 + (x_coord - w_center) ** 2) # High frequency: distance > quarter of image size high_freq_mask = (dist > min(H, W) / 4).float() # Penalize high-frequency magnitude spec_mag = torch.abs(x_fft_shift) high_freq_energy = (spec_mag * high_freq_mask.unsqueeze(0).unsqueeze(0)).mean() return high_freq_energy def gradient_penalty(self, x): """ Sobolev gradient penalty. L_grad = ||∇x||² (mean squared gradient magnitude) This prevents the generator from creating regions where gradients explode (common in GAN-like training). For diffusion, this helps stabilize the noise prediction. Args: x: [B, C, H, W] images Returns: grad_loss: scalar """ grad_h = x[:, :, 1:, :] - x[:, :, :-1, :] grad_w = x[:, :, :, 1:] - x[:, :, :, :-1] grad_mag = (grad_h ** 2).mean() + (grad_w ** 2).mean() return grad_mag def forward(self, x0_hat, x_ref=None): """ Compute total physics loss. Args: x0_hat: Estimated clean image [B, C, H, W] x_ref: Optional ground truth reference (for intensity tracking) Returns: total_loss: Combined physics regularizer (scalar) loss_dict: Dict of individual losses """ losses = {} # Total Variation if self.tv_weight > 0: losses['tv'] = self.total_variation(x0_hat) # Conservation of Intensity if self.cons_weight > 0: losses['cons'] = self.conservation_intensity(x0_hat) # Spectral Regularizer if self.spec_weight > 0: losses['spec'] = self.spectral_regularizer(x0_hat) # Gradient Penalty if self.grad_weight > 0: losses['grad'] = self.gradient_penalty(x0_hat) # Weighted sum total = ( self.tv_weight * losses.get('tv', 0.0) + self.cons_weight * losses.get('cons', 0.0) + self.spec_weight * losses.get('spec', 0.0) + self.grad_weight * losses.get('grad', 0.0) ) return total, losses class DDIMEstimator: """ DDIM clean-sample estimator for physics loss computation. From the Bastek & Sun (ICLR 2025) pattern: x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t) This provides an estimate of the clean sample at training time without requiring full reverse diffusion. """ @staticmethod def estimate_x0(x_t, eps_pred, alpha_bar_t): """ Estimate clean sample from noisy sample and predicted noise. Args: x_t: Noisy sample [B, C, H, W] eps_pred: Predicted noise [B, C, H, W] alpha_bar_t: Cumulative product of alphas at timestep t [B] Returns: x0_hat: Estimated clean sample [B, C, H, W] """ alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1) x0_hat = (x_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t) return x0_hat @staticmethod def estimate_noise(x_t, x0_hat, alpha_bar_t): """Reverse: estimate noise from clean sample.""" alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1) eps_pred = (x_t - torch.sqrt(alpha_bar_t) * x0_hat) / torch.sqrt(1 - alpha_bar_t) return eps_pred