| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| REPO_ROOT = Path(__file__).resolve().parent |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
|
|
| from codes.dataset import OverthrustTrueimpDataset, SeismicBase |
| from codes.pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline |
| from codes.util import OverthrustForwardOperator, ricker_wavelet |
|
|
|
|
| PATCH_INDEX = 0 |
| MODEL_DIR = REPO_ROOT |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run SAII-LDDPM/CLDM inference.") |
| parser.add_argument("method", nargs="?", choices=["LDDPM", "CLDM"], default="LDDPM") |
| parser.add_argument("--eval", action="store_true", help="Run full Overthrust evaluation after single-sample inference.") |
| return parser.parse_args() |
|
|
| def save_comparison(dipin, record, target, prediction, output_path): |
| fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| for ax, arr, title, cmap in zip( |
| axes, |
| [dipin, record, target, prediction], |
| ["Low-frequency impedance", "Seismic record", "Target", "Prediction"], |
| ["jet", "gray", "jet", "jet"], |
| ): |
| im = ax.imshow(arr, cmap=cmap) |
| ax.set_title(title) |
| ax.axis("off") |
| plt.colorbar(im, ax=ax, fraction=0.046) |
| plt.tight_layout() |
| fig.savefig(output_path, dpi=150) |
| plt.close(fig) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| method = args.method.upper() |
| out_dir = REPO_ROOT / "outputs" / f"infer_{method}" |
|
|
| out_dir.mkdir(parents=True, exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| print(f"Method: {method}") |
|
|
| dataset = OverthrustTrueimpDataset( |
| patch_indices=[PATCH_INDEX], |
| data_dir=REPO_ROOT / "data", |
| cache_dir=out_dir / "cache", |
| ) |
| sample = dataset[0] |
| dipin = sample["dipin"].unsqueeze(0).to(device) |
| record = sample["record"].unsqueeze(0).to(device) |
| image = sample["image"].unsqueeze(0).to(device) |
| seed = int(sample["seed"]) |
|
|
| if method == "LDDPM": |
| num_inference_steps = 1000 |
| extra_kwargs = {} |
| pipe = SeismicImpInvLDDPMPipeline.from_pretrained( |
| MODEL_DIR, |
| torch_dtype=torch.float32, |
| trust_remote_code=True, |
| ).to(device) |
|
|
| else: |
| pipe = SeismicImpInvCLDMPipeline.from_pretrained( |
| MODEL_DIR, |
| torch_dtype=torch.float32, |
| trust_remote_code=True, |
| ).to(device) |
| num_inference_steps = 30 |
| f0 = int(sample["rick_v"].item()) |
| f0_phase = int(sample["rick_phase"].item()) |
|
|
| |
| |
| |
| |
| wavelet = ricker_wavelet(f0=f0, nt=256 // 2, dt=0.002) |
| |
| wavelet = SeismicBase.phaseshift(wavelet, f0_phase) |
|
|
| operator = OverthrustForwardOperator( |
| wavelet=wavelet, |
| device=device, |
| ) |
| extra_kwargs = dict( |
| measurement=record, |
| operator=operator, |
| ) |
|
|
| output = pipe( |
| dipin=dipin, |
| record=record, |
| image=image, |
| num_inference_steps=num_inference_steps, |
| seeds=[seed], |
| **extra_kwargs, |
| ) |
|
|
| prediction = output.impedance_samples[0, 0].detach().cpu().numpy() |
| target = image[0, 0].detach().cpu().numpy() |
| dipin_np = dipin[0, 0].detach().cpu().numpy() |
| record_np = record[0, 0].detach().cpu().numpy() |
|
|
| np.save(out_dir / "prediction.npy", prediction) |
| np.save(out_dir / "target.npy", target) |
| save_comparison(dipin_np, record_np, target, prediction, out_dir / "comparison.png") |
|
|
| print(f"Saved: {out_dir / 'prediction.npy'}") |
| print(f"Saved: {out_dir / 'target.npy'}") |
| print(f"Saved: {out_dir / 'comparison.png'}") |
|
|
| if args.eval: |
| from codes.eval_overthrust import evaluate_overthrust |
|
|
| evaluate_overthrust(pipe, method=method, output_dir=out_dir / "eval") |
|
|