| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def ricker_wavelet(f0: float, nt: int, dt: float) -> np.ndarray: |
| """Ricker (Mexican hat) wavelet - pure NumPy implementation. |
| |
| Replaces pylops.utils.wavelets.ricker with identical output. |
| Creates a Ricker wavelet given time axis parameters and central frequency. |
| |
| Args: |
| f0: Central frequency in Hz |
| nt: Number of time samples (positive part including zero) |
| dt: Time sampling interval in seconds |
| |
| Returns: |
| Wavelet array with symmetric time axis |
| """ |
| |
| t_positive = np.arange(nt) * dt |
|
|
| |
| if len(t_positive) % 2 == 0: |
| t_positive = t_positive[:-1] |
|
|
| |
| t = np.concatenate((np.flipud(-t_positive[1:]), t_positive), axis=0) |
|
|
| |
| w = (1 - 2 * (np.pi * f0 * t) ** 2) * np.exp(-((np.pi * f0 * t) ** 2)) |
|
|
| return w |
|
|
|
|
| def build_convmtx(wavelet: np.ndarray, size: int) -> np.ndarray: |
| """Build convolution matrix (Toeplitz matrix) - pure NumPy implementation. |
| |
| Replaces pylops.utils.signalprocessing.convmtx with identical output. |
| |
| Args: |
| wavelet: 1D wavelet array |
| size: Output matrix size (size x size) |
| |
| Returns: |
| Convolution matrix of shape (size, size) |
| """ |
| wlen = len(wavelet) |
| offset = wlen // 2 |
| matrix = np.zeros((size, size), dtype=wavelet.dtype) |
|
|
| for i in range(size): |
| for j, w_val in enumerate(wavelet): |
| col_idx = i - offset + j |
| if 0 <= col_idx < size: |
| matrix[i, col_idx] = w_val |
|
|
| return matrix |
|
|
|
|
| class OverthrustForwardOperator: |
| """Differentiable seismic forward model matching OverthrustTrueimpDataset.""" |
|
|
| def __init__( |
| self, |
| *, |
| wavelet: np.ndarray, |
| size: int = 256, |
| normal_min: float = 5.0931, |
| normal_max: float = 6.501110975896774, |
| record_scale: float = 0.3215932963300079, |
| normalize: str = "minmax", |
| device: torch.device | None = None, |
| dtype: torch.dtype = torch.float32, |
| ): |
| device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| wavelet_matrix = build_convmtx(wavelet, size) |
| s1 = np.eye(size, k=1) - np.eye(size, k=0) |
| s2 = np.eye(size, k=1) + np.eye(size, k=0) |
| s1[-1] = 0 |
| s2[-1] = 0 |
|
|
| self.wavelet_matrix = torch.as_tensor(wavelet_matrix, device=device, dtype=dtype) |
| self.s1 = torch.as_tensor(s1, device=device, dtype=dtype) |
| self.s2 = torch.as_tensor(s2, device=device, dtype=dtype) |
| self.normal_min = float(normal_min) |
| self.normal_max = float(normal_max) |
| self.record_scale = float(record_scale) |
| self.normalize = normalize |
|
|
| def _inv_normal(self, image: torch.Tensor) -> torch.Tensor: |
| if self.normalize == "minmax": |
| return image * (self.normal_max - self.normal_min) + self.normal_min |
| if self.normalize == "max": |
| return image * self.normal_max |
| raise ValueError(f"Unsupported normalize: {self.normalize}") |
|
|
| def __call__(self, image: torch.Tensor) -> torch.Tensor: |
| impedance = torch.exp(self._inv_normal(image)) |
| numerator = torch.matmul(self.s1.to(dtype=image.dtype), impedance) |
| denominator = torch.matmul(self.s2.to(dtype=image.dtype), impedance) |
| reflectivity = numerator / torch.clamp(denominator, min=1e-6) |
| record = torch.matmul(self.wavelet_matrix.to(dtype=image.dtype), reflectivity) |
| return record / self.record_scale |
|
|