from __future__ import annotations import os import random import sys from pathlib import Path import numpy as np import scipy.io import torch from scipy.fftpack import fft, ifft from scipy.signal import butter, filtfilt from torch.utils.data import Dataset # Import shared wavelet and convolution utilities sys.path.insert(0, str(Path(__file__).parent)) from util import build_convmtx, ricker_wavelet class SeismicBase: @staticmethod def phaseshift(w: np.ndarray, d: float) -> np.ndarray: if d == 0: return w wf_shift = fft(w) * np.exp(1j * (np.pi * d / 180.0)) return np.real(ifft(wf_shift)) @staticmethod def add_gaussian_band_noise( target_snr: float, data: np.ndarray, rng: np.random.Generator | None = None, ) -> tuple[np.ndarray, float]: if target_snr == 0: return data, 0.0 rng = rng or np.random.default_rng() signal_energy = np.linalg.norm(data) ** 2 noise_energy = signal_energy / (10 ** (target_snr / 10)) initial_noise = rng.normal(loc=0, scale=1, size=data.shape) noise = filtfilt( np.ones(3) / 3, 1, filtfilt(np.ones(3) / 3, 1, initial_noise.T, method="gust").T, method="gust", ) noise = noise * np.sqrt(noise_energy / np.linalg.norm(noise) ** 2) noisy_data = data + noise actual_snr = 10 * np.log10(signal_energy / np.linalg.norm(noise) ** 2) return noisy_data, float(actual_snr) @staticmethod def add_gaussian_noise( target_snr: float, data: np.ndarray, rng: np.random.Generator | None = None, ) -> tuple[np.ndarray, float]: if target_snr == 0: return data, 0.0 rng = rng or np.random.default_rng() signal_energy = np.linalg.norm(data) ** 2 noise_energy = signal_energy / (10 ** (target_snr / 10)) noise_std = np.sqrt(noise_energy / data.size) noise = rng.normal(0, noise_std, data.shape) noisy_data = data + noise actual_snr = 10 * np.log10(signal_energy / np.linalg.norm(noise) ** 2) return noisy_data, float(actual_snr) class OverthrustTrueimpDataset(SeismicBase, Dataset): """Overthrust benchmark dataset used by SAII-CLDM synthetic evaluation.""" def __init__( self, size: int = 256, interval: int = 1, special_splits: bool = False, use_mask: bool = False, record_noraml: bool = True, normalize: str = "minmax", zhengyan_type: str = "nonlinear", train_keys: tuple[str, ...] | list[str] = ("image", "dipin", "record"), ricks: tuple[int, ...] | list[int] = (30,), ricks_phase: tuple[int, ...] | list[int] = (0,), noise_snr: tuple[int, ...] | list[int] = (15,), noise_type: str = "guassian_band", dipins: tuple[float, ...] | list[float] = (0.012,), dipin_nsmoothz: int = 20, dipin_nsmoothx: int = 20, patch_indices: tuple[int, ...] | list[int] | None = None, base_seed: int = 1234, data_dir: str | Path | None = None, cache_dir: str | Path = "outputs/cache", fixed_f0: int | None = None, fixed_dipin_v: float | None = None, fixed_noise_snr: int | None = None, fixed_f0_phase: int | None = None, ): """Initialize the OverthrustTrueimpDataset. This dataset loads the Overthrust benchmark impedance model and synthesizes seismic records, low-frequency backgrounds, and reflection coefficients. Data is cached to disk to avoid recomputation. Patches are extracted from the full-size model and returned as CHW tensors for PyTorch DataLoader. Args: size: Patch size in pixels (height and width). Extracted patches are square regions of size×size from the 551×551 Overthrust model. Common values: 64, 128, 256. Default: 256. interval: Sampling interval for patch extraction. interval=1 uses all patches, interval=2 skips every other patch to reduce dataset size. Default: 1. special_splits: Whether to use specialized patch splitting strategy for the Overthrust model at specific locations. If False, uses standard grid-based splitting with overlap. Default: False. use_mask: Whether to include acquisition mask in returned samples. Masks indicate missing trace columns at positions 100, 200, 300. Default: False. record_noraml: Whether to normalize seismic records by a fixed constant (0.32159). If False, records remain in original amplitude scale. Default: True. normalize: Normalization method for impedance data. - 'minmax': Min-max scaling to [0, 1] range - 'max': Divide by maximum value only Default: 'minmax'. zhengyan_type: Forward modeling type for computing reflection coefficients. - 'linear': Linear approximation R = (v2 - v1) / (v2 + v1) - 'nonlinear': Nonlinear exact formula from impedance Default: 'nonlinear'. train_keys: Keys to include in each sample dict. Available keys: 'image' (impedance), 'dipin' (low-frequency), 'record' (seismic), 'reflection', 'mask_speed', 'mask'. Default: ('image', 'dipin', 'record'). ricks: Ricker wavelet dominant frequencies in Hz. Multiple frequencies can be specified; one will be randomly selected per sample unless fixed_f0 is set. Common values: 20, 25, 30, 35, 40. Default: (30,). ricks_phase: Phase shifts for Ricker wavelet in degrees. Multiple phases can be specified; one will be randomly selected unless fixed_f0_phase is set. Default: (0,). noise_snr: Target signal-to-noise ratios in dB for synthetic records. Multiple SNRs can be specified; one will be randomly selected unless fixed_noise_snr is set. Value 0 means no noise. Default: (15,). noise_type: Type of Gaussian noise added to seismic records. - 'guassian_band': Band-limited Gaussian noise (filtered) - 'guassian': White Gaussian noise (unfiltered) Default: 'guassian_band'. dipins: Low-pass filter cutoff frequencies for generating dipin (low- frequency background model). Multiple frequencies can be specified; one will be randomly selected unless fixed_dipin_v is set. Default: (0.012,). dipin_nsmoothz: Smoothing window size along depth (z) axis for dipin generation. Larger values produce smoother backgrounds. Default: 20. dipin_nsmoothx: Smoothing window size along horizontal (x) axis for dipin generation. Larger values produce smoother backgrounds. Default: 20. patch_indices: Specific patch indices to use. If None, all valid patches are used. Useful for creating train/validation splits. Default: None. base_seed: Base random seed for reproducible noise generation. Different frequencies and SNRs use derived seeds (base_seed + f0*1000 + phase*10 + snr) to ensure consistent noise patterns across runs. Default: 1234. data_dir: Directory containing Overthrust_trueimp.mat file. If None, uses DATASET_DIR environment variable or falls back to 'data'. Default: None. cache_dir: Directory for caching synthesized arrays (records, dipin). Cached files are reused across runs to avoid recomputation. Default: 'outputs/cache'. fixed_f0: If set, all samples use this specific Ricker frequency instead of random selection from ricks. Useful for evaluation with fixed wavelet. Default: None (random selection). fixed_dipin_v: If set, all samples use this specific dipin frequency instead of random selection from dipins. Default: None. fixed_noise_snr: If set, all samples use this specific SNR instead of random selection from noise_snr. Default: None. fixed_f0_phase: If set, all samples use this specific wavelet phase instead of random selection from ricks_phase. Default: None. Raises: FileNotFoundError: If Overthrust_trueimp.mat file not found in data_dir. ValueError: If normalize method is not 'minmax' or 'max'. ValueError: If zhengyan_type is not 'linear' or 'nonlinear'. ValueError: If noise_type is not 'guassian_band' or 'guassian'. Example: >>> # Basic usage with default parameters >>> dataset = OverthrustTrueimpDataset(size=256, normalize='minmax') >>> sample = dataset[0] >>> sample['image'].shape # torch.Size([1, 256, 256]) >>> sample['record'].shape # torch.Size([1, 256, 256]) >>> # Evaluation mode with fixed wavelet and noise >>> dataset = OverthrustTrueimpDataset( ... size=256, ... fixed_f0=30, ... fixed_noise_snr=20, ... fixed_f0_phase=0 ... ) >>> # All samples will have consistent wavelet parameters >>> # Train/val split using patch_indices >>> all_patches = list(range(len(dataset))) >>> train_patches = all_patches[:80] >>> val_patches = all_patches[80:] >>> train_dataset = OverthrustTrueimpDataset(patch_indices=train_patches) >>> val_dataset = OverthrustTrueimpDataset(patch_indices=val_patches) Note: - The full Overthrust model is 551×551 pixels; with size=256, 6 patches are extracted at fixed locations (0,0), (146,0), (295,0), (0,145), (146,145), (295,145). - Cached files are named with parameters encoded in filename to ensure correct cache reuse (e.g., 'Overthrust_trueimpBig_sesimic_record__nonlinear_ ricker=30-000_guassian_band=15_seed=1234.npy'). - The dataset inherits from both SeismicBase (noise generation, phase shift) and torch.utils.data.Dataset. - Samples are returned as float32 tensors with shape (1, size, size) in CHW format, ready for convolutional networks. """ self.name = "Overthrust_trueimp" self.size = size self.interval = interval self.special_splits = special_splits self.use_mask = use_mask self.record_noraml = record_noraml self.normalize = normalize self.zhengyan_type = zhengyan_type self.train_keys = list(train_keys) self.ricks = list(ricks) self.ricks_phase = list(ricks_phase) self.noise_snr = list(noise_snr) self.noise_type = noise_type self.dipins = list(dipins) self.dipin_nsmoothz = dipin_nsmoothz self.dipin_nsmoothx = dipin_nsmoothx self.base_seed = base_seed self.have_exp = False self.info: dict[str, float | str] = {} self.fixed_f0 = self.ricks[0] if fixed_f0 is None else fixed_f0 self.fixed_dipin_v = self.dipins[0] if fixed_dipin_v is None else fixed_dipin_v self.fixed_noise_snr = self.noise_snr[0] if fixed_noise_snr is None else fixed_noise_snr self.fixed_f0_phase = self.ricks_phase[0] if fixed_f0_phase is None else fixed_f0_phase self.data_dir = Path(data_dir or os.getenv("DATASET_DIR", "data")) self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self._load_big_impedance() self._build_splits_and_patches(special_splits=special_splits) self._build_wavelets() self.big_reflect = self._load_or_build_reflect() self.record_data = { f0: { phase: { snr: self._patches_from_big_image( self._load_or_build_record(f0=f0, phase=phase, noise_snr=snr) ) for snr in self.noise_snr } for phase in self.ricks_phase } for f0 in self.ricks } self.dipin_datas = { dipin_v: self._patches_from_big_image(self._load_or_build_dipin(dipin_v)) for dipin_v in self.dipins } all_indices = list(range(len(self.splits))) self.patch_indices = all_indices if patch_indices is None else list(patch_indices) def __len__(self) -> int: return len(self.patch_indices) def __getitem__(self, index: int) -> dict[str, torch.Tensor]: patch_idx = self.patch_indices[index] f0 = self.fixed_f0 if self.fixed_f0 is not None else random.choice(self.ricks) dipin_v = ( self.fixed_dipin_v if self.fixed_dipin_v is not None else random.choice(self.dipins) ) noise_snr = ( self.fixed_noise_snr if self.fixed_noise_snr is not None else random.choice(self.noise_snr) ) f0_phase = ( self.fixed_f0_phase if self.fixed_f0_phase is not None else random.choice(self.ricks_phase) ) sample = { "patch_idx": torch.tensor(patch_idx, dtype=torch.long), "seed": torch.tensor( self.base_seed + index + int(noise_snr) * 100, dtype=torch.long ), } if "image" in self.train_keys: sample["image"] = torch.from_numpy(self.file_data[patch_idx]).float() if "dipin" in self.train_keys: sample["dipin"] = torch.from_numpy(self.dipin_datas[dipin_v][patch_idx]).float() sample["dipin_v"] = torch.tensor(dipin_v, dtype=torch.float32).reshape(1, 1, 1) if "record" in self.train_keys: sample["record"] = torch.from_numpy( self.record_data[f0][f0_phase][noise_snr][patch_idx] ).float() sample["rick_v"] = torch.tensor(f0, dtype=torch.float32).reshape(1, 1, 1) sample["rick_phase"] = torch.tensor(f0_phase, dtype=torch.float32).reshape(1, 1, 1) sample["snr_v"] = torch.tensor(noise_snr, dtype=torch.float32).reshape(1, 1, 1) if "reflection" in self.train_keys: sample["reflection"] = torch.from_numpy(self.reflect_data[patch_idx]).float() if "mask_speed" in self.train_keys: sample["mask_speed"] = torch.from_numpy( self.mask_data[patch_idx] * self.file_data[patch_idx] ).float() if self.use_mask: sample["mask"] = torch.from_numpy(self.mask_data[patch_idx]).float() return sample def fan(self, x: np.ndarray) -> np.ndarray: minn = 5.0931 maxn = 6.501110975896774 return np.exp(x * (maxn - minn) + minn) * 10.9 + 200 def inv_normal(self, x: np.ndarray) -> np.ndarray: vmin = float(self.info["normal_min"]) vmax = float(self.info["normal_max"]) if self.normalize == "minmax": return x * (vmax - vmin) + vmin return x * vmax def _load_big_impedance(self) -> None: file_path = self.data_dir / "Overthrust_trueimp.mat" if not file_path.exists(): raise FileNotFoundError(f"Overthrust data not found: {file_path}") wave = scipy.io.loadmat(file_path)["Overthrust_trueimp"].T wave = np.log(wave) normal_min = wave.min() normal_max = wave.max() self.info.update( {"normal_min": normal_min, "normal_max": normal_max, "normal": "max"} ) self.big_img_unnorm = wave self.big_speedimg = wave if self.normalize == "max": wave = wave / normal_max elif self.normalize == "minmax": wave = (wave - normal_min) / (normal_max - normal_min) else: raise ValueError(f"Unsupported normalize: {self.normalize}") self.big_img = wave.astype(np.float32) def _build_splits_and_patches(self, special_splits: bool = False) -> None: self.big_mask = np.zeros(self.big_img.shape, dtype=np.float32) for col in (100, 200, 300): if col < self.big_mask.shape[1]: self.big_mask[:, col : col + 1] = 1 if special_splits: splits = [] for x in range(0, 551 - self.size, 20): for y in range(0, 551 - self.size, 20): splits.append((x, y)) for y in range(0, 551 - self.size, 9): splits.extend([(30, y), (90, y), (140, y)]) elif self.size == 256: splits = [ (0, 0), (146, 0), (551 - 256, 0), (0, 145), (146, 145), (551 - 256, 145), ] else: splits = [] interval_size = self.size - 1 for r in range(0, self.big_img.shape[0] - self.size, interval_size): for c in range(0, self.big_img.shape[1] - self.size, interval_size): splits.append((r, c)) splits.append((r, self.big_img.shape[1] - self.size)) for c in range(0, self.big_img.shape[1] - self.size, interval_size): splits.append((self.big_img.shape[0] - self.size, c)) splits.append( (self.big_img.shape[0] - self.size, self.big_img.shape[1] - self.size) ) self.splits = [] patches = [] masks = [] for x, y in splits: x2 = x + self.size y2 = y + self.size if x2 > self.big_img.shape[0] or y2 > self.big_img.shape[1]: continue self.splits.append((x, y)) patches.append(self.big_img[x:x2, y:y2].reshape(1, self.size, self.size)) masks.append(self.big_mask[x:x2, y:y2].reshape(1, self.size, self.size)) self.file_data = np.stack(patches, axis=0).astype(np.float32)[:: self.interval] self.mask_data = np.stack(masks, axis=0).astype(np.float32)[:: self.interval] self.splits = self.splits[:: self.interval] def _build_wavelets(self) -> None: nt0 = 256 dt0 = 0.002 self.wavelets = {} for f0 in self.ricks: self.wavelets[f0] = {} wav = ricker_wavelet(f0, nt0 // 2, dt0) for phase in self.ricks_phase: self.wavelets[f0][phase] = self.phaseshift(wav, phase) def _cache_path(self, name: str) -> Path: return self.cache_dir / name def _load_or_build_reflect(self) -> np.ndarray: cache_path = self._cache_path( f"Overthrust_trueimpBig_sesimic_reflect_{self.zhengyan_type}.npy" ) if not cache_path.exists(): size = self.big_img.shape[0] if self.zhengyan_type == "linear": s1 = np.diag(0.5 * np.ones(size - 1, dtype="float32"), k=1) - np.diag( 0.5 * np.ones(size - 1, dtype="float32"), k=-1 ) s1[-1] = s1[0] = 0 reflect = s1 @ self.big_img elif self.zhengyan_type == "nonlinear": expspeed = ( np.exp(self.big_img_unnorm) if self.have_exp is False else self.big_img_unnorm ) s1 = np.eye(size, k=1) - np.eye(size, k=0) s2 = np.eye(size, k=1) + np.eye(size, k=0) s1[-1] = 0 s2[-1] = 0 numerator = s1 @ expspeed denominator = s2 @ expspeed denominator = np.where(denominator < 1e-6, 1e-6, denominator) reflect = numerator / denominator else: raise ValueError(f"Unsupported zhengyan_type: {self.zhengyan_type}") np.save(cache_path, reflect) reflect = np.load(cache_path).astype(np.float32) self.reflect_data = self._patches_from_big_image(reflect) return reflect def _load_or_build_record(self, f0: int, phase: int, noise_snr: int) -> np.ndarray: cache_path = self._cache_path( f"Overthrust_trueimpBig_sesimic_record__{self.zhengyan_type}" f"_ricker={f0:02d}-{phase:03d}_{self.noise_type}={noise_snr:02d}" f"_seed={self.base_seed}.npy" ) if not cache_path.exists(): wav = self.wavelets[f0][phase] w_mat = build_convmtx(wav, self.big_reflect.shape[0]) records_clear = w_mat @ self.big_reflect rng = np.random.default_rng(self.base_seed + f0 * 1000 + phase * 10 + noise_snr) if self.noise_type == "guassian_band": record, _ = self.add_gaussian_band_noise(noise_snr, records_clear, rng=rng) elif self.noise_type == "guassian": record, _ = self.add_gaussian_noise(noise_snr, records_clear, rng=rng) else: raise ValueError(f"Unsupported noise_type: {self.noise_type}") np.save(cache_path, record) record = np.load(cache_path).astype(np.float32) self.info.update( { "record_minn": min(float(self.info.get("record_minn", 10)), float(record.min())), "record_maxn": max(float(self.info.get("record_maxn", -10)), float(record.max())), "record_normal": "max", } ) if self.record_noraml: record = record / 0.3215932963300079 self.info["record_maxn"] = 0.3215932963300079 return record def _load_or_build_dipin(self, dipin_v: float) -> np.ndarray: cache_path = self._cache_path( f"Overthrust_trueimpBig_sesimic_dipin={dipin_v:.03f}.npy" ) if not cache_path.exists(): bb, aa = butter(2, dipin_v, "low") smooth_filter_z = np.ones(self.dipin_nsmoothz) / float(self.dipin_nsmoothz) smooth_filter_x = np.ones(self.dipin_nsmoothx) / float(self.dipin_nsmoothx) mback = filtfilt(bb, aa, self.big_img.T).T mback = filtfilt(smooth_filter_z, 1, mback, axis=0) mback = filtfilt(smooth_filter_x, 1, mback, axis=1) np.save(cache_path, mback) return np.load(cache_path).astype(np.float32) def _patches_from_big_image(self, big_image: np.ndarray) -> np.ndarray: patches = [] for x, y in self.splits: patches.append( big_image[x : x + self.size, y : y + self.size].reshape( 1, self.size, self.size ) ) return np.stack(patches, axis=0).astype(np.float32)