mally-2000 commited on
Commit
3f19d1a
·
verified ·
1 Parent(s): 5083bd4

Add Overthrust inference benchmark and model card

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.png filter=lfs diff=lfs merge=lfs -text
37
+ data/Overthrust_trueimp.mat filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: image-to-image
4
+ tags:
5
+ - seismic-inversion
6
+ - impedance-inversion
7
+ - diffusion
8
+ - ddpm
9
+ - overthrust
10
+ ---
11
+
12
+ # Seismic-LDDPM
13
+
14
+ Seismic-LDDPM is a latent DDPM pipeline for seismic impedance inversion. The
15
+ pipeline takes a low-frequency impedance image (`dipin`) and a synthetic seismic
16
+ record (`record`) and predicts the impedance image.
17
+
18
+ This repository includes:
19
+
20
+ - Diffusers-format model components: `vq_model`, `unet`, `scheduler`, and
21
+ `condition_encoder`.
22
+ - `SeismicImpInvLDDPMPipeline` in `pipeline.py`.
23
+ - A complete Overthrust benchmark sample at `data/Overthrust_trueimp.mat`.
24
+ - Inference scripts under `inference/`.
25
+
26
+ ## Installation
27
+
28
+ ```bash
29
+ git clone https://huggingface.co/mally-2000/seismic-lddpm
30
+ cd seismic-lddpm
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ ## Overthrust Evaluation
35
+
36
+ The Overthrust evaluation script is intentionally fixed to the bundled
37
+ `data/Overthrust_trueimp.mat`. It cuts the full model into six `256 x 256`
38
+ patches, synthesizes the seismic records and low-frequency impedance inputs,
39
+ runs inference, stitches the six predictions back together, and computes the
40
+ metrics.
41
+
42
+ ```bash
43
+ python inference/eval.py \
44
+ --model . \
45
+ --output outputs/overthrust \
46
+ --num-inference-steps 1000
47
+ ```
48
+
49
+ Outputs:
50
+
51
+ - `outputs/overthrust/full_target.npy`
52
+ - `outputs/overthrust/full_prediction.npy`
53
+ - `outputs/overthrust/full_reconstruction.npy`
54
+ - `outputs/overthrust/comparison_impedance.png`
55
+ - `outputs/overthrust/metrics_summary.json`
56
+
57
+ ## Benchmark Result
58
+
59
+ Evaluated locally on the bundled Overthrust benchmark with 1000 DDPM steps,
60
+ `noise_snr=15`, `dipin_v=0.012`, `f0=30`, `phase=0`, `seed=1234`, and patch
61
+ indices `[0, 1, 2, 3, 4, 5]`.
62
+
63
+ | Space | PSNR | SSIM | PCC | RRE | NMSE |
64
+ |---|---:|---:|---:|---:|---:|
65
+ | Normalized | 30.7698 | 0.9339 | 0.9963 | 0.0435 | 0.001894 |
66
+ | Impedance | 33.4413 | 0.9554 | 0.9957 | 0.0324 | 0.001050 |
67
+ | VQ reconstruction | 37.7954 | 0.9677 | 0.9983 | 0.0209 | 0.000435 |
68
+
69
+ ![Overthrust evaluation](assets/demo.png)
70
+
71
+ ## Single-Sample Inference
72
+
73
+ For a single prepared sample:
74
+
75
+ ```bash
76
+ python inference/infer.py \
77
+ --dipin path/to/dipin.npy \
78
+ --record path/to/record.npy \
79
+ --model . \
80
+ --output outputs/single
81
+ ```
82
+
83
+ The input arrays may be `H x W`, `C x H x W`, or `B x C x H x W`. The script
84
+ converts them to BCHW tensors and saves `prediction.npy` and `prediction.png`.
85
+
86
+ ## Python Usage
87
+
88
+ ```python
89
+ import torch
90
+ from pipeline import SeismicImpInvLDDPMPipeline
91
+
92
+ pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
93
+ "mally-2000/seismic-lddpm",
94
+ torch_dtype=torch.float32,
95
+ trust_remote_code=True,
96
+ ).to("cuda")
97
+
98
+ result = pipe(
99
+ dipin=dipin, # torch.Tensor, BCHW
100
+ record=record, # torch.Tensor, BCHW
101
+ num_inference_steps=1000,
102
+ seed=1234,
103
+ )
104
+
105
+ prediction = result.impedance_samples
106
+ ```
107
+
108
+ ## Notes
109
+
110
+ - `inference/dataset.py` contains a lightweight `SeismicBase` and
111
+ `OverthrustTrueimpDataset`; it does not depend on the original training
112
+ repository's `ldm.data.seisimic`.
113
+ - Synthetic record generation is seeded through the benchmark configuration so
114
+ the published Overthrust evaluation is reproducible.
115
+ - The bundled Overthrust file is used only as a compact benchmark input for
116
+ reproducing this model's inference pipeline.
assets/demo.png ADDED

Git LFS Details

  • SHA256: c6408b5736cf116ca133c7db9154624af429397fec2c7b477139a088beea911b
  • Pointer size: 131 Bytes
  • Size of remote file: 537 kB
data/Overthrust_trueimp.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59345b022a90f174efd444004192f9edf93ef5f65c556a84f52e9596f1695bd5
3
+ size 334424
examples/expected_metrics.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config": {
3
+ "size": 256,
4
+ "patch_indices": [
5
+ 0,
6
+ 1,
7
+ 2,
8
+ 3,
9
+ 4,
10
+ 5
11
+ ],
12
+ "noise_snr": 15,
13
+ "dipin_v": 0.012,
14
+ "f0": 30,
15
+ "f0_phase": 0,
16
+ "seed": 1234,
17
+ "zhengyan_type": "nonlinear",
18
+ "normalize": "minmax",
19
+ "batch_size": 3,
20
+ "num_inference_steps": 1000
21
+ },
22
+ "normalized": {
23
+ "PSNR": 30.76981345155257,
24
+ "rre": 0.043521951884031296,
25
+ "SSIM": 0.9339061199595424,
26
+ "PCC": 0.9963035366574778,
27
+ "nmse": 0.001894210814498365,
28
+ "mse": 3.811614279402499e-09
29
+ },
30
+ "impedance": {
31
+ "PSNR": 33.44134288739278,
32
+ "rre": 0.03240736946463585,
33
+ "SSIM": 0.955363744873021,
34
+ "PCC": 0.9957485049549735,
35
+ "nmse": 0.0010502231307327747,
36
+ "mse": 0.11484166561534005
37
+ },
38
+ "encode_impedance": {
39
+ "PSNR": 37.79544219163976,
40
+ "rre": 0.020859846845269203,
41
+ "SSIM": 0.9676508176373475,
42
+ "PCC": 0.9982799028636675,
43
+ "nmse": 0.0004351270035840571,
44
+ "mse": 0.04758103889550172
45
+ }
46
+ }
inference/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Inference helpers for seismic-lddpm."""
inference/dataset.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import pylops
9
+ import scipy.io
10
+ import torch
11
+ from scipy.fftpack import fft, ifft
12
+ from scipy.signal import butter, filtfilt
13
+ from torch.utils.data import Dataset
14
+
15
+
16
+ class SeismicBase:
17
+ @staticmethod
18
+ def phaseshift(w: np.ndarray, d: float) -> np.ndarray:
19
+ if d == 0:
20
+ return w
21
+ wf_shift = fft(w) * np.exp(1j * (np.pi * d / 180.0))
22
+ return np.real(ifft(wf_shift))
23
+
24
+ @staticmethod
25
+ def add_gaussian_band_noise(
26
+ target_snr: float,
27
+ data: np.ndarray,
28
+ rng: np.random.Generator | None = None,
29
+ ) -> tuple[np.ndarray, float]:
30
+ if target_snr == 0:
31
+ return data, 0.0
32
+ rng = rng or np.random.default_rng()
33
+ signal_energy = np.linalg.norm(data) ** 2
34
+ noise_energy = signal_energy / (10 ** (target_snr / 10))
35
+ initial_noise = rng.normal(loc=0, scale=1, size=data.shape)
36
+ noise = filtfilt(
37
+ np.ones(3) / 3,
38
+ 1,
39
+ filtfilt(np.ones(3) / 3, 1, initial_noise.T, method="gust").T,
40
+ method="gust",
41
+ )
42
+ noise = noise * np.sqrt(noise_energy / np.linalg.norm(noise) ** 2)
43
+ noisy_data = data + noise
44
+ actual_snr = 10 * np.log10(signal_energy / np.linalg.norm(noise) ** 2)
45
+ return noisy_data, float(actual_snr)
46
+
47
+ @staticmethod
48
+ def add_gaussian_noise(
49
+ target_snr: float,
50
+ data: np.ndarray,
51
+ rng: np.random.Generator | None = None,
52
+ ) -> tuple[np.ndarray, float]:
53
+ if target_snr == 0:
54
+ return data, 0.0
55
+ rng = rng or np.random.default_rng()
56
+ signal_energy = np.linalg.norm(data) ** 2
57
+ noise_energy = signal_energy / (10 ** (target_snr / 10))
58
+ noise_std = np.sqrt(noise_energy / data.size)
59
+ noise = rng.normal(0, noise_std, data.shape)
60
+ noisy_data = data + noise
61
+ actual_snr = 10 * np.log10(signal_energy / np.linalg.norm(noise) ** 2)
62
+ return noisy_data, float(actual_snr)
63
+
64
+
65
+ class OverthrustTrueimpDataset(SeismicBase, Dataset):
66
+ """Overthrust benchmark dataset used by seismic-lddpm evaluation."""
67
+
68
+ def __init__(
69
+ self,
70
+ size: int = 256,
71
+ interval: int = 1,
72
+ special_splits: bool = False,
73
+ use_mask: bool = False,
74
+ record_noraml: bool = True,
75
+ normalize: str = "minmax",
76
+ zhengyan_type: str = "linear",
77
+ train_keys: tuple[str, ...] | list[str] = ("image", "dipin", "record"),
78
+ ricks: tuple[int, ...] | list[int] = (30,),
79
+ ricks_phase: tuple[int, ...] | list[int] = (0,),
80
+ noise_snr: tuple[int, ...] | list[int] = (0,),
81
+ noise_type: str = "guassian_band",
82
+ dipins: tuple[float, ...] | list[float] = (0.012,),
83
+ dipin_nsmoothz: int = 20,
84
+ dipin_nsmoothx: int = 20,
85
+ patch_indices: tuple[int, ...] | list[int] | None = None,
86
+ base_seed: int = 1234,
87
+ data_dir: str | Path | None = None,
88
+ cache_dir: str | Path = "outputs/cache",
89
+ fixed_f0: int | None = None,
90
+ fixed_dipin_v: float | None = None,
91
+ fixed_noise_snr: int | None = None,
92
+ fixed_f0_phase: int | None = None,
93
+ ):
94
+ self.name = "Overthrust_trueimp"
95
+ self.size = size
96
+ self.interval = interval
97
+ self.special_splits = special_splits
98
+ self.use_mask = use_mask
99
+ self.record_noraml = record_noraml
100
+ self.normalize = normalize
101
+ self.zhengyan_type = zhengyan_type
102
+ self.train_keys = list(train_keys)
103
+ self.ricks = list(ricks)
104
+ self.ricks_phase = list(ricks_phase)
105
+ self.noise_snr = list(noise_snr)
106
+ self.noise_type = noise_type
107
+ self.dipins = list(dipins)
108
+ self.dipin_nsmoothz = dipin_nsmoothz
109
+ self.dipin_nsmoothx = dipin_nsmoothx
110
+ self.base_seed = base_seed
111
+ self.have_exp = False
112
+ self.info: dict[str, float | str] = {}
113
+ self.fixed_f0 = fixed_f0
114
+ self.fixed_dipin_v = fixed_dipin_v
115
+ self.fixed_noise_snr = fixed_noise_snr
116
+ self.fixed_f0_phase = fixed_f0_phase
117
+ self.data_dir = Path(data_dir or os.getenv("DATASET_DIR", "data"))
118
+ self.cache_dir = Path(cache_dir)
119
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
120
+
121
+ self._load_big_impedance()
122
+ self._build_splits_and_patches(special_splits=special_splits)
123
+ self._build_wavelets()
124
+ self.big_reflect = self._load_or_build_reflect()
125
+ self.record_data = {
126
+ f0: {
127
+ phase: {
128
+ snr: self._patches_from_big_image(
129
+ self._load_or_build_record(f0=f0, phase=phase, noise_snr=snr)
130
+ )
131
+ for snr in self.noise_snr
132
+ }
133
+ for phase in self.ricks_phase
134
+ }
135
+ for f0 in self.ricks
136
+ }
137
+ self.dipin_datas = {
138
+ dipin_v: self._patches_from_big_image(self._load_or_build_dipin(dipin_v))
139
+ for dipin_v in self.dipins
140
+ }
141
+ all_indices = list(range(len(self.splits)))
142
+ self.patch_indices = all_indices if patch_indices is None else list(patch_indices)
143
+
144
+ def __len__(self) -> int:
145
+ return len(self.patch_indices)
146
+
147
+ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
148
+ patch_idx = self.patch_indices[index]
149
+ f0 = self.fixed_f0 if self.fixed_f0 is not None else random.choice(self.ricks)
150
+ dipin_v = (
151
+ self.fixed_dipin_v
152
+ if self.fixed_dipin_v is not None
153
+ else random.choice(self.dipins)
154
+ )
155
+ noise_snr = (
156
+ self.fixed_noise_snr
157
+ if self.fixed_noise_snr is not None
158
+ else random.choice(self.noise_snr)
159
+ )
160
+ f0_phase = (
161
+ self.fixed_f0_phase
162
+ if self.fixed_f0_phase is not None
163
+ else random.choice(self.ricks_phase)
164
+ )
165
+ sample = {
166
+ "patch_idx": torch.tensor(patch_idx, dtype=torch.long),
167
+ "seed": torch.tensor(
168
+ self.base_seed + index + int(noise_snr) * 100, dtype=torch.long
169
+ ),
170
+ }
171
+ if "image" in self.train_keys:
172
+ sample["image"] = torch.from_numpy(self.file_data[patch_idx]).float()
173
+ if "dipin" in self.train_keys:
174
+ sample["dipin"] = torch.from_numpy(self.dipin_datas[dipin_v][patch_idx]).float()
175
+ sample["dipin_v"] = torch.tensor(dipin_v, dtype=torch.float32).reshape(1, 1, 1)
176
+ if "record" in self.train_keys:
177
+ sample["record"] = torch.from_numpy(
178
+ self.record_data[f0][f0_phase][noise_snr][patch_idx]
179
+ ).float()
180
+ sample["rick_v"] = torch.tensor(f0, dtype=torch.float32).reshape(1, 1, 1)
181
+ sample["rick_phase"] = torch.tensor(f0_phase, dtype=torch.float32).reshape(1, 1, 1)
182
+ sample["snr_v"] = torch.tensor(noise_snr, dtype=torch.float32).reshape(1, 1, 1)
183
+ if "reflection" in self.train_keys:
184
+ sample["reflection"] = torch.from_numpy(self.reflect_data[patch_idx]).float()
185
+ if "mask_speed" in self.train_keys:
186
+ sample["mask_speed"] = torch.from_numpy(
187
+ self.mask_data[patch_idx] * self.file_data[patch_idx]
188
+ ).float()
189
+ if self.use_mask:
190
+ sample["mask"] = torch.from_numpy(self.mask_data[patch_idx]).float()
191
+ return sample
192
+
193
+ def fan(self, x: np.ndarray) -> np.ndarray:
194
+ minn = 5.0931
195
+ maxn = 6.501110975896774
196
+ return np.exp(x * (maxn - minn) + minn) * 10.9 + 200
197
+
198
+ def inv_normal(self, x: np.ndarray) -> np.ndarray:
199
+ vmin = float(self.info["normal_min"])
200
+ vmax = float(self.info["normal_max"])
201
+ if self.normalize == "minmax":
202
+ return x * (vmax - vmin) + vmin
203
+ return x * vmax
204
+
205
+ def _load_big_impedance(self) -> None:
206
+ file_path = self.data_dir / "Overthrust_trueimp.mat"
207
+ if not file_path.exists():
208
+ raise FileNotFoundError(f"Overthrust data not found: {file_path}")
209
+ wave = scipy.io.loadmat(file_path)["Overthrust_trueimp"].T
210
+ wave = np.log(wave)
211
+ normal_min = wave.min()
212
+ normal_max = wave.max()
213
+ self.info.update(
214
+ {"normal_min": normal_min, "normal_max": normal_max, "normal": "max"}
215
+ )
216
+ self.big_img_unnorm = wave
217
+ self.big_speedimg = wave
218
+ if self.normalize == "max":
219
+ wave = wave / normal_max
220
+ elif self.normalize == "minmax":
221
+ wave = (wave - normal_min) / (normal_max - normal_min)
222
+ else:
223
+ raise ValueError(f"Unsupported normalize: {self.normalize}")
224
+ self.big_img = wave.astype(np.float32)
225
+
226
+ def _build_splits_and_patches(self, special_splits: bool = False) -> None:
227
+ self.big_mask = np.zeros(self.big_img.shape, dtype=np.float32)
228
+ for col in (100, 200, 300):
229
+ if col < self.big_mask.shape[1]:
230
+ self.big_mask[:, col : col + 1] = 1
231
+ if special_splits:
232
+ splits = []
233
+ for x in range(0, 551 - self.size, 20):
234
+ for y in range(0, 551 - self.size, 20):
235
+ splits.append((x, y))
236
+ for y in range(0, 551 - self.size, 9):
237
+ splits.extend([(30, y), (90, y), (140, y)])
238
+ elif self.size == 256:
239
+ splits = [
240
+ (0, 0),
241
+ (146, 0),
242
+ (551 - 256, 0),
243
+ (0, 145),
244
+ (146, 145),
245
+ (551 - 256, 145),
246
+ ]
247
+ else:
248
+ splits = []
249
+ interval_size = self.size - 1
250
+ for r in range(0, self.big_img.shape[0] - self.size, interval_size):
251
+ for c in range(0, self.big_img.shape[1] - self.size, interval_size):
252
+ splits.append((r, c))
253
+ splits.append((r, self.big_img.shape[1] - self.size))
254
+ for c in range(0, self.big_img.shape[1] - self.size, interval_size):
255
+ splits.append((self.big_img.shape[0] - self.size, c))
256
+ splits.append(
257
+ (self.big_img.shape[0] - self.size, self.big_img.shape[1] - self.size)
258
+ )
259
+
260
+ self.splits = []
261
+ patches = []
262
+ masks = []
263
+ for x, y in splits:
264
+ x2 = x + self.size
265
+ y2 = y + self.size
266
+ if x2 > self.big_img.shape[0] or y2 > self.big_img.shape[1]:
267
+ continue
268
+ self.splits.append((x, y))
269
+ patches.append(self.big_img[x:x2, y:y2].reshape(1, self.size, self.size))
270
+ masks.append(self.big_mask[x:x2, y:y2].reshape(1, self.size, self.size))
271
+ self.file_data = np.stack(patches, axis=0).astype(np.float32)[:: self.interval]
272
+ self.mask_data = np.stack(masks, axis=0).astype(np.float32)[:: self.interval]
273
+ self.splits = self.splits[:: self.interval]
274
+
275
+ def _build_wavelets(self) -> None:
276
+ nt0 = 256
277
+ dt0 = 0.002
278
+ self.wavelets = {}
279
+ for f0 in self.ricks:
280
+ self.wavelets[f0] = {}
281
+ wav = pylops.utils.wavelets.ricker(np.arange(nt0 // 2) * dt0, f0)[0]
282
+ for phase in self.ricks_phase:
283
+ self.wavelets[f0][phase] = self.phaseshift(wav, phase)
284
+
285
+ def _cache_path(self, name: str) -> Path:
286
+ return self.cache_dir / name
287
+
288
+ def _load_or_build_reflect(self) -> np.ndarray:
289
+ cache_path = self._cache_path(
290
+ f"Overthrust_trueimpBig_sesimic_reflect_{self.zhengyan_type}.npy"
291
+ )
292
+ if not cache_path.exists():
293
+ size = self.big_img.shape[0]
294
+ if self.zhengyan_type == "linear":
295
+ s1 = np.diag(0.5 * np.ones(size - 1, dtype="float32"), k=1) - np.diag(
296
+ 0.5 * np.ones(size - 1, dtype="float32"), k=-1
297
+ )
298
+ s1[-1] = s1[0] = 0
299
+ reflect = s1 @ self.big_img
300
+ elif self.zhengyan_type == "nonlinear":
301
+ expspeed = (
302
+ np.exp(self.big_img_unnorm)
303
+ if self.have_exp is False
304
+ else self.big_img_unnorm
305
+ )
306
+ s1 = np.eye(size, k=1) - np.eye(size, k=0)
307
+ s2 = np.eye(size, k=1) + np.eye(size, k=0)
308
+ s1[-1] = 0
309
+ s2[-1] = 0
310
+ numerator = s1 @ expspeed
311
+ denominator = s2 @ expspeed
312
+ denominator = np.where(denominator < 1e-6, 1e-6, denominator)
313
+ reflect = numerator / denominator
314
+ else:
315
+ raise ValueError(f"Unsupported zhengyan_type: {self.zhengyan_type}")
316
+ np.save(cache_path, reflect)
317
+ reflect = np.load(cache_path).astype(np.float32)
318
+ self.reflect_data = self._patches_from_big_image(reflect)
319
+ return reflect
320
+
321
+ def _load_or_build_record(self, f0: int, phase: int, noise_snr: int) -> np.ndarray:
322
+ cache_path = self._cache_path(
323
+ f"Overthrust_trueimpBig_sesimic_record__{self.zhengyan_type}"
324
+ f"_ricker={f0:02d}-{phase:03d}_{self.noise_type}={noise_snr:02d}"
325
+ f"_seed={self.base_seed}.npy"
326
+ )
327
+ if not cache_path.exists():
328
+ wav = self.wavelets[f0][phase]
329
+ w_mat = pylops.utils.signalprocessing.convmtx(
330
+ wav, self.big_reflect.shape[0], len(wav) // 2
331
+ )[: self.big_reflect.shape[0]]
332
+ records_clear = w_mat @ self.big_reflect
333
+ rng = np.random.default_rng(self.base_seed + f0 * 1000 + phase * 10 + noise_snr)
334
+ if self.noise_type == "guassian_band":
335
+ record, _ = self.add_gaussian_band_noise(noise_snr, records_clear, rng=rng)
336
+ elif self.noise_type == "guassian":
337
+ record, _ = self.add_gaussian_noise(noise_snr, records_clear, rng=rng)
338
+ else:
339
+ raise ValueError(f"Unsupported noise_type: {self.noise_type}")
340
+ np.save(cache_path, record)
341
+ record = np.load(cache_path).astype(np.float32)
342
+ self.info.update(
343
+ {
344
+ "record_minn": min(float(self.info.get("record_minn", 10)), float(record.min())),
345
+ "record_maxn": max(float(self.info.get("record_maxn", -10)), float(record.max())),
346
+ "record_normal": "max",
347
+ }
348
+ )
349
+ if self.record_noraml:
350
+ record = record / 0.3215932963300079
351
+ self.info["record_maxn"] = 0.3215932963300079
352
+ return record
353
+
354
+ def _load_or_build_dipin(self, dipin_v: float) -> np.ndarray:
355
+ cache_path = self._cache_path(
356
+ f"Overthrust_trueimpBig_sesimic_dipin={dipin_v:.03f}.npy"
357
+ )
358
+ if not cache_path.exists():
359
+ bb, aa = butter(2, dipin_v, "low")
360
+ smooth_filter_z = np.ones(self.dipin_nsmoothz) / float(self.dipin_nsmoothz)
361
+ smooth_filter_x = np.ones(self.dipin_nsmoothx) / float(self.dipin_nsmoothx)
362
+ mback = filtfilt(bb, aa, self.big_img.T).T
363
+ mback = filtfilt(smooth_filter_z, 1, mback, axis=0)
364
+ mback = filtfilt(smooth_filter_x, 1, mback, axis=1)
365
+ np.save(cache_path, mback)
366
+ return np.load(cache_path).astype(np.float32)
367
+
368
+ def _patches_from_big_image(self, big_image: np.ndarray) -> np.ndarray:
369
+ patches = []
370
+ for x, y in self.splits:
371
+ patches.append(
372
+ big_image[x : x + self.size, y : y + self.size].reshape(
373
+ 1, self.size, self.size
374
+ )
375
+ )
376
+ return np.stack(patches, axis=0).astype(np.float32)
inference/eval.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+ from scipy.stats import pearsonr
12
+ from skimage.metrics import structural_similarity
13
+ from torch.utils.data import DataLoader
14
+
15
+ REPO_ROOT = Path(__file__).resolve().parents[1]
16
+ if str(REPO_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(REPO_ROOT))
18
+
19
+ from inference.dataset import OverthrustTrueimpDataset
20
+ from pipeline import SeismicImpInvLDDPMPipeline
21
+
22
+
23
+ OVERTHRUST_CONFIG = {
24
+ "size": 256,
25
+ "patch_indices": [0, 1, 2, 3, 4, 5],
26
+ "noise_snr": 15,
27
+ "dipin_v": 0.012,
28
+ "f0": 30,
29
+ "f0_phase": 0,
30
+ "seed": 1234,
31
+ "zhengyan_type": "nonlinear",
32
+ "normalize": "minmax",
33
+ "batch_size": 3,
34
+ }
35
+
36
+
37
+ def stitch_patches(
38
+ patches: list[np.ndarray], splits: list[tuple[int, int]], big_shape: tuple[int, int], img_size: int
39
+ ) -> np.ndarray:
40
+ rec = np.zeros(big_shape, dtype=np.float32)
41
+ cnt = np.zeros(big_shape, dtype=np.float32)
42
+ for idx, (x, y) in enumerate(splits):
43
+ rec[x : x + img_size, y : y + img_size] += patches[idx]
44
+ cnt[x : x + img_size, y : y + img_size] += 1
45
+ return rec / np.maximum(cnt, 1)
46
+
47
+
48
+ def compute_metrics(prediction: np.ndarray, target: np.ndarray) -> dict[str, float]:
49
+ diff = prediction - target
50
+ denom = np.linalg.norm(diff.ravel()) ** 2
51
+ psnr = float("inf") if denom == 0 else float(
52
+ 10.0 * np.log10(len(prediction.ravel()) * np.max(prediction.ravel()) ** 2 / denom)
53
+ )
54
+ return {
55
+ "PSNR": psnr,
56
+ "rre": float(np.linalg.norm(diff.ravel()) / np.linalg.norm(target.ravel())),
57
+ "SSIM": float(structural_similarity(target, prediction, data_range=target.max())),
58
+ "PCC": float(pearsonr(prediction.ravel(), target.ravel()).statistic),
59
+ "nmse": float(np.sum(diff ** 2) / np.sum(target ** 2)),
60
+ "mse": float(np.mean(diff ** 2) / prediction.size),
61
+ }
62
+
63
+
64
+ def save_comparison(
65
+ target_impedance: np.ndarray,
66
+ prediction_impedance: np.ndarray,
67
+ output_path: Path,
68
+ ) -> None:
69
+ error = np.abs(target_impedance - prediction_impedance)
70
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
71
+ vmin_imp = min(target_impedance.min(), prediction_impedance.min())
72
+ vmax_imp = max(target_impedance.max(), prediction_impedance.max())
73
+ for ax, arr, title in zip(
74
+ axes,
75
+ [target_impedance, prediction_impedance, error],
76
+ ["Target (Impedance)", "Prediction (Impedance)", "Error (Impedance)"],
77
+ ):
78
+ if "Error" in title:
79
+ im = ax.imshow(arr, cmap="hot", vmin=0, vmax=error.max())
80
+ else:
81
+ im = ax.imshow(arr, cmap="jet", vmin=vmin_imp, vmax=vmax_imp)
82
+ ax.set_title(title)
83
+ ax.axis("off")
84
+ plt.colorbar(im, ax=ax, fraction=0.046)
85
+ plt.tight_layout()
86
+ fig.savefig(output_path, dpi=150)
87
+ plt.close(fig)
88
+
89
+
90
+ def evaluate_overthrust(
91
+ pipe: SeismicImpInvLDDPMPipeline,
92
+ output_dir: str | Path = "outputs/overthrust",
93
+ num_inference_steps: int = 1000,
94
+ device: str | torch.device | None = None,
95
+ ) -> dict[str, object]:
96
+ output_dir = Path(output_dir)
97
+ output_dir.mkdir(parents=True, exist_ok=True)
98
+ device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
99
+ pipe = pipe.to(device)
100
+
101
+ dataset = OverthrustTrueimpDataset(
102
+ size=OVERTHRUST_CONFIG["size"],
103
+ normalize=OVERTHRUST_CONFIG["normalize"],
104
+ zhengyan_type=OVERTHRUST_CONFIG["zhengyan_type"],
105
+ ricks=[OVERTHRUST_CONFIG["f0"]],
106
+ ricks_phase=[OVERTHRUST_CONFIG["f0_phase"]],
107
+ noise_snr=[OVERTHRUST_CONFIG["noise_snr"]],
108
+ dipins=[OVERTHRUST_CONFIG["dipin_v"]],
109
+ record_noraml=True,
110
+ train_keys=["image", "dipin", "record"],
111
+ patch_indices=OVERTHRUST_CONFIG["patch_indices"],
112
+ base_seed=OVERTHRUST_CONFIG["seed"],
113
+ data_dir=REPO_ROOT / "data",
114
+ cache_dir=output_dir / "cache",
115
+ fixed_f0=OVERTHRUST_CONFIG["f0"],
116
+ fixed_dipin_v=OVERTHRUST_CONFIG["dipin_v"],
117
+ fixed_noise_snr=OVERTHRUST_CONFIG["noise_snr"],
118
+ fixed_f0_phase=OVERTHRUST_CONFIG["f0_phase"],
119
+ )
120
+ loader = DataLoader(
121
+ dataset,
122
+ batch_size=OVERTHRUST_CONFIG["batch_size"],
123
+ shuffle=False,
124
+ num_workers=0,
125
+ )
126
+
127
+ all_predictions: list[np.ndarray] = []
128
+ all_targets: list[np.ndarray] = []
129
+ all_reconstructions: list[np.ndarray] = []
130
+ for batch in loader:
131
+ seeds = batch["seed"].tolist()
132
+ dipin = batch["dipin"].to(device)
133
+ record = batch["record"].to(device)
134
+ image = batch["image"].to(device)
135
+ output = pipe(
136
+ dipin=dipin,
137
+ record=record,
138
+ image=image,
139
+ num_inference_steps=num_inference_steps,
140
+ seeds=seeds,
141
+ )
142
+ prediction = output.impedance_samples
143
+ reconstruction = output.impedance_reconstructed
144
+ for local_idx in range(prediction.shape[0]):
145
+ all_predictions.append(prediction[local_idx, 0].detach().cpu().numpy())
146
+ all_targets.append(image[local_idx, 0].detach().cpu().numpy())
147
+ all_reconstructions.append(reconstruction[local_idx, 0].detach().cpu().numpy())
148
+
149
+ full_target = stitch_patches(
150
+ all_targets, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
151
+ )
152
+ full_prediction = stitch_patches(
153
+ all_predictions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
154
+ )
155
+ full_reconstruction = stitch_patches(
156
+ all_reconstructions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
157
+ )
158
+
159
+ full_target_impedance = dataset.fan(full_target)
160
+ full_prediction_impedance = dataset.fan(full_prediction)
161
+ full_reconstruction_impedance = dataset.fan(full_reconstruction)
162
+
163
+ metrics_summary = {
164
+ "config": {**OVERTHRUST_CONFIG, "num_inference_steps": num_inference_steps},
165
+ "normalized": compute_metrics(full_prediction, full_target),
166
+ "impedance": compute_metrics(full_prediction_impedance, full_target_impedance),
167
+ "encode_impedance": compute_metrics(
168
+ full_reconstruction_impedance, full_target_impedance
169
+ ),
170
+ }
171
+
172
+ paths = {
173
+ "full_target": output_dir / "full_target.npy",
174
+ "full_prediction": output_dir / "full_prediction.npy",
175
+ "full_reconstruction": output_dir / "full_reconstruction.npy",
176
+ "comparison": output_dir / "comparison_impedance.png",
177
+ "metrics": output_dir / "metrics_summary.json",
178
+ }
179
+ np.save(paths["full_target"], full_target)
180
+ np.save(paths["full_prediction"], full_prediction)
181
+ np.save(paths["full_reconstruction"], full_reconstruction)
182
+ save_comparison(full_target_impedance, full_prediction_impedance, paths["comparison"])
183
+ paths["metrics"].write_text(json.dumps(metrics_summary, indent=2), encoding="utf-8")
184
+ return {
185
+ "metrics": metrics_summary,
186
+ "paths": {key: str(value) for key, value in paths.items()},
187
+ }
188
+
189
+
190
+ def parse_args() -> argparse.Namespace:
191
+ parser = argparse.ArgumentParser(description="Evaluate seismic-lddpm on Overthrust.")
192
+ parser.add_argument("--model", default="mally-2000/seismic-lddpm")
193
+ parser.add_argument("--output", default="outputs/overthrust")
194
+ parser.add_argument("--device", default=None)
195
+ parser.add_argument("--num-inference-steps", type=int, default=1000)
196
+ return parser.parse_args()
197
+
198
+
199
+ def main() -> None:
200
+ args = parse_args()
201
+ pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
202
+ args.model,
203
+ torch_dtype=torch.float32,
204
+ trust_remote_code=True,
205
+ )
206
+ result = evaluate_overthrust(
207
+ pipe,
208
+ output_dir=args.output,
209
+ num_inference_steps=args.num_inference_steps,
210
+ device=args.device,
211
+ )
212
+ print(json.dumps(result, indent=2))
213
+
214
+
215
+ if __name__ == "__main__":
216
+ main()
inference/infer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ REPO_ROOT = Path(__file__).resolve().parents[1]
12
+ if str(REPO_ROOT) not in sys.path:
13
+ sys.path.insert(0, str(REPO_ROOT))
14
+
15
+ from pipeline import SeismicImpInvLDDPMPipeline
16
+
17
+
18
+ def load_bchw_npy(path: str | Path) -> torch.Tensor:
19
+ arr = np.load(path).astype(np.float32)
20
+ if arr.ndim == 2:
21
+ arr = arr[None, None, :, :]
22
+ elif arr.ndim == 3:
23
+ arr = arr[None, :, :, :]
24
+ elif arr.ndim != 4:
25
+ raise ValueError(f"Expected 2D, 3D, or 4D array at {path}, got shape {arr.shape}")
26
+ return torch.from_numpy(arr)
27
+
28
+
29
+ def save_prediction_png(prediction: np.ndarray, output_path: Path) -> None:
30
+ fig, ax = plt.subplots(figsize=(5, 5))
31
+ im = ax.imshow(prediction, cmap="jet")
32
+ ax.axis("off")
33
+ plt.colorbar(im, ax=ax, fraction=0.046)
34
+ plt.tight_layout()
35
+ fig.savefig(output_path, dpi=150)
36
+ plt.close(fig)
37
+
38
+
39
+ def parse_args() -> argparse.Namespace:
40
+ parser = argparse.ArgumentParser(description="Run seismic-lddpm inference on one sample.")
41
+ parser.add_argument("--dipin", required=True, help="Path to low-frequency impedance .npy")
42
+ parser.add_argument("--record", required=True, help="Path to seismic record .npy")
43
+ parser.add_argument("--model", default="mally-2000/seismic-lddpm")
44
+ parser.add_argument("--output", default="outputs/single")
45
+ parser.add_argument("--device", default=None)
46
+ parser.add_argument("--seed", type=int, default=1234)
47
+ parser.add_argument("--num-inference-steps", type=int, default=1000)
48
+ return parser.parse_args()
49
+
50
+
51
+ def main() -> None:
52
+ args = parse_args()
53
+ output_dir = Path(args.output)
54
+ output_dir.mkdir(parents=True, exist_ok=True)
55
+ device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
56
+ pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
57
+ args.model,
58
+ torch_dtype=torch.float32,
59
+ trust_remote_code=True,
60
+ ).to(device)
61
+ dipin = load_bchw_npy(args.dipin).to(device)
62
+ record = load_bchw_npy(args.record).to(device)
63
+ output = pipe(
64
+ dipin=dipin,
65
+ record=record,
66
+ num_inference_steps=args.num_inference_steps,
67
+ seed=args.seed,
68
+ )
69
+ prediction = output.impedance_samples.detach().cpu().numpy()
70
+ np.save(output_dir / "prediction.npy", prediction)
71
+ save_prediction_png(prediction[0, 0], output_dir / "prediction.png")
72
+ print(f"Saved: {output_dir / 'prediction.npy'}")
73
+ print(f"Saved: {output_dir / 'prediction.png'}")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
pipeline.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DDPMScheduler, DiffusionPipeline, UNet2DModel, VQModel
8
+ from diffusers.utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class SeismicImpInvLDDPMPipelineOutput(BaseOutput):
13
+ impedance_samples: torch.Tensor | np.ndarray
14
+ impedance_latents: torch.Tensor | np.ndarray
15
+ impedance_dipin: torch.Tensor | np.ndarray
16
+ impedance_reconstructed: torch.Tensor | np.ndarray | None = None
17
+ record_features: torch.Tensor | np.ndarray | None = None
18
+
19
+
20
+ class SeismicImpInvLDDPMPipeline(DiffusionPipeline):
21
+ """SAII-LDDPM impedance inversion pipeline."""
22
+
23
+ def __init__(
24
+ self,
25
+ vq_model: VQModel,
26
+ condition_encoder: torch.nn.Module,
27
+ unet: UNet2DModel,
28
+ scheduler: DDPMScheduler,
29
+ ):
30
+ super().__init__()
31
+ self.register_modules(
32
+ vq_model=vq_model,
33
+ condition_encoder=condition_encoder,
34
+ unet=unet,
35
+ scheduler=scheduler,
36
+ )
37
+
38
+ def _encode_conditioning(
39
+ self, dipin: torch.Tensor, record: torch.Tensor
40
+ ) -> tuple[torch.Tensor, torch.Tensor]:
41
+ dipin_latents = self.vq_model.encode(dipin).latents
42
+ if hasattr(self.condition_encoder, "encode") and callable(
43
+ self.condition_encoder.encode
44
+ ):
45
+ record_features = self.condition_encoder.encode(record)
46
+ else:
47
+ record_features = self.condition_encoder(record)
48
+ return (
49
+ dipin_latents.to(dtype=self.unet.dtype),
50
+ record_features.to(dtype=self.unet.dtype),
51
+ )
52
+
53
+ @staticmethod
54
+ def _extract_into_tensor(
55
+ arr: torch.Tensor, timesteps: torch.Tensor, broadcast_shape: torch.Size
56
+ ) -> torch.Tensor:
57
+ values = arr.to(device=timesteps.device, dtype=torch.float32).gather(0, timesteps)
58
+ return values.reshape(timesteps.shape[0], *((1,) * (len(broadcast_shape) - 1)))
59
+
60
+ @staticmethod
61
+ def _build_legacy_ddpm_buffers(
62
+ scheduler: DDPMScheduler, device: torch.device
63
+ ) -> dict[str, torch.Tensor]:
64
+ betas = scheduler.betas.to(device=device, dtype=torch.float32)
65
+ alphas = 1.0 - betas
66
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
67
+ alphas_cumprod_prev = torch.cat(
68
+ [torch.ones(1, device=device), alphas_cumprod[:-1]], dim=0
69
+ )
70
+ posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
71
+ posterior_log_variance_clipped = torch.log(
72
+ torch.clamp(posterior_variance, min=1e-20)
73
+ )
74
+ return {
75
+ "sqrt_recip_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod),
76
+ "sqrt_recipm1_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod - 1),
77
+ "posterior_mean_coef1": betas
78
+ * torch.sqrt(alphas_cumprod_prev)
79
+ / (1.0 - alphas_cumprod),
80
+ "posterior_mean_coef2": (1.0 - alphas_cumprod_prev)
81
+ * torch.sqrt(alphas)
82
+ / (1.0 - alphas_cumprod),
83
+ "posterior_log_variance_clipped": posterior_log_variance_clipped,
84
+ }
85
+
86
+ @staticmethod
87
+ def _randn_like_sample(
88
+ sample: torch.Tensor, generator: torch.Generator | list[torch.Generator] | None
89
+ ) -> torch.Tensor:
90
+ if isinstance(generator, list):
91
+ if len(generator) != sample.shape[0]:
92
+ raise ValueError(
93
+ f"Expected {sample.shape[0]} generators, got {len(generator)}"
94
+ )
95
+ return torch.cat(
96
+ [
97
+ torch.randn(
98
+ sample[i : i + 1].shape,
99
+ generator=sample_generator,
100
+ device=sample.device,
101
+ dtype=sample.dtype,
102
+ )
103
+ for i, sample_generator in enumerate(generator)
104
+ ],
105
+ dim=0,
106
+ )
107
+ return torch.randn(
108
+ sample.shape, generator=generator, device=sample.device, dtype=sample.dtype
109
+ )
110
+
111
+ def _ddpm_step(
112
+ self,
113
+ latents: torch.Tensor,
114
+ conditioning: torch.Tensor,
115
+ timestep: torch.Tensor,
116
+ generator: torch.Generator | list[torch.Generator] | None,
117
+ buffers: dict[str, torch.Tensor],
118
+ ) -> torch.Tensor:
119
+ model_input = torch.cat([latents, conditioning], dim=1)
120
+ noise_pred = self.unet(model_input, timestep).sample
121
+ pred_x0 = (
122
+ self._extract_into_tensor(
123
+ buffers["sqrt_recip_alphas_cumprod"], timestep, latents.shape
124
+ )
125
+ * latents
126
+ - self._extract_into_tensor(
127
+ buffers["sqrt_recipm1_alphas_cumprod"], timestep, latents.shape
128
+ )
129
+ * noise_pred
130
+ )
131
+ pred_x0 = self.vq_model.quantize(pred_x0)[0]
132
+ model_mean = (
133
+ self._extract_into_tensor(
134
+ buffers["posterior_mean_coef1"], timestep, latents.shape
135
+ )
136
+ * pred_x0
137
+ + self._extract_into_tensor(
138
+ buffers["posterior_mean_coef2"], timestep, latents.shape
139
+ )
140
+ * latents
141
+ )
142
+ noise = self._randn_like_sample(latents, generator)
143
+ nonzero_mask = (1 - (timestep == 0).float()).reshape(
144
+ latents.shape[0], *((1,) * (len(latents.shape) - 1))
145
+ )
146
+ return model_mean + nonzero_mask * (
147
+ 0.5
148
+ * self._extract_into_tensor(
149
+ buffers["posterior_log_variance_clipped"], timestep, latents.shape
150
+ )
151
+ ).exp() * noise
152
+
153
+ @torch.no_grad()
154
+ def __call__(
155
+ self,
156
+ dipin: torch.Tensor,
157
+ record: torch.Tensor,
158
+ image: torch.Tensor | None = None,
159
+ num_inference_steps: int = 1000,
160
+ seed: int | None = None,
161
+ seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None,
162
+ generator: torch.Generator | None = None,
163
+ output_type: str = "tensor",
164
+ ) -> SeismicImpInvLDDPMPipelineOutput:
165
+ device = self.unet.device
166
+ if seeds is not None:
167
+ if isinstance(seeds, torch.Tensor):
168
+ seeds = seeds.detach().cpu().tolist()
169
+ seeds = [int(value) for value in seeds]
170
+ if len(seeds) != dipin.shape[0]:
171
+ raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}")
172
+ generator = [
173
+ torch.Generator(device=device).manual_seed(value) for value in seeds
174
+ ]
175
+ elif seed is not None:
176
+ generator = torch.Generator(device=device).manual_seed(seed)
177
+ elif generator is None:
178
+ generator = torch.Generator(device=device)
179
+
180
+ dipin = dipin.to(device=device, dtype=self.vq_model.dtype)
181
+ record = record.to(device=device, dtype=self.unet.dtype)
182
+ impedance_dipin, record_features = self._encode_conditioning(dipin, record)
183
+ conditioning = torch.cat([impedance_dipin, record_features], dim=1)
184
+ impedance_latents = self._randn_like_sample(
185
+ torch.empty(
186
+ impedance_dipin.shape,
187
+ device=device,
188
+ dtype=self.unet.dtype,
189
+ ),
190
+ generator,
191
+ )
192
+ buffers = self._build_legacy_ddpm_buffers(self.scheduler, device)
193
+ for t in reversed(range(num_inference_steps)):
194
+ timestep = torch.full(
195
+ (impedance_latents.shape[0],), t, device=device, dtype=torch.long
196
+ )
197
+ impedance_latents = self._ddpm_step(
198
+ impedance_latents, conditioning, timestep, generator, buffers
199
+ )
200
+
201
+ impedance_samples = self.vq_model.decode(
202
+ impedance_latents.to(dtype=self.vq_model.dtype)
203
+ ).sample
204
+ impedance_reconstructed = None
205
+ if image is not None:
206
+ image = image.to(device=device, dtype=self.vq_model.dtype)
207
+ image_latents = self.vq_model.encode(image).latents
208
+ impedance_reconstructed = self.vq_model.decode(image_latents).sample
209
+
210
+ if output_type == "np":
211
+ impedance_samples = impedance_samples.detach().cpu().numpy()
212
+ impedance_latents = impedance_latents.detach().cpu().numpy()
213
+ impedance_dipin = impedance_dipin.detach().cpu().numpy()
214
+ record_features = record_features.detach().cpu().numpy()
215
+ if impedance_reconstructed is not None:
216
+ impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy()
217
+
218
+ return SeismicImpInvLDDPMPipelineOutput(
219
+ impedance_samples=impedance_samples,
220
+ impedance_latents=impedance_latents,
221
+ impedance_dipin=impedance_dipin,
222
+ impedance_reconstructed=impedance_reconstructed,
223
+ record_features=record_features,
224
+ )
225
+
226
+ @torch.no_grad()
227
+ def encode_decode(
228
+ self, image: torch.Tensor, output_type: str = "tensor"
229
+ ) -> torch.Tensor | np.ndarray:
230
+ image = image.to(device=self.vq_model.device, dtype=self.vq_model.dtype)
231
+ reconstruction = self.vq_model.decode(self.vq_model.encode(image).latents).sample
232
+ if output_type == "np":
233
+ return reconstruction.detach().cpu().numpy()
234
+ return reconstruction
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.6.0
2
+ diffusers>=0.25.0,<0.31.0
3
+ accelerate>=0.25.0
4
+ safetensors>=0.4.0
5
+ numpy>=1.23.0,<2.0
6
+ scipy>=1.10.1
7
+ matplotlib>=3.9.4
8
+ scikit-image>=0.24.0
9
+ pylops==2.2.0
10
+ pytorch-wavelets>=1.3.0