File size: 3,703 Bytes
d105891 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | from __future__ import annotations
import numpy as np
import torch
def ricker_wavelet(f0: float, nt: int, dt: float) -> np.ndarray:
"""Ricker (Mexican hat) wavelet - pure NumPy implementation.
Replaces pylops.utils.wavelets.ricker with identical output.
Creates a Ricker wavelet given time axis parameters and central frequency.
Args:
f0: Central frequency in Hz
nt: Number of time samples (positive part including zero)
dt: Time sampling interval in seconds
Returns:
Wavelet array with symmetric time axis
"""
# Construct positive time axis (including zero)
t_positive = np.arange(nt) * dt
# _tcrop: if even length, remove last sample to ensure odd length
if len(t_positive) % 2 == 0:
t_positive = t_positive[:-1]
# Construct symmetric time axis (negative + positive)
t = np.concatenate((np.flipud(-t_positive[1:]), t_positive), axis=0)
# Ricker wavelet formula
w = (1 - 2 * (np.pi * f0 * t) ** 2) * np.exp(-((np.pi * f0 * t) ** 2))
return w
def build_convmtx(wavelet: np.ndarray, size: int) -> np.ndarray:
"""Build convolution matrix (Toeplitz matrix) - pure NumPy implementation.
Replaces pylops.utils.signalprocessing.convmtx with identical output.
Args:
wavelet: 1D wavelet array
size: Output matrix size (size x size)
Returns:
Convolution matrix of shape (size, size)
"""
wlen = len(wavelet)
offset = wlen // 2
matrix = np.zeros((size, size), dtype=wavelet.dtype)
for i in range(size):
for j, w_val in enumerate(wavelet):
col_idx = i - offset + j
if 0 <= col_idx < size:
matrix[i, col_idx] = w_val
return matrix
class OverthrustForwardOperator:
"""Differentiable seismic forward model matching OverthrustTrueimpDataset."""
def __init__(
self,
*,
wavelet: np.ndarray,
size: int = 256,
normal_min: float = 5.0931,
normal_max: float = 6.501110975896774,
record_scale: float = 0.3215932963300079,
normalize: str = "minmax",
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
wavelet_matrix = build_convmtx(wavelet, size)
s1 = np.eye(size, k=1) - np.eye(size, k=0)
s2 = np.eye(size, k=1) + np.eye(size, k=0)
s1[-1] = 0
s2[-1] = 0
self.wavelet_matrix = torch.as_tensor(wavelet_matrix, device=device, dtype=dtype)
self.s1 = torch.as_tensor(s1, device=device, dtype=dtype)
self.s2 = torch.as_tensor(s2, device=device, dtype=dtype)
self.normal_min = float(normal_min)
self.normal_max = float(normal_max)
self.record_scale = float(record_scale)
self.normalize = normalize
def _inv_normal(self, image: torch.Tensor) -> torch.Tensor:
if self.normalize == "minmax":
return image * (self.normal_max - self.normal_min) + self.normal_min
if self.normalize == "max":
return image * self.normal_max
raise ValueError(f"Unsupported normalize: {self.normalize}")
def __call__(self, image: torch.Tensor) -> torch.Tensor:
impedance = torch.exp(self._inv_normal(image))
numerator = torch.matmul(self.s1.to(dtype=image.dtype), impedance)
denominator = torch.matmul(self.s2.to(dtype=image.dtype), impedance)
reflectivity = numerator / torch.clamp(denominator, min=1e-6)
record = torch.matmul(self.wavelet_matrix.to(dtype=image.dtype), reflectivity)
return record / self.record_scale
|