mally-2000's picture
Make full evaluation optional in infer script
d6f6beb verified
from __future__ import annotations
import argparse
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
import matplotlib.pyplot as plt
import numpy as np
import torch
from codes.dataset import OverthrustTrueimpDataset, SeismicBase
from codes.pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline
from codes.util import OverthrustForwardOperator, ricker_wavelet
PATCH_INDEX = 0
MODEL_DIR = REPO_ROOT
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run SAII-LDDPM/CLDM inference.")
parser.add_argument("method", nargs="?", choices=["LDDPM", "CLDM"], default="LDDPM")
parser.add_argument("--eval", action="store_true", help="Run full Overthrust evaluation after single-sample inference.")
return parser.parse_args()
def save_comparison(dipin, record, target, prediction, output_path):
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, arr, title, cmap in zip(
axes,
[dipin, record, target, prediction],
["Low-frequency impedance", "Seismic record", "Target", "Prediction"],
["jet", "gray", "jet", "jet"],
):
im = ax.imshow(arr, cmap=cmap)
ax.set_title(title)
ax.axis("off")
plt.colorbar(im, ax=ax, fraction=0.046)
plt.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
if __name__ == "__main__":
args = parse_args()
method = args.method.upper()
out_dir = REPO_ROOT / "outputs" / f"infer_{method}"
out_dir.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Method: {method}")
dataset = OverthrustTrueimpDataset(
patch_indices=[PATCH_INDEX],
data_dir=REPO_ROOT / "data",
cache_dir=out_dir / "cache",
)
sample = dataset[0]
dipin = sample["dipin"].unsqueeze(0).to(device)
record = sample["record"].unsqueeze(0).to(device)
image = sample["image"].unsqueeze(0).to(device)
seed = int(sample["seed"])
if method == "LDDPM":
num_inference_steps = 1000
extra_kwargs = {}
pipe = SeismicImpInvLDDPMPipeline.from_pretrained(
MODEL_DIR,
torch_dtype=torch.float32,
trust_remote_code=True,
).to(device)
else:
pipe = SeismicImpInvCLDMPipeline.from_pretrained(
MODEL_DIR,
torch_dtype=torch.float32,
trust_remote_code=True,
).to(device)
num_inference_steps = 30
f0 = int(sample["rick_v"].item())
f0_phase = int(sample["rick_phase"].item())
# NOTE: The forward operator's wavelet must match the dataset's wavelet
# to ensure consistency between simulated measurements and actual data.
# The parameters (f0=30Hz, dt=0.002s) must match the values used in
# OverthrustTrueimpDataset._build_wavelets() to generate the seismic records.
wavelet = ricker_wavelet(f0=f0, nt=256 // 2, dt=0.002)
# Apply phase shift to match the dataset's wavelet phase
wavelet = SeismicBase.phaseshift(wavelet, f0_phase)
operator = OverthrustForwardOperator(
wavelet=wavelet,
device=device,
)
extra_kwargs = dict(
measurement=record,
operator=operator,
)
output = pipe(
dipin=dipin,
record=record,
image=image,
num_inference_steps=num_inference_steps,
seeds=[seed],
**extra_kwargs,
)
prediction = output.impedance_samples[0, 0].detach().cpu().numpy()
target = image[0, 0].detach().cpu().numpy()
dipin_np = dipin[0, 0].detach().cpu().numpy()
record_np = record[0, 0].detach().cpu().numpy()
np.save(out_dir / "prediction.npy", prediction)
np.save(out_dir / "target.npy", target)
save_comparison(dipin_np, record_np, target, prediction, out_dir / "comparison.png")
print(f"Saved: {out_dir / 'prediction.npy'}")
print(f"Saved: {out_dir / 'target.npy'}")
print(f"Saved: {out_dir / 'comparison.png'}")
if args.eval:
from codes.eval_overthrust import evaluate_overthrust
evaluate_overthrust(pipe, method=method, output_dir=out_dir / "eval")