YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

3D OCT Latent Diffusion Model

Model Overview

This repository contains a 3D Optical Coherence Tomography (OCT) latent diffusion model for volumetric image synthesis. The model operates in the latent space of a learned 3D VAE and uses Rectified Flow for the diffusion process. It is intended for research and development in medical imaging synthesis.

Architecture

  • Autoencoder: 3D variational autoencoder (MONAI MAISI-style AutoencoderKlMaisi), 1 input channel, 4 latent channels, spatial dimensions 3. Encoder/decoder use num_channels [64, 128, 256] and L1 reconstruction loss with optional perceptual and adversarial terms.
  • Diffusion: 3D UNet (DiffusionModelUNetMaisi) in latent space with 4 input/output channels, num_channels [64, 128, 256, 512], attention at the two deepest levels, ResBlock up/down, and conditioning on spacing. The noise process is Rectified Flow (RFlowScheduler, 1000 steps, continuous time, scale 1.4).
  • Pipeline: Encode 3D OCT volume to latent; normalize latents using dataset-derived mean/std; run Rectified Flow sampling in latent space; denormalize and decode to image space.

Training Details

  • Stage 1 – VAE warmup: 3D OCT VAE trained on patches (e.g. 32Γ—128Γ—128) with L1, perceptual, and KL terms. Best and final checkpoints are provided.
  • Stage 2 – Latent statistics: Per-channel mean and standard deviation of the VAE latents are computed on the training set and saved for normalization during diffusion training and inference.
  • Stage 3 – Diffusion (scratch): 3D UNet trained from scratch with Rectified Flow in latent space for 50k steps. Best validation checkpoint is released as diffusion_unet_scratch_best.pt.
  • Stage 4 – Diffusion (finetune): Same UNet architecture finetuned from the scratch run for 20k steps. Released checkpoints: 15k-step weights (diffusion_unet_20k_finetune.pt) and 10k-step full resume checkpoint for reproducibility (diffusion_unet_resume_10k.pt).

Training used AdamW, gradient accumulation, optional LR scheduler with warmup, EMA, and mixed precision (AMP) on NVIDIA A100 GPUs.

Hardware

Training was conducted on NVIDIA A100 GPUs. Inference can be run on a single GPU with sufficient VRAM for 3D patches (e.g. 32Γ—128Γ—128 or similar).

Latent Normalization

The diffusion model is trained and should be run with normalized latents. At training time, latents are normalized as z_norm = (z - mean) / std using per-channel statistics computed over the training set. The same normalization must be applied at inference:

  • Load latent_stats/latent_mean.npy and latent_stats/latent_std.npy (shape [4] for 4 latent channels).
  • When encoding: z = ae.encode(...) then z_norm = (z - mean) / std.
  • When decoding after sampling: z = z_norm * std + mean then x = ae.decode(z).

The file latent_stats/latent_stats_report.json documents the dataset split, sample count, and per-channel statistics used.

Folder Structure

hf_release_oct_latent_diffusion/
β”œβ”€β”€ vae/
β”‚   β”œβ”€β”€ vae_best.pt          # VAE best validation checkpoint (recommended)
β”‚   └── vae_final.pt         # VAE final training checkpoint
β”œβ”€β”€ diffusion/
β”‚   β”œβ”€β”€ diffusion_unet_20k_finetune.pt   # UNet weights at 15k finetune steps (recommended for sampling)
β”‚   β”œβ”€β”€ diffusion_unet_scratch_best.pt   # UNet best validation from 50k scratch training
β”‚   └── diffusion_unet_resume_10k.pt      # Full resume checkpoint at 10k finetune (optimizer/EMA/step)
β”œβ”€β”€ latent_stats/
β”‚   β”œβ”€β”€ latent_mean.npy      # Per-channel latent mean [4]
β”‚   β”œβ”€β”€ latent_std.npy       # Per-channel latent std [4]
β”‚   └── latent_stats_report.json
β”œβ”€β”€ configs/
β”‚   └── cfg_oct_snapshot.json   # Training config snapshot (paths may point to local env)
└── README.md

Example: Loading and Sampling (PyTorch)

Requires the same network definitions as in the training code (e.g. MONAI MAISI AutoencoderKlMaisi and DiffusionModelUNetMaisi), plus the Rectified Flow scheduler and your datalist/config paths.

import json
import numpy as np
import torch
from pathlib import Path

# Paths (adjust to your clone or Hugging Face cache)
REPO = Path("hf_release_oct_latent_diffusion")
VAE_CKPT = REPO / "vae" / "vae_best.pt"
UNET_CKPT = REPO / "diffusion" / "diffusion_unet_20k_finetune.pt"
LATENT_MEAN = np.load(REPO / "latent_stats" / "latent_mean.npy")
LATENT_STD = np.load(REPO / "latent_stats" / "latent_std.npy")

# Load VAE
def load_ae_state_dict(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    for k in ["autoencoder_state_dict", "autoencoder", "state_dict", "model"]:
        if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
            return ckpt[k]
    return ckpt

ae = ...  # build from config_network_rflow.json autoencoder_def
ae.load_state_dict(load_ae_state_dict(VAE_CKPT), strict=False)
ae.eval()

# Load UNet (weights-only checkpoint)
def load_unet_state_dict(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    for k in ["unet_state_dict", "unet", "state_dict", "model"]:
        if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
            return ckpt[k]
    return ckpt

unet = ...  # build from config_network_rflow.json diffusion_unet_def
unet.load_state_dict(load_unet_state_dict(UNET_CKPT), strict=False)
unet.eval()

# Sampling: initialize z_T ~ N(0, I) in normalized space, then Rectified Flow
# steps (e.g. 40), then denormalize: z = z_norm * LATENT_STD + LATENT_MEAN, then ae.decode(z).
# See training repo for full scheduler and sampling loop.

Citation

If you use this model in your research, please cite the associated work (citation to be added upon publication).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support