| from __future__ import annotations |
|
|
| import os |
| import random |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import scipy.io |
| import torch |
| from scipy.fftpack import fft, ifft |
| from scipy.signal import butter, filtfilt |
| from torch.utils.data import Dataset |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
| from util import build_convmtx, ricker_wavelet |
|
|
|
|
| class SeismicBase: |
| @staticmethod |
| def phaseshift(w: np.ndarray, d: float) -> np.ndarray: |
| if d == 0: |
| return w |
| wf_shift = fft(w) * np.exp(1j * (np.pi * d / 180.0)) |
| return np.real(ifft(wf_shift)) |
|
|
| @staticmethod |
| def add_gaussian_band_noise( |
| target_snr: float, |
| data: np.ndarray, |
| rng: np.random.Generator | None = None, |
| ) -> tuple[np.ndarray, float]: |
| if target_snr == 0: |
| return data, 0.0 |
| rng = rng or np.random.default_rng() |
| signal_energy = np.linalg.norm(data) ** 2 |
| noise_energy = signal_energy / (10 ** (target_snr / 10)) |
| initial_noise = rng.normal(loc=0, scale=1, size=data.shape) |
| noise = filtfilt( |
| np.ones(3) / 3, |
| 1, |
| filtfilt(np.ones(3) / 3, 1, initial_noise.T, method="gust").T, |
| method="gust", |
| ) |
| noise = noise * np.sqrt(noise_energy / np.linalg.norm(noise) ** 2) |
| noisy_data = data + noise |
| actual_snr = 10 * np.log10(signal_energy / np.linalg.norm(noise) ** 2) |
| return noisy_data, float(actual_snr) |
|
|
| @staticmethod |
| def add_gaussian_noise( |
| target_snr: float, |
| data: np.ndarray, |
| rng: np.random.Generator | None = None, |
| ) -> tuple[np.ndarray, float]: |
| if target_snr == 0: |
| return data, 0.0 |
| rng = rng or np.random.default_rng() |
| signal_energy = np.linalg.norm(data) ** 2 |
| noise_energy = signal_energy / (10 ** (target_snr / 10)) |
| noise_std = np.sqrt(noise_energy / data.size) |
| noise = rng.normal(0, noise_std, data.shape) |
| noisy_data = data + noise |
| actual_snr = 10 * np.log10(signal_energy / np.linalg.norm(noise) ** 2) |
| return noisy_data, float(actual_snr) |
|
|
|
|
| class OverthrustTrueimpDataset(SeismicBase, Dataset): |
| """Overthrust benchmark dataset used by SAII-CLDM synthetic evaluation.""" |
|
|
| def __init__( |
| self, |
| size: int = 256, |
| interval: int = 1, |
| special_splits: bool = False, |
| use_mask: bool = False, |
| record_noraml: bool = True, |
| normalize: str = "minmax", |
| zhengyan_type: str = "nonlinear", |
| train_keys: tuple[str, ...] | list[str] = ("image", "dipin", "record"), |
| ricks: tuple[int, ...] | list[int] = (30,), |
| ricks_phase: tuple[int, ...] | list[int] = (0,), |
| noise_snr: tuple[int, ...] | list[int] = (15,), |
| noise_type: str = "guassian_band", |
| dipins: tuple[float, ...] | list[float] = (0.012,), |
| dipin_nsmoothz: int = 20, |
| dipin_nsmoothx: int = 20, |
| patch_indices: tuple[int, ...] | list[int] | None = None, |
| base_seed: int = 1234, |
| data_dir: str | Path | None = None, |
| cache_dir: str | Path = "outputs/cache", |
| fixed_f0: int | None = None, |
| fixed_dipin_v: float | None = None, |
| fixed_noise_snr: int | None = None, |
| fixed_f0_phase: int | None = None, |
| ): |
| """Initialize the OverthrustTrueimpDataset. |
| |
| This dataset loads the Overthrust benchmark impedance model and synthesizes |
| seismic records, low-frequency backgrounds, and reflection coefficients. |
| Data is cached to disk to avoid recomputation. Patches are extracted from |
| the full-size model and returned as CHW tensors for PyTorch DataLoader. |
| |
| Args: |
| size: Patch size in pixels (height and width). Extracted patches are |
| square regions of size×size from the 551×551 Overthrust model. |
| Common values: 64, 128, 256. Default: 256. |
| interval: Sampling interval for patch extraction. interval=1 uses all |
| patches, interval=2 skips every other patch to reduce dataset size. |
| Default: 1. |
| special_splits: Whether to use specialized patch splitting strategy for |
| the Overthrust model at specific locations. If False, uses standard |
| grid-based splitting with overlap. Default: False. |
| use_mask: Whether to include acquisition mask in returned samples. |
| Masks indicate missing trace columns at positions 100, 200, 300. |
| Default: False. |
| record_noraml: Whether to normalize seismic records by a fixed constant |
| (0.32159). If False, records remain in original amplitude scale. |
| Default: True. |
| normalize: Normalization method for impedance data. |
| - 'minmax': Min-max scaling to [0, 1] range |
| - 'max': Divide by maximum value only |
| Default: 'minmax'. |
| zhengyan_type: Forward modeling type for computing reflection coefficients. |
| - 'linear': Linear approximation R = (v2 - v1) / (v2 + v1) |
| - 'nonlinear': Nonlinear exact formula from impedance |
| Default: 'nonlinear'. |
| train_keys: Keys to include in each sample dict. Available keys: |
| 'image' (impedance), 'dipin' (low-frequency), 'record' (seismic), |
| 'reflection', 'mask_speed', 'mask'. Default: ('image', 'dipin', 'record'). |
| ricks: Ricker wavelet dominant frequencies in Hz. Multiple frequencies |
| can be specified; one will be randomly selected per sample unless |
| fixed_f0 is set. Common values: 20, 25, 30, 35, 40. Default: (30,). |
| ricks_phase: Phase shifts for Ricker wavelet in degrees. Multiple phases |
| can be specified; one will be randomly selected unless fixed_f0_phase |
| is set. Default: (0,). |
| noise_snr: Target signal-to-noise ratios in dB for synthetic records. |
| Multiple SNRs can be specified; one will be randomly selected unless |
| fixed_noise_snr is set. Value 0 means no noise. Default: (15,). |
| noise_type: Type of Gaussian noise added to seismic records. |
| - 'guassian_band': Band-limited Gaussian noise (filtered) |
| - 'guassian': White Gaussian noise (unfiltered) |
| Default: 'guassian_band'. |
| dipins: Low-pass filter cutoff frequencies for generating dipin (low- |
| frequency background model). Multiple frequencies can be specified; |
| one will be randomly selected unless fixed_dipin_v is set. |
| Default: (0.012,). |
| dipin_nsmoothz: Smoothing window size along depth (z) axis for dipin |
| generation. Larger values produce smoother backgrounds. Default: 20. |
| dipin_nsmoothx: Smoothing window size along horizontal (x) axis for dipin |
| generation. Larger values produce smoother backgrounds. Default: 20. |
| patch_indices: Specific patch indices to use. If None, all valid patches |
| are used. Useful for creating train/validation splits. Default: None. |
| base_seed: Base random seed for reproducible noise generation. Different |
| frequencies and SNRs use derived seeds (base_seed + f0*1000 + phase*10 |
| + snr) to ensure consistent noise patterns across runs. Default: 1234. |
| data_dir: Directory containing Overthrust_trueimp.mat file. If None, |
| uses DATASET_DIR environment variable or falls back to 'data'. |
| Default: None. |
| cache_dir: Directory for caching synthesized arrays (records, dipin). |
| Cached files are reused across runs to avoid recomputation. |
| Default: 'outputs/cache'. |
| fixed_f0: If set, all samples use this specific Ricker frequency instead |
| of random selection from ricks. Useful for evaluation with fixed |
| wavelet. Default: None (random selection). |
| fixed_dipin_v: If set, all samples use this specific dipin frequency |
| instead of random selection from dipins. Default: None. |
| fixed_noise_snr: If set, all samples use this specific SNR instead of |
| random selection from noise_snr. Default: None. |
| fixed_f0_phase: If set, all samples use this specific wavelet phase |
| instead of random selection from ricks_phase. Default: None. |
| |
| Raises: |
| FileNotFoundError: If Overthrust_trueimp.mat file not found in data_dir. |
| ValueError: If normalize method is not 'minmax' or 'max'. |
| ValueError: If zhengyan_type is not 'linear' or 'nonlinear'. |
| ValueError: If noise_type is not 'guassian_band' or 'guassian'. |
| |
| Example: |
| >>> # Basic usage with default parameters |
| >>> dataset = OverthrustTrueimpDataset(size=256, normalize='minmax') |
| >>> sample = dataset[0] |
| >>> sample['image'].shape # torch.Size([1, 256, 256]) |
| >>> sample['record'].shape # torch.Size([1, 256, 256]) |
| |
| >>> # Evaluation mode with fixed wavelet and noise |
| >>> dataset = OverthrustTrueimpDataset( |
| ... size=256, |
| ... fixed_f0=30, |
| ... fixed_noise_snr=20, |
| ... fixed_f0_phase=0 |
| ... ) |
| >>> # All samples will have consistent wavelet parameters |
| |
| >>> # Train/val split using patch_indices |
| >>> all_patches = list(range(len(dataset))) |
| >>> train_patches = all_patches[:80] |
| >>> val_patches = all_patches[80:] |
| >>> train_dataset = OverthrustTrueimpDataset(patch_indices=train_patches) |
| >>> val_dataset = OverthrustTrueimpDataset(patch_indices=val_patches) |
| |
| Note: |
| - The full Overthrust model is 551×551 pixels; with size=256, 6 patches |
| are extracted at fixed locations (0,0), (146,0), (295,0), (0,145), |
| (146,145), (295,145). |
| - Cached files are named with parameters encoded in filename to ensure |
| correct cache reuse (e.g., 'Overthrust_trueimpBig_sesimic_record__nonlinear_ |
| ricker=30-000_guassian_band=15_seed=1234.npy'). |
| - The dataset inherits from both SeismicBase (noise generation, phase shift) |
| and torch.utils.data.Dataset. |
| - Samples are returned as float32 tensors with shape (1, size, size) in |
| CHW format, ready for convolutional networks. |
| """ |
| self.name = "Overthrust_trueimp" |
| self.size = size |
| self.interval = interval |
| self.special_splits = special_splits |
| self.use_mask = use_mask |
| self.record_noraml = record_noraml |
| self.normalize = normalize |
| self.zhengyan_type = zhengyan_type |
| self.train_keys = list(train_keys) |
| self.ricks = list(ricks) |
| self.ricks_phase = list(ricks_phase) |
| self.noise_snr = list(noise_snr) |
| self.noise_type = noise_type |
| self.dipins = list(dipins) |
| self.dipin_nsmoothz = dipin_nsmoothz |
| self.dipin_nsmoothx = dipin_nsmoothx |
| self.base_seed = base_seed |
| self.have_exp = False |
| self.info: dict[str, float | str] = {} |
| self.fixed_f0 = self.ricks[0] if fixed_f0 is None else fixed_f0 |
| self.fixed_dipin_v = self.dipins[0] if fixed_dipin_v is None else fixed_dipin_v |
| self.fixed_noise_snr = self.noise_snr[0] if fixed_noise_snr is None else fixed_noise_snr |
| self.fixed_f0_phase = self.ricks_phase[0] if fixed_f0_phase is None else fixed_f0_phase |
| self.data_dir = Path(data_dir or os.getenv("DATASET_DIR", "data")) |
| self.cache_dir = Path(cache_dir) |
| self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
| self._load_big_impedance() |
| self._build_splits_and_patches(special_splits=special_splits) |
| self._build_wavelets() |
| self.big_reflect = self._load_or_build_reflect() |
| self.record_data = { |
| f0: { |
| phase: { |
| snr: self._patches_from_big_image( |
| self._load_or_build_record(f0=f0, phase=phase, noise_snr=snr) |
| ) |
| for snr in self.noise_snr |
| } |
| for phase in self.ricks_phase |
| } |
| for f0 in self.ricks |
| } |
| self.dipin_datas = { |
| dipin_v: self._patches_from_big_image(self._load_or_build_dipin(dipin_v)) |
| for dipin_v in self.dipins |
| } |
| all_indices = list(range(len(self.splits))) |
| self.patch_indices = all_indices if patch_indices is None else list(patch_indices) |
|
|
| def __len__(self) -> int: |
| return len(self.patch_indices) |
|
|
| def __getitem__(self, index: int) -> dict[str, torch.Tensor]: |
| patch_idx = self.patch_indices[index] |
| f0 = self.fixed_f0 if self.fixed_f0 is not None else random.choice(self.ricks) |
| dipin_v = ( |
| self.fixed_dipin_v |
| if self.fixed_dipin_v is not None |
| else random.choice(self.dipins) |
| ) |
| noise_snr = ( |
| self.fixed_noise_snr |
| if self.fixed_noise_snr is not None |
| else random.choice(self.noise_snr) |
| ) |
| f0_phase = ( |
| self.fixed_f0_phase |
| if self.fixed_f0_phase is not None |
| else random.choice(self.ricks_phase) |
| ) |
| sample = { |
| "patch_idx": torch.tensor(patch_idx, dtype=torch.long), |
| "seed": torch.tensor( |
| self.base_seed + index + int(noise_snr) * 100, dtype=torch.long |
| ), |
| } |
| if "image" in self.train_keys: |
| sample["image"] = torch.from_numpy(self.file_data[patch_idx]).float() |
| if "dipin" in self.train_keys: |
| sample["dipin"] = torch.from_numpy(self.dipin_datas[dipin_v][patch_idx]).float() |
| sample["dipin_v"] = torch.tensor(dipin_v, dtype=torch.float32).reshape(1, 1, 1) |
| if "record" in self.train_keys: |
| sample["record"] = torch.from_numpy( |
| self.record_data[f0][f0_phase][noise_snr][patch_idx] |
| ).float() |
| sample["rick_v"] = torch.tensor(f0, dtype=torch.float32).reshape(1, 1, 1) |
| sample["rick_phase"] = torch.tensor(f0_phase, dtype=torch.float32).reshape(1, 1, 1) |
| sample["snr_v"] = torch.tensor(noise_snr, dtype=torch.float32).reshape(1, 1, 1) |
| if "reflection" in self.train_keys: |
| sample["reflection"] = torch.from_numpy(self.reflect_data[patch_idx]).float() |
| if "mask_speed" in self.train_keys: |
| sample["mask_speed"] = torch.from_numpy( |
| self.mask_data[patch_idx] * self.file_data[patch_idx] |
| ).float() |
| if self.use_mask: |
| sample["mask"] = torch.from_numpy(self.mask_data[patch_idx]).float() |
| return sample |
|
|
| def fan(self, x: np.ndarray) -> np.ndarray: |
| minn = 5.0931 |
| maxn = 6.501110975896774 |
| return np.exp(x * (maxn - minn) + minn) * 10.9 + 200 |
|
|
| def inv_normal(self, x: np.ndarray) -> np.ndarray: |
| vmin = float(self.info["normal_min"]) |
| vmax = float(self.info["normal_max"]) |
| if self.normalize == "minmax": |
| return x * (vmax - vmin) + vmin |
| return x * vmax |
|
|
| def _load_big_impedance(self) -> None: |
| file_path = self.data_dir / "Overthrust_trueimp.mat" |
| if not file_path.exists(): |
| raise FileNotFoundError(f"Overthrust data not found: {file_path}") |
| wave = scipy.io.loadmat(file_path)["Overthrust_trueimp"].T |
| wave = np.log(wave) |
| normal_min = wave.min() |
| normal_max = wave.max() |
| self.info.update( |
| {"normal_min": normal_min, "normal_max": normal_max, "normal": "max"} |
| ) |
| self.big_img_unnorm = wave |
| self.big_speedimg = wave |
| if self.normalize == "max": |
| wave = wave / normal_max |
| elif self.normalize == "minmax": |
| wave = (wave - normal_min) / (normal_max - normal_min) |
| else: |
| raise ValueError(f"Unsupported normalize: {self.normalize}") |
| self.big_img = wave.astype(np.float32) |
|
|
| def _build_splits_and_patches(self, special_splits: bool = False) -> None: |
| self.big_mask = np.zeros(self.big_img.shape, dtype=np.float32) |
| for col in (100, 200, 300): |
| if col < self.big_mask.shape[1]: |
| self.big_mask[:, col : col + 1] = 1 |
| if special_splits: |
| splits = [] |
| for x in range(0, 551 - self.size, 20): |
| for y in range(0, 551 - self.size, 20): |
| splits.append((x, y)) |
| for y in range(0, 551 - self.size, 9): |
| splits.extend([(30, y), (90, y), (140, y)]) |
| elif self.size == 256: |
| splits = [ |
| (0, 0), |
| (146, 0), |
| (551 - 256, 0), |
| (0, 145), |
| (146, 145), |
| (551 - 256, 145), |
| ] |
| else: |
| splits = [] |
| interval_size = self.size - 1 |
| for r in range(0, self.big_img.shape[0] - self.size, interval_size): |
| for c in range(0, self.big_img.shape[1] - self.size, interval_size): |
| splits.append((r, c)) |
| splits.append((r, self.big_img.shape[1] - self.size)) |
| for c in range(0, self.big_img.shape[1] - self.size, interval_size): |
| splits.append((self.big_img.shape[0] - self.size, c)) |
| splits.append( |
| (self.big_img.shape[0] - self.size, self.big_img.shape[1] - self.size) |
| ) |
|
|
| self.splits = [] |
| patches = [] |
| masks = [] |
| for x, y in splits: |
| x2 = x + self.size |
| y2 = y + self.size |
| if x2 > self.big_img.shape[0] or y2 > self.big_img.shape[1]: |
| continue |
| self.splits.append((x, y)) |
| patches.append(self.big_img[x:x2, y:y2].reshape(1, self.size, self.size)) |
| masks.append(self.big_mask[x:x2, y:y2].reshape(1, self.size, self.size)) |
| self.file_data = np.stack(patches, axis=0).astype(np.float32)[:: self.interval] |
| self.mask_data = np.stack(masks, axis=0).astype(np.float32)[:: self.interval] |
| self.splits = self.splits[:: self.interval] |
|
|
| def _build_wavelets(self) -> None: |
| nt0 = 256 |
| dt0 = 0.002 |
| self.wavelets = {} |
| for f0 in self.ricks: |
| self.wavelets[f0] = {} |
| wav = ricker_wavelet(f0, nt0 // 2, dt0) |
| for phase in self.ricks_phase: |
| self.wavelets[f0][phase] = self.phaseshift(wav, phase) |
|
|
| def _cache_path(self, name: str) -> Path: |
| return self.cache_dir / name |
|
|
| def _load_or_build_reflect(self) -> np.ndarray: |
| cache_path = self._cache_path( |
| f"Overthrust_trueimpBig_sesimic_reflect_{self.zhengyan_type}.npy" |
| ) |
| if not cache_path.exists(): |
| size = self.big_img.shape[0] |
| if self.zhengyan_type == "linear": |
| s1 = np.diag(0.5 * np.ones(size - 1, dtype="float32"), k=1) - np.diag( |
| 0.5 * np.ones(size - 1, dtype="float32"), k=-1 |
| ) |
| s1[-1] = s1[0] = 0 |
| reflect = s1 @ self.big_img |
| elif self.zhengyan_type == "nonlinear": |
| expspeed = ( |
| np.exp(self.big_img_unnorm) |
| if self.have_exp is False |
| else self.big_img_unnorm |
| ) |
| s1 = np.eye(size, k=1) - np.eye(size, k=0) |
| s2 = np.eye(size, k=1) + np.eye(size, k=0) |
| s1[-1] = 0 |
| s2[-1] = 0 |
| numerator = s1 @ expspeed |
| denominator = s2 @ expspeed |
| denominator = np.where(denominator < 1e-6, 1e-6, denominator) |
| reflect = numerator / denominator |
| else: |
| raise ValueError(f"Unsupported zhengyan_type: {self.zhengyan_type}") |
| np.save(cache_path, reflect) |
| reflect = np.load(cache_path).astype(np.float32) |
| self.reflect_data = self._patches_from_big_image(reflect) |
| return reflect |
|
|
| def _load_or_build_record(self, f0: int, phase: int, noise_snr: int) -> np.ndarray: |
| cache_path = self._cache_path( |
| f"Overthrust_trueimpBig_sesimic_record__{self.zhengyan_type}" |
| f"_ricker={f0:02d}-{phase:03d}_{self.noise_type}={noise_snr:02d}" |
| f"_seed={self.base_seed}.npy" |
| ) |
| if not cache_path.exists(): |
| wav = self.wavelets[f0][phase] |
| w_mat = build_convmtx(wav, self.big_reflect.shape[0]) |
| records_clear = w_mat @ self.big_reflect |
| rng = np.random.default_rng(self.base_seed + f0 * 1000 + phase * 10 + noise_snr) |
| if self.noise_type == "guassian_band": |
| record, _ = self.add_gaussian_band_noise(noise_snr, records_clear, rng=rng) |
| elif self.noise_type == "guassian": |
| record, _ = self.add_gaussian_noise(noise_snr, records_clear, rng=rng) |
| else: |
| raise ValueError(f"Unsupported noise_type: {self.noise_type}") |
| np.save(cache_path, record) |
| record = np.load(cache_path).astype(np.float32) |
| self.info.update( |
| { |
| "record_minn": min(float(self.info.get("record_minn", 10)), float(record.min())), |
| "record_maxn": max(float(self.info.get("record_maxn", -10)), float(record.max())), |
| "record_normal": "max", |
| } |
| ) |
| if self.record_noraml: |
| record = record / 0.3215932963300079 |
| self.info["record_maxn"] = 0.3215932963300079 |
| return record |
|
|
| def _load_or_build_dipin(self, dipin_v: float) -> np.ndarray: |
| cache_path = self._cache_path( |
| f"Overthrust_trueimpBig_sesimic_dipin={dipin_v:.03f}.npy" |
| ) |
| if not cache_path.exists(): |
| bb, aa = butter(2, dipin_v, "low") |
| smooth_filter_z = np.ones(self.dipin_nsmoothz) / float(self.dipin_nsmoothz) |
| smooth_filter_x = np.ones(self.dipin_nsmoothx) / float(self.dipin_nsmoothx) |
| mback = filtfilt(bb, aa, self.big_img.T).T |
| mback = filtfilt(smooth_filter_z, 1, mback, axis=0) |
| mback = filtfilt(smooth_filter_x, 1, mback, axis=1) |
| np.save(cache_path, mback) |
| return np.load(cache_path).astype(np.float32) |
|
|
| def _patches_from_big_image(self, big_image: np.ndarray) -> np.ndarray: |
| patches = [] |
| for x, y in self.splits: |
| patches.append( |
| big_image[x : x + self.size, y : y + self.size].reshape( |
| 1, self.size, self.size |
| ) |
| ) |
| return np.stack(patches, axis=0).astype(np.float32) |
|
|