| """ |
| Seismic Wavelet Encoder - Minimal diffusers-compatible implementation. |
| Corresponds to HWD_down4 from ldm.modules.encoders.C2f_CGuidedBlock |
| """ |
| import torch |
| import torch.nn as nn |
| from pytorch_wavelets import DWTForward |
| from diffusers.models.modeling_utils import ModelMixin |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
|
|
| class SeismicWaveletEncoder(ModelMixin, ConfigMixin): |
| """ |
| Lightweight wavelet-based condition encoder for seismic data. |
| |
| Architecture: Two-stage Haar wavelet downsampling |
| - Stage 1: 1 channel -> 4 channels, spatial /2 |
| - Stage 2: 4 channels -> out_channels, spatial /2 |
| - Total: 256x256 -> 64x64 |
| |
| Corresponds to HWD_down4 in C2f_CGuidedBlock.py |
| """ |
|
|
| config_name = "config.json" |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_channels: int = 1, |
| channels: int = 4, |
| out_channels: int = 4, |
| wavelet_type: str = "haar", |
| mode: str = "reflect", |
| ): |
| super().__init__() |
|
|
| |
| self.dwt1 = DWTForward(J=1, wave=wavelet_type, mode=mode) |
| self.fusion1 = nn.Sequential( |
| nn.Conv2d(in_channels * 4, channels, kernel_size=1, stride=1), |
| nn.BatchNorm2d(channels), |
| nn.Tanh(), |
| ) |
|
|
| |
| self.dwt2 = DWTForward(J=1, wave=wavelet_type, mode=mode) |
| self.fusion2 = nn.Sequential( |
| nn.Conv2d(channels * 4, out_channels, kernel_size=1, stride=1), |
| nn.BatchNorm2d(out_channels), |
| nn.Tanh(), |
| ) |
|
|
| @staticmethod |
| def _center_crop_like_legacy(x: torch.Tensor, out_size: int) -> torch.Tensor: |
| """Match HWD.forward center crop after pytorch_wavelets DWTForward.""" |
| offset = (x.shape[3] // 2) - (out_size // 2) |
| return x[:, :, offset:offset + out_size, offset:offset + out_size] |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Encode seismic record to latent features. |
| |
| Args: |
| x: Input tensor [B, in_channels, H, W], typically [B, 1, 256, 256] |
| |
| Returns: |
| Encoded features [B, out_channels, H/4, W/4], typically [B, 4, 64, 64] |
| """ |
| |
| out_size = x.shape[2] // 2 |
| xll, xh = self.dwt1(x) |
| |
| xlh, xhl, xhh = torch.unbind(xh[0], dim=2) |
| x = torch.cat([xll, xlh, xhl, xhh], dim=1) |
| x = self._center_crop_like_legacy(x, out_size) |
| x = self.fusion1(x) |
|
|
| |
| out_size = x.shape[2] // 2 |
| xll, xh = self.dwt2(x) |
| xlh, xhl, xhh = torch.unbind(xh[0], dim=2) |
| x = torch.cat([xll, xlh, xhl, xhh], dim=1) |
| x = self._center_crop_like_legacy(x, out_size) |
| x = self.fusion2(x) |
|
|
| return x |
|
|
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| """Alias for forward, for semantic clarity.""" |
| return self.forward(x) |
|
|