from __future__ import annotations import numpy as np import torch def ricker_wavelet(f0: float, nt: int, dt: float) -> np.ndarray: """Ricker (Mexican hat) wavelet - pure NumPy implementation. Replaces pylops.utils.wavelets.ricker with identical output. Creates a Ricker wavelet given time axis parameters and central frequency. Args: f0: Central frequency in Hz nt: Number of time samples (positive part including zero) dt: Time sampling interval in seconds Returns: Wavelet array with symmetric time axis """ # Construct positive time axis (including zero) t_positive = np.arange(nt) * dt # _tcrop: if even length, remove last sample to ensure odd length if len(t_positive) % 2 == 0: t_positive = t_positive[:-1] # Construct symmetric time axis (negative + positive) t = np.concatenate((np.flipud(-t_positive[1:]), t_positive), axis=0) # Ricker wavelet formula w = (1 - 2 * (np.pi * f0 * t) ** 2) * np.exp(-((np.pi * f0 * t) ** 2)) return w def build_convmtx(wavelet: np.ndarray, size: int) -> np.ndarray: """Build convolution matrix (Toeplitz matrix) - pure NumPy implementation. Replaces pylops.utils.signalprocessing.convmtx with identical output. Args: wavelet: 1D wavelet array size: Output matrix size (size x size) Returns: Convolution matrix of shape (size, size) """ wlen = len(wavelet) offset = wlen // 2 matrix = np.zeros((size, size), dtype=wavelet.dtype) for i in range(size): for j, w_val in enumerate(wavelet): col_idx = i - offset + j if 0 <= col_idx < size: matrix[i, col_idx] = w_val return matrix class OverthrustForwardOperator: """Differentiable seismic forward model matching OverthrustTrueimpDataset.""" def __init__( self, *, wavelet: np.ndarray, size: int = 256, normal_min: float = 5.0931, normal_max: float = 6.501110975896774, record_scale: float = 0.3215932963300079, normalize: str = "minmax", device: torch.device | None = None, dtype: torch.dtype = torch.float32, ): device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") wavelet_matrix = build_convmtx(wavelet, size) s1 = np.eye(size, k=1) - np.eye(size, k=0) s2 = np.eye(size, k=1) + np.eye(size, k=0) s1[-1] = 0 s2[-1] = 0 self.wavelet_matrix = torch.as_tensor(wavelet_matrix, device=device, dtype=dtype) self.s1 = torch.as_tensor(s1, device=device, dtype=dtype) self.s2 = torch.as_tensor(s2, device=device, dtype=dtype) self.normal_min = float(normal_min) self.normal_max = float(normal_max) self.record_scale = float(record_scale) self.normalize = normalize def _inv_normal(self, image: torch.Tensor) -> torch.Tensor: if self.normalize == "minmax": return image * (self.normal_max - self.normal_min) + self.normal_min if self.normalize == "max": return image * self.normal_max raise ValueError(f"Unsupported normalize: {self.normalize}") def __call__(self, image: torch.Tensor) -> torch.Tensor: impedance = torch.exp(self._inv_normal(image)) numerator = torch.matmul(self.s1.to(dtype=image.dtype), impedance) denominator = torch.matmul(self.s2.to(dtype=image.dtype), impedance) reflectivity = numerator / torch.clamp(denominator, min=1e-6) record = torch.matmul(self.wavelet_matrix.to(dtype=image.dtype), reflectivity) return record / self.record_scale