from __future__ import annotations import argparse import json import sys from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) import matplotlib.pyplot as plt import numpy as np import torch from scipy.stats import pearsonr from skimage.metrics import structural_similarity from torch.utils.data import DataLoader from codes.dataset import OverthrustTrueimpDataset from codes.pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline from codes.util import OverthrustForwardOperator OVERTHRUST_CONFIG = { "size": 256, "patch_indices": [0, 1, 2, 3, 4, 5], "noise_snr": 15, "dipin_v": 0.012, "f0": 30, "f0_phase": 0, "seed": 1234, "zhengyan_type": "nonlinear", "normalize": "minmax", "batch_size": 3, } def stitch_patches( patches: list[np.ndarray], splits: list[tuple[int, int]], big_shape: tuple[int, int], img_size: int ) -> np.ndarray: rec = np.zeros(big_shape, dtype=np.float32) cnt = np.zeros(big_shape, dtype=np.float32) for idx, (x, y) in enumerate(splits): rec[x : x + img_size, y : y + img_size] += patches[idx] cnt[x : x + img_size, y : y + img_size] += 1 return rec / np.maximum(cnt, 1) def compute_metrics(prediction: np.ndarray, target: np.ndarray) -> dict[str, float]: diff = prediction - target denom = np.linalg.norm(diff.ravel()) ** 2 psnr = float("inf") if denom == 0 else float( 10.0 * np.log10(len(prediction.ravel()) * np.max(prediction.ravel()) ** 2 / denom) ) return { "PSNR": psnr, "rre": float(np.linalg.norm(diff.ravel()) / np.linalg.norm(target.ravel())), "SSIM": float(structural_similarity(target, prediction, data_range=target.max())), "PCC": float(pearsonr(prediction.ravel(), target.ravel()).statistic), "nmse": float(np.sum(diff ** 2) / np.sum(target ** 2)), "mse": float(np.mean(diff ** 2) / prediction.size), } def save_comparison( target_impedance: np.ndarray, prediction_impedance: np.ndarray, output_path: Path, ) -> None: error = np.abs(target_impedance - prediction_impedance) fig, axes = plt.subplots(1, 3, figsize=(15, 5)) vmin_imp = min(target_impedance.min(), prediction_impedance.min()) vmax_imp = max(target_impedance.max(), prediction_impedance.max()) for ax, arr, title in zip( axes, [target_impedance, prediction_impedance, error], ["Target (Impedance)", "Prediction (Impedance)", "Error (Impedance)"], ): if "Error" in title: im = ax.imshow(arr, cmap="hot", vmin=0, vmax=error.max()) else: im = ax.imshow(arr, cmap="jet", vmin=vmin_imp, vmax=vmax_imp) ax.set_title(title) ax.axis("off") plt.colorbar(im, ax=ax, fraction=0.046) plt.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) def evaluate_overthrust( pipe: SeismicImpInvLDDPMPipeline, method: str = "LDDPM", output_dir: str | Path = "outputs/overthrust", num_inference_steps: int | None = None, device: str | torch.device | None = None, ) -> dict[str, object]: method = method.upper() if method not in {"LDDPM", "CLDM"}: raise ValueError("method must be LDDPM or CLDM") if num_inference_steps is None: num_inference_steps = 30 if method == "CLDM" else 1000 output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) print(f"[eval] method={method}, steps={num_inference_steps}, device={device}") print(f"[eval] output_dir={output_dir}") print("[eval] moving pipeline to device...") pipe = pipe.to(device) print("[eval] building Overthrust dataset...") dataset = OverthrustTrueimpDataset( size=OVERTHRUST_CONFIG["size"], normalize=OVERTHRUST_CONFIG["normalize"], zhengyan_type=OVERTHRUST_CONFIG["zhengyan_type"], ricks=[OVERTHRUST_CONFIG["f0"]], ricks_phase=[OVERTHRUST_CONFIG["f0_phase"]], noise_snr=[OVERTHRUST_CONFIG["noise_snr"]], dipins=[OVERTHRUST_CONFIG["dipin_v"]], record_noraml=True, train_keys=["image", "dipin", "record"], patch_indices=OVERTHRUST_CONFIG["patch_indices"], base_seed=OVERTHRUST_CONFIG["seed"], data_dir=REPO_ROOT / "data", cache_dir=output_dir / "cache", fixed_f0=OVERTHRUST_CONFIG["f0"], fixed_dipin_v=OVERTHRUST_CONFIG["dipin_v"], fixed_noise_snr=OVERTHRUST_CONFIG["noise_snr"], fixed_f0_phase=OVERTHRUST_CONFIG["f0_phase"], ) print( "[eval] dataset ready: " f"patches={len(dataset)}, batch_size={OVERTHRUST_CONFIG['batch_size']}, " f"patch_indices={OVERTHRUST_CONFIG['patch_indices']}" ) loader = DataLoader( dataset, batch_size=OVERTHRUST_CONFIG["batch_size"], shuffle=False, num_workers=0, ) all_predictions: list[np.ndarray] = [] all_targets: list[np.ndarray] = [] all_reconstructions: list[np.ndarray] = [] total_batches = len(loader) for batch_idx, batch in enumerate(loader, start=1): seeds = batch["seed"].tolist() batch_size = len(seeds) print( f"[eval] batch {batch_idx}/{total_batches}: " f"batch_size={batch_size}, seeds={seeds}" ) dipin = batch["dipin"].to(device) record = batch["record"].to(device) image = batch["image"].to(device) extra_kwargs = {} if method == "CLDM": f0 = int(batch["rick_v"][0].item()) f0_phase = int(batch["rick_phase"][0].item()) extra_kwargs = { "measurement": record, "operator": OverthrustForwardOperator( wavelet=dataset.wavelets[f0][f0_phase], device=device, ), } print(f"[eval] batch {batch_idx}/{total_batches}: CLDM operator ready") print(f"[eval] batch {batch_idx}/{total_batches}: running pipeline...") output = pipe( dipin=dipin, record=record, image=image, num_inference_steps=num_inference_steps, seeds=seeds, **extra_kwargs, ) print(f"[eval] batch {batch_idx}/{total_batches}: collecting predictions...") prediction = output.impedance_samples reconstruction = output.impedance_reconstructed for local_idx in range(prediction.shape[0]): all_predictions.append(prediction[local_idx, 0].detach().cpu().numpy()) all_targets.append(image[local_idx, 0].detach().cpu().numpy()) all_reconstructions.append(reconstruction[local_idx, 0].detach().cpu().numpy()) print("[eval] stitching patches...") full_target = stitch_patches( all_targets, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"] ) full_prediction = stitch_patches( all_predictions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"] ) full_reconstruction = stitch_patches( all_reconstructions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"] ) print("[eval] converting normalized predictions to impedance...") full_target_impedance = dataset.fan(full_target) full_prediction_impedance = dataset.fan(full_prediction) full_reconstruction_impedance = dataset.fan(full_reconstruction) print("[eval] computing metrics...") metrics_summary = { "config": { **OVERTHRUST_CONFIG, "method": method, "num_inference_steps": num_inference_steps, }, "normalized": compute_metrics(full_prediction, full_target), "impedance": compute_metrics(full_prediction_impedance, full_target_impedance), "encode_impedance": compute_metrics( full_reconstruction_impedance, full_target_impedance ), } paths = { "full_target": output_dir / "full_target.npy", "full_prediction": output_dir / "full_prediction.npy", "full_reconstruction": output_dir / "full_reconstruction.npy", "comparison": output_dir / "comparison_impedance.png", "metrics": output_dir / "metrics_summary.json", } print("[eval] saving outputs...") np.save(paths["full_target"], full_target) np.save(paths["full_prediction"], full_prediction) np.save(paths["full_reconstruction"], full_reconstruction) save_comparison(full_target_impedance, full_prediction_impedance, paths["comparison"]) paths["metrics"].write_text(json.dumps(metrics_summary, indent=2), encoding="utf-8") print(f"[eval] done. metrics={paths['metrics']}") return { "metrics": metrics_summary, "paths": {key: str(value) for key, value in paths.items()}, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate SAII-LDDPM/CLDM on Overthrust.") parser.add_argument("method", nargs="?", choices=["LDDPM", "CLDM"], default="LDDPM") parser.add_argument("--model", default=str(REPO_ROOT)) parser.add_argument("--output", default="outputs/overthrust") parser.add_argument("--device", default=None) parser.add_argument("--num-inference-steps", type=int, default=None) return parser.parse_args() def main() -> None: args = parse_args() pipe_cls = SeismicImpInvCLDMPipeline if args.method == "CLDM" else SeismicImpInvLDDPMPipeline pipe = pipe_cls.from_pretrained( args.model, torch_dtype=torch.float32, trust_remote_code=True, ) result = evaluate_overthrust( pipe, method=args.method, output_dir=args.output, num_inference_steps=args.num_inference_steps, device=args.device, ) print(json.dumps(result, indent=2)) if __name__ == "__main__": main()