from __future__ import annotations import argparse import sys from pathlib import Path REPO_ROOT = Path(__file__).resolve().parent if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) import matplotlib.pyplot as plt import numpy as np import torch from codes.dataset import OverthrustTrueimpDataset, SeismicBase from codes.pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline from codes.util import OverthrustForwardOperator, ricker_wavelet PATCH_INDEX = 0 MODEL_DIR = REPO_ROOT def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run SAII-LDDPM/CLDM inference.") parser.add_argument("method", nargs="?", choices=["LDDPM", "CLDM"], default="LDDPM") parser.add_argument("--eval", action="store_true", help="Run full Overthrust evaluation after single-sample inference.") return parser.parse_args() def save_comparison(dipin, record, target, prediction, output_path): fig, axes = plt.subplots(1, 4, figsize=(16, 4)) for ax, arr, title, cmap in zip( axes, [dipin, record, target, prediction], ["Low-frequency impedance", "Seismic record", "Target", "Prediction"], ["jet", "gray", "jet", "jet"], ): im = ax.imshow(arr, cmap=cmap) ax.set_title(title) ax.axis("off") plt.colorbar(im, ax=ax, fraction=0.046) plt.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) if __name__ == "__main__": args = parse_args() method = args.method.upper() out_dir = REPO_ROOT / "outputs" / f"infer_{method}" out_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") print(f"Method: {method}") dataset = OverthrustTrueimpDataset( patch_indices=[PATCH_INDEX], data_dir=REPO_ROOT / "data", cache_dir=out_dir / "cache", ) sample = dataset[0] dipin = sample["dipin"].unsqueeze(0).to(device) record = sample["record"].unsqueeze(0).to(device) image = sample["image"].unsqueeze(0).to(device) seed = int(sample["seed"]) if method == "LDDPM": num_inference_steps = 1000 extra_kwargs = {} pipe = SeismicImpInvLDDPMPipeline.from_pretrained( MODEL_DIR, torch_dtype=torch.float32, trust_remote_code=True, ).to(device) else: pipe = SeismicImpInvCLDMPipeline.from_pretrained( MODEL_DIR, torch_dtype=torch.float32, trust_remote_code=True, ).to(device) num_inference_steps = 30 f0 = int(sample["rick_v"].item()) f0_phase = int(sample["rick_phase"].item()) # NOTE: The forward operator's wavelet must match the dataset's wavelet # to ensure consistency between simulated measurements and actual data. # The parameters (f0=30Hz, dt=0.002s) must match the values used in # OverthrustTrueimpDataset._build_wavelets() to generate the seismic records. wavelet = ricker_wavelet(f0=f0, nt=256 // 2, dt=0.002) # Apply phase shift to match the dataset's wavelet phase wavelet = SeismicBase.phaseshift(wavelet, f0_phase) operator = OverthrustForwardOperator( wavelet=wavelet, device=device, ) extra_kwargs = dict( measurement=record, operator=operator, ) output = pipe( dipin=dipin, record=record, image=image, num_inference_steps=num_inference_steps, seeds=[seed], **extra_kwargs, ) prediction = output.impedance_samples[0, 0].detach().cpu().numpy() target = image[0, 0].detach().cpu().numpy() dipin_np = dipin[0, 0].detach().cpu().numpy() record_np = record[0, 0].detach().cpu().numpy() np.save(out_dir / "prediction.npy", prediction) np.save(out_dir / "target.npy", target) save_comparison(dipin_np, record_np, target, prediction, out_dir / "comparison.png") print(f"Saved: {out_dir / 'prediction.npy'}") print(f"Saved: {out_dir / 'target.npy'}") print(f"Saved: {out_dir / 'comparison.png'}") if args.eval: from codes.eval_overthrust import evaluate_overthrust evaluate_overthrust(pipe, method=method, output_dir=out_dir / "eval")