mally-2000 commited on
Commit
d105891
·
verified ·
1 Parent(s): 5c1eece

Add SAII-CLDM inference pipeline

Browse files

Add a unified infer.py entry point for SAII-LDDPM and SAII-CLDM, add the differentiable Overthrust forward operator, update Overthrust evaluation to support CLDM, and align CLDM sampling defaults with the published SAII-CLDM resampling setup.

README.md CHANGED
@@ -1,113 +1,312 @@
1
- ---
2
- library_name: diffusers
3
- pipeline_tag: image-to-image
4
- tags:
5
- - seismic-inversion
6
- - impedance-inversion
7
- - diffusion
8
- - ddpm
9
- - overthrust
10
- ---
11
 
12
- # Seismic-LDDPM
13
 
14
- Seismic-LDDPM is a latent DDPM pipeline for seismic impedance inversion. The
15
- pipeline takes a low-frequency impedance image (`dipin`) and a synthetic seismic
16
- record (`record`) and predicts the impedance image.
17
 
18
- This repository includes:
 
19
 
20
- - Diffusers-format model components: `vq_model`, `unet`, `scheduler`, and
21
- `condition_encoder`.
22
- - `SeismicImpInvLDDPMPipeline` in `pipeline.py`.
23
- - A complete Overthrust benchmark sample at `data/Overthrust_trueimp.mat`.
24
- - Inference scripts under `inference/`.
25
 
26
- ## Installation
 
 
 
 
 
 
 
27
 
28
  ```bash
29
- git clone https://huggingface.co/mally-2000/seismic-lddpm
30
- cd seismic-lddpm
31
- pip install -r requirements.txt
32
  ```
33
 
34
- ## Overthrust Evaluation
35
 
36
- The Overthrust evaluation script is intentionally fixed to the bundled
37
- `data/Overthrust_trueimp.mat`. It cuts the full model into six `256 x 256`
38
- patches, synthesizes the seismic records and low-frequency impedance inputs,
39
- runs inference, stitches the six predictions back together, and computes the
40
- metrics.
41
 
 
42
  ```bash
43
- python inference/eval_overthrust.py \
44
- --model . \
45
- --output outputs/overthrust \
46
- --num-inference-steps 1000
47
  ```
48
 
49
- Outputs:
 
50
 
51
- - `outputs/overthrust/full_target.npy`
52
- - `outputs/overthrust/full_prediction.npy`
53
- - `outputs/overthrust/full_reconstruction.npy`
54
- - `outputs/overthrust/comparison_impedance.png`
55
- - `outputs/overthrust/metrics_summary.json`
 
 
56
 
57
- ## Benchmark Result
 
 
 
 
 
58
 
59
- Evaluated locally on the bundled Overthrust benchmark with 1000 DDPM steps,
60
- `noise_snr=15`, `dipin_v=0.012`, `f0=30`, `phase=0`, `seed=1234`, and patch
61
- indices `[0, 1, 2, 3, 4, 5]`.
62
 
63
- | Space | PSNR | SSIM | PCC | RRE | NMSE |
64
- |---|---:|---:|---:|---:|---:|
65
- | Normalized | 30.7698 | 0.9339 | 0.9963 | 0.0435 | 0.001894 |
66
- | Impedance | 33.4413 | 0.9554 | 0.9957 | 0.0324 | 0.001050 |
67
- | VQ reconstruction | 37.7954 | 0.9677 | 0.9983 | 0.0209 | 0.000435 |
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- ![Overthrust evaluation](assets/demo.png)
 
 
 
 
70
 
71
- ## Single-Sample Inference
72
 
73
- For a single default Overthrust patch:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  ```bash
76
- python inference/infer_LDDPM.py
 
 
 
 
 
 
 
77
  ```
78
 
79
- The script builds one Overthrust test sample internally, synthesizes the
80
- low-frequency impedance and seismic record, and saves `prediction.npy`,
81
- `target.npy`, and `comparison.png` under `outputs/infer_LDDPM`.
 
82
 
83
- ## Python Usage
84
 
85
- ```python
86
- import torch
87
- from pipeline import SeismicImpInvLDDPMPipeline
88
 
89
- pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
90
- "mally-2000/seismic-lddpm",
91
- torch_dtype=torch.float32,
92
- trust_remote_code=True,
93
- ).to("cuda")
94
 
95
- result = pipe(
96
- dipin=dipin, # torch.Tensor, BCHW
97
- record=record, # torch.Tensor, BCHW
98
- num_inference_steps=1000,
99
- seed=1234,
100
- )
101
 
102
- prediction = result.impedance_samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  ```
104
 
105
- ## Notes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- - `inference/dataset.py` contains a lightweight `SeismicBase` and
108
- `OverthrustTrueimpDataset`; it does not depend on the original training
109
- repository's `ldm.data.seisimic`.
110
- - Synthetic record generation is seeded through the benchmark configuration so
111
- the published Overthrust evaluation is reproducible.
112
- - The bundled Overthrust file is used only as a compact benchmark input for
113
- reproducing this model's inference pipeline.
 
1
+ ## 快速开始
 
 
 
 
 
 
 
 
 
2
 
3
+ ## Seismic-LDDPM 开源推理
4
 
5
+ 面向 Hugging Face 模型仓库的最小推理入口包括:
 
 
6
 
7
+ - `inference/infer.py`:统一单样本推理入口,默认使用 SAII-LDDPM,传入 `CLDM` 使用 SAII-CLDM。
8
+ - `inference/eval_overthrust.py`:固定 Overthrust benchmark,不接收外部数据路径,使用仓库内 `data/Overthrust_trueimp.mat`,完成 6 个 patch 推理、拼接、指标计算和对比图保存。
9
 
10
+ Overthrust 评估示例:
 
 
 
 
11
 
12
+ ```bash
13
+ python inference/eval_overthrust.py \
14
+ --model mally-2000/seismic-lddpm \
15
+ --output outputs/overthrust \
16
+ --num-inference-steps 1000
17
+ ```
18
+
19
+ 单样本推理示例:
20
 
21
  ```bash
22
+ python inference/infer.py
 
 
23
  ```
24
 
25
+ SAII-CLDM 单样本推理示例:
26
 
27
+ ```bash
28
+ python inference/infer.py CLDM
29
+ ```
 
 
30
 
31
+ ### 环境
32
  ```bash
33
+ uv sync
34
+ source .venv/bin/activate
 
 
35
  ```
36
 
37
+ ### 训练
38
+ 正式训练入口现在统一为 `scripts/train.py`,保留 YAML 配置驱动,但不再依赖 `pytorch_lightning.Trainer` 作为主线:
39
 
40
+ ```bash
41
+ uv run scripts/train.py \
42
+ --config-path configs/task/F02_diffusers.yaml \
43
+ --output-dir tmp/train_f02 \
44
+ --max-train-steps 10 \
45
+ data.params.batch_size=1 data.params.train.params.number=1
46
+ ```
47
 
48
+ 常用参数:
49
+ - `--config-path`:指定训练配置。
50
+ - `--output-dir`:训练 summary 和 `diffusers-export/` 输出目录。
51
+ - `--max-train-steps`:跑多少步;`0` 表示只做 dry-run / 装配检查。
52
+ - `--train-batch-size` / `--learning-rate`:覆盖 YAML 默认值。
53
+ - 额外的 `key=value` 参数会按 OmegaConf dotlist 覆盖 YAML。
54
 
55
+ `F02` 兼容包装脚本仍保留:
 
 
56
 
57
+ ```bash
58
+ uv run scripts/train_diffusers_f02.py --max-train-steps 10 data.params.batch_size=1
59
+ ```
60
+
61
+ 两阶段训练现在可以直接串起来:
62
+
63
+ ```bash
64
+ uv run scripts/train_two_stage_latent_diffusion.py \
65
+ --output-dir tmp/two_stage_train \
66
+ --stage1-max-train-steps 10 \
67
+ --stage2-max-train-steps 10 \
68
+ --stage1-set data.params.batch_size=1 \
69
+ --stage1-set data.params.train.params.number=1 \
70
+ --stage2-set data.params.batch_size=1 \
71
+ --stage2-set data.params.train.params.number=1
72
+ ```
73
 
74
+ 这个脚本会:
75
+ - 先跑 `configs/task/F01_diffusers.yaml`
76
+ - 再把 `stage1_vq/diffusers-export/vq_model` 自动注入 `F02` 的 `model.params.official_vq_pretrained_dir`
77
+ - `configs/task/F02_diffusers.yaml` 本身不再内置硬编码的 VQ / UNet / condition encoder 资产路径
78
+ - 最后在输出根目录写 `two_stage_summary.json`
79
 
80
+ `main.py` 仍可作为 legacy 入口使用,但不再是正式训练主线。
81
 
82
+ ## 迁移状态
83
+
84
+ 当前仓库关于 `2D -> diffusers` 的实际落地状态,以 [docs/diffusers_status.md](/root/test/cldm2/docs/diffusers_status.md) 为准。
85
+
86
+ 这轮额外补了两份面向落地使用的说明:
87
+
88
+ - [docs/colab_quickstart.md](/root/test/cldm2/docs/colab_quickstart.md)
89
+ - [docs/diffusers_cleanup_plan.md](/root/test/cldm2/docs/diffusers_cleanup_plan.md)
90
+
91
+ 简述:
92
+
93
+ - `2D` 推理主线已经迁到 `diffusers`
94
+ - `2D` 训练主线还没有迁完
95
+ - `3D` 运行入口和相关 Python 主线已清理,不在当前 diffusers 主线内
96
+ - 旧路径保留为回归 oracle,不作为正式 `2D` 用户入口
97
+
98
+ 当前训练迁移实验配置见:
99
+
100
+ - [docs/diffusers_f02_training_mvp.md](/root/test/cldm2/docs/diffusers_f02_training_mvp.md)
101
+ - `configs/task/F02_diffusers.yaml`
102
+ - `configs/task/F01_diffusers.yaml`
103
+
104
+ 当前已确认:
105
+
106
+ - `scripts/train.py --config-path configs/task/F02_diffusers.yaml` 已可作为新的统一训练入口
107
+ - `scripts/train.py` 可自动识别 `F02 diffusers`、legacy `F01/VQ` 和官方 `diffusers.VQModel` 的 `F01_diffusers`
108
+ - 训练结束会导出 `diffusers-export/`,其中包含 `vq_model/`、`unet/`、`scheduler/`、`condition_encoder.pt`
109
+ - `sample.py --export-dir ...` 在显式 dataset preset 模式下不再强制要求 `--project-config`,更适合 Colab
110
+ - `main.py` 仍保留为 legacy 兼容路径,不再建议作为主训练入口继续扩展
111
+
112
+ ## 2D 推理主线
113
+
114
+ 当前 2D 推理默认入口已经切到 diffusers 主线:
115
+ - `sample.py`:统一的 2D 推理入口,默认单样本,`batch` 子命令用于批量推理
116
+ - `ldm/pipelines/seismic_inversion_pipeline.py`:共享的 `SeismicInversionPipeline`
117
+
118
+ 这条主线使用:
119
+ - 官方 `diffusers.VQModel`
120
+ - 官方 `diffusers.UNet2DModel`
121
+ - 官方 scheduler buffer / timestep
122
+ - 自定义 inversion loop 来保留 `ddim_resample + DPS`
123
+
124
+ 历史 `3D` 入口已删除;当前仓库只保留 `2D` 推理与训练迁移主线。
125
+
126
+ ### 单样本推理
127
+ 默认走 `field_testdata`:
128
 
129
  ```bash
130
+ uv run sample.py \
131
+ --project-config /root/use_model_param/2025-04-18T15-59-17_A101/configs/2025-04-18T15-59-17-project.yaml \
132
+ --legacy-checkpoint /root/use_model_param/2025-04-18T15-59-17_A101/checkpoints/epoch=000211-step=000013991.ckpt \
133
+ --vq-dir /root/use_model_param/old_vqgan_diffusers_vqmodel \
134
+ --unet-dir tmp/a101_diffusers_unet \
135
+ --sample-index 0 \
136
+ --sampler-type ddim_resample \
137
+ --num-inference-steps 30
138
  ```
139
 
140
+ 默认输出:
141
+ - `output-dir/summary.json`
142
+ - `output-dir/sample_000_overview.png`
143
+ - `output-dir/sample_000_*.png`
144
 
145
+ 如需关闭图像保存:
146
 
147
+ ```bash
148
+ uv run sample.py --no-save-images
149
+ ```
150
 
151
+ 如需交互式展示保存后的 overview:
 
 
 
 
152
 
153
+ ```bash
154
+ uv run sample.py --show
155
+ ```
 
 
 
156
 
157
+ 如需优先加载训练导出的自包含模型组件:
158
+
159
+ ```bash
160
+ uv run sample.py \
161
+ --export-dir debuglogs/<run>/diffusers-export \
162
+ --sample-index 0 \
163
+ --dataset-name field_testdata \
164
+ --sampler-type ddim \
165
+ --num-inference-steps 30
166
+ ```
167
+
168
+ 说明:
169
+ - `vq_model` / `unet` / `scheduler` / `condition_encoder` 会从 `--export-dir` 加载
170
+ - 对 `field_testdata`、`feild_traindata`、`Marmousi*`、`Overthrust` 这类显式 dataset preset,不再强制需要 `project-config`
171
+ - 如果使用 `--dataset-name config_train`,仍然需要 `project-config`
172
+
173
+ ### 批量推理
174
+ 默认走 `feild_traindata` 兼容 preset:
175
+
176
+ ```bash
177
+ uv run sample.py batch \
178
+ --project-config /root/use_model_param/2025-04-18T15-59-17_A101/configs/2025-04-18T15-59-17-project.yaml \
179
+ --legacy-checkpoint /root/use_model_param/2025-04-18T15-59-17_A101/checkpoints/epoch=000211-step=000013991.ckpt \
180
+ --vq-dir /root/use_model_param/old_vqgan_diffusers_vqmodel \
181
+ --unet-dir tmp/a101_diffusers_unet \
182
+ --max-samples 8
183
+ ```
184
+
185
+ 默认输出:
186
+ - `output-dir/x_samples_2d.npy`
187
+ - `output-dir/x_true_2d.npy`
188
+ - `output-dir/summary.json`
189
+
190
+ 按需保存逐样本图像:
191
+
192
+ ```bash
193
+ uv run sample.py batch --save-images
194
+ ```
195
+
196
+ ## Colab
197
+
198
+ 当前最稳的 Colab 路线是:
199
+
200
+ 1. clone 仓库
201
+ 2. `pip install -r requirements-colab.txt`
202
+ 3. `pip install -e ./src/taming-transformers`
203
+ 4. 使用 `diffusers-export/` 直接跑 `sample.py`
204
+
205
+ 最小单样本命令示例:
206
+
207
+ ```bash
208
+ python sample.py \
209
+ --export-dir /content/drive/MyDrive/SAII-CLDM/two_stage_run/stage2_f02/diffusers-export \
210
+ --dataset-name Marmousi3 \
211
+ --dataset-dt-path /content/drive/MyDrive/SAII-CLDM/data/dtA89-1.npz \
212
+ --sample-index 0 \
213
+ --sampler-type ddim \
214
+ --num-inference-steps 30 \
215
+ --output-dir /content/drive/MyDrive/SAII-CLDM/colab_outputs/sample_single \
216
+ --device cuda
217
+ ```
218
+
219
+ 完整步骤见 [docs/colab_quickstart.md](/root/test/cldm2/docs/colab_quickstart.md)。
220
+
221
+ ### 常用推理参数
222
+ - `--dataset-name`:`config_train`、`feild_traindata`、`field_testdata`、`Marmousi*`、`Overthrust`
223
+ - `--dataset-interval`
224
+ - `--img-size`
225
+ - `--f0`
226
+ - `--f0-phase`
227
+ - `--dipin-v`
228
+ - `--noise-snr`
229
+ - `--zhengyan-type`
230
+ - `--noise-type`
231
+ - `--sampler-type`:`ddim_resample`、`ddim`、`ddpm`
232
+ - `--num-inference-steps`
233
+ - `--eta`
234
+ - `--use-dps` / `--no-use-dps`
235
+ - `--dps-scale`
236
+ - `--resample-interval`
237
+ - `--sigma-a`
238
+ - `--pixel-max-iters`
239
+ - `--last-pixel-max-iters`
240
+ - `--device`
241
+ - `--seed`
242
+ - `--export-dir`
243
+ - `batch --start-index`
244
+ - `batch --max-samples`
245
+
246
+ ## 推理流程
247
+
248
+ 当前 `sample.py` 的 2D 推理流程如下:
249
+
250
+ ```mermaid
251
+ sequenceDiagram
252
+ participant Entry as sample.py
253
+ participant Runner as diffusers_inference_runner
254
+ participant Dataset as Dataset preset
255
+ participant Operator as zhengyan operator
256
+ participant Pipe as SeismicInversionPipeline
257
+ participant UNet as diffusers.UNet2DModel
258
+ participant VQ as diffusers.VQModel
259
+
260
+ Entry->>Runner: 解析 CLI
261
+ Runner->>Runner: 加载 project config / cond encoder / scheduler
262
+ Runner->>Dataset: 构建 2D 数据集并取样
263
+ Runner->>Operator: 构建正演算子
264
+ Runner->>Pipe: 调用 pipeline(image, dipin, record, measurement)
265
+
266
+ activate Pipe
267
+ Pipe->>VQ: 编码 dipin / image
268
+ Pipe->>UNet: 逐步预测噪声
269
+ alt sampler_type = ddim_resample
270
+ Pipe->>Pipe: DPS + resample + 可选像素优化
271
+ else sampler_type = ddim / ddpm
272
+ Pipe->>Pipe: 标准逆扩散更新
273
+ end
274
+ Pipe->>VQ: 解码最终 latent
275
+ Pipe-->>Runner: prediction / latents / measurement error
276
+ deactivate Pipe
277
+
278
+ Runner->>Runner: 写 summary / npy / 可视化
279
  ```
280
 
281
+ ## 回归与对比
282
+
283
+ 旧采样实现没有作为用户主入口保留,但继续作为回归 oracle:
284
+
285
+ ```bash
286
+ uv run scripts/compare_legacy_and_diffusers_inversion.py \
287
+ --dataset-name config_train \
288
+ --sample-index 0 \
289
+ --sampler-type ddim_resample \
290
+ --num-inference-steps 10
291
+ ```
292
+
293
+ 这条脚本会同时产出:
294
+ - `legacy/` 可视化
295
+ - `official/` 可视化
296
+ - `summary.json`
297
+
298
+ 当前仓库约定:
299
+ - `legacy-like` / legacy 路径只用于验证
300
+ - diffusers pipeline 路径用于正式 2D 推理主线
301
+
302
+ 更多背景可见:
303
+ - `docs/diffusers_status.md`
304
+ - `docs/official_sampling_comparison.md`
305
+
306
+ ## 代码结构
307
 
308
+ - `ldm/pipelines`:diffusers pipeline 与可视化工具
309
+ - `ldm/data`:地震数据集
310
+ - `ldm/ldm_inverse`:正演与测量相关逻辑
311
+ - `scripts`:转换、验证、对比和推理 runner
312
+ - `tests`:pipeline CLI smoke tests
 
 
inference/eval_overthrust.py CHANGED
@@ -17,7 +17,8 @@ if str(REPO_ROOT) not in sys.path:
17
  sys.path.insert(0, str(REPO_ROOT))
18
 
19
  from inference.dataset import OverthrustTrueimpDataset
20
- from pipeline import SeismicImpInvLDDPMPipeline
 
21
 
22
 
23
  OVERTHRUST_CONFIG = {
@@ -89,10 +90,17 @@ def save_comparison(
89
 
90
  def evaluate_overthrust(
91
  pipe: SeismicImpInvLDDPMPipeline,
 
92
  output_dir: str | Path = "outputs/overthrust",
93
- num_inference_steps: int = 1000,
94
  device: str | torch.device | None = None,
95
  ) -> dict[str, object]:
 
 
 
 
 
 
96
  output_dir = Path(output_dir)
97
  output_dir.mkdir(parents=True, exist_ok=True)
98
  device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
@@ -132,12 +140,24 @@ def evaluate_overthrust(
132
  dipin = batch["dipin"].to(device)
133
  record = batch["record"].to(device)
134
  image = batch["image"].to(device)
 
 
 
 
 
 
 
 
 
 
 
135
  output = pipe(
136
  dipin=dipin,
137
  record=record,
138
  image=image,
139
  num_inference_steps=num_inference_steps,
140
  seeds=seeds,
 
141
  )
142
  prediction = output.impedance_samples
143
  reconstruction = output.impedance_reconstructed
@@ -161,7 +181,11 @@ def evaluate_overthrust(
161
  full_reconstruction_impedance = dataset.fan(full_reconstruction)
162
 
163
  metrics_summary = {
164
- "config": {**OVERTHRUST_CONFIG, "num_inference_steps": num_inference_steps},
 
 
 
 
165
  "normalized": compute_metrics(full_prediction, full_target),
166
  "impedance": compute_metrics(full_prediction_impedance, full_target_impedance),
167
  "encode_impedance": compute_metrics(
@@ -188,23 +212,26 @@ def evaluate_overthrust(
188
 
189
 
190
  def parse_args() -> argparse.Namespace:
191
- parser = argparse.ArgumentParser(description="Evaluate SAII-LDDPM on Overthrust.")
 
192
  parser.add_argument("--model", default="mally-2000/seismic-lddpm")
193
  parser.add_argument("--output", default="outputs/overthrust")
194
  parser.add_argument("--device", default=None)
195
- parser.add_argument("--num-inference-steps", type=int, default=1000)
196
  return parser.parse_args()
197
 
198
 
199
  def main() -> None:
200
  args = parse_args()
201
- pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
 
202
  args.model,
203
  torch_dtype=torch.float32,
204
  trust_remote_code=True,
205
  )
206
  result = evaluate_overthrust(
207
  pipe,
 
208
  output_dir=args.output,
209
  num_inference_steps=args.num_inference_steps,
210
  device=args.device,
 
17
  sys.path.insert(0, str(REPO_ROOT))
18
 
19
  from inference.dataset import OverthrustTrueimpDataset
20
+ from inference.util import OverthrustForwardOperator
21
+ from pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline
22
 
23
 
24
  OVERTHRUST_CONFIG = {
 
90
 
91
  def evaluate_overthrust(
92
  pipe: SeismicImpInvLDDPMPipeline,
93
+ method: str = "LDDPM",
94
  output_dir: str | Path = "outputs/overthrust",
95
+ num_inference_steps: int | None = None,
96
  device: str | torch.device | None = None,
97
  ) -> dict[str, object]:
98
+ method = method.upper()
99
+ if method not in {"LDDPM", "CLDM"}:
100
+ raise ValueError("method must be LDDPM or CLDM")
101
+ if num_inference_steps is None:
102
+ num_inference_steps = 30 if method == "CLDM" else 1000
103
+
104
  output_dir = Path(output_dir)
105
  output_dir.mkdir(parents=True, exist_ok=True)
106
  device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
 
140
  dipin = batch["dipin"].to(device)
141
  record = batch["record"].to(device)
142
  image = batch["image"].to(device)
143
+ extra_kwargs = {}
144
+ if method == "CLDM":
145
+ f0 = int(batch["rick_v"][0].item())
146
+ f0_phase = int(batch["rick_phase"][0].item())
147
+ extra_kwargs = {
148
+ "measurement": record,
149
+ "operator": OverthrustForwardOperator(
150
+ wavelet=dataset.wavelets[f0][f0_phase],
151
+ device=device,
152
+ ),
153
+ }
154
  output = pipe(
155
  dipin=dipin,
156
  record=record,
157
  image=image,
158
  num_inference_steps=num_inference_steps,
159
  seeds=seeds,
160
+ **extra_kwargs,
161
  )
162
  prediction = output.impedance_samples
163
  reconstruction = output.impedance_reconstructed
 
181
  full_reconstruction_impedance = dataset.fan(full_reconstruction)
182
 
183
  metrics_summary = {
184
+ "config": {
185
+ **OVERTHRUST_CONFIG,
186
+ "method": method,
187
+ "num_inference_steps": num_inference_steps,
188
+ },
189
  "normalized": compute_metrics(full_prediction, full_target),
190
  "impedance": compute_metrics(full_prediction_impedance, full_target_impedance),
191
  "encode_impedance": compute_metrics(
 
212
 
213
 
214
  def parse_args() -> argparse.Namespace:
215
+ parser = argparse.ArgumentParser(description="Evaluate SAII-LDDPM/CLDM on Overthrust.")
216
+ parser.add_argument("method", nargs="?", choices=["LDDPM", "CLDM"], default="LDDPM")
217
  parser.add_argument("--model", default="mally-2000/seismic-lddpm")
218
  parser.add_argument("--output", default="outputs/overthrust")
219
  parser.add_argument("--device", default=None)
220
+ parser.add_argument("--num-inference-steps", type=int, default=None)
221
  return parser.parse_args()
222
 
223
 
224
  def main() -> None:
225
  args = parse_args()
226
+ pipe_cls = SeismicImpInvCLDMPipeline if args.method == "CLDM" else SeismicImpInvLDDPMPipeline
227
+ pipe = pipe_cls.from_pretrained(
228
  args.model,
229
  torch_dtype=torch.float32,
230
  trust_remote_code=True,
231
  )
232
  result = evaluate_overthrust(
233
  pipe,
234
+ method=args.method,
235
  output_dir=args.output,
236
  num_inference_steps=args.num_inference_steps,
237
  device=args.device,
inference/{infer_LDDPM.py → infer.py} RENAMED
@@ -11,15 +11,15 @@ REPO_ROOT = Path(__file__).resolve().parents[1]
11
  if str(REPO_ROOT) not in sys.path:
12
  sys.path.insert(0, str(REPO_ROOT))
13
 
14
- from inference.dataset import OverthrustTrueimpDataset
15
- from pipeline import SeismicImpInvLDDPMPipeline
 
16
 
17
 
18
- MODEL_ID = "mally-2000/seismic-lddpm"
19
- OUT_DIR = REPO_ROOT / "outputs" / "infer_LDDPM"
20
- NUM_INFERENCE_STEPS = 1000
21
  PATCH_INDEX = 0
22
-
23
 
24
  def save_comparison(dipin, record, target, prediction, output_path):
25
  fig, axes = plt.subplots(1, 4, figsize=(16, 4))
@@ -37,21 +37,16 @@ def save_comparison(dipin, record, target, prediction, output_path):
37
  fig.savefig(output_path, dpi=150)
38
  plt.close(fig)
39
 
 
40
  if __name__ == "__main__":
 
 
 
41
  OUT_DIR.mkdir(parents=True, exist_ok=True)
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  print(f"Using device: {device}")
 
44
 
45
- pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
46
- MODEL_ID,
47
- torch_dtype=torch.float32,
48
- trust_remote_code=True,
49
- ).to(device)
50
- print(f"UNet device: {pipe.unet.device}")
51
-
52
-
53
- # One default Overthrust patch. Dataset defaults define the LDDPM test setup:
54
- # nonlinear forward model, 30 Hz Ricker wavelet, 15 dB noise, and dipin=0.012.
55
  dataset = OverthrustTrueimpDataset(
56
  patch_indices=[PATCH_INDEX],
57
  data_dir=REPO_ROOT / "data",
@@ -61,13 +56,51 @@ if __name__ == "__main__":
61
  dipin = sample["dipin"].unsqueeze(0).to(device)
62
  record = sample["record"].unsqueeze(0).to(device)
63
  image = sample["image"].unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  output = pipe(
66
  dipin=dipin,
67
  record=record,
68
  image=image,
69
- num_inference_steps=NUM_INFERENCE_STEPS,
70
- seeds=[int(sample["seed"])],
 
71
  )
72
 
73
  prediction = output.impedance_samples[0, 0].detach().cpu().numpy()
@@ -83,3 +116,7 @@ if __name__ == "__main__":
83
  print(f"Saved: {OUT_DIR / 'target.npy'}")
84
  print(f"Saved: {OUT_DIR / 'comparison.png'}")
85
 
 
 
 
 
 
11
  if str(REPO_ROOT) not in sys.path:
12
  sys.path.insert(0, str(REPO_ROOT))
13
 
14
+ from inference.dataset import OverthrustTrueimpDataset, SeismicBase
15
+ from inference.util import OverthrustForwardOperator, ricker_wavelet
16
+ from pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline
17
 
18
 
19
+ METHOD = sys.argv[1].upper() if len(sys.argv) > 1 else "LDDPM"
20
+ OUT_DIR = REPO_ROOT / "outputs" / f"infer_{METHOD}"
 
21
  PATCH_INDEX = 0
22
+ RUN_EVAL = True
23
 
24
  def save_comparison(dipin, record, target, prediction, output_path):
25
  fig, axes = plt.subplots(1, 4, figsize=(16, 4))
 
37
  fig.savefig(output_path, dpi=150)
38
  plt.close(fig)
39
 
40
+
41
  if __name__ == "__main__":
42
+ if METHOD not in {"LDDPM", "CLDM"}:
43
+ raise ValueError("METHOD must be LDDPM or CLDM. Example: python inference/infer.py CLDM")
44
+
45
  OUT_DIR.mkdir(parents=True, exist_ok=True)
46
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
  print(f"Using device: {device}")
48
+ print(f"Method: {METHOD}")
49
 
 
 
 
 
 
 
 
 
 
 
50
  dataset = OverthrustTrueimpDataset(
51
  patch_indices=[PATCH_INDEX],
52
  data_dir=REPO_ROOT / "data",
 
56
  dipin = sample["dipin"].unsqueeze(0).to(device)
57
  record = sample["record"].unsqueeze(0).to(device)
58
  image = sample["image"].unsqueeze(0).to(device)
59
+ seed = int(sample["seed"])
60
+
61
+ if METHOD == "LDDPM":
62
+ num_inference_steps = 1000
63
+ extra_kwargs = {}
64
+ pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
65
+ "mally-2000/seismic-lddpm",
66
+ torch_dtype=torch.float32,
67
+ trust_remote_code=True,
68
+ ).to(device)
69
+
70
+ else:
71
+ pipe = SeismicImpInvCLDMPipeline.from_pretrained(
72
+ "mally-2000/seismic-lddpm",
73
+ torch_dtype=torch.float32,
74
+ trust_remote_code=True,
75
+ ).to(device)
76
+ num_inference_steps = 30
77
+ f0 = int(sample["rick_v"].item())
78
+ f0_phase = int(sample["rick_phase"].item())
79
+
80
+ # NOTE: The forward operator's wavelet must match the dataset's wavelet
81
+ # to ensure consistency between simulated measurements and actual data.
82
+ # The parameters (f0=30Hz, dt=0.002s) must match the values used in
83
+ # OverthrustTrueimpDataset._build_wavelets() to generate the seismic records.
84
+ wavelet = ricker_wavelet(f0=f0, nt=256 // 2, dt=0.002)
85
+ # Apply phase shift to match the dataset's wavelet phase
86
+ wavelet = SeismicBase.phaseshift(wavelet, f0_phase)
87
+
88
+ operator = OverthrustForwardOperator(
89
+ wavelet=wavelet,
90
+ device=device,
91
+ )
92
+ extra_kwargs = dict(
93
+ measurement=record,
94
+ operator=operator,
95
+ )
96
 
97
  output = pipe(
98
  dipin=dipin,
99
  record=record,
100
  image=image,
101
+ num_inference_steps=num_inference_steps,
102
+ seeds=[seed],
103
+ **extra_kwargs,
104
  )
105
 
106
  prediction = output.impedance_samples[0, 0].detach().cpu().numpy()
 
116
  print(f"Saved: {OUT_DIR / 'target.npy'}")
117
  print(f"Saved: {OUT_DIR / 'comparison.png'}")
118
 
119
+ if RUN_EVAL:
120
+ from inference.eval_overthrust import evaluate_overthrust
121
+
122
+ evaluate_overthrust(pipe, method=METHOD, output_dir=OUT_DIR / "eval")
inference/util.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def ricker_wavelet(f0: float, nt: int, dt: float) -> np.ndarray:
8
+ """Ricker (Mexican hat) wavelet - pure NumPy implementation.
9
+
10
+ Replaces pylops.utils.wavelets.ricker with identical output.
11
+ Creates a Ricker wavelet given time axis parameters and central frequency.
12
+
13
+ Args:
14
+ f0: Central frequency in Hz
15
+ nt: Number of time samples (positive part including zero)
16
+ dt: Time sampling interval in seconds
17
+
18
+ Returns:
19
+ Wavelet array with symmetric time axis
20
+ """
21
+ # Construct positive time axis (including zero)
22
+ t_positive = np.arange(nt) * dt
23
+
24
+ # _tcrop: if even length, remove last sample to ensure odd length
25
+ if len(t_positive) % 2 == 0:
26
+ t_positive = t_positive[:-1]
27
+
28
+ # Construct symmetric time axis (negative + positive)
29
+ t = np.concatenate((np.flipud(-t_positive[1:]), t_positive), axis=0)
30
+
31
+ # Ricker wavelet formula
32
+ w = (1 - 2 * (np.pi * f0 * t) ** 2) * np.exp(-((np.pi * f0 * t) ** 2))
33
+
34
+ return w
35
+
36
+
37
+ def build_convmtx(wavelet: np.ndarray, size: int) -> np.ndarray:
38
+ """Build convolution matrix (Toeplitz matrix) - pure NumPy implementation.
39
+
40
+ Replaces pylops.utils.signalprocessing.convmtx with identical output.
41
+
42
+ Args:
43
+ wavelet: 1D wavelet array
44
+ size: Output matrix size (size x size)
45
+
46
+ Returns:
47
+ Convolution matrix of shape (size, size)
48
+ """
49
+ wlen = len(wavelet)
50
+ offset = wlen // 2
51
+ matrix = np.zeros((size, size), dtype=wavelet.dtype)
52
+
53
+ for i in range(size):
54
+ for j, w_val in enumerate(wavelet):
55
+ col_idx = i - offset + j
56
+ if 0 <= col_idx < size:
57
+ matrix[i, col_idx] = w_val
58
+
59
+ return matrix
60
+
61
+
62
+ class OverthrustForwardOperator:
63
+ """Differentiable seismic forward model matching OverthrustTrueimpDataset."""
64
+
65
+ def __init__(
66
+ self,
67
+ *,
68
+ wavelet: np.ndarray,
69
+ size: int = 256,
70
+ normal_min: float = 5.0931,
71
+ normal_max: float = 6.501110975896774,
72
+ record_scale: float = 0.3215932963300079,
73
+ normalize: str = "minmax",
74
+ device: torch.device | None = None,
75
+ dtype: torch.dtype = torch.float32,
76
+ ):
77
+ device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+
79
+ wavelet_matrix = build_convmtx(wavelet, size)
80
+ s1 = np.eye(size, k=1) - np.eye(size, k=0)
81
+ s2 = np.eye(size, k=1) + np.eye(size, k=0)
82
+ s1[-1] = 0
83
+ s2[-1] = 0
84
+
85
+ self.wavelet_matrix = torch.as_tensor(wavelet_matrix, device=device, dtype=dtype)
86
+ self.s1 = torch.as_tensor(s1, device=device, dtype=dtype)
87
+ self.s2 = torch.as_tensor(s2, device=device, dtype=dtype)
88
+ self.normal_min = float(normal_min)
89
+ self.normal_max = float(normal_max)
90
+ self.record_scale = float(record_scale)
91
+ self.normalize = normalize
92
+
93
+ def _inv_normal(self, image: torch.Tensor) -> torch.Tensor:
94
+ if self.normalize == "minmax":
95
+ return image * (self.normal_max - self.normal_min) + self.normal_min
96
+ if self.normalize == "max":
97
+ return image * self.normal_max
98
+ raise ValueError(f"Unsupported normalize: {self.normalize}")
99
+
100
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
101
+ impedance = torch.exp(self._inv_normal(image))
102
+ numerator = torch.matmul(self.s1.to(dtype=image.dtype), impedance)
103
+ denominator = torch.matmul(self.s2.to(dtype=image.dtype), impedance)
104
+ reflectivity = numerator / torch.clamp(denominator, min=1e-6)
105
+ record = torch.matmul(self.wavelet_matrix.to(dtype=image.dtype), reflectivity)
106
+ return record / self.record_scale
pipeline.py CHANGED
@@ -1,10 +1,11 @@
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
 
4
 
5
  import numpy as np
6
  import torch
7
- from diffusers import DDPMScheduler, DiffusionPipeline, UNet2DModel, VQModel
8
  from diffusers.utils import BaseOutput
9
 
10
 
@@ -232,3 +233,332 @@ class SeismicImpInvLDDPMPipeline(DiffusionPipeline):
232
  if output_type == "np":
233
  return reconstruction.detach().cpu().numpy()
234
  return reconstruction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
+ from typing import Any, Callable
5
 
6
  import numpy as np
7
  import torch
8
+ from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DModel, VQModel
9
  from diffusers.utils import BaseOutput
10
 
11
 
 
233
  if output_type == "np":
234
  return reconstruction.detach().cpu().numpy()
235
  return reconstruction
236
+
237
+
238
+ class SeismicImpInvCLDMPipeline(SeismicImpInvLDDPMPipeline):
239
+ """SAII-CLDM inference pipeline.
240
+
241
+ This reuses the same trained components as SAII-LDDPM and replaces only the
242
+ reverse sampling procedure with DDIM plus model-driven resampling.
243
+ """
244
+
245
+ @staticmethod
246
+ def _get_operator_fn(operator: Any) -> Callable[[torch.Tensor], torch.Tensor]:
247
+ if callable(operator):
248
+ return operator
249
+ if hasattr(operator, "forward") and callable(operator.forward):
250
+ return operator.forward
251
+ raise TypeError("`operator` must be callable or expose a callable `forward` method.")
252
+
253
+ @staticmethod
254
+ def _build_ddim_scheduler(
255
+ scheduler: DDPMScheduler,
256
+ num_inference_steps: int,
257
+ device: torch.device,
258
+ ) -> DDIMScheduler:
259
+ ddim_scheduler = DDIMScheduler.from_config(
260
+ scheduler.config,
261
+ clip_sample=False,
262
+ set_alpha_to_one=False,
263
+ steps_offset=1,
264
+ timestep_spacing="leading",
265
+ )
266
+ ddim_scheduler.set_timesteps(num_inference_steps, device=device)
267
+ return ddim_scheduler
268
+
269
+ @staticmethod
270
+ def _default_pixel_optimization_param() -> dict[str, float | int]:
271
+ return {
272
+ "eps": 1e-4,
273
+ "max_iters": 100,
274
+ "lr": 1e-5,
275
+ "y_coef": 1.0,
276
+ "x_coef": 0.0,
277
+ "tv_coef": 0.0,
278
+ "dh_coef": 1.0,
279
+ "dw_coef": 1.5,
280
+ }
281
+
282
+ @staticmethod
283
+ def _default_last_pixel_optimization_param() -> dict[str, float | int]:
284
+ return {
285
+ "eps": 1e-4,
286
+ "max_iters": 1,
287
+ "lr": 1e-4,
288
+ "y_coef": 1.0,
289
+ "x_coef": 0.1,
290
+ "tv_coef": 0.0,
291
+ "dh_coef": 1.0,
292
+ "dw_coef": 1.5,
293
+ }
294
+
295
+ @staticmethod
296
+ def _tv_loss(x: torch.Tensor, *, dh_coef: float, dw_coef: float) -> torch.Tensor:
297
+ dh = dh_coef * torch.abs(x[..., :, 1:] - x[..., :, :-1])
298
+ dw = dw_coef * torch.abs(x[..., 1:, :] - x[..., :-1, :])
299
+ return torch.mean(dh[..., :-1, :] + dw[..., :, :-1])
300
+
301
+ def _ddim_step(
302
+ self,
303
+ latents: torch.Tensor,
304
+ conditioning: torch.Tensor,
305
+ timestep: int,
306
+ scheduler: DDIMScheduler,
307
+ eta: float,
308
+ generator: torch.Generator | list[torch.Generator] | None,
309
+ quantize_denoised: bool,
310
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]:
311
+ model_input = torch.cat(
312
+ [
313
+ scheduler.scale_model_input(latents, timestep),
314
+ conditioning.to(dtype=latents.dtype),
315
+ ],
316
+ dim=1,
317
+ )
318
+ timestep_tensor = torch.full(
319
+ (latents.shape[0],), timestep, device=latents.device, dtype=torch.long
320
+ )
321
+ noise_pred = self.unet(model_input, timestep_tensor).sample
322
+
323
+ alpha_t = scheduler.alphas_cumprod[timestep].to(
324
+ device=latents.device, dtype=latents.dtype
325
+ )
326
+ prev_timestep = timestep - (
327
+ scheduler.config.num_train_timesteps // scheduler.num_inference_steps
328
+ )
329
+ if prev_timestep >= 0:
330
+ alpha_prev = scheduler.alphas_cumprod[prev_timestep].to(
331
+ device=latents.device, dtype=latents.dtype
332
+ )
333
+ else:
334
+ alpha_prev = scheduler.final_alpha_cumprod.to(
335
+ device=latents.device, dtype=latents.dtype
336
+ )
337
+
338
+ beta_t = 1.0 - alpha_t
339
+ pred_x0 = (latents - beta_t.sqrt() * noise_pred) / alpha_t.sqrt()
340
+ pseudo_x0 = (latents - beta_t * noise_pred) / alpha_t.sqrt()
341
+ if quantize_denoised:
342
+ pred_x0 = self.vq_model.quantize(pred_x0.to(dtype=self.vq_model.dtype))[0].to(
343
+ dtype=latents.dtype
344
+ )
345
+ noise_pred = (latents - alpha_t.sqrt() * pred_x0) / beta_t.sqrt()
346
+
347
+ variance = scheduler._get_variance(timestep, prev_timestep).to(
348
+ device=latents.device, dtype=latents.dtype
349
+ )
350
+ sigma_t = eta * variance.sqrt()
351
+ direction = torch.clamp(1.0 - alpha_prev - sigma_t**2, min=0.0).sqrt() * noise_pred
352
+ noise = torch.zeros_like(latents)
353
+ if eta > 0:
354
+ noise = sigma_t * self._randn_like_sample(latents, generator)
355
+ prev_sample = alpha_prev.sqrt() * pred_x0 + direction + noise
356
+
357
+ batch_shape = (latents.shape[0], 1, 1, 1)
358
+ return (
359
+ prev_sample,
360
+ pred_x0,
361
+ pseudo_x0,
362
+ {
363
+ "a_t": torch.full(
364
+ batch_shape,
365
+ float(alpha_t.item()),
366
+ device=latents.device,
367
+ dtype=latents.dtype,
368
+ ),
369
+ "a_prev": torch.full(
370
+ batch_shape,
371
+ float(alpha_prev.item()),
372
+ device=latents.device,
373
+ dtype=latents.dtype,
374
+ ),
375
+ },
376
+ )
377
+
378
+ def _optimize_pixels(
379
+ self,
380
+ x_prime: torch.Tensor,
381
+ measurement: torch.Tensor,
382
+ operator_fn: Callable[[torch.Tensor], torch.Tensor],
383
+ params: dict[str, Any],
384
+ ) -> torch.Tensor:
385
+ merged = {**self._default_pixel_optimization_param(), **params}
386
+ if int(merged["max_iters"]) <= 0:
387
+ return x_prime.detach()
388
+
389
+ loss_fn = torch.nn.MSELoss(reduction="mean")
390
+ opt_var = x_prime.detach().clone().requires_grad_(True)
391
+ opt_init = x_prime.detach().clone()
392
+ optimizer = torch.optim.AdamW([opt_var], lr=float(merged["lr"]))
393
+
394
+ for _ in range(int(merged["max_iters"])):
395
+ optimizer.zero_grad(set_to_none=True)
396
+ measurement_loss = (
397
+ loss_fn(measurement, operator_fn(opt_var)) * float(merged["y_coef"])
398
+ + loss_fn(opt_init, opt_var) * float(merged["x_coef"])
399
+ )
400
+ if float(merged["tv_coef"]) != 0.0:
401
+ measurement_loss = measurement_loss + float(merged["tv_coef"]) * self._tv_loss(
402
+ opt_var,
403
+ dh_coef=float(merged["dh_coef"]),
404
+ dw_coef=float(merged["dw_coef"]),
405
+ )
406
+ measurement_loss.backward()
407
+ optimizer.step()
408
+ if float(measurement_loss.detach().cpu().item()) < float(merged["eps"]):
409
+ break
410
+
411
+ return opt_var.detach()
412
+
413
+ def _stochastic_resample(
414
+ self,
415
+ pseudo_x0: torch.Tensor,
416
+ x_t: torch.Tensor,
417
+ a_t: torch.Tensor,
418
+ sigma: torch.Tensor,
419
+ generator: torch.Generator | list[torch.Generator] | None,
420
+ ) -> torch.Tensor:
421
+ sigma = torch.clamp(sigma, min=1e-12)
422
+ one_minus_a_t = torch.clamp(1.0 - a_t, min=1e-12)
423
+ noise = self._randn_like_sample(pseudo_x0, generator)
424
+ return (
425
+ (sigma * a_t.sqrt() * pseudo_x0 + one_minus_a_t * x_t)
426
+ / (sigma + one_minus_a_t)
427
+ + noise * torch.sqrt(1.0 / (1.0 / sigma + 1.0 / one_minus_a_t))
428
+ )
429
+
430
+ def __call__(
431
+ self,
432
+ dipin: torch.Tensor,
433
+ record: torch.Tensor,
434
+ measurement: torch.Tensor | None = None,
435
+ operator: Any | None = None,
436
+ image: torch.Tensor | None = None,
437
+ num_inference_steps: int = 30,
438
+ seed: int | None = None,
439
+ seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None,
440
+ generator: torch.Generator | None = None,
441
+ eta: float = 0.01,
442
+ interval: int = 6,
443
+ sigma_a: float = 20.0,
444
+ pixel_optimization_param: dict[str, Any] | None = None,
445
+ last_pixel_optimization_param: dict[str, Any] | None = None,
446
+ quantize_denoised: bool = False,
447
+ output_type: str = "tensor",
448
+ ) -> SeismicImpInvLDDPMPipelineOutput:
449
+ if measurement is None:
450
+ measurement = record
451
+ if operator is None:
452
+ raise ValueError("SAII-CLDM requires a forward `operator`.")
453
+ if interval <= 0:
454
+ raise ValueError("`interval` must be a positive integer.")
455
+
456
+ device = self.unet.device
457
+ if seeds is not None:
458
+ if isinstance(seeds, torch.Tensor):
459
+ seeds = seeds.detach().cpu().tolist()
460
+ seeds = [int(value) for value in seeds]
461
+ if len(seeds) != dipin.shape[0]:
462
+ raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}")
463
+ generator = [
464
+ torch.Generator(device=device).manual_seed(value) for value in seeds
465
+ ]
466
+ elif seed is not None:
467
+ generator = torch.Generator(device=device).manual_seed(seed)
468
+ elif generator is None:
469
+ generator = torch.Generator(device=device)
470
+
471
+ with torch.no_grad():
472
+ dipin = dipin.to(device=device, dtype=self.vq_model.dtype)
473
+ record = record.to(device=device, dtype=self.unet.dtype)
474
+ measurement = measurement.to(device=device, dtype=self.unet.dtype)
475
+ impedance_dipin, record_features = self._encode_conditioning(dipin, record)
476
+ conditioning = torch.cat([impedance_dipin, record_features], dim=1)
477
+ impedance_latents = self._randn_like_sample(
478
+ torch.empty(
479
+ impedance_dipin.shape,
480
+ device=device,
481
+ dtype=self.unet.dtype,
482
+ ),
483
+ generator,
484
+ )
485
+
486
+ operator_fn = self._get_operator_fn(operator)
487
+ pixel_params = pixel_optimization_param or {}
488
+ last_pixel_params = last_pixel_optimization_param or self._default_last_pixel_optimization_param()
489
+ schedule = self._build_ddim_scheduler(self.scheduler, num_inference_steps, device)
490
+ time_range = [int(timestep) for timestep in schedule.timesteps.tolist()]
491
+ resample_start_index = len(time_range) // 4
492
+
493
+ for step_idx, timestep in enumerate(time_range):
494
+ index = len(time_range) - step_idx - 1
495
+ with torch.no_grad():
496
+ next_latents, pred_x0, pseudo_x0, step_stats = self._ddim_step(
497
+ impedance_latents,
498
+ conditioning,
499
+ timestep,
500
+ schedule,
501
+ eta,
502
+ generator,
503
+ quantize_denoised,
504
+ )
505
+
506
+ if (index >= resample_start_index or index == 0) and (
507
+ index % interval == 0 or index == 0
508
+ ):
509
+ x_t_reference = next_latents.detach().clone()
510
+ sigma = sigma_a * (1.0 - step_stats["a_prev"]) / (
511
+ 1.0 - step_stats["a_t"]
512
+ )
513
+ sigma = sigma * (1.0 - step_stats["a_t"] / step_stats["a_prev"])
514
+ sigma = torch.clamp(sigma, min=1e-12)
515
+
516
+ with torch.no_grad():
517
+ pseudo_x0_pixel = self.vq_model.decode(
518
+ pseudo_x0.detach().to(dtype=self.vq_model.dtype)
519
+ ).sample
520
+ optimized_pixels = self._optimize_pixels(
521
+ pseudo_x0_pixel,
522
+ measurement,
523
+ operator_fn,
524
+ last_pixel_params if index == 0 else pixel_params,
525
+ )
526
+ with torch.no_grad():
527
+ optimized_latents = self.vq_model.encode(
528
+ optimized_pixels.to(dtype=self.vq_model.dtype)
529
+ ).latents.to(dtype=self.unet.dtype)
530
+ next_latents = self._stochastic_resample(
531
+ optimized_latents,
532
+ x_t_reference,
533
+ step_stats["a_prev"],
534
+ sigma.to(dtype=self.unet.dtype),
535
+ generator,
536
+ )
537
+
538
+ impedance_latents = next_latents.detach()
539
+
540
+ with torch.no_grad():
541
+ impedance_samples = self.vq_model.decode(
542
+ impedance_latents.to(dtype=self.vq_model.dtype)
543
+ ).sample
544
+ impedance_reconstructed = None
545
+ if image is not None:
546
+ image = image.to(device=device, dtype=self.vq_model.dtype)
547
+ image_latents = self.vq_model.encode(image).latents
548
+ impedance_reconstructed = self.vq_model.decode(image_latents).sample
549
+
550
+ if output_type == "np":
551
+ impedance_samples = impedance_samples.detach().cpu().numpy()
552
+ impedance_latents = impedance_latents.detach().cpu().numpy()
553
+ impedance_dipin = impedance_dipin.detach().cpu().numpy()
554
+ record_features = record_features.detach().cpu().numpy()
555
+ if impedance_reconstructed is not None:
556
+ impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy()
557
+
558
+ return SeismicImpInvLDDPMPipelineOutput(
559
+ impedance_samples=impedance_samples,
560
+ impedance_latents=impedance_latents,
561
+ impedance_dipin=impedance_dipin,
562
+ impedance_reconstructed=impedance_reconstructed,
563
+ record_features=record_features,
564
+ )