File size: 8,342 Bytes
1a3345c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
Physics-Informed Regularization for LiquidFlow.

From: "Physics-Informed Diffusion Models" (Bastek & Sun, ICLR 2025)
and "PID: Physics-Informed Diffusion for IR Image Generation" (Mao et al., 2024)

Physics losses act as TRAINING-ONLY regularizers — they don't affect
inference speed. The pattern:

1. During training: denoise to get x̂₀, compute physics residual, add to loss
2. During inference: no change at all

Implemented physics constraints for image generation:

A. Total Variation (TV) — penalizes non-smooth outputs
   L_TV = ||∇_x x̂₀||₁ + ||∇_y x̂₀||₁
   → Enforces spatial smoothness, reduces artifacts

B. Conservation of Intensity — mass conservation across image
   L_cons = ||mean(x̂₀) - E[mean(x_ref)]||²
   → Prevents intensity drift

C. Spectral Regularizer — penalizes high-frequency noise
   L_spec = ||FFT_high(x̂₀)||²
   → Reduces checkerboard artifacts

D. Gradient Magnitude Balance — prevents exploding gradients in dark regions
   L_grad = ||∇x̂₀||² (Sobolev regularization)
   → Stabilizes training in low-signal regions

Pattern: L_total = L_diffusion + λ_TV * L_TV + λ_cons * L_cons + λ_spec * L_spec

The virtual-observable paradigm (from PAD-Hand, 2026):
Physics constraints are SOFT — they guide without requiring perfect satisfaction.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class PhysicsRegularizer(nn.Module):
    """
    Physics-informed regularizer for image generation training.
    
    All losses are computed on the estimated clean sample x̂₀ during training.
    They are ADDITIVE regularizers — just add to the diffusion loss.
    
    Args:
        tv_weight: Total Variation weight (default 0.01)
        cons_weight: Conservation of intensity weight (default 0.001)
        spec_weight: Spectral regularizer weight (default 0.01)
        grad_weight: Gradient magnitude penalty weight (default 0.001)
    """
    
    def __init__(
        self,
        tv_weight=0.01,
        cons_weight=0.001,
        spec_weight=0.01,
        grad_weight=0.001,
    ):
        super().__init__()
        self.tv_weight = tv_weight
        self.cons_weight = cons_weight
        self.spec_weight = spec_weight
        self.grad_weight = grad_weight
        
        # Running mean for intensity conservation
        self.register_buffer('intensity_mean', torch.tensor(0.0))
        self.register_buffer('intensity_count', torch.tensor(0))
        self.intensity_alpha = 0.99  # EMA decay
    
    def total_variation(self, x):
        """
        Total Variation loss on image batch x.
        
        L_TV = mean(|x_{i+1,j} - x_{i,j}| + |x_{i,j+1} - x_{i,j}|)
        
        Args:
            x: [B, C, H, W] images
        Returns:
            tv_loss: scalar
        """
        diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
        diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
        return diff_h.mean() + diff_w.mean()
    
    def conservation_intensity(self, x):
        """
        Conservation of image intensity (mass).
        
        L_cons = (mean(x) - running_mean)^2
        
        This prevents the generator from drifting into producing
        images that are too dark or too bright.
        
        Args:
            x: [B, C, H, W] images
        Returns:
            cons_loss: scalar
        """
        batch_mean = x.mean()
        
        # Update running statistics
        if self.training:
            with torch.no_grad():
                self.intensity_mean = (
                    self.intensity_alpha * self.intensity_mean +
                    (1 - self.intensity_alpha) * batch_mean.detach()
                )
        
        # Conservation loss: penalize deviation from running mean
        if self.intensity_count > 100:  # Only after some warmup
            return ((batch_mean - self.intensity_mean) ** 2).mean()
        return torch.tensor(0.0, device=x.device)
    
    def spectral_regularizer(self, x):
        """
        Spectral regularizer: penalize high-frequency content.
        
        Uses FFT and penalizes high-frequency components.
        This prevents high-frequency artifacts (checkerboard patterns).
        
        Args:
            x: [B, C, H, W] images
        Returns:
            spec_loss: scalar
        """
        # 2D FFT
        x_fft = torch.fft.fft2(x)
        x_fft_shift = torch.fft.fftshift(x_fft)
        
        # Create high-frequency mask (center is low frequency)
        B, C, H, W = x.shape
        h_center, w_center = H // 2, W // 2
        
        y, x_coord = torch.meshgrid(
            torch.arange(H, device=x.device),
            torch.arange(W, device=x.device),
            indexing='ij'
        )
        dist = torch.sqrt((y - h_center) ** 2 + (x_coord - w_center) ** 2)
        
        # High frequency: distance > quarter of image size
        high_freq_mask = (dist > min(H, W) / 4).float()
        
        # Penalize high-frequency magnitude
        spec_mag = torch.abs(x_fft_shift)
        high_freq_energy = (spec_mag * high_freq_mask.unsqueeze(0).unsqueeze(0)).mean()
        
        return high_freq_energy
    
    def gradient_penalty(self, x):
        """
        Sobolev gradient penalty.
        
        L_grad = ||∇x||² (mean squared gradient magnitude)
        
        This prevents the generator from creating regions where
        gradients explode (common in GAN-like training).
        For diffusion, this helps stabilize the noise prediction.
        
        Args:
            x: [B, C, H, W] images
        Returns:
            grad_loss: scalar
        """
        grad_h = x[:, :, 1:, :] - x[:, :, :-1, :]
        grad_w = x[:, :, :, 1:] - x[:, :, :, :-1]
        
        grad_mag = (grad_h ** 2).mean() + (grad_w ** 2).mean()
        return grad_mag
    
    def forward(self, x0_hat, x_ref=None):
        """
        Compute total physics loss.
        
        Args:
            x0_hat: Estimated clean image [B, C, H, W]
            x_ref: Optional ground truth reference (for intensity tracking)
        
        Returns:
            total_loss: Combined physics regularizer (scalar)
            loss_dict: Dict of individual losses
        """
        losses = {}
        
        # Total Variation
        if self.tv_weight > 0:
            losses['tv'] = self.total_variation(x0_hat)
        
        # Conservation of Intensity
        if self.cons_weight > 0:
            losses['cons'] = self.conservation_intensity(x0_hat)
        
        # Spectral Regularizer
        if self.spec_weight > 0:
            losses['spec'] = self.spectral_regularizer(x0_hat)
        
        # Gradient Penalty
        if self.grad_weight > 0:
            losses['grad'] = self.gradient_penalty(x0_hat)
        
        # Weighted sum
        total = (
            self.tv_weight * losses.get('tv', 0.0) +
            self.cons_weight * losses.get('cons', 0.0) +
            self.spec_weight * losses.get('spec', 0.0) +
            self.grad_weight * losses.get('grad', 0.0)
        )
        
        return total, losses


class DDIMEstimator:
    """
    DDIM clean-sample estimator for physics loss computation.
    
    From the Bastek & Sun (ICLR 2025) pattern:
    x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t)
    
    This provides an estimate of the clean sample at training time
    without requiring full reverse diffusion.
    """
    
    @staticmethod
    def estimate_x0(x_t, eps_pred, alpha_bar_t):
        """
        Estimate clean sample from noisy sample and predicted noise.
        
        Args:
            x_t: Noisy sample [B, C, H, W]
            eps_pred: Predicted noise [B, C, H, W]
            alpha_bar_t: Cumulative product of alphas at timestep t [B]
        
        Returns:
            x0_hat: Estimated clean sample [B, C, H, W]
        """
        alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1)
        x0_hat = (x_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)
        return x0_hat
    
    @staticmethod
    def estimate_noise(x_t, x0_hat, alpha_bar_t):
        """Reverse: estimate noise from clean sample."""
        alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1)
        eps_pred = (x_t - torch.sqrt(alpha_bar_t) * x0_hat) / torch.sqrt(1 - alpha_bar_t)
        return eps_pred