from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable import numpy as np import torch from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DModel, VQModel from diffusers.utils import BaseOutput @dataclass class SeismicImpInvLDDPMPipelineOutput(BaseOutput): impedance_samples: torch.Tensor | np.ndarray impedance_latents: torch.Tensor | np.ndarray impedance_dipin: torch.Tensor | np.ndarray impedance_reconstructed: torch.Tensor | np.ndarray | None = None record_features: torch.Tensor | np.ndarray | None = None class SeismicImpInvLDDPMPipeline(DiffusionPipeline): """SAII-LDDPM impedance inversion pipeline.""" def __init__( self, vq_model: VQModel, condition_encoder: torch.nn.Module, unet: UNet2DModel, scheduler: DDPMScheduler, ): super().__init__() self.register_modules( vq_model=vq_model, condition_encoder=condition_encoder, unet=unet, scheduler=scheduler, ) def _encode_conditioning( self, dipin: torch.Tensor, record: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: dipin_latents = self.vq_model.encode(dipin).latents if hasattr(self.condition_encoder, "encode") and callable( self.condition_encoder.encode ): record_features = self.condition_encoder.encode(record) else: record_features = self.condition_encoder(record) return ( dipin_latents.to(dtype=self.unet.dtype), record_features.to(dtype=self.unet.dtype), ) @staticmethod def _extract_into_tensor( arr: torch.Tensor, timesteps: torch.Tensor, broadcast_shape: torch.Size ) -> torch.Tensor: values = arr.to(device=timesteps.device, dtype=torch.float32).gather(0, timesteps) return values.reshape(timesteps.shape[0], *((1,) * (len(broadcast_shape) - 1))) @staticmethod def _build_legacy_ddpm_buffers( scheduler: DDPMScheduler, device: torch.device ) -> dict[str, torch.Tensor]: betas = scheduler.betas.to(device=device, dtype=torch.float32) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = torch.cat( [torch.ones(1, device=device), alphas_cumprod[:-1]], dim=0 ) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) posterior_log_variance_clipped = torch.log( torch.clamp(posterior_variance, min=1e-20) ) return { "sqrt_recip_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod), "sqrt_recipm1_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod - 1), "posterior_mean_coef1": betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), "posterior_mean_coef2": (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod), "posterior_log_variance_clipped": posterior_log_variance_clipped, } @staticmethod def _randn_like_sample( sample: torch.Tensor, generator: torch.Generator | list[torch.Generator] | None ) -> torch.Tensor: if isinstance(generator, list): if len(generator) != sample.shape[0]: raise ValueError( f"Expected {sample.shape[0]} generators, got {len(generator)}" ) return torch.cat( [ torch.randn( sample[i : i + 1].shape, generator=sample_generator, device=sample.device, dtype=sample.dtype, ) for i, sample_generator in enumerate(generator) ], dim=0, ) return torch.randn( sample.shape, generator=generator, device=sample.device, dtype=sample.dtype ) def _ddpm_step( self, latents: torch.Tensor, conditioning: torch.Tensor, timestep: torch.Tensor, generator: torch.Generator | list[torch.Generator] | None, buffers: dict[str, torch.Tensor], ) -> torch.Tensor: model_input = torch.cat([latents, conditioning], dim=1) noise_pred = self.unet(model_input, timestep).sample pred_x0 = ( self._extract_into_tensor( buffers["sqrt_recip_alphas_cumprod"], timestep, latents.shape ) * latents - self._extract_into_tensor( buffers["sqrt_recipm1_alphas_cumprod"], timestep, latents.shape ) * noise_pred ) pred_x0 = self.vq_model.quantize(pred_x0)[0] model_mean = ( self._extract_into_tensor( buffers["posterior_mean_coef1"], timestep, latents.shape ) * pred_x0 + self._extract_into_tensor( buffers["posterior_mean_coef2"], timestep, latents.shape ) * latents ) noise = self._randn_like_sample(latents, generator) nonzero_mask = (1 - (timestep == 0).float()).reshape( latents.shape[0], *((1,) * (len(latents.shape) - 1)) ) return model_mean + nonzero_mask * ( 0.5 * self._extract_into_tensor( buffers["posterior_log_variance_clipped"], timestep, latents.shape ) ).exp() * noise @torch.no_grad() def __call__( self, dipin: torch.Tensor, record: torch.Tensor, image: torch.Tensor | None = None, num_inference_steps: int = 1000, seed: int | None = None, seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None, generator: torch.Generator | None = None, output_type: str = "tensor", ) -> SeismicImpInvLDDPMPipelineOutput: device = self.unet.device if seeds is not None: if isinstance(seeds, torch.Tensor): seeds = seeds.detach().cpu().tolist() seeds = [int(value) for value in seeds] if len(seeds) != dipin.shape[0]: raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}") generator = [ torch.Generator(device=device).manual_seed(value) for value in seeds ] elif seed is not None: generator = torch.Generator(device=device).manual_seed(seed) elif generator is None: generator = torch.Generator(device=device) dipin = dipin.to(device=device, dtype=self.vq_model.dtype) record = record.to(device=device, dtype=self.unet.dtype) impedance_dipin, record_features = self._encode_conditioning(dipin, record) conditioning = torch.cat([impedance_dipin, record_features], dim=1) impedance_latents = self._randn_like_sample( torch.empty( impedance_dipin.shape, device=device, dtype=self.unet.dtype, ), generator, ) buffers = self._build_legacy_ddpm_buffers(self.scheduler, device) for t in reversed(range(num_inference_steps)): timestep = torch.full( (impedance_latents.shape[0],), t, device=device, dtype=torch.long ) impedance_latents = self._ddpm_step( impedance_latents, conditioning, timestep, generator, buffers ) impedance_samples = self.vq_model.decode( impedance_latents.to(dtype=self.vq_model.dtype) ).sample impedance_reconstructed = None if image is not None: image = image.to(device=device, dtype=self.vq_model.dtype) image_latents = self.vq_model.encode(image).latents impedance_reconstructed = self.vq_model.decode(image_latents).sample if output_type == "np": impedance_samples = impedance_samples.detach().cpu().numpy() impedance_latents = impedance_latents.detach().cpu().numpy() impedance_dipin = impedance_dipin.detach().cpu().numpy() record_features = record_features.detach().cpu().numpy() if impedance_reconstructed is not None: impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy() return SeismicImpInvLDDPMPipelineOutput( impedance_samples=impedance_samples, impedance_latents=impedance_latents, impedance_dipin=impedance_dipin, impedance_reconstructed=impedance_reconstructed, record_features=record_features, ) @torch.no_grad() def encode_decode( self, image: torch.Tensor, output_type: str = "tensor" ) -> torch.Tensor | np.ndarray: image = image.to(device=self.vq_model.device, dtype=self.vq_model.dtype) reconstruction = self.vq_model.decode(self.vq_model.encode(image).latents).sample if output_type == "np": return reconstruction.detach().cpu().numpy() return reconstruction class SeismicImpInvCLDMPipeline(SeismicImpInvLDDPMPipeline): """SAII-CLDM inference pipeline. This reuses the same trained components as SAII-LDDPM and replaces only the reverse sampling procedure with DDIM plus model-driven resampling. """ @staticmethod def _get_operator_fn(operator: Any) -> Callable[[torch.Tensor], torch.Tensor]: if callable(operator): return operator if hasattr(operator, "forward") and callable(operator.forward): return operator.forward raise TypeError("`operator` must be callable or expose a callable `forward` method.") @staticmethod def _build_ddim_scheduler( scheduler: DDPMScheduler, num_inference_steps: int, device: torch.device, ) -> DDIMScheduler: ddim_scheduler = DDIMScheduler.from_config( scheduler.config, clip_sample=False, set_alpha_to_one=False, steps_offset=1, timestep_spacing="leading", ) ddim_scheduler.set_timesteps(num_inference_steps, device=device) return ddim_scheduler @staticmethod def _default_pixel_optimization_param() -> dict[str, float | int]: return { "eps": 1e-4, "max_iters": 100, "lr": 1e-5, "y_coef": 1.0, "x_coef": 0.0, "tv_coef": 0.0, "dh_coef": 1.0, "dw_coef": 1.5, } @staticmethod def _default_last_pixel_optimization_param() -> dict[str, float | int]: return { "eps": 1e-4, "max_iters": 1, "lr": 1e-4, "y_coef": 1.0, "x_coef": 0.1, "tv_coef": 0.0, "dh_coef": 1.0, "dw_coef": 1.5, } @staticmethod def _tv_loss(x: torch.Tensor, *, dh_coef: float, dw_coef: float) -> torch.Tensor: dh = dh_coef * torch.abs(x[..., :, 1:] - x[..., :, :-1]) dw = dw_coef * torch.abs(x[..., 1:, :] - x[..., :-1, :]) return torch.mean(dh[..., :-1, :] + dw[..., :, :-1]) def _ddim_step( self, latents: torch.Tensor, conditioning: torch.Tensor, timestep: int, scheduler: DDIMScheduler, eta: float, generator: torch.Generator | list[torch.Generator] | None, quantize_denoised: bool, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: model_input = torch.cat( [ scheduler.scale_model_input(latents, timestep), conditioning.to(dtype=latents.dtype), ], dim=1, ) timestep_tensor = torch.full( (latents.shape[0],), timestep, device=latents.device, dtype=torch.long ) noise_pred = self.unet(model_input, timestep_tensor).sample alpha_t = scheduler.alphas_cumprod[timestep].to( device=latents.device, dtype=latents.dtype ) prev_timestep = timestep - ( scheduler.config.num_train_timesteps // scheduler.num_inference_steps ) if prev_timestep >= 0: alpha_prev = scheduler.alphas_cumprod[prev_timestep].to( device=latents.device, dtype=latents.dtype ) else: alpha_prev = scheduler.final_alpha_cumprod.to( device=latents.device, dtype=latents.dtype ) beta_t = 1.0 - alpha_t pred_x0 = (latents - beta_t.sqrt() * noise_pred) / alpha_t.sqrt() pseudo_x0 = (latents - beta_t * noise_pred) / alpha_t.sqrt() if quantize_denoised: pred_x0 = self.vq_model.quantize(pred_x0.to(dtype=self.vq_model.dtype))[0].to( dtype=latents.dtype ) noise_pred = (latents - alpha_t.sqrt() * pred_x0) / beta_t.sqrt() variance = scheduler._get_variance(timestep, prev_timestep).to( device=latents.device, dtype=latents.dtype ) sigma_t = eta * variance.sqrt() direction = torch.clamp(1.0 - alpha_prev - sigma_t**2, min=0.0).sqrt() * noise_pred noise = torch.zeros_like(latents) if eta > 0: noise = sigma_t * self._randn_like_sample(latents, generator) prev_sample = alpha_prev.sqrt() * pred_x0 + direction + noise batch_shape = (latents.shape[0], 1, 1, 1) return ( prev_sample, pred_x0, pseudo_x0, { "a_t": torch.full( batch_shape, float(alpha_t.item()), device=latents.device, dtype=latents.dtype, ), "a_prev": torch.full( batch_shape, float(alpha_prev.item()), device=latents.device, dtype=latents.dtype, ), }, ) def _optimize_pixels( self, x_prime: torch.Tensor, measurement: torch.Tensor, operator_fn: Callable[[torch.Tensor], torch.Tensor], params: dict[str, Any], ) -> torch.Tensor: merged = {**self._default_pixel_optimization_param(), **params} if int(merged["max_iters"]) <= 0: return x_prime.detach() loss_fn = torch.nn.MSELoss(reduction="mean") opt_var = x_prime.detach().clone().requires_grad_(True) opt_init = x_prime.detach().clone() optimizer = torch.optim.AdamW([opt_var], lr=float(merged["lr"])) for _ in range(int(merged["max_iters"])): optimizer.zero_grad(set_to_none=True) measurement_loss = ( loss_fn(measurement, operator_fn(opt_var)) * float(merged["y_coef"]) + loss_fn(opt_init, opt_var) * float(merged["x_coef"]) ) if float(merged["tv_coef"]) != 0.0: measurement_loss = measurement_loss + float(merged["tv_coef"]) * self._tv_loss( opt_var, dh_coef=float(merged["dh_coef"]), dw_coef=float(merged["dw_coef"]), ) measurement_loss.backward() optimizer.step() if float(measurement_loss.detach().cpu().item()) < float(merged["eps"]): break return opt_var.detach() def _stochastic_resample( self, pseudo_x0: torch.Tensor, x_t: torch.Tensor, a_t: torch.Tensor, sigma: torch.Tensor, generator: torch.Generator | list[torch.Generator] | None, ) -> torch.Tensor: sigma = torch.clamp(sigma, min=1e-12) one_minus_a_t = torch.clamp(1.0 - a_t, min=1e-12) noise = self._randn_like_sample(pseudo_x0, generator) return ( (sigma * a_t.sqrt() * pseudo_x0 + one_minus_a_t * x_t) / (sigma + one_minus_a_t) + noise * torch.sqrt(1.0 / (1.0 / sigma + 1.0 / one_minus_a_t)) ) def __call__( self, dipin: torch.Tensor, record: torch.Tensor, measurement: torch.Tensor | None = None, operator: Any | None = None, image: torch.Tensor | None = None, num_inference_steps: int = 30, seed: int | None = None, seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None, generator: torch.Generator | None = None, eta: float = 0.01, interval: int = 6, sigma_a: float = 20.0, pixel_optimization_param: dict[str, Any] | None = None, last_pixel_optimization_param: dict[str, Any] | None = None, quantize_denoised: bool = False, output_type: str = "tensor", ) -> SeismicImpInvLDDPMPipelineOutput: if measurement is None: measurement = record if operator is None: raise ValueError("SAII-CLDM requires a forward `operator`.") if interval <= 0: raise ValueError("`interval` must be a positive integer.") device = self.unet.device if seeds is not None: if isinstance(seeds, torch.Tensor): seeds = seeds.detach().cpu().tolist() seeds = [int(value) for value in seeds] if len(seeds) != dipin.shape[0]: raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}") generator = [ torch.Generator(device=device).manual_seed(value) for value in seeds ] elif seed is not None: generator = torch.Generator(device=device).manual_seed(seed) elif generator is None: generator = torch.Generator(device=device) with torch.no_grad(): dipin = dipin.to(device=device, dtype=self.vq_model.dtype) record = record.to(device=device, dtype=self.unet.dtype) measurement = measurement.to(device=device, dtype=self.unet.dtype) impedance_dipin, record_features = self._encode_conditioning(dipin, record) conditioning = torch.cat([impedance_dipin, record_features], dim=1) impedance_latents = self._randn_like_sample( torch.empty( impedance_dipin.shape, device=device, dtype=self.unet.dtype, ), generator, ) operator_fn = self._get_operator_fn(operator) pixel_params = pixel_optimization_param or {} last_pixel_params = last_pixel_optimization_param or self._default_last_pixel_optimization_param() schedule = self._build_ddim_scheduler(self.scheduler, num_inference_steps, device) time_range = [int(timestep) for timestep in schedule.timesteps.tolist()] resample_start_index = len(time_range) // 4 for step_idx, timestep in enumerate(time_range): index = len(time_range) - step_idx - 1 with torch.no_grad(): next_latents, pred_x0, pseudo_x0, step_stats = self._ddim_step( impedance_latents, conditioning, timestep, schedule, eta, generator, quantize_denoised, ) if (index >= resample_start_index or index == 0) and ( index % interval == 0 or index == 0 ): x_t_reference = next_latents.detach().clone() sigma = sigma_a * (1.0 - step_stats["a_prev"]) / ( 1.0 - step_stats["a_t"] ) sigma = sigma * (1.0 - step_stats["a_t"] / step_stats["a_prev"]) sigma = torch.clamp(sigma, min=1e-12) with torch.no_grad(): pseudo_x0_pixel = self.vq_model.decode( pseudo_x0.detach().to(dtype=self.vq_model.dtype) ).sample optimized_pixels = self._optimize_pixels( pseudo_x0_pixel, measurement, operator_fn, last_pixel_params if index == 0 else pixel_params, ) with torch.no_grad(): optimized_latents = self.vq_model.encode( optimized_pixels.to(dtype=self.vq_model.dtype) ).latents.to(dtype=self.unet.dtype) next_latents = self._stochastic_resample( optimized_latents, x_t_reference, step_stats["a_prev"], sigma.to(dtype=self.unet.dtype), generator, ) impedance_latents = next_latents.detach() with torch.no_grad(): impedance_samples = self.vq_model.decode( impedance_latents.to(dtype=self.vq_model.dtype) ).sample impedance_reconstructed = None if image is not None: image = image.to(device=device, dtype=self.vq_model.dtype) image_latents = self.vq_model.encode(image).latents impedance_reconstructed = self.vq_model.decode(image_latents).sample if output_type == "np": impedance_samples = impedance_samples.detach().cpu().numpy() impedance_latents = impedance_latents.detach().cpu().numpy() impedance_dipin = impedance_dipin.detach().cpu().numpy() record_features = record_features.detach().cpu().numpy() if impedance_reconstructed is not None: impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy() return SeismicImpInvLDDPMPipelineOutput( impedance_samples=impedance_samples, impedance_latents=impedance_latents, impedance_dipin=impedance_dipin, impedance_reconstructed=impedance_reconstructed, record_features=record_features, )