YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
3D OCT Latent Diffusion Model
Model Overview
This repository contains a 3D Optical Coherence Tomography (OCT) latent diffusion model for volumetric image synthesis. The model operates in the latent space of a learned 3D VAE and uses Rectified Flow for the diffusion process. It is intended for research and development in medical imaging synthesis.
Architecture
- Autoencoder: 3D variational autoencoder (MONAI MAISI-style
AutoencoderKlMaisi), 1 input channel, 4 latent channels, spatial dimensions 3. Encoder/decoder usenum_channels[64, 128, 256] and L1 reconstruction loss with optional perceptual and adversarial terms. - Diffusion: 3D UNet (
DiffusionModelUNetMaisi) in latent space with 4 input/output channels,num_channels[64, 128, 256, 512], attention at the two deepest levels, ResBlock up/down, and conditioning on spacing. The noise process is Rectified Flow (RFlowScheduler, 1000 steps, continuous time, scale 1.4). - Pipeline: Encode 3D OCT volume to latent; normalize latents using dataset-derived mean/std; run Rectified Flow sampling in latent space; denormalize and decode to image space.
Training Details
- Stage 1 β VAE warmup: 3D OCT VAE trained on patches (e.g. 32Γ128Γ128) with L1, perceptual, and KL terms. Best and final checkpoints are provided.
- Stage 2 β Latent statistics: Per-channel mean and standard deviation of the VAE latents are computed on the training set and saved for normalization during diffusion training and inference.
- Stage 3 β Diffusion (scratch): 3D UNet trained from scratch with Rectified Flow in latent space for 50k steps. Best validation checkpoint is released as
diffusion_unet_scratch_best.pt. - Stage 4 β Diffusion (finetune): Same UNet architecture finetuned from the scratch run for 20k steps. Released checkpoints: 15k-step weights (
diffusion_unet_20k_finetune.pt) and 10k-step full resume checkpoint for reproducibility (diffusion_unet_resume_10k.pt).
Training used AdamW, gradient accumulation, optional LR scheduler with warmup, EMA, and mixed precision (AMP) on NVIDIA A100 GPUs.
Hardware
Training was conducted on NVIDIA A100 GPUs. Inference can be run on a single GPU with sufficient VRAM for 3D patches (e.g. 32Γ128Γ128 or similar).
Latent Normalization
The diffusion model is trained and should be run with normalized latents. At training time, latents are normalized as z_norm = (z - mean) / std using per-channel statistics computed over the training set. The same normalization must be applied at inference:
- Load
latent_stats/latent_mean.npyandlatent_stats/latent_std.npy(shape[4]for 4 latent channels). - When encoding:
z = ae.encode(...)thenz_norm = (z - mean) / std. - When decoding after sampling:
z = z_norm * std + meanthenx = ae.decode(z).
The file latent_stats/latent_stats_report.json documents the dataset split, sample count, and per-channel statistics used.
Folder Structure
hf_release_oct_latent_diffusion/
βββ vae/
β βββ vae_best.pt # VAE best validation checkpoint (recommended)
β βββ vae_final.pt # VAE final training checkpoint
βββ diffusion/
β βββ diffusion_unet_20k_finetune.pt # UNet weights at 15k finetune steps (recommended for sampling)
β βββ diffusion_unet_scratch_best.pt # UNet best validation from 50k scratch training
β βββ diffusion_unet_resume_10k.pt # Full resume checkpoint at 10k finetune (optimizer/EMA/step)
βββ latent_stats/
β βββ latent_mean.npy # Per-channel latent mean [4]
β βββ latent_std.npy # Per-channel latent std [4]
β βββ latent_stats_report.json
βββ configs/
β βββ cfg_oct_snapshot.json # Training config snapshot (paths may point to local env)
βββ README.md
Example: Loading and Sampling (PyTorch)
Requires the same network definitions as in the training code (e.g. MONAI MAISI AutoencoderKlMaisi and DiffusionModelUNetMaisi), plus the Rectified Flow scheduler and your datalist/config paths.
import json
import numpy as np
import torch
from pathlib import Path
# Paths (adjust to your clone or Hugging Face cache)
REPO = Path("hf_release_oct_latent_diffusion")
VAE_CKPT = REPO / "vae" / "vae_best.pt"
UNET_CKPT = REPO / "diffusion" / "diffusion_unet_20k_finetune.pt"
LATENT_MEAN = np.load(REPO / "latent_stats" / "latent_mean.npy")
LATENT_STD = np.load(REPO / "latent_stats" / "latent_std.npy")
# Load VAE
def load_ae_state_dict(ckpt_path):
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
for k in ["autoencoder_state_dict", "autoencoder", "state_dict", "model"]:
if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
return ckpt[k]
return ckpt
ae = ... # build from config_network_rflow.json autoencoder_def
ae.load_state_dict(load_ae_state_dict(VAE_CKPT), strict=False)
ae.eval()
# Load UNet (weights-only checkpoint)
def load_unet_state_dict(ckpt_path):
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
for k in ["unet_state_dict", "unet", "state_dict", "model"]:
if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
return ckpt[k]
return ckpt
unet = ... # build from config_network_rflow.json diffusion_unet_def
unet.load_state_dict(load_unet_state_dict(UNET_CKPT), strict=False)
unet.eval()
# Sampling: initialize z_T ~ N(0, I) in normalized space, then Rectified Flow
# steps (e.g. 40), then denormalize: z = z_norm * LATENT_STD + LATENT_MEAN, then ae.decode(z).
# See training repo for full scheduler and sampling loop.
Citation
If you use this model in your research, please cite the associated work (citation to be added upon publication).