LiquidFlow-Gen / liquid_flow /physics_loss.py
krystv's picture
Upload liquid_flow/physics_loss.py
1a3345c verified
"""
Physics-Informed Regularization for LiquidFlow.
From: "Physics-Informed Diffusion Models" (Bastek & Sun, ICLR 2025)
and "PID: Physics-Informed Diffusion for IR Image Generation" (Mao et al., 2024)
Physics losses act as TRAINING-ONLY regularizers — they don't affect
inference speed. The pattern:
1. During training: denoise to get x̂₀, compute physics residual, add to loss
2. During inference: no change at all
Implemented physics constraints for image generation:
A. Total Variation (TV) — penalizes non-smooth outputs
L_TV = ||∇_x x̂₀||₁ + ||∇_y x̂₀||₁
→ Enforces spatial smoothness, reduces artifacts
B. Conservation of Intensity — mass conservation across image
L_cons = ||mean(x̂₀) - E[mean(x_ref)]||²
→ Prevents intensity drift
C. Spectral Regularizer — penalizes high-frequency noise
L_spec = ||FFT_high(x̂₀)||²
→ Reduces checkerboard artifacts
D. Gradient Magnitude Balance — prevents exploding gradients in dark regions
L_grad = ||∇x̂₀||² (Sobolev regularization)
→ Stabilizes training in low-signal regions
Pattern: L_total = L_diffusion + λ_TV * L_TV + λ_cons * L_cons + λ_spec * L_spec
The virtual-observable paradigm (from PAD-Hand, 2026):
Physics constraints are SOFT — they guide without requiring perfect satisfaction.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PhysicsRegularizer(nn.Module):
"""
Physics-informed regularizer for image generation training.
All losses are computed on the estimated clean sample x̂₀ during training.
They are ADDITIVE regularizers — just add to the diffusion loss.
Args:
tv_weight: Total Variation weight (default 0.01)
cons_weight: Conservation of intensity weight (default 0.001)
spec_weight: Spectral regularizer weight (default 0.01)
grad_weight: Gradient magnitude penalty weight (default 0.001)
"""
def __init__(
self,
tv_weight=0.01,
cons_weight=0.001,
spec_weight=0.01,
grad_weight=0.001,
):
super().__init__()
self.tv_weight = tv_weight
self.cons_weight = cons_weight
self.spec_weight = spec_weight
self.grad_weight = grad_weight
# Running mean for intensity conservation
self.register_buffer('intensity_mean', torch.tensor(0.0))
self.register_buffer('intensity_count', torch.tensor(0))
self.intensity_alpha = 0.99 # EMA decay
def total_variation(self, x):
"""
Total Variation loss on image batch x.
L_TV = mean(|x_{i+1,j} - x_{i,j}| + |x_{i,j+1} - x_{i,j}|)
Args:
x: [B, C, H, W] images
Returns:
tv_loss: scalar
"""
diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
return diff_h.mean() + diff_w.mean()
def conservation_intensity(self, x):
"""
Conservation of image intensity (mass).
L_cons = (mean(x) - running_mean)^2
This prevents the generator from drifting into producing
images that are too dark or too bright.
Args:
x: [B, C, H, W] images
Returns:
cons_loss: scalar
"""
batch_mean = x.mean()
# Update running statistics
if self.training:
with torch.no_grad():
self.intensity_mean = (
self.intensity_alpha * self.intensity_mean +
(1 - self.intensity_alpha) * batch_mean.detach()
)
# Conservation loss: penalize deviation from running mean
if self.intensity_count > 100: # Only after some warmup
return ((batch_mean - self.intensity_mean) ** 2).mean()
return torch.tensor(0.0, device=x.device)
def spectral_regularizer(self, x):
"""
Spectral regularizer: penalize high-frequency content.
Uses FFT and penalizes high-frequency components.
This prevents high-frequency artifacts (checkerboard patterns).
Args:
x: [B, C, H, W] images
Returns:
spec_loss: scalar
"""
# 2D FFT
x_fft = torch.fft.fft2(x)
x_fft_shift = torch.fft.fftshift(x_fft)
# Create high-frequency mask (center is low frequency)
B, C, H, W = x.shape
h_center, w_center = H // 2, W // 2
y, x_coord = torch.meshgrid(
torch.arange(H, device=x.device),
torch.arange(W, device=x.device),
indexing='ij'
)
dist = torch.sqrt((y - h_center) ** 2 + (x_coord - w_center) ** 2)
# High frequency: distance > quarter of image size
high_freq_mask = (dist > min(H, W) / 4).float()
# Penalize high-frequency magnitude
spec_mag = torch.abs(x_fft_shift)
high_freq_energy = (spec_mag * high_freq_mask.unsqueeze(0).unsqueeze(0)).mean()
return high_freq_energy
def gradient_penalty(self, x):
"""
Sobolev gradient penalty.
L_grad = ||∇x||² (mean squared gradient magnitude)
This prevents the generator from creating regions where
gradients explode (common in GAN-like training).
For diffusion, this helps stabilize the noise prediction.
Args:
x: [B, C, H, W] images
Returns:
grad_loss: scalar
"""
grad_h = x[:, :, 1:, :] - x[:, :, :-1, :]
grad_w = x[:, :, :, 1:] - x[:, :, :, :-1]
grad_mag = (grad_h ** 2).mean() + (grad_w ** 2).mean()
return grad_mag
def forward(self, x0_hat, x_ref=None):
"""
Compute total physics loss.
Args:
x0_hat: Estimated clean image [B, C, H, W]
x_ref: Optional ground truth reference (for intensity tracking)
Returns:
total_loss: Combined physics regularizer (scalar)
loss_dict: Dict of individual losses
"""
losses = {}
# Total Variation
if self.tv_weight > 0:
losses['tv'] = self.total_variation(x0_hat)
# Conservation of Intensity
if self.cons_weight > 0:
losses['cons'] = self.conservation_intensity(x0_hat)
# Spectral Regularizer
if self.spec_weight > 0:
losses['spec'] = self.spectral_regularizer(x0_hat)
# Gradient Penalty
if self.grad_weight > 0:
losses['grad'] = self.gradient_penalty(x0_hat)
# Weighted sum
total = (
self.tv_weight * losses.get('tv', 0.0) +
self.cons_weight * losses.get('cons', 0.0) +
self.spec_weight * losses.get('spec', 0.0) +
self.grad_weight * losses.get('grad', 0.0)
)
return total, losses
class DDIMEstimator:
"""
DDIM clean-sample estimator for physics loss computation.
From the Bastek & Sun (ICLR 2025) pattern:
x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t)
This provides an estimate of the clean sample at training time
without requiring full reverse diffusion.
"""
@staticmethod
def estimate_x0(x_t, eps_pred, alpha_bar_t):
"""
Estimate clean sample from noisy sample and predicted noise.
Args:
x_t: Noisy sample [B, C, H, W]
eps_pred: Predicted noise [B, C, H, W]
alpha_bar_t: Cumulative product of alphas at timestep t [B]
Returns:
x0_hat: Estimated clean sample [B, C, H, W]
"""
alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1)
x0_hat = (x_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)
return x0_hat
@staticmethod
def estimate_noise(x_t, x0_hat, alpha_bar_t):
"""Reverse: estimate noise from clean sample."""
alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1)
eps_pred = (x_t - torch.sqrt(alpha_bar_t) * x0_hat) / torch.sqrt(1 - alpha_bar_t)
return eps_pred