File size: 3,121 Bytes
7e43d7d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | """
Seismic Wavelet Encoder - Minimal diffusers-compatible implementation.
Corresponds to HWD_down4 from ldm.modules.encoders.C2f_CGuidedBlock
"""
import torch
import torch.nn as nn
from pytorch_wavelets import DWTForward
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
class SeismicWaveletEncoder(ModelMixin, ConfigMixin):
"""
Lightweight wavelet-based condition encoder for seismic data.
Architecture: Two-stage Haar wavelet downsampling
- Stage 1: 1 channel -> 4 channels, spatial /2
- Stage 2: 4 channels -> out_channels, spatial /2
- Total: 256x256 -> 64x64
Corresponds to HWD_down4 in C2f_CGuidedBlock.py
"""
config_name = "config.json"
@register_to_config
def __init__(
self,
in_channels: int = 1,
channels: int = 4,
out_channels: int = 4,
wavelet_type: str = "haar",
mode: str = "reflect",
):
super().__init__()
# Stage 1: in_channels -> channels
self.dwt1 = DWTForward(J=1, wave=wavelet_type, mode=mode)
self.fusion1 = nn.Sequential(
nn.Conv2d(in_channels * 4, channels, kernel_size=1, stride=1),
nn.BatchNorm2d(channels),
nn.Tanh(),
)
# Stage 2: channels -> out_channels
self.dwt2 = DWTForward(J=1, wave=wavelet_type, mode=mode)
self.fusion2 = nn.Sequential(
nn.Conv2d(channels * 4, out_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.Tanh(),
)
@staticmethod
def _center_crop_like_legacy(x: torch.Tensor, out_size: int) -> torch.Tensor:
"""Match HWD.forward center crop after pytorch_wavelets DWTForward."""
offset = (x.shape[3] // 2) - (out_size // 2)
return x[:, :, offset:offset + out_size, offset:offset + out_size]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode seismic record to latent features.
Args:
x: Input tensor [B, in_channels, H, W], typically [B, 1, 256, 256]
Returns:
Encoded features [B, out_channels, H/4, W/4], typically [B, 4, 64, 64]
"""
# Stage 1
out_size = x.shape[2] // 2
xll, xh = self.dwt1(x)
# xh[0] shape: [B, C, 3, H/2, W/2] -> 3 high-freq bands
xlh, xhl, xhh = torch.unbind(xh[0], dim=2)
x = torch.cat([xll, xlh, xhl, xhh], dim=1) # [B, in_channels*4, H/2, W/2]
x = self._center_crop_like_legacy(x, out_size)
x = self.fusion1(x) # [B, channels, H/2, W/2]
# Stage 2
out_size = x.shape[2] // 2
xll, xh = self.dwt2(x)
xlh, xhl, xhh = torch.unbind(xh[0], dim=2)
x = torch.cat([xll, xlh, xhl, xhh], dim=1) # [B, channels*4, H/4, W/4]
x = self._center_crop_like_legacy(x, out_size)
x = self.fusion2(x) # [B, out_channels, H/4, W/4]
return x
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Alias for forward, for semantic clarity."""
return self.forward(x)
|