| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Callable |
|
|
| import numpy as np |
| import torch |
| from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DModel, VQModel |
| from diffusers.utils import BaseOutput |
|
|
|
|
| @dataclass |
| class SeismicImpInvLDDPMPipelineOutput(BaseOutput): |
| impedance_samples: torch.Tensor | np.ndarray |
| impedance_latents: torch.Tensor | np.ndarray |
| impedance_dipin: torch.Tensor | np.ndarray |
| impedance_reconstructed: torch.Tensor | np.ndarray | None = None |
| record_features: torch.Tensor | np.ndarray | None = None |
|
|
|
|
| class SeismicImpInvLDDPMPipeline(DiffusionPipeline): |
| """SAII-LDDPM impedance inversion pipeline.""" |
|
|
| def __init__( |
| self, |
| vq_model: VQModel, |
| condition_encoder: torch.nn.Module, |
| unet: UNet2DModel, |
| scheduler: DDPMScheduler, |
| ): |
| super().__init__() |
| self.register_modules( |
| vq_model=vq_model, |
| condition_encoder=condition_encoder, |
| unet=unet, |
| scheduler=scheduler, |
| ) |
|
|
| def _encode_conditioning( |
| self, dipin: torch.Tensor, record: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| dipin_latents = self.vq_model.encode(dipin).latents |
| if hasattr(self.condition_encoder, "encode") and callable( |
| self.condition_encoder.encode |
| ): |
| record_features = self.condition_encoder.encode(record) |
| else: |
| record_features = self.condition_encoder(record) |
| return ( |
| dipin_latents.to(dtype=self.unet.dtype), |
| record_features.to(dtype=self.unet.dtype), |
| ) |
|
|
| @staticmethod |
| def _extract_into_tensor( |
| arr: torch.Tensor, timesteps: torch.Tensor, broadcast_shape: torch.Size |
| ) -> torch.Tensor: |
| values = arr.to(device=timesteps.device, dtype=torch.float32).gather(0, timesteps) |
| return values.reshape(timesteps.shape[0], *((1,) * (len(broadcast_shape) - 1))) |
|
|
| @staticmethod |
| def _build_legacy_ddpm_buffers( |
| scheduler: DDPMScheduler, device: torch.device |
| ) -> dict[str, torch.Tensor]: |
| betas = scheduler.betas.to(device=device, dtype=torch.float32) |
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| alphas_cumprod_prev = torch.cat( |
| [torch.ones(1, device=device), alphas_cumprod[:-1]], dim=0 |
| ) |
| posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| posterior_log_variance_clipped = torch.log( |
| torch.clamp(posterior_variance, min=1e-20) |
| ) |
| return { |
| "sqrt_recip_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod), |
| "sqrt_recipm1_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod - 1), |
| "posterior_mean_coef1": betas |
| * torch.sqrt(alphas_cumprod_prev) |
| / (1.0 - alphas_cumprod), |
| "posterior_mean_coef2": (1.0 - alphas_cumprod_prev) |
| * torch.sqrt(alphas) |
| / (1.0 - alphas_cumprod), |
| "posterior_log_variance_clipped": posterior_log_variance_clipped, |
| } |
|
|
| @staticmethod |
| def _randn_like_sample( |
| sample: torch.Tensor, generator: torch.Generator | list[torch.Generator] | None |
| ) -> torch.Tensor: |
| if isinstance(generator, list): |
| if len(generator) != sample.shape[0]: |
| raise ValueError( |
| f"Expected {sample.shape[0]} generators, got {len(generator)}" |
| ) |
| return torch.cat( |
| [ |
| torch.randn( |
| sample[i : i + 1].shape, |
| generator=sample_generator, |
| device=sample.device, |
| dtype=sample.dtype, |
| ) |
| for i, sample_generator in enumerate(generator) |
| ], |
| dim=0, |
| ) |
| return torch.randn( |
| sample.shape, generator=generator, device=sample.device, dtype=sample.dtype |
| ) |
|
|
| def _ddpm_step( |
| self, |
| latents: torch.Tensor, |
| conditioning: torch.Tensor, |
| timestep: torch.Tensor, |
| generator: torch.Generator | list[torch.Generator] | None, |
| buffers: dict[str, torch.Tensor], |
| ) -> torch.Tensor: |
| model_input = torch.cat([latents, conditioning], dim=1) |
| noise_pred = self.unet(model_input, timestep).sample |
| pred_x0 = ( |
| self._extract_into_tensor( |
| buffers["sqrt_recip_alphas_cumprod"], timestep, latents.shape |
| ) |
| * latents |
| - self._extract_into_tensor( |
| buffers["sqrt_recipm1_alphas_cumprod"], timestep, latents.shape |
| ) |
| * noise_pred |
| ) |
| pred_x0 = self.vq_model.quantize(pred_x0)[0] |
| model_mean = ( |
| self._extract_into_tensor( |
| buffers["posterior_mean_coef1"], timestep, latents.shape |
| ) |
| * pred_x0 |
| + self._extract_into_tensor( |
| buffers["posterior_mean_coef2"], timestep, latents.shape |
| ) |
| * latents |
| ) |
| noise = self._randn_like_sample(latents, generator) |
| nonzero_mask = (1 - (timestep == 0).float()).reshape( |
| latents.shape[0], *((1,) * (len(latents.shape) - 1)) |
| ) |
| return model_mean + nonzero_mask * ( |
| 0.5 |
| * self._extract_into_tensor( |
| buffers["posterior_log_variance_clipped"], timestep, latents.shape |
| ) |
| ).exp() * noise |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| dipin: torch.Tensor, |
| record: torch.Tensor, |
| image: torch.Tensor | None = None, |
| num_inference_steps: int = 1000, |
| seed: int | None = None, |
| seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None, |
| generator: torch.Generator | None = None, |
| output_type: str = "tensor", |
| ) -> SeismicImpInvLDDPMPipelineOutput: |
| device = self.unet.device |
| if seeds is not None: |
| if isinstance(seeds, torch.Tensor): |
| seeds = seeds.detach().cpu().tolist() |
| seeds = [int(value) for value in seeds] |
| if len(seeds) != dipin.shape[0]: |
| raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}") |
| generator = [ |
| torch.Generator(device=device).manual_seed(value) for value in seeds |
| ] |
| elif seed is not None: |
| generator = torch.Generator(device=device).manual_seed(seed) |
| elif generator is None: |
| generator = torch.Generator(device=device) |
|
|
| dipin = dipin.to(device=device, dtype=self.vq_model.dtype) |
| record = record.to(device=device, dtype=self.unet.dtype) |
| impedance_dipin, record_features = self._encode_conditioning(dipin, record) |
| conditioning = torch.cat([impedance_dipin, record_features], dim=1) |
| impedance_latents = self._randn_like_sample( |
| torch.empty( |
| impedance_dipin.shape, |
| device=device, |
| dtype=self.unet.dtype, |
| ), |
| generator, |
| ) |
| buffers = self._build_legacy_ddpm_buffers(self.scheduler, device) |
| for t in reversed(range(num_inference_steps)): |
| timestep = torch.full( |
| (impedance_latents.shape[0],), t, device=device, dtype=torch.long |
| ) |
| impedance_latents = self._ddpm_step( |
| impedance_latents, conditioning, timestep, generator, buffers |
| ) |
|
|
| impedance_samples = self.vq_model.decode( |
| impedance_latents.to(dtype=self.vq_model.dtype) |
| ).sample |
| impedance_reconstructed = None |
| if image is not None: |
| image = image.to(device=device, dtype=self.vq_model.dtype) |
| image_latents = self.vq_model.encode(image).latents |
| impedance_reconstructed = self.vq_model.decode(image_latents).sample |
|
|
| if output_type == "np": |
| impedance_samples = impedance_samples.detach().cpu().numpy() |
| impedance_latents = impedance_latents.detach().cpu().numpy() |
| impedance_dipin = impedance_dipin.detach().cpu().numpy() |
| record_features = record_features.detach().cpu().numpy() |
| if impedance_reconstructed is not None: |
| impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy() |
|
|
| return SeismicImpInvLDDPMPipelineOutput( |
| impedance_samples=impedance_samples, |
| impedance_latents=impedance_latents, |
| impedance_dipin=impedance_dipin, |
| impedance_reconstructed=impedance_reconstructed, |
| record_features=record_features, |
| ) |
|
|
| @torch.no_grad() |
| def encode_decode( |
| self, image: torch.Tensor, output_type: str = "tensor" |
| ) -> torch.Tensor | np.ndarray: |
| image = image.to(device=self.vq_model.device, dtype=self.vq_model.dtype) |
| reconstruction = self.vq_model.decode(self.vq_model.encode(image).latents).sample |
| if output_type == "np": |
| return reconstruction.detach().cpu().numpy() |
| return reconstruction |
|
|
|
|
| class SeismicImpInvCLDMPipeline(SeismicImpInvLDDPMPipeline): |
| """SAII-CLDM inference pipeline. |
| |
| This reuses the same trained components as SAII-LDDPM and replaces only the |
| reverse sampling procedure with DDIM plus model-driven resampling. |
| """ |
|
|
| @staticmethod |
| def _get_operator_fn(operator: Any) -> Callable[[torch.Tensor], torch.Tensor]: |
| if callable(operator): |
| return operator |
| if hasattr(operator, "forward") and callable(operator.forward): |
| return operator.forward |
| raise TypeError("`operator` must be callable or expose a callable `forward` method.") |
|
|
| @staticmethod |
| def _build_ddim_scheduler( |
| scheduler: DDPMScheduler, |
| num_inference_steps: int, |
| device: torch.device, |
| ) -> DDIMScheduler: |
| ddim_scheduler = DDIMScheduler.from_config( |
| scheduler.config, |
| clip_sample=False, |
| set_alpha_to_one=False, |
| steps_offset=1, |
| timestep_spacing="leading", |
| ) |
| ddim_scheduler.set_timesteps(num_inference_steps, device=device) |
| return ddim_scheduler |
|
|
| @staticmethod |
| def _default_pixel_optimization_param() -> dict[str, float | int]: |
| return { |
| "eps": 1e-4, |
| "max_iters": 100, |
| "lr": 1e-5, |
| "y_coef": 1.0, |
| "x_coef": 0.0, |
| "tv_coef": 0.0, |
| "dh_coef": 1.0, |
| "dw_coef": 1.5, |
| } |
|
|
| @staticmethod |
| def _default_last_pixel_optimization_param() -> dict[str, float | int]: |
| return { |
| "eps": 1e-4, |
| "max_iters": 1, |
| "lr": 1e-4, |
| "y_coef": 1.0, |
| "x_coef": 0.1, |
| "tv_coef": 0.0, |
| "dh_coef": 1.0, |
| "dw_coef": 1.5, |
| } |
|
|
| @staticmethod |
| def _tv_loss(x: torch.Tensor, *, dh_coef: float, dw_coef: float) -> torch.Tensor: |
| dh = dh_coef * torch.abs(x[..., :, 1:] - x[..., :, :-1]) |
| dw = dw_coef * torch.abs(x[..., 1:, :] - x[..., :-1, :]) |
| return torch.mean(dh[..., :-1, :] + dw[..., :, :-1]) |
|
|
| def _ddim_step( |
| self, |
| latents: torch.Tensor, |
| conditioning: torch.Tensor, |
| timestep: int, |
| scheduler: DDIMScheduler, |
| eta: float, |
| generator: torch.Generator | list[torch.Generator] | None, |
| quantize_denoised: bool, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: |
| model_input = torch.cat( |
| [ |
| scheduler.scale_model_input(latents, timestep), |
| conditioning.to(dtype=latents.dtype), |
| ], |
| dim=1, |
| ) |
| timestep_tensor = torch.full( |
| (latents.shape[0],), timestep, device=latents.device, dtype=torch.long |
| ) |
| noise_pred = self.unet(model_input, timestep_tensor).sample |
|
|
| alpha_t = scheduler.alphas_cumprod[timestep].to( |
| device=latents.device, dtype=latents.dtype |
| ) |
| prev_timestep = timestep - ( |
| scheduler.config.num_train_timesteps // scheduler.num_inference_steps |
| ) |
| if prev_timestep >= 0: |
| alpha_prev = scheduler.alphas_cumprod[prev_timestep].to( |
| device=latents.device, dtype=latents.dtype |
| ) |
| else: |
| alpha_prev = scheduler.final_alpha_cumprod.to( |
| device=latents.device, dtype=latents.dtype |
| ) |
|
|
| beta_t = 1.0 - alpha_t |
| pred_x0 = (latents - beta_t.sqrt() * noise_pred) / alpha_t.sqrt() |
| pseudo_x0 = (latents - beta_t * noise_pred) / alpha_t.sqrt() |
| if quantize_denoised: |
| pred_x0 = self.vq_model.quantize(pred_x0.to(dtype=self.vq_model.dtype))[0].to( |
| dtype=latents.dtype |
| ) |
| noise_pred = (latents - alpha_t.sqrt() * pred_x0) / beta_t.sqrt() |
|
|
| variance = scheduler._get_variance(timestep, prev_timestep).to( |
| device=latents.device, dtype=latents.dtype |
| ) |
| sigma_t = eta * variance.sqrt() |
| direction = torch.clamp(1.0 - alpha_prev - sigma_t**2, min=0.0).sqrt() * noise_pred |
| noise = torch.zeros_like(latents) |
| if eta > 0: |
| noise = sigma_t * self._randn_like_sample(latents, generator) |
| prev_sample = alpha_prev.sqrt() * pred_x0 + direction + noise |
|
|
| batch_shape = (latents.shape[0], 1, 1, 1) |
| return ( |
| prev_sample, |
| pred_x0, |
| pseudo_x0, |
| { |
| "a_t": torch.full( |
| batch_shape, |
| float(alpha_t.item()), |
| device=latents.device, |
| dtype=latents.dtype, |
| ), |
| "a_prev": torch.full( |
| batch_shape, |
| float(alpha_prev.item()), |
| device=latents.device, |
| dtype=latents.dtype, |
| ), |
| }, |
| ) |
|
|
| def _optimize_pixels( |
| self, |
| x_prime: torch.Tensor, |
| measurement: torch.Tensor, |
| operator_fn: Callable[[torch.Tensor], torch.Tensor], |
| params: dict[str, Any], |
| ) -> torch.Tensor: |
| merged = {**self._default_pixel_optimization_param(), **params} |
| if int(merged["max_iters"]) <= 0: |
| return x_prime.detach() |
|
|
| loss_fn = torch.nn.MSELoss(reduction="mean") |
| opt_var = x_prime.detach().clone().requires_grad_(True) |
| opt_init = x_prime.detach().clone() |
| optimizer = torch.optim.AdamW([opt_var], lr=float(merged["lr"])) |
|
|
| for _ in range(int(merged["max_iters"])): |
| optimizer.zero_grad(set_to_none=True) |
| measurement_loss = ( |
| loss_fn(measurement, operator_fn(opt_var)) * float(merged["y_coef"]) |
| + loss_fn(opt_init, opt_var) * float(merged["x_coef"]) |
| ) |
| if float(merged["tv_coef"]) != 0.0: |
| measurement_loss = measurement_loss + float(merged["tv_coef"]) * self._tv_loss( |
| opt_var, |
| dh_coef=float(merged["dh_coef"]), |
| dw_coef=float(merged["dw_coef"]), |
| ) |
| measurement_loss.backward() |
| optimizer.step() |
| if float(measurement_loss.detach().cpu().item()) < float(merged["eps"]): |
| break |
|
|
| return opt_var.detach() |
|
|
| def _stochastic_resample( |
| self, |
| pseudo_x0: torch.Tensor, |
| x_t: torch.Tensor, |
| a_t: torch.Tensor, |
| sigma: torch.Tensor, |
| generator: torch.Generator | list[torch.Generator] | None, |
| ) -> torch.Tensor: |
| sigma = torch.clamp(sigma, min=1e-12) |
| one_minus_a_t = torch.clamp(1.0 - a_t, min=1e-12) |
| noise = self._randn_like_sample(pseudo_x0, generator) |
| return ( |
| (sigma * a_t.sqrt() * pseudo_x0 + one_minus_a_t * x_t) |
| / (sigma + one_minus_a_t) |
| + noise * torch.sqrt(1.0 / (1.0 / sigma + 1.0 / one_minus_a_t)) |
| ) |
|
|
| def __call__( |
| self, |
| dipin: torch.Tensor, |
| record: torch.Tensor, |
| measurement: torch.Tensor | None = None, |
| operator: Any | None = None, |
| image: torch.Tensor | None = None, |
| num_inference_steps: int = 30, |
| seed: int | None = None, |
| seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None, |
| generator: torch.Generator | None = None, |
| eta: float = 0.01, |
| interval: int = 6, |
| sigma_a: float = 20.0, |
| pixel_optimization_param: dict[str, Any] | None = None, |
| last_pixel_optimization_param: dict[str, Any] | None = None, |
| quantize_denoised: bool = False, |
| output_type: str = "tensor", |
| ) -> SeismicImpInvLDDPMPipelineOutput: |
| if measurement is None: |
| measurement = record |
| if operator is None: |
| raise ValueError("SAII-CLDM requires a forward `operator`.") |
| if interval <= 0: |
| raise ValueError("`interval` must be a positive integer.") |
|
|
| device = self.unet.device |
| if seeds is not None: |
| if isinstance(seeds, torch.Tensor): |
| seeds = seeds.detach().cpu().tolist() |
| seeds = [int(value) for value in seeds] |
| if len(seeds) != dipin.shape[0]: |
| raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}") |
| generator = [ |
| torch.Generator(device=device).manual_seed(value) for value in seeds |
| ] |
| elif seed is not None: |
| generator = torch.Generator(device=device).manual_seed(seed) |
| elif generator is None: |
| generator = torch.Generator(device=device) |
|
|
| with torch.no_grad(): |
| dipin = dipin.to(device=device, dtype=self.vq_model.dtype) |
| record = record.to(device=device, dtype=self.unet.dtype) |
| measurement = measurement.to(device=device, dtype=self.unet.dtype) |
| impedance_dipin, record_features = self._encode_conditioning(dipin, record) |
| conditioning = torch.cat([impedance_dipin, record_features], dim=1) |
| impedance_latents = self._randn_like_sample( |
| torch.empty( |
| impedance_dipin.shape, |
| device=device, |
| dtype=self.unet.dtype, |
| ), |
| generator, |
| ) |
|
|
| operator_fn = self._get_operator_fn(operator) |
| pixel_params = pixel_optimization_param or {} |
| last_pixel_params = last_pixel_optimization_param or self._default_last_pixel_optimization_param() |
| schedule = self._build_ddim_scheduler(self.scheduler, num_inference_steps, device) |
| time_range = [int(timestep) for timestep in schedule.timesteps.tolist()] |
| resample_start_index = len(time_range) // 4 |
|
|
| for step_idx, timestep in enumerate(time_range): |
| index = len(time_range) - step_idx - 1 |
| with torch.no_grad(): |
| next_latents, pred_x0, pseudo_x0, step_stats = self._ddim_step( |
| impedance_latents, |
| conditioning, |
| timestep, |
| schedule, |
| eta, |
| generator, |
| quantize_denoised, |
| ) |
|
|
| if (index >= resample_start_index or index == 0) and ( |
| index % interval == 0 or index == 0 |
| ): |
| x_t_reference = next_latents.detach().clone() |
| sigma = sigma_a * (1.0 - step_stats["a_prev"]) / ( |
| 1.0 - step_stats["a_t"] |
| ) |
| sigma = sigma * (1.0 - step_stats["a_t"] / step_stats["a_prev"]) |
| sigma = torch.clamp(sigma, min=1e-12) |
|
|
| with torch.no_grad(): |
| pseudo_x0_pixel = self.vq_model.decode( |
| pseudo_x0.detach().to(dtype=self.vq_model.dtype) |
| ).sample |
| optimized_pixels = self._optimize_pixels( |
| pseudo_x0_pixel, |
| measurement, |
| operator_fn, |
| last_pixel_params if index == 0 else pixel_params, |
| ) |
| with torch.no_grad(): |
| optimized_latents = self.vq_model.encode( |
| optimized_pixels.to(dtype=self.vq_model.dtype) |
| ).latents.to(dtype=self.unet.dtype) |
| next_latents = self._stochastic_resample( |
| optimized_latents, |
| x_t_reference, |
| step_stats["a_prev"], |
| sigma.to(dtype=self.unet.dtype), |
| generator, |
| ) |
|
|
| impedance_latents = next_latents.detach() |
|
|
| with torch.no_grad(): |
| impedance_samples = self.vq_model.decode( |
| impedance_latents.to(dtype=self.vq_model.dtype) |
| ).sample |
| impedance_reconstructed = None |
| if image is not None: |
| image = image.to(device=device, dtype=self.vq_model.dtype) |
| image_latents = self.vq_model.encode(image).latents |
| impedance_reconstructed = self.vq_model.decode(image_latents).sample |
|
|
| if output_type == "np": |
| impedance_samples = impedance_samples.detach().cpu().numpy() |
| impedance_latents = impedance_latents.detach().cpu().numpy() |
| impedance_dipin = impedance_dipin.detach().cpu().numpy() |
| record_features = record_features.detach().cpu().numpy() |
| if impedance_reconstructed is not None: |
| impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy() |
|
|
| return SeismicImpInvLDDPMPipelineOutput( |
| impedance_samples=impedance_samples, |
| impedance_latents=impedance_latents, |
| impedance_dipin=impedance_dipin, |
| impedance_reconstructed=impedance_reconstructed, |
| record_features=record_features, |
| ) |
|
|