File size: 3,703 Bytes
d105891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations

import numpy as np
import torch


def ricker_wavelet(f0: float, nt: int, dt: float) -> np.ndarray:
    """Ricker (Mexican hat) wavelet - pure NumPy implementation.

    Replaces pylops.utils.wavelets.ricker with identical output.
    Creates a Ricker wavelet given time axis parameters and central frequency.

    Args:
        f0: Central frequency in Hz
        nt: Number of time samples (positive part including zero)
        dt: Time sampling interval in seconds

    Returns:
        Wavelet array with symmetric time axis
    """
    # Construct positive time axis (including zero)
    t_positive = np.arange(nt) * dt

    # _tcrop: if even length, remove last sample to ensure odd length
    if len(t_positive) % 2 == 0:
        t_positive = t_positive[:-1]

    # Construct symmetric time axis (negative + positive)
    t = np.concatenate((np.flipud(-t_positive[1:]), t_positive), axis=0)

    # Ricker wavelet formula
    w = (1 - 2 * (np.pi * f0 * t) ** 2) * np.exp(-((np.pi * f0 * t) ** 2))

    return w


def build_convmtx(wavelet: np.ndarray, size: int) -> np.ndarray:
    """Build convolution matrix (Toeplitz matrix) - pure NumPy implementation.

    Replaces pylops.utils.signalprocessing.convmtx with identical output.

    Args:
        wavelet: 1D wavelet array
        size: Output matrix size (size x size)

    Returns:
        Convolution matrix of shape (size, size)
    """
    wlen = len(wavelet)
    offset = wlen // 2
    matrix = np.zeros((size, size), dtype=wavelet.dtype)

    for i in range(size):
        for j, w_val in enumerate(wavelet):
            col_idx = i - offset + j
            if 0 <= col_idx < size:
                matrix[i, col_idx] = w_val

    return matrix


class OverthrustForwardOperator:
    """Differentiable seismic forward model matching OverthrustTrueimpDataset."""

    def __init__(
        self,
        *,
        wavelet: np.ndarray,
        size: int = 256,
        normal_min: float = 5.0931,
        normal_max: float = 6.501110975896774,
        record_scale: float = 0.3215932963300079,
        normalize: str = "minmax",
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ):
        device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        wavelet_matrix = build_convmtx(wavelet, size)
        s1 = np.eye(size, k=1) - np.eye(size, k=0)
        s2 = np.eye(size, k=1) + np.eye(size, k=0)
        s1[-1] = 0
        s2[-1] = 0

        self.wavelet_matrix = torch.as_tensor(wavelet_matrix, device=device, dtype=dtype)
        self.s1 = torch.as_tensor(s1, device=device, dtype=dtype)
        self.s2 = torch.as_tensor(s2, device=device, dtype=dtype)
        self.normal_min = float(normal_min)
        self.normal_max = float(normal_max)
        self.record_scale = float(record_scale)
        self.normalize = normalize

    def _inv_normal(self, image: torch.Tensor) -> torch.Tensor:
        if self.normalize == "minmax":
            return image * (self.normal_max - self.normal_min) + self.normal_min
        if self.normalize == "max":
            return image * self.normal_max
        raise ValueError(f"Unsupported normalize: {self.normalize}")

    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        impedance = torch.exp(self._inv_normal(image))
        numerator = torch.matmul(self.s1.to(dtype=image.dtype), impedance)
        denominator = torch.matmul(self.s2.to(dtype=image.dtype), impedance)
        reflectivity = numerator / torch.clamp(denominator, min=1e-6)
        record = torch.matmul(self.wavelet_matrix.to(dtype=image.dtype), reflectivity)
        return record / self.record_scale