Add Overthrust inference benchmark and model card
Browse files- .gitattributes +2 -0
- README.md +116 -0
- assets/demo.png +3 -0
- data/Overthrust_trueimp.mat +3 -0
- examples/expected_metrics.json +46 -0
- inference/__init__.py +1 -0
- inference/dataset.py +376 -0
- inference/eval.py +216 -0
- inference/infer.py +77 -0
- pipeline.py +234 -0
- requirements.txt +10 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/Overthrust_trueimp.mat filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
pipeline_tag: image-to-image
|
| 4 |
+
tags:
|
| 5 |
+
- seismic-inversion
|
| 6 |
+
- impedance-inversion
|
| 7 |
+
- diffusion
|
| 8 |
+
- ddpm
|
| 9 |
+
- overthrust
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Seismic-LDDPM
|
| 13 |
+
|
| 14 |
+
Seismic-LDDPM is a latent DDPM pipeline for seismic impedance inversion. The
|
| 15 |
+
pipeline takes a low-frequency impedance image (`dipin`) and a synthetic seismic
|
| 16 |
+
record (`record`) and predicts the impedance image.
|
| 17 |
+
|
| 18 |
+
This repository includes:
|
| 19 |
+
|
| 20 |
+
- Diffusers-format model components: `vq_model`, `unet`, `scheduler`, and
|
| 21 |
+
`condition_encoder`.
|
| 22 |
+
- `SeismicImpInvLDDPMPipeline` in `pipeline.py`.
|
| 23 |
+
- A complete Overthrust benchmark sample at `data/Overthrust_trueimp.mat`.
|
| 24 |
+
- Inference scripts under `inference/`.
|
| 25 |
+
|
| 26 |
+
## Installation
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
git clone https://huggingface.co/mally-2000/seismic-lddpm
|
| 30 |
+
cd seismic-lddpm
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Overthrust Evaluation
|
| 35 |
+
|
| 36 |
+
The Overthrust evaluation script is intentionally fixed to the bundled
|
| 37 |
+
`data/Overthrust_trueimp.mat`. It cuts the full model into six `256 x 256`
|
| 38 |
+
patches, synthesizes the seismic records and low-frequency impedance inputs,
|
| 39 |
+
runs inference, stitches the six predictions back together, and computes the
|
| 40 |
+
metrics.
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
python inference/eval.py \
|
| 44 |
+
--model . \
|
| 45 |
+
--output outputs/overthrust \
|
| 46 |
+
--num-inference-steps 1000
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Outputs:
|
| 50 |
+
|
| 51 |
+
- `outputs/overthrust/full_target.npy`
|
| 52 |
+
- `outputs/overthrust/full_prediction.npy`
|
| 53 |
+
- `outputs/overthrust/full_reconstruction.npy`
|
| 54 |
+
- `outputs/overthrust/comparison_impedance.png`
|
| 55 |
+
- `outputs/overthrust/metrics_summary.json`
|
| 56 |
+
|
| 57 |
+
## Benchmark Result
|
| 58 |
+
|
| 59 |
+
Evaluated locally on the bundled Overthrust benchmark with 1000 DDPM steps,
|
| 60 |
+
`noise_snr=15`, `dipin_v=0.012`, `f0=30`, `phase=0`, `seed=1234`, and patch
|
| 61 |
+
indices `[0, 1, 2, 3, 4, 5]`.
|
| 62 |
+
|
| 63 |
+
| Space | PSNR | SSIM | PCC | RRE | NMSE |
|
| 64 |
+
|---|---:|---:|---:|---:|---:|
|
| 65 |
+
| Normalized | 30.7698 | 0.9339 | 0.9963 | 0.0435 | 0.001894 |
|
| 66 |
+
| Impedance | 33.4413 | 0.9554 | 0.9957 | 0.0324 | 0.001050 |
|
| 67 |
+
| VQ reconstruction | 37.7954 | 0.9677 | 0.9983 | 0.0209 | 0.000435 |
|
| 68 |
+
|
| 69 |
+

|
| 70 |
+
|
| 71 |
+
## Single-Sample Inference
|
| 72 |
+
|
| 73 |
+
For a single prepared sample:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
python inference/infer.py \
|
| 77 |
+
--dipin path/to/dipin.npy \
|
| 78 |
+
--record path/to/record.npy \
|
| 79 |
+
--model . \
|
| 80 |
+
--output outputs/single
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
The input arrays may be `H x W`, `C x H x W`, or `B x C x H x W`. The script
|
| 84 |
+
converts them to BCHW tensors and saves `prediction.npy` and `prediction.png`.
|
| 85 |
+
|
| 86 |
+
## Python Usage
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
import torch
|
| 90 |
+
from pipeline import SeismicImpInvLDDPMPipeline
|
| 91 |
+
|
| 92 |
+
pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
|
| 93 |
+
"mally-2000/seismic-lddpm",
|
| 94 |
+
torch_dtype=torch.float32,
|
| 95 |
+
trust_remote_code=True,
|
| 96 |
+
).to("cuda")
|
| 97 |
+
|
| 98 |
+
result = pipe(
|
| 99 |
+
dipin=dipin, # torch.Tensor, BCHW
|
| 100 |
+
record=record, # torch.Tensor, BCHW
|
| 101 |
+
num_inference_steps=1000,
|
| 102 |
+
seed=1234,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
prediction = result.impedance_samples
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Notes
|
| 109 |
+
|
| 110 |
+
- `inference/dataset.py` contains a lightweight `SeismicBase` and
|
| 111 |
+
`OverthrustTrueimpDataset`; it does not depend on the original training
|
| 112 |
+
repository's `ldm.data.seisimic`.
|
| 113 |
+
- Synthetic record generation is seeded through the benchmark configuration so
|
| 114 |
+
the published Overthrust evaluation is reproducible.
|
| 115 |
+
- The bundled Overthrust file is used only as a compact benchmark input for
|
| 116 |
+
reproducing this model's inference pipeline.
|
assets/demo.png
ADDED
|
Git LFS Details
|
data/Overthrust_trueimp.mat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59345b022a90f174efd444004192f9edf93ef5f65c556a84f52e9596f1695bd5
|
| 3 |
+
size 334424
|
examples/expected_metrics.json
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config": {
|
| 3 |
+
"size": 256,
|
| 4 |
+
"patch_indices": [
|
| 5 |
+
0,
|
| 6 |
+
1,
|
| 7 |
+
2,
|
| 8 |
+
3,
|
| 9 |
+
4,
|
| 10 |
+
5
|
| 11 |
+
],
|
| 12 |
+
"noise_snr": 15,
|
| 13 |
+
"dipin_v": 0.012,
|
| 14 |
+
"f0": 30,
|
| 15 |
+
"f0_phase": 0,
|
| 16 |
+
"seed": 1234,
|
| 17 |
+
"zhengyan_type": "nonlinear",
|
| 18 |
+
"normalize": "minmax",
|
| 19 |
+
"batch_size": 3,
|
| 20 |
+
"num_inference_steps": 1000
|
| 21 |
+
},
|
| 22 |
+
"normalized": {
|
| 23 |
+
"PSNR": 30.76981345155257,
|
| 24 |
+
"rre": 0.043521951884031296,
|
| 25 |
+
"SSIM": 0.9339061199595424,
|
| 26 |
+
"PCC": 0.9963035366574778,
|
| 27 |
+
"nmse": 0.001894210814498365,
|
| 28 |
+
"mse": 3.811614279402499e-09
|
| 29 |
+
},
|
| 30 |
+
"impedance": {
|
| 31 |
+
"PSNR": 33.44134288739278,
|
| 32 |
+
"rre": 0.03240736946463585,
|
| 33 |
+
"SSIM": 0.955363744873021,
|
| 34 |
+
"PCC": 0.9957485049549735,
|
| 35 |
+
"nmse": 0.0010502231307327747,
|
| 36 |
+
"mse": 0.11484166561534005
|
| 37 |
+
},
|
| 38 |
+
"encode_impedance": {
|
| 39 |
+
"PSNR": 37.79544219163976,
|
| 40 |
+
"rre": 0.020859846845269203,
|
| 41 |
+
"SSIM": 0.9676508176373475,
|
| 42 |
+
"PCC": 0.9982799028636675,
|
| 43 |
+
"nmse": 0.0004351270035840571,
|
| 44 |
+
"mse": 0.04758103889550172
|
| 45 |
+
}
|
| 46 |
+
}
|
inference/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Inference helpers for seismic-lddpm."""
|
inference/dataset.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pylops
|
| 9 |
+
import scipy.io
|
| 10 |
+
import torch
|
| 11 |
+
from scipy.fftpack import fft, ifft
|
| 12 |
+
from scipy.signal import butter, filtfilt
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SeismicBase:
|
| 17 |
+
@staticmethod
|
| 18 |
+
def phaseshift(w: np.ndarray, d: float) -> np.ndarray:
|
| 19 |
+
if d == 0:
|
| 20 |
+
return w
|
| 21 |
+
wf_shift = fft(w) * np.exp(1j * (np.pi * d / 180.0))
|
| 22 |
+
return np.real(ifft(wf_shift))
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def add_gaussian_band_noise(
|
| 26 |
+
target_snr: float,
|
| 27 |
+
data: np.ndarray,
|
| 28 |
+
rng: np.random.Generator | None = None,
|
| 29 |
+
) -> tuple[np.ndarray, float]:
|
| 30 |
+
if target_snr == 0:
|
| 31 |
+
return data, 0.0
|
| 32 |
+
rng = rng or np.random.default_rng()
|
| 33 |
+
signal_energy = np.linalg.norm(data) ** 2
|
| 34 |
+
noise_energy = signal_energy / (10 ** (target_snr / 10))
|
| 35 |
+
initial_noise = rng.normal(loc=0, scale=1, size=data.shape)
|
| 36 |
+
noise = filtfilt(
|
| 37 |
+
np.ones(3) / 3,
|
| 38 |
+
1,
|
| 39 |
+
filtfilt(np.ones(3) / 3, 1, initial_noise.T, method="gust").T,
|
| 40 |
+
method="gust",
|
| 41 |
+
)
|
| 42 |
+
noise = noise * np.sqrt(noise_energy / np.linalg.norm(noise) ** 2)
|
| 43 |
+
noisy_data = data + noise
|
| 44 |
+
actual_snr = 10 * np.log10(signal_energy / np.linalg.norm(noise) ** 2)
|
| 45 |
+
return noisy_data, float(actual_snr)
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def add_gaussian_noise(
|
| 49 |
+
target_snr: float,
|
| 50 |
+
data: np.ndarray,
|
| 51 |
+
rng: np.random.Generator | None = None,
|
| 52 |
+
) -> tuple[np.ndarray, float]:
|
| 53 |
+
if target_snr == 0:
|
| 54 |
+
return data, 0.0
|
| 55 |
+
rng = rng or np.random.default_rng()
|
| 56 |
+
signal_energy = np.linalg.norm(data) ** 2
|
| 57 |
+
noise_energy = signal_energy / (10 ** (target_snr / 10))
|
| 58 |
+
noise_std = np.sqrt(noise_energy / data.size)
|
| 59 |
+
noise = rng.normal(0, noise_std, data.shape)
|
| 60 |
+
noisy_data = data + noise
|
| 61 |
+
actual_snr = 10 * np.log10(signal_energy / np.linalg.norm(noise) ** 2)
|
| 62 |
+
return noisy_data, float(actual_snr)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class OverthrustTrueimpDataset(SeismicBase, Dataset):
|
| 66 |
+
"""Overthrust benchmark dataset used by seismic-lddpm evaluation."""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
size: int = 256,
|
| 71 |
+
interval: int = 1,
|
| 72 |
+
special_splits: bool = False,
|
| 73 |
+
use_mask: bool = False,
|
| 74 |
+
record_noraml: bool = True,
|
| 75 |
+
normalize: str = "minmax",
|
| 76 |
+
zhengyan_type: str = "linear",
|
| 77 |
+
train_keys: tuple[str, ...] | list[str] = ("image", "dipin", "record"),
|
| 78 |
+
ricks: tuple[int, ...] | list[int] = (30,),
|
| 79 |
+
ricks_phase: tuple[int, ...] | list[int] = (0,),
|
| 80 |
+
noise_snr: tuple[int, ...] | list[int] = (0,),
|
| 81 |
+
noise_type: str = "guassian_band",
|
| 82 |
+
dipins: tuple[float, ...] | list[float] = (0.012,),
|
| 83 |
+
dipin_nsmoothz: int = 20,
|
| 84 |
+
dipin_nsmoothx: int = 20,
|
| 85 |
+
patch_indices: tuple[int, ...] | list[int] | None = None,
|
| 86 |
+
base_seed: int = 1234,
|
| 87 |
+
data_dir: str | Path | None = None,
|
| 88 |
+
cache_dir: str | Path = "outputs/cache",
|
| 89 |
+
fixed_f0: int | None = None,
|
| 90 |
+
fixed_dipin_v: float | None = None,
|
| 91 |
+
fixed_noise_snr: int | None = None,
|
| 92 |
+
fixed_f0_phase: int | None = None,
|
| 93 |
+
):
|
| 94 |
+
self.name = "Overthrust_trueimp"
|
| 95 |
+
self.size = size
|
| 96 |
+
self.interval = interval
|
| 97 |
+
self.special_splits = special_splits
|
| 98 |
+
self.use_mask = use_mask
|
| 99 |
+
self.record_noraml = record_noraml
|
| 100 |
+
self.normalize = normalize
|
| 101 |
+
self.zhengyan_type = zhengyan_type
|
| 102 |
+
self.train_keys = list(train_keys)
|
| 103 |
+
self.ricks = list(ricks)
|
| 104 |
+
self.ricks_phase = list(ricks_phase)
|
| 105 |
+
self.noise_snr = list(noise_snr)
|
| 106 |
+
self.noise_type = noise_type
|
| 107 |
+
self.dipins = list(dipins)
|
| 108 |
+
self.dipin_nsmoothz = dipin_nsmoothz
|
| 109 |
+
self.dipin_nsmoothx = dipin_nsmoothx
|
| 110 |
+
self.base_seed = base_seed
|
| 111 |
+
self.have_exp = False
|
| 112 |
+
self.info: dict[str, float | str] = {}
|
| 113 |
+
self.fixed_f0 = fixed_f0
|
| 114 |
+
self.fixed_dipin_v = fixed_dipin_v
|
| 115 |
+
self.fixed_noise_snr = fixed_noise_snr
|
| 116 |
+
self.fixed_f0_phase = fixed_f0_phase
|
| 117 |
+
self.data_dir = Path(data_dir or os.getenv("DATASET_DIR", "data"))
|
| 118 |
+
self.cache_dir = Path(cache_dir)
|
| 119 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
self._load_big_impedance()
|
| 122 |
+
self._build_splits_and_patches(special_splits=special_splits)
|
| 123 |
+
self._build_wavelets()
|
| 124 |
+
self.big_reflect = self._load_or_build_reflect()
|
| 125 |
+
self.record_data = {
|
| 126 |
+
f0: {
|
| 127 |
+
phase: {
|
| 128 |
+
snr: self._patches_from_big_image(
|
| 129 |
+
self._load_or_build_record(f0=f0, phase=phase, noise_snr=snr)
|
| 130 |
+
)
|
| 131 |
+
for snr in self.noise_snr
|
| 132 |
+
}
|
| 133 |
+
for phase in self.ricks_phase
|
| 134 |
+
}
|
| 135 |
+
for f0 in self.ricks
|
| 136 |
+
}
|
| 137 |
+
self.dipin_datas = {
|
| 138 |
+
dipin_v: self._patches_from_big_image(self._load_or_build_dipin(dipin_v))
|
| 139 |
+
for dipin_v in self.dipins
|
| 140 |
+
}
|
| 141 |
+
all_indices = list(range(len(self.splits)))
|
| 142 |
+
self.patch_indices = all_indices if patch_indices is None else list(patch_indices)
|
| 143 |
+
|
| 144 |
+
def __len__(self) -> int:
|
| 145 |
+
return len(self.patch_indices)
|
| 146 |
+
|
| 147 |
+
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
| 148 |
+
patch_idx = self.patch_indices[index]
|
| 149 |
+
f0 = self.fixed_f0 if self.fixed_f0 is not None else random.choice(self.ricks)
|
| 150 |
+
dipin_v = (
|
| 151 |
+
self.fixed_dipin_v
|
| 152 |
+
if self.fixed_dipin_v is not None
|
| 153 |
+
else random.choice(self.dipins)
|
| 154 |
+
)
|
| 155 |
+
noise_snr = (
|
| 156 |
+
self.fixed_noise_snr
|
| 157 |
+
if self.fixed_noise_snr is not None
|
| 158 |
+
else random.choice(self.noise_snr)
|
| 159 |
+
)
|
| 160 |
+
f0_phase = (
|
| 161 |
+
self.fixed_f0_phase
|
| 162 |
+
if self.fixed_f0_phase is not None
|
| 163 |
+
else random.choice(self.ricks_phase)
|
| 164 |
+
)
|
| 165 |
+
sample = {
|
| 166 |
+
"patch_idx": torch.tensor(patch_idx, dtype=torch.long),
|
| 167 |
+
"seed": torch.tensor(
|
| 168 |
+
self.base_seed + index + int(noise_snr) * 100, dtype=torch.long
|
| 169 |
+
),
|
| 170 |
+
}
|
| 171 |
+
if "image" in self.train_keys:
|
| 172 |
+
sample["image"] = torch.from_numpy(self.file_data[patch_idx]).float()
|
| 173 |
+
if "dipin" in self.train_keys:
|
| 174 |
+
sample["dipin"] = torch.from_numpy(self.dipin_datas[dipin_v][patch_idx]).float()
|
| 175 |
+
sample["dipin_v"] = torch.tensor(dipin_v, dtype=torch.float32).reshape(1, 1, 1)
|
| 176 |
+
if "record" in self.train_keys:
|
| 177 |
+
sample["record"] = torch.from_numpy(
|
| 178 |
+
self.record_data[f0][f0_phase][noise_snr][patch_idx]
|
| 179 |
+
).float()
|
| 180 |
+
sample["rick_v"] = torch.tensor(f0, dtype=torch.float32).reshape(1, 1, 1)
|
| 181 |
+
sample["rick_phase"] = torch.tensor(f0_phase, dtype=torch.float32).reshape(1, 1, 1)
|
| 182 |
+
sample["snr_v"] = torch.tensor(noise_snr, dtype=torch.float32).reshape(1, 1, 1)
|
| 183 |
+
if "reflection" in self.train_keys:
|
| 184 |
+
sample["reflection"] = torch.from_numpy(self.reflect_data[patch_idx]).float()
|
| 185 |
+
if "mask_speed" in self.train_keys:
|
| 186 |
+
sample["mask_speed"] = torch.from_numpy(
|
| 187 |
+
self.mask_data[patch_idx] * self.file_data[patch_idx]
|
| 188 |
+
).float()
|
| 189 |
+
if self.use_mask:
|
| 190 |
+
sample["mask"] = torch.from_numpy(self.mask_data[patch_idx]).float()
|
| 191 |
+
return sample
|
| 192 |
+
|
| 193 |
+
def fan(self, x: np.ndarray) -> np.ndarray:
|
| 194 |
+
minn = 5.0931
|
| 195 |
+
maxn = 6.501110975896774
|
| 196 |
+
return np.exp(x * (maxn - minn) + minn) * 10.9 + 200
|
| 197 |
+
|
| 198 |
+
def inv_normal(self, x: np.ndarray) -> np.ndarray:
|
| 199 |
+
vmin = float(self.info["normal_min"])
|
| 200 |
+
vmax = float(self.info["normal_max"])
|
| 201 |
+
if self.normalize == "minmax":
|
| 202 |
+
return x * (vmax - vmin) + vmin
|
| 203 |
+
return x * vmax
|
| 204 |
+
|
| 205 |
+
def _load_big_impedance(self) -> None:
|
| 206 |
+
file_path = self.data_dir / "Overthrust_trueimp.mat"
|
| 207 |
+
if not file_path.exists():
|
| 208 |
+
raise FileNotFoundError(f"Overthrust data not found: {file_path}")
|
| 209 |
+
wave = scipy.io.loadmat(file_path)["Overthrust_trueimp"].T
|
| 210 |
+
wave = np.log(wave)
|
| 211 |
+
normal_min = wave.min()
|
| 212 |
+
normal_max = wave.max()
|
| 213 |
+
self.info.update(
|
| 214 |
+
{"normal_min": normal_min, "normal_max": normal_max, "normal": "max"}
|
| 215 |
+
)
|
| 216 |
+
self.big_img_unnorm = wave
|
| 217 |
+
self.big_speedimg = wave
|
| 218 |
+
if self.normalize == "max":
|
| 219 |
+
wave = wave / normal_max
|
| 220 |
+
elif self.normalize == "minmax":
|
| 221 |
+
wave = (wave - normal_min) / (normal_max - normal_min)
|
| 222 |
+
else:
|
| 223 |
+
raise ValueError(f"Unsupported normalize: {self.normalize}")
|
| 224 |
+
self.big_img = wave.astype(np.float32)
|
| 225 |
+
|
| 226 |
+
def _build_splits_and_patches(self, special_splits: bool = False) -> None:
|
| 227 |
+
self.big_mask = np.zeros(self.big_img.shape, dtype=np.float32)
|
| 228 |
+
for col in (100, 200, 300):
|
| 229 |
+
if col < self.big_mask.shape[1]:
|
| 230 |
+
self.big_mask[:, col : col + 1] = 1
|
| 231 |
+
if special_splits:
|
| 232 |
+
splits = []
|
| 233 |
+
for x in range(0, 551 - self.size, 20):
|
| 234 |
+
for y in range(0, 551 - self.size, 20):
|
| 235 |
+
splits.append((x, y))
|
| 236 |
+
for y in range(0, 551 - self.size, 9):
|
| 237 |
+
splits.extend([(30, y), (90, y), (140, y)])
|
| 238 |
+
elif self.size == 256:
|
| 239 |
+
splits = [
|
| 240 |
+
(0, 0),
|
| 241 |
+
(146, 0),
|
| 242 |
+
(551 - 256, 0),
|
| 243 |
+
(0, 145),
|
| 244 |
+
(146, 145),
|
| 245 |
+
(551 - 256, 145),
|
| 246 |
+
]
|
| 247 |
+
else:
|
| 248 |
+
splits = []
|
| 249 |
+
interval_size = self.size - 1
|
| 250 |
+
for r in range(0, self.big_img.shape[0] - self.size, interval_size):
|
| 251 |
+
for c in range(0, self.big_img.shape[1] - self.size, interval_size):
|
| 252 |
+
splits.append((r, c))
|
| 253 |
+
splits.append((r, self.big_img.shape[1] - self.size))
|
| 254 |
+
for c in range(0, self.big_img.shape[1] - self.size, interval_size):
|
| 255 |
+
splits.append((self.big_img.shape[0] - self.size, c))
|
| 256 |
+
splits.append(
|
| 257 |
+
(self.big_img.shape[0] - self.size, self.big_img.shape[1] - self.size)
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
self.splits = []
|
| 261 |
+
patches = []
|
| 262 |
+
masks = []
|
| 263 |
+
for x, y in splits:
|
| 264 |
+
x2 = x + self.size
|
| 265 |
+
y2 = y + self.size
|
| 266 |
+
if x2 > self.big_img.shape[0] or y2 > self.big_img.shape[1]:
|
| 267 |
+
continue
|
| 268 |
+
self.splits.append((x, y))
|
| 269 |
+
patches.append(self.big_img[x:x2, y:y2].reshape(1, self.size, self.size))
|
| 270 |
+
masks.append(self.big_mask[x:x2, y:y2].reshape(1, self.size, self.size))
|
| 271 |
+
self.file_data = np.stack(patches, axis=0).astype(np.float32)[:: self.interval]
|
| 272 |
+
self.mask_data = np.stack(masks, axis=0).astype(np.float32)[:: self.interval]
|
| 273 |
+
self.splits = self.splits[:: self.interval]
|
| 274 |
+
|
| 275 |
+
def _build_wavelets(self) -> None:
|
| 276 |
+
nt0 = 256
|
| 277 |
+
dt0 = 0.002
|
| 278 |
+
self.wavelets = {}
|
| 279 |
+
for f0 in self.ricks:
|
| 280 |
+
self.wavelets[f0] = {}
|
| 281 |
+
wav = pylops.utils.wavelets.ricker(np.arange(nt0 // 2) * dt0, f0)[0]
|
| 282 |
+
for phase in self.ricks_phase:
|
| 283 |
+
self.wavelets[f0][phase] = self.phaseshift(wav, phase)
|
| 284 |
+
|
| 285 |
+
def _cache_path(self, name: str) -> Path:
|
| 286 |
+
return self.cache_dir / name
|
| 287 |
+
|
| 288 |
+
def _load_or_build_reflect(self) -> np.ndarray:
|
| 289 |
+
cache_path = self._cache_path(
|
| 290 |
+
f"Overthrust_trueimpBig_sesimic_reflect_{self.zhengyan_type}.npy"
|
| 291 |
+
)
|
| 292 |
+
if not cache_path.exists():
|
| 293 |
+
size = self.big_img.shape[0]
|
| 294 |
+
if self.zhengyan_type == "linear":
|
| 295 |
+
s1 = np.diag(0.5 * np.ones(size - 1, dtype="float32"), k=1) - np.diag(
|
| 296 |
+
0.5 * np.ones(size - 1, dtype="float32"), k=-1
|
| 297 |
+
)
|
| 298 |
+
s1[-1] = s1[0] = 0
|
| 299 |
+
reflect = s1 @ self.big_img
|
| 300 |
+
elif self.zhengyan_type == "nonlinear":
|
| 301 |
+
expspeed = (
|
| 302 |
+
np.exp(self.big_img_unnorm)
|
| 303 |
+
if self.have_exp is False
|
| 304 |
+
else self.big_img_unnorm
|
| 305 |
+
)
|
| 306 |
+
s1 = np.eye(size, k=1) - np.eye(size, k=0)
|
| 307 |
+
s2 = np.eye(size, k=1) + np.eye(size, k=0)
|
| 308 |
+
s1[-1] = 0
|
| 309 |
+
s2[-1] = 0
|
| 310 |
+
numerator = s1 @ expspeed
|
| 311 |
+
denominator = s2 @ expspeed
|
| 312 |
+
denominator = np.where(denominator < 1e-6, 1e-6, denominator)
|
| 313 |
+
reflect = numerator / denominator
|
| 314 |
+
else:
|
| 315 |
+
raise ValueError(f"Unsupported zhengyan_type: {self.zhengyan_type}")
|
| 316 |
+
np.save(cache_path, reflect)
|
| 317 |
+
reflect = np.load(cache_path).astype(np.float32)
|
| 318 |
+
self.reflect_data = self._patches_from_big_image(reflect)
|
| 319 |
+
return reflect
|
| 320 |
+
|
| 321 |
+
def _load_or_build_record(self, f0: int, phase: int, noise_snr: int) -> np.ndarray:
|
| 322 |
+
cache_path = self._cache_path(
|
| 323 |
+
f"Overthrust_trueimpBig_sesimic_record__{self.zhengyan_type}"
|
| 324 |
+
f"_ricker={f0:02d}-{phase:03d}_{self.noise_type}={noise_snr:02d}"
|
| 325 |
+
f"_seed={self.base_seed}.npy"
|
| 326 |
+
)
|
| 327 |
+
if not cache_path.exists():
|
| 328 |
+
wav = self.wavelets[f0][phase]
|
| 329 |
+
w_mat = pylops.utils.signalprocessing.convmtx(
|
| 330 |
+
wav, self.big_reflect.shape[0], len(wav) // 2
|
| 331 |
+
)[: self.big_reflect.shape[0]]
|
| 332 |
+
records_clear = w_mat @ self.big_reflect
|
| 333 |
+
rng = np.random.default_rng(self.base_seed + f0 * 1000 + phase * 10 + noise_snr)
|
| 334 |
+
if self.noise_type == "guassian_band":
|
| 335 |
+
record, _ = self.add_gaussian_band_noise(noise_snr, records_clear, rng=rng)
|
| 336 |
+
elif self.noise_type == "guassian":
|
| 337 |
+
record, _ = self.add_gaussian_noise(noise_snr, records_clear, rng=rng)
|
| 338 |
+
else:
|
| 339 |
+
raise ValueError(f"Unsupported noise_type: {self.noise_type}")
|
| 340 |
+
np.save(cache_path, record)
|
| 341 |
+
record = np.load(cache_path).astype(np.float32)
|
| 342 |
+
self.info.update(
|
| 343 |
+
{
|
| 344 |
+
"record_minn": min(float(self.info.get("record_minn", 10)), float(record.min())),
|
| 345 |
+
"record_maxn": max(float(self.info.get("record_maxn", -10)), float(record.max())),
|
| 346 |
+
"record_normal": "max",
|
| 347 |
+
}
|
| 348 |
+
)
|
| 349 |
+
if self.record_noraml:
|
| 350 |
+
record = record / 0.3215932963300079
|
| 351 |
+
self.info["record_maxn"] = 0.3215932963300079
|
| 352 |
+
return record
|
| 353 |
+
|
| 354 |
+
def _load_or_build_dipin(self, dipin_v: float) -> np.ndarray:
|
| 355 |
+
cache_path = self._cache_path(
|
| 356 |
+
f"Overthrust_trueimpBig_sesimic_dipin={dipin_v:.03f}.npy"
|
| 357 |
+
)
|
| 358 |
+
if not cache_path.exists():
|
| 359 |
+
bb, aa = butter(2, dipin_v, "low")
|
| 360 |
+
smooth_filter_z = np.ones(self.dipin_nsmoothz) / float(self.dipin_nsmoothz)
|
| 361 |
+
smooth_filter_x = np.ones(self.dipin_nsmoothx) / float(self.dipin_nsmoothx)
|
| 362 |
+
mback = filtfilt(bb, aa, self.big_img.T).T
|
| 363 |
+
mback = filtfilt(smooth_filter_z, 1, mback, axis=0)
|
| 364 |
+
mback = filtfilt(smooth_filter_x, 1, mback, axis=1)
|
| 365 |
+
np.save(cache_path, mback)
|
| 366 |
+
return np.load(cache_path).astype(np.float32)
|
| 367 |
+
|
| 368 |
+
def _patches_from_big_image(self, big_image: np.ndarray) -> np.ndarray:
|
| 369 |
+
patches = []
|
| 370 |
+
for x, y in self.splits:
|
| 371 |
+
patches.append(
|
| 372 |
+
big_image[x : x + self.size, y : y + self.size].reshape(
|
| 373 |
+
1, self.size, self.size
|
| 374 |
+
)
|
| 375 |
+
)
|
| 376 |
+
return np.stack(patches, axis=0).astype(np.float32)
|
inference/eval.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from scipy.stats import pearsonr
|
| 12 |
+
from skimage.metrics import structural_similarity
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 16 |
+
if str(REPO_ROOT) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 18 |
+
|
| 19 |
+
from inference.dataset import OverthrustTrueimpDataset
|
| 20 |
+
from pipeline import SeismicImpInvLDDPMPipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
OVERTHRUST_CONFIG = {
|
| 24 |
+
"size": 256,
|
| 25 |
+
"patch_indices": [0, 1, 2, 3, 4, 5],
|
| 26 |
+
"noise_snr": 15,
|
| 27 |
+
"dipin_v": 0.012,
|
| 28 |
+
"f0": 30,
|
| 29 |
+
"f0_phase": 0,
|
| 30 |
+
"seed": 1234,
|
| 31 |
+
"zhengyan_type": "nonlinear",
|
| 32 |
+
"normalize": "minmax",
|
| 33 |
+
"batch_size": 3,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def stitch_patches(
|
| 38 |
+
patches: list[np.ndarray], splits: list[tuple[int, int]], big_shape: tuple[int, int], img_size: int
|
| 39 |
+
) -> np.ndarray:
|
| 40 |
+
rec = np.zeros(big_shape, dtype=np.float32)
|
| 41 |
+
cnt = np.zeros(big_shape, dtype=np.float32)
|
| 42 |
+
for idx, (x, y) in enumerate(splits):
|
| 43 |
+
rec[x : x + img_size, y : y + img_size] += patches[idx]
|
| 44 |
+
cnt[x : x + img_size, y : y + img_size] += 1
|
| 45 |
+
return rec / np.maximum(cnt, 1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def compute_metrics(prediction: np.ndarray, target: np.ndarray) -> dict[str, float]:
|
| 49 |
+
diff = prediction - target
|
| 50 |
+
denom = np.linalg.norm(diff.ravel()) ** 2
|
| 51 |
+
psnr = float("inf") if denom == 0 else float(
|
| 52 |
+
10.0 * np.log10(len(prediction.ravel()) * np.max(prediction.ravel()) ** 2 / denom)
|
| 53 |
+
)
|
| 54 |
+
return {
|
| 55 |
+
"PSNR": psnr,
|
| 56 |
+
"rre": float(np.linalg.norm(diff.ravel()) / np.linalg.norm(target.ravel())),
|
| 57 |
+
"SSIM": float(structural_similarity(target, prediction, data_range=target.max())),
|
| 58 |
+
"PCC": float(pearsonr(prediction.ravel(), target.ravel()).statistic),
|
| 59 |
+
"nmse": float(np.sum(diff ** 2) / np.sum(target ** 2)),
|
| 60 |
+
"mse": float(np.mean(diff ** 2) / prediction.size),
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def save_comparison(
|
| 65 |
+
target_impedance: np.ndarray,
|
| 66 |
+
prediction_impedance: np.ndarray,
|
| 67 |
+
output_path: Path,
|
| 68 |
+
) -> None:
|
| 69 |
+
error = np.abs(target_impedance - prediction_impedance)
|
| 70 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 71 |
+
vmin_imp = min(target_impedance.min(), prediction_impedance.min())
|
| 72 |
+
vmax_imp = max(target_impedance.max(), prediction_impedance.max())
|
| 73 |
+
for ax, arr, title in zip(
|
| 74 |
+
axes,
|
| 75 |
+
[target_impedance, prediction_impedance, error],
|
| 76 |
+
["Target (Impedance)", "Prediction (Impedance)", "Error (Impedance)"],
|
| 77 |
+
):
|
| 78 |
+
if "Error" in title:
|
| 79 |
+
im = ax.imshow(arr, cmap="hot", vmin=0, vmax=error.max())
|
| 80 |
+
else:
|
| 81 |
+
im = ax.imshow(arr, cmap="jet", vmin=vmin_imp, vmax=vmax_imp)
|
| 82 |
+
ax.set_title(title)
|
| 83 |
+
ax.axis("off")
|
| 84 |
+
plt.colorbar(im, ax=ax, fraction=0.046)
|
| 85 |
+
plt.tight_layout()
|
| 86 |
+
fig.savefig(output_path, dpi=150)
|
| 87 |
+
plt.close(fig)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def evaluate_overthrust(
|
| 91 |
+
pipe: SeismicImpInvLDDPMPipeline,
|
| 92 |
+
output_dir: str | Path = "outputs/overthrust",
|
| 93 |
+
num_inference_steps: int = 1000,
|
| 94 |
+
device: str | torch.device | None = None,
|
| 95 |
+
) -> dict[str, object]:
|
| 96 |
+
output_dir = Path(output_dir)
|
| 97 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 98 |
+
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 99 |
+
pipe = pipe.to(device)
|
| 100 |
+
|
| 101 |
+
dataset = OverthrustTrueimpDataset(
|
| 102 |
+
size=OVERTHRUST_CONFIG["size"],
|
| 103 |
+
normalize=OVERTHRUST_CONFIG["normalize"],
|
| 104 |
+
zhengyan_type=OVERTHRUST_CONFIG["zhengyan_type"],
|
| 105 |
+
ricks=[OVERTHRUST_CONFIG["f0"]],
|
| 106 |
+
ricks_phase=[OVERTHRUST_CONFIG["f0_phase"]],
|
| 107 |
+
noise_snr=[OVERTHRUST_CONFIG["noise_snr"]],
|
| 108 |
+
dipins=[OVERTHRUST_CONFIG["dipin_v"]],
|
| 109 |
+
record_noraml=True,
|
| 110 |
+
train_keys=["image", "dipin", "record"],
|
| 111 |
+
patch_indices=OVERTHRUST_CONFIG["patch_indices"],
|
| 112 |
+
base_seed=OVERTHRUST_CONFIG["seed"],
|
| 113 |
+
data_dir=REPO_ROOT / "data",
|
| 114 |
+
cache_dir=output_dir / "cache",
|
| 115 |
+
fixed_f0=OVERTHRUST_CONFIG["f0"],
|
| 116 |
+
fixed_dipin_v=OVERTHRUST_CONFIG["dipin_v"],
|
| 117 |
+
fixed_noise_snr=OVERTHRUST_CONFIG["noise_snr"],
|
| 118 |
+
fixed_f0_phase=OVERTHRUST_CONFIG["f0_phase"],
|
| 119 |
+
)
|
| 120 |
+
loader = DataLoader(
|
| 121 |
+
dataset,
|
| 122 |
+
batch_size=OVERTHRUST_CONFIG["batch_size"],
|
| 123 |
+
shuffle=False,
|
| 124 |
+
num_workers=0,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
all_predictions: list[np.ndarray] = []
|
| 128 |
+
all_targets: list[np.ndarray] = []
|
| 129 |
+
all_reconstructions: list[np.ndarray] = []
|
| 130 |
+
for batch in loader:
|
| 131 |
+
seeds = batch["seed"].tolist()
|
| 132 |
+
dipin = batch["dipin"].to(device)
|
| 133 |
+
record = batch["record"].to(device)
|
| 134 |
+
image = batch["image"].to(device)
|
| 135 |
+
output = pipe(
|
| 136 |
+
dipin=dipin,
|
| 137 |
+
record=record,
|
| 138 |
+
image=image,
|
| 139 |
+
num_inference_steps=num_inference_steps,
|
| 140 |
+
seeds=seeds,
|
| 141 |
+
)
|
| 142 |
+
prediction = output.impedance_samples
|
| 143 |
+
reconstruction = output.impedance_reconstructed
|
| 144 |
+
for local_idx in range(prediction.shape[0]):
|
| 145 |
+
all_predictions.append(prediction[local_idx, 0].detach().cpu().numpy())
|
| 146 |
+
all_targets.append(image[local_idx, 0].detach().cpu().numpy())
|
| 147 |
+
all_reconstructions.append(reconstruction[local_idx, 0].detach().cpu().numpy())
|
| 148 |
+
|
| 149 |
+
full_target = stitch_patches(
|
| 150 |
+
all_targets, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
|
| 151 |
+
)
|
| 152 |
+
full_prediction = stitch_patches(
|
| 153 |
+
all_predictions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
|
| 154 |
+
)
|
| 155 |
+
full_reconstruction = stitch_patches(
|
| 156 |
+
all_reconstructions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
full_target_impedance = dataset.fan(full_target)
|
| 160 |
+
full_prediction_impedance = dataset.fan(full_prediction)
|
| 161 |
+
full_reconstruction_impedance = dataset.fan(full_reconstruction)
|
| 162 |
+
|
| 163 |
+
metrics_summary = {
|
| 164 |
+
"config": {**OVERTHRUST_CONFIG, "num_inference_steps": num_inference_steps},
|
| 165 |
+
"normalized": compute_metrics(full_prediction, full_target),
|
| 166 |
+
"impedance": compute_metrics(full_prediction_impedance, full_target_impedance),
|
| 167 |
+
"encode_impedance": compute_metrics(
|
| 168 |
+
full_reconstruction_impedance, full_target_impedance
|
| 169 |
+
),
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
paths = {
|
| 173 |
+
"full_target": output_dir / "full_target.npy",
|
| 174 |
+
"full_prediction": output_dir / "full_prediction.npy",
|
| 175 |
+
"full_reconstruction": output_dir / "full_reconstruction.npy",
|
| 176 |
+
"comparison": output_dir / "comparison_impedance.png",
|
| 177 |
+
"metrics": output_dir / "metrics_summary.json",
|
| 178 |
+
}
|
| 179 |
+
np.save(paths["full_target"], full_target)
|
| 180 |
+
np.save(paths["full_prediction"], full_prediction)
|
| 181 |
+
np.save(paths["full_reconstruction"], full_reconstruction)
|
| 182 |
+
save_comparison(full_target_impedance, full_prediction_impedance, paths["comparison"])
|
| 183 |
+
paths["metrics"].write_text(json.dumps(metrics_summary, indent=2), encoding="utf-8")
|
| 184 |
+
return {
|
| 185 |
+
"metrics": metrics_summary,
|
| 186 |
+
"paths": {key: str(value) for key, value in paths.items()},
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def parse_args() -> argparse.Namespace:
|
| 191 |
+
parser = argparse.ArgumentParser(description="Evaluate seismic-lddpm on Overthrust.")
|
| 192 |
+
parser.add_argument("--model", default="mally-2000/seismic-lddpm")
|
| 193 |
+
parser.add_argument("--output", default="outputs/overthrust")
|
| 194 |
+
parser.add_argument("--device", default=None)
|
| 195 |
+
parser.add_argument("--num-inference-steps", type=int, default=1000)
|
| 196 |
+
return parser.parse_args()
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def main() -> None:
|
| 200 |
+
args = parse_args()
|
| 201 |
+
pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
|
| 202 |
+
args.model,
|
| 203 |
+
torch_dtype=torch.float32,
|
| 204 |
+
trust_remote_code=True,
|
| 205 |
+
)
|
| 206 |
+
result = evaluate_overthrust(
|
| 207 |
+
pipe,
|
| 208 |
+
output_dir=args.output,
|
| 209 |
+
num_inference_steps=args.num_inference_steps,
|
| 210 |
+
device=args.device,
|
| 211 |
+
)
|
| 212 |
+
print(json.dumps(result, indent=2))
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
main()
|
inference/infer.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(REPO_ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 14 |
+
|
| 15 |
+
from pipeline import SeismicImpInvLDDPMPipeline
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_bchw_npy(path: str | Path) -> torch.Tensor:
|
| 19 |
+
arr = np.load(path).astype(np.float32)
|
| 20 |
+
if arr.ndim == 2:
|
| 21 |
+
arr = arr[None, None, :, :]
|
| 22 |
+
elif arr.ndim == 3:
|
| 23 |
+
arr = arr[None, :, :, :]
|
| 24 |
+
elif arr.ndim != 4:
|
| 25 |
+
raise ValueError(f"Expected 2D, 3D, or 4D array at {path}, got shape {arr.shape}")
|
| 26 |
+
return torch.from_numpy(arr)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def save_prediction_png(prediction: np.ndarray, output_path: Path) -> None:
|
| 30 |
+
fig, ax = plt.subplots(figsize=(5, 5))
|
| 31 |
+
im = ax.imshow(prediction, cmap="jet")
|
| 32 |
+
ax.axis("off")
|
| 33 |
+
plt.colorbar(im, ax=ax, fraction=0.046)
|
| 34 |
+
plt.tight_layout()
|
| 35 |
+
fig.savefig(output_path, dpi=150)
|
| 36 |
+
plt.close(fig)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def parse_args() -> argparse.Namespace:
|
| 40 |
+
parser = argparse.ArgumentParser(description="Run seismic-lddpm inference on one sample.")
|
| 41 |
+
parser.add_argument("--dipin", required=True, help="Path to low-frequency impedance .npy")
|
| 42 |
+
parser.add_argument("--record", required=True, help="Path to seismic record .npy")
|
| 43 |
+
parser.add_argument("--model", default="mally-2000/seismic-lddpm")
|
| 44 |
+
parser.add_argument("--output", default="outputs/single")
|
| 45 |
+
parser.add_argument("--device", default=None)
|
| 46 |
+
parser.add_argument("--seed", type=int, default=1234)
|
| 47 |
+
parser.add_argument("--num-inference-steps", type=int, default=1000)
|
| 48 |
+
return parser.parse_args()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main() -> None:
|
| 52 |
+
args = parse_args()
|
| 53 |
+
output_dir = Path(args.output)
|
| 54 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 56 |
+
pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
|
| 57 |
+
args.model,
|
| 58 |
+
torch_dtype=torch.float32,
|
| 59 |
+
trust_remote_code=True,
|
| 60 |
+
).to(device)
|
| 61 |
+
dipin = load_bchw_npy(args.dipin).to(device)
|
| 62 |
+
record = load_bchw_npy(args.record).to(device)
|
| 63 |
+
output = pipe(
|
| 64 |
+
dipin=dipin,
|
| 65 |
+
record=record,
|
| 66 |
+
num_inference_steps=args.num_inference_steps,
|
| 67 |
+
seed=args.seed,
|
| 68 |
+
)
|
| 69 |
+
prediction = output.impedance_samples.detach().cpu().numpy()
|
| 70 |
+
np.save(output_dir / "prediction.npy", prediction)
|
| 71 |
+
save_prediction_png(prediction[0, 0], output_dir / "prediction.png")
|
| 72 |
+
print(f"Saved: {output_dir / 'prediction.npy'}")
|
| 73 |
+
print(f"Saved: {output_dir / 'prediction.png'}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
main()
|
pipeline.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from diffusers import DDPMScheduler, DiffusionPipeline, UNet2DModel, VQModel
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SeismicImpInvLDDPMPipelineOutput(BaseOutput):
|
| 13 |
+
impedance_samples: torch.Tensor | np.ndarray
|
| 14 |
+
impedance_latents: torch.Tensor | np.ndarray
|
| 15 |
+
impedance_dipin: torch.Tensor | np.ndarray
|
| 16 |
+
impedance_reconstructed: torch.Tensor | np.ndarray | None = None
|
| 17 |
+
record_features: torch.Tensor | np.ndarray | None = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SeismicImpInvLDDPMPipeline(DiffusionPipeline):
|
| 21 |
+
"""SAII-LDDPM impedance inversion pipeline."""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
vq_model: VQModel,
|
| 26 |
+
condition_encoder: torch.nn.Module,
|
| 27 |
+
unet: UNet2DModel,
|
| 28 |
+
scheduler: DDPMScheduler,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.register_modules(
|
| 32 |
+
vq_model=vq_model,
|
| 33 |
+
condition_encoder=condition_encoder,
|
| 34 |
+
unet=unet,
|
| 35 |
+
scheduler=scheduler,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def _encode_conditioning(
|
| 39 |
+
self, dipin: torch.Tensor, record: torch.Tensor
|
| 40 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 41 |
+
dipin_latents = self.vq_model.encode(dipin).latents
|
| 42 |
+
if hasattr(self.condition_encoder, "encode") and callable(
|
| 43 |
+
self.condition_encoder.encode
|
| 44 |
+
):
|
| 45 |
+
record_features = self.condition_encoder.encode(record)
|
| 46 |
+
else:
|
| 47 |
+
record_features = self.condition_encoder(record)
|
| 48 |
+
return (
|
| 49 |
+
dipin_latents.to(dtype=self.unet.dtype),
|
| 50 |
+
record_features.to(dtype=self.unet.dtype),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def _extract_into_tensor(
|
| 55 |
+
arr: torch.Tensor, timesteps: torch.Tensor, broadcast_shape: torch.Size
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
values = arr.to(device=timesteps.device, dtype=torch.float32).gather(0, timesteps)
|
| 58 |
+
return values.reshape(timesteps.shape[0], *((1,) * (len(broadcast_shape) - 1)))
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def _build_legacy_ddpm_buffers(
|
| 62 |
+
scheduler: DDPMScheduler, device: torch.device
|
| 63 |
+
) -> dict[str, torch.Tensor]:
|
| 64 |
+
betas = scheduler.betas.to(device=device, dtype=torch.float32)
|
| 65 |
+
alphas = 1.0 - betas
|
| 66 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 67 |
+
alphas_cumprod_prev = torch.cat(
|
| 68 |
+
[torch.ones(1, device=device), alphas_cumprod[:-1]], dim=0
|
| 69 |
+
)
|
| 70 |
+
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
| 71 |
+
posterior_log_variance_clipped = torch.log(
|
| 72 |
+
torch.clamp(posterior_variance, min=1e-20)
|
| 73 |
+
)
|
| 74 |
+
return {
|
| 75 |
+
"sqrt_recip_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod),
|
| 76 |
+
"sqrt_recipm1_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod - 1),
|
| 77 |
+
"posterior_mean_coef1": betas
|
| 78 |
+
* torch.sqrt(alphas_cumprod_prev)
|
| 79 |
+
/ (1.0 - alphas_cumprod),
|
| 80 |
+
"posterior_mean_coef2": (1.0 - alphas_cumprod_prev)
|
| 81 |
+
* torch.sqrt(alphas)
|
| 82 |
+
/ (1.0 - alphas_cumprod),
|
| 83 |
+
"posterior_log_variance_clipped": posterior_log_variance_clipped,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def _randn_like_sample(
|
| 88 |
+
sample: torch.Tensor, generator: torch.Generator | list[torch.Generator] | None
|
| 89 |
+
) -> torch.Tensor:
|
| 90 |
+
if isinstance(generator, list):
|
| 91 |
+
if len(generator) != sample.shape[0]:
|
| 92 |
+
raise ValueError(
|
| 93 |
+
f"Expected {sample.shape[0]} generators, got {len(generator)}"
|
| 94 |
+
)
|
| 95 |
+
return torch.cat(
|
| 96 |
+
[
|
| 97 |
+
torch.randn(
|
| 98 |
+
sample[i : i + 1].shape,
|
| 99 |
+
generator=sample_generator,
|
| 100 |
+
device=sample.device,
|
| 101 |
+
dtype=sample.dtype,
|
| 102 |
+
)
|
| 103 |
+
for i, sample_generator in enumerate(generator)
|
| 104 |
+
],
|
| 105 |
+
dim=0,
|
| 106 |
+
)
|
| 107 |
+
return torch.randn(
|
| 108 |
+
sample.shape, generator=generator, device=sample.device, dtype=sample.dtype
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def _ddpm_step(
|
| 112 |
+
self,
|
| 113 |
+
latents: torch.Tensor,
|
| 114 |
+
conditioning: torch.Tensor,
|
| 115 |
+
timestep: torch.Tensor,
|
| 116 |
+
generator: torch.Generator | list[torch.Generator] | None,
|
| 117 |
+
buffers: dict[str, torch.Tensor],
|
| 118 |
+
) -> torch.Tensor:
|
| 119 |
+
model_input = torch.cat([latents, conditioning], dim=1)
|
| 120 |
+
noise_pred = self.unet(model_input, timestep).sample
|
| 121 |
+
pred_x0 = (
|
| 122 |
+
self._extract_into_tensor(
|
| 123 |
+
buffers["sqrt_recip_alphas_cumprod"], timestep, latents.shape
|
| 124 |
+
)
|
| 125 |
+
* latents
|
| 126 |
+
- self._extract_into_tensor(
|
| 127 |
+
buffers["sqrt_recipm1_alphas_cumprod"], timestep, latents.shape
|
| 128 |
+
)
|
| 129 |
+
* noise_pred
|
| 130 |
+
)
|
| 131 |
+
pred_x0 = self.vq_model.quantize(pred_x0)[0]
|
| 132 |
+
model_mean = (
|
| 133 |
+
self._extract_into_tensor(
|
| 134 |
+
buffers["posterior_mean_coef1"], timestep, latents.shape
|
| 135 |
+
)
|
| 136 |
+
* pred_x0
|
| 137 |
+
+ self._extract_into_tensor(
|
| 138 |
+
buffers["posterior_mean_coef2"], timestep, latents.shape
|
| 139 |
+
)
|
| 140 |
+
* latents
|
| 141 |
+
)
|
| 142 |
+
noise = self._randn_like_sample(latents, generator)
|
| 143 |
+
nonzero_mask = (1 - (timestep == 0).float()).reshape(
|
| 144 |
+
latents.shape[0], *((1,) * (len(latents.shape) - 1))
|
| 145 |
+
)
|
| 146 |
+
return model_mean + nonzero_mask * (
|
| 147 |
+
0.5
|
| 148 |
+
* self._extract_into_tensor(
|
| 149 |
+
buffers["posterior_log_variance_clipped"], timestep, latents.shape
|
| 150 |
+
)
|
| 151 |
+
).exp() * noise
|
| 152 |
+
|
| 153 |
+
@torch.no_grad()
|
| 154 |
+
def __call__(
|
| 155 |
+
self,
|
| 156 |
+
dipin: torch.Tensor,
|
| 157 |
+
record: torch.Tensor,
|
| 158 |
+
image: torch.Tensor | None = None,
|
| 159 |
+
num_inference_steps: int = 1000,
|
| 160 |
+
seed: int | None = None,
|
| 161 |
+
seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None,
|
| 162 |
+
generator: torch.Generator | None = None,
|
| 163 |
+
output_type: str = "tensor",
|
| 164 |
+
) -> SeismicImpInvLDDPMPipelineOutput:
|
| 165 |
+
device = self.unet.device
|
| 166 |
+
if seeds is not None:
|
| 167 |
+
if isinstance(seeds, torch.Tensor):
|
| 168 |
+
seeds = seeds.detach().cpu().tolist()
|
| 169 |
+
seeds = [int(value) for value in seeds]
|
| 170 |
+
if len(seeds) != dipin.shape[0]:
|
| 171 |
+
raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}")
|
| 172 |
+
generator = [
|
| 173 |
+
torch.Generator(device=device).manual_seed(value) for value in seeds
|
| 174 |
+
]
|
| 175 |
+
elif seed is not None:
|
| 176 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 177 |
+
elif generator is None:
|
| 178 |
+
generator = torch.Generator(device=device)
|
| 179 |
+
|
| 180 |
+
dipin = dipin.to(device=device, dtype=self.vq_model.dtype)
|
| 181 |
+
record = record.to(device=device, dtype=self.unet.dtype)
|
| 182 |
+
impedance_dipin, record_features = self._encode_conditioning(dipin, record)
|
| 183 |
+
conditioning = torch.cat([impedance_dipin, record_features], dim=1)
|
| 184 |
+
impedance_latents = self._randn_like_sample(
|
| 185 |
+
torch.empty(
|
| 186 |
+
impedance_dipin.shape,
|
| 187 |
+
device=device,
|
| 188 |
+
dtype=self.unet.dtype,
|
| 189 |
+
),
|
| 190 |
+
generator,
|
| 191 |
+
)
|
| 192 |
+
buffers = self._build_legacy_ddpm_buffers(self.scheduler, device)
|
| 193 |
+
for t in reversed(range(num_inference_steps)):
|
| 194 |
+
timestep = torch.full(
|
| 195 |
+
(impedance_latents.shape[0],), t, device=device, dtype=torch.long
|
| 196 |
+
)
|
| 197 |
+
impedance_latents = self._ddpm_step(
|
| 198 |
+
impedance_latents, conditioning, timestep, generator, buffers
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
impedance_samples = self.vq_model.decode(
|
| 202 |
+
impedance_latents.to(dtype=self.vq_model.dtype)
|
| 203 |
+
).sample
|
| 204 |
+
impedance_reconstructed = None
|
| 205 |
+
if image is not None:
|
| 206 |
+
image = image.to(device=device, dtype=self.vq_model.dtype)
|
| 207 |
+
image_latents = self.vq_model.encode(image).latents
|
| 208 |
+
impedance_reconstructed = self.vq_model.decode(image_latents).sample
|
| 209 |
+
|
| 210 |
+
if output_type == "np":
|
| 211 |
+
impedance_samples = impedance_samples.detach().cpu().numpy()
|
| 212 |
+
impedance_latents = impedance_latents.detach().cpu().numpy()
|
| 213 |
+
impedance_dipin = impedance_dipin.detach().cpu().numpy()
|
| 214 |
+
record_features = record_features.detach().cpu().numpy()
|
| 215 |
+
if impedance_reconstructed is not None:
|
| 216 |
+
impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy()
|
| 217 |
+
|
| 218 |
+
return SeismicImpInvLDDPMPipelineOutput(
|
| 219 |
+
impedance_samples=impedance_samples,
|
| 220 |
+
impedance_latents=impedance_latents,
|
| 221 |
+
impedance_dipin=impedance_dipin,
|
| 222 |
+
impedance_reconstructed=impedance_reconstructed,
|
| 223 |
+
record_features=record_features,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
@torch.no_grad()
|
| 227 |
+
def encode_decode(
|
| 228 |
+
self, image: torch.Tensor, output_type: str = "tensor"
|
| 229 |
+
) -> torch.Tensor | np.ndarray:
|
| 230 |
+
image = image.to(device=self.vq_model.device, dtype=self.vq_model.dtype)
|
| 231 |
+
reconstruction = self.vq_model.decode(self.vq_model.encode(image).latents).sample
|
| 232 |
+
if output_type == "np":
|
| 233 |
+
return reconstruction.detach().cpu().numpy()
|
| 234 |
+
return reconstruction
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.6.0
|
| 2 |
+
diffusers>=0.25.0,<0.31.0
|
| 3 |
+
accelerate>=0.25.0
|
| 4 |
+
safetensors>=0.4.0
|
| 5 |
+
numpy>=1.23.0,<2.0
|
| 6 |
+
scipy>=1.10.1
|
| 7 |
+
matplotlib>=3.9.4
|
| 8 |
+
scikit-image>=0.24.0
|
| 9 |
+
pylops==2.2.0
|
| 10 |
+
pytorch-wavelets>=1.3.0
|