saii-cldm-synthetic / codes /pipeline.py
mally-2000's picture
Reorganize inference code under codes
89fe4af verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
import numpy as np
import torch
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DModel, VQModel
from diffusers.utils import BaseOutput
@dataclass
class SeismicImpInvLDDPMPipelineOutput(BaseOutput):
impedance_samples: torch.Tensor | np.ndarray
impedance_latents: torch.Tensor | np.ndarray
impedance_dipin: torch.Tensor | np.ndarray
impedance_reconstructed: torch.Tensor | np.ndarray | None = None
record_features: torch.Tensor | np.ndarray | None = None
class SeismicImpInvLDDPMPipeline(DiffusionPipeline):
"""SAII-LDDPM impedance inversion pipeline."""
def __init__(
self,
vq_model: VQModel,
condition_encoder: torch.nn.Module,
unet: UNet2DModel,
scheduler: DDPMScheduler,
):
super().__init__()
self.register_modules(
vq_model=vq_model,
condition_encoder=condition_encoder,
unet=unet,
scheduler=scheduler,
)
def _encode_conditioning(
self, dipin: torch.Tensor, record: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
dipin_latents = self.vq_model.encode(dipin).latents
if hasattr(self.condition_encoder, "encode") and callable(
self.condition_encoder.encode
):
record_features = self.condition_encoder.encode(record)
else:
record_features = self.condition_encoder(record)
return (
dipin_latents.to(dtype=self.unet.dtype),
record_features.to(dtype=self.unet.dtype),
)
@staticmethod
def _extract_into_tensor(
arr: torch.Tensor, timesteps: torch.Tensor, broadcast_shape: torch.Size
) -> torch.Tensor:
values = arr.to(device=timesteps.device, dtype=torch.float32).gather(0, timesteps)
return values.reshape(timesteps.shape[0], *((1,) * (len(broadcast_shape) - 1)))
@staticmethod
def _build_legacy_ddpm_buffers(
scheduler: DDPMScheduler, device: torch.device
) -> dict[str, torch.Tensor]:
betas = scheduler.betas.to(device=device, dtype=torch.float32)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat(
[torch.ones(1, device=device), alphas_cumprod[:-1]], dim=0
)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_log_variance_clipped = torch.log(
torch.clamp(posterior_variance, min=1e-20)
)
return {
"sqrt_recip_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod),
"sqrt_recipm1_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod - 1),
"posterior_mean_coef1": betas
* torch.sqrt(alphas_cumprod_prev)
/ (1.0 - alphas_cumprod),
"posterior_mean_coef2": (1.0 - alphas_cumprod_prev)
* torch.sqrt(alphas)
/ (1.0 - alphas_cumprod),
"posterior_log_variance_clipped": posterior_log_variance_clipped,
}
@staticmethod
def _randn_like_sample(
sample: torch.Tensor, generator: torch.Generator | list[torch.Generator] | None
) -> torch.Tensor:
if isinstance(generator, list):
if len(generator) != sample.shape[0]:
raise ValueError(
f"Expected {sample.shape[0]} generators, got {len(generator)}"
)
return torch.cat(
[
torch.randn(
sample[i : i + 1].shape,
generator=sample_generator,
device=sample.device,
dtype=sample.dtype,
)
for i, sample_generator in enumerate(generator)
],
dim=0,
)
return torch.randn(
sample.shape, generator=generator, device=sample.device, dtype=sample.dtype
)
def _ddpm_step(
self,
latents: torch.Tensor,
conditioning: torch.Tensor,
timestep: torch.Tensor,
generator: torch.Generator | list[torch.Generator] | None,
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
model_input = torch.cat([latents, conditioning], dim=1)
noise_pred = self.unet(model_input, timestep).sample
pred_x0 = (
self._extract_into_tensor(
buffers["sqrt_recip_alphas_cumprod"], timestep, latents.shape
)
* latents
- self._extract_into_tensor(
buffers["sqrt_recipm1_alphas_cumprod"], timestep, latents.shape
)
* noise_pred
)
pred_x0 = self.vq_model.quantize(pred_x0)[0]
model_mean = (
self._extract_into_tensor(
buffers["posterior_mean_coef1"], timestep, latents.shape
)
* pred_x0
+ self._extract_into_tensor(
buffers["posterior_mean_coef2"], timestep, latents.shape
)
* latents
)
noise = self._randn_like_sample(latents, generator)
nonzero_mask = (1 - (timestep == 0).float()).reshape(
latents.shape[0], *((1,) * (len(latents.shape) - 1))
)
return model_mean + nonzero_mask * (
0.5
* self._extract_into_tensor(
buffers["posterior_log_variance_clipped"], timestep, latents.shape
)
).exp() * noise
@torch.no_grad()
def __call__(
self,
dipin: torch.Tensor,
record: torch.Tensor,
image: torch.Tensor | None = None,
num_inference_steps: int = 1000,
seed: int | None = None,
seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None,
generator: torch.Generator | None = None,
output_type: str = "tensor",
) -> SeismicImpInvLDDPMPipelineOutput:
device = self.unet.device
if seeds is not None:
if isinstance(seeds, torch.Tensor):
seeds = seeds.detach().cpu().tolist()
seeds = [int(value) for value in seeds]
if len(seeds) != dipin.shape[0]:
raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}")
generator = [
torch.Generator(device=device).manual_seed(value) for value in seeds
]
elif seed is not None:
generator = torch.Generator(device=device).manual_seed(seed)
elif generator is None:
generator = torch.Generator(device=device)
dipin = dipin.to(device=device, dtype=self.vq_model.dtype)
record = record.to(device=device, dtype=self.unet.dtype)
impedance_dipin, record_features = self._encode_conditioning(dipin, record)
conditioning = torch.cat([impedance_dipin, record_features], dim=1)
impedance_latents = self._randn_like_sample(
torch.empty(
impedance_dipin.shape,
device=device,
dtype=self.unet.dtype,
),
generator,
)
buffers = self._build_legacy_ddpm_buffers(self.scheduler, device)
for t in reversed(range(num_inference_steps)):
timestep = torch.full(
(impedance_latents.shape[0],), t, device=device, dtype=torch.long
)
impedance_latents = self._ddpm_step(
impedance_latents, conditioning, timestep, generator, buffers
)
impedance_samples = self.vq_model.decode(
impedance_latents.to(dtype=self.vq_model.dtype)
).sample
impedance_reconstructed = None
if image is not None:
image = image.to(device=device, dtype=self.vq_model.dtype)
image_latents = self.vq_model.encode(image).latents
impedance_reconstructed = self.vq_model.decode(image_latents).sample
if output_type == "np":
impedance_samples = impedance_samples.detach().cpu().numpy()
impedance_latents = impedance_latents.detach().cpu().numpy()
impedance_dipin = impedance_dipin.detach().cpu().numpy()
record_features = record_features.detach().cpu().numpy()
if impedance_reconstructed is not None:
impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy()
return SeismicImpInvLDDPMPipelineOutput(
impedance_samples=impedance_samples,
impedance_latents=impedance_latents,
impedance_dipin=impedance_dipin,
impedance_reconstructed=impedance_reconstructed,
record_features=record_features,
)
@torch.no_grad()
def encode_decode(
self, image: torch.Tensor, output_type: str = "tensor"
) -> torch.Tensor | np.ndarray:
image = image.to(device=self.vq_model.device, dtype=self.vq_model.dtype)
reconstruction = self.vq_model.decode(self.vq_model.encode(image).latents).sample
if output_type == "np":
return reconstruction.detach().cpu().numpy()
return reconstruction
class SeismicImpInvCLDMPipeline(SeismicImpInvLDDPMPipeline):
"""SAII-CLDM inference pipeline.
This reuses the same trained components as SAII-LDDPM and replaces only the
reverse sampling procedure with DDIM plus model-driven resampling.
"""
@staticmethod
def _get_operator_fn(operator: Any) -> Callable[[torch.Tensor], torch.Tensor]:
if callable(operator):
return operator
if hasattr(operator, "forward") and callable(operator.forward):
return operator.forward
raise TypeError("`operator` must be callable or expose a callable `forward` method.")
@staticmethod
def _build_ddim_scheduler(
scheduler: DDPMScheduler,
num_inference_steps: int,
device: torch.device,
) -> DDIMScheduler:
ddim_scheduler = DDIMScheduler.from_config(
scheduler.config,
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
timestep_spacing="leading",
)
ddim_scheduler.set_timesteps(num_inference_steps, device=device)
return ddim_scheduler
@staticmethod
def _default_pixel_optimization_param() -> dict[str, float | int]:
return {
"eps": 1e-4,
"max_iters": 100,
"lr": 1e-5,
"y_coef": 1.0,
"x_coef": 0.0,
"tv_coef": 0.0,
"dh_coef": 1.0,
"dw_coef": 1.5,
}
@staticmethod
def _default_last_pixel_optimization_param() -> dict[str, float | int]:
return {
"eps": 1e-4,
"max_iters": 1,
"lr": 1e-4,
"y_coef": 1.0,
"x_coef": 0.1,
"tv_coef": 0.0,
"dh_coef": 1.0,
"dw_coef": 1.5,
}
@staticmethod
def _tv_loss(x: torch.Tensor, *, dh_coef: float, dw_coef: float) -> torch.Tensor:
dh = dh_coef * torch.abs(x[..., :, 1:] - x[..., :, :-1])
dw = dw_coef * torch.abs(x[..., 1:, :] - x[..., :-1, :])
return torch.mean(dh[..., :-1, :] + dw[..., :, :-1])
def _ddim_step(
self,
latents: torch.Tensor,
conditioning: torch.Tensor,
timestep: int,
scheduler: DDIMScheduler,
eta: float,
generator: torch.Generator | list[torch.Generator] | None,
quantize_denoised: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]:
model_input = torch.cat(
[
scheduler.scale_model_input(latents, timestep),
conditioning.to(dtype=latents.dtype),
],
dim=1,
)
timestep_tensor = torch.full(
(latents.shape[0],), timestep, device=latents.device, dtype=torch.long
)
noise_pred = self.unet(model_input, timestep_tensor).sample
alpha_t = scheduler.alphas_cumprod[timestep].to(
device=latents.device, dtype=latents.dtype
)
prev_timestep = timestep - (
scheduler.config.num_train_timesteps // scheduler.num_inference_steps
)
if prev_timestep >= 0:
alpha_prev = scheduler.alphas_cumprod[prev_timestep].to(
device=latents.device, dtype=latents.dtype
)
else:
alpha_prev = scheduler.final_alpha_cumprod.to(
device=latents.device, dtype=latents.dtype
)
beta_t = 1.0 - alpha_t
pred_x0 = (latents - beta_t.sqrt() * noise_pred) / alpha_t.sqrt()
pseudo_x0 = (latents - beta_t * noise_pred) / alpha_t.sqrt()
if quantize_denoised:
pred_x0 = self.vq_model.quantize(pred_x0.to(dtype=self.vq_model.dtype))[0].to(
dtype=latents.dtype
)
noise_pred = (latents - alpha_t.sqrt() * pred_x0) / beta_t.sqrt()
variance = scheduler._get_variance(timestep, prev_timestep).to(
device=latents.device, dtype=latents.dtype
)
sigma_t = eta * variance.sqrt()
direction = torch.clamp(1.0 - alpha_prev - sigma_t**2, min=0.0).sqrt() * noise_pred
noise = torch.zeros_like(latents)
if eta > 0:
noise = sigma_t * self._randn_like_sample(latents, generator)
prev_sample = alpha_prev.sqrt() * pred_x0 + direction + noise
batch_shape = (latents.shape[0], 1, 1, 1)
return (
prev_sample,
pred_x0,
pseudo_x0,
{
"a_t": torch.full(
batch_shape,
float(alpha_t.item()),
device=latents.device,
dtype=latents.dtype,
),
"a_prev": torch.full(
batch_shape,
float(alpha_prev.item()),
device=latents.device,
dtype=latents.dtype,
),
},
)
def _optimize_pixels(
self,
x_prime: torch.Tensor,
measurement: torch.Tensor,
operator_fn: Callable[[torch.Tensor], torch.Tensor],
params: dict[str, Any],
) -> torch.Tensor:
merged = {**self._default_pixel_optimization_param(), **params}
if int(merged["max_iters"]) <= 0:
return x_prime.detach()
loss_fn = torch.nn.MSELoss(reduction="mean")
opt_var = x_prime.detach().clone().requires_grad_(True)
opt_init = x_prime.detach().clone()
optimizer = torch.optim.AdamW([opt_var], lr=float(merged["lr"]))
for _ in range(int(merged["max_iters"])):
optimizer.zero_grad(set_to_none=True)
measurement_loss = (
loss_fn(measurement, operator_fn(opt_var)) * float(merged["y_coef"])
+ loss_fn(opt_init, opt_var) * float(merged["x_coef"])
)
if float(merged["tv_coef"]) != 0.0:
measurement_loss = measurement_loss + float(merged["tv_coef"]) * self._tv_loss(
opt_var,
dh_coef=float(merged["dh_coef"]),
dw_coef=float(merged["dw_coef"]),
)
measurement_loss.backward()
optimizer.step()
if float(measurement_loss.detach().cpu().item()) < float(merged["eps"]):
break
return opt_var.detach()
def _stochastic_resample(
self,
pseudo_x0: torch.Tensor,
x_t: torch.Tensor,
a_t: torch.Tensor,
sigma: torch.Tensor,
generator: torch.Generator | list[torch.Generator] | None,
) -> torch.Tensor:
sigma = torch.clamp(sigma, min=1e-12)
one_minus_a_t = torch.clamp(1.0 - a_t, min=1e-12)
noise = self._randn_like_sample(pseudo_x0, generator)
return (
(sigma * a_t.sqrt() * pseudo_x0 + one_minus_a_t * x_t)
/ (sigma + one_minus_a_t)
+ noise * torch.sqrt(1.0 / (1.0 / sigma + 1.0 / one_minus_a_t))
)
def __call__(
self,
dipin: torch.Tensor,
record: torch.Tensor,
measurement: torch.Tensor | None = None,
operator: Any | None = None,
image: torch.Tensor | None = None,
num_inference_steps: int = 30,
seed: int | None = None,
seeds: list[int] | tuple[int, ...] | torch.Tensor | None = None,
generator: torch.Generator | None = None,
eta: float = 0.01,
interval: int = 6,
sigma_a: float = 20.0,
pixel_optimization_param: dict[str, Any] | None = None,
last_pixel_optimization_param: dict[str, Any] | None = None,
quantize_denoised: bool = False,
output_type: str = "tensor",
) -> SeismicImpInvLDDPMPipelineOutput:
if measurement is None:
measurement = record
if operator is None:
raise ValueError("SAII-CLDM requires a forward `operator`.")
if interval <= 0:
raise ValueError("`interval` must be a positive integer.")
device = self.unet.device
if seeds is not None:
if isinstance(seeds, torch.Tensor):
seeds = seeds.detach().cpu().tolist()
seeds = [int(value) for value in seeds]
if len(seeds) != dipin.shape[0]:
raise ValueError(f"Expected {dipin.shape[0]} seeds, got {len(seeds)}")
generator = [
torch.Generator(device=device).manual_seed(value) for value in seeds
]
elif seed is not None:
generator = torch.Generator(device=device).manual_seed(seed)
elif generator is None:
generator = torch.Generator(device=device)
with torch.no_grad():
dipin = dipin.to(device=device, dtype=self.vq_model.dtype)
record = record.to(device=device, dtype=self.unet.dtype)
measurement = measurement.to(device=device, dtype=self.unet.dtype)
impedance_dipin, record_features = self._encode_conditioning(dipin, record)
conditioning = torch.cat([impedance_dipin, record_features], dim=1)
impedance_latents = self._randn_like_sample(
torch.empty(
impedance_dipin.shape,
device=device,
dtype=self.unet.dtype,
),
generator,
)
operator_fn = self._get_operator_fn(operator)
pixel_params = pixel_optimization_param or {}
last_pixel_params = last_pixel_optimization_param or self._default_last_pixel_optimization_param()
schedule = self._build_ddim_scheduler(self.scheduler, num_inference_steps, device)
time_range = [int(timestep) for timestep in schedule.timesteps.tolist()]
resample_start_index = len(time_range) // 4
for step_idx, timestep in enumerate(time_range):
index = len(time_range) - step_idx - 1
with torch.no_grad():
next_latents, pred_x0, pseudo_x0, step_stats = self._ddim_step(
impedance_latents,
conditioning,
timestep,
schedule,
eta,
generator,
quantize_denoised,
)
if (index >= resample_start_index or index == 0) and (
index % interval == 0 or index == 0
):
x_t_reference = next_latents.detach().clone()
sigma = sigma_a * (1.0 - step_stats["a_prev"]) / (
1.0 - step_stats["a_t"]
)
sigma = sigma * (1.0 - step_stats["a_t"] / step_stats["a_prev"])
sigma = torch.clamp(sigma, min=1e-12)
with torch.no_grad():
pseudo_x0_pixel = self.vq_model.decode(
pseudo_x0.detach().to(dtype=self.vq_model.dtype)
).sample
optimized_pixels = self._optimize_pixels(
pseudo_x0_pixel,
measurement,
operator_fn,
last_pixel_params if index == 0 else pixel_params,
)
with torch.no_grad():
optimized_latents = self.vq_model.encode(
optimized_pixels.to(dtype=self.vq_model.dtype)
).latents.to(dtype=self.unet.dtype)
next_latents = self._stochastic_resample(
optimized_latents,
x_t_reference,
step_stats["a_prev"],
sigma.to(dtype=self.unet.dtype),
generator,
)
impedance_latents = next_latents.detach()
with torch.no_grad():
impedance_samples = self.vq_model.decode(
impedance_latents.to(dtype=self.vq_model.dtype)
).sample
impedance_reconstructed = None
if image is not None:
image = image.to(device=device, dtype=self.vq_model.dtype)
image_latents = self.vq_model.encode(image).latents
impedance_reconstructed = self.vq_model.decode(image_latents).sample
if output_type == "np":
impedance_samples = impedance_samples.detach().cpu().numpy()
impedance_latents = impedance_latents.detach().cpu().numpy()
impedance_dipin = impedance_dipin.detach().cpu().numpy()
record_features = record_features.detach().cpu().numpy()
if impedance_reconstructed is not None:
impedance_reconstructed = impedance_reconstructed.detach().cpu().numpy()
return SeismicImpInvLDDPMPipelineOutput(
impedance_samples=impedance_samples,
impedance_latents=impedance_latents,
impedance_dipin=impedance_dipin,
impedance_reconstructed=impedance_reconstructed,
record_features=record_features,
)