saii-cldm-synthetic / codes /eval_overthrust.py
mally-2000's picture
Add Overthrust evaluation progress logs
e8c53ea verified
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.stats import pearsonr
from skimage.metrics import structural_similarity
from torch.utils.data import DataLoader
from codes.dataset import OverthrustTrueimpDataset
from codes.pipeline import SeismicImpInvCLDMPipeline, SeismicImpInvLDDPMPipeline
from codes.util import OverthrustForwardOperator
OVERTHRUST_CONFIG = {
"size": 256,
"patch_indices": [0, 1, 2, 3, 4, 5],
"noise_snr": 15,
"dipin_v": 0.012,
"f0": 30,
"f0_phase": 0,
"seed": 1234,
"zhengyan_type": "nonlinear",
"normalize": "minmax",
"batch_size": 3,
}
def stitch_patches(
patches: list[np.ndarray], splits: list[tuple[int, int]], big_shape: tuple[int, int], img_size: int
) -> np.ndarray:
rec = np.zeros(big_shape, dtype=np.float32)
cnt = np.zeros(big_shape, dtype=np.float32)
for idx, (x, y) in enumerate(splits):
rec[x : x + img_size, y : y + img_size] += patches[idx]
cnt[x : x + img_size, y : y + img_size] += 1
return rec / np.maximum(cnt, 1)
def compute_metrics(prediction: np.ndarray, target: np.ndarray) -> dict[str, float]:
diff = prediction - target
denom = np.linalg.norm(diff.ravel()) ** 2
psnr = float("inf") if denom == 0 else float(
10.0 * np.log10(len(prediction.ravel()) * np.max(prediction.ravel()) ** 2 / denom)
)
return {
"PSNR": psnr,
"rre": float(np.linalg.norm(diff.ravel()) / np.linalg.norm(target.ravel())),
"SSIM": float(structural_similarity(target, prediction, data_range=target.max())),
"PCC": float(pearsonr(prediction.ravel(), target.ravel()).statistic),
"nmse": float(np.sum(diff ** 2) / np.sum(target ** 2)),
"mse": float(np.mean(diff ** 2) / prediction.size),
}
def save_comparison(
target_impedance: np.ndarray,
prediction_impedance: np.ndarray,
output_path: Path,
) -> None:
error = np.abs(target_impedance - prediction_impedance)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
vmin_imp = min(target_impedance.min(), prediction_impedance.min())
vmax_imp = max(target_impedance.max(), prediction_impedance.max())
for ax, arr, title in zip(
axes,
[target_impedance, prediction_impedance, error],
["Target (Impedance)", "Prediction (Impedance)", "Error (Impedance)"],
):
if "Error" in title:
im = ax.imshow(arr, cmap="hot", vmin=0, vmax=error.max())
else:
im = ax.imshow(arr, cmap="jet", vmin=vmin_imp, vmax=vmax_imp)
ax.set_title(title)
ax.axis("off")
plt.colorbar(im, ax=ax, fraction=0.046)
plt.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def evaluate_overthrust(
pipe: SeismicImpInvLDDPMPipeline,
method: str = "LDDPM",
output_dir: str | Path = "outputs/overthrust",
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
) -> dict[str, object]:
method = method.upper()
if method not in {"LDDPM", "CLDM"}:
raise ValueError("method must be LDDPM or CLDM")
if num_inference_steps is None:
num_inference_steps = 30 if method == "CLDM" else 1000
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
print(f"[eval] method={method}, steps={num_inference_steps}, device={device}")
print(f"[eval] output_dir={output_dir}")
print("[eval] moving pipeline to device...")
pipe = pipe.to(device)
print("[eval] building Overthrust dataset...")
dataset = OverthrustTrueimpDataset(
size=OVERTHRUST_CONFIG["size"],
normalize=OVERTHRUST_CONFIG["normalize"],
zhengyan_type=OVERTHRUST_CONFIG["zhengyan_type"],
ricks=[OVERTHRUST_CONFIG["f0"]],
ricks_phase=[OVERTHRUST_CONFIG["f0_phase"]],
noise_snr=[OVERTHRUST_CONFIG["noise_snr"]],
dipins=[OVERTHRUST_CONFIG["dipin_v"]],
record_noraml=True,
train_keys=["image", "dipin", "record"],
patch_indices=OVERTHRUST_CONFIG["patch_indices"],
base_seed=OVERTHRUST_CONFIG["seed"],
data_dir=REPO_ROOT / "data",
cache_dir=output_dir / "cache",
fixed_f0=OVERTHRUST_CONFIG["f0"],
fixed_dipin_v=OVERTHRUST_CONFIG["dipin_v"],
fixed_noise_snr=OVERTHRUST_CONFIG["noise_snr"],
fixed_f0_phase=OVERTHRUST_CONFIG["f0_phase"],
)
print(
"[eval] dataset ready: "
f"patches={len(dataset)}, batch_size={OVERTHRUST_CONFIG['batch_size']}, "
f"patch_indices={OVERTHRUST_CONFIG['patch_indices']}"
)
loader = DataLoader(
dataset,
batch_size=OVERTHRUST_CONFIG["batch_size"],
shuffle=False,
num_workers=0,
)
all_predictions: list[np.ndarray] = []
all_targets: list[np.ndarray] = []
all_reconstructions: list[np.ndarray] = []
total_batches = len(loader)
for batch_idx, batch in enumerate(loader, start=1):
seeds = batch["seed"].tolist()
batch_size = len(seeds)
print(
f"[eval] batch {batch_idx}/{total_batches}: "
f"batch_size={batch_size}, seeds={seeds}"
)
dipin = batch["dipin"].to(device)
record = batch["record"].to(device)
image = batch["image"].to(device)
extra_kwargs = {}
if method == "CLDM":
f0 = int(batch["rick_v"][0].item())
f0_phase = int(batch["rick_phase"][0].item())
extra_kwargs = {
"measurement": record,
"operator": OverthrustForwardOperator(
wavelet=dataset.wavelets[f0][f0_phase],
device=device,
),
}
print(f"[eval] batch {batch_idx}/{total_batches}: CLDM operator ready")
print(f"[eval] batch {batch_idx}/{total_batches}: running pipeline...")
output = pipe(
dipin=dipin,
record=record,
image=image,
num_inference_steps=num_inference_steps,
seeds=seeds,
**extra_kwargs,
)
print(f"[eval] batch {batch_idx}/{total_batches}: collecting predictions...")
prediction = output.impedance_samples
reconstruction = output.impedance_reconstructed
for local_idx in range(prediction.shape[0]):
all_predictions.append(prediction[local_idx, 0].detach().cpu().numpy())
all_targets.append(image[local_idx, 0].detach().cpu().numpy())
all_reconstructions.append(reconstruction[local_idx, 0].detach().cpu().numpy())
print("[eval] stitching patches...")
full_target = stitch_patches(
all_targets, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
)
full_prediction = stitch_patches(
all_predictions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
)
full_reconstruction = stitch_patches(
all_reconstructions, dataset.splits, dataset.big_img.shape, OVERTHRUST_CONFIG["size"]
)
print("[eval] converting normalized predictions to impedance...")
full_target_impedance = dataset.fan(full_target)
full_prediction_impedance = dataset.fan(full_prediction)
full_reconstruction_impedance = dataset.fan(full_reconstruction)
print("[eval] computing metrics...")
metrics_summary = {
"config": {
**OVERTHRUST_CONFIG,
"method": method,
"num_inference_steps": num_inference_steps,
},
"normalized": compute_metrics(full_prediction, full_target),
"impedance": compute_metrics(full_prediction_impedance, full_target_impedance),
"encode_impedance": compute_metrics(
full_reconstruction_impedance, full_target_impedance
),
}
paths = {
"full_target": output_dir / "full_target.npy",
"full_prediction": output_dir / "full_prediction.npy",
"full_reconstruction": output_dir / "full_reconstruction.npy",
"comparison": output_dir / "comparison_impedance.png",
"metrics": output_dir / "metrics_summary.json",
}
print("[eval] saving outputs...")
np.save(paths["full_target"], full_target)
np.save(paths["full_prediction"], full_prediction)
np.save(paths["full_reconstruction"], full_reconstruction)
save_comparison(full_target_impedance, full_prediction_impedance, paths["comparison"])
paths["metrics"].write_text(json.dumps(metrics_summary, indent=2), encoding="utf-8")
print(f"[eval] done. metrics={paths['metrics']}")
return {
"metrics": metrics_summary,
"paths": {key: str(value) for key, value in paths.items()},
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate SAII-LDDPM/CLDM on Overthrust.")
parser.add_argument("method", nargs="?", choices=["LDDPM", "CLDM"], default="LDDPM")
parser.add_argument("--model", default=str(REPO_ROOT))
parser.add_argument("--output", default="outputs/overthrust")
parser.add_argument("--device", default=None)
parser.add_argument("--num-inference-steps", type=int, default=None)
return parser.parse_args()
def main() -> None:
args = parse_args()
pipe_cls = SeismicImpInvCLDMPipeline if args.method == "CLDM" else SeismicImpInvLDDPMPipeline
pipe = pipe_cls.from_pretrained(
args.model,
torch_dtype=torch.float32,
trust_remote_code=True,
)
result = evaluate_overthrust(
pipe,
method=args.method,
output_dir=args.output,
num_inference_steps=args.num_inference_steps,
device=args.device,
)
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()