""" Seismic Wavelet Encoder - Minimal diffusers-compatible implementation. Corresponds to HWD_down4 from ldm.modules.encoders.C2f_CGuidedBlock """ import torch import torch.nn as nn from pytorch_wavelets import DWTForward from diffusers.models.modeling_utils import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config class SeismicWaveletEncoder(ModelMixin, ConfigMixin): """ Lightweight wavelet-based condition encoder for seismic data. Architecture: Two-stage Haar wavelet downsampling - Stage 1: 1 channel -> 4 channels, spatial /2 - Stage 2: 4 channels -> out_channels, spatial /2 - Total: 256x256 -> 64x64 Corresponds to HWD_down4 in C2f_CGuidedBlock.py """ config_name = "config.json" @register_to_config def __init__( self, in_channels: int = 1, channels: int = 4, out_channels: int = 4, wavelet_type: str = "haar", mode: str = "reflect", ): super().__init__() # Stage 1: in_channels -> channels self.dwt1 = DWTForward(J=1, wave=wavelet_type, mode=mode) self.fusion1 = nn.Sequential( nn.Conv2d(in_channels * 4, channels, kernel_size=1, stride=1), nn.BatchNorm2d(channels), nn.Tanh(), ) # Stage 2: channels -> out_channels self.dwt2 = DWTForward(J=1, wave=wavelet_type, mode=mode) self.fusion2 = nn.Sequential( nn.Conv2d(channels * 4, out_channels, kernel_size=1, stride=1), nn.BatchNorm2d(out_channels), nn.Tanh(), ) @staticmethod def _center_crop_like_legacy(x: torch.Tensor, out_size: int) -> torch.Tensor: """Match HWD.forward center crop after pytorch_wavelets DWTForward.""" offset = (x.shape[3] // 2) - (out_size // 2) return x[:, :, offset:offset + out_size, offset:offset + out_size] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Encode seismic record to latent features. Args: x: Input tensor [B, in_channels, H, W], typically [B, 1, 256, 256] Returns: Encoded features [B, out_channels, H/4, W/4], typically [B, 4, 64, 64] """ # Stage 1 out_size = x.shape[2] // 2 xll, xh = self.dwt1(x) # xh[0] shape: [B, C, 3, H/2, W/2] -> 3 high-freq bands xlh, xhl, xhh = torch.unbind(xh[0], dim=2) x = torch.cat([xll, xlh, xhl, xhh], dim=1) # [B, in_channels*4, H/2, W/2] x = self._center_crop_like_legacy(x, out_size) x = self.fusion1(x) # [B, channels, H/2, W/2] # Stage 2 out_size = x.shape[2] // 2 xll, xh = self.dwt2(x) xlh, xhl, xhh = torch.unbind(xh[0], dim=2) x = torch.cat([xll, xlh, xhl, xhh], dim=1) # [B, channels*4, H/4, W/4] x = self._center_crop_like_legacy(x, out_size) x = self.fusion2(x) # [B, out_channels, H/4, W/4] return x def encode(self, x: torch.Tensor) -> torch.Tensor: """Alias for forward, for semantic clarity.""" return self.forward(x)