File size: 3,121 Bytes
7e43d7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Seismic Wavelet Encoder - Minimal diffusers-compatible implementation.
Corresponds to HWD_down4 from ldm.modules.encoders.C2f_CGuidedBlock
"""
import torch
import torch.nn as nn
from pytorch_wavelets import DWTForward
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config


class SeismicWaveletEncoder(ModelMixin, ConfigMixin):
    """
    Lightweight wavelet-based condition encoder for seismic data.

    Architecture: Two-stage Haar wavelet downsampling
    - Stage 1: 1 channel -> 4 channels, spatial /2
    - Stage 2: 4 channels -> out_channels, spatial /2
    - Total: 256x256 -> 64x64

    Corresponds to HWD_down4 in C2f_CGuidedBlock.py
    """

    config_name = "config.json"

    @register_to_config
    def __init__(
        self,
        in_channels: int = 1,
        channels: int = 4,
        out_channels: int = 4,
        wavelet_type: str = "haar",
        mode: str = "reflect",
    ):
        super().__init__()

        # Stage 1: in_channels -> channels
        self.dwt1 = DWTForward(J=1, wave=wavelet_type, mode=mode)
        self.fusion1 = nn.Sequential(
            nn.Conv2d(in_channels * 4, channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(channels),
            nn.Tanh(),
        )

        # Stage 2: channels -> out_channels
        self.dwt2 = DWTForward(J=1, wave=wavelet_type, mode=mode)
        self.fusion2 = nn.Sequential(
            nn.Conv2d(channels * 4, out_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.Tanh(),
        )

    @staticmethod
    def _center_crop_like_legacy(x: torch.Tensor, out_size: int) -> torch.Tensor:
        """Match HWD.forward center crop after pytorch_wavelets DWTForward."""
        offset = (x.shape[3] // 2) - (out_size // 2)
        return x[:, :, offset:offset + out_size, offset:offset + out_size]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode seismic record to latent features.

        Args:
            x: Input tensor [B, in_channels, H, W], typically [B, 1, 256, 256]

        Returns:
            Encoded features [B, out_channels, H/4, W/4], typically [B, 4, 64, 64]
        """
        # Stage 1
        out_size = x.shape[2] // 2
        xll, xh = self.dwt1(x)
        # xh[0] shape: [B, C, 3, H/2, W/2] -> 3 high-freq bands
        xlh, xhl, xhh = torch.unbind(xh[0], dim=2)
        x = torch.cat([xll, xlh, xhl, xhh], dim=1)  # [B, in_channels*4, H/2, W/2]
        x = self._center_crop_like_legacy(x, out_size)
        x = self.fusion1(x)  # [B, channels, H/2, W/2]

        # Stage 2
        out_size = x.shape[2] // 2
        xll, xh = self.dwt2(x)
        xlh, xhl, xhh = torch.unbind(xh[0], dim=2)
        x = torch.cat([xll, xlh, xhl, xhh], dim=1)  # [B, channels*4, H/4, W/4]
        x = self._center_crop_like_legacy(x, out_size)
        x = self.fusion2(x)  # [B, out_channels, H/4, W/4]

        return x

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Alias for forward, for semantic clarity."""
        return self.forward(x)