| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| from scipy.stats import pearsonr |
| from skimage.metrics import structural_similarity |
| from torch.utils.data import DataLoader |
|
|
| from codes.dataset import OverthrustTrueimpDataset |
| from codes.pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline |
| from codes.util import OverthrustForwardOperator |
|
|
|
|
| OVERTHRUST_CONFIG = { |
| "size": 256, |
| "patch_indices": [0, 1, 2, 3, 4, 5], |
| "noise_snr": 15, |
| "dipin_v": 0.012, |
| "f0": 30, |
| "f0_phase": 0, |
| "seed": 1234, |
| "zhengyan_type": "nonlinear", |
| "normalize": "minmax", |
| "batch_size": 3, |
| } |
|
|
|
|
| def stitch_patches( |
| patches: list[np.ndarray], splits: list[tuple[int, int]], big_shape: tuple[int, int], img_size: int |
| ) -> np.ndarray: |
| rec = np.zeros(big_shape, dtype=np.float32) |
| cnt = np.zeros(big_shape, dtype=np.float32) |
| for idx, (x, y) in enumerate(splits): |
| rec[x : x + img_size, y : y + img_size] += patches[idx] |
| cnt[x : x + img_size, y : y + img_size] += 1 |
| return rec / np.maximum(cnt, 1) |
|
|
|
|
| def compute_metrics(prediction: np.ndarray, target: np.ndarray) -> dict[str, float]: |
| diff = prediction - target |
| denom = np.linalg.norm(diff.ravel()) ** 2 |
| psnr = float("inf") if denom == 0 else float( |
| 10.0 * np.log10(len(prediction.ravel()) * np.max(prediction.ravel()) ** 2 / denom) |
| ) |
| return { |
| "PSNR": psnr, |
| "rre": float(np.linalg.norm(diff.ravel()) / np.linalg.norm(target.ravel())), |
| "SSIM": float(structural_similarity(target, prediction, data_range=target.max())), |
| "PCC": float(pearsonr(prediction.ravel(), target.ravel()).statistic), |
| "nmse": float(np.sum(diff ** 2) / np.sum(target ** 2)), |
| "mse": float(np.mean(diff ** 2) / prediction.size), |
| } |
|
|
|
|
| def save_comparison( |
| target_impedance: np.ndarray, |
| prediction_impedance: np.ndarray, |
| output_path: Path, |
| ) -> None: |
| error = np.abs(target_impedance - prediction_impedance) |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
| vmin_imp = min(target_impedance.min(), prediction_impedance.min()) |
| vmax_imp = max(target_impedance.max(), prediction_impedance.max()) |
| for ax, arr, title in zip( |
| axes, |
| [target_impedance, prediction_impedance, error], |
| ["Target (Impedance)", "Prediction (Impedance)", "Error (Impedance)"], |
| ): |
| if "Error" in title: |
| im = ax.imshow(arr, cmap="hot", vmin=0, vmax=error.max()) |
| else: |
| im = ax.imshow(arr, cmap="jet", vmin=vmin_imp, vmax=vmax_imp) |
| ax.set_title(title) |
| ax.axis("off") |
| plt.colorbar(im, ax=ax, fraction=0.046) |
| plt.tight_layout() |
| fig.savefig(output_path, dpi=150) |
| plt.close(fig) |
|
|
|
|
| def evaluate_overthrust( |
| pipe: SeismicImpInvLDDPMPipeline, |
| method: str = "LDDPM", |
| output_dir: str | Path = "outputs/overthrust", |
| num_inference_steps: int | None = None, |
| device: str | torch.device | None = None, |
| ) -> dict[str, object]: |
| method = method.upper() |
| if method not in {"LDDPM", "CLDM"}: |
| raise ValueError("method must be LDDPM or CLDM") |
| if num_inference_steps is None: |
| num_inference_steps = 30 if method == "CLDM" else 1000 |
|
|
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| print(f"[eval] method={method}, steps={num_inference_steps}, device={device}") |
| print(f"[eval] output_dir={output_dir}") |
| print("[eval] moving pipeline to device...") |
| pipe = pipe.to(device) |
|
|
| print("[eval] building Overthrust dataset...") |
| dataset = OverthrustTrueimpDataset( |
| size=OVERTHRUST_CONFIG["size"], |
| normalize=OVERTHRUST_CONFIG["normalize"], |
| zhengyan_type=OVERTHRUST_CONFIG["zhengyan_type"], |
| ricks=[OVERTHRUST_CONFIG["f0"]], |
| ricks_phase=[OVERTHRUST_CONFIG["f0_phase"]], |
| noise_snr=[OVERTHRUST_CONFIG["noise_snr"]], |
| dipins=[OVERTHRUST_CONFIG["dipin_v"]], |
| record_noraml=True, |
| train_keys=["image", "dipin", "record"], |
| patch_indices=OVERTHRUST_CONFIG["patch_indices"], |
| base_seed=OVERTHRUST_CONFIG["seed"], |
| data_dir=REPO_ROOT / "data", |
| cache_dir=output_dir / "cache", |
| fixed_f0=OVERTHRUST_CONFIG["f0"], |
| fixed_dipin_v=OVERTHRUST_CONFIG["dipin_v"], |
| fixed_noise_snr=OVERTHRUST_CONFIG["noise_snr"], |
| fixed_f0_phase=OVERTHRUST_CONFIG["f0_phase"], |
| ) |
| print( |
| "[eval] dataset ready: " |
| f"patches={len(dataset)}, batch_size={OVERTHRUST_CONFIG['batch_size']}, " |
| f"patch_indices={OVERTHRUST_CONFIG['patch_indices']}" |
| ) |
| loader = DataLoader( |
| dataset, |
| batch_size=OVERTHRUST_CONFIG["batch_size"], |
| shuffle=False, |
| num_workers=0, |
| ) |
|
|
| all_predictions: list[np.ndarray] = [] |
| all_targets: list[np.ndarray] = [] |
| all_reconstructions: list[np.ndarray] = [] |
| total_batches = len(loader) |
| for batch_idx, batch in enumerate(loader, start=1): |
| seeds = batch["seed"].tolist() |
| batch_size = len(seeds) |
| print( |
| f"[eval] batch {batch_idx}/{total_batches}: " |
| f"batch_size={batch_size}, seeds={seeds}" |
| ) |
| dipin = batch["dipin"].to(device) |
| record = batch["record"].to(device) |
| image = batch["image"].to(device) |
| extra_kwargs = {} |
| if method == "CLDM": |
| f0 = int(batch["rick_v"][0].item()) |
| f0_phase = int(batch["rick_phase"][0].item()) |
| extra_kwargs = { |
| "measurement": record, |
| "operator": OverthrustForwardOperator( |
| wavelet=dataset.wavelets[f0][f0_phase], |
| device=device, |
| ), |
| } |
| print(f"[eval] batch {batch_idx}/{total_batches}: CLDM operator ready") |
| print(f"[eval] batch {batch_idx}/{total_batches}: running pipeline...") |
| output = pipe( |
| dipin=dipin, |
| record=record, |
| image=image, |
| num_inference_steps=num_inference_steps, |
| seeds=seeds, |
| **extra_kwargs, |
| ) |
| print(f"[eval] batch {batch_idx}/{total_batches}: collecting predictions...") |
| prediction = output.impedance_samples |
| reconstruction = output.impedance_reconstructed |
| for local_idx in range(prediction.shape[0]): |
| all_predictions.append(prediction[local_idx, 0].detach().cpu().numpy()) |
| all_targets.append(image[local_idx, 0].detach().cpu().numpy()) |
| all_reconstructions.append(reconstruction[local_idx, 0].detach().cpu().numpy()) |
|
|
| print("[eval] stitching patches...") |
| full_target = stitch_patches( |
| all_targets, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"] |
| ) |
| full_prediction = stitch_patches( |
| all_predictions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"] |
| ) |
| full_reconstruction = stitch_patches( |
| all_reconstructions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"] |
| ) |
|
|
| print("[eval] converting normalized predictions to impedance...") |
| full_target_impedance = dataset.fan(full_target) |
| full_prediction_impedance = dataset.fan(full_prediction) |
| full_reconstruction_impedance = dataset.fan(full_reconstruction) |
|
|
| print("[eval] computing metrics...") |
| metrics_summary = { |
| "config": { |
| **OVERTHRUST_CONFIG, |
| "method": method, |
| "num_inference_steps": num_inference_steps, |
| }, |
| "normalized": compute_metrics(full_prediction, full_target), |
| "impedance": compute_metrics(full_prediction_impedance, full_target_impedance), |
| "encode_impedance": compute_metrics( |
| full_reconstruction_impedance, full_target_impedance |
| ), |
| } |
|
|
| paths = { |
| "full_target": output_dir / "full_target.npy", |
| "full_prediction": output_dir / "full_prediction.npy", |
| "full_reconstruction": output_dir / "full_reconstruction.npy", |
| "comparison": output_dir / "comparison_impedance.png", |
| "metrics": output_dir / "metrics_summary.json", |
| } |
| print("[eval] saving outputs...") |
| np.save(paths["full_target"], full_target) |
| np.save(paths["full_prediction"], full_prediction) |
| np.save(paths["full_reconstruction"], full_reconstruction) |
| save_comparison(full_target_impedance, full_prediction_impedance, paths["comparison"]) |
| paths["metrics"].write_text(json.dumps(metrics_summary, indent=2), encoding="utf-8") |
| print(f"[eval] done. metrics={paths['metrics']}") |
| return { |
| "metrics": metrics_summary, |
| "paths": {key: str(value) for key, value in paths.items()}, |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Evaluate SAII-LDDPM/CLDM on Overthrust.") |
| parser.add_argument("method", nargs="?", choices=["LDDPM", "CLDM"], default="LDDPM") |
| parser.add_argument("--model", default=str(REPO_ROOT)) |
| parser.add_argument("--output", default="outputs/overthrust") |
| parser.add_argument("--device", default=None) |
| parser.add_argument("--num-inference-steps", type=int, default=None) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| pipe_cls = SeismicImpInvCLDMPipeline if args.method == "CLDM" else SeismicImpInvLDDPMPipeline |
| pipe = pipe_cls.from_pretrained( |
| args.model, |
| torch_dtype=torch.float32, |
| trust_remote_code=True, |
| ) |
| result = evaluate_overthrust( |
| pipe, |
| method=args.method, |
| output_dir=args.output, |
| num_inference_steps=args.num_inference_steps, |
| device=args.device, |
| ) |
| print(json.dumps(result, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|