diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..e0fc2128b3abb03b579f96f612179ed758a9087a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/example_1.png filter=lfs diff=lfs merge=lfs -text
+assets/generated_images.png filter=lfs diff=lfs merge=lfs -text
+assets/reconstructed.png filter=lfs diff=lfs merge=lfs -text
+assets/teaser.png filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6768d3d8c3e16a0313dc7800eb5cad1241f2eb43
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2026 Maitreya Patel
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 4304c3dea1a8ee2d6b911796730dac72685d98b8..a2ff51f361c0a484aa015443eb09a370c6811b68 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,179 @@
+# [CVPR 2026] VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations
+
+
+
+
+
+
+ CVPR 2026 |
+ Paper |
+ Project Page |
+ Checkpoints
+
+
+
+
+
+
+
+
+
---
-title: VibeToken
-emoji: 🦀
-colorFrom: blue
-colorTo: red
-sdk: gradio
-sdk_version: 6.6.0
-python_version: '3.12'
-app_file: app.py
-pinned: false
-license: mit
----
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+We introduce an efficient, resolution-agnostic autoregressive (AR) image synthesis approach that generalizes to **arbitrary resolutions and aspect ratios**, narrowing the gap to diffusion models at scale. At its core is **VibeToken**, a novel resolution-agnostic 1D Transformer-based image tokenizer that encodes images into a dynamic, user-controllable sequence of 32--256 tokens, achieving state-of-the-art efficiency and performance trade-off. Building on VibeToken, we present **VibeToken-Gen**, a class-conditioned AR generator with out-of-the-box support for arbitrary resolutions while requiring significantly fewer compute resources.
+
+### 🔥 Highlights
+
+| | |
+|---|---|
+| 🎯 **1024×1024 in just 64 tokens** | Achieves **3.94 gFID** vs. 5.87 gFID for diffusion-based SOTA (1,024 tokens) |
+| ⚡ **Constant 179G FLOPs** | 63× more efficient than LlamaGen (11T FLOPs at 1024×1024) |
+| 🌐 **Resolution-agnostic** | Supports arbitrary resolutions and aspect ratios out of the box |
+| 🎛️ **Dynamic token count** | User-controllable 32--256 tokens per image |
+| 🔍 **Native super-resolution** | Supports image super-resolution out of the box |
+
+
+## 📰 News
+
+- **[Feb 2026]** 🎉 VibeToken is accepted at **CVPR 2026**!
+- **[Feb 2026]** Training scripts released.
+- **[Feb 2026]** Inference code and checkpoints released.
+
+
+## 🚀 Quick Start
+
+```bash
+# 1. Clone and setup
+git clone https://github.com//VibeToken.git
+cd VibeToken
+uv venv --python=3.11.6
+source .venv/bin/activate
+uv pip install -r requirements.txt
+
+# 2. Download a checkpoint (see Checkpoints section below)
+mkdir -p checkpoints
+wget https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_LL.bin -O ./checkpoints/VibeToken_LL.bin
+
+# 3. Reconstruct an image
+python reconstruct.py --auto \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint ./checkpoints/VibeToken_LL.bin \
+ --image ./assets/example_1.png \
+ --output ./assets/reconstructed.png
+```
+
+
+## 📦 Checkpoints
+
+All checkpoints are hosted on [Hugging Face](https://huggingface.co/mpatel57/VibeToken).
+
+#### Reconstruction Checkpoints
+
+| Name | Resolution | rFID (256 tokens) | rFID (64 tokens) | Download |
+|------|:----------:|:-----------------:|:----------------:|----------|
+| VibeToken-LL | 1024×1024 | 3.76 | 4.12 | [VibeToken_LL.bin](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_LL.bin) |
+| VibeToken-LL | 256×256 | 5.12 | 0.90 | same as above |
+| VibeToken-SL | 1024×1024 | 4.25 | 2.41 | [VibeToken_SL.bin](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_SL.bin) |
+| VibeToken-SL | 256×256 | 5.44 | 0.40 | same as above |
+
+#### Generation Checkpoints
+
+| Name | Training Resolution(s) | Tokens | Best gFID | Download |
+|------|:----------------------:|:------:|:---------:|----------|
+| VibeToken-Gen-B | 256×256 | 65 | 7.62 | [VibeTokenGen-b-fixed65_dynamic_1500k.pt](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeTokenGen-b-fixed65_dynamic_1500k.pt) |
+| VibeToken-Gen-B | 1024×1024 | 65 | 7.37 | same as above |
+| VibeToken-Gen-XXL | 256×256 | 65 | 3.62 | [VibeTokenGen-xxl-dynamic-65_750k.pt](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeTokenGen-xxl-dynamic-65_750k.pt) |
+| VibeToken-Gen-XXL | 1024×1024 | 65 | **3.54** | same as above |
+
+
+## 🛠️ Setup
+
+```bash
+uv venv --python=3.11.6
+source .venv/bin/activate
+uv pip install -r requirements.txt
+```
+
+> **Tip:** If you don't have `uv`, install it via `pip install uv` or see [uv docs](https://github.com/astral-sh/uv). Alternatively, use `python -m venv .venv && pip install -r requirements.txt`.
+
+
+## 🖼️ VibeToken Reconstruction
+
+Download the VibeToken-LL checkpoint (see [Checkpoints](#-checkpoints)), then:
+
+```bash
+# Auto mode (recommended) -- automatically determines optimal patch sizes
+python reconstruct.py --auto \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint ./checkpoints/VibeToken_LL.bin \
+ --image ./assets/example_1.png \
+ --output ./assets/reconstructed.png
+
+# Manual mode -- specify patch sizes explicitly
+python reconstruct.py \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint ./checkpoints/VibeToken_LL.bin \
+ --image ./assets/example_1.png \
+ --output ./assets/reconstructed.png \
+ --encoder_patch_size 16 \
+ --decoder_patch_size 16
+```
+
+> **Note:** For best performance, the input image resolution should be a multiple of 32. Images with other resolutions are automatically rescaled to the nearest multiple of 32.
+
+
+## 🎨 VibeToken-Gen: ImageNet-1k Generation
+
+Download both the VibeToken-LL and VibeToken-Gen-XXL checkpoints (see [Checkpoints](#-checkpoints)), then:
+
+```bash
+python generate.py \
+ --gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \
+ --gpt-model GPT-XXL --num-output-layer 4 \
+ --num-codebooks 8 --codebook-size 32768 \
+ --image-size 256 --cfg-scale 4.0 --top-k 500 --temperature 1.0 \
+ --class-dropout-prob 0.1 \
+ --extra-layers "QKV" \
+ --latent-size 65 \
+ --config ./configs/vibetoken_ll.yaml \
+ --vq-ckpt ./checkpoints/VibeToken_LL.bin \
+ --sample-dir ./assets/ \
+ --skip-folder-creation \
+ --compile \
+ --decoder-patch-size 32,32 \
+ --target-resolution 1024,1024 \
+ --llamagen-target-resolution 256,256 \
+ --precision bf16 \
+ --global-seed 156464151
+```
+
+The `--target-resolution` controls the tokenizer output resolution, while `--llamagen-target-resolution` controls the generator's internal resolution (max 512×512; for higher resolutions, the tokenizer handles upscaling).
+
+
+## 🏋️ Training
+
+To train the VibeToken tokenizer from scratch, please refer to [TRAIN.md](TRAIN.md) for detailed instructions.
+
+
+## 🙏 Acknowledgement
+
+We would like to acknowledge the following repositories that inspired our work and upon which we directly build:
+[1d-tokenizer](https://github.com/bytedance/1d-tokenizer),
+[LlamaGen](https://github.com/FoundationVision/LlamaGen), and
+[UniTok](https://github.com/FoundationVision/UniTok).
+
+
+## 📝 Citation
+
+If you find VibeToken useful in your research, please consider citing:
+
+```bibtex
+@inproceedings{vibetoken2026,
+ title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations},
+ author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan},
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year = {2026}
+}
+```
+
+If you have any questions, feel free to open an issue or reach out!
diff --git a/TRAIN.md b/TRAIN.md
new file mode 100644
index 0000000000000000000000000000000000000000..3ca4fa9a9ba049934fcc98064346b794e26a59fe
--- /dev/null
+++ b/TRAIN.md
@@ -0,0 +1,110 @@
+# Training Instructions
+
+## VibeToken MVQ Tokenizer
+
+This repository contains the training code for our tokenizer.
+We provide the example config [VibeToken-Small](configs/training/VibeToken_small.yaml) that trains the small encoder/decoder architecture with 32-64 tokens.
+
+### Data Preparation
+
+All data paths are controlled by the `DATA_DIR` environment variable. Set it once to point to your preferred storage location:
+
+```bash
+export DATA_DIR=/path/to/your/storage # defaults to ./data if unset
+```
+
+Download ImageNet-1k and convert to WebDataset format:
+
+```bash
+source .venv/bin/activate
+
+# Option 1: Use the setup script (recommended)
+bash setup.sh
+
+# Option 2: Run steps manually
+export HF_HUB_ENABLE_HF_TRANSFER=1
+huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir "${DATA_DIR}/imagenet-1k"
+python data/convert_imagenet_to_wds.py \
+ --input_dir "${DATA_DIR}/imagenet-1k" \
+ --output_dir "${DATA_DIR}/imagenet_wds"
+```
+
+After preparation, update the shard paths in your training config to match your `DATA_DIR`:
+
+```yaml
+dataset:
+ params:
+ train_shards_path_or_url: "/imagenet_wds/imagenet-train-{000001..000128}.tar"
+ eval_shards_path_or_url: "/imagenet_wds/imagenet-val-{000001..000004}.tar"
+```
+
+### Launch Training
+
+Start training on 1 node with 8 GPUs:
+
+```bash
+source .venv/bin/activate
+bash train_tokenizer.sh
+```
+
+### Config Reference
+
+Below are the important hyperparameters to manage the training.
+
+```yaml
+model:
+ vq_model:
+ vit_enc_model_size: "small" # this can be small/base/large
+ vit_dec_model_size: "small" # this can be small/base/large
+ num_latent_tokens: 64 # in paper we set this to 256
+
+losses:
+ discriminator_start: 100_000 # set based on convergence, in paper we set this to 250_000
+
+dataset:
+ params:
+ pretokenization: True # keep this true if using the current setup
+ train_shards_path_or_url: "./data/imagenet_wds/imagenet-train-{000001..000128}.tar"
+ eval_shards_path_or_url: "./data/imagenet_wds/imagenet-val-{000001..000004}.tar"
+ preprocessing:
+ resize_shorter_edge: 512 # maximum size during pretraining but can be any value
+ crop_size: 512 # maximum size during pretraining but can be any value
+ min_tokens: 32 # minimum number of tokens to generate
+ max_tokens: 64 # maximum number of tokens to generate
+
+training:
+ gradient_accumulation_steps: 1 # increase for LL model that does not fit on single node
+ per_gpu_batch_size: 32 # decrease to 16 for LL model; during GAN training this is halved
+ max_train_steps: 400_000 # in paper we train up to 650_000; model may diverge after 600_000
+ num_generated_images: 2 # for validation
+ variable_resolution: # any-to-any resolution training
+ any2any: True
+ dim:
+ - [256, 256]
+ - [512, 512]
+ - [384, 256]
+ - [256, 384]
+ - [512, 384]
+ - [384, 512]
+ ratio: [0.3, 0.3, 0.1, 0.1, 0.1, 0.1] # probability per resolution; must sum to 1.0
+
+
+# Remove patch mixture parameters unless the model does not fit in memory.
+# This will slow down training and may hurt performance.
+# We do not use this in our normal setup.
+model:
+ vq_model:
+ encoder:
+ patch_mixture_start_layer: 2
+ patch_mixture_end_layer: 22
+ decoder:
+ patch_mixture_start_layer: 2
+ patch_mixture_end_layer: 22
+```
+
+
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ba34a59d354c11c378fe362ed38acfc0778908a
--- /dev/null
+++ b/app.py
@@ -0,0 +1,471 @@
+"""
+VibeToken-Gen Gradio Demo
+Class-conditional ImageNet generation with dynamic resolution support.
+"""
+import spaces
+
+import os
+import random
+
+import gradio as gr
+import numpy as np
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+torch.set_float32_matmul_precision("high")
+torch.set_grad_enabled(False)
+setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from vibetokengen.generate import generate
+from vibetokengen.model import GPT_models
+from vibetoken import VibeTokenTokenizer
+
+# ---------------------------------------------------------------------------
+# Configuration
+# ---------------------------------------------------------------------------
+
+HF_REPO = "mpatel57/VibeToken"
+USE_XXL = os.environ.get("VIBETOKEN_XXL", "0") == "1"
+
+if USE_XXL:
+ GPT_MODEL_NAME = "GPT-XXL"
+ GPT_CKPT_FILENAME = "VibeTokenGen-xxl-dynamic-65_750k.pt"
+ NUM_OUTPUT_LAYER = 4
+ EXTRA_LAYERS = "QKV"
+else:
+ GPT_MODEL_NAME = "GPT-B"
+ GPT_CKPT_FILENAME = "VibeTokenGen-b-fixed65_dynamic_1500k.pt"
+ NUM_OUTPUT_LAYER = 4
+ EXTRA_LAYERS = "QKV"
+
+VQ_CKPT_FILENAME = "VibeToken_LL.bin"
+CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "vibetoken_ll.yaml")
+
+CODEBOOK_SIZE = 32768
+NUM_CODEBOOKS = 8
+LATENT_SIZE = 65
+NUM_CLASSES = 1000
+CLS_TOKEN_NUM = 1
+CLASS_DROPOUT_PROB = 0.1
+CAPPING = 50.0
+
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
+COMPILE = os.environ.get("VIBETOKEN_NO_COMPILE", "0") != "1" and DEVICE == "cuda"
+
+# ---------------------------------------------------------------------------
+# ImageNet class labels (curated popular subset)
+# ---------------------------------------------------------------------------
+
+IMAGENET_CLASSES = {
+ "Golden Retriever": 207,
+ "Labrador Retriever": 208,
+ "German Shepherd": 235,
+ "Siberian Husky": 250,
+ "Pembroke Corgi": 263,
+ "Tabby Cat": 281,
+ "Persian Cat": 283,
+ "Siamese Cat": 284,
+ "Tiger": 292,
+ "Lion": 291,
+ "Cheetah": 293,
+ "Brown Bear": 294,
+ "Giant Panda": 388,
+ "Red Fox": 277,
+ "Arctic Fox": 279,
+ "Timber Wolf": 269,
+ "Bald Eagle": 22,
+ "Macaw": 88,
+ "Flamingo": 130,
+ "Peacock": 84,
+ "Goldfish": 1,
+ "Great White Shark": 2,
+ "Jellyfish": 107,
+ "Monarch Butterfly": 323,
+ "Ladybug": 301,
+ "Snail": 113,
+ "Red Sports Car": 817,
+ "School Bus": 779,
+ "Steam Locomotive": 820,
+ "Sailboat": 914,
+ "Space Shuttle": 812,
+ "Castle": 483,
+ "Church": 497,
+ "Lighthouse": 437,
+ "Volcano": 980,
+ "Lakeside": 975,
+ "Cliff": 972,
+ "Coral Reef": 973,
+ "Valley": 979,
+ "Seashore": 978,
+ "Mushroom": 947,
+ "Broccoli": 937,
+ "Pizza": 963,
+ "Ice Cream": 928,
+ "Cheeseburger": 933,
+ "Espresso": 967,
+ "Acoustic Guitar": 402,
+ "Grand Piano": 579,
+ "Violin": 889,
+ "Balloon": 417,
+}
+
+GENERATOR_RESOLUTION_PRESETS = {
+ "256 × 256": (256, 256),
+ "384 × 256": (384, 256),
+ "256 × 384": (256, 384),
+ "384 × 384": (384, 384),
+ "512 × 256": (512, 256),
+ "256 × 512": (256, 512),
+ "512 × 512": (512, 512),
+}
+
+OUTPUT_RESOLUTION_PRESETS = {
+ "Same as generator": None,
+ "256 × 256": (256, 256),
+ "384 × 384": (384, 384),
+ "512 × 512": (512, 512),
+ "768 × 768": (768, 768),
+ "1024 × 1024": (1024, 1024),
+ "512 × 256 (2:1)": (512, 256),
+ "256 × 512 (1:2)": (256, 512),
+ "768 × 512 (3:2)": (768, 512),
+ "512 × 768 (2:3)": (512, 768),
+ "1024 × 512 (2:1)": (1024, 512),
+ "512 × 1024 (1:2)": (512, 1024),
+}
+
+# ---------------------------------------------------------------------------
+# Model loading
+# ---------------------------------------------------------------------------
+
+vq_model = None
+gpt_model = None
+
+
+def download_checkpoint(filename: str) -> str:
+ return hf_hub_download(repo_id=HF_REPO, filename=filename)
+
+
+def _make_res_tensors(gen_h: int, gen_w: int, multiplier: int):
+ """Create normalized resolution tensors for the GPT generator."""
+ th = torch.tensor(gen_h / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1)
+ tw = torch.tensor(gen_w / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1)
+ return th, tw
+
+
+def _warmup(model):
+ """Run a throwaway generation to trigger torch.compile and warm CUDA caches."""
+ print("Warming up (first call triggers compilation, may take ~30-60s)...")
+ dummy_cond = torch.tensor([0], device=DEVICE)
+ th, tw = _make_res_tensors(256, 256, multiplier=2)
+ with torch.inference_mode():
+ generate(
+ model, dummy_cond, LATENT_SIZE, NUM_CODEBOOKS,
+ cfg_scale=4.0, cfg_interval=-1,
+ target_h=th, target_w=tw,
+ temperature=1.0, top_k=500, top_p=1.0, sample_logits=True,
+ )
+ if DEVICE == "cuda":
+ torch.cuda.synchronize()
+ print("Warmup complete — subsequent generations will be fast.")
+
+
+def load_models():
+ global vq_model, gpt_model
+
+ print("Downloading checkpoints (if needed)...")
+ vq_path = download_checkpoint(VQ_CKPT_FILENAME)
+ gpt_path = download_checkpoint(GPT_CKPT_FILENAME)
+
+ print(f"Loading VibeToken tokenizer from {vq_path}...")
+ vq_model = VibeTokenTokenizer.from_config(
+ CONFIG_PATH, vq_path, device=DEVICE, dtype=DTYPE,
+ )
+ print("VibeToken tokenizer loaded.")
+
+ print(f"Loading {GPT_MODEL_NAME} from {gpt_path}...")
+ gpt_model = GPT_models[GPT_MODEL_NAME](
+ vocab_size=CODEBOOK_SIZE,
+ block_size=LATENT_SIZE,
+ num_classes=NUM_CLASSES,
+ cls_token_num=CLS_TOKEN_NUM,
+ model_type="c2i",
+ num_codebooks=NUM_CODEBOOKS,
+ n_output_layer=NUM_OUTPUT_LAYER,
+ class_dropout_prob=CLASS_DROPOUT_PROB,
+ extra_layers=EXTRA_LAYERS,
+ capping=CAPPING,
+ ).to(device=DEVICE, dtype=DTYPE)
+
+ checkpoint = torch.load(gpt_path, map_location="cpu", weights_only=False)
+ if "model" in checkpoint:
+ weights = checkpoint["model"]
+ elif "module" in checkpoint:
+ weights = checkpoint["module"]
+ elif "state_dict" in checkpoint:
+ weights = checkpoint["state_dict"]
+ else:
+ weights = checkpoint
+ gpt_model.load_state_dict(weights, strict=True)
+ gpt_model.eval()
+ del checkpoint
+ print(f"{GPT_MODEL_NAME} loaded.")
+
+ if COMPILE:
+ print("Compiling GPT model with torch.compile (max-autotune)...")
+ gpt_model = torch.compile(gpt_model, mode="max-autotune", fullgraph=True)
+ _warmup(gpt_model)
+ else:
+ print("Skipping torch.compile (set VIBETOKEN_NO_COMPILE=0 to enable).")
+
+
+# ---------------------------------------------------------------------------
+# Decoder patch-size heuristic
+# ---------------------------------------------------------------------------
+
+def auto_decoder_patch_size(h: int, w: int) -> tuple[int, int]:
+ max_dim = max(h, w)
+ if max_dim <= 256:
+ ps = 8
+ elif max_dim <= 512:
+ ps = 16
+ else:
+ ps = 32
+ return (ps, ps)
+
+
+# ---------------------------------------------------------------------------
+# Generation
+# ---------------------------------------------------------------------------
+
+@torch.inference_mode()
+@spaces.GPU(duration=90)
+def generate_image(
+ class_name: str,
+ class_id: int,
+ gen_resolution_preset: str,
+ out_resolution_preset: str,
+ decoder_ps_choice: str,
+ cfg_scale: float,
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ seed: int,
+ randomize_seed: bool,
+):
+ if vq_model is None or gpt_model is None:
+ raise gr.Error("Models are still loading. Please wait a moment and try again.")
+
+ if randomize_seed:
+ seed = random.randint(0, 2**31 - 1)
+
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ if DEVICE == "cuda":
+ torch.cuda.manual_seed_all(seed)
+
+ if class_name and class_name != "Custom (enter ID below)":
+ cid = IMAGENET_CLASSES[class_name]
+ else:
+ cid = int(class_id)
+ cid = max(0, min(cid, NUM_CLASSES - 1))
+
+ gen_h, gen_w = GENERATOR_RESOLUTION_PRESETS[gen_resolution_preset]
+
+ out_res = OUTPUT_RESOLUTION_PRESETS[out_resolution_preset]
+ if out_res is None:
+ out_h, out_w = gen_h, gen_w
+ else:
+ out_h, out_w = out_res
+
+ if decoder_ps_choice == "Auto":
+ dec_ps = auto_decoder_patch_size(out_h, out_w)
+ else:
+ ps = int(decoder_ps_choice)
+ dec_ps = (ps, ps)
+
+ multiplier = 2 if cfg_scale > 1.0 else 1
+
+ c_indices = torch.tensor([cid], device=DEVICE)
+ th, tw = _make_res_tensors(gen_h, gen_w, multiplier)
+
+ index_sample = generate(
+ gpt_model,
+ c_indices,
+ LATENT_SIZE,
+ NUM_CODEBOOKS,
+ cfg_scale=cfg_scale,
+ cfg_interval=-1,
+ target_h=th,
+ target_w=tw,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ sample_logits=True,
+ )
+
+ index_sample = index_sample.unsqueeze(2)
+ samples = vq_model.decode(
+ index_sample,
+ height=out_h,
+ width=out_w,
+ patch_size=dec_ps,
+ )
+ samples = torch.clamp(samples, 0, 1)
+
+ img_np = (samples[0].permute(1, 2, 0).float().cpu().numpy() * 255).astype("uint8")
+ pil_img = Image.fromarray(img_np)
+
+ return pil_img, seed
+
+
+# ---------------------------------------------------------------------------
+# Gradio UI
+# ---------------------------------------------------------------------------
+
+HEADER_MD = """
+# VibeToken-Gen: Dynamic Resolution Image Generation
+
+
+ Maitreya Patel, Jingtao Li, Weiming Zhuang, Yezhou Yang, Lingjuan Lyu
+ |
+
+CVPR 2026 (Main Conference)
+
+
+ 🤗 Model |
+ 💻 GitHub
+
+
+Generate ImageNet class-conditional images at **arbitrary resolutions** using only **65 tokens**.
+VibeToken-Gen maintains a constant **179G FLOPs** regardless of output resolution.
+"""
+
+CITATION_MD = """
+### Citation
+```bibtex
+@inproceedings{vibetoken2026,
+ title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations},
+ author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan},
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year = {2026}
+}
+```
+"""
+
+class_choices = ["Custom (enter ID below)"] + sorted(IMAGENET_CLASSES.keys())
+
+with gr.Blocks(
+ title="VibeToken-Gen Demo",
+ theme=gr.themes.Soft(),
+) as demo:
+ gr.Markdown(HEADER_MD)
+
+ with gr.Row():
+ # ---- Left column: controls ----
+ with gr.Column(scale=1):
+ class_dropdown = gr.Dropdown(
+ label="ImageNet Class",
+ choices=class_choices,
+ value="Golden Retriever",
+ info="Pick a class or choose 'Custom' to enter an ID manually.",
+ )
+ class_id_input = gr.Number(
+ label="Custom Class ID (0–999)",
+ value=207,
+ minimum=0,
+ maximum=999,
+ step=1,
+ visible=False,
+ )
+ gen_resolution_dropdown = gr.Dropdown(
+ label="Generator Resolution",
+ choices=list(GENERATOR_RESOLUTION_PRESETS.keys()),
+ value="256 × 256",
+ info="Internal resolution for the AR generator (max 512×512).",
+ )
+ out_resolution_dropdown = gr.Dropdown(
+ label="Output Resolution (Decoder)",
+ choices=list(OUTPUT_RESOLUTION_PRESETS.keys()),
+ value="Same as generator",
+ info="Final image resolution. Set higher for super-resolution (e.g. generate at 256, decode at 1024).",
+ )
+ decoder_ps_dropdown = gr.Dropdown(
+ label="Decoder Patch Size",
+ choices=["Auto", "8", "16", "32"],
+ value="Auto",
+ info="'Auto' selects based on output resolution. Larger = faster but coarser.",
+ )
+
+ with gr.Accordion("Advanced Sampling Parameters", open=False):
+ cfg_slider = gr.Slider(
+ label="CFG Scale",
+ minimum=1.0, maximum=20.0, value=4.0, step=0.5,
+ info="Classifier-free guidance strength.",
+ )
+ temp_slider = gr.Slider(
+ label="Temperature",
+ minimum=0.1, maximum=2.0, value=1.0, step=0.05,
+ )
+ topk_slider = gr.Slider(
+ label="Top-k",
+ minimum=0, maximum=2000, value=500, step=10,
+ info="0 disables top-k filtering.",
+ )
+ topp_slider = gr.Slider(
+ label="Top-p",
+ minimum=0.0, maximum=1.0, value=1.0, step=0.05,
+ info="1.0 disables nucleus sampling.",
+ )
+ seed_input = gr.Number(
+ label="Seed", value=0, minimum=0, maximum=2**31 - 1, step=1,
+ )
+ randomize_cb = gr.Checkbox(label="Randomize seed", value=True)
+
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
+
+ # ---- Right column: output ----
+ with gr.Column(scale=2):
+ output_image = gr.Image(label="Generated Image", type="pil", height=512)
+ used_seed = gr.Number(label="Seed used", interactive=False)
+
+ # Show/hide custom class ID field
+ def toggle_custom_id(choice):
+ return gr.update(visible=(choice == "Custom (enter ID below)"))
+
+ class_dropdown.change(
+ fn=toggle_custom_id,
+ inputs=[class_dropdown],
+ outputs=[class_id_input],
+ )
+
+ generate_btn.click(
+ fn=generate_image,
+ inputs=[
+ class_dropdown,
+ class_id_input,
+ gen_resolution_dropdown,
+ out_resolution_dropdown,
+ decoder_ps_dropdown,
+ cfg_slider,
+ temp_slider,
+ topk_slider,
+ topp_slider,
+ seed_input,
+ randomize_cb,
+ ],
+ outputs=[output_image, used_seed],
+ )
+
+ gr.Markdown(CITATION_MD)
+
+
+if __name__ == "__main__":
+ load_models()
+ demo.launch()
diff --git a/assets/example_1.png b/assets/example_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..02baa35758b74ca837670d8fbfee94c0c9f62063
--- /dev/null
+++ b/assets/example_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da07f6cd58181e54c6b4bbbc0458d99f48946da19a27ea02e9a3920bfb2b5d15
+size 334096
diff --git a/assets/generated_images.png b/assets/generated_images.png
new file mode 100644
index 0000000000000000000000000000000000000000..c162f8a3610735f4295df42285d8061f5b6ed409
--- /dev/null
+++ b/assets/generated_images.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e1f5cdef18942f6331460be883485ce5141c7d9c6db7e1cd7596422a57b5cba7
+size 11085224
diff --git a/assets/reconstructed.png b/assets/reconstructed.png
new file mode 100644
index 0000000000000000000000000000000000000000..b8c8f0317a2afe89ef7a49b24727818d8b55141a
--- /dev/null
+++ b/assets/reconstructed.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86b57a2d196f6af62b37979b15a7dcda8de1097669ccb297d9213b32098d5873
+size 343690
diff --git a/assets/teaser.png b/assets/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..5bd55143ee3ba7080e1bb3e5ddda5fb2eee5aa9a
--- /dev/null
+++ b/assets/teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46bd1ff58c18d17b2808ba9445d6beab0ae9098be21b316e9ed730d553b607fb
+size 2859418
diff --git a/configs/training/VibeToken_small.yaml b/configs/training/VibeToken_small.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4922a7b450f3db0bc69ae24a27cd11e981a4fd2
--- /dev/null
+++ b/configs/training/VibeToken_small.yaml
@@ -0,0 +1,102 @@
+experiment:
+ project: "VibeToken_mvq_tiny_main"
+ name: "VibeToken_mvq_tiny_main"
+ output_dir: "wandb/VibeToken_mvq_tiny_main"
+ max_train_examples: 1_281_167
+ save_every: 10_000
+ eval_every: 10_000
+ generate_every: 5_000
+ log_every: 50
+ log_grad_norm_every: 1_000
+ resume: True
+
+model:
+ sub_model_type: "vibetoken"
+ train_with_attention: True
+ eval_with_attention: True
+ vq_model:
+ # encoder: # patch mixture is not supported
+ # patch_mixture_start_layer: 2
+ # patch_mixture_end_layer: 22
+ # decoder: # patch mixture is not supported
+ # patch_mixture_start_layer: 2
+ # patch_mixture_end_layer: 22
+ quantize_mode: mvq
+ codebook_size: 32768 # 32768 / 8 = 4096
+ token_size: 256 # 256 / 8 = 32
+ use_l2_norm: False
+ commitment_cost: 0.25
+ clustering_vq: False
+ num_codebooks: 8
+ # vit arch
+ vit_enc_model_size: "small"
+ vit_dec_model_size: "small"
+ vit_enc_patch_size: 32
+ vit_dec_patch_size: 32
+ num_latent_tokens: 64
+ finetune_decoder: False
+ is_legacy: False
+
+losses:
+ discriminator_start: 100_000
+ quantizer_weight: 1.0
+ discriminator_factor: 1.0
+ discriminator_weight: 0.1
+ perceptual_loss: "lpips-convnext_s-1.0-0.1"
+ perceptual_weight: 1.1
+ reconstruction_loss: "l2"
+ reconstruction_weight: 1.0
+ lecam_regularization_weight: 0.001
+
+dataset:
+ params:
+ pretokenization: True
+ train_shards_path_or_url: "./data/imagenet_wds/imagenet-train-{000001..000128}.tar"
+ eval_shards_path_or_url: "./data/imagenet_wds/imagenet-val-{000001..000004}.tar"
+ num_workers_per_gpu: 12
+ preprocessing:
+ resize_shorter_edge: 512
+ crop_size: 512
+ random_crop: True
+ random_flip: True
+ res_ratio_filtering: True
+ min_tokens: 32
+ max_tokens: 64
+
+optimizer:
+ name: adamw
+ params:
+ learning_rate: 1e-4
+ discriminator_learning_rate: 1e-4
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 1e-4
+
+lr_scheduler:
+ scheduler: "cosine"
+ params:
+ learning_rate: ${optimizer.params.learning_rate}
+ warmup_steps: 10_000
+ end_lr: 1e-5
+
+training:
+ gradient_accumulation_steps: 1
+ per_gpu_batch_size: 32
+ mixed_precision: "fp16"
+ enable_tf32: True
+ enable_wandb: True
+ use_ema: True
+ seed: 42
+ max_train_steps: 400_000
+ num_generated_images: 2
+ max_grad_norm: 1.0
+ variable_resolution:
+ any2any: True
+ dim:
+ - [256, 256]
+ - [512, 512]
+ - [384, 256]
+ - [256, 384]
+ - [512, 384]
+ - [384, 512]
+ ratio: [0.3, 0.3, 0.1, 0.1, 0.1, 0.1]
\ No newline at end of file
diff --git a/configs/vibetoken_ll.yaml b/configs/vibetoken_ll.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7c69a2440fb3f0e528d7eeee20fb6f2aa0bc2e2e
--- /dev/null
+++ b/configs/vibetoken_ll.yaml
@@ -0,0 +1,40 @@
+# VibeToken Large-Large Configuration
+# Large encoder + Large decoder for highest quality
+#
+# Usage:
+# from vibetoken import VibeTokenTokenizer
+# tokenizer = VibeTokenTokenizer.from_config(
+# "configs/vibetoken_ll.yaml",
+# "path/to/checkpoint.bin"
+# )
+
+model:
+ sub_model_type: "vibetoken"
+ vq_model:
+ # Quantization settings
+ quantize_mode: mvq
+ codebook_size: 32768 # 32768 / 8 = 4096 per codebook
+ token_size: 256 # 256 / 8 = 32 per codebook
+ num_codebooks: 8
+ use_l2_norm: false
+ commitment_cost: 0.25
+
+ # Encoder architecture
+ vit_enc_model_size: "large"
+ vit_enc_patch_size: 32
+
+ # Decoder architecture
+ vit_dec_model_size: "large"
+ vit_dec_patch_size: 32
+
+ # Latent tokens
+ num_latent_tokens: 256
+
+ # Mode flags
+ is_legacy: false
+ finetune_decoder: false
+
+# Dataset preprocessing defaults (for reference)
+dataset:
+ preprocessing:
+ crop_size: 512
diff --git a/configs/vibetoken_sl.yaml b/configs/vibetoken_sl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4d0f6def50e0b6d01c3e554fd6c15ff1aa2197ad
--- /dev/null
+++ b/configs/vibetoken_sl.yaml
@@ -0,0 +1,40 @@
+# VibeToken Small-Large Configuration
+# Small encoder + Large decoder for faster encoding
+#
+# Usage:
+# from vibetoken import VibeTokenTokenizer
+# tokenizer = VibeTokenTokenizer.from_config(
+# "configs/vibetoken_sl.yaml",
+# "path/to/checkpoint.bin"
+# )
+
+model:
+ sub_model_type: "vibetoken"
+ vq_model:
+ # Quantization settings
+ quantize_mode: mvq
+ codebook_size: 32768 # 32768 / 8 = 4096 per codebook
+ token_size: 256 # 256 / 8 = 32 per codebook
+ num_codebooks: 8
+ use_l2_norm: false
+ commitment_cost: 0.25
+
+ # Encoder architecture (Small for faster encoding)
+ vit_enc_model_size: "small"
+ vit_enc_patch_size: 32
+
+ # Decoder architecture (Large for quality)
+ vit_dec_model_size: "large"
+ vit_dec_patch_size: 32
+
+ # Latent tokens
+ num_latent_tokens: 256
+
+ # Mode flags
+ is_legacy: false
+ finetune_decoder: false
+
+# Dataset preprocessing defaults (for reference)
+dataset:
+ preprocessing:
+ crop_size: 512
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..283b24901f4bca6318dfe6f31b7ff1c39b0d5375
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1 @@
+from .webdataset_reader import SimpleImageDataset, PretoeknizedDataSetJSONL, PretokenizedWebDataset
diff --git a/data/convert_imagenet_to_wds.py b/data/convert_imagenet_to_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4c15cd898055f8d2339f9818fd1eadaf8291f83
--- /dev/null
+++ b/data/convert_imagenet_to_wds.py
@@ -0,0 +1,56 @@
+# Adapted from https://github.com/webdataset/webdataset-imagenet/blob/main/convert-imagenet.py
+
+import argparse
+import os
+import sys
+import time
+
+import webdataset as wds
+from datasets import load_dataset
+
+
+def convert_imagenet_to_wds(input_dir, output_dir, max_train_samples_per_shard, max_val_samples_per_shard):
+ assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar"))
+ assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar"))
+
+ opat = os.path.join(output_dir, "imagenet-train-%06d.tar")
+ output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard)
+ dataset = load_dataset(input_dir, split="train")
+ now = time.time()
+ for i, example in enumerate(dataset):
+ if i % max_train_samples_per_shard == 0:
+ print(i, file=sys.stderr)
+ img, label = example["image"], example["label"]
+ output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
+ output.close()
+ time_taken = time.time() - now
+ print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.")
+
+ opat = os.path.join(output_dir, "imagenet-val-%06d.tar")
+ output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard)
+ dataset = load_dataset(input_dir, split="validation")
+ now = time.time()
+ for i, example in enumerate(dataset):
+ if i % max_val_samples_per_shard == 0:
+ print(i, file=sys.stderr)
+ img, label = example["image"], example["label"]
+ output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
+ output.close()
+ time_taken = time.time() - now
+ print(f"Wrote {i+1} val examples in {time_taken // 60} min.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_dir", type=str, required=True,
+ help="Path to the ImageNet-1k dataset (HuggingFace format).")
+ parser.add_argument("--output_dir", type=str, required=True,
+ help="Path to the output directory for WebDataset shards.")
+ parser.add_argument("--max_train_samples_per_shard", type=int, default=10000,
+ help="Maximum number of training samples per shard.")
+ parser.add_argument("--max_val_samples_per_shard", type=int, default=10000,
+ help="Maximum number of validation samples per shard.")
+ args = parser.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ convert_imagenet_to_wds(args.input_dir, args.output_dir, args.max_train_samples_per_shard, args.max_val_samples_per_shard)
\ No newline at end of file
diff --git a/data/webdataset_reader.py b/data/webdataset_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..76b95ba61a5e24873ec82f5a60f8bc1f53309743
--- /dev/null
+++ b/data/webdataset_reader.py
@@ -0,0 +1,518 @@
+"""Data loader using webdataset.
+
+Reference:
+ https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
+ https://github.com/huggingface/open-muse/blob/main/training/data.py
+"""
+
+import math
+from typing import List, Union, Text
+import webdataset as wds
+import numpy as np
+import torch
+from torch.utils.data import default_collate
+from torchvision import transforms
+from torch.utils.data import Dataset
+import linecache
+import json
+from PIL import Image
+import random
+import cv2
+import numpy as np
+from tqdm import tqdm
+
+Image.MAX_IMAGE_PIXELS = None
+
+
+def load_json(sample):
+ sample['json'] = json.loads(sample['json'].decode('utf-8'))
+ return sample
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def filter_by_res_ratio(min_res=256, min_ratio=0.5, max_ratio=2.0):
+ def _f(sample):
+ cfg = sample['json']
+ h, w = cfg['original_height'], cfg['original_width']
+ ratio = h/w
+ longer_side = max(h, w)
+ return ratio >= min_ratio and ratio <= max_ratio and longer_side >= min_res
+ return _f
+
+def calculate_laplacian_variance(image):
+ """Calculate the variance of Laplacian which is a measure of image sharpness/blur."""
+ # Convert to grayscale if it's RGB
+ image = np.array(image)
+ if len(image.shape) == 3:
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ else:
+ gray = image
+
+ # Calculate Laplacian
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F)
+
+ # Calculate variance
+ return laplacian.var()
+
+# Add this function to map Laplacian values to token lengths
+def get_dynamic_length(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=256, mean_tokens=128):
+ """
+ Maps Laplacian values to token lengths using a bell curve approach.
+ At the mean Laplacian value, uses mean_tokens.
+ Values further from the mean get mapped to shorter/longer token lengths.
+ """
+ # Prevent division by zero and handle edge cases
+ if std <= 0:
+ return mean_tokens
+
+ # Calculate z-score
+ z_score = (laplacian_value - mean) / std
+
+ # Use bell curve mapping (gaussian)
+ # When z_score is 0 (at mean), we get mean_tokens
+ # As z_score increases, token length increases toward max_tokens
+ # As z_score decreases, token length decreases toward min_tokens
+ scaling_factor = 2.0 # Controls how quickly we reach min/max tokens
+ normalized_position = 0.5 * (1 + math.tanh(scaling_factor * z_score))
+
+ # Map to token range [min_tokens, max_tokens]
+ token_length = min_tokens + normalized_position * (max_tokens - min_tokens)
+ return int(round(token_length))
+
+# Add this function to map Laplacian values to token lengths
+def get_dynamic_length_v2(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=128, mean_tokens=128):
+ """
+ Maps Laplacian values to token lengths using a linear mapping.
+ Ensures laplacian_value=0 maps to min_tokens, mean maps to mean_tokens,
+ and higher values scale up to max_tokens.
+ """
+ # Prevent division by zero and handle edge cases
+ if std <= 0:
+ return mean_tokens
+
+ # Linear mapping from laplacian space to token space
+ # First normalize laplacian value relative to mean
+ normalized = (laplacian_value - 0.0) / mean
+
+ # Map 0->min_tokens, mean->mean_tokens, and scale up linearly
+ if laplacian_value <= mean:
+ # Linear interpolation between min_tokens and mean_tokens
+ ratio = laplacian_value / mean
+ token_length = min_tokens + (mean_tokens - min_tokens) * ratio
+ else:
+ # Linear interpolation between mean_tokens and max_tokens
+ ratio = (laplacian_value - mean) / mean # How far past mean
+ token_length = mean_tokens + (max_tokens - mean_tokens) * ratio
+
+ # Clamp to valid range
+ token_length = max(min_tokens, min(max_tokens, token_length))
+ return int(round(token_length))
+
+def get_laplacian_attention_mask(sample):
+ """Process sample to add Laplacian variance and attention mask."""
+ # Create a new dict to avoid modifying the input
+ processed = dict(sample)
+
+ # Calculate Laplacian variance
+ var = calculate_laplacian_variance(processed["image"])
+ length = get_dynamic_length(var)
+
+ # Create attention mask
+ attention_mask = torch.zeros((128,), dtype=torch.float32)
+ attention_mask[:length+1] = 1.0
+
+ # Add new fields to processed dict
+ processed["laplacian_var"] = var
+ processed["attention_mask"] = attention_mask
+
+ return processed
+
+def get_uniform_attention_mask(min_tokens=32, max_tokens=128):
+ """Process sample to add uniform random attention mask."""
+ def _f(dictionary):
+ # Sample length uniformly between min_tokens and max_tokens
+ length = torch.randint(min_tokens, max_tokens+1, (1,)).item()
+
+ # Create attention mask
+ attention_mask = torch.zeros((max_tokens,), dtype=torch.float32)
+ attention_mask[:length+1] = 1.0
+
+ # Add attention mask to dictionary
+ dictionary["attention_mask"] = attention_mask
+ return dictionary
+ return _f
+
+def process_recap_text(p):
+ def _f(dictionary):
+ if "recap_txt" in dictionary:
+ if random.random() < p:
+ recap_prefixes = ["The image " + v for v in ['depicts', "displays", 'showcases', 'features', 'shows']]
+ # Convert input to string and strip whitespace
+ text = dictionary["recap_txt"].decode("utf-8").strip()
+ # Check if text starts with any of the phrases
+ for phrase in recap_prefixes:
+ if text.startswith(phrase):
+ # Remove the phrase and any leading/trailing whitespace
+ text = text[len(phrase):].strip()
+ # Capitalize the first letter
+ text = text[0].upper() + text[1:] if text else ""
+ break
+
+ dictionary["text"] = text.encode("utf-8")
+ return dictionary
+
+ return _f
+
+
+def identity(x):
+ return x
+
+
+class ImageTransform:
+ def __init__(self,
+ resize_shorter_edge: int = 256,
+ crop_size: int = 256,
+ random_crop: bool = True,
+ random_flip: bool = True,
+ normalize_mean: List[float] = [0., 0., 0.],
+ normalize_std: List[float] = [1., 1., 1.]):
+ """Initializes the WebDatasetReader with specified augmentation parameters.
+
+ Args:
+ resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
+ crop_size: An integer, the size to crop the input image to.
+ random_crop: A boolean, whether to use random crop augmentation during training.
+ random_flip: A boolean, whether to use random flipping augmentation during training.
+ normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
+ normalize_std: A list of float, the normalization std used to normalize the image tensor.
+
+ Raises:
+ NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"].
+ """
+ train_transform = []
+ interpolation = transforms.InterpolationMode.BICUBIC
+
+ train_transform.append(
+ transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True))
+ if random_crop:
+ train_transform.append(transforms.RandomCrop(crop_size))
+ else:
+ train_transform.append(transforms.CenterCrop(crop_size))
+ if random_flip:
+ train_transform.append(transforms.RandomHorizontalFlip())
+ train_transform.append(transforms.ToTensor())
+ # normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
+ # normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
+ train_transform.append(transforms.Normalize(normalize_mean, normalize_std))
+
+ self.train_transform = transforms.Compose(train_transform)
+ self.eval_transform = transforms.Compose(
+ [
+ # Note that we always resize to crop_size during eval to ensure the results
+ # can be compared against reference numbers on ImageNet etc.
+ transforms.Resize(crop_size, interpolation=interpolation, antialias=True),
+ transforms.CenterCrop(crop_size),
+ transforms.ToTensor(),
+ transforms.Normalize(normalize_mean, normalize_std)
+ ]
+ )
+ print(f"self.train_transform: {self.train_transform}")
+ print(f"self.eval_transform: {self.eval_transform}")
+
+
+class SimpleImageDataset:
+ def __init__(
+ self,
+ train_shards_path: Union[Text, List[Text]],
+ eval_shards_path: Union[Text, List[Text]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers_per_gpu: int = 12,
+ resize_shorter_edge: int = 256,
+ crop_size: int = 256,
+ random_crop = True,
+ random_flip = True,
+ normalize_mean: List[float] = [0., 0., 0.],
+ normalize_std: List[float] = [1., 1., 1.],
+ dataset_with_class_label: bool = True,
+ dataset_with_text_label: bool = False,
+ res_ratio_filtering = False,
+ min_tokens = 32,
+ max_tokens = 128,
+ ):
+ """Initializes the WebDatasetReader class.
+
+ Args:
+ train_shards_path: A string or list of string, path to the training data shards in webdataset format.
+ eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format.
+ num_train_examples: An integer, total number of training examples.
+ per_gpu_batch_size: An integer, number of examples per GPU batch.
+ global_batch_size: An integer, total number of examples in a batch across all GPUs.
+ num_workers_per_gpu: An integer, number of workers per GPU.
+ resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
+ crop_size: An integer, the size to crop the input image to.
+ random_crop: A boolean, whether to use random crop augmentation during training.
+ random_flip: A boolean, whether to use random flipping augmentation during training.
+ normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
+ normalize_std: A list of float, the normalization std used to normalize the image tensor.
+ """
+ transform = ImageTransform(
+ resize_shorter_edge, crop_size, random_crop, random_flip,
+ normalize_mean, normalize_std)
+
+ if dataset_with_class_label:
+ train_processing_pipeline = [
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp",
+ class_id="cls",
+ handler=wds.warn_and_continue,
+ ),
+ wds.map(filter_keys(set(["image", "class_id", "filename"]))),
+ wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
+ wds.map_dict(
+ image=transform.train_transform,
+ class_id=lambda x: int(x),
+ attention_mask=lambda x: x,
+ handler=wds.warn_and_continue,
+ ),
+ ]
+ elif dataset_with_text_label:
+ train_processing_pipeline = [
+ wds.map(load_json),
+ wds.select(filter_by_res_ratio()) if res_ratio_filtering else wds.map(identity),
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]),only=["webp", "png", "jpg", "jpeg", "txt"], handler=wds.warn_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp",
+ text="txt",
+ handler=wds.warn_and_continue,
+ ),
+ wds.map(filter_keys(set(["image", "text", "__key__"]))),
+ wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
+ wds.map_dict(
+ image=transform.train_transform,
+ attention_mask=lambda x: x,
+ handler=wds.warn_and_continue,
+ ),
+ ]
+ else:
+ raise NotImplementedError
+
+ test_processing_pipeline = [
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp",
+ class_id="cls",
+ handler=wds.warn_and_continue,
+ ),
+ wds.map(filter_keys(set(["image", "class_id", "filename"]))),
+ wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
+ wds.map_dict(
+ image=transform.eval_transform,
+ class_id=lambda x: int(x),
+ # laplacian_var=lambda x: x,
+ attention_mask=lambda x: x,
+ handler=wds.warn_and_continue,
+ ),
+ ]
+
+ # Create train dataset and loader.
+ pipeline = [
+ wds.ResampledShards(train_shards_path),
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
+ wds.shuffle(bufsize=5000,
+ initial=1000),
+ *train_processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_batches = math.ceil(num_train_examples / global_batch_size)
+ num_worker_batches = math.ceil(num_train_examples /
+ (global_batch_size * num_workers_per_gpu))
+ num_batches = num_worker_batches * num_workers_per_gpu
+ num_samples = num_batches * global_batch_size
+
+ # Each worker is iterating over the complete dataset.
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers_per_gpu,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ # Add meta-data to dataloader instance for convenience.
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ # Create eval dataset and loader.
+ pipeline = [
+ wds.SimpleShardList(eval_shards_path),
+ wds.split_by_worker,
+ wds.tarfile_to_samples(handler=wds.ignore_and_continue),
+ *test_processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
+ ]
+ self._eval_dataset = wds.DataPipeline(*pipeline)
+ self._eval_dataloader = wds.WebLoader(
+ self._eval_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers_per_gpu,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+ @property
+ def eval_dataset(self):
+ return self._eval_dataset
+
+ @property
+ def eval_dataloader(self):
+ return self._eval_dataloader
+
+
+class PretoeknizedDataSetJSONL(Dataset):
+ def __init__(self, data_path):
+ super().__init__()
+ self.jsonl_file = data_path
+ self.num_lines = sum(1 for _ in open(self.jsonl_file))
+ # Ensure the file is cached
+ linecache.checkcache(self.jsonl_file)
+ print("Number of data:", self.num_lines)
+
+ def __len__(self):
+ return self.num_lines
+
+ def __getitem__(self, idx):
+ line = linecache.getline(self.jsonl_file, idx + 1).strip()
+ data = json.loads(line)
+ return torch.tensor(data["class_id"]), torch.tensor(data["tokens"])
+
+
+class PretokenizedWebDataset(SimpleImageDataset):
+ def __init__ (
+ self,
+ train_shards_path: Union[Text, List[Text]],
+ eval_shards_path: Union[Text, List[Text]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers_per_gpu: int,
+ resize_shorter_edge: int = 256,
+ crop_size: int = 256,
+ random_crop = True,
+ random_flip = True,
+ normalize_mean: List[float] = [0., 0., 0.],
+ normalize_std: List[float] = [1., 1., 1.],
+ process_recap = False,
+ use_recap_prob = 0.95,
+ ):
+ """Initializes the PretokenizedWebDataset class.
+
+ Text-to-image datasets are pretokenized with careful filtering (Tab. 7 in Supp.) to speed up the training
+ """
+ transform = ImageTransform(
+ resize_shorter_edge, crop_size, random_crop, random_flip,
+ normalize_mean, normalize_std)
+
+ def decode_npy(x):
+ arr = np.frombuffer(x, dtype=np.float16)
+ ret = torch.tensor(arr)
+ return ret
+
+ def decode_text(x):
+ ret = x.decode("utf-8")
+ return ret
+
+ train_processing_pipeline = [
+ wds.rename(
+ tokens="token.npy",
+ text="txt",
+ handler=wds.warn_and_continue,
+ ),
+ wds.map(process_recap_text(use_recap_prob) if process_recap else wds.map(identity)),
+ wds.map(filter_keys(set(["tokens", "text", "aes_score", "__key__"]))),
+ wds.map_dict(
+ tokens=decode_npy,
+ text=decode_text,
+ handler=wds.warn_and_continue,
+ ),
+ ]
+
+ test_processing_pipeline = [
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
+ wds.rename(
+ image="jpg;png;jpeg;webp",
+ handler=wds.warn_and_continue,
+ ),
+ wds.map_dict(
+ image=transform.eval_transform,
+ handler=wds.warn_and_continue,
+ ),
+ ]
+
+
+ # Create train dataset and loader.
+ pipeline = [
+ wds.ResampledShards(train_shards_path),
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
+ wds.shuffle(bufsize=5000,
+ initial=1000),
+ *train_processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_batches = math.ceil(num_train_examples / global_batch_size)
+ num_worker_batches = math.ceil(num_train_examples /
+ (global_batch_size * num_workers_per_gpu))
+ num_batches = num_worker_batches * num_workers_per_gpu
+ num_samples = num_batches * global_batch_size
+
+ # Each worker is iterating over the complete dataset.
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers_per_gpu,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ # Add meta-data to dataloader instance for convenience.
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ # Create eval dataset and loader.
+ pipeline = [
+ wds.SimpleShardList(eval_shards_path),
+ wds.split_by_worker,
+ wds.tarfile_to_samples(handler=wds.ignore_and_continue),
+ *test_processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
+ ]
+ self._eval_dataset = wds.DataPipeline(*pipeline)
+ self._eval_dataloader = wds.WebLoader(
+ self._eval_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers_per_gpu,
+ pin_memory=True,
+ persistent_workers=True,
+ )
\ No newline at end of file
diff --git a/evaluator/__init__.py b/evaluator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5774b098f8a2d51bd5ad898a7ad5015434d124c
--- /dev/null
+++ b/evaluator/__init__.py
@@ -0,0 +1 @@
+from .evaluator import VQGANEvaluator
\ No newline at end of file
diff --git a/evaluator/evaluator.py b/evaluator/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..8860328f0ed5f22512fde935472ee47b92d51be5
--- /dev/null
+++ b/evaluator/evaluator.py
@@ -0,0 +1,230 @@
+"""Evaluator for reconstruction results."""
+
+import warnings
+
+from typing import Sequence, Optional, Mapping, Text
+import numpy as np
+from scipy import linalg
+import torch
+import torch.nn.functional as F
+
+from .inception import get_inception_model
+
+
+def get_covariance(sigma: torch.Tensor, total: torch.Tensor, num_examples: int) -> torch.Tensor:
+ """Computes covariance of the input tensor.
+
+ Args:
+ sigma: A torch.Tensor, sum of outer products of input features.
+ total: A torch.Tensor, sum of all input features.
+ num_examples: An integer, number of examples in the input tensor.
+ Returns:
+ A torch.Tensor, covariance of the input tensor.
+ """
+ if num_examples == 0:
+ return torch.zeros_like(sigma)
+
+ sub_matrix = torch.outer(total, total)
+ sub_matrix = sub_matrix / num_examples
+
+ return (sigma - sub_matrix) / (num_examples - 1)
+
+
+class VQGANEvaluator:
+ def __init__(
+ self,
+ device,
+ enable_rfid: bool = True,
+ enable_inception_score: bool = True,
+ enable_codebook_usage_measure: bool = False,
+ enable_codebook_entropy_measure: bool = False,
+ num_codebook_entries: int = 1024
+ ):
+ """Initializes VQGAN Evaluator.
+
+ Args:
+ device: The device to use for evaluation.
+ enable_rfid: A boolean, whether enabling rFID score.
+ enable_inception_score: A boolean, whether enabling Inception Score.
+ enable_codebook_usage_measure: A boolean, whether enabling codebook usage measure.
+ enable_codebook_entropy_measure: A boolean, whether enabling codebook entropy measure.
+ num_codebook_entries: An integer, the number of codebook entries.
+ """
+ self._device = device
+
+ self._enable_rfid = enable_rfid
+ self._enable_inception_score = enable_inception_score
+ self._enable_codebook_usage_measure = enable_codebook_usage_measure
+ self._enable_codebook_entropy_measure = enable_codebook_entropy_measure
+ self._num_codebook_entries = num_codebook_entries
+
+ # Variables related to Inception score and rFID.
+ self._inception_model = None
+ self._is_num_features = 0
+ self._rfid_num_features = 0
+ if self._enable_inception_score or self._enable_rfid:
+ self._rfid_num_features = 2048
+ self._is_num_features = 1008
+ self._inception_model = get_inception_model().to(self._device)
+ self._inception_model.eval()
+ self._is_eps = 1e-16
+ self._rfid_eps = 1e-6
+
+ self.reset_metrics()
+
+ def reset_metrics(self):
+ """Resets all metrics."""
+ self._num_examples = 0
+ self._num_updates = 0
+
+ self._is_prob_total = torch.zeros(
+ self._is_num_features, dtype=torch.float64, device=self._device
+ )
+ self._is_total_kl_d = torch.zeros(
+ self._is_num_features, dtype=torch.float64, device=self._device
+ )
+ self._rfid_real_sigma = torch.zeros(
+ (self._rfid_num_features, self._rfid_num_features),
+ dtype=torch.float64, device=self._device
+ )
+ self._rfid_real_total = torch.zeros(
+ self._rfid_num_features, dtype=torch.float64, device=self._device
+ )
+ self._rfid_fake_sigma = torch.zeros(
+ (self._rfid_num_features, self._rfid_num_features),
+ dtype=torch.float64, device=self._device
+ )
+ self._rfid_fake_total = torch.zeros(
+ self._rfid_num_features, dtype=torch.float64, device=self._device
+ )
+
+ self._set_of_codebook_indices = set()
+ self._codebook_frequencies = torch.zeros((self._num_codebook_entries), dtype=torch.float64, device=self._device)
+
+ def update(
+ self,
+ real_images: torch.Tensor,
+ fake_images: torch.Tensor,
+ codebook_indices: Optional[torch.Tensor] = None
+ ):
+ """Updates the metrics with the given images.
+
+ Args:
+ real_images: A torch.Tensor, the real images.
+ fake_images: A torch.Tensor, the fake images.
+ codebook_indices: A torch.Tensor, the indices of the codebooks for each image.
+
+ Raises:
+ ValueError: If the fake images is not in RGB (3 channel).
+ ValueError: If the fake and real images have different shape.
+ """
+
+ batch_size = real_images.shape[0]
+ dim = tuple(range(1, real_images.ndim))
+ self._num_examples += batch_size
+ self._num_updates += 1
+
+ if self._enable_inception_score or self._enable_rfid:
+ # Quantize to uint8 as a real image.
+ fake_inception_images = (fake_images * 255).to(torch.uint8)
+ features_fake = self._inception_model(fake_inception_images)
+ inception_logits_fake = features_fake["logits_unbiased"]
+ inception_probabilities_fake = F.softmax(inception_logits_fake, dim=-1)
+
+ if self._enable_inception_score:
+ probabiliies_sum = torch.sum(inception_probabilities_fake, 0, dtype=torch.float64)
+
+ log_prob = torch.log(inception_probabilities_fake + self._is_eps)
+ if log_prob.dtype != inception_probabilities_fake.dtype:
+ log_prob = log_prob.to(inception_probabilities_fake)
+ kl_sum = torch.sum(inception_probabilities_fake * log_prob, 0, dtype=torch.float64)
+
+ self._is_prob_total += probabiliies_sum
+ self._is_total_kl_d += kl_sum
+
+ if self._enable_rfid:
+ real_inception_images = (real_images * 255).to(torch.uint8)
+ features_real = self._inception_model(real_inception_images)
+ if (features_real['2048'].shape[0] != features_fake['2048'].shape[0] or
+ features_real['2048'].shape[1] != features_fake['2048'].shape[1]):
+ raise ValueError(f"Number of features should be equal for real and fake.")
+
+ for f_real, f_fake in zip(features_real['2048'], features_fake['2048']):
+ self._rfid_real_total += f_real
+ self._rfid_fake_total += f_fake
+
+ self._rfid_real_sigma += torch.outer(f_real, f_real)
+ self._rfid_fake_sigma += torch.outer(f_fake, f_fake)
+
+ if self._enable_codebook_usage_measure:
+ self._set_of_codebook_indices |= set(torch.unique(codebook_indices, sorted=False).tolist())
+
+ if self._enable_codebook_entropy_measure:
+ entries, counts = torch.unique(codebook_indices, sorted=False, return_counts=True)
+ self._codebook_frequencies.index_add_(0, entries.int(), counts.double())
+
+
+ def result(self) -> Mapping[Text, torch.Tensor]:
+ """Returns the evaluation result."""
+ eval_score = {}
+
+ if self._num_examples < 1:
+ raise ValueError("No examples to evaluate.")
+
+ if self._enable_inception_score:
+ mean_probs = self._is_prob_total / self._num_examples
+ log_mean_probs = torch.log(mean_probs + self._is_eps)
+ if log_mean_probs.dtype != self._is_prob_total.dtype:
+ log_mean_probs = log_mean_probs.to(self._is_prob_total)
+ excess_entropy = self._is_prob_total * log_mean_probs
+ avg_kl_d = torch.sum(self._is_total_kl_d - excess_entropy) / self._num_examples
+
+ inception_score = torch.exp(avg_kl_d).item()
+ eval_score["InceptionScore"] = inception_score
+
+ if self._enable_rfid:
+ mu_real = self._rfid_real_total / self._num_examples
+ mu_fake = self._rfid_fake_total / self._num_examples
+ sigma_real = get_covariance(self._rfid_real_sigma, self._rfid_real_total, self._num_examples)
+ sigma_fake = get_covariance(self._rfid_fake_sigma, self._rfid_fake_total, self._num_examples)
+
+ mu_real, mu_fake = mu_real.cpu(), mu_fake.cpu()
+ sigma_real, sigma_fake = sigma_real.cpu(), sigma_fake.cpu()
+
+ diff = mu_real - mu_fake
+
+ # Product might be almost singular.
+ covmean, _ = linalg.sqrtm(sigma_real.mm(sigma_fake).numpy(), disp=False)
+ # Numerical error might give slight imaginary component.
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError("Imaginary component {}".format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ if not np.isfinite(covmean).all():
+ tr_covmean = np.sum(np.sqrt((
+ (np.diag(sigma_real) * self._rfid_eps) * (np.diag(sigma_fake) * self._rfid_eps))
+ / (self._rfid_eps * self._rfid_eps)
+ ))
+
+ rfid = float(diff.dot(diff).item() + torch.trace(sigma_real) + torch.trace(sigma_fake)
+ - 2 * tr_covmean
+ )
+ if torch.isnan(torch.tensor(rfid)) or torch.isinf(torch.tensor(rfid)):
+ warnings.warn("The product of covariance of train and test features is out of bounds.")
+
+ eval_score["rFID"] = rfid
+
+ if self._enable_codebook_usage_measure:
+ usage = float(len(self._set_of_codebook_indices)) / self._num_codebook_entries
+ eval_score["CodebookUsage"] = usage
+
+ if self._enable_codebook_entropy_measure:
+ probs = self._codebook_frequencies / self._codebook_frequencies.sum()
+ entropy = (-torch.log2(probs + 1e-8) * probs).sum()
+ eval_score["CodebookEntropy"] = entropy
+
+ return eval_score
diff --git a/evaluator/inception.py b/evaluator/inception.py
new file mode 100644
index 0000000000000000000000000000000000000000..cda0355bbe510617e474ef2128f501044f4eab7a
--- /dev/null
+++ b/evaluator/inception.py
@@ -0,0 +1,215 @@
+"""Inception model for FID evaluation.
+
+Reference:
+ https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
+"""
+import torch
+import torch.nn.functional as F
+
+from torch_fidelity.feature_extractor_base import FeatureExtractorBase
+from torch_fidelity.helpers import vassert
+from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2
+from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
+
+try:
+ from torchvision.models.utils import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+
+# Note: Compared shasum and models should be the same.
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
+
+class FeatureExtractorInceptionV3(FeatureExtractorBase):
+ INPUT_IMAGE_SIZE = 299
+
+ def __init__(
+ self,
+ name,
+ features_list,
+ **kwargs,
+ ):
+ """
+ InceptionV3 feature extractor for 2D RGB 24bit images.
+
+ Args:
+
+ name (str): Unique name of the feature extractor, must be the same as used in
+ :func:`register_feature_extractor`.
+
+ features_list (list): A list of the requested feature names, which will be produced for each input. This
+ feature extractor provides the following features:
+
+ - '64'
+ - '192'
+ - '768'
+ - '2048'
+ - 'logits_unbiased'
+ - 'logits'
+
+ """
+ super(FeatureExtractorInceptionV3, self).__init__(name, features_list)
+ self.feature_extractor_internal_dtype = torch.float64
+
+ self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
+ self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
+ self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
+ self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
+
+ self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
+ self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
+ self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
+
+ self.Mixed_5b = InceptionA(192, pool_features=32)
+ self.Mixed_5c = InceptionA(256, pool_features=64)
+ self.Mixed_5d = InceptionA(288, pool_features=64)
+ self.Mixed_6a = InceptionB(288)
+ self.Mixed_6b = InceptionC(768, channels_7x7=128)
+ self.Mixed_6c = InceptionC(768, channels_7x7=160)
+ self.Mixed_6d = InceptionC(768, channels_7x7=160)
+ self.Mixed_6e = InceptionC(768, channels_7x7=192)
+
+ self.Mixed_7a = InceptionD(768)
+ self.Mixed_7b = InceptionE_1(1280)
+ self.Mixed_7c = InceptionE_2(2048)
+ self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
+
+ self.fc = torch.nn.Linear(2048, 1008)
+
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
+ #state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu')
+ self.load_state_dict(state_dict)
+
+ self.to(self.feature_extractor_internal_dtype)
+ self.requires_grad_(False)
+ self.eval()
+
+ def forward(self, x):
+ vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
+ vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}')
+ features = {}
+ remaining_features = self.features_list.copy()
+
+ x = x.to(self.feature_extractor_internal_dtype)
+ # N x 3 x ? x ?
+
+ x = interpolate_bilinear_2d_like_tensorflow1x(
+ x,
+ size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
+ align_corners=False,
+ )
+ # N x 3 x 299 x 299
+
+ # x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph
+ x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too
+ # N x 3 x 299 x 299
+
+ x = self.Conv2d_1a_3x3(x)
+ # N x 32 x 149 x 149
+ x = self.Conv2d_2a_3x3(x)
+ # N x 32 x 147 x 147
+ x = self.Conv2d_2b_3x3(x)
+ # N x 64 x 147 x 147
+ x = self.MaxPool_1(x)
+ # N x 64 x 73 x 73
+
+ if '64' in remaining_features:
+ features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
+ remaining_features.remove('64')
+ if len(remaining_features) == 0:
+ return features
+
+ x = self.Conv2d_3b_1x1(x)
+ # N x 80 x 73 x 73
+ x = self.Conv2d_4a_3x3(x)
+ # N x 192 x 71 x 71
+ x = self.MaxPool_2(x)
+ # N x 192 x 35 x 35
+
+ if '192' in remaining_features:
+ features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
+ remaining_features.remove('192')
+ if len(remaining_features) == 0:
+ return features
+
+ x = self.Mixed_5b(x)
+ # N x 256 x 35 x 35
+ x = self.Mixed_5c(x)
+ # N x 288 x 35 x 35
+ x = self.Mixed_5d(x)
+ # N x 288 x 35 x 35
+ x = self.Mixed_6a(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6b(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6c(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6d(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6e(x)
+ # N x 768 x 17 x 17
+
+ if '768' in remaining_features:
+ features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
+ remaining_features.remove('768')
+ if len(remaining_features) == 0:
+ return features
+
+ x = self.Mixed_7a(x)
+ # N x 1280 x 8 x 8
+ x = self.Mixed_7b(x)
+ # N x 2048 x 8 x 8
+ x = self.Mixed_7c(x)
+ # N x 2048 x 8 x 8
+ x = self.AvgPool(x)
+ # N x 2048 x 1 x 1
+
+ x = torch.flatten(x, 1)
+ # N x 2048
+
+ if '2048' in remaining_features:
+ features['2048'] = x
+ remaining_features.remove('2048')
+ if len(remaining_features) == 0:
+ return features
+
+ if 'logits_unbiased' in remaining_features:
+ x = x.mm(self.fc.weight.T)
+ # N x 1008 (num_classes)
+ features['logits_unbiased'] = x
+ remaining_features.remove('logits_unbiased')
+ if len(remaining_features) == 0:
+ return features
+
+ x = x + self.fc.bias.unsqueeze(0)
+ else:
+ x = self.fc(x)
+ # N x 1008 (num_classes)
+
+ features['logits'] = x
+ return features
+
+ @staticmethod
+ def get_provided_features_list():
+ return '64', '192', '768', '2048', 'logits_unbiased', 'logits'
+
+ @staticmethod
+ def get_default_feature_layer_for_metric(metric):
+ return {
+ 'isc': 'logits_unbiased',
+ 'fid': '2048',
+ 'kid': '2048',
+ 'prc': '2048',
+ }[metric]
+
+ @staticmethod
+ def can_be_compiled():
+ return True
+
+ @staticmethod
+ def get_dummy_input_for_compile():
+ return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8)
+
+def get_inception_model():
+ model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"])
+ return model
diff --git a/examples/batch_inference.py b/examples/batch_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb0094ce5080f1c95aa3905c3cac3ba6f425b7cd
--- /dev/null
+++ b/examples/batch_inference.py
@@ -0,0 +1,241 @@
+#!/usr/bin/env python3
+"""Batch inference example for VibeToken.
+
+Demonstrates how to process multiple images efficiently in batches.
+
+Usage:
+ # Auto mode (recommended)
+ python examples/batch_inference.py --auto \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint path/to/checkpoint.bin \
+ --input_dir path/to/images/ \
+ --output_dir path/to/output/ \
+ --batch_size 4
+
+ # Manual mode
+ python examples/batch_inference.py \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint path/to/checkpoint.bin \
+ --input_dir path/to/images/ \
+ --output_dir path/to/output/ \
+ --batch_size 4 \
+ --resolution 512 \
+ --encoder_patch_size 16,32 \
+ --decoder_patch_size 16
+"""
+
+import argparse
+import time
+from pathlib import Path
+
+import torch
+from PIL import Image
+import numpy as np
+
+import sys
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
+
+
+def parse_patch_size(value):
+ """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
+ if value is None:
+ return None
+ if ',' in value:
+ parts = value.split(',')
+ return (int(parts[0]), int(parts[1]))
+ return int(value)
+
+
+def load_and_preprocess_image(path: Path, target_size: tuple = None, auto_mode: bool = False) -> tuple:
+ """Load and preprocess image.
+
+ Args:
+ path: Path to image
+ target_size: Optional target size (width, height) for resizing
+ auto_mode: If True, use auto_preprocess_image for cropping
+
+ Returns:
+ image: numpy array
+ patch_size: auto-determined patch size (if auto_mode) or None
+ """
+ img = Image.open(path).convert("RGB")
+
+ if auto_mode:
+ # Use centralized auto_preprocess_image
+ img, patch_size, info = auto_preprocess_image(img, verbose=False)
+ return np.array(img), patch_size, info
+ else:
+ if target_size:
+ img = img.resize(target_size, Image.LANCZOS)
+ # Always center crop to ensure dimensions divisible by 32
+ img = center_crop_to_multiple(img, multiple=32)
+ return np.array(img), None, None
+
+
+def main():
+ parser = argparse.ArgumentParser(description="VibeToken batch inference example")
+ parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
+ parser.add_argument("--input_dir", type=str, required=True, help="Directory with input images")
+ parser.add_argument("--output_dir", type=str, required=True, help="Directory for output images")
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
+ parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
+
+ # Auto mode
+ parser.add_argument("--auto", action="store_true",
+ help="Auto mode: automatically determine optimal settings per image")
+
+ # Manual mode options
+ parser.add_argument("--resolution", type=int, default=512, help="Target resolution (manual mode)")
+ parser.add_argument("--encoder_patch_size", type=str, default=None,
+ help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
+ parser.add_argument("--decoder_patch_size", type=str, default=None,
+ help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
+ args = parser.parse_args()
+
+ # Parse patch sizes
+ encoder_patch_size = parse_patch_size(args.encoder_patch_size)
+ decoder_patch_size = parse_patch_size(args.decoder_patch_size)
+
+ # Check CUDA
+ if args.device == "cuda" and not torch.cuda.is_available():
+ print("CUDA not available, falling back to CPU")
+ args.device = "cpu"
+
+ # Create output directory
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Load tokenizer
+ print(f"Loading tokenizer from {args.config}")
+ tokenizer = VibeTokenTokenizer.from_config(
+ config_path=args.config,
+ checkpoint_path=args.checkpoint,
+ device=args.device,
+ )
+
+ if args.auto:
+ print("Running in AUTO MODE - optimal settings determined per image")
+ else:
+ print(f"Running in MANUAL MODE - resolution: {args.resolution}")
+ if encoder_patch_size:
+ print(f" Encoder patch size: {encoder_patch_size}")
+ if decoder_patch_size:
+ print(f" Decoder patch size: {decoder_patch_size}")
+
+ # Find all images
+ input_dir = Path(args.input_dir)
+ image_extensions = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
+ image_paths = [p for p in input_dir.iterdir() if p.suffix.lower() in image_extensions]
+ print(f"Found {len(image_paths)} images")
+
+ if not image_paths:
+ print("No images found!")
+ return
+
+ # Process in batches
+ target_size = (args.resolution, args.resolution) if not args.auto else None
+ total_time = 0
+ num_processed = 0
+
+ if args.auto:
+ # AUTO MODE: Process images one by one since each may have different sizes
+ for i, path in enumerate(image_paths):
+ try:
+ img_array, patch_size, info = load_and_preprocess_image(path, auto_mode=True)
+ batch_array = img_array[np.newaxis, ...] # Add batch dim
+
+ start_time = time.time()
+
+ # Reconstruct with auto-determined patch size
+ height, width = info['cropped_size'][1], info['cropped_size'][0]
+ reconstructed = tokenizer.reconstruct(
+ batch_array,
+ encode_patch_size=patch_size,
+ decode_patch_size=patch_size,
+ target_height=height,
+ target_width=width,
+ )
+
+ if args.device == "cuda":
+ torch.cuda.synchronize()
+
+ batch_time = time.time() - start_time
+ total_time += batch_time
+ num_processed += 1
+
+ # Save output
+ output_images = tokenizer.to_pil(reconstructed)
+ output_path = output_dir / f"{path.stem}_recon.png"
+ output_images[0].save(output_path)
+
+ print(f"[{i+1}/{len(image_paths)}] {path.name}: "
+ f"{info['cropped_size'][0]}x{info['cropped_size'][1]}, "
+ f"patch_size={patch_size}, {batch_time:.2f}s")
+
+ except Exception as e:
+ print(f"Error processing {path}: {e}")
+ continue
+ else:
+ # MANUAL MODE: Batch processing with uniform size
+ for batch_start in range(0, len(image_paths), args.batch_size):
+ batch_paths = image_paths[batch_start:batch_start + args.batch_size]
+ batch_names = [p.stem for p in batch_paths]
+
+ # Load batch
+ batch_images = []
+ for path in batch_paths:
+ try:
+ img_array, _, _ = load_and_preprocess_image(path, target_size, auto_mode=False)
+ batch_images.append(img_array)
+ except Exception as e:
+ print(f"Error loading {path}: {e}")
+ continue
+
+ if not batch_images:
+ continue
+
+ # Stack into batch tensor
+ batch_array = np.stack(batch_images, axis=0)
+
+ # Measure time
+ start_time = time.time()
+
+ # Reconstruct
+ reconstructed = tokenizer.reconstruct(
+ batch_array,
+ encode_patch_size=encoder_patch_size,
+ decode_patch_size=decoder_patch_size,
+ target_height=args.resolution,
+ target_width=args.resolution,
+ )
+
+ # Synchronize if GPU
+ if args.device == "cuda":
+ torch.cuda.synchronize()
+
+ batch_time = time.time() - start_time
+ total_time += batch_time
+ num_processed += len(batch_images)
+
+ # Save outputs
+ output_images = tokenizer.to_pil(reconstructed)
+ for name, img in zip(batch_names[:len(output_images)], output_images):
+ output_path = output_dir / f"{name}_recon.png"
+ img.save(output_path)
+
+ print(f"Processed batch {batch_start // args.batch_size + 1}: "
+ f"{len(batch_images)} images in {batch_time:.2f}s "
+ f"({len(batch_images) / batch_time:.2f} img/s)")
+
+ # Summary
+ if num_processed > 0:
+ print(f"\nTotal: {num_processed} images in {total_time:.2f}s")
+ print(f"Average: {num_processed / total_time:.2f} images/sec")
+ print(f"Per image: {total_time / num_processed * 1000:.1f}ms")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/encode_decode.py b/examples/encode_decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7eef4436e9baa8fd8b6f3f5da3c022d1920c5ef
--- /dev/null
+++ b/examples/encode_decode.py
@@ -0,0 +1,172 @@
+#!/usr/bin/env python3
+"""Basic encode-decode example for VibeToken.
+
+Demonstrates how to:
+1. Load the tokenizer from config and checkpoint
+2. Encode an image to discrete tokens
+3. Decode tokens back to an image
+4. Save the reconstructed image
+
+Usage:
+ # Auto mode (recommended)
+ python examples/encode_decode.py --auto \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint path/to/checkpoint.bin \
+ --image path/to/image.jpg \
+ --output reconstructed.png
+
+ # Manual mode
+ python examples/encode_decode.py \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint path/to/checkpoint.bin \
+ --image path/to/image.jpg \
+ --output reconstructed.png \
+ --encoder_patch_size 16,32 \
+ --decoder_patch_size 16
+"""
+
+import argparse
+from pathlib import Path
+
+import torch
+from PIL import Image
+
+import sys
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
+
+
+def parse_patch_size(value):
+ """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
+ if value is None:
+ return None
+ if ',' in value:
+ parts = value.split(',')
+ return (int(parts[0]), int(parts[1]))
+ return int(value)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="VibeToken encode-decode example")
+ parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
+ parser.add_argument("--image", type=str, required=True, help="Path to input image")
+ parser.add_argument("--output", type=str, default="reconstructed.png", help="Output image path")
+ parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
+
+ # Auto mode
+ parser.add_argument("--auto", action="store_true",
+ help="Auto mode: automatically determine optimal settings")
+
+ parser.add_argument("--height", type=int, default=None, help="Output height (default: input height)")
+ parser.add_argument("--width", type=int, default=None, help="Output width (default: input width)")
+ parser.add_argument("--encoder_patch_size", type=str, default=None,
+ help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
+ parser.add_argument("--decoder_patch_size", type=str, default=None,
+ help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
+ parser.add_argument("--num_tokens", type=int, default=None, help="Number of tokens to encode")
+
+ args = parser.parse_args()
+
+ # Check if CUDA is available
+ if args.device == "cuda" and not torch.cuda.is_available():
+ print("CUDA not available, falling back to CPU")
+ args.device = "cpu"
+
+ print(f"Loading tokenizer from {args.config}")
+ tokenizer = VibeTokenTokenizer.from_config(
+ config_path=args.config,
+ checkpoint_path=args.checkpoint,
+ device=args.device,
+ )
+ print(f"Tokenizer loaded: codebook_size={tokenizer.codebook_size}, "
+ f"num_latent_tokens={tokenizer.num_latent_tokens}")
+
+ # Load image
+ print(f"Loading image from {args.image}")
+ image = Image.open(args.image).convert("RGB")
+ original_size = image.size # (W, H)
+ print(f"Original image size: {original_size[0]}x{original_size[1]}")
+
+ if args.auto:
+ # AUTO MODE - use centralized auto_preprocess_image
+ print("\n=== AUTO MODE ===")
+ image, patch_size, info = auto_preprocess_image(image, verbose=True)
+ encoder_patch_size = patch_size
+ decoder_patch_size = patch_size
+ height, width = info['cropped_size'][1], info['cropped_size'][0]
+ print("=================\n")
+
+ # Encode to tokens
+ print("Encoding image to tokens...")
+ print(f" Using encoder patch size: {encoder_patch_size}")
+ tokens = tokenizer.encode(image, patch_size=encoder_patch_size)
+ print(f"Token shape: {tokens.shape}")
+
+ # Decode back to image
+ print(f"Decoding tokens to image ({width}x{height})...")
+ print(f" Using decoder patch size: {decoder_patch_size}")
+ reconstructed = tokenizer.decode(
+ tokens, height=height, width=width, patch_size=decoder_patch_size
+ )
+
+ else:
+ # MANUAL MODE
+ # Parse patch sizes
+ encoder_patch_size = parse_patch_size(args.encoder_patch_size)
+ decoder_patch_size = parse_patch_size(args.decoder_patch_size)
+
+ # Always center crop to ensure dimensions divisible by 32
+ image = center_crop_to_multiple(image, multiple=32)
+ cropped_size = image.size # (W, H)
+ if cropped_size != original_size:
+ print(f"Center cropped to {cropped_size[0]}x{cropped_size[1]} (divisible by 32)")
+
+ # Encode to tokens
+ print("Encoding image to tokens...")
+ if encoder_patch_size:
+ print(f" Using encoder patch size: {encoder_patch_size}")
+ tokens = tokenizer.encode(image, patch_size=encoder_patch_size, num_tokens=args.num_tokens)
+ print(f"Token shape: {tokens.shape}")
+
+ if tokenizer.model.quantize_mode == "mvq":
+ print(f" - Batch size: {tokens.shape[0]}")
+ print(f" - Num codebooks: {tokens.shape[1]}")
+ print(f" - Sequence length: {tokens.shape[2]}")
+ else:
+ print(f" - Batch size: {tokens.shape[0]}")
+ print(f" - Sequence length: {tokens.shape[1]}")
+
+ # Decode back to image (use cropped size as default)
+ height = args.height or cropped_size[1]
+ width = args.width or cropped_size[0]
+ print(f"Decoding tokens to image ({width}x{height})...")
+ if decoder_patch_size:
+ print(f" Using decoder patch size: {decoder_patch_size}")
+
+ reconstructed = tokenizer.decode(
+ tokens, height=height, width=width, patch_size=decoder_patch_size
+ )
+
+ print(f"Reconstructed image shape: {reconstructed.shape}")
+
+ # Convert to PIL and save
+ output_images = tokenizer.to_pil(reconstructed)
+ output_path = Path(args.output)
+ output_images[0].save(output_path)
+ print(f"Saved reconstructed image to {output_path}")
+
+ # Compute PSNR (compare with cropped image)
+ import numpy as np
+ original_np = np.array(image).astype(np.float32)
+ recon_np = np.array(output_images[0]).astype(np.float32)
+ if original_np.shape == recon_np.shape:
+ mse = np.mean((original_np - recon_np) ** 2)
+ if mse > 0:
+ psnr = 20 * np.log10(255.0 / np.sqrt(mse))
+ print(f"PSNR: {psnr:.2f} dB")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/generate.py b/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..b893bff5c6117e4970ea9b18b3fa6bd3e574ff3c
--- /dev/null
+++ b/generate.py
@@ -0,0 +1,240 @@
+# Modified from:
+# DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
+
+"""Example run:
+python generate.py \
+ --gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \
+ --gpt-model GPT-XXL --num-output-layer 4 \
+ --num-codebooks 8 --codebook-size 32768 \
+ --image-size 256 --cfg-scale 2.0 --top-k 0 --temperature 1.0 \
+ --class-dropout-prob 0.1 \
+ --extra-layers "QKV" \
+ --latent-size 65 \
+ --config ./configs/vibetoken_ll.yaml \
+ --vq-ckpt ./checkpoints/VibeToken_LL.bin \
+ --sample-dir ./assets/ \
+ --skip-folder-creation \
+ --compile \
+ --decoder-patch-size 16,16 \
+ --target-resolution 1024,1024 \
+ --llamagen-target-resolution 256,256 \
+ --precision bf16
+"""
+
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+torch.set_float32_matmul_precision('high')
+setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
+setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from tqdm import tqdm
+import os
+from PIL import Image
+import numpy as np
+import math
+import argparse
+import sys
+from omegaconf import OmegaConf
+
+from vibetokengen.model import GPT_models
+from vibetokengen.generate import generate
+
+from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
+
+
+def create_npz_from_sample_folder(sample_dir, num=50_000):
+ """
+ Builds a single .npz file from a folder of .png samples.
+ """
+ samples = []
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
+ samples.append(sample_np)
+ samples = np.stack(samples)
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
+ npz_path = f"{sample_dir}.npz"
+ np.savez(npz_path, arr_0=samples)
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
+ return npz_path
+
+
+def main(args):
+ # Setup PyTorch:
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
+ torch.set_grad_enabled(False)
+
+ # Set global seed for reproducibility
+ torch.manual_seed(args.global_seed)
+ np.random.seed(args.global_seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(args.global_seed)
+ torch.cuda.manual_seed_all(args.global_seed)
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
+
+ # Load VibeToken model
+ vq_model = VibeTokenTokenizer.from_config(
+ args.config,
+ args.vq_ckpt,
+ device=device,
+ dtype=precision,
+ )
+ print(f"VibeToken image tokenizer is loaded")
+
+ # create and load gpt model
+ gpt_model = GPT_models[args.gpt_model](
+ vocab_size=args.codebook_size,
+ block_size=args.latent_size,
+ num_classes=args.num_classes,
+ cls_token_num=args.cls_token_num,
+ model_type=args.gpt_type,
+ num_codebooks=args.num_codebooks,
+ n_output_layer=args.num_output_layer,
+ class_dropout_prob=args.class_dropout_prob,
+ extra_layers=args.extra_layers,
+ capping=args.capping,
+ ).to(device=device, dtype=precision)
+ print(f"GPT model is loaded")
+
+ checkpoint = torch.load(args.gpt_ckpt, map_location="cpu", weights_only=False)
+ if args.from_fsdp: # fsdp
+ model_weight = checkpoint
+ elif "model" in checkpoint: # ddp
+ model_weight = checkpoint["model"]
+ elif "module" in checkpoint: # deepspeed
+ model_weight = checkpoint["module"]
+ elif "state_dict" in checkpoint:
+ model_weight = checkpoint["state_dict"]
+ else:
+ raise Exception("please check model weight, maybe add --from-fsdp to run command")
+ gpt_model.load_state_dict(model_weight, strict=True)
+ gpt_model.eval()
+ del checkpoint
+
+ print(f"GPT model weights are loaded")
+
+ if args.compile:
+ print(f"compiling the model...")
+ gpt_model = torch.compile(
+ gpt_model,
+ mode="reduce-overhead",
+ fullgraph=True
+ ) # requires PyTorch 2.0 (optional)
+ else:
+ print(f"no model compile")
+
+ print(f"GPT model is compiled")
+
+ # Create folder to save samples:
+ model_string_name = args.gpt_model.replace("/", "-")
+ if args.from_fsdp:
+ ckpt_string_name = args.gpt_ckpt.split('/')[-2]
+ else:
+ ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
+ folder_name = f"{model_string_name}-{ckpt_string_name}-target-resolution-{args.target_resolution}-llamagen-target-resolution-{args.llamagen_target_resolution}-vibetoken-" \
+ f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
+ f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
+ if args.skip_folder_creation:
+ sample_folder_dir = args.sample_dir
+ else:
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
+
+ os.makedirs(sample_folder_dir, exist_ok=True)
+ print(f"Saving .png samples at {sample_folder_dir}")
+
+ multiplier = 2 if args.cfg_scale > 1.0 else 1
+
+ # Use fixed class labels
+ class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
+ c_indices = torch.tensor(class_labels, device=device)
+ n = len(class_labels)
+ nrow = 4 # 2 rows x 4 columns for 8 images
+
+ index_sample = generate(
+ gpt_model, c_indices, args.latent_size, args.num_codebooks,
+ cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
+ target_h=torch.tensor(args.llamagen_target_resolution[0]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1),
+ target_w=torch.tensor(args.llamagen_target_resolution[1]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1),
+ temperature=args.temperature, top_k=args.top_k,
+ top_p=args.top_p, sample_logits=True,
+ )
+
+ # Use VibeToken decode_tokens method
+ # VibeToken expects tokens in shape (batch_size, seq_len, 1)
+ index_sample = index_sample.unsqueeze(2)
+ samples = vq_model.decode(
+ index_sample,
+ height=args.target_resolution[0],
+ width=args.target_resolution[1],
+ patch_size=args.decoder_patch_size
+ )
+
+ # VibeToken output is in [0, 1] range, clamp and convert to uint8
+ samples = torch.clamp(samples, 0, 1)
+
+ # Create a grid of images (2 rows x 4 columns)
+ from torchvision.utils import make_grid
+ grid = make_grid(samples, nrow=nrow, padding=2, normalize=False)
+
+ # Convert to PIL and save
+ grid_np = (grid.permute(1, 2, 0).to(torch.float32).cpu().numpy() * 255).astype('uint8')
+ Image.fromarray(grid_np).save(f"{sample_folder_dir}/generated_images.png")
+ print(f"Saved grid of {n} images to {sample_folder_dir}/generated_images.png")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
+ parser.add_argument("--gpt-ckpt", type=str, default=None)
+ parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i",
+ help="class-conditional or text-conditional")
+ parser.add_argument("--from-fsdp", action='store_true')
+ parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
+ parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
+ parser.add_argument("--compile", action='store_true', default=True)
+ # parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
+ parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
+ parser.add_argument("--config", type=str, required=True, help="Path to VibeToken config file")
+ parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
+ parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
+ parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
+ parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
+ parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
+ parser.add_argument("--num-classes", type=int, default=1000)
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
+ parser.add_argument("--cfg-interval", type=float, default=-1)
+ parser.add_argument("--sample-dir", type=str, default="samples")
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
+ parser.add_argument("--num-fid-samples", type=int, default=50000)
+ parser.add_argument("--global-seed", type=int, default=0) # not used
+ parser.add_argument("--top-k", type=int, default=500, help="top-k value to sample with")
+ parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
+ parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
+ parser.add_argument("--num-codebooks", type=int, default=1)
+ parser.add_argument("--num-output-layer", type=int, default=1)
+ parser.add_argument("--class-dropout-prob", type=float, default=0.1)
+ parser.add_argument("--extra-layers", type=str, choices=['QK', 'QKV', 'FC', 'cap', 'clip', 'QK_cap', 'QKV_cap', 'QK_clip', 'QKV_clip', 'QK_FC_cap', 'QKV_FC_cap', 'QK_FC_clip', 'QKV_FC_clip'], default=None,
+ help="Type of extra layers to add: QK (query-key), QKV (query-key-value), FC (fully connected), cap (caption), clip (clip), QK_cap (query-key-caption), QKV_cap (query-key-value-caption), QK_clip (query-key-clip), QKV_clip (query-key-value-clip), QK_FC_cap (query-key-fully-connected-caption), QKV_FC_cap (query-key-value-fully-connected-caption), QK_FC_clip (query-key-fully-connected-clip), QKV_FC_clip (query-key-value-fully-connected-clip)")
+ parser.add_argument("--capping", type=float, default=50.0, help="Capping for attention softmax")
+
+ # VibeToken dynamic
+ parser.add_argument("--decoder-patch-size", type=str, default="8,8", help="Decoder patch size as 'width,height'")
+ parser.add_argument("--target-resolution", type=str, default="256,256", help="Target resolution as 'width,height'")
+ parser.add_argument("--llamagen-target-resolution", type=str, default="256,256", help="LlamaGen target resolution as 'width,height'")
+
+ parser.add_argument("--latent-size", type=int, default=16, help="Latent size")
+ parser.add_argument("--skip-folder-creation", action='store_true', default=False, help="skip folder creation")
+
+ args = parser.parse_args()
+
+ args.decoder_patch_size = tuple(map(int, args.decoder_patch_size.split(",")))
+ args.target_resolution = tuple(map(int, args.target_resolution.split(",")))
+ args.llamagen_target_resolution = tuple(map(int, args.llamagen_target_resolution.split(",")))
+
+ main(args)
\ No newline at end of file
diff --git a/generator/__init__.py b/generator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b1caa2c3a152ce588930de89eca95b6b437fcf0
--- /dev/null
+++ b/generator/__init__.py
@@ -0,0 +1,4 @@
+"""Generator module placeholder for VibeToken-Gen integration."""
+
+# Future: Add GPT-based generator for image synthesis
+# from .gpt import VibeTokenGenerator
diff --git a/modeling/__init__.py b/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modeling/modules/__init__.py b/modeling/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5abed7a88cf447243ab6dc17ba76ae4097d7910a
--- /dev/null
+++ b/modeling/modules/__init__.py
@@ -0,0 +1,6 @@
+from .base_model import BaseModel
+from .ema_model import EMAModel
+from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, ReconstructionLoss_Single_Stage
+from .blocks import TiTokEncoder, TiTokDecoder, TATiTokDecoder, UViTBlock
+from .maskgit_vqgan import Decoder as Pixel_Decoder
+from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
\ No newline at end of file
diff --git a/modeling/modules/base_model.py b/modeling/modules/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb1e4791155ec018f2d663174208b61027a3317
--- /dev/null
+++ b/modeling/modules/base_model.py
@@ -0,0 +1,124 @@
+"""Base class implementation for models.
+
+Reference:
+ https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py
+"""
+import os
+from typing import Union, Callable, Dict, Optional
+
+import torch
+
+
+class BaseModel(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def save_pretrained_weight(
+ self,
+ save_directory: Union[str, os.PathLike],
+ save_function: Callable = None,
+ state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ ):
+ """Saves a model and its configuration file to a directory.
+
+ Args:
+ save_directory: A string or os.PathLike, directory to which to save.
+ Will be created if it doesn't exist.
+ save_function: A Callable function, the function to use to save the state dictionary.
+ Useful on distributed training like TPUs when one need to replace `torch.save` by
+ another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`.
+ state_dict: A dictionary from str to torch.Tensor, the state dictionary to save.
+ If `None`, the model's state dictionary will be saved.
+ """
+ if os.path.isfile(save_directory):
+ print(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ if state_dict is None:
+ state_dict = model_to_save.state_dict()
+ weights_name = "pytorch_model.bin"
+
+ save_function(state_dict, os.path.join(save_directory, weights_name))
+
+ print(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
+
+ def load_pretrained_weight(
+ self,
+ pretrained_model_path: Union[str, os.PathLike],
+ strict_loading: bool = True,
+ torch_dtype: Optional[torch.dtype] = None
+ ):
+ r"""Instantiates a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ Args:
+ pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights.
+
+ Raises:
+ ValueError: If pretrained_model_path does not exist.
+ """
+ # If pretrained_model_path is a file, set model_file to this file.
+ if os.path.isfile(pretrained_model_path):
+ model_file = pretrained_model_path
+ # If pretrained_model_path is a directory, set model_file to the path of the
+ # file "pytorch_model.bin" in this directory.
+ elif os.path.isdir(pretrained_model_path):
+ pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
+ if os.path.isfile(pretrained_model_path):
+ model_file = pretrained_model_path
+ else:
+ raise ValueError(f"{pretrained_model_path} does not exist")
+ else:
+ raise ValueError(f"{pretrained_model_path} does not exist")
+
+ # Load model state from checkpoint.
+ checkpoint = torch.load(model_file, map_location="cpu")
+ # Load state dictionary into self.
+ msg = self.load_state_dict(checkpoint, strict=strict_loading)
+ # Print information about loading weights.
+ print(f"loading weight from {model_file}, msg: {msg}")
+ # If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype.
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ self.to(torch_dtype)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default.
+ self.eval()
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """Gets the number of parameters in the module.
+
+ Args:
+ only_trainable: A boolean, whether to only include trainable parameters.
+ exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings.
+
+ Returns:
+ An integer, the number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+ ]
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+ else:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
diff --git a/modeling/modules/blocks.py b/modeling/modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..caa5fda7bc78647c76533fa88f605762beab231d
--- /dev/null
+++ b/modeling/modules/blocks.py
@@ -0,0 +1,617 @@
+"""Transformer building blocks.
+
+Reference:
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
+ https://github.com/baofff/U-ViT/blob/main/libs/timm.py
+"""
+
+import math
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from collections import OrderedDict
+import einops
+from einops.layers.torch import Rearrange
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ n_head,
+ mlp_ratio = 4.0,
+ act_layer = nn.GELU,
+ norm_layer = nn.LayerNorm
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.mlp_ratio = mlp_ratio
+ # optionally we can disable the FFN
+ if mlp_ratio > 0:
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+
+ def attention(
+ self,
+ x: torch.Tensor
+ ):
+ return self.attn(x, x, x, need_weights=False)[0]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ):
+ attn_output = self.attention(x=self.ln_1(x))
+ x = x + attn_output
+ if self.mlp_ratio > 0:
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
+ ATTENTION_MODE = 'flash'
+else:
+ try:
+ import xformers
+ import xformers.ops
+ ATTENTION_MODE = 'xformers'
+ except:
+ ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, L, C = x.shape
+
+ qkv = self.qkv(x)
+ if ATTENTION_MODE == 'flash':
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'xformers':
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
+ x = xformers.ops.memory_efficient_attention(q, k, v)
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
+ elif ATTENTION_MODE == 'math':
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+ else:
+ raise NotImplemented
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class UViTBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, x, skip=None):
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
+ else:
+ return self._forward(x, skip)
+
+ def _forward(self, x, skip=None):
+ if self.skip_linear is not None:
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+def _expand_token(token, batch_size: int):
+ return token.unsqueeze(0).expand(batch_size, -1, -1)
+
+
+class TiTokEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.image_size = config.dataset.preprocessing.crop_size
+ self.patch_size = config.model.vq_model.vit_enc_patch_size
+ self.grid_size = self.image_size // self.patch_size
+ self.model_size = config.model.vq_model.vit_enc_model_size
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
+ self.token_size = config.model.vq_model.token_size
+
+ if config.model.vq_model.get("quantize_mode", "vq") == "vae":
+ self.token_size = self.token_size * 2 # needs to split into mean and std
+
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
+
+ self.width = {
+ "small": 512,
+ "base": 768,
+ "large": 1024,
+ }[self.model_size]
+ self.num_layers = {
+ "small": 8,
+ "base": 12,
+ "large": 24,
+ }[self.model_size]
+ self.num_heads = {
+ "small": 8,
+ "base": 12,
+ "large": 16,
+ }[self.model_size]
+
+ self.patch_embed = nn.Conv2d(
+ in_channels=3, out_channels=self.width,
+ kernel_size=self.patch_size, stride=self.patch_size, bias=True)
+
+ scale = self.width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
+ self.positional_embedding = nn.Parameter(
+ scale * torch.randn(self.grid_size ** 2 + 1, self.width))
+ self.latent_token_positional_embedding = nn.Parameter(
+ scale * torch.randn(self.num_latent_tokens, self.width))
+ self.ln_pre = nn.LayerNorm(self.width)
+ self.transformer = nn.ModuleList()
+ for i in range(self.num_layers):
+ self.transformer.append(ResidualAttentionBlock(
+ self.width, self.num_heads, mlp_ratio=4.0
+ ))
+ self.ln_post = nn.LayerNorm(self.width)
+ self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
+
+ def forward(self, pixel_values, latent_tokens):
+ batch_size = pixel_values.shape[0]
+ x = pixel_values
+ x = self.patch_embed(x)
+ x = x.reshape(x.shape[0], x.shape[1], -1)
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ # class embeddings and positional embeddings
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
+ x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width]
+
+
+ latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
+ latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)
+ x = torch.cat([x, latent_tokens], dim=1)
+
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for i in range(self.num_layers):
+ x = self.transformer[i](x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ latent_tokens = x[:, 1+self.grid_size**2:]
+ latent_tokens = self.ln_post(latent_tokens)
+ # fake 2D shape
+ if self.is_legacy:
+ latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)
+ else:
+ # Fix legacy problem.
+ latent_tokens = latent_tokens.reshape(batch_size, self.num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)
+ latent_tokens = self.conv_out(latent_tokens)
+ latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)
+ return latent_tokens
+
+
+class TiTokDecoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.image_size = config.dataset.preprocessing.crop_size
+ self.patch_size = config.model.vq_model.vit_dec_patch_size
+ self.grid_size = self.image_size // self.patch_size
+ self.model_size = config.model.vq_model.vit_dec_model_size
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
+ self.token_size = config.model.vq_model.token_size
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
+ self.width = {
+ "small": 512,
+ "base": 768,
+ "large": 1024,
+ }[self.model_size]
+ self.num_layers = {
+ "small": 8,
+ "base": 12,
+ "large": 24,
+ }[self.model_size]
+ self.num_heads = {
+ "small": 8,
+ "base": 12,
+ "large": 16,
+ }[self.model_size]
+
+ self.decoder_embed = nn.Linear(
+ self.token_size, self.width, bias=True)
+ scale = self.width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
+ self.positional_embedding = nn.Parameter(
+ scale * torch.randn(self.grid_size ** 2 + 1, self.width))
+ # add mask token and query pos embed
+ self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
+ self.latent_token_positional_embedding = nn.Parameter(
+ scale * torch.randn(self.num_latent_tokens, self.width))
+ self.ln_pre = nn.LayerNorm(self.width)
+ self.transformer = nn.ModuleList()
+ for i in range(self.num_layers):
+ self.transformer.append(ResidualAttentionBlock(
+ self.width, self.num_heads, mlp_ratio=4.0
+ ))
+ self.ln_post = nn.LayerNorm(self.width)
+
+ if self.is_legacy:
+ self.ffn = nn.Sequential(
+ nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
+ nn.Tanh(),
+ nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
+ )
+ self.conv_out = nn.Identity()
+ else:
+ # Directly predicting RGB pixels
+ self.ffn = nn.Sequential(
+ nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True),
+ Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
+ p1 = self.patch_size, p2 = self.patch_size),)
+ self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
+
+ def forward(self, z_quantized):
+ N, C, H, W = z_quantized.shape
+ assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
+ x = self.decoder_embed(x)
+
+ batchsize, seq_len, _ = x.shape
+
+ mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
+ mask_tokens], dim=1)
+ mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
+ x = x + self.latent_token_positional_embedding[:seq_len]
+ x = torch.cat([mask_tokens, x], dim=1)
+
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for i in range(self.num_layers):
+ x = self.transformer[i](x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = x[:, 1:1+self.grid_size**2] # remove cls embed
+ x = self.ln_post(x)
+ # N L D -> N D H W
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
+ x = self.ffn(x.contiguous())
+ x = self.conv_out(x)
+ return x
+
+
+class TATiTokDecoder(TiTokDecoder):
+ def __init__(self, config):
+ super().__init__(config)
+ scale = self.width ** -0.5
+ self.text_context_length = config.model.vq_model.get("text_context_length", 77)
+ self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768)
+ self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width)
+ self.text_guidance_positional_embedding = nn.Parameter(scale * torch.randn(self.text_context_length, self.width))
+
+ def forward(self, z_quantized, text_guidance):
+ N, C, H, W = z_quantized.shape
+ assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
+ x = self.decoder_embed(x)
+
+ batchsize, seq_len, _ = x.shape
+
+ mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
+ mask_tokens], dim=1)
+ mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
+ x = x + self.latent_token_positional_embedding[:seq_len]
+ x = torch.cat([mask_tokens, x], dim=1)
+
+ text_guidance = self.text_guidance_proj(text_guidance)
+ text_guidance = text_guidance + self.text_guidance_positional_embedding
+ x = torch.cat([x, text_guidance], dim=1)
+
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for i in range(self.num_layers):
+ x = self.transformer[i](x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = x[:, 1:1+self.grid_size**2] # remove cls embed
+ x = self.ln_post(x)
+ # N L D -> N D H W
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
+ x = self.ffn(x.contiguous())
+ x = self.conv_out(x)
+ return x
+
+
+class WeightTiedLMHead(nn.Module):
+ def __init__(self, embeddings, target_codebook_size):
+ super().__init__()
+ self.weight = embeddings.weight
+ self.target_codebook_size = target_codebook_size
+
+ def forward(self, x):
+ # x shape: [batch_size, seq_len, embed_dim]
+ # Get the weights for the target codebook size
+ weight = self.weight[:self.target_codebook_size] # Shape: [target_codebook_size, embed_dim]
+ # Compute the logits by matrix multiplication
+ logits = torch.matmul(x, weight.t()) # Shape: [batch_size, seq_len, target_codebook_size]
+ return logits
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class ResBlock(nn.Module):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ """
+
+ def __init__(
+ self,
+ channels
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
+ self.mlp = nn.Sequential(
+ nn.Linear(channels, channels, bias=True),
+ nn.SiLU(),
+ nn.Linear(channels, channels, bias=True),
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 3 * channels, bias=True)
+ )
+
+ def forward(self, x, y):
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
+ h = self.mlp(h)
+ return x + gate_mlp * h
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer adopted from DiT.
+ """
+ def __init__(self, model_channels, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class SimpleMLPAdaLN(nn.Module):
+ """
+ The MLP for Diffusion Loss.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param z_channels: channels in the condition.
+ :param num_res_blocks: number of residual blocks per downsample.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ z_channels,
+ num_res_blocks,
+ grad_checkpointing=False,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.grad_checkpointing = grad_checkpointing
+
+ self.time_embed = TimestepEmbedder(model_channels)
+ self.cond_embed = nn.Linear(z_channels, model_channels)
+
+ self.input_proj = nn.Linear(in_channels, model_channels)
+
+ res_blocks = []
+ for i in range(num_res_blocks):
+ res_blocks.append(ResBlock(
+ model_channels,
+ ))
+
+ self.res_blocks = nn.ModuleList(res_blocks)
+ self.final_layer = FinalLayer(model_channels, out_channels)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize timestep embedding MLP
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers
+ for block in self.res_blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, c):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C] Tensor of inputs.
+ :param t: a 1-D batch of timesteps.
+ :param c: conditioning from AR transformer.
+ :return: an [N x C] Tensor of outputs.
+ """
+ x = self.input_proj(x)
+ t = self.time_embed(t)
+ c = self.cond_embed(c)
+
+ y = t + c
+
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ for block in self.res_blocks:
+ x = checkpoint(block, x, y)
+ else:
+ for block in self.res_blocks:
+ x = block(x, y)
+
+ return self.final_layer(x, y)
+
+ def forward_with_cfg(self, x, t, c, cfg_scale):
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, c)
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
\ No newline at end of file
diff --git a/modeling/modules/discriminator.py b/modeling/modules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fcb38e84b90c9e028d9666a9c1c47a4a0173dc8
--- /dev/null
+++ b/modeling/modules/discriminator.py
@@ -0,0 +1,124 @@
+"""Discriminator implementation."""
+import functools
+import math
+from typing import Tuple
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .maskgit_vqgan import Conv2dSame
+
+
+class BlurBlock(torch.nn.Module):
+ def __init__(self,
+ kernel: Tuple[int] = (1, 3, 3, 1)
+ ):
+ super().__init__()
+
+ kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
+ kernel = kernel[None, :] * kernel[:, None]
+ kernel /= kernel.sum()
+ kernel = kernel.unsqueeze(0).unsqueeze(0)
+ self.register_buffer("kernel", kernel)
+
+ def calc_same_pad(self, i: int, k: int, s: int) -> int:
+ return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ ic, ih, iw = x.size()[-3:]
+ pad_h = self.calc_same_pad(i=ih, k=4, s=2)
+ pad_w = self.calc_same_pad(i=iw, k=4, s=2)
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+
+ weight = self.kernel.expand(ic, -1, -1, -1)
+
+ out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1])
+ return out
+
+
+class NLayerDiscriminator(torch.nn.Module):
+ def __init__(
+ self,
+ num_channels: int = 3,
+ hidden_channels: int = 128,
+ num_stages: int = 3,
+ blur_resample: bool = True,
+ blur_kernel_size: int = 4
+ ):
+ """ Initializes the NLayerDiscriminator.
+
+ Args:
+ num_channels -> int: The number of input channels.
+ hidden_channels -> int: The number of hidden channels.
+ num_stages -> int: The number of stages.
+ blur_resample -> bool: Whether to use blur resampling.
+ blur_kernel_size -> int: The blur kernel size.
+ """
+ super().__init__()
+ assert num_stages > 0, "Discriminator cannot have 0 stages"
+ assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]"
+
+ in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
+ init_kernel_size = 5
+ activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
+
+ self.block_in = torch.nn.Sequential(
+ Conv2dSame(
+ num_channels,
+ hidden_channels,
+ kernel_size=init_kernel_size
+ ),
+ activation(),
+ )
+
+ BLUR_KERNEL_MAP = {
+ 3: (1,2,1),
+ 4: (1,3,3,1),
+ 5: (1,4,6,4,1),
+ }
+
+ discriminator_blocks = []
+ for i_level in range(num_stages):
+ in_channels = hidden_channels * in_channel_mult[i_level]
+ out_channels = hidden_channels * in_channel_mult[i_level + 1]
+ block = torch.nn.Sequential(
+ Conv2dSame(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ ),
+ torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]),
+ torch.nn.GroupNorm(32, out_channels),
+ activation(),
+ )
+ discriminator_blocks.append(block)
+
+ self.blocks = torch.nn.ModuleList(discriminator_blocks)
+
+ self.pool = torch.nn.AdaptiveMaxPool2d((16, 16))
+
+ self.to_logits = torch.nn.Sequential(
+ Conv2dSame(out_channels, out_channels, 1),
+ activation(),
+ Conv2dSame(out_channels, 1, kernel_size=5)
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """ Forward pass.
+
+ Args:
+ x -> torch.Tensor: The input tensor.
+
+ Returns:
+ output -> torch.Tensor: The output tensor.
+ """
+ hidden_states = self.block_in(x)
+ for block in self.blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.pool(hidden_states)
+
+ return self.to_logits(hidden_states)
diff --git a/modeling/modules/ema_model.py b/modeling/modules/ema_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0d2f319b48e9a228eec464b6c2d58b5e37ffc91
--- /dev/null
+++ b/modeling/modules/ema_model.py
@@ -0,0 +1,241 @@
+"""EMA (Exponential Moving Average) model.
+
+Reference:
+ https://github.com/huggingface/open-muse/blob/64e1afe033717d795866ab8204484705cd4dc3f7/muse/modeling_ema.py#L8
+"""
+
+
+import copy
+from typing import Any, Iterable, Optional, Union
+
+import torch
+
+
+class EMAModel:
+ """Exponential Moving Average of models weights."""
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter],
+ decay: float = 0.9999,
+ min_decay: float = 0.0,
+ update_after_step: int = 0,
+ update_every: int = 1,
+ current_step: int = 0,
+ use_ema_warmup: bool = False,
+ inv_gamma: Union[float, int] = 1.0,
+ power: Union[float, int] = 2 / 3,
+ model_cls: Optional[Any] = None,
+ **model_config_kwargs
+ ):
+ """
+ Args:
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
+ decay (float): The decay factor for the exponential moving average.
+ min_decay (float): The minimum decay factor for the exponential moving average.
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
+ update_every (int): The number of steps between each EMA update.
+ current_step (int): The current training step.
+ use_ema_warmup (bool): Whether to use EMA warmup.
+ inv_gamma (float):
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
+
+ notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ """
+
+ parameters = list(parameters)
+ self.shadow_params = [p.clone().detach() for p in parameters]
+ self.temp_stored_params = None
+
+ self.decay = decay
+ self.min_decay = min_decay
+ self.update_after_step = update_after_step
+ self.update_every = update_every
+ self.use_ema_warmup = use_ema_warmup
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.optimization_step = current_step
+ self.cur_decay_value = None # set in `step()`
+
+ self.model_cls = model_cls
+ self.model_config_kwargs = model_config_kwargs
+
+ @classmethod
+ def from_pretrained(cls, checkpoint, model_cls, **model_config_kwargs) -> "EMAModel":
+ model = model_cls(**model_config_kwargs)
+ model.load_pretrained_weight(checkpoint)
+
+ ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs)
+ return ema_model
+
+ def save_pretrained(self, path):
+ if self.model_cls is None:
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
+
+ if self.model_config_kwargs is None:
+ raise ValueError("`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__.")
+
+ model = self.model_cls(**self.model_config_kwargs)
+ self.copy_to(model.parameters())
+ model.save_pretrained_weight(path)
+
+ def set_step(self, optimization_step: int):
+ self.optimization_step = optimization_step
+
+ def get_decay(self, optimization_step: int) -> float:
+ """Computes the decay factor for the exponential moving average."""
+ step = max(0, optimization_step - self.update_after_step - 1)
+
+ if step <= 0:
+ return 0.0
+
+ if self.use_ema_warmup:
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
+ else:
+ cur_decay_value = (1 + step) / (10 + step)
+
+ cur_decay_value = min(cur_decay_value, self.decay)
+ # Make sure decay is not smaller than min_decay.
+ cur_decay_value = max(cur_decay_value, self.min_decay)
+ return cur_decay_value
+
+ @torch.no_grad()
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
+ parameters = list(parameters)
+
+ self.optimization_step += 1
+
+ if (self.optimization_step - 1) % self.update_every != 0:
+ return
+
+ # Compute the decay factor for the exponential moving average.
+ decay = self.get_decay(self.optimization_step)
+ self.cur_decay_value = decay
+ one_minus_decay = 1 - decay
+
+ for s_param, param in zip(self.shadow_params, parameters):
+ if param.requires_grad:
+ s_param.sub_(one_minus_decay * (s_param - param))
+ else:
+ s_param.copy_(param)
+
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ """Copies current averaged parameters into given collection of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the parameters with which this
+ `ExponentialMovingAverage` was initialized will be used.
+ """
+ parameters = list(parameters)
+ for s_param, param in zip(self.shadow_params, parameters):
+ param.data.copy_(s_param.to(param.device).data)
+
+ def to(self, device=None, dtype=None) -> None:
+ r"""Moves internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
+ for p in self.shadow_params
+ ]
+
+ def state_dict(self) -> dict:
+ r"""Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
+ checkpointing to save the ema state dict.
+ """
+ # Following PyTorch conventions, references to tensors are returned:
+ # "returns a reference to the state and not its copy!" -
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
+ return {
+ "decay": self.decay,
+ "min_decay": self.min_decay,
+ "optimization_step": self.optimization_step,
+ "update_after_step": self.update_after_step,
+ "use_ema_warmup": self.use_ema_warmup,
+ "inv_gamma": self.inv_gamma,
+ "power": self.power,
+ "shadow_params": self.shadow_params,
+ }
+
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ r"""
+ Args:
+ Save the current parameters for restoring later.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
+
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ r"""Restores the parameters stored with the `store` method. Useful to validate
+ the model with EMA parameters without affecting the original optimization process.
+ Store the parameters before the `copy_to()` method. After validation (or
+ model saving), use this to restore the former parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the parameters with which this
+ `ExponentialMovingAverage` was initialized will be used.
+ """
+ if self.temp_stored_params is None:
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
+ for c_param, param in zip(self.temp_stored_params, parameters):
+ param.data.copy_(c_param.data)
+
+ # Better memory-wise.
+ self.temp_stored_params = None
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
+ ema state dict.
+
+ Args:
+ state_dict (dict): EMA state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # Deepcopy, to be consistent with module API
+ state_dict = copy.deepcopy(state_dict)
+
+ self.decay = state_dict.get("decay", self.decay)
+ if self.decay < 0.0 or self.decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
+ if not isinstance(self.min_decay, float):
+ raise ValueError("Invalid min_decay")
+
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
+ if not isinstance(self.optimization_step, int):
+ raise ValueError("Invalid optimization_step")
+
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
+ if not isinstance(self.update_after_step, int):
+ raise ValueError("Invalid update_after_step")
+
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
+ if not isinstance(self.use_ema_warmup, bool):
+ raise ValueError("Invalid use_ema_warmup")
+
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
+ if not isinstance(self.inv_gamma, (float, int)):
+ raise ValueError("Invalid inv_gamma")
+
+ self.power = state_dict.get("power", self.power)
+ if not isinstance(self.power, (float, int)):
+ raise ValueError("Invalid power")
+
+ shadow_params = state_dict.get("shadow_params", None)
+ if shadow_params is not None:
+ self.shadow_params = shadow_params
+ if not isinstance(self.shadow_params, list):
+ raise ValueError("shadow_params must be a list")
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
+ raise ValueError("shadow_params must all be Tensors")
\ No newline at end of file
diff --git a/modeling/modules/encoder_decoder.py b/modeling/modules/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..33559d9a446872dafab40c281ee44aa4fa3e8abe
--- /dev/null
+++ b/modeling/modules/encoder_decoder.py
@@ -0,0 +1,1142 @@
+"""Encoder and decoder building blocks for VibeToken.
+
+Reference:
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
+ https://github.com/baofff/U-ViT/blob/main/libs/timm.py
+"""
+
+import random
+import math
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from collections import OrderedDict
+import einops
+from einops.layers.torch import Rearrange
+from typing import Optional, Sequence, Tuple, Union
+from modeling.modules.fuzzy_embedding import FuzzyEmbedding
+import collections.abc
+from itertools import repeat
+from typing import Any
+import numpy as np
+import torch.nn.functional as F
+from einops import rearrange
+from torch import vmap
+from torch import Tensor
+
+def to_2tuple(x: Any) -> Tuple:
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, 2))
+
+class PatchMixture():
+ def __init__(self, seed=42):
+ self.seed = seed
+
+ def get_mask(self, x, mask_ratio=0.0, l1_reg=0.0, inverse=False):
+ batch_size, num_patches, _ = x.shape
+ device = x.device
+ num_mask = int(num_patches * mask_ratio)
+ num_keep = num_patches - num_mask
+ token_magnitudes = x.abs().sum(dim=-1)
+ min_mags = token_magnitudes.min(dim=1, keepdim=True)[0]
+ max_mags = token_magnitudes.max(dim=1, keepdim=True)[0]
+ token_magnitudes = (token_magnitudes - min_mags) / (max_mags - min_mags + 1e-8)
+ if inverse:
+ adjusted_magnitudes = 1.0 - token_magnitudes
+ else:
+ adjusted_magnitudes = token_magnitudes
+ noise_random = torch.rand(batch_size, num_patches, device=device)
+ noise = (1.0 - l1_reg) * noise_random + l1_reg * adjusted_magnitudes
+ ids_shuffle = torch.argsort(noise, dim=1)
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+ ids_keep = ids_shuffle[:, :num_keep]
+ ids_mask = ids_shuffle[:, num_keep:]
+ mask = torch.ones((batch_size, num_patches), device=device, dtype=torch.bool)
+ mask.scatter_(1, ids_keep, False)
+ return {
+ 'mask': mask,
+ 'ids_keep': ids_keep,
+ 'ids_mask': ids_mask,
+ 'ids_shuffle': ids_shuffle,
+ 'ids_restore': ids_restore
+ }
+
+ def start_route(self, x, mask_info):
+ ids_shuffle = mask_info['ids_shuffle']
+ num_keep = mask_info['ids_keep'].size(1)
+ batch_indices = torch.arange(x.size(0), device=x.device).unsqueeze(-1)
+ x_shuffled = x.gather(1, ids_shuffle.unsqueeze(-1).expand(-1, -1, x.size(2)))
+ masked_x = x_shuffled[:, :num_keep, :]
+ return masked_x
+
+ def end_route(self, masked_x, mask_info, original_x=None, mask_token=0.0):
+ batch_size, num_patches = mask_info['mask'].shape
+ num_keep = masked_x.size(1)
+ dim = masked_x.size(2)
+ device = masked_x.device
+ ids_restore = mask_info['ids_restore']
+ batch_indices = torch.arange(batch_size, device=device).unsqueeze(-1)
+ x_unshuffled = torch.empty((batch_size, num_patches, dim), device=device)
+ x_unshuffled[:, :num_keep, :] = masked_x
+ if original_x is not None:
+ x_shuffled = original_x.gather(1, mask_info['ids_shuffle'].unsqueeze(-1).expand(-1, -1, dim))
+ x_unshuffled[:, num_keep:, :] = x_shuffled[:, num_keep:, :]
+ else:
+ x_unshuffled[:, num_keep:, :].fill_(mask_token)
+ x_unmasked = x_unshuffled.gather(1, ids_restore.unsqueeze(-1).expand(-1, -1, dim))
+ return x_unmasked
+
+class ResizableBlur(nn.Module):
+ """
+ Single-parameter anti‑aliasing layer.
+ Call with scale=1,2,4 to downsample by 1× (identity), 2×, or 4×.
+ """
+ def __init__(self, channels: int,
+ max_kernel_size: int = 9,
+ init_type: str = "gaussian"):
+ super().__init__()
+ self.C = channels
+ K = max_kernel_size # e.g. 9 for 4×
+ assert K % 2 == 1, "kernel must be odd"
+
+ # ----- initialise the largest kernel ---------------------------------
+ if init_type == "gaussian":
+ # 2‑D separable Gaussian, σ≈K/6
+ ax = torch.arange(-(K//2), K//2 + 1)
+ g1d = torch.exp(-0.5 * (ax / (K/6.0))**2)
+ g2d = torch.outer(g1d, g1d)
+ kernel = g2d / g2d.sum()
+ elif init_type == "lanczos":
+ a = K//2 # window size parameter
+ x = torch.arange(-a, a+1).float()
+ sinc = lambda t: torch.where(t==0, torch.ones_like(t), torch.sin(torch.pi*t)/(torch.pi*t))
+ k1d = sinc(x) * sinc(x/a)
+ k2d = torch.outer(k1d, k1d)
+ kernel = k2d / k2d.sum()
+ else:
+ raise ValueError("unknown init_type")
+
+ # learnable base kernel (shape 1×1×K×K)
+ self.weight = nn.Parameter(kernel.unsqueeze(0).unsqueeze(0))
+
+ # ------------------------------------------------------------------------
+ @staticmethod
+ def _resize_and_normalise(weight: torch.Tensor, k_size: int) -> torch.Tensor:
+ """
+ Bilinearly interpolate weight (B,C,H,W) to target k_size×k_size,
+ then L1‑normalise over spatial dims so Σ=1.
+ """
+ if weight.shape[-1] != k_size:
+ weight = F.interpolate(weight, size=(k_size, k_size),
+ mode="bilinear", align_corners=True)
+ weight = weight / weight.sum(dim=(-2, -1), keepdim=True).clamp(min=1e-8)
+ return weight
+
+ # ------------------------------------------------------------------------
+ def forward(self, x: torch.Tensor, input_size, target_size) -> torch.Tensor:
+ # Unpack input and target dimensions
+ input_h, input_w = input_size
+ target_h, target_w = target_size
+
+ # Calculate scale factors for height and width
+ scale_h = input_h / target_h
+ scale_w = input_w / target_w
+
+ # Determine kernel size based on scale factors
+ # Larger scale factors need larger kernels for better anti-aliasing
+ k_size_h = min(self.weight.shape[-1], max(1, int(2 * scale_h + 3)))
+ k_size_w = min(self.weight.shape[-1], max(1, int(2 * scale_w + 3)))
+
+ # Make sure kernel sizes are odd
+ k_size_h = k_size_h if k_size_h % 2 == 1 else k_size_h + 1
+ k_size_w = k_size_w if k_size_w % 2 == 1 else k_size_w + 1
+
+ # Use the maximum for a square kernel, or create a rectangular kernel if needed
+ k_size = max(k_size_h, k_size_w)
+
+ # Calculate appropriate stride and padding
+ stride_h = max(1, round(scale_h))
+ stride_w = max(1, round(scale_w))
+ pad_h = k_size_h // 2
+ pad_w = k_size_w // 2
+
+ # Get the kernel and normalize it
+ k = self._resize_and_normalise(self.weight, k_size) # (1,1,k,k)
+ k = k.repeat(self.C, 1, 1, 1) # depth-wise
+
+ # Apply convolution with calculated parameters
+ result = F.conv2d(x, weight=k, stride=(stride_h, stride_w),
+ padding=(pad_h, pad_w), groups=self.C)
+
+ # If the convolution didn't get us exactly to the target size, use interpolation for fine adjustment
+ if result.shape[2:] != target_size:
+ result = F.interpolate(result, size=target_size, mode='bilinear', align_corners=True)
+
+ return result
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ n_head,
+ mlp_ratio = 4.0,
+ act_layer = nn.GELU,
+ norm_layer = nn.LayerNorm
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.mlp_ratio = mlp_ratio
+ # optionally we can disable the FFN
+ if mlp_ratio > 0:
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+
+ def attention(
+ self,
+ x: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None
+ ):
+ return self.attn(x, x, x, attn_mask=attention_mask, need_weights=False)[0]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None
+ ):
+ attn_output = self.attention(x=self.ln_1(x), attention_mask=attention_mask)
+ x = x + attn_output
+ if self.mlp_ratio > 0:
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
+ ATTENTION_MODE = 'flash'
+else:
+ try:
+ import xformers
+ import xformers.ops
+ ATTENTION_MODE = 'xformers'
+ except:
+ ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, L, C = x.shape
+
+ qkv = self.qkv(x)
+ if ATTENTION_MODE == 'flash':
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'xformers':
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
+ x = xformers.ops.memory_efficient_attention(q, k, v)
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
+ elif ATTENTION_MODE == 'math':
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+ else:
+ raise NotImplemented
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class UViTBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, x, skip=None):
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
+ else:
+ return self._forward(x, skip)
+
+ def _forward(self, x, skip=None):
+ if self.skip_linear is not None:
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+def _expand_token(token, batch_size: int):
+ return token.unsqueeze(0).expand(batch_size, -1, -1)
+
+
+class ResolutionEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.image_size = config.dataset.preprocessing.crop_size
+ self.patch_size = config.model.vq_model.vit_enc_patch_size
+ self.model_size = config.model.vq_model.vit_enc_model_size
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
+ self.token_size = config.model.vq_model.token_size
+ self.apply_fuzzy = config.model.vq_model.get("apply_fuzzy", False)
+ self.patch_mixture_start_layer = config.model.vq_model.get("patch_mixture_start_layer", 100)
+ self.patch_mixture_end_layer = config.model.vq_model.get("patch_mixture_end_layer", 100)
+
+ if config.model.vq_model.get("quantize_mode", "vq") == "vae":
+ self.token_size = self.token_size * 2 # needs to split into mean and std
+
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
+
+ self.width = {
+ "tiny": 256,
+ "small": 512,
+ "base": 768,
+ "large": 1024,
+ }[self.model_size]
+ self.num_layers = {
+ "tiny": 4,
+ "small": 8,
+ "base": 12,
+ "large": 24,
+ }[self.model_size]
+ self.num_heads = {
+ "tiny": 4,
+ "small": 8,
+ "base": 12,
+ "large": 16,
+ }[self.model_size]
+
+ self.patch_embed = nn.Conv2d(
+ in_channels=3, out_channels=self.width,
+ kernel_size=self.patch_size, stride=self.patch_size, bias=True)
+
+ scale = self.width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
+
+ self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
+
+ self.latent_token_positional_embedding = nn.Parameter(
+ scale * torch.randn(self.num_latent_tokens, self.width))
+ self.ln_pre = nn.LayerNorm(self.width)
+
+ self.patch_mixture = PatchMixture()
+
+ self.transformer = nn.ModuleList()
+ for i in range(self.num_layers):
+ self.transformer.append(ResidualAttentionBlock(
+ self.width, self.num_heads, mlp_ratio=4.0
+ ))
+
+ self.ln_post = nn.LayerNorm(self.width)
+ self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
+ self.pinvs = {}
+
+ def apply_flexivit_patch_embed(self, x, target_patch_size):
+ patch_size = to_2tuple(target_patch_size)
+
+ # Resize conv weights
+ if patch_size == to_2tuple(self.patch_size):
+ weight = self.patch_embed.weight
+ else:
+ weight = self.resize_patch_embed(self.patch_embed.weight, patch_size)
+
+ # Apply conv with resized weights
+ x = F.conv2d(x, weight, bias=self.patch_embed.bias, stride=patch_size)
+ return x
+
+ def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
+ x_resized = F.interpolate(
+ x[None, None, ...],
+ shape,
+ mode="bilinear",
+ antialias=False,
+ )
+ return x_resized[0, 0, ...]
+
+ def _calculate_pinv(
+ self, old_shape: Tuple[int, int], new_shape: Tuple[int, int], device=None
+ ) -> Tensor:
+ # Use the device from patch_embed weights if available
+ if device is None and hasattr(self, 'patch_embed'):
+ device = self.patch_embed.weight.device
+
+ mat = []
+ for i in range(np.prod(old_shape)):
+ basis_vec = torch.zeros(old_shape, device=device) # Specify device here
+ basis_vec[np.unravel_index(i, old_shape)] = 1.0
+ mat.append(self._resize(basis_vec, new_shape).reshape(-1))
+ resize_matrix = torch.stack(mat)
+ return torch.linalg.pinv(resize_matrix)
+
+ def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]):
+ """Resize patch_embed to target resolution via pseudo-inverse resizing"""
+ # Return original kernel if no resize is necessary
+ if to_2tuple(self.patch_size) == new_patch_size:
+ return patch_embed
+
+ # Calculate pseudo-inverse of resize matrix
+ if new_patch_size not in self.pinvs:
+ self.pinvs[new_patch_size] = self._calculate_pinv(
+ to_2tuple(self.patch_size), new_patch_size, device=patch_embed.device
+ )
+ pinv = self.pinvs[new_patch_size]
+
+ def resample_patch_embed(patch_embed: Tensor):
+ h, w = new_patch_size
+ original_dtype = patch_embed.dtype
+ patch_embed_float = patch_embed.float()
+ resampled_kernel = pinv @ patch_embed_float.reshape(-1)
+ resampled_kernel = resampled_kernel.to(original_dtype)
+ return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
+
+ v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
+
+ return v_resample_patch_embed(patch_embed)
+
+ def get_attention_mask(self, target_shape, attention_mask):
+ # Create mask for mask_tokens (all True since we want to attend to all mask tokens)
+ mask_token_mask = torch.ones(target_shape).to(attention_mask.device)
+ # Combine with input attention mask
+ attention_mask = torch.cat((mask_token_mask, attention_mask), dim=1).bool()
+ sequence_length = attention_mask.shape[1]
+
+ # Create causal attention mask
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0],
+ self.num_heads,
+ sequence_length,
+ sequence_length
+ )
+
+ # Reshape to [B*num_heads, S, S]
+ attention_mask = attention_mask.reshape(
+ -1, sequence_length, sequence_length
+ )
+
+ # Convert boolean mask to float
+ attention_mask = attention_mask.float()
+
+ # Convert mask values: True -> 0.0, False -> -inf
+ attention_mask = attention_mask.masked_fill(
+ ~attention_mask.bool(),
+ float('-inf')
+ )
+ return attention_mask
+
+ def forward(self, pixel_values, latent_tokens, attention_mask=None, encode_patch_size=None, train=True):
+ batch_size, _, H, W = pixel_values.shape
+ x = pixel_values
+
+ # Apply dynamic patch embedding
+ # Determine patch size dynamically based on image resolution
+ # Base patch size (32) is for 512x512 images
+ # Scale proportionally for other resolutions to maintain ~256 tokens
+ base_resolution = 512
+
+ if encode_patch_size is None:
+ base_patch_size = random.choice([16, 32])
+ target_patch_size = min(int(min(H, W) / base_resolution * base_patch_size), 32) # we force it to be at most 32 otherwise we lose information
+ else:
+ target_patch_size = encode_patch_size
+
+ if isinstance(target_patch_size, int):
+ target_patch_size = (target_patch_size, target_patch_size)
+
+ x = self.apply_flexivit_patch_embed(x, target_patch_size)
+
+ x = x.reshape(x.shape[0], x.shape[1], -1)
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ # class embeddings and positional embeddings
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
+
+ # create image_rotary_emb
+ grid_height = H // target_patch_size[0]
+ grid_width = W // target_patch_size[1]
+
+ mask_ratio = 0.0
+ if grid_height*grid_width > 256 and train:
+ mask_ratio = torch.empty(1).uniform_(0.5, 0.7).item()
+
+ num_latent_tokens = latent_tokens.shape[0]
+ latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
+ latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)[:num_latent_tokens]
+
+ x = x + self.positional_embedding(grid_height, grid_width, train=train, dtype=x.dtype)
+
+ # apply attention_mask before concatenating x and latent_tokens
+ if attention_mask is not None:
+ key_attention_mask = attention_mask.clone()
+ attention_mask = self.get_attention_mask((batch_size, x.shape[1]), key_attention_mask)
+ full_seq_attention_mask = attention_mask.clone()
+ else:
+ key_attention_mask = None
+ full_seq_attention_mask = None
+
+ # Concatenate x and latent_tokens first
+ x = torch.cat([x, latent_tokens], dim=1)
+
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for i in range(self.num_layers):
+ if i == self.patch_mixture_start_layer:
+ x = x.permute(1, 0, 2)
+ x_D_last = x[:, 1:grid_height*grid_width+1].clone()
+ mask_info = self.patch_mixture.get_mask(x[:, 1:grid_height*grid_width+1], mask_ratio=mask_ratio)
+ new_x = self.patch_mixture.start_route(x, mask_info)
+ x = torch.cat([x[:, :1], new_x, x[:, grid_height*grid_width+1:]], dim=1)
+ x = x.permute(1, 0, 2)
+ if key_attention_mask is not None:
+ attention_mask = self.get_attention_mask((batch_size, 1+new_x.shape[1]), key_attention_mask)
+ else:
+ attention_mask = None
+
+ x = self.transformer[i](x, attention_mask=attention_mask)
+
+ if i == self.patch_mixture_end_layer:
+ x = x.permute(1, 0, 2)
+ new_x = self.patch_mixture.end_route(x[:, 1:-self.num_latent_tokens], mask_info, original_x=x_D_last)
+ x = torch.cat([x[:, :1], new_x, x[:, -self.num_latent_tokens:]], dim=1)
+ x = x.permute(1, 0, 2)
+ if full_seq_attention_mask is not None:
+ attention_mask = full_seq_attention_mask.clone()
+ else:
+ attention_mask = None
+
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ latent_tokens = x[:, 1+grid_height*grid_width:]
+ latent_tokens = self.ln_post(latent_tokens)
+
+ # fake 2D shape
+ if self.is_legacy:
+ latent_tokens = latent_tokens.reshape(batch_size, self.width, num_latent_tokens, 1)
+ else:
+ # Fix legacy problem.
+ latent_tokens = latent_tokens.reshape(batch_size, num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)
+ latent_tokens = self.conv_out(latent_tokens)
+ latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, num_latent_tokens)
+ return latent_tokens
+
+# Keep the original TiTokEncoder as a legacy class
+class TiTokEncoder(ResolutionEncoder):
+ """Legacy TiTokEncoder - now inherits from ResolutionEncoder for backward compatibility"""
+ pass
+
+class ResolutionDecoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.image_size = config.dataset.preprocessing.crop_size
+ self.patch_size = config.model.vq_model.vit_dec_patch_size
+ self.model_size = config.model.vq_model.vit_dec_model_size
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
+ self.token_size = config.model.vq_model.token_size
+ self.apply_fuzzy = config.model.vq_model.get("apply_fuzzy", False)
+ self.patch_mixture_start_layer = config.model.vq_model.get("patch_mixture_start_layer", 100)
+ self.patch_mixture_end_layer = config.model.vq_model.get("patch_mixture_end_layer", 100)
+
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
+ self.width = {
+ "tiny": 256,
+ "small": 512,
+ "base": 768,
+ "large": 1024,
+ }[self.model_size]
+ self.num_layers = {
+ "tiny": 4,
+ "small": 8,
+ "base": 12,
+ "large": 24,
+ }[self.model_size]
+ self.num_heads = {
+ "tiny": 4,
+ "small": 8,
+ "base": 12,
+ "large": 16,
+ }[self.model_size]
+
+ self.decoder_embed = nn.Linear(
+ self.token_size, self.width, bias=True)
+ scale = self.width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
+
+ self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
+
+ # add mask token and query pos embed
+ self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
+ self.latent_token_positional_embedding = nn.Parameter(
+ scale * torch.randn(self.num_latent_tokens, self.width))
+ self.ln_pre = nn.LayerNorm(self.width)
+
+ self.patch_mixture = PatchMixture()
+
+ self.transformer = nn.ModuleList()
+ for i in range(self.num_layers):
+ self.transformer.append(ResidualAttentionBlock(
+ self.width, self.num_heads, mlp_ratio=4.0
+ ))
+ self.ln_post = nn.LayerNorm(self.width)
+
+ if self.is_legacy:
+ raise NotImplementedError("Legacy mode is not implemented for ResolutionDecoder")
+ else:
+ # Directly predicting RGB pixels
+ self.ffn = nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True)
+ self.rearrange = Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
+ p1 = self.patch_size, p2 = self.patch_size)
+ self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos")
+ self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
+
+ def get_attention_mask(self, target_shape, attention_mask):
+ # Create mask for mask_tokens (all True since we want to attend to all mask tokens)
+ mask_token_mask = torch.ones(target_shape).to(attention_mask.device)
+ # Combine with input attention mask
+ attention_mask = torch.cat((mask_token_mask, attention_mask), dim=1).bool()
+ sequence_length = attention_mask.shape[1]
+
+ # Create causal attention mask
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S]
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0],
+ self.num_heads,
+ sequence_length,
+ sequence_length
+ )
+
+ # Reshape to [B*num_heads, S, S]
+ attention_mask = attention_mask.reshape(
+ -1, sequence_length, sequence_length
+ )
+
+ # Convert boolean mask to float
+ attention_mask = attention_mask.float()
+
+ # Convert mask values: True -> 0.0, False -> -inf
+ attention_mask = attention_mask.masked_fill(
+ ~attention_mask.bool(),
+ float('-inf')
+ )
+ return attention_mask
+
+ def forward(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
+ N, C, H, W = z_quantized.shape
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
+ x = self.decoder_embed(x)
+
+ batchsize, seq_len, _ = x.shape
+
+ if height is None:
+ height = self.image_size
+ if width is None:
+ width = self.image_size
+
+ # create image_rotary_emb
+ if decode_patch_size is None:
+ # Calculate total area and determine appropriate patch size
+ total_pixels = height * width
+
+ # Target patch counts between 256 and 1024
+ min_patches = 256
+ max_patches = 1024
+
+ # Calculate possible patch sizes that would give us patch counts in our target range
+ possible_patch_sizes = []
+ for patch_size in [8, 16, 32]:
+ grid_h = height // patch_size
+ grid_w = width // patch_size
+ total_patches = grid_h * grid_w
+ if min_patches <= total_patches <= max_patches:
+ possible_patch_sizes.append(patch_size)
+
+ if not possible_patch_sizes:
+ # If no patch size gives us the desired range, pick the one closest to our target range
+ patch_counts = []
+ for patch_size in [8, 16, 32]:
+ grid_h = height // patch_size
+ grid_w = width // patch_size
+ patch_counts.append((patch_size, grid_h * grid_w))
+
+ # Sort by how close the patch count is to our target range
+ patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches)))
+ possible_patch_sizes = [patch_counts[0][0]]
+
+ selected_patch_size = random.choice(possible_patch_sizes)
+ else:
+ selected_patch_size = decode_patch_size
+
+ if isinstance(selected_patch_size, int):
+ selected_patch_size = (selected_patch_size, selected_patch_size)
+
+ grid_height = height // selected_patch_size[0]
+ grid_width = width // selected_patch_size[1]
+
+ # if grid_height*grid_width>1024 and train:
+ # grid_height = 32
+ # grid_width = 32
+
+ mask_ratio = 0.0
+ if grid_height*grid_width > 256 and train:
+ mask_ratio = torch.empty(1).uniform_(0.5, 0.7).item()
+
+ mask_tokens = self.mask_token.repeat(batchsize, grid_height*grid_width, 1).to(x.dtype)
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
+ mask_tokens], dim=1)
+
+ mask_tokens = mask_tokens + self.positional_embedding(grid_height, grid_width, train=train).to(mask_tokens.dtype)
+
+ x = x + self.latent_token_positional_embedding[:seq_len]
+ x = torch.cat([mask_tokens, x], dim=1)
+
+ if attention_mask is not None:
+ key_attention_mask = attention_mask.clone()
+ attention_mask = self.get_attention_mask((batchsize, 1+grid_height*grid_width), key_attention_mask)
+ full_seq_attention_mask = attention_mask.clone()
+ else:
+ key_attention_mask = None
+ full_seq_attention_mask = None
+
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for i in range(self.num_layers):
+ if i == self.patch_mixture_start_layer:
+ x = x.permute(1, 0, 2)
+ x_D_last = x[:, 1:grid_height*grid_width+1].clone()
+ mask_info = self.patch_mixture.get_mask(x[:, 1:grid_height*grid_width+1], mask_ratio=mask_ratio)
+ new_x = self.patch_mixture.start_route(x, mask_info)
+ x = torch.cat([x[:, :1], new_x, x[:, grid_height*grid_width+1:]], dim=1)
+ x = x.permute(1, 0, 2)
+ if key_attention_mask is not None:
+ attention_mask = self.get_attention_mask((batchsize, 1+new_x.shape[1]), key_attention_mask)
+ else:
+ attention_mask = None
+
+ x = self.transformer[i](x, attention_mask=attention_mask)
+
+ if i == self.patch_mixture_end_layer:
+ x = x.permute(1, 0, 2)
+ new_x = self.patch_mixture.end_route(x[:, 1:-self.num_latent_tokens], mask_info, original_x=x_D_last)
+ x = torch.cat([x[:, :1], new_x, x[:, -self.num_latent_tokens:]], dim=1)
+ x = x.permute(1, 0, 2)
+ if full_seq_attention_mask is not None:
+ attention_mask = full_seq_attention_mask.clone()
+ else:
+ attention_mask = None
+
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = x[:, 1:1+grid_height*grid_width] # remove cls embed
+ x = self.ln_post(x)
+ # N L D -> N D H W
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
+ x = self.ffn(x.contiguous())
+ x = self.rearrange(x)
+ _, _, org_h, org_w = x.shape
+ x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width))
+ x = self.conv_out(x)
+
+ return x
+
+# Keep the original TiTokDecoder as a legacy class that inherits from ResolutionDecoder
+class TiTokDecoder(ResolutionDecoder):
+ """Legacy TiTokDecoder - now inherits from ResolutionDecoder for backward compatibility"""
+
+ def __init__(self, config):
+ # Override config to disable patch mixture and other advanced features for legacy mode
+ config_copy = type(config)()
+ for attr in dir(config):
+ if not attr.startswith('__'):
+ try:
+ setattr(config_copy, attr, getattr(config, attr))
+ except:
+ pass
+
+ # Disable patch mixture for legacy mode
+ if hasattr(config_copy.model.vq_model, 'patch_mixture_start_layer'):
+ config_copy.model.vq_model.patch_mixture_start_layer = -1
+ if hasattr(config_copy.model.vq_model, 'patch_mixture_end_layer'):
+ config_copy.model.vq_model.patch_mixture_end_layer = -1
+
+ super().__init__(config_copy)
+
+ # Override grid_size for legacy compatibility
+ self.grid_size = self.image_size // self.patch_size
+
+ # Replace ResolutionDecoder's advanced final layers with legacy ones if needed
+ if self.is_legacy:
+ self.ffn = nn.Sequential(
+ nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
+ nn.Tanh(),
+ nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
+ )
+ self.conv_out = nn.Identity()
+ else:
+ # Use simpler final layers for backward compatibility
+ self.ffn = nn.Sequential(
+ nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True),
+ Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
+ p1 = self.patch_size, p2 = self.patch_size),)
+ self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
+
+ def forward(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
+ # Legacy compatibility: use fixed grid size if height/width not provided
+ if height is None:
+ height = self.image_size
+ if width is None:
+ width = self.image_size
+
+ # Force decode_patch_size to be the original patch_size for legacy compatibility
+ if decode_patch_size is None:
+ decode_patch_size = self.patch_size
+
+ # Use the parent's forward method but with legacy parameters
+ return super().forward(z_quantized, attention_mask, height, width, decode_patch_size, train)
+
+
+class TATiTokDecoder(ResolutionDecoder):
+ def __init__(self, config):
+ super().__init__(config)
+ scale = self.width ** -0.5
+ self.text_context_length = config.model.vq_model.get("text_context_length", 77)
+ self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768)
+ self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width)
+ self.text_guidance_positional_embedding = nn.Parameter(scale * torch.randn(self.text_context_length, self.width))
+
+ # Add grid_size for backward compatibility
+ self.grid_size = self.image_size // self.patch_size
+
+ def forward(self, z_quantized, text_guidance, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
+ N, C, H, W = z_quantized.shape
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
+ x = self.decoder_embed(x)
+
+ batchsize, seq_len, _ = x.shape
+
+ # Use fixed grid size for backward compatibility
+ if height is None:
+ height = self.image_size
+ if width is None:
+ width = self.image_size
+ if decode_patch_size is None:
+ decode_patch_size = self.patch_size
+
+ grid_height = height // decode_patch_size
+ grid_width = width // decode_patch_size
+
+ mask_tokens = self.mask_token.repeat(batchsize, grid_height*grid_width, 1).to(x.dtype)
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
+ mask_tokens], dim=1)
+ mask_tokens = mask_tokens + self.positional_embedding(grid_height, grid_width, train=train).to(mask_tokens.dtype)
+ x = x + self.latent_token_positional_embedding[:seq_len]
+ x = torch.cat([mask_tokens, x], dim=1)
+
+ text_guidance = self.text_guidance_proj(text_guidance)
+ text_guidance = text_guidance + self.text_guidance_positional_embedding
+ x = torch.cat([x, text_guidance], dim=1)
+
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for i in range(self.num_layers):
+ x = self.transformer[i](x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = x[:, 1:1+grid_height*grid_width] # remove cls embed
+ x = self.ln_post(x)
+ # N L D -> N D H W
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
+ x = self.ffn(x.contiguous())
+ x = self.conv_out(x)
+ return x
+
+
+class WeightTiedLMHead(nn.Module):
+ def __init__(self, embeddings, target_codebook_size):
+ super().__init__()
+ self.weight = embeddings.weight
+ self.target_codebook_size = target_codebook_size
+
+ def forward(self, x):
+ # x shape: [batch_size, seq_len, embed_dim]
+ # Get the weights for the target codebook size
+ weight = self.weight[:self.target_codebook_size] # Shape: [target_codebook_size, embed_dim]
+ # Compute the logits by matrix multiplication
+ logits = torch.matmul(x, weight.t()) # Shape: [batch_size, seq_len, target_codebook_size]
+ return logits
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class ResBlock(nn.Module):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ """
+
+ def __init__(
+ self,
+ channels
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
+ self.mlp = nn.Sequential(
+ nn.Linear(channels, channels, bias=True),
+ nn.SiLU(),
+ nn.Linear(channels, channels, bias=True),
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 3 * channels, bias=True)
+ )
+
+ def forward(self, x, y):
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
+ h = self.mlp(h)
+ return x + gate_mlp * h
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer adopted from DiT.
+ """
+ def __init__(self, model_channels, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class SimpleMLPAdaLN(nn.Module):
+ """
+ The MLP for Diffusion Loss.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param z_channels: channels in the condition.
+ :param num_res_blocks: number of residual blocks per downsample.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ z_channels,
+ num_res_blocks,
+ grad_checkpointing=False,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.grad_checkpointing = grad_checkpointing
+
+ self.time_embed = TimestepEmbedder(model_channels)
+ self.cond_embed = nn.Linear(z_channels, model_channels)
+
+ self.input_proj = nn.Linear(in_channels, model_channels)
+
+ res_blocks = []
+ for i in range(num_res_blocks):
+ res_blocks.append(ResBlock(
+ model_channels,
+ ))
+
+ self.res_blocks = nn.ModuleList(res_blocks)
+ self.final_layer = FinalLayer(model_channels, out_channels)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize timestep embedding MLP
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers
+ for block in self.res_blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, c):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C] Tensor of inputs.
+ :param t: a 1-D batch of timesteps.
+ :param c: conditioning from AR transformer.
+ :return: an [N x C] Tensor of outputs.
+ """
+ x = self.input_proj(x)
+ t = self.time_embed(t)
+ c = self.cond_embed(c)
+
+ y = t + c
+
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ for block in self.res_blocks:
+ x = checkpoint(block, x, y)
+ else:
+ for block in self.res_blocks:
+ x = block(x, y)
+
+ return self.final_layer(x, y)
+
+ def forward_with_cfg(self, x, t, c, cfg_scale):
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, c)
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
\ No newline at end of file
diff --git a/modeling/modules/fuzzy_embedding.py b/modeling/modules/fuzzy_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0820631a0114b7f0aee125ecc861b83482dd375
--- /dev/null
+++ b/modeling/modules/fuzzy_embedding.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+import math
+
+class FuzzyEmbedding(nn.Module):
+ def __init__(self, grid_size, scale, width, apply_fuzzy=False):
+ super(FuzzyEmbedding, self).__init__()
+ assert grid_size == 1024, "grid_size must be 1024 for now"
+
+ self.grid_size = grid_size
+ self.scale = scale
+ self.width = width
+ self.apply_fuzzy = apply_fuzzy
+ # grid_size is the minimum possible token size
+ # then we can use grid_sample to get the fuzzy embedding for any resolution
+ self.positional_embedding = nn.Parameter(
+ scale * torch.randn(grid_size, width))
+
+ self.class_positional_embedding = nn.Parameter(
+ scale * torch.randn(1, width))
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, grid_height, grid_width, train=True, dtype=torch.float32):
+ meshx, meshy = torch.meshgrid(
+ torch.tensor(list(range(grid_height)), device=self.positional_embedding.device),
+ torch.tensor(list(range(grid_width)), device=self.positional_embedding.device)
+ )
+ meshx = meshx.to(dtype)
+ meshy = meshy.to(dtype)
+
+ # Normalize coordinates to [-1, 1] range
+ meshx = 2 * (meshx / (grid_height - 1)) - 1
+ meshy = 2 * (meshy / (grid_width - 1)) - 1
+
+ if self.apply_fuzzy:
+ # Add uniform noise in range [-0.0004, 0.0004] to x and y coordinates
+ if train:
+ noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004
+ noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004
+ else:
+ noise_x = torch.zeros_like(meshx)
+ noise_y = torch.zeros_like(meshy)
+
+ # Apply noise to the mesh coordinates
+ meshx = meshx + noise_x
+ meshy = meshy + noise_y
+
+ grid = torch.stack((meshy, meshx), 2).to(self.positional_embedding.device)
+ grid = grid.unsqueeze(0) # add batch dim
+
+ positional_embedding = einops.rearrange(self.positional_embedding, "(h w) d -> d h w", h=int(math.sqrt(self.grid_size)), w=int(math.sqrt(self.grid_size)))
+ positional_embedding = positional_embedding.to(dtype)
+ positional_embedding = positional_embedding.unsqueeze(0) # add batch dim
+
+ fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False)
+ fuzzy_embedding = fuzzy_embedding.to(dtype)
+ fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0)
+
+ final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0)
+ return final_embedding
+
+
+if __name__ == "__main__":
+ fuzzy_embedding = FuzzyEmbedding(256, 1.0, 1024)
+ grid_height = 16
+ grid_width = 32
+ fuzzy_embedding = fuzzy_embedding(grid_height, grid_width, dtype=torch.bfloat16)
+ print(fuzzy_embedding.shape)
\ No newline at end of file
diff --git a/modeling/modules/losses.py b/modeling/modules/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..22aae755b0ba10c0452e6cd243a0e9407c306f06
--- /dev/null
+++ b/modeling/modules/losses.py
@@ -0,0 +1,339 @@
+"""Training loss implementation.
+
+Ref:
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py
+"""
+from typing import Mapping, Text, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.cuda.amp import autocast
+
+from modeling.modules.blocks import SimpleMLPAdaLN
+from .perceptual_loss import PerceptualLoss
+from .discriminator import NLayerDiscriminator
+
+
+def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor:
+ """Hinge loss for discrminator.
+
+ This function is borrowed from
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20
+ """
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def compute_lecam_loss(
+ logits_real_mean: torch.Tensor,
+ logits_fake_mean: torch.Tensor,
+ ema_logits_real_mean: torch.Tensor,
+ ema_logits_fake_mean: torch.Tensor
+) -> torch.Tensor:
+ """Computes the LeCam loss for the given average real and fake logits.
+
+ Args:
+ logits_real_mean -> torch.Tensor: The average real logits.
+ logits_fake_mean -> torch.Tensor: The average fake logits.
+ ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits.
+ ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits.
+
+ Returns:
+ lecam_loss -> torch.Tensor: The LeCam loss.
+ """
+ lecam_loss = torch.mean(torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2))
+ lecam_loss += torch.mean(torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2))
+ return lecam_loss
+
+
+class ReconstructionLoss_Stage1(torch.nn.Module):
+ def __init__(
+ self,
+ config
+ ):
+ super().__init__()
+ loss_config = config.losses
+ self.quantizer_weight = loss_config.quantizer_weight
+ self.target_codebook_size = 1024
+
+ def forward(self,
+ target_codes: torch.Tensor,
+ reconstructions: torch.Tensor,
+ quantizer_loss: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ return self._forward_generator(target_codes, reconstructions, quantizer_loss)
+
+ def _forward_generator(self,
+ target_codes: torch.Tensor,
+ reconstructions: torch.Tensor,
+ quantizer_loss: Mapping[Text, torch.Tensor],
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ reconstructions = reconstructions.contiguous()
+ loss_fct = nn.CrossEntropyLoss(reduction="mean")
+ batch_size = reconstructions.shape[0]
+ reconstruction_loss = loss_fct(reconstructions.view(batch_size, self.target_codebook_size, -1),
+ target_codes.view(batch_size, -1))
+ total_loss = reconstruction_loss + \
+ self.quantizer_weight * quantizer_loss["quantizer_loss"]
+
+ loss_dict = dict(
+ total_loss=total_loss.clone().detach(),
+ reconstruction_loss=reconstruction_loss.detach(),
+ quantizer_loss=(self.quantizer_weight * quantizer_loss["quantizer_loss"]).detach(),
+ commitment_loss=quantizer_loss["commitment_loss"].detach(),
+ codebook_loss=quantizer_loss["codebook_loss"].detach(),
+ )
+
+ return total_loss, loss_dict
+
+
+class ReconstructionLoss_Stage2(torch.nn.Module):
+ def __init__(
+ self,
+ config
+ ):
+ """Initializes the losses module.
+
+ Args:
+ config: A dictionary, the configuration for the model and everything else.
+ """
+ super().__init__()
+ loss_config = config.losses
+ self.discriminator = NLayerDiscriminator()
+
+ self.reconstruction_loss = loss_config.reconstruction_loss
+ self.reconstruction_weight = loss_config.reconstruction_weight
+ self.quantizer_weight = loss_config.quantizer_weight
+ self.perceptual_loss = PerceptualLoss(
+ loss_config.perceptual_loss).eval()
+ self.perceptual_weight = loss_config.perceptual_weight
+ self.discriminator_iter_start = loss_config.discriminator_start
+
+ self.discriminator_factor = loss_config.discriminator_factor
+ self.discriminator_weight = loss_config.discriminator_weight
+ self.lecam_regularization_weight = loss_config.lecam_regularization_weight
+ self.lecam_ema_decay = loss_config.get("lecam_ema_decay", 0.999)
+ if self.lecam_regularization_weight > 0.0:
+ self.register_buffer("ema_real_logits_mean", torch.zeros((1)))
+ self.register_buffer("ema_fake_logits_mean", torch.zeros((1)))
+
+ self.config = config
+
+ @autocast(enabled=False)
+ def forward(self,
+ inputs: torch.Tensor,
+ reconstructions: torch.Tensor,
+ extra_result_dict: Mapping[Text, torch.Tensor],
+ global_step: int,
+ mode: str = "generator",
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ # Both inputs and reconstructions are in range [0, 1].
+ inputs = inputs.float()
+ reconstructions = reconstructions.float()
+
+ if mode == "generator":
+ return self._forward_generator(inputs, reconstructions, extra_result_dict, global_step)
+ elif mode == "discriminator":
+ return self._forward_discriminator(inputs, reconstructions, global_step)
+ else:
+ raise ValueError(f"Unsupported mode {mode}")
+
+ def should_discriminator_be_trained(self, global_step : int):
+ return global_step >= self.discriminator_iter_start
+
+ def _forward_generator(self,
+ inputs: torch.Tensor,
+ reconstructions: torch.Tensor,
+ extra_result_dict: Mapping[Text, torch.Tensor],
+ global_step: int
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ """Generator training step."""
+ inputs = inputs.contiguous()
+ reconstructions = reconstructions.contiguous()
+ if self.reconstruction_loss == "l1":
+ reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
+ elif self.reconstruction_loss == "l2":
+ reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
+ else:
+ raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}")
+ reconstruction_loss *= self.reconstruction_weight
+
+ # Compute perceptual loss.
+ perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
+
+ # Compute discriminator loss.
+ generator_loss = torch.zeros((), device=inputs.device)
+ discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
+ d_weight = 1.0
+ if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
+ # Disable discriminator gradients.
+ for param in self.discriminator.parameters():
+ param.requires_grad = False
+ logits_fake = self.discriminator(reconstructions)
+ generator_loss = -torch.mean(logits_fake)
+
+ d_weight *= self.discriminator_weight
+
+ # Compute quantizer loss.
+ quantizer_loss = extra_result_dict["quantizer_loss"]
+ total_loss = (
+ reconstruction_loss
+ + self.perceptual_weight * perceptual_loss
+ + self.quantizer_weight * quantizer_loss
+ + d_weight * discriminator_factor * generator_loss
+ )
+ loss_dict = dict(
+ total_loss=total_loss.clone().detach(),
+ reconstruction_loss=reconstruction_loss.detach(),
+ perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
+ quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
+ weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
+ discriminator_factor=torch.tensor(discriminator_factor),
+ commitment_loss=extra_result_dict["commitment_loss"].detach(),
+ codebook_loss=extra_result_dict["codebook_loss"].detach(),
+ d_weight=d_weight,
+ gan_loss=generator_loss.detach(),
+ )
+
+ return total_loss, loss_dict
+
+ def _forward_discriminator(self,
+ inputs: torch.Tensor,
+ reconstructions: torch.Tensor,
+ global_step: int,
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ """Discrminator training step."""
+ discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
+ loss_dict = {}
+ # Turn the gradients on.
+ for param in self.discriminator.parameters():
+ param.requires_grad = True
+
+ real_images = inputs.detach().requires_grad_(True)
+ logits_real = self.discriminator(real_images)
+ logits_fake = self.discriminator(reconstructions.detach())
+
+ discriminator_loss = discriminator_factor * hinge_d_loss(logits_real=logits_real, logits_fake=logits_fake)
+
+ # optional lecam regularization
+ lecam_loss = torch.zeros((), device=inputs.device)
+ if self.lecam_regularization_weight > 0.0:
+ lecam_loss = compute_lecam_loss(
+ torch.mean(logits_real),
+ torch.mean(logits_fake),
+ self.ema_real_logits_mean,
+ self.ema_fake_logits_mean
+ ) * self.lecam_regularization_weight
+
+ self.ema_real_logits_mean = self.ema_real_logits_mean * self.lecam_ema_decay + torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay)
+ self.ema_fake_logits_mean = self.ema_fake_logits_mean * self.lecam_ema_decay + torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay)
+
+ discriminator_loss += lecam_loss
+
+ loss_dict = dict(
+ discriminator_loss=discriminator_loss.detach(),
+ logits_real=logits_real.detach().mean(),
+ logits_fake=logits_fake.detach().mean(),
+ lecam_loss=lecam_loss.detach(),
+ )
+ return discriminator_loss, loss_dict
+
+
+class ReconstructionLoss_Single_Stage(ReconstructionLoss_Stage2):
+ def __init__(
+ self,
+ config
+ ):
+ super().__init__(config)
+ loss_config = config.losses
+ self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq")
+
+ if self.quantize_mode == "vae":
+ self.kl_weight = loss_config.get("kl_weight", 1e-6)
+ logvar_init = loss_config.get("logvar_init", 0.0)
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init, requires_grad=False)
+
+ def _forward_generator(self,
+ inputs: torch.Tensor,
+ reconstructions: torch.Tensor,
+ extra_result_dict: Mapping[Text, torch.Tensor],
+ global_step: int
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ """Generator training step."""
+ inputs = inputs.contiguous()
+ reconstructions = reconstructions.contiguous()
+ if self.reconstruction_loss == "l1":
+ reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
+ elif self.reconstruction_loss == "l2":
+ reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
+ else:
+ raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}")
+ reconstruction_loss *= self.reconstruction_weight
+
+ # Compute perceptual loss.
+ perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
+
+ # Compute discriminator loss.
+ generator_loss = torch.zeros((), device=inputs.device)
+ discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
+ d_weight = 1.0
+ if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
+ # Disable discriminator gradients.
+ for param in self.discriminator.parameters():
+ param.requires_grad = False
+ logits_fake = self.discriminator(reconstructions)
+ generator_loss = -torch.mean(logits_fake)
+
+ d_weight *= self.discriminator_weight
+
+ if self.quantize_mode in ["vq", "mvq", "softvq"]:
+ # Compute quantizer loss.
+ quantizer_loss = extra_result_dict["quantizer_loss"]
+ total_loss = (
+ reconstruction_loss
+ + self.perceptual_weight * perceptual_loss
+ + self.quantizer_weight * quantizer_loss
+ + d_weight * discriminator_factor * generator_loss
+ )
+ loss_dict = dict(
+ total_loss=total_loss.clone().detach(),
+ reconstruction_loss=reconstruction_loss.detach(),
+ perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
+ quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
+ weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
+ discriminator_factor=torch.tensor(discriminator_factor),
+ commitment_loss=extra_result_dict["commitment_loss"].detach(),
+ codebook_loss=extra_result_dict["codebook_loss"].detach(),
+ d_weight=d_weight,
+ gan_loss=generator_loss.detach(),
+ )
+ elif self.quantize_mode == "vae":
+ # Compute kl loss.
+ reconstruction_loss = reconstruction_loss / torch.exp(self.logvar)
+ posteriors = extra_result_dict
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ total_loss = (
+ reconstruction_loss
+ + self.perceptual_weight * perceptual_loss
+ + self.kl_weight * kl_loss
+ + d_weight * discriminator_factor * generator_loss
+ )
+ loss_dict = dict(
+ total_loss=total_loss.clone().detach(),
+ reconstruction_loss=reconstruction_loss.detach(),
+ perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
+ kl_loss=(self.kl_weight * kl_loss).detach(),
+ weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
+ discriminator_factor=torch.tensor(discriminator_factor),
+ d_weight=d_weight,
+ gan_loss=generator_loss.detach(),
+ )
+ else:
+ raise NotImplementedError
+
+ return total_loss, loss_dict
\ No newline at end of file
diff --git a/modeling/modules/lpips.py b/modeling/modules/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..b204699e1f2bf242dfbfc7eef4d34f2caad0fde7
--- /dev/null
+++ b/modeling/modules/lpips.py
@@ -0,0 +1,181 @@
+"""LPIPS perceptual loss.
+
+Reference:
+ https://github.com/richzhang/PerceptualSimilarity/
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py
+ https://github.com/CompVis/taming-transformers/blob/master/taming/util.py
+"""
+
+import os
+import hashlib
+import requests
+from collections import namedtuple
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+
+from torchvision import models
+
+_LPIPS_MEAN = [-0.030, -0.088, -0.188]
+_LPIPS_STD = [0.458, 0.448, 0.450]
+
+
+URL_MAP = {
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
+}
+
+CKPT_MAP = {
+ "vgg_lpips": "vgg.pth"
+}
+
+MD5_MAP = {
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
+}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric.
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_pretrained(self):
+ workspace = os.environ.get('WORKSPACE', '')
+ VGG_PATH = get_ckpt_path("vgg_lpips", os.path.join(workspace, "models/vgg_lpips.pth"), check=True)
+ self.load_state_dict(torch.load(VGG_PATH, map_location=torch.device("cpu"), weights_only=True), strict=False)
+
+ def forward(self, input, target):
+ # Notably, the LPIPS w/ pre-trained weights expect the input in the range of [-1, 1].
+ # However, our codebase assumes all inputs are in range of [0, 1], and thus a scaling is needed.
+ input = input * 2. - 1.
+ target = target * 2. - 1.
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer("shift", torch.Tensor(_LPIPS_MEAN)[None, :, None, None])
+ self.register_buffer("scale", torch.Tensor(_LPIPS_STD)[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """A single linear layer which does a 1x1 conv."""
+
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = (
+ [
+ nn.Dropout(),
+ ]
+ if (use_dropout)
+ else []
+ )
+ layers += [
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
+ ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2, 3], keepdim=keepdim)
diff --git a/modeling/modules/maskgit_vqgan.py b/modeling/modules/maskgit_vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..8275b811c83f5ad254d40ace51fd74c3857d34e6
--- /dev/null
+++ b/modeling/modules/maskgit_vqgan.py
@@ -0,0 +1,346 @@
+"""MaskGIT-VQGAN tokenizer.
+
+Reference:
+ https://github.com/huggingface/open-muse/blob/main/muse/modeling_maskgit_vqgan.py
+"""
+
+r"""MaskGIT Tokenizer based on VQGAN.
+
+This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841]
+with several modifications. The non-local layers are removed from VQGAN for
+faster speed.
+"""
+
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+# Conv2D with same padding
+class Conv2dSame(nn.Conv2d):
+ def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
+ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ ih, iw = x.size()[-2:]
+
+ pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
+ pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
+
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+ return super().forward(x)
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = None,
+ dropout_prob: float = 0.0,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = Conv2dSame(self.in_channels, self.out_channels_, kernel_size=3, bias=False)
+
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)
+ self.dropout = nn.Dropout(dropout_prob)
+ self.conv2 = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=3, bias=False)
+
+ if self.in_channels != self.out_channels_:
+ self.nin_shortcut = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=1, bias=False)
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels_:
+ residual = self.nin_shortcut(hidden_states)
+
+ return hidden_states + residual
+
+
+class DownsamplingBlock(nn.Module):
+ def __init__(self, config, block_idx: int):
+ super().__init__()
+
+ self.config = config
+ self.block_idx = block_idx
+
+ in_channel_mult = (1,) + tuple(self.config.channel_mult)
+ block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
+
+ res_blocks = nn.ModuleList()
+ for _ in range(self.config.num_res_blocks):
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
+ block_in = block_out
+ self.block = res_blocks
+
+ self.downsample = self.block_idx != self.config.num_resolutions - 1
+
+ def forward(self, hidden_states):
+ for res_block in self.block:
+ hidden_states = res_block(hidden_states)
+
+ if self.downsample:
+ hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2)
+
+ return hidden_states
+
+
+class UpsamplingBlock(nn.Module):
+ def __init__(self, config, block_idx: int):
+ super().__init__()
+
+ self.config = config
+ self.block_idx = block_idx
+
+ if self.block_idx == self.config.num_resolutions - 1:
+ block_in = self.config.hidden_channels * self.config.channel_mult[-1]
+ else:
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
+
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
+
+ res_blocks = []
+ for _ in range(self.config.num_res_blocks):
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
+ block_in = block_out
+ self.block = nn.ModuleList(res_blocks)
+
+ self.add_upsample = self.block_idx != 0
+ if self.add_upsample:
+ self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3)
+
+ def forward(self, hidden_states):
+ for res_block in self.block:
+ hidden_states = res_block(hidden_states)
+
+ if self.add_upsample:
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ hidden_states = self.upsample_conv(hidden_states)
+
+ return hidden_states
+
+
+class Encoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ # downsampling
+ self.conv_in = Conv2dSame(self.config.num_channels, self.config.hidden_channels, kernel_size=3, bias=False)
+
+ downsample_blocks = []
+ for i_level in range(self.config.num_resolutions):
+ downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level))
+ self.down = nn.ModuleList(downsample_blocks)
+
+ # middle
+ mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
+ res_blocks = nn.ModuleList()
+ for _ in range(self.config.num_res_blocks):
+ res_blocks.append(ResnetBlock(mid_channels, mid_channels, dropout_prob=self.config.dropout))
+ self.mid = res_blocks
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
+ self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1)
+
+ def forward(self, pixel_values):
+ # downsampling
+ hidden_states = self.conv_in(pixel_values)
+ for block in self.down:
+ hidden_states = block(hidden_states)
+
+ # middle
+ for block in self.mid:
+ hidden_states = block(hidden_states)
+
+ # end
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class Decoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+
+ # compute in_channel_mult, block_in and curr_res at lowest res
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
+ curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
+ self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3)
+
+ # middle
+ res_blocks = nn.ModuleList()
+ for _ in range(self.config.num_res_blocks):
+ res_blocks.append(ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout))
+ self.mid = res_blocks
+
+ # upsampling
+ upsample_blocks = []
+ for i_level in reversed(range(self.config.num_resolutions)):
+ upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level))
+ self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order
+
+ # end
+ block_out = self.config.hidden_channels * self.config.channel_mult[0]
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)
+ self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3)
+
+ def forward(self, hidden_states):
+ # z to block_in
+ hidden_states = self.conv_in(hidden_states)
+
+ # middle
+ for block in self.mid:
+ hidden_states = block(hidden_states)
+
+ # upsampling
+ for block in reversed(self.up):
+ hidden_states = block(hidden_states)
+
+ # end
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ Discretization bottleneck part of the VQ-VAE.
+ """
+
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
+ r"""
+ Args:
+ num_embeddings: number of vectors in the quantized space.
+ embedding_dim: dimensionality of the tensors in the quantized space.
+ Inputs to the modules must be in this format as well.
+ commitment_cost: scalar which controls the weighting of the loss terms
+ (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
+ """
+ super().__init__()
+
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.commitment_cost = commitment_cost
+
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
+ self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
+
+ def forward(self, hidden_states, return_loss=False):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
+ closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
+
+ distances = self.compute_distances(hidden_states)
+ min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
+
+ # reshape to (batch, num_tokens)
+ min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
+
+ # compute loss for embedding
+ loss = None
+ if return_loss:
+ loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean(
+ (z_q - hidden_states.detach()) ** 2
+ )
+ # preserve gradients
+ z_q = hidden_states + (z_q - hidden_states).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, min_encoding_indices, loss
+
+ def compute_distances(self, hidden_states):
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
+ emb_weights = self.embedding.weight.t()
+
+ inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
+ codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
+ distances = torch.addmm(
+ inputs_norm_sq + codebook_t_norm_sq,
+ hidden_states_flattended,
+ emb_weights,
+ alpha=-2.0,
+ )
+ return distances
+
+ def get_codebook_entry(self, indices):
+ # indices are expected to be of shape (batch, num_tokens)
+ # get quantized latent vectors
+ if len(indices.shape) == 2:
+ batch, num_tokens = indices.shape
+ z_q = self.embedding(indices)
+ z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
+ elif len(indices.shape) == 3:
+ batch, height, width = indices.shape
+ indices = indices.view(batch, -1)
+ z_q = self.embedding(indices)
+ z_q = z_q.reshape(batch, height, width, -1).permute(0, 3, 1, 2)
+ else:
+ print(indices.shape)
+ raise NotImplementedError
+ return z_q
+
+ # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
+ def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel)
+ distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings)
+
+ soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings)
+ if stochastic:
+ code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
+ else:
+ code = distances.argmin(dim=-1) # (batch * height * width)
+
+ code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
+ batch, num_tokens = code.shape
+ soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings)
+ return soft_code, code
+
+ def get_code(self, hidden_states):
+ # reshape z -> (batch, height, width, channel)
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
+ distances = self.compute_distances(hidden_states)
+ indices = torch.argmin(distances, axis=1).unsqueeze(1)
+ indices = indices.reshape(hidden_states.shape[0], -1)
+ return indices
\ No newline at end of file
diff --git a/modeling/modules/perceptual_loss.py b/modeling/modules/perceptual_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a6f0c8fe334cf17812f2d703acce86cfda1348
--- /dev/null
+++ b/modeling/modules/perceptual_loss.py
@@ -0,0 +1,101 @@
+"""Perceptual loss module using LPIPS and ConvNeXt-S."""
+
+import torch
+import torch.nn.functional as F
+
+from torchvision import models
+from .lpips import LPIPS
+
+_IMAGENET_MEAN = [0.485, 0.456, 0.406]
+_IMAGENET_STD = [0.229, 0.224, 0.225]
+
+
+class PerceptualLoss(torch.nn.Module):
+ def __init__(self, model_name: str = "convnext_s"):
+ """Initializes the PerceptualLoss class.
+
+ Args:
+ model_name: A string, the name of the perceptual loss model to use.
+
+ Raise:
+ ValueError: If the model_name does not contain "lpips" or "convnext_s".
+ """
+ super().__init__()
+ if ("lpips" not in model_name) and (
+ "convnext_s" not in model_name):
+ raise ValueError(f"Unsupported Perceptual Loss model name {model_name}")
+ self.lpips = None
+ self.convnext = None
+ self.loss_weight_lpips = None
+ self.loss_weight_convnext = None
+
+ # Parsing the model name. We support name formatted in
+ # "lpips-convnext_s-{float_number}-{float_number}", where the
+ # {float_number} refers to the loss weight for each component.
+ # E.g., lpips-convnext_s-1.0-2.0 refers to compute the perceptual loss
+ # using both the convnext_s and lpips, and average the final loss with
+ # (1.0 * loss(lpips) + 2.0 * loss(convnext_s)) / (1.0 + 2.0).
+ if "lpips" in model_name:
+ self.lpips = LPIPS().eval()
+
+ if "convnext_s" in model_name:
+ self.convnext = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).eval()
+
+ if "lpips" in model_name and "convnext_s" in model_name:
+ loss_config = model_name.split('-')[-2:]
+ self.loss_weight_lpips, self.loss_weight_convnext = float(loss_config[0]), float(loss_config[1])
+ print(f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}")
+
+ self.register_buffer("imagenet_mean", torch.Tensor(_IMAGENET_MEAN)[None, :, None, None])
+ self.register_buffer("imagenet_std", torch.Tensor(_IMAGENET_STD)[None, :, None, None])
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
+ """Computes the perceptual loss.
+
+ Args:
+ input: A tensor of shape (B, C, H, W), the input image. Normalized to [0, 1].
+ target: A tensor of shape (B, C, H, W), the target image. Normalized to [0, 1].
+
+ Returns:
+ A scalar tensor, the perceptual loss.
+ """
+ # Always in eval mode.
+ self.eval()
+ loss = 0.
+ num_losses = 0.
+ lpips_loss = 0.
+ convnext_loss = 0.
+ # Computes LPIPS loss, if available.
+ if self.lpips is not None:
+ lpips_loss = self.lpips(input, target)
+ if self.loss_weight_lpips is None:
+ loss += lpips_loss
+ num_losses += 1
+ else:
+ num_losses += self.loss_weight_lpips
+ loss += self.loss_weight_lpips * lpips_loss
+
+ if self.convnext is not None:
+ # Computes ConvNeXt-s loss, if available.
+ input = torch.nn.functional.interpolate(input, size=224, mode="bilinear", align_corners=False, antialias=True)
+ target = torch.nn.functional.interpolate(target, size=224, mode="bilinear", align_corners=False, antialias=True)
+ pred_input = self.convnext((input - self.imagenet_mean) / self.imagenet_std)
+ pred_target = self.convnext((target - self.imagenet_mean) / self.imagenet_std)
+ convnext_loss = torch.nn.functional.mse_loss(
+ pred_input,
+ pred_target,
+ reduction="mean")
+
+ if self.loss_weight_convnext is None:
+ num_losses += 1
+ loss += convnext_loss
+ else:
+ num_losses += self.loss_weight_convnext
+ loss += self.loss_weight_convnext * convnext_loss
+
+ # weighted avg.
+ loss = loss / num_losses
+ return loss
\ No newline at end of file
diff --git a/modeling/quantizer/__init__.py b/modeling/quantizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4707078c10391a8ba7edffcb3e72178adbd5a128
--- /dev/null
+++ b/modeling/quantizer/__init__.py
@@ -0,0 +1,3 @@
+from .quantizer import VectorQuantizer, DiagonalGaussianDistribution
+from .mvq import VectorQuantizerMVQ
+from .softvq import SoftVectorQuantizer
\ No newline at end of file
diff --git a/modeling/quantizer/dist.py b/modeling/quantizer/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..48d7fc9962a26decb036487720ad41e50082984c
--- /dev/null
+++ b/modeling/quantizer/dist.py
@@ -0,0 +1,302 @@
+import datetime
+import functools
+import os
+import sys
+from typing import List
+from typing import Union
+
+import pytz
+import torch
+import torch.distributed as tdist
+import torch.multiprocessing as mp
+
+__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
+__rank_str_zfill = '0'
+__initialized = False
+
+
+def initialized():
+ return __initialized
+
+
+def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
+ global __device
+ if not torch.cuda.is_available():
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
+ return
+ elif 'RANK' not in os.environ:
+ torch.cuda.set_device(gpu_id_if_not_distibuted)
+ __device = torch.empty(1).cuda().device
+ print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
+ return
+ # then 'RANK' must exist
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
+ local_rank = global_rank % num_gpus
+ torch.cuda.set_device(local_rank)
+
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
+ if mp.get_start_method(allow_none=True) is None:
+ method = 'fork' if fork else 'spawn'
+ print(f'[dist initialize] mp method={method}')
+ mp.set_start_method(method)
+ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
+
+ global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
+ __local_rank = local_rank
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
+ __rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
+ __device = torch.empty(1).cuda().device
+ __initialized = True
+
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
+ print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
+
+
+def get_rank():
+ return __rank
+
+
+def get_rank_str_zfill():
+ return __rank_str_zfill
+
+
+def get_local_rank():
+ return __local_rank
+
+
+def get_world_size():
+ return __world_size
+
+
+def get_device():
+ return __device
+
+
+def set_gpu_id(gpu_id: int):
+ if gpu_id is None: return
+ global __device
+ if isinstance(gpu_id, (str, int)):
+ torch.cuda.set_device(int(gpu_id))
+ __device = torch.empty(1).cuda().device
+ else:
+ raise NotImplementedError
+
+
+def is_master():
+ return __rank == 0
+
+
+def is_local_master():
+ return __local_rank == 0
+
+
+def new_group(ranks: List[int]):
+ if __initialized:
+ return tdist.new_group(ranks=ranks)
+ return None
+
+
+def new_local_machine_group():
+ if __initialized:
+ cur_subgroup, subgroups = tdist.new_subgroups()
+ return cur_subgroup
+ return None
+
+
+def barrier():
+ if __initialized:
+ tdist.barrier()
+
+
+def allreduce(t: torch.Tensor, async_op=False):
+ if __initialized:
+ if not t.is_cuda:
+ cu = t.detach().cuda()
+ ret = tdist.all_reduce(cu, async_op=async_op)
+ t.copy_(cu.cpu())
+ else:
+ ret = tdist.all_reduce(t, async_op=async_op)
+ return ret
+ return None
+
+
+def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
+ if __initialized:
+ if not t.is_cuda:
+ t = t.cuda()
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
+ tdist.all_gather(ls, t)
+ else:
+ ls = [t]
+ if cat:
+ ls = torch.cat(ls, dim=0)
+ return ls
+
+
+def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
+ if __initialized:
+ if not t.is_cuda:
+ t = t.cuda()
+
+ t_size = torch.tensor(t.size(), device=t.device)
+ ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
+ tdist.all_gather(ls_size, t_size)
+
+ max_B = max(size[0].item() for size in ls_size)
+ pad = max_B - t_size[0].item()
+ if pad:
+ pad_size = (pad, *t.size()[1:])
+ t = torch.cat((t, t.new_empty(pad_size)), dim=0)
+
+ ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
+ tdist.all_gather(ls_padded, t)
+ ls = []
+ for t, size in zip(ls_padded, ls_size):
+ ls.append(t[:size[0].item()])
+ else:
+ ls = [t]
+ if cat:
+ ls = torch.cat(ls, dim=0)
+ return ls
+
+
+def broadcast(t: torch.Tensor, src_rank) -> None:
+ if __initialized:
+ if not t.is_cuda:
+ cu = t.detach().cuda()
+ tdist.broadcast(cu, src=src_rank)
+ t.copy_(cu.cpu())
+ else:
+ tdist.broadcast(t, src=src_rank)
+
+
+def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
+ if not initialized():
+ return torch.tensor([val]) if fmt is None else [fmt % val]
+
+ ts = torch.zeros(__world_size)
+ ts[__rank] = val
+ allreduce(ts)
+ if fmt is None:
+ return ts
+ return [fmt % v for v in ts.cpu().numpy().tolist()]
+
+
+def master_only(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if force or is_master():
+ ret = func(*args, **kwargs)
+ else:
+ ret = None
+ barrier()
+ return ret
+ return wrapper
+
+
+def local_master_only(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if force or is_local_master():
+ ret = func(*args, **kwargs)
+ else:
+ ret = None
+ barrier()
+ return ret
+ return wrapper
+
+
+def for_visualize(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if is_master():
+ # with torch.no_grad():
+ ret = func(*args, **kwargs)
+ else:
+ ret = None
+ return ret
+ return wrapper
+
+
+def finalize():
+ if __initialized:
+ tdist.destroy_process_group()
+
+
+def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30):
+ try:
+ __initialize(fork=False, timeout_minutes=timeout_minutes)
+ barrier()
+ except RuntimeError as e:
+ print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
+ raise e
+
+ if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
+ _change_builtin_print(is_local_master())
+ if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
+ sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)
+
+
+def _change_builtin_print(is_master):
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+ if type(builtin_print) != type(open):
+ return
+
+ def prt(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ clean = kwargs.pop('clean', False)
+ deeper = kwargs.pop('deeper', False)
+ if is_master or force:
+ if not clean:
+ f_back = sys._getframe().f_back
+ if deeper and f_back.f_back is not None:
+ f_back = f_back.f_back
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
+ builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
+ else:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = prt
+
+
+class BackupStreamToFile(object):
+ def __init__(self, local_output_dir, for_stdout=True):
+ self.for_stdout = for_stdout
+ self.terminal_stream = sys.stdout if for_stdout else sys.stderr
+ fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt')
+ existing = os.path.exists(fname)
+ self.file_stream = open(fname, 'a')
+ if existing:
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
+ self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
+ self.file_stream.flush()
+ self.enabled = True
+
+ def write(self, message):
+ self.terminal_stream.write(message)
+ self.file_stream.write(message)
+
+ def flush(self):
+ self.terminal_stream.flush()
+ self.file_stream.flush()
+
+ def close(self):
+ if not self.enabled:
+ return
+ self.enabled = False
+ self.file_stream.flush()
+ self.file_stream.close()
+ if self.for_stdout:
+ sys.stdout = self.terminal_stream
+ sys.stdout.flush()
+ else:
+ sys.stderr = self.terminal_stream
+ sys.stderr.flush()
+
+ def __del__(self):
+ self.close()
\ No newline at end of file
diff --git a/modeling/quantizer/mvq.py b/modeling/quantizer/mvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..97dd699299b65af40bfc0690ff240f749a71ad29
--- /dev/null
+++ b/modeling/quantizer/mvq.py
@@ -0,0 +1,159 @@
+import torch
+from typing import List, Tuple
+from torch.nn import functional as F
+from torch import distributed as tdist, nn as nn
+
+from .quantizer import VectorQuantizer
+
+def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
+ E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
+ E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
+ logits = -E_dist.float().mul_(inv_entropy_tau)
+ # calc per_sample_entropy
+ prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
+ per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
+ # calc codebook_entropy
+ avg_prob = prob.mean(dim=0) # (vocab_size,)
+ log_avg_prob = torch.log(avg_prob + 1e-7)
+ codebook_entropy = (-avg_prob * log_avg_prob).sum()
+ # calc entropy_loss
+ entropy_loss = per_sample_entropy - codebook_entropy
+ return entropy_loss
+
+
+class NormalizedEmbedding(nn.Embedding):
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
+ # self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
+
+ def forward(self, idx):
+ return F.embedding(
+ idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse
+ )
+
+ def get_norm_weight(self):
+ return F.normalize(self.weight, dim=1)
+
+
+class ResConv(nn.Conv2d):
+ def __init__(self, embed_dim, quant_resi):
+ ks = 3 if quant_resi < 0 else 1
+ super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
+ self.resi_ratio = abs(quant_resi)
+
+ def forward(self, h_BChw):
+ return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
+
+class VectorQuantizerMVQ(nn.Module):
+ def __init__(
+ self,
+ codebook_size,
+ token_size,
+ commitment_cost=0.25,
+ use_l2_norm=False,
+ # entropy_temp=0.01, # we do not use this
+ clustering_vq=False,
+ num_codebooks=16
+ ):
+ super().__init__()
+ self.num_codebooks = num_codebooks
+ self.codebooks = nn.ModuleList()
+ for _ in range(num_codebooks):
+ codebook = VectorQuantizer(
+ codebook_size=codebook_size // num_codebooks,
+ token_size=token_size // num_codebooks,
+ commitment_cost=commitment_cost,
+ use_l2_norm=use_l2_norm,
+ clustering_vq=clustering_vq,
+ )
+ self.codebooks.append(codebook)
+
+ def init_vocab(self, eini: float):
+ for codebook in self.codebooks:
+ codebook.init_vocab(eini)
+
+ def f_to_idx(self, features):
+ indices = []
+ chunk_size = features.shape[-1] // self.num_codebooks
+ splited_features = features.split(chunk_size, dim=-1)
+ for i, codebook in enumerate(self.codebooks):
+ indices.append(codebook.f_to_idx(splited_features[i]))
+ indices = torch.stack(indices, dim=1)
+ return indices
+
+ def idx_to_f(self, indices):
+ assert indices.shape[1] == self.num_codebooks
+ latent_features = []
+ for i, codebook in enumerate(self.codebooks):
+ sub_indices = indices[:, i].flatten(start_dim=1)
+ latent_feature = codebook.codebook(sub_indices)
+ latent_features.append(latent_feature)
+ latent_features = torch.cat(latent_features, dim=-1)
+ return latent_features
+
+ def get_codebook_entry(self, indices):
+ """Get codebook entries for multi-codebook indices.
+
+ Args:
+ indices: Tensor of shape (N, num_codebooks) or (N, num_codebooks, H, W)
+
+ Returns:
+ z_quantized: Quantized features
+ """
+ if len(indices.shape) == 2:
+ # indices shape: (N, num_codebooks)
+ latent_features = []
+ for i, codebook in enumerate(self.codebooks):
+ sub_indices = indices[:, i]
+ latent_feature = codebook.get_codebook_entry(sub_indices)
+ latent_features.append(latent_feature)
+ return torch.cat(latent_features, dim=-1)
+ elif len(indices.shape) == 4:
+ # indices shape: (B, num_codebooks, H, W)
+ batch_size, _, height, width = indices.shape
+ latent_features = []
+ for i, codebook in enumerate(self.codebooks):
+ sub_indices = indices[:, i] # (B, H, W)
+ latent_feature = codebook.get_codebook_entry(sub_indices.flatten())
+ # Reshape to (B, H, W, token_size // num_codebooks)
+ latent_feature = latent_feature.view(batch_size, height, width, -1)
+ latent_features.append(latent_feature)
+ # Concatenate along the last dimension and rearrange to (B, C, H, W)
+ latent_features = torch.cat(latent_features, dim=-1) # (B, H, W, C)
+ return latent_features.permute(0, 3, 1, 2).contiguous() # (B, C, H, W)
+ else:
+ raise NotImplementedError(f"Unsupported indices shape: {indices.shape}")
+
+ def forward(self, features):
+ latent_features = []
+ all_result_dicts = []
+ chunk_size = features.shape[1] // self.num_codebooks
+ splited_features = features.split(chunk_size, dim=1)
+
+ for i, codebook in enumerate(self.codebooks):
+ latent_feature, result_dict = codebook(splited_features[i].float())
+ latent_features.append(latent_feature.to(features.dtype))
+ all_result_dicts.append(result_dict)
+
+ # Concatenate latent features
+ z_quantized = torch.cat(latent_features, dim=1) # Concatenate along channel dimension
+
+ # Calculate global losses
+ global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks
+ global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks
+ global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks
+
+ # Collect all min_encoding_indices
+ # Each codebook returns indices of shape (B, H, W)
+ # Stack them to get shape (B, num_codebooks, H, W)
+ all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1)
+
+ result_dict = dict(
+ quantizer_loss=global_quantizer_loss,
+ commitment_loss=global_commitment_loss,
+ codebook_loss=global_codebook_loss,
+ min_encoding_indices=all_indices
+ )
+
+ return z_quantized, result_dict
\ No newline at end of file
diff --git a/modeling/quantizer/quantizer.py b/modeling/quantizer/quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..db529f1dfe0c8584d0af1a50ff4d3c4ab3949c81
--- /dev/null
+++ b/modeling/quantizer/quantizer.py
@@ -0,0 +1,158 @@
+"""Vector quantizer.
+
+Reference:
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py
+ https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py
+ https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py
+ https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py
+"""
+from typing import Mapping, Text, Tuple
+
+import torch
+from einops import rearrange
+from accelerate.utils.operations import gather
+from torch.cuda.amp import autocast
+
+class VectorQuantizer(torch.nn.Module):
+ def __init__(self,
+ codebook_size: int = 1024,
+ token_size: int = 256,
+ commitment_cost: float = 0.25,
+ use_l2_norm: bool = False,
+ clustering_vq: bool = False
+ ):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.token_size = token_size
+ self.commitment_cost = commitment_cost
+
+ self.embedding = torch.nn.Embedding(codebook_size, token_size)
+ self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
+ self.use_l2_norm = use_l2_norm
+
+ self.clustering_vq = clustering_vq
+ if clustering_vq:
+ self.decay = 0.99
+ self.register_buffer("embed_prob", torch.zeros(self.codebook_size))
+
+ # Ensure quantization is performed using f32
+ @autocast(enabled=False)
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ z = z.float()
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = rearrange(z, 'b h w c -> (b h w) c')
+ unnormed_z_flattened = z_flattened
+
+ if self.use_l2_norm:
+ z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1)
+ embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
+ else:
+ embedding = self.embedding.weight
+ d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
+ torch.sum(embedding**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, embedding.T)
+
+ min_encoding_indices = torch.argmin(d, dim=1) # num_ele
+ z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
+
+ if self.use_l2_norm:
+ z = torch.nn.functional.normalize(z, dim=-1)
+
+ # compute loss for embedding
+ commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2)
+ codebook_loss = torch.mean((z_quantized - z.detach()) **2)
+
+ if self.clustering_vq and self.training:
+ with torch.no_grad():
+ # Gather distance matrix from all GPUs.
+ encoding_indices = gather(min_encoding_indices)
+ if len(min_encoding_indices.shape) != 1:
+ raise ValueError(f"min_encoding_indices in a wrong shape, {min_encoding_indices.shape}")
+ # Compute and update the usage of each entry in the codebook.
+ encodings = torch.zeros(encoding_indices.shape[0], self.codebook_size, device=z.device)
+ encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
+ avg_probs = torch.mean(encodings, dim=0)
+ self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1-self.decay)
+ # Closest sampling to update the codebook.
+ all_d = gather(d)
+ all_unnormed_z_flattened = gather(unnormed_z_flattened).detach()
+ if all_d.shape[0] != all_unnormed_z_flattened.shape[0]:
+ raise ValueError(
+ "all_d and all_unnormed_z_flattened have different length" +
+ f"{all_d.shape}, {all_unnormed_z_flattened.shape}")
+ indices = torch.argmin(all_d, dim=0)
+ random_feat = all_unnormed_z_flattened[indices]
+ # Decay parameter based on the average usage.
+ decay = torch.exp(-(self.embed_prob * self.codebook_size * 10) /
+ (1 - self.decay) - 1e-3).unsqueeze(1).repeat(1, self.token_size)
+ self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
+
+ loss = commitment_loss + codebook_loss
+
+ # preserve gradients
+ z_quantized = z + (z_quantized - z).detach()
+
+ # reshape back to match original input shape
+ z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
+
+ result_dict = dict(
+ quantizer_loss=loss,
+ commitment_loss=commitment_loss,
+ codebook_loss=codebook_loss,
+ min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
+ )
+
+ return z_quantized, result_dict
+
+ @autocast(enabled=False)
+ def get_codebook_entry(self, indices):
+ indices = indices.long()
+ if len(indices.shape) == 1:
+ z_quantized = self.embedding(indices)
+ elif len(indices.shape) == 2:
+ z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)
+ else:
+ raise NotImplementedError
+ if self.use_l2_norm:
+ z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
+ return z_quantized
+
+
+class DiagonalGaussianDistribution(object):
+ @autocast(enabled=False)
+ def __init__(self, parameters, deterministic=False):
+ """Initializes a Gaussian distribution instance given the parameters.
+
+ Args:
+ parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected
+ to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension.
+ First C channels are used for mean and last C are used for logvar in the Gaussian distribution.
+ deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results
+ is purely based on mean (i.e., std = 0).
+ """
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ @autocast(enabled=False)
+ def sample(self):
+ x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ @autocast(enabled=False)
+ def mode(self):
+ return self.mean
+
+ @autocast(enabled=False)
+ def kl(self):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ return 0.5 * torch.sum(torch.pow(self.mean.float(), 2)
+ + self.var.float() - 1.0 - self.logvar.float(),
+ dim=[1, 2])
diff --git a/modeling/quantizer/softvq.py b/modeling/quantizer/softvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..68ad4413a9b2b78e1e44d81e0bd2158f3d7d9db5
--- /dev/null
+++ b/modeling/quantizer/softvq.py
@@ -0,0 +1,170 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Mapping, Text, Tuple
+from einops import rearrange
+from torch.cuda.amp import autocast
+
+
+class SoftVectorQuantizer(torch.nn.Module):
+ def __init__(self,
+ codebook_size: int = 1024,
+ token_size: int = 256,
+ commitment_cost: float = 0.25,
+ use_l2_norm: bool = False,
+ clustering_vq: bool = False,
+ entropy_loss_ratio: float = 0.01,
+ tau: float = 0.07,
+ num_codebooks: int = 1,
+ show_usage: bool = False
+ ):
+ super().__init__()
+ # Map new parameter names to internal names for compatibility
+ self.codebook_size = codebook_size
+ self.token_size = token_size
+ self.commitment_cost = commitment_cost
+ self.use_l2_norm = use_l2_norm
+ self.clustering_vq = clustering_vq
+
+ # Keep soft quantization specific parameters
+ self.num_codebooks = num_codebooks
+ self.n_e = codebook_size
+ self.e_dim = token_size
+ self.entropy_loss_ratio = entropy_loss_ratio
+ self.l2_norm = use_l2_norm
+ self.show_usage = show_usage
+ self.tau = tau
+
+ # Single embedding layer for all codebooks
+ self.embedding = nn.Parameter(torch.randn(num_codebooks, codebook_size, token_size))
+ self.embedding.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ if self.l2_norm:
+ self.embedding.data = F.normalize(self.embedding.data, p=2, dim=-1)
+
+ if self.show_usage:
+ self.register_buffer("codebook_used", torch.zeros(num_codebooks, 65536))
+
+ # Ensure quantization is performed using f32
+ @autocast(enabled=False)
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ z = z.float()
+ original_shape = z.shape
+
+ # Handle input reshaping to match VectorQuantizer format
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z = z.view(z.size(0), -1, z.size(-1))
+
+ batch_size, seq_length, _ = z.shape
+
+ # Ensure sequence length is divisible by number of codebooks
+ assert seq_length % self.num_codebooks == 0, \
+ f"Sequence length ({seq_length}) must be divisible by number of codebooks ({self.num_codebooks})"
+
+ segment_length = seq_length // self.num_codebooks
+ z_segments = z.view(batch_size, self.num_codebooks, segment_length, self.e_dim)
+
+ # Apply L2 norm if needed
+ embedding = F.normalize(self.embedding, p=2, dim=-1) if self.l2_norm else self.embedding
+ if self.l2_norm:
+ z_segments = F.normalize(z_segments, p=2, dim=-1)
+
+ z_flat = z_segments.permute(1, 0, 2, 3).contiguous().view(self.num_codebooks, -1, self.e_dim)
+
+ logits = torch.einsum('nbe, nke -> nbk', z_flat, embedding.detach())
+
+ # Calculate probabilities (soft quantization)
+ probs = F.softmax(logits / self.tau, dim=-1)
+
+ # Soft quantize
+ z_q = torch.einsum('nbk, nke -> nbe', probs, embedding)
+
+ # Reshape back
+ z_q = z_q.view(self.num_codebooks, batch_size, segment_length, self.e_dim).permute(1, 0, 2, 3).contiguous()
+
+ # Calculate cosine similarity
+ with torch.no_grad():
+ zq_z_cos = F.cosine_similarity(
+ z_segments.view(-1, self.e_dim),
+ z_q.view(-1, self.e_dim),
+ dim=-1
+ ).mean()
+
+ # Get indices for usage tracking
+ indices = torch.argmax(probs, dim=-1) # (num_codebooks, batch_size * segment_length)
+ indices = indices.transpose(0, 1).contiguous() # (batch_size * segment_length, num_codebooks)
+
+ # Track codebook usage
+ if self.show_usage and self.training:
+ for k in range(self.num_codebooks):
+ cur_len = indices.size(0)
+ self.codebook_used[k, :-cur_len].copy_(self.codebook_used[k, cur_len:].clone())
+ self.codebook_used[k, -cur_len:].copy_(indices[:, k])
+
+ # Calculate losses if training
+ if self.training:
+ # Soft quantization doesn't have traditional commitment/codebook loss
+ # Map entropy loss to quantizer_loss for compatibility
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(logits.view(-1, self.n_e))
+ quantizer_loss = entropy_loss
+ commitment_loss = torch.tensor(0.0, device=z.device)
+ codebook_loss = torch.tensor(0.0, device=z.device)
+ else:
+ quantizer_loss = torch.tensor(0.0, device=z.device)
+ commitment_loss = torch.tensor(0.0, device=z.device)
+ codebook_loss = torch.tensor(0.0, device=z.device)
+
+ # Calculate codebook usage
+ codebook_usage = torch.tensor([
+ len(torch.unique(self.codebook_used[k])) / self.n_e
+ for k in range(self.num_codebooks)
+ ]).mean() if self.show_usage else 0
+
+ z_q = z_q.view(batch_size, -1, self.e_dim)
+
+ # Reshape back to original input shape to match VectorQuantizer
+ z_q = z_q.view(batch_size, original_shape[2], original_shape[3], original_shape[1])
+ z_quantized = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ # Calculate average probabilities
+ avg_probs = torch.mean(torch.mean(probs, dim=-1))
+ max_probs = torch.mean(torch.max(probs, dim=-1)[0])
+
+ # Return format matching VectorQuantizer
+ result_dict = dict(
+ quantizer_loss=quantizer_loss,
+ commitment_loss=commitment_loss,
+ codebook_loss=codebook_loss,
+ min_encoding_indices=indices.view(batch_size, self.num_codebooks, segment_length).view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
+ )
+
+ return z_quantized, result_dict
+
+ def get_codebook_entry(self, indices):
+ """Added for compatibility with VectorQuantizer API"""
+ if len(indices.shape) == 1:
+ # For single codebook case
+ z_quantized = self.embedding[0][indices]
+ elif len(indices.shape) == 2:
+ z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding[0])
+ else:
+ raise NotImplementedError
+ if self.use_l2_norm:
+ z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
+ return z_quantized
+
+
+def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
+ flat_affinity /= temperature
+ probs = F.softmax(flat_affinity, dim=-1)
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
+ if loss_type == "softmax":
+ target_probs = probs
+ else:
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
+ avg_probs = torch.mean(target_probs, dim=0)
+ avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-6))
+ sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
+ loss = sample_entropy - avg_entropy
+ return loss
\ No newline at end of file
diff --git a/modeling/vibetoken_model.py b/modeling/vibetoken_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a77d2a60f9caf3074a5c6d4f1a6c89980ee488b
--- /dev/null
+++ b/modeling/vibetoken_model.py
@@ -0,0 +1,219 @@
+"""VibeToken model definition."""
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from modeling.modules.base_model import BaseModel
+from modeling.modules.encoder_decoder import ResolutionEncoder, ResolutionDecoder
+from modeling.quantizer import VectorQuantizer, DiagonalGaussianDistribution, VectorQuantizerMVQ, SoftVectorQuantizer
+from modeling.modules.maskgit_vqgan import Encoder as Pixel_Eecoder
+from modeling.modules.maskgit_vqgan import Decoder as Pixel_Decoder
+from modeling.modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
+import json
+from omegaconf import OmegaConf
+from pathlib import Path
+
+from huggingface_hub import PyTorchModelHubMixin
+
+
+class PretrainedTokenizer(nn.Module):
+ def __init__(self, pretrained_weight):
+ super().__init__()
+ conf = OmegaConf.create(
+ {"channel_mult": [1, 1, 2, 2, 4],
+ "num_resolutions": 5,
+ "dropout": 0.0,
+ "hidden_channels": 128,
+ "num_channels": 3,
+ "num_res_blocks": 2,
+ "resolution": 256,
+ "z_channels": 256})
+ self.encoder = Pixel_Eecoder(conf)
+ self.decoder = Pixel_Decoder(conf)
+ self.quantize = Pixel_Quantizer(
+ num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
+ # Load pretrained weights
+ self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True)
+
+ self.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def encode(self, x):
+ hidden_states = self.encoder(x)
+ quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states)
+ return codebook_indices.detach()
+
+ @torch.no_grad()
+ def decode(self, codes):
+ quantized_states = self.quantize.get_codebook_entry(codes)
+ rec_images = self.decoder(quantized_states)
+ rec_images = torch.clamp(rec_images, 0.0, 1.0)
+ return rec_images.detach()
+
+ @torch.no_grad()
+ def decode_tokens(self, codes):
+ return self.decode(codes)
+
+
+class VibeTokenModel(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"]):
+ def __init__(self, config):
+
+ if isinstance(config, dict):
+ config = OmegaConf.create(config)
+
+ super().__init__()
+ self.config = config
+ # This should be False for stage1 and True for stage2.
+ self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True)
+
+ self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq")
+ if self.quantize_mode not in ["vq", "vae", "softvq", "mvq"]:
+ raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.")
+
+ if self.finetune_decoder and self.quantize_mode not in ["vq", "softvq", "mvq"]:
+ raise ValueError("Only supprot finetune_decoder with vq quantization for now.")
+
+ self.encoder = ResolutionEncoder(config)
+ self.decoder = ResolutionDecoder(config)
+
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
+ scale = self.encoder.width ** -0.5
+ self.latent_tokens = nn.Parameter(
+ scale * torch.randn(self.num_latent_tokens, self.encoder.width))
+
+ self.apply(self._init_weights)
+
+ if self.quantize_mode == "vq":
+ self.quantize = VectorQuantizer(
+ codebook_size=config.model.vq_model.codebook_size,
+ token_size=config.model.vq_model.token_size,
+ commitment_cost=config.model.vq_model.commitment_cost,
+ use_l2_norm=config.model.vq_model.use_l2_norm,)
+ elif self.quantize_mode == "vae":
+ self.quantize = DiagonalGaussianDistribution
+ elif self.quantize_mode == "mvq":
+ self.quantize = VectorQuantizerMVQ(
+ codebook_size=config.model.vq_model.codebook_size,
+ token_size=config.model.vq_model.token_size,
+ commitment_cost=config.model.vq_model.commitment_cost,
+ use_l2_norm=config.model.vq_model.use_l2_norm,
+ num_codebooks=config.model.vq_model.num_codebooks,
+ )
+ elif self.quantize_mode == "softvq":
+ self.quantize = SoftVectorQuantizer(
+ codebook_size=config.model.vq_model.codebook_size,
+ token_size=config.model.vq_model.token_size,
+ commitment_cost=config.model.vq_model.commitment_cost,
+ use_l2_norm=config.model.vq_model.use_l2_norm,
+ num_codebooks=config.model.vq_model.num_codebooks,
+ )
+ else:
+ raise NotImplementedError
+
+ if self.finetune_decoder:
+ # Freeze encoder/quantizer/latent tokens
+ self.latent_tokens.requires_grad_(False)
+ self.encoder.eval()
+ self.encoder.requires_grad_(False)
+ self.quantize.eval()
+ self.quantize.requires_grad_(False)
+
+ # Include MaskGiT-VQGAN's quantizer and decoder
+ self.pixel_quantize = Pixel_Quantizer(
+ num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
+ self.pixel_decoder = Pixel_Decoder(OmegaConf.create(
+ {"channel_mult": [1, 1, 2, 2, 4],
+ "num_resolutions": 5,
+ "dropout": 0.0,
+ "hidden_channels": 128,
+ "num_channels": 3,
+ "num_res_blocks": 2,
+ "resolution": 256,
+ "z_channels": 256}))
+
+ def _save_pretrained(self, save_directory: Path) -> None:
+ """Save weights and config to a local directory."""
+ # Assume 'self.config' is your DictConfig object
+ # Convert to a regular dictionary
+ dict_config = OmegaConf.to_container(self.config)
+ # Save as JSON
+ file_path = Path(save_directory) / "config.json"
+ with open(file_path, 'w') as json_file:
+ json.dump(dict_config, json_file, indent=4)
+ super()._save_pretrained(save_directory)
+
+ def _init_weights(self, module):
+ """ Initialize the weights.
+ :param:
+ module -> torch.nn.Module: module to initialize
+ """
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d):
+ module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def encode(self, x, attention_mask=None, encode_patch_size=None, train=True, length=None):
+ if self.finetune_decoder:
+ with torch.no_grad():
+ self.encoder.eval()
+ self.quantize.eval()
+ z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train)
+ z_quantized, result_dict = self.quantize(z)
+ result_dict["quantizer_loss"] *= 0
+ result_dict["commitment_loss"] *= 0
+ result_dict["codebook_loss"] *= 0
+ else:
+ if length is not None:
+ attention_mask = None
+ latent_tokens = self.latent_tokens[:length+1]
+ else:
+ latent_tokens = self.latent_tokens
+ z = self.encoder(pixel_values=x, latent_tokens=latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train)
+ if self.quantize_mode in ["vq", "mvq", "softvq"]:
+ z_quantized, result_dict = self.quantize(z)
+ elif self.quantize_mode == "vae":
+ posteriors = self.quantize(z)
+ z_quantized = posteriors.sample()
+ result_dict = posteriors
+
+ return z_quantized, result_dict
+
+ def decode(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
+ decoded = self.decoder(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train)
+ if self.finetune_decoder:
+ quantized_states = torch.einsum(
+ 'nchw,cd->ndhw', decoded.softmax(1),
+ self.pixel_quantize.embedding.weight)
+ decoded = self.pixel_decoder(quantized_states)
+ return decoded
+
+ def decode_tokens(self, tokens, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True):
+ if self.quantize_mode in ["vq", "softvq"]:
+ tokens = tokens.squeeze(1)
+ batch, seq_len = tokens.shape # B x N
+ z_quantized = self.quantize.get_codebook_entry(
+ tokens.reshape(-1)).reshape(batch, 1, seq_len, -1)
+ z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
+ elif self.quantize_mode == "mvq":
+ z_quantized = self.quantize.get_codebook_entry(tokens)
+ elif self.quantize_mode == "vae":
+ z_quantized = tokens
+ z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
+ decoded = self.decode(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train)
+ return decoded
+
+ def forward(self, x, key_attention_mask=None, height=None, width=None, train=True):
+ if height is None:
+ batch_size, channels, height, width = x.shape
+ z_quantized, result_dict = self.encode(x, attention_mask=key_attention_mask, train=train)
+ z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
+ decoded = self.decode(z_quantized, attention_mask=key_attention_mask, height=height, width=width, train=train)
+ return decoded, result_dict
\ No newline at end of file
diff --git a/reconstruct.py b/reconstruct.py
new file mode 100644
index 0000000000000000000000000000000000000000..29e9fe4ef0fe5aaef4a7bab363c861af16617c63
--- /dev/null
+++ b/reconstruct.py
@@ -0,0 +1,148 @@
+#!/usr/bin/env python3
+"""Simple reconstruction script for VibeToken.
+
+Usage:
+ # Auto mode (recommended) - automatically determines optimal settings
+ python reconstruct.py --auto \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint /path/to/checkpoint.bin \
+ --image assets/example_1.jpg \
+ --output assets/reconstructed.png
+
+ # Manual mode - specify all parameters
+ python reconstruct.py \
+ --config configs/vibetoken_ll.yaml \
+ --checkpoint /path/to/checkpoint.bin \
+ --image assets/example_1.jpg \
+ --output assets/reconstructed.png \
+ --input_height 512 --input_width 512 \
+ --encoder_patch_size 16,32 \
+ --decoder_patch_size 16
+"""
+
+import argparse
+from PIL import Image
+from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
+
+
+def parse_patch_size(value):
+ """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
+ if value is None:
+ return None
+ if ',' in value:
+ parts = value.split(',')
+ return (int(parts[0]), int(parts[1]))
+ return int(value)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="VibeToken image reconstruction")
+ parser.add_argument("--config", type=str, default="configs/vibetoken_ll.yaml",
+ help="Path to config YAML")
+ parser.add_argument("--checkpoint", type=str, required=True,
+ help="Path to model checkpoint")
+ parser.add_argument("--image", type=str, default="assets/example_1.jpg",
+ help="Path to input image")
+ parser.add_argument("--output", type=str, default="./assets/reconstructed.png",
+ help="Path to output image")
+ parser.add_argument("--device", type=str, default="cuda",
+ help="Device (cuda/cpu)")
+
+ # Auto mode
+ parser.add_argument("--auto", action="store_true",
+ help="Auto mode: automatically determine optimal input resolution and patch sizes")
+
+ # Input resolution (optional - resize input before encoding)
+ parser.add_argument("--input_height", type=int, default=None,
+ help="Resize input to this height before encoding (default: original)")
+ parser.add_argument("--input_width", type=int, default=None,
+ help="Resize input to this width before encoding (default: original)")
+
+ # Output resolution (optional - decode to this size)
+ parser.add_argument("--output_height", type=int, default=None,
+ help="Decode to this height (default: same as input)")
+ parser.add_argument("--output_width", type=int, default=None,
+ help="Decode to this width (default: same as input)")
+
+ # Patch sizes (optional) - supports single int or tuple like "16,32"
+ parser.add_argument("--encoder_patch_size", type=str, default=None,
+ help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
+ parser.add_argument("--decoder_patch_size", type=str, default=None,
+ help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
+
+ args = parser.parse_args()
+
+ # Load tokenizer
+ print(f"Loading tokenizer from {args.config}")
+ tokenizer = VibeTokenTokenizer.from_config(
+ args.config,
+ args.checkpoint,
+ device=args.device,
+ )
+
+ # Load image
+ print(f"Loading image from {args.image}")
+ image = Image.open(args.image).convert("RGB")
+ original_size = image.size # (W, H)
+ print(f"Original image size: {original_size[0]}x{original_size[1]}")
+
+ if args.auto:
+ # AUTO MODE - use centralized auto_preprocess_image
+ print("\n=== AUTO MODE ===")
+ image, patch_size, info = auto_preprocess_image(image, verbose=True)
+ input_width, input_height = info["cropped_size"]
+ output_width, output_height = input_width, input_height
+ encoder_patch_size = patch_size
+ decoder_patch_size = patch_size
+ print("=================\n")
+
+ else:
+ # MANUAL MODE
+ # Parse patch sizes
+ encoder_patch_size = parse_patch_size(args.encoder_patch_size)
+ decoder_patch_size = parse_patch_size(args.decoder_patch_size)
+
+ # Resize input if specified
+ if args.input_width or args.input_height:
+ input_width = args.input_width or original_size[0]
+ input_height = args.input_height or original_size[1]
+ print(f"Resizing input to {input_width}x{input_height}")
+ image = image.resize((input_width, input_height), Image.LANCZOS)
+
+ # Always center crop to ensure dimensions divisible by 32
+ image = center_crop_to_multiple(image, multiple=32)
+ input_width, input_height = image.size
+ if (input_width, input_height) != original_size:
+ print(f"Center cropped to {input_width}x{input_height} (divisible by 32)")
+
+ # Determine output size
+ output_height = args.output_height or input_height
+ output_width = args.output_width or input_width
+
+ # Encode image to tokens
+ print("Encoding image to tokens...")
+ if encoder_patch_size:
+ print(f" Using encoder patch size: {encoder_patch_size}")
+ tokens = tokenizer.encode(image, patch_size=encoder_patch_size)
+ print(f"Token shape: {tokens.shape}")
+
+ # Decode back to image
+ print(f"Decoding to {output_width}x{output_height}...")
+ if decoder_patch_size:
+ print(f" Using decoder patch size: {decoder_patch_size}")
+ reconstructed = tokenizer.decode(
+ tokens,
+ height=output_height,
+ width=output_width,
+ patch_size=decoder_patch_size
+ )
+ print(f"Reconstructed shape: {reconstructed.shape}")
+
+ # Convert tensor to PIL and save
+ output_images = tokenizer.to_pil(reconstructed)
+ output_images[0].save(args.output)
+ print(f"Saved reconstructed image to {args.output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..756560a57d27a87a7e3e5a847181fde70663474c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+spaces
+torch>=2.0.0
+torchvision
+einops>=0.6.0
+omegaconf>=2.3.0
+pillow>=9.0.0
+numpy>=1.20.0
+huggingface_hub>=0.16.0
+accelerate
+wandb
+webdataset
+timm
+open_clip_torch
+transformers
+scipy
+torch-fidelity
+torchinfo
+termcolor
+iopath
+opencv-python
+diffusers
+gdown
+tqdm
+requests
+datasets
+gradio>=4.0.0
\ No newline at end of file
diff --git a/scripts/train_vibetoken.py b/scripts/train_vibetoken.py
new file mode 100644
index 0000000000000000000000000000000000000000..241bb12540474854a43048331036975d5e6b8705
--- /dev/null
+++ b/scripts/train_vibetoken.py
@@ -0,0 +1,223 @@
+"""Training script for VibeToken.
+
+Reference:
+ https://github.com/huggingface/open-muse
+"""
+import math
+import os
+import sys
+from pathlib import Path
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
+sys.path.append(parent_dir)
+
+from accelerate.utils import set_seed
+from accelerate import Accelerator
+
+import torch
+import wandb
+from omegaconf import OmegaConf
+from utils.logger import setup_logger
+
+from utils.train_utils import (
+ get_config, create_pretrained_tokenizer,
+ create_model_and_loss_module,
+ create_optimizer, create_lr_scheduler, create_dataloader,
+ create_evaluator, auto_resume, save_checkpoint,
+ train_one_epoch)
+
+
+def main():
+ workspace = os.environ.get('WORKSPACE', '')
+ if workspace:
+ torch.hub.set_dir(workspace + "/models/hub")
+
+ config = get_config()
+ # Enable TF32 on Ampere GPUs.
+ if config.training.enable_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.deterministic = False
+
+ output_dir = config.experiment.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ config.experiment.logging_dir = os.path.join(output_dir, "logs")
+
+ # Whether logging to Wandb or Tensorboard.
+ tracker = "tensorboard"
+ if config.training.enable_wandb:
+ tracker = "wandb"
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=config.training.gradient_accumulation_steps,
+ mixed_precision=config.training.mixed_precision,
+ log_with=tracker,
+ project_dir=config.experiment.logging_dir,
+ split_batches=False,
+ )
+
+ logger = setup_logger(name="VibeToken", log_level="INFO",
+ output_file=f"{output_dir}/log{accelerator.process_index}.txt")
+
+ if accelerator.is_main_process:
+ if config.training.enable_wandb:
+ wandb_config = config.training.get("wandb", {})
+ wandb_project = wandb_config.get("project", config.experiment.project)
+ wandb_entity = wandb_config.get("entity", None)
+ wandb_name = wandb_config.get("name", config.experiment.name)
+ wandb_tags = list(wandb_config.get("tags", []))
+ wandb_notes = wandb_config.get("notes", None)
+ wandb_resume_id = wandb_config.get("resume_id", None)
+
+ wandb_init_kwargs = {
+ "wandb": {
+ "name": wandb_name,
+ "dir": output_dir,
+ "resume": "allow",
+ }
+ }
+ if wandb_entity:
+ wandb_init_kwargs["wandb"]["entity"] = wandb_entity
+ if wandb_tags:
+ wandb_init_kwargs["wandb"]["tags"] = wandb_tags
+ if wandb_notes:
+ wandb_init_kwargs["wandb"]["notes"] = wandb_notes
+ if wandb_resume_id:
+ wandb_init_kwargs["wandb"]["id"] = wandb_resume_id
+
+ accelerator.init_trackers(
+ project_name=wandb_project,
+ config=OmegaConf.to_container(config, resolve=True),
+ init_kwargs=wandb_init_kwargs,
+ )
+ logger.info(f"WandB initialized - Project: {wandb_project}, Name: {wandb_name}")
+ else:
+ accelerator.init_trackers(config.experiment.name)
+
+ config_path = Path(output_dir) / "config.yaml"
+ logger.info(f"Saving config to {config_path}")
+ OmegaConf.save(config, config_path)
+ logger.info(f"Config:\n{OmegaConf.to_yaml(config)}")
+
+ # If passed along, set the training seed now.
+ if config.training.seed is not None:
+ set_seed(config.training.seed, device_specific=True)
+
+ accelerator.wait_for_everyone()
+
+ # Create pretrained tokenizer in a synchronized manner
+ if config.model.vq_model.is_legacy:
+ if accelerator.is_main_process:
+ logger.info("Creating pretrained tokenizer on main process...")
+ accelerator.wait_for_everyone()
+ pretrained_tokenizer = create_pretrained_tokenizer(config, accelerator)
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ logger.info("Pretrained tokenizer creation completed.")
+ else:
+ pretrained_tokenizer = None
+
+ if accelerator.is_main_process:
+ logger.info("Creating model and loss module...")
+ accelerator.wait_for_everyone()
+
+ model, ema_model, loss_module = create_model_and_loss_module(
+ config, logger, accelerator, model_type="vibetoken")
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ logger.info("Model creation completed.")
+
+ optimizer, discriminator_optimizer = create_optimizer(config, logger, model, loss_module, model_type="vibetoken")
+
+ lr_scheduler, discriminator_lr_scheduler = create_lr_scheduler(
+ config, logger, accelerator, optimizer, discriminator_optimizer)
+
+ if accelerator.is_main_process:
+ logger.info("Creating dataloaders...")
+ train_dataloader, eval_dataloader = create_dataloader(config, logger, accelerator)
+ accelerator.wait_for_everyone()
+
+ # Set up evaluator.
+ if accelerator.is_main_process:
+ logger.info("Setting up evaluator...")
+ evaluator = create_evaluator(config, logger, accelerator)
+
+ # Prepare everything with accelerator.
+ logger.info("Preparing model, optimizer and dataloaders")
+ # The dataloader are already aware of distributed training, so we don't need to prepare them.
+ if config.model.vq_model.is_legacy:
+ if config.model.vq_model.finetune_decoder:
+ model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare(
+ model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler
+ )
+ else:
+ model, optimizer, lr_scheduler = accelerator.prepare(
+ model, optimizer, lr_scheduler
+ )
+ else:
+ model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare(
+ model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler
+ )
+
+ if config.training.use_ema:
+ ema_model.to(accelerator.device)
+
+ total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes
+ num_batches = math.ceil(
+ config.experiment.max_train_examples / total_batch_size_without_accum)
+ num_update_steps_per_epoch = math.ceil(num_batches / config.training.gradient_accumulation_steps)
+ num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch)
+
+ # Start training.
+ logger.info("***** Running training *****")
+ logger.info(f" Num training steps = {config.training.max_train_steps}")
+ logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}")
+ logger.info(f" Instantaneous batch size per gpu = { config.training.per_gpu_batch_size}")
+ logger.info(f""" Total train batch size (w. parallel, distributed & accumulation) = {(
+ config.training.per_gpu_batch_size *
+ accelerator.num_processes *
+ config.training.gradient_accumulation_steps)}""")
+ global_step = 0
+ first_epoch = 0
+
+ global_step, first_epoch = auto_resume(
+ config, logger, accelerator, ema_model, num_update_steps_per_epoch,
+ strict=True)
+
+ for current_epoch in range(first_epoch, num_train_epochs):
+ accelerator.print(f"Epoch {current_epoch}/{num_train_epochs-1} started.")
+ global_step = train_one_epoch(config, logger, accelerator,
+ model, ema_model, loss_module,
+ optimizer, discriminator_optimizer,
+ lr_scheduler, discriminator_lr_scheduler,
+ train_dataloader, eval_dataloader,
+ evaluator,
+ global_step,
+ pretrained_tokenizer=pretrained_tokenizer,
+ model_type="vibetoken")
+ # Stop training if max steps is reached.
+ if global_step >= config.training.max_train_steps:
+ accelerator.print(
+ f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}"
+ )
+ break
+
+ accelerator.wait_for_everyone()
+ # Save checkpoint at the end of training.
+ save_checkpoint(model, output_dir, accelerator, global_step, logger=logger)
+ # Save the final trained checkpoint
+ if accelerator.is_main_process:
+ model = accelerator.unwrap_model(model)
+ if config.training.use_ema:
+ ema_model.copy_to(model.parameters())
+ model.save_pretrained_weight(output_dir)
+
+ if accelerator.is_main_process and config.training.enable_wandb:
+ wandb.finish()
+ logger.info("WandB run finished")
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/setup.sh b/setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4ca3020f3dcd0efb2dee70ed6e7cc0d05b9053ac
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+# Data preparation script for VibeToken training.
+# Set DATA_DIR to control where datasets are stored (defaults to ./data).
+#
+# Usage:
+# export DATA_DIR=/mnt/fastssd/datasets # optional, defaults to ./data
+# bash setup.sh
+
+DATA_DIR="${DATA_DIR:-./data}"
+
+echo "Using DATA_DIR=${DATA_DIR}"
+
+# Download ImageNet-1k via HuggingFace
+export HF_HUB_ENABLE_HF_TRANSFER=1
+huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir "${DATA_DIR}/imagenet-1k"
+
+# Convert to WebDataset format
+python data/convert_imagenet_to_wds.py \
+ --input_dir "${DATA_DIR}/imagenet-1k" \
+ --output_dir "${DATA_DIR}/imagenet_wds"
diff --git a/train_tokenvibe.sh b/train_tokenvibe.sh
new file mode 100644
index 0000000000000000000000000000000000000000..517edf995d0b12071e0fb86c857646f953771dff
--- /dev/null
+++ b/train_tokenvibe.sh
@@ -0,0 +1,14 @@
+# Run training with 8 GPUs across 2 nodes (4 GPUs per node)
+NODE_RANK=${RANK:-1}
+MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
+MASTER_PORT=${MASTER_PORT:-9871}
+
+accelerate launch \
+ --num_machines=1 \
+ --num_processes=8 \
+ --machine_rank=$NODE_RANK \
+ --main_process_ip=$MASTER_ADDR \
+ --main_process_port=$MASTER_PORT \
+ --same_network \
+ scripts/train_tokenvibe.py \
+ config=configs/training/VibeToken_small.yaml
\ No newline at end of file
diff --git a/train_vibetoken.sh b/train_vibetoken.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ab7ebdb0a4df3774d4aab3d0afaf778dee78effb
--- /dev/null
+++ b/train_vibetoken.sh
@@ -0,0 +1,14 @@
+# Run training with 8 GPUs across 2 nodes (4 GPUs per node)
+NODE_RANK=${RANK:-1}
+MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
+MASTER_PORT=${MASTER_PORT:-9871}
+
+accelerate launch \
+ --num_machines=1 \
+ --num_processes=8 \
+ --machine_rank=$NODE_RANK \
+ --main_process_ip=$MASTER_ADDR \
+ --main_process_port=$MASTER_PORT \
+ --same_network \
+ scripts/train_vibetoken.py \
+ config=configs/training/VibeToken_small.yaml
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e7e94c7b322c0a04d213d587ed94d666c9d073d
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,69 @@
+"""Util functions supporting logging to terminal and files."""
+
+import functools
+import sys
+from accelerate.logging import MultiProcessAdapter
+import logging
+from termcolor import colored
+
+from iopath.common.file_io import PathManager as PathManagerClass
+
+__all__ = ["setup_logger", "PathManager"]
+
+PathManager = PathManagerClass()
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", self._root_name)
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+@functools.lru_cache()
+def setup_logger(name="TiTok", log_level: str = None, color=True, use_accelerate=True,
+ output_file=None):
+ logger = logging.getLogger(name)
+ if log_level is None:
+ logger.setLevel(logging.DEBUG)
+ else:
+ logger.setLevel(log_level.upper())
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
+ )
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ if output_file is not None:
+ fileHandler = logging.FileHandler(output_file)
+ fileHandler.setFormatter(formatter)
+ logger.addHandler(fileHandler)
+
+ if use_accelerate:
+ return MultiProcessAdapter(logger, {})
+ else:
+ return logger
\ No newline at end of file
diff --git a/utils/lr_schedulers.py b/utils/lr_schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d5ff550c6accb988b2e8f827ea76f61cbbdd939
--- /dev/null
+++ b/utils/lr_schedulers.py
@@ -0,0 +1,129 @@
+"""Learning rate schedulers.
+
+Reference:
+ https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py
+"""
+import math
+from enum import Enum
+from typing import Optional, Union
+
+import torch
+
+
+class SchedulerType(Enum):
+ COSINE = "cosine"
+ CONSTANT = "constant"
+
+def get_cosine_schedule_with_warmup(
+ optimizer: torch.optim.Optimizer,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ last_epoch: int = -1,
+ base_lr: float = 1e-4,
+ end_lr: float = 0.0,
+):
+ """Creates a cosine learning rate schedule with warm-up and ending learning rate.
+
+ Args:
+ optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
+ num_warmup_steps: An integer, the number of steps for the warmup phase.
+ num_training_steps: An integer, the total number of training steps.
+ num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
+ just decrease from the max value to 0 following a half-cosine).
+ last_epoch: An integer, the index of the last epoch when resuming training.
+ base_lr: A float, the base learning rate.
+ end_lr: A float, the final learning rate.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / \
+ float(max(1, num_training_steps - num_warmup_steps))
+ ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+ return (end_lr + (base_lr - end_lr) * ratio) / base_lr
+
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_constant_schedule_with_warmup(
+ optimizer: torch.optim.Optimizer,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ base_lr: float = 1e-4,
+ end_lr: float = 0.0,
+):
+ """UViT: Creates a constant learning rate schedule with warm-up.
+
+ Args:
+ optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
+ num_warmup_steps: An integer, the number of steps for the warmup phase.
+ num_training_steps: An integer, the total number of training steps.
+ num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
+ just decrease from the max value to 0 following a half-cosine).
+ last_epoch: An integer, the index of the last epoch when resuming training.
+ base_lr: A float, the base learning rate.
+ end_lr: A float, the final learning rate.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ else:
+ return 1.0
+
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+ SchedulerType.CONSTANT: get_constant_schedule_with_warmup,
+}
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ optimizer: torch.optim.Optimizer,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+ base_lr: float = 1e-4,
+ end_lr: float = 0.0,
+):
+ """Retrieves a learning rate scheduler from the given name and optimizer.
+
+ Args:
+ name: A string or SchedulerType, the name of the scheduler to retrieve.
+ optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler.
+ num_warmup_steps: An integer, the number of warmup steps.
+ num_training_steps: An integer, the total number of training steps.
+ base_lr: A float, the base learning rate.
+ end_lr: A float, the final learning rate.
+
+ Returns:
+ A instance of torch.optim.lr_scheduler.LambdaLR
+
+ Raises:
+ ValueError: If num_warmup_steps or num_training_steps is not provided.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ base_lr=base_lr,
+ end_lr=end_lr,
+ )
\ No newline at end of file
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..154b64488c5150a8221ed5f53f23ba942d8852f9
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,342 @@
+"""This file is borrowed from https://github.com/LTH14/mar/blob/main/util/misc.py
+"""
+import builtins
+import datetime
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+TORCH_MAJOR = int(torch.__version__.split('.')[0])
+TORCH_MINOR = int(torch.__version__.split('.')[1])
+
+if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
+ from torch._six import inf
+else:
+ from torch import inf
+import copy
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
+
+
+def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
+ decay = []
+ no_decay = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
+ no_decay.append(param) # no weight decay on bias, norm and diffloss
+ else:
+ decay.append(param)
+ return [
+ {'params': no_decay, 'weight_decay': 0.},
+ {'params': decay, 'weight_decay': weight_decay}]
+
+
+def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None):
+ if epoch_name is None:
+ epoch_name = str(epoch)
+ output_dir = Path(args.output_dir)
+ checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
+
+ # ema
+ if ema_params is not None:
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
+ assert name in ema_state_dict
+ ema_state_dict[name] = ema_params[i]
+ else:
+ ema_state_dict = None
+
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'model_ema': ema_state_dict,
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }
+ save_on_master(to_save, checkpoint_path)
+
+
+def all_reduce_mean(x):
+ world_size = get_world_size()
+ if world_size > 1:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
\ No newline at end of file
diff --git a/utils/train_utils.py b/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f399c4fdaa4f50fd482382bb61bb40e9f891816
--- /dev/null
+++ b/utils/train_utils.py
@@ -0,0 +1,836 @@
+"""Training utils for VibeToken."""
+import json
+import os
+import time
+import math
+from pathlib import Path
+import pprint
+import glob
+from collections import defaultdict
+import random
+import gc
+
+from data import SimpleImageDataset, PretoeknizedDataSetJSONL, PretokenizedWebDataset
+import torch
+from torch.utils.data import DataLoader
+from omegaconf import OmegaConf
+from torch.optim import AdamW
+from utils.lr_schedulers import get_scheduler
+from modeling.modules import EMAModel, ReconstructionLoss_Single_Stage
+from modeling.vibetoken_model import VibeTokenModel, PretrainedTokenizer
+from evaluator import VQGANEvaluator
+
+from utils.viz_utils import make_viz_from_samples
+from torchinfo import summary
+import accelerate
+
+def get_config():
+ """Reads configs from a yaml file and terminal."""
+ cli_conf = OmegaConf.from_cli()
+
+ yaml_conf = OmegaConf.load(cli_conf.config)
+ conf = OmegaConf.merge(yaml_conf, cli_conf)
+
+ return conf
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value.
+
+ This class is borrowed from
+ https://github.com/pytorch/examples/blob/main/imagenet/main.py#L423
+ """
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def create_pretrained_tokenizer(config, accelerator=None):
+ if config.model.vq_model.finetune_decoder:
+ pretrianed_tokenizer = None
+ else:
+ pretrianed_tokenizer = PretrainedTokenizer(config.model.vq_model.pretrained_tokenizer_weight)
+ if accelerator is not None:
+ pretrianed_tokenizer.to(accelerator.device)
+ return pretrianed_tokenizer
+
+
+def create_model_and_loss_module(config, logger, accelerator,
+ model_type="vibetoken"):
+ """Creates model and loss module."""
+ logger.info("Creating model and loss module.")
+ if model_type == "vibetoken":
+ if config.model.sub_model_type == "vibetoken":
+ model_cls = VibeTokenModel
+ loss_cls = ReconstructionLoss_Single_Stage
+ else:
+ raise ValueError(f"Unsupported sub_model_type {config.model.sub_model_type}")
+ else:
+ raise ValueError(f"Unsupported model_type {model_type}")
+ model = model_cls(config)
+
+ if config.experiment.get("init_weight", ""):
+ model_weight = torch.load(config.experiment.init_weight, map_location="cpu")
+ if config.model.vq_model.finetune_decoder:
+ pretrained_tokenizer_weight = torch.load(
+ config.model.vq_model.pretrained_tokenizer_weight, map_location="cpu"
+ )
+ pretrained_tokenizer_weight = {"pixel_" + k:v for k,v in pretrained_tokenizer_weight.items() if not "encoder." in k}
+ model_weight.update(pretrained_tokenizer_weight)
+
+ msg = model.load_state_dict(model_weight, strict=False)
+ logger.info(f"loading weight from {config.experiment.init_weight}, msg: {msg}")
+
+ # Create the EMA model.
+ ema_model = None
+ if config.training.use_ema:
+ ema_model = EMAModel(model.parameters(), decay=0.999,
+ model_cls=model_cls, config=config)
+ def load_model_hook(models, input_dir):
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"),
+ model_cls=model_cls, config=config)
+ ema_model.load_state_dict(load_model.state_dict())
+ ema_model.to(accelerator.device)
+ del load_model
+
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ ema_model.save_pretrained(os.path.join(output_dir, "ema_model"))
+
+ accelerator.register_load_state_pre_hook(load_model_hook)
+ accelerator.register_save_state_pre_hook(save_model_hook)
+
+ loss_module = loss_cls(config=config) if loss_cls is not None else None
+
+ if accelerator.is_main_process:
+ if model_type in ["vibetoken"]:
+ logger.info("VibeToken model summary not implemented yet.")
+ else:
+ raise NotImplementedError
+
+ return model, ema_model, loss_module
+
+
+def create_optimizer(config, logger, model, loss_module,
+ model_type="vibetoken", need_discrminator=True):
+ """Creates optimizer for model and discriminator."""
+ logger.info("Creating optimizers.")
+ optimizer_config = config.optimizer.params
+ learning_rate = optimizer_config.learning_rate
+
+ optimizer_type = config.optimizer.name
+ if optimizer_type == "adamw":
+ optimizer_cls = AdamW
+ else:
+ raise ValueError(f"Optimizer {optimizer_type} not supported")
+
+ exclude = (lambda n, p: p.ndim < 2 or "ln" in n or "bias" in n or 'latent_tokens' in n
+ or 'mask_token' in n or 'embedding' in n or 'norm' in n or 'gamma' in n or 'embed' in n)
+ include = lambda n, p: not exclude(n, p)
+ named_parameters = list(model.named_parameters())
+ gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
+ rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
+ optimizer = optimizer_cls(
+ [
+ {"params": gain_or_bias_params, "weight_decay": 0.},
+ {"params": rest_params, "weight_decay": optimizer_config.weight_decay},
+ ],
+ lr=learning_rate,
+ betas=(optimizer_config.beta1, optimizer_config.beta2)
+ )
+
+ if (config.model.vq_model.finetune_decoder or model_type == "vibetoken") and need_discrminator:
+ discriminator_learning_rate = optimizer_config.discriminator_learning_rate
+ discriminator_named_parameters = list(loss_module.named_parameters())
+ discriminator_gain_or_bias_params = [p for n, p in discriminator_named_parameters if exclude(n, p) and p.requires_grad]
+ discriminator_rest_params = [p for n, p in discriminator_named_parameters if include(n, p) and p.requires_grad]
+
+ discriminator_optimizer = optimizer_cls(
+ [
+ {"params": discriminator_gain_or_bias_params, "weight_decay": 0.},
+ {"params": discriminator_rest_params, "weight_decay": optimizer_config.weight_decay},
+ ],
+ lr=discriminator_learning_rate,
+ betas=(optimizer_config.beta1, optimizer_config.beta2)
+ )
+ else:
+ discriminator_optimizer = None
+
+ assert discriminator_optimizer is not None, "Discriminator optimizer is None with condition values: {config.model.vq_model.finetune_decoder} {model_type} {need_discrminator}"
+
+ return optimizer, discriminator_optimizer
+
+
+def create_lr_scheduler(config, logger, accelerator, optimizer, discriminator_optimizer=None):
+ """Creates learning rate scheduler for model and discriminator."""
+ logger.info("Creating lr_schedulers.")
+ lr_scheduler = get_scheduler(
+ config.lr_scheduler.scheduler,
+ optimizer=optimizer,
+ num_training_steps=config.training.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=config.lr_scheduler.params.warmup_steps * accelerator.num_processes,
+ base_lr=config.lr_scheduler.params.learning_rate,
+ end_lr=config.lr_scheduler.params.end_lr,
+ )
+ if discriminator_optimizer is not None:
+ discriminator_lr_scheduler = get_scheduler(
+ config.lr_scheduler.scheduler,
+ optimizer=discriminator_optimizer,
+ num_training_steps=config.training.max_train_steps * accelerator.num_processes - config.losses.discriminator_start,
+ num_warmup_steps=config.lr_scheduler.params.warmup_steps * accelerator.num_processes,
+ base_lr=config.lr_scheduler.params.learning_rate,
+ end_lr=config.lr_scheduler.params.end_lr,
+ )
+ else:
+ discriminator_lr_scheduler = None
+ return lr_scheduler, discriminator_lr_scheduler
+
+
+def create_dataloader(config, logger, accelerator):
+ """Creates data loader for training and testing."""
+ logger.info("Creating dataloaders.")
+ total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes
+ total_batch_size = (
+ config.training.per_gpu_batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps
+ )
+ preproc_config = config.dataset.preprocessing
+ dataset_config = config.dataset.params
+
+ if dataset_config.get("pretokenization", "") and dataset_config.get("dataset_with_text_label", False) is True:
+ dataset = PretokenizedWebDataset(
+ train_shards_path=dataset_config.train_shards_path_or_url,
+ eval_shards_path=dataset_config.eval_shards_path_or_url,
+ num_train_examples=config.experiment.max_train_examples,
+ per_gpu_batch_size=config.training.per_gpu_batch_size,
+ global_batch_size=total_batch_size_without_accum,
+ num_workers_per_gpu=dataset_config.num_workers_per_gpu,
+ resize_shorter_edge=preproc_config.resize_shorter_edge,
+ crop_size=preproc_config.crop_size,
+ random_crop=preproc_config.random_crop,
+ random_flip=preproc_config.random_flip,
+ normalize_mean=preproc_config.normalize_mean,
+ normalize_std=preproc_config.normalize_std,
+ process_recap=preproc_config.get("preproc_recap", True),
+ use_recap_prob=preproc_config.get("use_recap_prob", 0.95)
+ )
+ train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader
+ elif dataset_config.get("pretokenization", "") and dataset_config.get("dataset_with_text_label", False) is False:
+ dataset = SimpleImageDataset(
+ train_shards_path=dataset_config.train_shards_path_or_url,
+ eval_shards_path=dataset_config.eval_shards_path_or_url,
+ num_train_examples=config.experiment.max_train_examples,
+ per_gpu_batch_size=config.training.per_gpu_batch_size,
+ global_batch_size=total_batch_size_without_accum,
+ num_workers_per_gpu=dataset_config.num_workers_per_gpu,
+ resize_shorter_edge=preproc_config.resize_shorter_edge,
+ crop_size=preproc_config.crop_size,
+ random_crop=preproc_config.random_crop,
+ random_flip=preproc_config.random_flip,
+ dataset_with_class_label=dataset_config.get("dataset_with_class_label", True),
+ dataset_with_text_label=dataset_config.get("dataset_with_text_label", False),
+ res_ratio_filtering=preproc_config.get("res_ratio_filtering", False),
+ min_tokens=preproc_config.min_tokens,
+ max_tokens=preproc_config.max_tokens,
+ )
+ train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader
+ else:
+ if dataset_config.get("pretokenization", ""):
+ train_dataloader = DataLoader(
+ PretoeknizedDataSetJSONL(dataset_config.pretokenization),
+ batch_size=config.training.per_gpu_batch_size,
+ shuffle=True, drop_last=True, pin_memory=True)
+ train_dataloader.num_batches = math.ceil(
+ config.experiment.max_train_examples / total_batch_size_without_accum)
+
+ return train_dataloader, eval_dataloader
+
+
+class LazyVQGANEvaluator:
+ """A lazy-loading wrapper for VQGANEvaluator that delays inception model initialization."""
+
+ def __init__(self, device, enable_rfid=True, enable_inception_score=True,
+ enable_codebook_usage_measure=False, enable_codebook_entropy_measure=False,
+ num_codebook_entries=1024, accelerator=None):
+ self._device = device
+ self._enable_rfid = enable_rfid
+ self._enable_inception_score = enable_inception_score
+ self._enable_codebook_usage_measure = enable_codebook_usage_measure
+ self._enable_codebook_entropy_measure = enable_codebook_entropy_measure
+ self._num_codebook_entries = num_codebook_entries
+ self._accelerator = accelerator
+ self._evaluator = None
+ self._initialized = False
+
+ def _ensure_initialized(self):
+ """Initialize the real evaluator only when needed."""
+ if not self._initialized:
+ if self._accelerator and self._accelerator.num_processes > 1:
+ if self._accelerator.is_main_process:
+ try:
+ from evaluator.inception import get_inception_model
+ _ = get_inception_model()
+ except Exception as e:
+ print(f"Warning: Failed to pre-load inception model: {e}")
+
+ if self._accelerator:
+ self._accelerator.wait_for_everyone()
+
+ try:
+ self._evaluator = VQGANEvaluator(
+ device=self._device,
+ enable_rfid=self._enable_rfid,
+ enable_inception_score=self._enable_inception_score,
+ enable_codebook_usage_measure=self._enable_codebook_usage_measure,
+ enable_codebook_entropy_measure=self._enable_codebook_entropy_measure,
+ num_codebook_entries=self._num_codebook_entries
+ )
+ self._initialized = True
+ except Exception as e:
+ print(f"Warning: Failed to create VQGANEvaluator, using dummy: {e}")
+ class DummyEvaluator:
+ def reset_metrics(self): pass
+ def update(self, real_images, fake_images, codebook_indices=None): pass
+ def result(self):
+ return {"InceptionScore": 0.0, "rFID": 0.0, "CodebookUsage": 0.0, "CodebookEntropy": 0.0}
+ self._evaluator = DummyEvaluator()
+ self._initialized = True
+
+ def reset_metrics(self):
+ self._ensure_initialized()
+ return self._evaluator.reset_metrics()
+
+ def update(self, real_images, fake_images, codebook_indices=None):
+ self._ensure_initialized()
+ return self._evaluator.update(real_images, fake_images, codebook_indices)
+
+ def result(self):
+ self._ensure_initialized()
+ return self._evaluator.result()
+
+
+def create_evaluator(config, logger, accelerator):
+ """Creates evaluator."""
+ logger.info("Creating evaluator.")
+
+ if config.model.vq_model.get("quantize_mode", "vq") in ["vq", "softvq", "mvq"]:
+ evaluator = LazyVQGANEvaluator(
+ device=accelerator.device,
+ enable_rfid=True,
+ enable_inception_score=True,
+ enable_codebook_usage_measure=True,
+ enable_codebook_entropy_measure=True,
+ num_codebook_entries=config.model.vq_model.codebook_size,
+ accelerator=accelerator
+ )
+ elif config.model.vq_model.get("quantize_mode", "vq") == "vae":
+ evaluator = LazyVQGANEvaluator(
+ device=accelerator.device,
+ enable_rfid=True,
+ enable_inception_score=True,
+ enable_codebook_usage_measure=False,
+ enable_codebook_entropy_measure=False,
+ accelerator=accelerator
+ )
+ else:
+ raise NotImplementedError
+
+ logger.info("Lazy evaluator creation completed.")
+ return evaluator
+
+
+def auto_resume(config, logger, accelerator, ema_model,
+ num_update_steps_per_epoch, strict=True):
+ """Auto resuming the training."""
+ global_step = 0
+ first_epoch = 0
+ if config.experiment.resume:
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ local_ckpt_list = list(glob.glob(os.path.join(
+ config.experiment.output_dir, "checkpoint*")))
+ logger.info(f"All globbed checkpoints are: {local_ckpt_list}")
+ else:
+ local_ckpt_list = []
+
+ if accelerator.num_processes > 1:
+ checkpoint_count = torch.tensor(len(local_ckpt_list), device=accelerator.device)
+ accelerate.utils.broadcast(checkpoint_count, 0)
+
+ if checkpoint_count > 0:
+ if accelerator.is_main_process:
+ if len(local_ckpt_list) > 1:
+ fn = lambda x: int(x.split('/')[-1].split('-')[-1])
+ checkpoint_paths = sorted(local_ckpt_list, key=fn, reverse=True)
+ else:
+ checkpoint_paths = local_ckpt_list
+ latest_checkpoint = checkpoint_paths[0]
+ else:
+ latest_checkpoint = ""
+
+ if accelerator.is_main_process:
+ checkpoint_path_tensor = torch.tensor([ord(c) for c in latest_checkpoint], device=accelerator.device, dtype=torch.long)
+ path_length = torch.tensor(len(latest_checkpoint), device=accelerator.device)
+ else:
+ path_length = torch.tensor(0, device=accelerator.device)
+
+ accelerate.utils.broadcast(path_length, 0)
+
+ if not accelerator.is_main_process:
+ checkpoint_path_tensor = torch.zeros(path_length.item(), device=accelerator.device, dtype=torch.long)
+
+ accelerate.utils.broadcast(checkpoint_path_tensor, 0)
+
+ if not accelerator.is_main_process:
+ latest_checkpoint = ''.join([chr(c.item()) for c in checkpoint_path_tensor])
+
+ global_step = load_checkpoint(
+ Path(latest_checkpoint),
+ accelerator,
+ logger=logger,
+ strict=strict
+ )
+ if config.training.use_ema:
+ ema_model.set_step(global_step)
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ logger.info("Training from scratch.")
+ else:
+ if len(local_ckpt_list) >= 1:
+ if len(local_ckpt_list) > 1:
+ fn = lambda x: int(x.split('/')[-1].split('-')[-1])
+ checkpoint_paths = sorted(local_ckpt_list, key=fn, reverse=True)
+ else:
+ checkpoint_paths = local_ckpt_list
+ global_step = load_checkpoint(
+ Path(checkpoint_paths[0]),
+ accelerator,
+ logger=logger,
+ strict=strict
+ )
+ if config.training.use_ema:
+ ema_model.set_step(global_step)
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ logger.info("Training from scratch.")
+
+ accelerator.wait_for_everyone()
+ return global_step, first_epoch
+
+
+def train_one_epoch(config, logger, accelerator,
+ model, ema_model, loss_module,
+ optimizer, discriminator_optimizer,
+ lr_scheduler, discriminator_lr_scheduler,
+ train_dataloader, eval_dataloader,
+ evaluator,
+ global_step,
+ model_type="vibetoken",
+ clip_tokenizer=None,
+ clip_encoder=None,
+ pretrained_tokenizer=None):
+ """One epoch training."""
+ batch_time_meter = AverageMeter()
+ data_time_meter = AverageMeter()
+ end = time.time()
+
+ model.train()
+
+ autoencoder_logs = defaultdict(float)
+ discriminator_logs = defaultdict(float)
+ for i, batch in enumerate(train_dataloader):
+ model.train()
+ if "image" in batch:
+ images = batch["image"].to(
+ accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
+ )
+ if config.training.get("variable_resolution", False):
+ any2any = config.training.variable_resolution.get("any2any", True)
+
+ dims = config.training.variable_resolution.dim
+ ratios = config.training.variable_resolution.ratio
+ assert len(dims) == len(ratios), "dims and ratios must have the same length"
+ input_res = tuple(random.choices(dims, weights=ratios, k=1)[0])
+
+ if any2any:
+ output_res = tuple(random.choices(dims, weights=ratios, k=1)[0])
+ else:
+ output_res = input_res
+
+ images = torch.nn.functional.interpolate(images, size=output_res, mode="bilinear", align_corners=False)
+ input_images = torch.nn.functional.interpolate(images, size=input_res, mode="bilinear", align_corners=False)
+ else:
+ input_images = images
+ output_res = (None, None)
+
+ fnames = batch["__key__"]
+ data_time_meter.update(time.time() - end)
+
+ if pretrained_tokenizer is not None:
+ pretrained_tokenizer.eval()
+ proxy_codes = pretrained_tokenizer.encode(images)
+ else:
+ proxy_codes = None
+
+ with accelerator.accumulate([model, loss_module]):
+ additional_args = {}
+ if config.model.get("train_with_attention", False):
+ additional_args["key_attention_mask"] = batch["attention_mask"].to(
+ accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
+ )
+ reconstructed_images, extra_results_dict = model(input_images, height=output_res[0], width=output_res[1], **additional_args)
+ autoencoder_loss, loss_dict = loss_module(
+ images,
+ reconstructed_images,
+ extra_results_dict,
+ global_step,
+ mode="generator",
+ )
+
+ autoencoder_logs = {}
+ for k, v in loss_dict.items():
+ if k in ["discriminator_factor", "d_weight"]:
+ if type(v) == torch.Tensor:
+ autoencoder_logs["train/" + k] = v.cpu().item()
+ else:
+ autoencoder_logs["train/" + k] = v
+ else:
+ gathered_tensor = accelerator.gather(v)
+ autoencoder_logs["train/" + k] = gathered_tensor.mean().item()
+ del gathered_tensor
+
+ torch.cuda.empty_cache()
+ accelerator.backward(autoencoder_loss)
+
+ if config.training.max_grad_norm is not None and accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+
+ if (
+ accelerator.sync_gradients
+ and (global_step + 1) % config.experiment.log_grad_norm_every == 0
+ and accelerator.is_main_process
+ ):
+ log_grad_norm(model, accelerator, global_step + 1)
+
+ optimizer.zero_grad(set_to_none=True)
+
+ # Train discriminator.
+ discriminator_logs = defaultdict(float)
+ if (config.model.vq_model.finetune_decoder or model_type == "vibetoken") and accelerator.unwrap_model(loss_module).should_discriminator_be_trained(global_step):
+ discriminator_logs = defaultdict(float)
+ discriminator_loss, loss_dict_discriminator = loss_module(
+ images,
+ reconstructed_images,
+ extra_results_dict,
+ global_step=global_step,
+ mode="discriminator",
+ )
+
+ for k, v in loss_dict_discriminator.items():
+ if k in ["logits_real", "logits_fake"]:
+ if type(v) == torch.Tensor:
+ discriminator_logs["train/" + k] = v.cpu().item()
+ else:
+ discriminator_logs["train/" + k] = v
+ else:
+ gathered_tensor = accelerator.gather(v)
+ discriminator_logs["train/" + k] = gathered_tensor.mean().item()
+ del gathered_tensor
+
+ torch.cuda.empty_cache()
+ accelerator.backward(discriminator_loss)
+
+ if config.training.max_grad_norm is not None and accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(loss_module.parameters(), config.training.max_grad_norm)
+
+ discriminator_optimizer.step()
+ discriminator_lr_scheduler.step()
+
+ if (
+ accelerator.sync_gradients
+ and (global_step + 1) % config.experiment.log_grad_norm_every == 0
+ and accelerator.is_main_process
+ ):
+ log_grad_norm(loss_module, accelerator, global_step + 1)
+
+ discriminator_optimizer.zero_grad(set_to_none=True)
+
+ if accelerator.sync_gradients:
+ if config.training.use_ema:
+ ema_model.step(model.parameters())
+ batch_time_meter.update(time.time() - end)
+ end = time.time()
+
+ if (global_step + 1) % config.experiment.log_every == 0:
+ samples_per_second_per_gpu = (
+ config.training.gradient_accumulation_steps * config.training.per_gpu_batch_size / batch_time_meter.val
+ )
+
+ lr = lr_scheduler.get_last_lr()[0]
+ logger.info(
+ f"Data (t): {data_time_meter.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu "
+ f"Batch (t): {batch_time_meter.val:0.4f} "
+ f"LR: {lr:0.6f} "
+ f"Step: {global_step + 1} "
+ f"Total Loss: {autoencoder_logs['train/total_loss']:0.4f} "
+ f"Recon Loss: {autoencoder_logs['train/reconstruction_loss']:0.4f} "
+ )
+ logs = {
+ "lr": lr,
+ "lr/generator": lr,
+ "samples/sec/gpu": samples_per_second_per_gpu,
+ "time/data_time": data_time_meter.val,
+ "time/batch_time": batch_time_meter.val,
+ }
+ logs.update(autoencoder_logs)
+ logs.update(discriminator_logs)
+ accelerator.log(logs, step=global_step + 1)
+
+ del autoencoder_logs, discriminator_logs, logs
+ gc.collect()
+
+ batch_time_meter.reset()
+ data_time_meter.reset()
+
+ # Save model checkpoint.
+ if (global_step + 1) % config.experiment.save_every == 0:
+ save_path = save_checkpoint(
+ model, config.experiment.output_dir, accelerator, global_step + 1, logger=logger)
+ accelerator.wait_for_everyone()
+
+ # Generate images.
+ if (global_step + 1) % config.experiment.generate_every == 0:
+ if accelerator.is_main_process:
+ if config.training.get("use_ema", False):
+ ema_model.store(model.parameters())
+ ema_model.copy_to(model.parameters())
+
+ reconstruct_images(
+ model,
+ images[:config.training.num_generated_images],
+ fnames[:config.training.num_generated_images],
+ accelerator,
+ global_step + 1,
+ config.experiment.output_dir,
+ logger=logger,
+ config=config,
+ pretrained_tokenizer=pretrained_tokenizer
+ )
+
+ if config.training.get("use_ema", False):
+ ema_model.restore(model.parameters())
+
+ accelerator.wait_for_everyone()
+
+
+ # Evaluate reconstruction.
+ if eval_dataloader is not None and (global_step + 1) % config.experiment.eval_every == 0:
+ logger.info(f"Computing metrics on the validation set.")
+ if config.training.get("use_ema", False):
+ ema_model.store(model.parameters())
+ ema_model.copy_to(model.parameters())
+ eval_scores = eval_reconstruction(
+ config,
+ model,
+ eval_dataloader,
+ accelerator,
+ evaluator,
+ pretrained_tokenizer=pretrained_tokenizer
+ )
+ logger.info(
+ f"EMA EVALUATION "
+ f"Step: {global_step + 1} "
+ )
+ logger.info(pprint.pformat(eval_scores))
+ if accelerator.is_main_process:
+ eval_log = {f'ema_eval/'+k: v for k, v in eval_scores.items()}
+ accelerator.log(eval_log, step=global_step + 1)
+ if config.training.get("use_ema", False):
+ ema_model.restore(model.parameters())
+ else:
+ eval_scores = eval_reconstruction(
+ config,
+ model,
+ eval_dataloader,
+ accelerator,
+ evaluator,
+ pretrained_tokenizer=pretrained_tokenizer
+ )
+
+ logger.info(
+ f"Non-EMA EVALUATION "
+ f"Step: {global_step + 1} "
+ )
+ logger.info(pprint.pformat(eval_scores))
+ if accelerator.is_main_process:
+ eval_log = {f'eval/'+k: v for k, v in eval_scores.items()}
+ accelerator.log(eval_log, step=global_step + 1)
+
+ accelerator.wait_for_everyone()
+
+ global_step += 1
+
+ if global_step >= config.training.max_train_steps:
+ accelerator.print(
+ f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}"
+ )
+ break
+
+
+ return global_step
+
+
+@torch.no_grad()
+def eval_reconstruction(
+ config,
+ model,
+ eval_loader,
+ accelerator,
+ evaluator,
+ pretrained_tokenizer=None
+):
+ model.eval()
+ evaluator.reset_metrics()
+ local_model = accelerator.unwrap_model(model)
+
+ accelerator.wait_for_everyone()
+
+ for batch in eval_loader:
+ images = batch["image"].to(
+ accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
+ )
+
+ original_images = torch.clone(images)
+ additional_args = {}
+ if config.model.get("eval_with_attention", False):
+ additional_args["key_attention_mask"] = batch["attention_mask"].to(
+ accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
+ )
+ reconstructed_images, model_dict = local_model(images, **additional_args)
+
+ if pretrained_tokenizer is not None:
+ reconstructed_images = pretrained_tokenizer.decode(reconstructed_images.argmax(1))
+ reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0)
+ reconstructed_images = torch.round(reconstructed_images * 255.0) / 255.0
+ original_images = torch.clamp(original_images, 0.0, 1.0)
+
+ if isinstance(model_dict, dict):
+ evaluator.update(original_images, reconstructed_images.squeeze(2), model_dict["min_encoding_indices"])
+ else:
+ evaluator.update(original_images, reconstructed_images.squeeze(2), None)
+
+ accelerator.wait_for_everyone()
+
+ local_results = evaluator.result()
+
+ if accelerator.num_processes > 1:
+ gathered_results = {}
+ for key, value in local_results.items():
+ if isinstance(value, (int, float)):
+ value_tensor = torch.tensor(value, device=accelerator.device)
+ gathered_values = accelerator.gather(value_tensor)
+ gathered_results[key] = gathered_values.mean().item()
+ else:
+ gathered_results[key] = value
+
+ accelerator.wait_for_everyone()
+ model.train()
+ return gathered_results
+ else:
+ model.train()
+ return local_results
+
+
+@torch.no_grad()
+def reconstruct_images(model, original_images, fnames, accelerator,
+ global_step, output_dir, logger, config=None,
+ pretrained_tokenizer=None):
+ logger.info("Reconstructing images...")
+ original_images = torch.clone(original_images)
+ _, _, height, width = original_images.shape
+ model.eval()
+ dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ dtype = torch.bfloat16
+
+ with torch.autocast("cuda", dtype=dtype, enabled=accelerator.mixed_precision != "no"):
+ enc_tokens, encoder_dict = accelerator.unwrap_model(model).encode(original_images)
+ reconstructed_images = accelerator.unwrap_model(model).decode(enc_tokens, height=height, width=width)
+ if pretrained_tokenizer is not None:
+ reconstructed_images = pretrained_tokenizer.decode(reconstructed_images.argmax(1))
+
+ images_for_saving, images_for_logging = make_viz_from_samples(
+ original_images,
+ reconstructed_images
+ )
+ if config.training.enable_wandb:
+ accelerator.get_tracker("wandb").log_images(
+ {f"Train Reconstruction": images_for_saving},
+ step=global_step
+ )
+ else:
+ accelerator.get_tracker("tensorboard").log_images(
+ {"Train Reconstruction": images_for_logging}, step=global_step
+ )
+ root = Path(output_dir) / "train_images"
+ os.makedirs(root, exist_ok=True)
+ for i,img in enumerate(images_for_saving):
+ filename = f"{global_step:08}_s-{i:03}-{fnames[i]}.png"
+ path = os.path.join(root, filename)
+ img.save(path)
+
+ model.train()
+
+
+def save_checkpoint(model, output_dir, accelerator, global_step, logger) -> Path:
+ save_path = Path(output_dir) / f"checkpoint-{global_step}"
+
+ state_dict = accelerator.get_state_dict(model)
+ if accelerator.is_main_process:
+ unwrapped_model = accelerator.unwrap_model(model)
+ unwrapped_model.save_pretrained_weight(
+ save_path / "unwrapped_model",
+ save_function=accelerator.save,
+ state_dict=state_dict,
+ )
+ json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+"))
+ logger.info(f"Saved state to {save_path}")
+
+ accelerator.save_state(save_path)
+ return save_path
+
+
+def load_checkpoint(checkpoint_path: Path, accelerator, logger, strict=True):
+ logger.info(f"Load checkpoint from {checkpoint_path}")
+
+ accelerator.load_state(checkpoint_path, strict=strict)
+
+ with open(checkpoint_path / "metadata.json", "r") as f:
+ global_step = int(json.load(f)["global_step"])
+
+ logger.info(f"Resuming at global_step {global_step}")
+ return global_step
+
+
+def log_grad_norm(model, accelerator, global_step):
+ for name, param in model.named_parameters():
+ if param.grad is not None:
+ grads = param.grad.detach().data
+ grad_norm = (grads.norm(p=2) / grads.numel()).item()
+ accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step)
diff --git a/utils/viz_utils.py b/utils/viz_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f7e6c28f2af996b6fd489d56c34dced895379ac
--- /dev/null
+++ b/utils/viz_utils.py
@@ -0,0 +1,83 @@
+"""Utils functions for visualization."""
+
+import torch
+import torchvision.transforms.functional as F
+from einops import rearrange
+from PIL import Image, ImageDraw, ImageFont
+
+def make_viz_from_samples(
+ original_images,
+ reconstructed_images
+):
+ """Generates visualization images from original images and reconstructed images.
+
+ Args:
+ original_images: A torch.Tensor, original images.
+ reconstructed_images: A torch.Tensor, reconstructed images.
+
+ Returns:
+ A tuple containing two lists - images_for_saving and images_for_logging.
+ """
+ reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0)
+ reconstructed_images = reconstructed_images * 255.0
+ reconstructed_images = reconstructed_images.cpu()
+
+ original_images = torch.clamp(original_images, 0.0, 1.0)
+ original_images *= 255.0
+ original_images = original_images.cpu()
+
+ diff_img = torch.abs(original_images - reconstructed_images)
+ to_stack = [original_images, reconstructed_images, diff_img]
+
+ images_for_logging = rearrange(
+ torch.stack(to_stack),
+ "(l1 l2) b c h w -> b c (l1 h) (l2 w)",
+ l1=1).byte()
+ images_for_saving = [F.to_pil_image(image) for image in images_for_logging]
+
+ return images_for_saving, images_for_logging
+
+
+def make_viz_from_samples_generation(
+ generated_images,
+):
+ generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0
+ images_for_logging = rearrange(
+ generated,
+ "(l1 l2) c h w -> c (l1 h) (l2 w)",
+ l1=2)
+
+ images_for_logging = images_for_logging.cpu().byte()
+ images_for_saving = F.to_pil_image(images_for_logging)
+
+ return images_for_saving, images_for_logging
+
+
+def make_viz_from_samples_t2i_generation(
+ generated_images,
+ captions,
+):
+ generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0
+ images_for_logging = rearrange(
+ generated,
+ "(l1 l2) c h w -> c (l1 h) (l2 w)",
+ l1=2)
+
+ images_for_logging = images_for_logging.cpu().byte()
+ images_for_saving = F.to_pil_image(images_for_logging)
+
+ # Create a new image with space for captions
+ width, height = images_for_saving.size
+ caption_height = 20 * len(captions) + 10
+ new_height = height + caption_height
+ new_image = Image.new("RGB", (width, new_height), "black")
+ new_image.paste(images_for_saving, (0, 0))
+
+ # Adding captions below the image
+ draw = ImageDraw.Draw(new_image)
+ font = ImageFont.load_default()
+
+ for i, caption in enumerate(captions):
+ draw.text((10, height + 10 + i * 20), caption, fill="white", font=font)
+
+ return new_image, images_for_logging
\ No newline at end of file
diff --git a/vibetoken/__init__.py b/vibetoken/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee1543c09c194ce2ac9ed6dc4fd213506323206d
--- /dev/null
+++ b/vibetoken/__init__.py
@@ -0,0 +1,16 @@
+"""VibeToken - Minimal inference library for VibeToken image tokenizer."""
+
+from .tokenizer import (
+ VibeTokenTokenizer,
+ center_crop_to_multiple,
+ get_auto_patch_size,
+ auto_preprocess_image,
+)
+
+__version__ = "0.1.0"
+__all__ = [
+ "VibeTokenTokenizer",
+ "center_crop_to_multiple",
+ "get_auto_patch_size",
+ "auto_preprocess_image",
+]
diff --git a/vibetoken/modeling/__init__.py b/vibetoken/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fab0334ac4e28c40f68a3261677d55ba214d9c23
--- /dev/null
+++ b/vibetoken/modeling/__init__.py
@@ -0,0 +1,7 @@
+"""VibeToken modeling components."""
+
+from .vibetoken import VibeToken
+from .encoder import ResolutionEncoder
+from .decoder import ResolutionDecoder
+
+__all__ = ["VibeToken", "ResolutionEncoder", "ResolutionDecoder"]
diff --git a/vibetoken/modeling/blocks.py b/vibetoken/modeling/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f5c70e95ccf7385802aeb80e78d3cc68ce61b4c
--- /dev/null
+++ b/vibetoken/modeling/blocks.py
@@ -0,0 +1,288 @@
+"""Transformer building blocks for VibeToken.
+
+Reference:
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
+ https://github.com/baofff/U-ViT/blob/main/libs/timm.py
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import OrderedDict
+from typing import Optional
+import einops
+
+
+# Determine attention mode based on available implementations
+if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
+ ATTENTION_MODE = 'flash'
+else:
+ try:
+ import xformers
+ import xformers.ops
+ ATTENTION_MODE = 'xformers'
+ except ImportError:
+ ATTENTION_MODE = 'math'
+
+
+class Attention(nn.Module):
+ """Multi-head self-attention with support for flash/xformers/math backends."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_scale: Optional[float] = None,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, L, C = x.shape
+
+ qkv = self.qkv(x)
+ if ATTENTION_MODE == 'flash':
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'xformers':
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ x = xformers.ops.memory_efficient_attention(q, k, v)
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
+ else: # math
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class ResidualAttentionBlock(nn.Module):
+ """Residual attention block with MLP."""
+
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ act_layer: type = nn.GELU,
+ norm_layer: type = nn.LayerNorm,
+ ):
+ super().__init__()
+ self.ln_1 = norm_layer(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.mlp_ratio = mlp_ratio
+
+ if mlp_ratio > 0:
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+
+ def attention(
+ self,
+ x: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ return self.attn(x, x, x, attn_mask=attention_mask, need_weights=False)[0]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ attn_output = self.attention(x=self.ln_1(x), attention_mask=attention_mask)
+ x = x + attn_output
+ if self.mlp_ratio > 0:
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """Drop paths (Stochastic Depth) per sample."""
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_()
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample."""
+
+ def __init__(self, drop_prob: float = 0.0):
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ """MLP block with GELU activation."""
+
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: type = nn.GELU,
+ drop: float = 0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class UViTBlock(nn.Module):
+ """U-ViT block with optional skip connection."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ qk_scale: Optional[float] = None,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ drop_path: float = 0.0,
+ act_layer: type = nn.GELU,
+ norm_layer: type = nn.LayerNorm,
+ skip: bool = False,
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias,
+ qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip, use_reentrant=False)
+ return self._forward(x, skip)
+
+ def _forward(self, x: torch.Tensor, skip: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.skip_linear is not None and skip is not None:
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class ResizableBlur(nn.Module):
+ """Anti-aliasing layer for downsampling with learnable blur kernel."""
+
+ def __init__(
+ self,
+ channels: int,
+ max_kernel_size: int = 9,
+ init_type: str = "gaussian",
+ ):
+ super().__init__()
+ self.C = channels
+ K = max_kernel_size
+ assert K % 2 == 1, "kernel must be odd"
+
+ if init_type == "gaussian":
+ ax = torch.arange(-(K // 2), K // 2 + 1)
+ g1d = torch.exp(-0.5 * (ax / (K / 6.0)) ** 2)
+ g2d = torch.outer(g1d, g1d)
+ kernel = g2d / g2d.sum()
+ elif init_type == "lanczos":
+ a = K // 2
+ x = torch.arange(-a, a + 1).float()
+ sinc = lambda t: torch.where(
+ t == 0, torch.ones_like(t),
+ torch.sin(torch.pi * t) / (torch.pi * t)
+ )
+ k1d = sinc(x) * sinc(x / a)
+ k2d = torch.outer(k1d, k1d)
+ kernel = k2d / k2d.sum()
+ else:
+ raise ValueError(f"Unknown init_type: {init_type}")
+
+ self.weight = nn.Parameter(kernel.unsqueeze(0).unsqueeze(0))
+
+ @staticmethod
+ def _resize_and_normalise(weight: torch.Tensor, k_size: int) -> torch.Tensor:
+ if weight.shape[-1] != k_size:
+ weight = F.interpolate(weight, size=(k_size, k_size), mode="bilinear", align_corners=True)
+ weight = weight / weight.sum(dim=(-2, -1), keepdim=True).clamp(min=1e-8)
+ return weight
+
+ def forward(self, x: torch.Tensor, input_size: tuple, target_size: tuple) -> torch.Tensor:
+ input_h, input_w = input_size
+ target_h, target_w = target_size
+
+ scale_h = input_h / target_h
+ scale_w = input_w / target_w
+
+ k_size_h = min(self.weight.shape[-1], max(1, int(2 * scale_h + 3)))
+ k_size_w = min(self.weight.shape[-1], max(1, int(2 * scale_w + 3)))
+ k_size_h = k_size_h if k_size_h % 2 == 1 else k_size_h + 1
+ k_size_w = k_size_w if k_size_w % 2 == 1 else k_size_w + 1
+ k_size = max(k_size_h, k_size_w)
+
+ stride_h = max(1, round(scale_h))
+ stride_w = max(1, round(scale_w))
+ pad_h = k_size_h // 2
+ pad_w = k_size_w // 2
+
+ k = self._resize_and_normalise(self.weight, k_size)
+ k = k.repeat(self.C, 1, 1, 1)
+
+ result = F.conv2d(x, weight=k, stride=(stride_h, stride_w),
+ padding=(pad_h, pad_w), groups=self.C)
+
+ if result.shape[2:] != target_size:
+ result = F.interpolate(result, size=target_size, mode='bilinear', align_corners=True)
+
+ return result
+
+
+def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
+ """Expand a single token to batch size."""
+ return token.unsqueeze(0).expand(batch_size, -1, -1)
diff --git a/vibetoken/modeling/decoder.py b/vibetoken/modeling/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8e265c648fcefdd89702a5ff102cbc5cf0c993b
--- /dev/null
+++ b/vibetoken/modeling/decoder.py
@@ -0,0 +1,206 @@
+"""Resolution-aware decoder for VibeToken.
+
+Vision Transformer-based decoder with flexible output resolutions.
+"""
+
+import torch
+import torch.nn as nn
+from typing import Optional, Tuple
+from einops.layers.torch import Rearrange
+
+from .blocks import ResidualAttentionBlock, ResizableBlur, _expand_token
+from .embeddings import FuzzyEmbedding
+
+
+class ResolutionDecoder(nn.Module):
+ """Vision Transformer decoder with flexible resolution support.
+
+ Decodes latent tokens back to images with support for variable
+ output resolutions and patch sizes.
+ """
+
+ # Model size configurations
+ MODEL_CONFIGS = {
+ "small": {"width": 512, "num_layers": 8, "num_heads": 8},
+ "base": {"width": 768, "num_layers": 12, "num_heads": 12},
+ "large": {"width": 1024, "num_layers": 24, "num_heads": 16},
+ }
+
+ def __init__(self, config):
+ """Initialize ResolutionDecoder.
+
+ Args:
+ config: OmegaConf config with model parameters.
+ """
+ super().__init__()
+ self.config = config
+
+ # Extract config values
+ vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model
+ self.image_size = getattr(config.dataset.preprocessing, 'crop_size', 512) if hasattr(config, 'dataset') else 512
+ self.patch_size = getattr(vq_config, 'vit_dec_patch_size', 32)
+ self.model_size = getattr(vq_config, 'vit_dec_model_size', 'large')
+ self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256)
+ self.token_size = getattr(vq_config, 'token_size', 256)
+ self.is_legacy = getattr(vq_config, 'is_legacy', False)
+
+ if self.is_legacy:
+ raise NotImplementedError("Legacy mode is not supported in this inference-only version")
+
+ # Get model dimensions
+ model_cfg = self.MODEL_CONFIGS[self.model_size]
+ self.width = model_cfg["width"]
+ self.num_layers = model_cfg["num_layers"]
+ self.num_heads = model_cfg["num_heads"]
+
+ # Input projection
+ self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)
+
+ # Embeddings
+ scale = self.width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
+ self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
+ self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
+ self.latent_token_positional_embedding = nn.Parameter(
+ scale * torch.randn(self.num_latent_tokens, self.width)
+ )
+ self.ln_pre = nn.LayerNorm(self.width)
+
+ # Transformer layers
+ self.transformer = nn.ModuleList([
+ ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
+ for _ in range(self.num_layers)
+ ])
+
+ # Output projection
+ self.ln_post = nn.LayerNorm(self.width)
+ self.ffn = nn.Conv2d(
+ self.width, self.patch_size * self.patch_size * 3,
+ kernel_size=1, padding=0, bias=True
+ )
+ self.rearrange = Rearrange(
+ 'b (p1 p2 c) h w -> b c (h p1) (w p2)',
+ p1=self.patch_size, p2=self.patch_size
+ )
+ self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos")
+ self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
+
+ def _select_patch_size(self, height: int, width: int) -> int:
+ """Select appropriate patch size based on target resolution.
+
+ Args:
+ height: Target image height.
+ width: Target image width.
+
+ Returns:
+ Selected patch size.
+ """
+ total_pixels = height * width
+ min_patches, max_patches = 256, 1024
+
+ possible_sizes = []
+ for ps in [8, 16, 32]:
+ grid_h = height // ps
+ grid_w = width // ps
+ total_patches = grid_h * grid_w
+ if min_patches <= total_patches <= max_patches:
+ possible_sizes.append(ps)
+
+ if not possible_sizes:
+ # Find closest to target range
+ patch_counts = []
+ for ps in [8, 16, 32]:
+ grid_h = height // ps
+ grid_w = width // ps
+ patch_counts.append((ps, grid_h * grid_w))
+ patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches)))
+ return patch_counts[0][0]
+
+ return possible_sizes[0]
+
+ def forward(
+ self,
+ z_quantized: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ decode_patch_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ """Decode latent tokens to images.
+
+ Args:
+ z_quantized: Quantized latent features (B, C, H, W).
+ attention_mask: Optional attention mask.
+ height: Target image height.
+ width: Target image width.
+ decode_patch_size: Optional custom patch size for decoding.
+
+ Returns:
+ Decoded images (B, 3, height, width), values in [0, 1].
+ """
+ N, C, H, W = z_quantized.shape
+
+ # Reshape and project input
+ x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # (N, seq_len, C*H)
+ x = self.decoder_embed(x)
+
+ batchsize, seq_len, _ = x.shape
+
+ # Default output size
+ if height is None:
+ height = self.image_size
+ if width is None:
+ width = self.image_size
+
+ # Determine patch size
+ if decode_patch_size is None:
+ selected_patch_size = self._select_patch_size(height, width)
+ else:
+ selected_patch_size = decode_patch_size
+
+ if isinstance(selected_patch_size, int):
+ selected_patch_size = (selected_patch_size, selected_patch_size)
+
+ grid_height = height // selected_patch_size[0]
+ grid_width = width // selected_patch_size[1]
+
+ # Create mask tokens for output positions
+ mask_tokens = self.mask_token.repeat(batchsize, grid_height * grid_width, 1).to(x.dtype)
+ mask_tokens = torch.cat([
+ _expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
+ mask_tokens
+ ], dim=1)
+
+ # Add positional embeddings
+ mask_tokens = mask_tokens + self.positional_embedding(
+ grid_height, grid_width, train=False
+ ).to(mask_tokens.dtype)
+
+ x = x + self.latent_token_positional_embedding[:seq_len]
+ x = torch.cat([mask_tokens, x], dim=1)
+
+ # Pre-norm and reshape for transformer
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # (seq_len, B, width)
+
+ # Apply transformer layers
+ for layer in self.transformer:
+ x = layer(x, attention_mask=None)
+
+ x = x.permute(1, 0, 2) # (B, seq_len, width)
+
+ # Extract output tokens (excluding class token and latent tokens)
+ x = x[:, 1:1 + grid_height * grid_width]
+ x = self.ln_post(x)
+
+ # Reshape to spatial format and project to pixels
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width)
+ x = self.ffn(x.contiguous())
+ x = self.rearrange(x)
+
+ # Downsample to target resolution
+ _, _, org_h, org_w = x.shape
+ x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width))
+ x = self.conv_out(x)
+
+ return x
diff --git a/vibetoken/modeling/embeddings.py b/vibetoken/modeling/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bf3c89fb5ccf4e7a648fb7ef0426475c82ac424
--- /dev/null
+++ b/vibetoken/modeling/embeddings.py
@@ -0,0 +1,193 @@
+"""Embedding modules for VibeToken.
+
+Includes positional embeddings with fuzzy interpolation for variable resolutions.
+"""
+
+import math
+from typing import Tuple, Any
+import collections.abc
+from itertools import repeat
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+import numpy as np
+from einops import rearrange
+from torch import Tensor, vmap
+
+
+def to_2tuple(x: Any) -> Tuple:
+ """Convert input to 2-tuple."""
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, 2))
+
+
+class FuzzyEmbedding(nn.Module):
+ """Fuzzy positional embedding with bilinear interpolation.
+
+ Supports variable-resolution inputs by interpolating from a base
+ positional embedding grid.
+ """
+
+ def __init__(
+ self,
+ grid_size: int,
+ scale: float,
+ width: int,
+ apply_fuzzy: bool = False,
+ ):
+ """Initialize FuzzyEmbedding.
+
+ Args:
+ grid_size: Base grid size (must be 1024).
+ scale: Initialization scale for embeddings.
+ width: Embedding dimension.
+ apply_fuzzy: Whether to add noise during training.
+ """
+ super().__init__()
+ assert grid_size == 1024, "grid_size must be 1024"
+
+ self.grid_size = grid_size
+ self.scale = scale
+ self.width = width
+ self.apply_fuzzy = apply_fuzzy
+
+ self.positional_embedding = nn.Parameter(scale * torch.randn(grid_size, width))
+ self.class_positional_embedding = nn.Parameter(scale * torch.randn(1, width))
+
+ @torch.amp.autocast('cuda', enabled=False)
+ def forward(
+ self,
+ grid_height: int,
+ grid_width: int,
+ train: bool = False,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """Compute positional embeddings for given grid size.
+
+ Args:
+ grid_height: Target grid height.
+ grid_width: Target grid width.
+ train: Whether in training mode (affects fuzzy noise).
+ dtype: Output dtype.
+
+ Returns:
+ Positional embeddings of shape (1 + grid_height * grid_width, width).
+ """
+ device = self.positional_embedding.device
+
+ meshx, meshy = torch.meshgrid(
+ torch.arange(grid_height, device=device),
+ torch.arange(grid_width, device=device),
+ indexing='ij'
+ )
+ meshx = meshx.to(dtype)
+ meshy = meshy.to(dtype)
+
+ # Normalize coordinates to [-1, 1]
+ meshx = 2 * (meshx / max(grid_height - 1, 1)) - 1
+ meshy = 2 * (meshy / max(grid_width - 1, 1)) - 1
+
+ if self.apply_fuzzy and train:
+ noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004
+ noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004
+ meshx = meshx + noise_x
+ meshy = meshy + noise_y
+
+ grid = torch.stack((meshy, meshx), 2).unsqueeze(0)
+
+ # Reshape positional embedding to 2D grid
+ base_size = int(math.sqrt(self.grid_size))
+ positional_embedding = einops.rearrange(
+ self.positional_embedding, "(h w) d -> d h w", h=base_size, w=base_size
+ )
+ positional_embedding = positional_embedding.to(dtype).unsqueeze(0)
+
+ # Bilinear interpolation
+ fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False)
+ fuzzy_embedding = fuzzy_embedding.to(dtype)
+ fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0)
+
+ final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0)
+ return final_embedding
+
+
+class FlexiPatchEmbed:
+ """Flexible patch embedding utilities for variable patch sizes.
+
+ Based on FlexiViT: https://arxiv.org/abs/2212.08013
+ """
+
+ def __init__(self, base_patch_size: int, width: int):
+ """Initialize FlexiPatchEmbed.
+
+ Args:
+ base_patch_size: Original patch size of the model.
+ width: Embedding dimension.
+ """
+ self.base_patch_size = base_patch_size
+ self.width = width
+ self.pinvs = {}
+
+ def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
+ """Bilinear resize of 2D tensor."""
+ x_resized = F.interpolate(
+ x[None, None, ...],
+ shape,
+ mode="bilinear",
+ antialias=False,
+ )
+ return x_resized[0, 0, ...]
+
+ def _calculate_pinv(
+ self,
+ old_shape: Tuple[int, int],
+ new_shape: Tuple[int, int],
+ device: torch.device,
+ ) -> Tensor:
+ """Calculate pseudo-inverse of resize matrix."""
+ mat = []
+ for i in range(np.prod(old_shape)):
+ basis_vec = torch.zeros(old_shape, device=device)
+ basis_vec[np.unravel_index(i, old_shape)] = 1.0
+ mat.append(self._resize(basis_vec, new_shape).reshape(-1))
+ resize_matrix = torch.stack(mat)
+ return torch.linalg.pinv(resize_matrix)
+
+ def resize_patch_embed(
+ self,
+ patch_embed: Tensor,
+ base_patch_size: Tuple[int, int],
+ new_patch_size: Tuple[int, int],
+ ) -> Tensor:
+ """Resize patch embedding kernel to new patch size.
+
+ Args:
+ patch_embed: Original patch embedding weight (out_ch, in_ch, H, W).
+ base_patch_size: Original patch size.
+ new_patch_size: Target patch size.
+
+ Returns:
+ Resized patch embedding weight.
+ """
+ if base_patch_size == new_patch_size:
+ return patch_embed
+
+ if new_patch_size not in self.pinvs:
+ self.pinvs[new_patch_size] = self._calculate_pinv(
+ base_patch_size, new_patch_size, device=patch_embed.device
+ )
+ pinv = self.pinvs[new_patch_size]
+
+ def resample_patch_embed(patch_embed: Tensor):
+ h, w = new_patch_size
+ original_dtype = patch_embed.dtype
+ patch_embed_float = patch_embed.float()
+ resampled_kernel = pinv @ patch_embed_float.reshape(-1)
+ resampled_kernel = resampled_kernel.to(original_dtype)
+ return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
+
+ v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
+ return v_resample_patch_embed(patch_embed)
diff --git a/vibetoken/modeling/encoder.py b/vibetoken/modeling/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3d78db0b765557a0d1c97595369c22a4b557bfe
--- /dev/null
+++ b/vibetoken/modeling/encoder.py
@@ -0,0 +1,234 @@
+"""Resolution-aware encoder for VibeToken.
+
+Vision Transformer-based encoder with flexible patch sizes for variable-resolution inputs.
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple
+from einops import rearrange
+from torch import Tensor, vmap
+import numpy as np
+
+from .blocks import ResidualAttentionBlock, _expand_token
+from .embeddings import FuzzyEmbedding, to_2tuple
+
+
+class ResolutionEncoder(nn.Module):
+ """Vision Transformer encoder with flexible resolution support.
+
+ Encodes images into latent tokens using a ViT architecture with
+ support for variable input resolutions and patch sizes.
+ """
+
+ # Model size configurations
+ MODEL_CONFIGS = {
+ "small": {"width": 512, "num_layers": 8, "num_heads": 8},
+ "base": {"width": 768, "num_layers": 12, "num_heads": 12},
+ "large": {"width": 1024, "num_layers": 24, "num_heads": 16},
+ }
+
+ def __init__(self, config):
+ """Initialize ResolutionEncoder.
+
+ Args:
+ config: OmegaConf config with model parameters.
+ """
+ super().__init__()
+ self.config = config
+
+ # Extract config values
+ vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model
+ self.patch_size = getattr(vq_config, 'vit_enc_patch_size', 32)
+ self.model_size = getattr(vq_config, 'vit_enc_model_size', 'large')
+ self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256)
+ self.token_size = getattr(vq_config, 'token_size', 256)
+ self.is_legacy = getattr(vq_config, 'is_legacy', False)
+
+ # Handle VAE mode (doubles token size for mean+std)
+ quantize_mode = getattr(vq_config, 'quantize_mode', 'vq')
+ if quantize_mode == "vae":
+ self.token_size = self.token_size * 2
+
+ # Get model dimensions from config
+ model_cfg = self.MODEL_CONFIGS[self.model_size]
+ self.width = model_cfg["width"]
+ self.num_layers = model_cfg["num_layers"]
+ self.num_heads = model_cfg["num_heads"]
+
+ # Patch embedding
+ self.patch_embed = nn.Conv2d(
+ in_channels=3, out_channels=self.width,
+ kernel_size=self.patch_size, stride=self.patch_size, bias=True
+ )
+
+ # Embeddings
+ scale = self.width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
+ self.positional_embedding = FuzzyEmbedding(1024, scale, self.width)
+ self.latent_token_positional_embedding = nn.Parameter(
+ scale * torch.randn(self.num_latent_tokens, self.width)
+ )
+ self.ln_pre = nn.LayerNorm(self.width)
+
+ # Transformer layers
+ self.transformer = nn.ModuleList([
+ ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
+ for _ in range(self.num_layers)
+ ])
+
+ # Output projection
+ self.ln_post = nn.LayerNorm(self.width)
+ self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
+
+ # Cache for pseudo-inverse matrices
+ self.pinvs = {}
+
+ def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor:
+ """Bilinear resize of 2D tensor."""
+ x_resized = F.interpolate(
+ x[None, None, ...], shape, mode="bilinear", antialias=False
+ )
+ return x_resized[0, 0, ...]
+
+ def _calculate_pinv(
+ self,
+ old_shape: Tuple[int, int],
+ new_shape: Tuple[int, int],
+ device: torch.device,
+ ) -> Tensor:
+ """Calculate pseudo-inverse of resize matrix for FlexiViT."""
+ mat = []
+ for i in range(np.prod(old_shape)):
+ basis_vec = torch.zeros(old_shape, device=device)
+ basis_vec[np.unravel_index(i, old_shape)] = 1.0
+ mat.append(self._resize(basis_vec, new_shape).reshape(-1))
+ resize_matrix = torch.stack(mat)
+ return torch.linalg.pinv(resize_matrix)
+
+ def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]) -> Tensor:
+ """Resize patch embedding kernel to new patch size (FlexiViT).
+
+ Args:
+ patch_embed: Original weight tensor (out_ch, in_ch, H, W).
+ new_patch_size: Target (H, W) patch size.
+
+ Returns:
+ Resized weight tensor.
+ """
+ base_size = to_2tuple(self.patch_size)
+ if base_size == new_patch_size:
+ return patch_embed
+
+ if new_patch_size not in self.pinvs:
+ self.pinvs[new_patch_size] = self._calculate_pinv(
+ base_size, new_patch_size, device=patch_embed.device
+ )
+ pinv = self.pinvs[new_patch_size]
+
+ def resample_patch_embed(pe: Tensor) -> Tensor:
+ h, w = new_patch_size
+ original_dtype = pe.dtype
+ resampled = pinv @ pe.float().reshape(-1)
+ return rearrange(resampled.to(original_dtype), "(h w) -> h w", h=h, w=w)
+
+ v_resample = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
+ return v_resample(patch_embed)
+
+ def apply_flexivit_patch_embed(self, x: Tensor, target_patch_size: Tuple[int, int]) -> Tensor:
+ """Apply patch embedding with flexible patch size.
+
+ Args:
+ x: Input image tensor (B, 3, H, W).
+ target_patch_size: Target patch size (H, W).
+
+ Returns:
+ Patch embeddings (B, C, grid_H, grid_W).
+ """
+ patch_size = to_2tuple(target_patch_size)
+
+ if patch_size == to_2tuple(self.patch_size):
+ weight = self.patch_embed.weight
+ else:
+ weight = self.resize_patch_embed(self.patch_embed.weight, patch_size)
+
+ return F.conv2d(x, weight, bias=self.patch_embed.bias, stride=patch_size)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ latent_tokens: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encode_patch_size: Optional[Tuple[int, int]] = None,
+ ) -> torch.Tensor:
+ """Encode images to latent tokens.
+
+ Args:
+ pixel_values: Input images (B, 3, H, W), values in [0, 1].
+ latent_tokens: Learnable latent tokens (num_latent, width).
+ attention_mask: Optional attention mask.
+ encode_patch_size: Optional custom patch size for encoding.
+
+ Returns:
+ Encoded latent features (B, token_size, 1, num_latent).
+ """
+ batch_size, _, H, W = pixel_values.shape
+
+ # Determine patch size
+ if encode_patch_size is None:
+ target_patch_size = (self.patch_size, self.patch_size)
+ elif isinstance(encode_patch_size, int):
+ target_patch_size = (encode_patch_size, encode_patch_size)
+ else:
+ target_patch_size = encode_patch_size
+
+ # Apply flexible patch embedding
+ x = self.apply_flexivit_patch_embed(pixel_values, target_patch_size)
+
+ # Flatten spatial dimensions
+ x = x.reshape(x.shape[0], x.shape[1], -1)
+ x = x.permute(0, 2, 1) # (B, num_patches, width)
+
+ # Add class embedding
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
+
+ # Compute grid dimensions
+ grid_height = H // target_patch_size[0]
+ grid_width = W // target_patch_size[1]
+
+ # Add positional embeddings to latent tokens
+ num_latent = latent_tokens.shape[0]
+ latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
+ latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)[:num_latent]
+
+ # Add positional embeddings to image patches
+ x = x + self.positional_embedding(grid_height, grid_width, train=False, dtype=x.dtype)
+
+ # Concatenate image patches and latent tokens
+ x = torch.cat([x, latent_tokens], dim=1)
+
+ # Pre-norm and reshape for transformer
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # (seq_len, B, width)
+
+ # Apply transformer layers
+ for layer in self.transformer:
+ x = layer(x, attention_mask=None)
+
+ x = x.permute(1, 0, 2) # (B, seq_len, width)
+
+ # Extract latent tokens
+ latent_tokens = x[:, 1 + grid_height * grid_width:]
+ latent_tokens = self.ln_post(latent_tokens)
+
+ # Reshape and project to token size
+ if self.is_legacy:
+ latent_tokens = latent_tokens.reshape(batch_size, self.width, num_latent, 1)
+ else:
+ latent_tokens = latent_tokens.reshape(batch_size, num_latent, self.width, 1).permute(0, 2, 1, 3)
+
+ latent_tokens = self.conv_out(latent_tokens)
+ latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, num_latent)
+
+ return latent_tokens
diff --git a/vibetoken/modeling/vibetoken.py b/vibetoken/modeling/vibetoken.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60b76472dba6dc78659bee63baead4625c9ea99
--- /dev/null
+++ b/vibetoken/modeling/vibetoken.py
@@ -0,0 +1,273 @@
+"""VibeToken model - Main model class for image tokenization.
+
+Combines encoder, quantizer, and decoder into a unified model for
+encoding images to discrete tokens and decoding back to images.
+"""
+
+import torch
+import torch.nn as nn
+from typing import Dict, Any, Optional, Tuple, Union
+from einops import rearrange
+from omegaconf import OmegaConf
+
+from .encoder import ResolutionEncoder
+from .decoder import ResolutionDecoder
+from ..quantizer import VectorQuantizer, VectorQuantizerMVQ
+
+
+class VibeToken(nn.Module):
+ """VibeToken image tokenizer model.
+
+ A Vision Transformer-based image tokenizer that encodes images into
+ discrete tokens and decodes them back to images. Supports multiple
+ quantization modes (VQ, MVQ, VAE) and variable resolutions.
+ """
+
+ def __init__(self, config: Union[dict, Any]):
+ """Initialize VibeToken model.
+
+ Args:
+ config: Configuration dict or OmegaConf object with model parameters.
+ """
+ super().__init__()
+
+ if isinstance(config, dict):
+ config = OmegaConf.create(config)
+
+ self.config = config
+
+ # Get model config
+ vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model
+
+ # Quantization mode
+ self.quantize_mode = getattr(vq_config, 'quantize_mode', 'vq')
+ if self.quantize_mode not in ["vq", "vae", "mvq"]:
+ raise ValueError(f"Unsupported quantize mode: {self.quantize_mode}")
+
+ # Build encoder and decoder
+ self.encoder = ResolutionEncoder(config)
+ self.decoder = ResolutionDecoder(config)
+
+ # Latent tokens (learnable queries for encoder)
+ self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256)
+ scale = self.encoder.width ** -0.5
+ self.latent_tokens = nn.Parameter(
+ scale * torch.randn(self.num_latent_tokens, self.encoder.width)
+ )
+
+ # Build quantizer based on mode
+ if self.quantize_mode == "vq":
+ self.quantize = VectorQuantizer(
+ codebook_size=getattr(vq_config, 'codebook_size', 32768),
+ token_size=getattr(vq_config, 'token_size', 256),
+ commitment_cost=getattr(vq_config, 'commitment_cost', 0.25),
+ use_l2_norm=getattr(vq_config, 'use_l2_norm', False),
+ )
+ elif self.quantize_mode == "mvq":
+ self.quantize = VectorQuantizerMVQ(
+ codebook_size=getattr(vq_config, 'codebook_size', 32768),
+ token_size=getattr(vq_config, 'token_size', 256),
+ commitment_cost=getattr(vq_config, 'commitment_cost', 0.25),
+ use_l2_norm=getattr(vq_config, 'use_l2_norm', False),
+ num_codebooks=getattr(vq_config, 'num_codebooks', 8),
+ )
+ elif self.quantize_mode == "vae":
+ # VAE mode uses DiagonalGaussianDistribution directly in forward
+ from ..quantizer.vector_quantizer import DiagonalGaussianDistribution
+ self.quantize = DiagonalGaussianDistribution
+
+ # Initialize weights
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize weights for module."""
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ nn.init.zeros_(module.bias)
+ nn.init.ones_(module.weight)
+
+ def encode(
+ self,
+ x: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encode_patch_size: Optional[Tuple[int, int]] = None,
+ length: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Encode images to quantized latent tokens.
+
+ Args:
+ x: Input images (B, 3, H, W), values in [0, 1].
+ attention_mask: Optional attention mask.
+ encode_patch_size: Optional custom patch size for encoding.
+ length: Optional token length limit.
+
+ Returns:
+ z_quantized: Quantized latent features.
+ result_dict: Dictionary with token indices and optional losses.
+ """
+ # Select latent tokens
+ if length is not None:
+ latent_tokens = self.latent_tokens[:length + 1]
+ else:
+ latent_tokens = self.latent_tokens
+
+ # Encode through ViT encoder
+ z = self.encoder(
+ pixel_values=x,
+ latent_tokens=latent_tokens,
+ attention_mask=attention_mask,
+ encode_patch_size=encode_patch_size,
+ )
+
+ # Quantize
+ if self.quantize_mode in ["vq", "mvq"]:
+ z_quantized, result_dict = self.quantize(z)
+ elif self.quantize_mode == "vae":
+ posteriors = self.quantize(z)
+ z_quantized = posteriors.sample()
+ result_dict = {"posteriors": posteriors}
+
+ return z_quantized, result_dict
+
+ def decode(
+ self,
+ z_quantized: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ decode_patch_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ """Decode quantized features to images.
+
+ Args:
+ z_quantized: Quantized latent features.
+ attention_mask: Optional attention mask.
+ height: Target image height.
+ width: Target image width.
+ decode_patch_size: Optional custom patch size for decoding.
+
+ Returns:
+ Decoded images (B, 3, height, width), values approximately in [0, 1].
+ """
+ z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
+ decoded = self.decoder(
+ z_quantized,
+ attention_mask=attention_mask,
+ height=height,
+ width=width,
+ decode_patch_size=decode_patch_size,
+ )
+ return decoded
+
+ def decode_tokens(
+ self,
+ tokens: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ decode_patch_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ """Decode discrete tokens directly to images.
+
+ This is the main interface used by generators.
+
+ Args:
+ tokens: Token indices. Shape depends on quantize_mode:
+ - VQ: (B, seq_len) or (B, 1, seq_len)
+ - MVQ: (B, num_codebooks, seq_len) or (B, seq_len, 1)
+ - VAE: Continuous latent (B, C, H, W)
+ attention_mask: Optional attention mask.
+ height: Target image height.
+ width: Target image width.
+ decode_patch_size: Optional custom patch size for decoding.
+
+ Returns:
+ Decoded images (B, 3, height, width), values approximately in [0, 1].
+ """
+ if self.quantize_mode == "vq":
+ tokens = tokens.squeeze(1)
+ batch, seq_len = tokens.shape
+ z_quantized = self.quantize.get_codebook_entry(tokens.reshape(-1))
+ z_quantized = z_quantized.reshape(batch, 1, seq_len, -1)
+ z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
+ elif self.quantize_mode == "mvq":
+ z_quantized = self.quantize.get_codebook_entry(tokens)
+ elif self.quantize_mode == "vae":
+ z_quantized = tokens
+
+ z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
+ decoded = self.decode(
+ z_quantized,
+ attention_mask=attention_mask,
+ height=height,
+ width=width,
+ decode_patch_size=decode_patch_size,
+ )
+ return decoded
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ key_attention_mask: Optional[torch.Tensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Full encode-decode forward pass.
+
+ Args:
+ x: Input images (B, 3, H, W), values in [0, 1].
+ key_attention_mask: Optional attention mask.
+ height: Target output height (default: same as input).
+ width: Target output width (default: same as input).
+
+ Returns:
+ decoded: Reconstructed images.
+ result_dict: Dictionary with token indices and losses.
+ """
+ if height is None:
+ _, _, height, width = x.shape
+
+ z_quantized, result_dict = self.encode(x, attention_mask=key_attention_mask)
+ z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype)
+ decoded = self.decode(z_quantized, attention_mask=key_attention_mask, height=height, width=width)
+
+ return decoded, result_dict
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ config_path: str,
+ checkpoint_path: str,
+ device: str = "cuda",
+ dtype: Optional[torch.dtype] = None,
+ ) -> "VibeToken":
+ """Load pretrained model from config and checkpoint.
+
+ Args:
+ config_path: Path to YAML config file.
+ checkpoint_path: Path to model checkpoint (.pt or .bin).
+ device: Device to load model on.
+ dtype: Optional dtype for model parameters.
+
+ Returns:
+ Loaded VibeToken model.
+ """
+ config = OmegaConf.load(config_path)
+ model = cls(config)
+
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
+ model.load_state_dict(state_dict)
+
+ model.eval()
+ model.requires_grad_(False)
+ model.to(device)
+
+ if dtype is not None:
+ model.to(dtype)
+
+ return model
diff --git a/vibetoken/quantizer/__init__.py b/vibetoken/quantizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2044a6a5301a12a660da4e20c79a63a89bf5c0ed
--- /dev/null
+++ b/vibetoken/quantizer/__init__.py
@@ -0,0 +1,6 @@
+"""Quantizer modules for VibeToken."""
+
+from .vector_quantizer import VectorQuantizer
+from .mvq import VectorQuantizerMVQ
+
+__all__ = ["VectorQuantizer", "VectorQuantizerMVQ"]
diff --git a/vibetoken/quantizer/mvq.py b/vibetoken/quantizer/mvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c8033572653783cfc25b4174b6228069b365d8a
--- /dev/null
+++ b/vibetoken/quantizer/mvq.py
@@ -0,0 +1,183 @@
+"""Multi-codebook Vector Quantizer (MVQ) for VibeToken.
+
+Uses multiple independent codebooks for richer discrete representations.
+"""
+
+import torch
+import torch.nn as nn
+from typing import Tuple, Dict, Any
+
+from .vector_quantizer import VectorQuantizer
+
+
+class VectorQuantizerMVQ(nn.Module):
+ """Multi-codebook Vector Quantizer.
+
+ Splits the latent representation into multiple parts, each quantized
+ by an independent codebook. This allows for more expressive discrete
+ representations.
+ """
+
+ def __init__(
+ self,
+ codebook_size: int,
+ token_size: int,
+ commitment_cost: float = 0.25,
+ use_l2_norm: bool = False,
+ num_codebooks: int = 8,
+ ):
+ """Initialize MVQ.
+
+ Args:
+ codebook_size: Total codebook size (divided among codebooks).
+ token_size: Total token dimension (divided among codebooks).
+ commitment_cost: Weight for commitment loss.
+ use_l2_norm: Whether to L2-normalize embeddings.
+ num_codebooks: Number of independent codebooks.
+ """
+ super().__init__()
+ self.num_codebooks = num_codebooks
+ self.codebooks = nn.ModuleList()
+
+ for _ in range(num_codebooks):
+ codebook = VectorQuantizer(
+ codebook_size=codebook_size // num_codebooks,
+ token_size=token_size // num_codebooks,
+ commitment_cost=commitment_cost,
+ use_l2_norm=use_l2_norm,
+ )
+ self.codebooks.append(codebook)
+
+ def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Quantize features using multiple codebooks.
+
+ Args:
+ features: Input features of shape (B, C, H, W).
+
+ Returns:
+ z_quantized: Quantized features of shape (B, C, H, W).
+ result_dict: Dictionary with losses and indices.
+ """
+ latent_features = []
+ all_result_dicts = []
+ chunk_size = features.shape[1] // self.num_codebooks
+ splited_features = features.split(chunk_size, dim=1)
+
+ for i, codebook in enumerate(self.codebooks):
+ latent_feature, result_dict = codebook(splited_features[i].float())
+ latent_features.append(latent_feature.to(features.dtype))
+ all_result_dicts.append(result_dict)
+
+ # Concatenate quantized features
+ z_quantized = torch.cat(latent_features, dim=1)
+
+ # Aggregate losses
+ global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks
+ global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks
+ global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks
+
+ # Stack indices: shape (B, num_codebooks, H, W)
+ all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1)
+
+ result_dict = dict(
+ quantizer_loss=global_quantizer_loss,
+ commitment_loss=global_commitment_loss,
+ codebook_loss=global_codebook_loss,
+ min_encoding_indices=all_indices
+ )
+
+ return z_quantized, result_dict
+
+ def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor:
+ """Get codebook entries for multi-codebook indices.
+
+ Args:
+ indices: Tensor of shape:
+ - (B, num_codebooks): single token per codebook
+ - (B, num_codebooks, seq_len): sequence of tokens per codebook
+ - (B, num_codebooks, H, W): 2D spatial tokens per codebook
+ - (B, seq_len, 1): generator format (single codebook index per position)
+
+ Returns:
+ z_quantized: Quantized features.
+ """
+ if len(indices.shape) == 2:
+ # Shape: (B, num_codebooks) - each entry is a token index
+ latent_features = []
+ for i, codebook in enumerate(self.codebooks):
+ sub_indices = indices[:, i]
+ latent_feature = codebook.get_codebook_entry(sub_indices)
+ latent_features.append(latent_feature)
+ return torch.cat(latent_features, dim=-1)
+
+ elif len(indices.shape) == 3:
+ batch_size, dim1, dim2 = indices.shape
+
+ # Check if this is (B, num_codebooks, seq_len) or (B, seq_len, 1)
+ if dim1 == self.num_codebooks:
+ # Shape: (B, num_codebooks, seq_len) - from encode()
+ seq_len = dim2
+ latent_features = []
+ for i, codebook in enumerate(self.codebooks):
+ sub_indices = indices[:, i, :] # (B, seq_len)
+ latent_feature = codebook.get_codebook_entry(sub_indices.flatten())
+ latent_feature = latent_feature.view(batch_size, seq_len, -1)
+ latent_features.append(latent_feature)
+
+ # Concatenate along feature dimension: (B, seq_len, C)
+ z_quantized = torch.cat(latent_features, dim=-1)
+ # Reshape to (B, C, 1, seq_len) for decoder
+ z_quantized = z_quantized.permute(0, 2, 1).unsqueeze(2)
+ return z_quantized
+
+ elif dim2 == 1:
+ # Shape: (B, seq_len, 1) - common format from generator
+ indices = indices.squeeze(-1) # (B, seq_len)
+ seq_len = dim1
+
+ # For generator format, all codebooks use the same indices
+ latent_features = []
+ for i, codebook in enumerate(self.codebooks):
+ latent_feature = codebook.get_codebook_entry(indices.flatten())
+ latent_feature = latent_feature.view(batch_size, seq_len, -1)
+ latent_features.append(latent_feature)
+
+ z_quantized = torch.cat(latent_features, dim=-1) # (B, seq_len, C)
+ z_quantized = z_quantized.permute(0, 2, 1).unsqueeze(2)
+ return z_quantized
+ else:
+ raise ValueError(f"Ambiguous 3D indices shape: {indices.shape}. "
+ f"Expected (B, {self.num_codebooks}, seq_len) or (B, seq_len, 1)")
+
+ elif len(indices.shape) == 4:
+ # Shape: (B, num_codebooks, H, W)
+ batch_size, _, height, width = indices.shape
+ latent_features = []
+ for i, codebook in enumerate(self.codebooks):
+ sub_indices = indices[:, i] # (B, H, W)
+ latent_feature = codebook.get_codebook_entry(sub_indices.flatten())
+ latent_feature = latent_feature.view(batch_size, height, width, -1)
+ latent_features.append(latent_feature)
+
+ # Concatenate and permute to (B, C, H, W)
+ latent_features = torch.cat(latent_features, dim=-1)
+ return latent_features.permute(0, 3, 1, 2).contiguous()
+ else:
+ raise NotImplementedError(f"Unsupported indices shape: {indices.shape}")
+
+ def f_to_idx(self, features: torch.Tensor) -> torch.Tensor:
+ """Convert features directly to indices without quantization.
+
+ Args:
+ features: Input features.
+
+ Returns:
+ indices: Token indices for each codebook.
+ """
+ indices = []
+ chunk_size = features.shape[-1] // self.num_codebooks
+ splited_features = features.split(chunk_size, dim=-1)
+ for i, codebook in enumerate(self.codebooks):
+ indices.append(codebook.f_to_idx(splited_features[i]))
+ indices = torch.stack(indices, dim=1)
+ return indices
diff --git a/vibetoken/quantizer/vector_quantizer.py b/vibetoken/quantizer/vector_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bfc32961e252d3fa58ce8a8845268b5f9f88e42
--- /dev/null
+++ b/vibetoken/quantizer/vector_quantizer.py
@@ -0,0 +1,174 @@
+"""Vector Quantizer for VibeToken.
+
+Simplified for inference-only use. Training-specific features removed.
+
+Reference:
+ https://github.com/CompVis/taming-transformers
+ https://github.com/google-research/magvit
+"""
+
+from typing import Mapping, Text, Tuple
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from torch.amp import autocast
+
+
+class VectorQuantizer(nn.Module):
+ """Vector Quantizer module for discrete tokenization.
+
+ Converts continuous latent representations to discrete tokens using
+ a learned codebook.
+ """
+
+ def __init__(
+ self,
+ codebook_size: int = 1024,
+ token_size: int = 256,
+ commitment_cost: float = 0.25,
+ use_l2_norm: bool = False,
+ ):
+ """Initialize VectorQuantizer.
+
+ Args:
+ codebook_size: Number of entries in the codebook.
+ token_size: Dimension of each codebook entry.
+ commitment_cost: Weight for commitment loss (unused in inference).
+ use_l2_norm: Whether to L2-normalize embeddings.
+ """
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.token_size = token_size
+ self.commitment_cost = commitment_cost
+ self.use_l2_norm = use_l2_norm
+
+ self.embedding = nn.Embedding(codebook_size, token_size)
+ self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
+
+ @autocast('cuda', enabled=False)
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
+ """Quantize input tensor.
+
+ Args:
+ z: Input tensor of shape (B, C, H, W).
+
+ Returns:
+ z_quantized: Quantized tensor of shape (B, C, H, W).
+ result_dict: Dictionary containing min_encoding_indices and losses.
+ """
+ z = z.float()
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = rearrange(z, 'b h w c -> (b h w) c')
+
+ if self.use_l2_norm:
+ z_flattened = nn.functional.normalize(z_flattened, dim=-1)
+ embedding = nn.functional.normalize(self.embedding.weight, dim=-1)
+ else:
+ embedding = self.embedding.weight
+
+ # Compute distances to codebook entries
+ d = (torch.sum(z_flattened**2, dim=1, keepdim=True) +
+ torch.sum(embedding**2, dim=1) -
+ 2 * torch.einsum('bd,dn->bn', z_flattened, embedding.T))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
+
+ if self.use_l2_norm:
+ z_norm = nn.functional.normalize(z, dim=-1)
+ else:
+ z_norm = z
+
+ # Compute losses (for compatibility, not used in inference)
+ commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z_norm) ** 2)
+ codebook_loss = torch.mean((z_quantized - z_norm.detach()) ** 2)
+ loss = commitment_loss + codebook_loss
+
+ # Straight-through estimator: preserve gradients
+ z_quantized = z_norm + (z_quantized - z_norm).detach()
+
+ # Reshape back to original format
+ z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
+
+ result_dict = dict(
+ quantizer_loss=loss,
+ commitment_loss=commitment_loss,
+ codebook_loss=codebook_loss,
+ min_encoding_indices=min_encoding_indices.view(
+ z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]
+ )
+ )
+
+ return z_quantized, result_dict
+
+ @autocast('cuda', enabled=False)
+ def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor:
+ """Get codebook entries for given indices.
+
+ Args:
+ indices: Token indices, shape (N,) or (N, vocab_size) for soft indices.
+
+ Returns:
+ Codebook entries, shape (N, token_size).
+ """
+ indices = indices.long()
+ if len(indices.shape) == 1:
+ z_quantized = self.embedding(indices)
+ elif len(indices.shape) == 2:
+ # Soft indices (weighted sum of embeddings)
+ z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)
+ else:
+ raise NotImplementedError(f"Unsupported indices shape: {indices.shape}")
+
+ if self.use_l2_norm:
+ z_quantized = nn.functional.normalize(z_quantized, dim=-1)
+
+ return z_quantized
+
+
+class DiagonalGaussianDistribution:
+ """Diagonal Gaussian distribution for VAE-style quantization.
+
+ Used when quantize_mode='vae' instead of discrete VQ.
+ """
+
+ @autocast('cuda', enabled=False)
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
+ """Initialize Gaussian distribution.
+
+ Args:
+ parameters: Tensor of shape (B, 2*C, H, W) containing mean and logvar.
+ deterministic: If True, sample() returns mean (no noise).
+ """
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ @autocast('cuda', enabled=False)
+ def sample(self) -> torch.Tensor:
+ """Sample from the distribution."""
+ x = self.mean.float() + self.std.float() * torch.randn(
+ self.mean.shape, device=self.parameters.device
+ )
+ return x
+
+ @autocast('cuda', enabled=False)
+ def mode(self) -> torch.Tensor:
+ """Return the mode (mean) of the distribution."""
+ return self.mean
+
+ @autocast('cuda', enabled=False)
+ def kl(self) -> torch.Tensor:
+ """Compute KL divergence from standard Gaussian."""
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ return 0.5 * torch.sum(
+ torch.pow(self.mean.float(), 2) + self.var.float() - 1.0 - self.logvar.float(),
+ dim=[1, 2]
+ )
diff --git a/vibetoken/tokenizer.py b/vibetoken/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6435b6c3ff44f506ab18f67f0cdda4bb96a4475a
--- /dev/null
+++ b/vibetoken/tokenizer.py
@@ -0,0 +1,439 @@
+"""High-level VibeToken Tokenizer API.
+
+Provides a clean, user-friendly interface for image tokenization.
+"""
+
+import os
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from PIL import Image
+import numpy as np
+from omegaconf import OmegaConf
+
+from .modeling.vibetoken import VibeToken
+
+
+def center_crop_to_multiple(image: Image.Image, multiple: int = 32) -> Image.Image:
+ """Center crop image to dimensions divisible by multiple.
+
+ Args:
+ image: PIL Image
+ multiple: Dimensions must be divisible by this value (default: 32)
+
+ Returns:
+ Cropped PIL Image
+ """
+ w, h = image.size
+ new_w = (w // multiple) * multiple
+ new_h = (h // multiple) * multiple
+
+ if new_w == w and new_h == h:
+ return image
+
+ left = (w - new_w) // 2
+ top = (h - new_h) // 2
+ right = left + new_w
+ bottom = top + new_h
+
+ return image.crop((left, top, right, bottom))
+
+
+def get_auto_patch_size(height: int, width: int) -> int:
+ """Automatically determine optimal patch size based on resolution.
+
+ Rules:
+ - <= 256x256: patch size 8
+ - <= 512x512: patch size 16
+ - > 512x512: patch size 32
+
+ Args:
+ height: Image height
+ width: Image width
+
+ Returns:
+ Optimal patch size as int
+ """
+ # Check <= 256x256
+ if height <= 256 and width <= 256:
+ return 8
+
+ # Check <= 512x512
+ if height <= 512 and width <= 512:
+ return 16
+
+ # Larger resolutions
+ return 32
+
+
+def auto_preprocess_image(image: Image.Image, verbose: bool = True) -> Tuple[Image.Image, int, dict]:
+ """Automatically preprocess image: center crop to divisible by 32 and determine optimal patch size.
+
+ Args:
+ image: PIL Image
+ verbose: Whether to print preprocessing info
+
+ Returns:
+ image: Preprocessed PIL Image
+ patch_size: Optimal patch size
+ info: Dictionary with preprocessing details
+ """
+ original_size = image.size # (W, H)
+
+ # Center crop to dimensions divisible by 32
+ image = center_crop_to_multiple(image, multiple=32)
+ cropped_size = image.size # (W, H)
+
+ # Get optimal patch size
+ patch_size = get_auto_patch_size(cropped_size[1], cropped_size[0]) # (H, W)
+
+ info = {
+ "original_size": original_size,
+ "cropped_size": cropped_size,
+ "patch_size": patch_size,
+ "was_cropped": original_size != cropped_size,
+ }
+
+ if verbose:
+ if info["was_cropped"]:
+ print(f"Center cropped: {original_size[0]}x{original_size[1]} -> {cropped_size[0]}x{cropped_size[1]} (divisible by 32)")
+ print(f"Auto patch size: {patch_size}x{patch_size}")
+
+ return image, patch_size, info
+
+
+class VibeTokenTokenizer:
+ """High-level API for VibeToken image tokenization.
+
+ Provides simple encode/decode methods for converting images to/from
+ discrete tokens.
+
+ Example:
+ >>> tokenizer = VibeTokenTokenizer.from_config("configs/vibetoken_ll.yaml", "model.bin")
+ >>> tokens = tokenizer.encode(images) # (B, num_codebooks, seq_len)
+ >>> reconstructed = tokenizer.decode(tokens, height=512, width=512)
+ """
+
+ def __init__(
+ self,
+ model: VibeToken,
+ device: str = "cuda",
+ ):
+ """Initialize tokenizer with a VibeToken model.
+
+ Args:
+ model: Loaded VibeToken model.
+ device: Device for inference.
+ """
+ self.model = model
+ self.device = device
+ self.num_codebooks = getattr(model.config.model.vq_model, 'num_codebooks', 1) \
+ if hasattr(model.config.model, 'vq_model') else 1
+
+ @classmethod
+ def from_config(
+ cls,
+ config_path: str,
+ checkpoint_path: str,
+ device: str = "cuda",
+ dtype: Optional[torch.dtype] = None,
+ ) -> "VibeTokenTokenizer":
+ """Create tokenizer from config and checkpoint files.
+
+ Args:
+ config_path: Path to YAML config file.
+ checkpoint_path: Path to model checkpoint.
+ device: Device for inference.
+ dtype: Optional dtype for model.
+
+ Returns:
+ VibeTokenTokenizer instance.
+ """
+ model = VibeToken.from_pretrained(
+ config_path=config_path,
+ checkpoint_path=checkpoint_path,
+ device=device,
+ dtype=dtype,
+ )
+ return cls(model=model, device=device)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_name: str,
+ checkpoint_path: str,
+ device: str = "cuda",
+ dtype: Optional[torch.dtype] = None,
+ ) -> "VibeTokenTokenizer":
+ """Load tokenizer by model name.
+
+ Args:
+ model_name: Model variant name ("ll" for Large-Large, "sl" for Small-Large).
+ checkpoint_path: Path to model checkpoint.
+ device: Device for inference.
+ dtype: Optional dtype for model.
+
+ Returns:
+ VibeTokenTokenizer instance.
+ """
+ # Get config path based on model name
+ package_dir = Path(__file__).parent.parent
+ config_map = {
+ "ll": "vibetoken_ll.yaml",
+ "sl": "vibetoken_sl.yaml",
+ }
+
+ if model_name not in config_map:
+ raise ValueError(f"Unknown model name: {model_name}. Choose from: {list(config_map.keys())}")
+
+ config_path = package_dir / "configs" / config_map[model_name]
+
+ if not config_path.exists():
+ raise FileNotFoundError(f"Config file not found: {config_path}")
+
+ return cls.from_config(
+ config_path=str(config_path),
+ checkpoint_path=checkpoint_path,
+ device=device,
+ dtype=dtype,
+ )
+
+ def preprocess(
+ self,
+ images: Union[torch.Tensor, np.ndarray, Image.Image, list],
+ ) -> torch.Tensor:
+ """Preprocess images for encoding.
+
+ Args:
+ images: Input images. Can be:
+ - torch.Tensor (B, 3, H, W) or (3, H, W), values in [0, 1]
+ - numpy array (B, H, W, 3) or (H, W, 3), values in [0, 255]
+ - PIL Image
+ - List of PIL Images
+
+ Returns:
+ Preprocessed tensor (B, 3, H, W), values in [0, 1].
+ """
+ if isinstance(images, Image.Image):
+ images = [images]
+
+ if isinstance(images, list):
+ # List of PIL images
+ tensors = []
+ for img in images:
+ if isinstance(img, Image.Image):
+ arr = np.array(img.convert("RGB"))
+ tensor = torch.from_numpy(arr).float().permute(2, 0, 1) / 255.0
+ tensors.append(tensor)
+ else:
+ raise ValueError(f"Unsupported image type in list: {type(img)}")
+ images = torch.stack(tensors)
+ elif isinstance(images, np.ndarray):
+ if images.ndim == 3:
+ images = images[None] # Add batch dim
+ # Assume (B, H, W, C) format
+ images = torch.from_numpy(images).float().permute(0, 3, 1, 2) / 255.0
+ elif isinstance(images, torch.Tensor):
+ if images.ndim == 3:
+ images = images.unsqueeze(0) # Add batch dim
+ # Assume already in [0, 1] range
+ if images.max() > 1.0:
+ images = images / 255.0
+ else:
+ raise ValueError(f"Unsupported image type: {type(images)}")
+
+ return images.to(self.device)
+
+ @torch.no_grad()
+ def encode(
+ self,
+ images: Union[torch.Tensor, np.ndarray, Image.Image, list],
+ patch_size: Optional[Union[int, Tuple[int, int]]] = None,
+ return_dict: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ """Encode images to discrete tokens.
+
+ Args:
+ images: Input images (see preprocess for supported formats).
+ patch_size: Optional encoding patch size. Can be int or tuple (H, W).
+ return_dict: Whether to return result dictionary with additional info.
+
+ Returns:
+ tokens: Token indices. Shape depends on quantize_mode:
+ - VQ: (B, seq_len)
+ - MVQ: (B, num_codebooks, seq_len)
+ result_dict (optional): Dictionary with additional encoding info.
+ """
+ images = self.preprocess(images)
+
+ # Convert int to tuple if needed
+ if isinstance(patch_size, int):
+ encode_ps = (patch_size, patch_size)
+ else:
+ encode_ps = patch_size # Already tuple or None
+ _, result_dict = self.model.encode(images, encode_patch_size=encode_ps)
+
+ # Extract token indices
+ tokens = result_dict['min_encoding_indices']
+
+ # Reshape based on quantize mode
+ if self.model.quantize_mode == "mvq":
+ # Already in (B, num_codebooks, H, W) format, flatten spatial
+ B, num_cb, H, W = tokens.shape
+ tokens = tokens.reshape(B, num_cb, H * W)
+ else:
+ # VQ mode: (B, H, W) -> (B, H*W)
+ B, H, W = tokens.shape
+ tokens = tokens.reshape(B, H * W)
+
+ if return_dict:
+ return tokens, result_dict
+ return tokens
+
+ @torch.no_grad()
+ def decode(
+ self,
+ tokens: torch.Tensor,
+ height: int,
+ width: int,
+ patch_size: Optional[Union[int, Tuple[int, int]]] = None,
+ clamp: bool = True,
+ ) -> torch.Tensor:
+ """Decode tokens back to images.
+
+ Args:
+ tokens: Token indices from encode().
+ height: Target image height.
+ width: Target image width.
+ patch_size: Optional decoding patch size. Can be int or tuple (H, W).
+ clamp: Whether to clamp output to [0, 1].
+
+ Returns:
+ Decoded images (B, 3, height, width), values in [0, 1].
+ """
+ tokens = tokens.to(self.device)
+
+ # Convert int to tuple if needed
+ if isinstance(patch_size, int):
+ decode_ps = (patch_size, patch_size)
+ else:
+ decode_ps = patch_size # Already tuple or None
+
+ decoded = self.model.decode_tokens(
+ tokens,
+ height=height,
+ width=width,
+ decode_patch_size=decode_ps,
+ )
+
+ if clamp:
+ decoded = torch.clamp(decoded, 0.0, 1.0)
+
+ return decoded
+
+ @torch.no_grad()
+ def reconstruct(
+ self,
+ images: Union[torch.Tensor, np.ndarray, Image.Image, list],
+ encode_patch_size: Optional[Union[int, Tuple[int, int]]] = None,
+ decode_patch_size: Optional[Union[int, Tuple[int, int]]] = None,
+ target_height: Optional[int] = None,
+ target_width: Optional[int] = None,
+ ) -> torch.Tensor:
+ """Encode and decode images (full reconstruction).
+
+ Args:
+ images: Input images.
+ encode_patch_size: Optional encoding patch size. Can be int or tuple (H, W).
+ decode_patch_size: Optional decoding patch size. Can be int or tuple (H, W).
+ target_height: Target output height (default: same as input).
+ target_width: Target output width (default: same as input).
+
+ Returns:
+ Reconstructed images (B, 3, H, W), values in [0, 1].
+ """
+ images = self.preprocess(images)
+ B, C, H, W = images.shape
+
+ if target_height is None:
+ target_height = H
+ if target_width is None:
+ target_width = W
+
+ # Encode with optional patch size
+ tokens = self.encode(images, patch_size=encode_patch_size)
+
+ # Decode with optional patch size
+ decoded = self.decode(
+ tokens,
+ height=target_height,
+ width=target_width,
+ patch_size=decode_patch_size,
+ )
+
+ return decoded
+
+ def to_pil(self, images: torch.Tensor) -> list:
+ """Convert tensor images to PIL Images.
+
+ Args:
+ images: Tensor (B, 3, H, W), values in [0, 1].
+
+ Returns:
+ List of PIL Images.
+ """
+ images = torch.clamp(images * 255, 0, 255)
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
+ return [Image.fromarray(img) for img in images]
+
+ @torch.no_grad()
+ def auto_reconstruct(
+ self,
+ image: Image.Image,
+ verbose: bool = True,
+ ) -> Tuple[torch.Tensor, dict]:
+ """Automatically reconstruct image with optimal settings.
+
+ This method:
+ 1. Center crops input to dimensions divisible by 32
+ 2. Auto-determines optimal patch sizes based on resolution
+ 3. Returns output at same size as (cropped) input
+
+ Args:
+ image: Input PIL Image
+ verbose: Whether to print auto-selected settings
+
+ Returns:
+ reconstructed: Reconstructed image tensor (1, 3, H, W)
+ info: Dictionary with auto-selected settings
+ """
+ # Use centralized auto_preprocess_image
+ original_size = image.size
+ image, patch_size, preprocess_info = auto_preprocess_image(image, verbose=verbose)
+ width, height = preprocess_info["cropped_size"]
+
+ # Encode and decode
+ tokens = self.encode(image, patch_size=patch_size)
+ reconstructed = self.decode(tokens, height=height, width=width, patch_size=patch_size)
+
+ info = {
+ "original_size": original_size,
+ "cropped_size": (width, height),
+ "patch_size": patch_size,
+ "token_shape": tokens.shape,
+ }
+
+ return reconstructed, info
+
+ @property
+ def codebook_size(self) -> int:
+ """Get total codebook size."""
+ vq_config = self.model.config.model.vq_model if hasattr(self.model.config.model, 'vq_model') else self.model.config.model
+ return getattr(vq_config, 'codebook_size', 32768)
+
+ @property
+ def num_latent_tokens(self) -> int:
+ """Get number of latent tokens."""
+ return self.model.num_latent_tokens
diff --git a/vibetokengen/__init__.py b/vibetokengen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vibetokengen/generate.py b/vibetokengen/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b61934df10639adc238bdd30dff9f9b9b0986d7
--- /dev/null
+++ b/vibetokengen/generate.py
@@ -0,0 +1,219 @@
+# Modified from:
+# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
+# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+import torch._dynamo.config
+import torch._inductor.config
+import copy
+
+
+# torch._inductor.config.coordinate_descent_tuning = True
+# torch._inductor.config.triton.unique_kernel_names = True
+# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
+
+
+### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
+def top_k_top_p_filtering(
+ logits,
+ top_k: int = 0,
+ top_p: float = 1.0,
+ filter_value: float = -float("Inf"),
+ min_tokens_to_keep: int = 1,
+):
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
+ Args:
+ logits: logits distribution shape (batch size, vocabulary size)
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+ """
+ if top_k > 0:
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ if min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ logits[indices_to_remove] = filter_value
+ return logits
+
+
+def sample(logits, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, sample_logits=True):
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
+ if top_k > 0 or top_p < 1.0:
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+ probs = F.softmax(logits, dim=-1)
+ if sample_logits:
+ idx = torch.multinomial(probs, num_samples=1)
+ else:
+ _, idx = torch.topk(probs, k=1, dim=-1)
+ return idx, probs
+
+
+def logits_to_probs(logits, temperature: float = 1.0, top_p: float = 1.0, top_k: int = None, **kwargs):
+ logits = logits / max(temperature, 1e-5)
+ if top_k > 0 or top_p < 1.0:
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
+
+
+def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, guidance_scale: float, guidance_scale_pow: float, target_h: float, target_w: float, **sampling_kwargs):
+ scale_pow = torch.ones((1), device=cond_idx.device) * guidance_scale_pow
+ scale_step = (1 + torch.cos(
+ ((0 / 256) ** scale_pow) * torch.pi)) / 2
+ cfg_scale = (guidance_scale - 1) * scale_step + 1
+
+ next_token = model(None, cond_idx, input_pos, target_h=target_h, target_w=target_w)
+ model.setup_head_caches(
+ next_token.shape[0],
+ model.num_codebooks,
+ dtype=next_token.dtype,
+ device=next_token.device
+ )
+ indices = []
+ bs = next_token.shape[0]
+ for i in range(model.num_codebooks):
+ start_pos = torch.tensor([i], dtype=torch.int)
+ mask = model.output.head_causal_mask[:bs, None, start_pos]
+ if i == 0:
+ logits = model.output.forward_test(next_token, input_pos=start_pos, mask=mask)
+ else:
+ if guidance_scale > 0.0:
+ pred_idx = torch.cat([pred_idx, pred_idx])
+ logits = model.output.forward_test(next_token, idx=pred_idx, input_pos=start_pos, mask=mask)
+ if guidance_scale > 0.0:
+ cond_logits, uncond_logits = torch.split(logits, len(logits) // 2, dim=0)
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
+ pred_idx = sample(logits, **sampling_kwargs)[0]
+ else:
+ pred_idx = sample(logits, **sampling_kwargs)[0]
+ indices.append(pred_idx)
+ indices = torch.stack(indices, dim=1)
+ return indices
+
+
+def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, guidance_scale: float, guidance_scale_pow: float, cfg_flag: bool, target_h: float, target_w: float, token_idx: int, **sampling_kwargs):
+ assert input_pos.shape[-1] == 1
+ scale_pow = torch.ones((1), device=x.device) * guidance_scale_pow
+ scale_step = (1 + torch.cos(
+ ((token_idx / 256) ** scale_pow) * torch.pi)) / 2
+ cfg_scale = (guidance_scale - 1) * scale_step + 1
+
+ if guidance_scale > 0.0:
+ x_combined = torch.cat([x, x])
+ next_token = model(x_combined, cond_idx=None, input_pos=input_pos, target_h=target_h, target_w=target_w)
+ else:
+ next_token = model(x, cond_idx=None, input_pos=input_pos, target_h=target_h, target_w=target_w)
+ model.setup_head_caches(
+ next_token.shape[0],
+ model.num_codebooks,
+ dtype=next_token.dtype,
+ device=next_token.device
+ )
+ indices = []
+ bs = next_token.shape[0]
+ for i in range(model.num_codebooks):
+ start_pos = torch.tensor([i], dtype=torch.int)
+ mask = model.output.head_causal_mask[:bs, None, start_pos]
+ if i == 0:
+ logits = model.output.forward_test(next_token, input_pos=start_pos, mask=mask)
+ else:
+ if guidance_scale > 0.0:
+ pred_idx = torch.cat([pred_idx, pred_idx])
+ logits = model.output.forward_test(next_token, idx=pred_idx, input_pos=start_pos, mask=mask)
+ if guidance_scale > 0.0:
+ cond_logits, uncond_logits = torch.split(logits, len(logits) // 2, dim=0)
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
+ pred_idx = sample(logits, **sampling_kwargs)[0]
+ else:
+ pred_idx = sample(logits, **sampling_kwargs)[0]
+ indices.append(pred_idx)
+ indices = torch.stack(indices, dim=1)
+ return indices
+
+
+def decode_n_tokens(
+ model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
+ guidance_scale: float, guidance_scale_pow: float, cfg_interval: int, target_h: float, target_w: float, **sampling_kwargs
+ ):
+ new_tokens = []
+ cfg_flag = True
+ for i in range(num_new_tokens):
+ # Actually better for Inductor to codegen attention here
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
+ next_token = decode_one_token(model, cur_token, input_pos, guidance_scale, guidance_scale_pow, cfg_flag, target_h, target_w, i+1, **sampling_kwargs)
+ input_pos += 1
+ new_tokens.append(next_token.clone())
+ cur_token = next_token
+ return new_tokens
+
+
+@torch.no_grad()
+def generate(model, cond, max_new_tokens, num_codebooks, emb_masks=None, cfg_scale=1.0, cfg_interval=-1,
+ target_h=None, target_w=None, guidance_scale=16.0, guidance_scale_pow=2.5, **sampling_kwargs):
+
+ if model.model_type == 'c2i':
+ if guidance_scale > 0.0:
+ cond_null = torch.ones_like(cond) * model.num_classes
+ cond_combined = torch.cat([cond, cond_null])
+ else:
+ cond_combined = cond
+ T = 1
+ else:
+ raise Exception("please check model type")
+
+ T_new = T + max_new_tokens
+ max_seq_length = T_new
+ max_batch_size = cond.shape[0]
+
+ device = cond.device
+ with torch.device(device):
+ max_batch_size_cfg = max_batch_size * 2 if guidance_scale > 0.0 else max_batch_size
+ model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length,
+ dtype=model.tok_embeddings.codebooks[0].weight.dtype)
+
+ if emb_masks is not None:
+ assert emb_masks.shape[0] == max_batch_size
+ assert emb_masks.shape[-1] == T
+ if guidance_scale > 0.0:
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
+ else:
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
+
+ eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
+ model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
+
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ seq = torch.empty((max_batch_size, num_codebooks, T_new), dtype=torch.int, device=device)
+
+ input_pos = torch.arange(0, T, device=device)
+ next_token = prefill(model, cond_combined, input_pos, guidance_scale, guidance_scale_pow, target_h, target_w, **sampling_kwargs)
+ seq[:, :, T:T+1] = next_token
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+ generated_tokens = decode_n_tokens(model, next_token, input_pos, max_new_tokens - 1, guidance_scale, guidance_scale_pow, cfg_interval, target_h, target_w,
+ **sampling_kwargs)
+ seq[:, :, T+1:] = torch.cat(generated_tokens, dim=-1)
+
+ return seq[:, :, T:]
+
+
diff --git a/vibetokengen/model.py b/vibetokengen/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f45c58d6d74f0428a26e6e6b3ac241db270b6191
--- /dev/null
+++ b/vibetokengen/model.py
@@ -0,0 +1,777 @@
+# Modified from:
+# VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
+# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
+# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
+# llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
+# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
+# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+from dataclasses import dataclass
+from typing import Optional, List
+
+import math
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from contextlib import nullcontext
+import logging
+
+logger = logging.getLogger(__name__)
+
+def find_multiple(n: int, k: int):
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+class DropPath(torch.nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
+
+@dataclass
+class ModelArgs:
+ dim: int = 4096
+ n_layer: int = 32
+ n_head: int = 32
+ n_kv_head: Optional[int] = None
+ n_output_layer: int = 1
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
+ ffn_dim_multiplier: Optional[float] = None
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ initializer_range: float = 0.02
+
+ token_dropout_p: float = 0.1
+ attn_dropout_p: float = 0.0
+ resid_dropout_p: float = 0.1
+ ffn_dropout_p: float = 0.1
+ drop_path_rate: float = 0.0
+
+ num_codebooks: int = 8
+ num_classes: int = 1000
+ caption_dim: int = 2048
+ class_dropout_prob: float = 0.1
+ model_type: str = 'c2i'
+ extra_layers: str = None
+ capping: float = 50.0
+
+ vocab_size: int = 16384
+ cls_token_num: int = 1
+ block_size: int = 256
+ max_batch_size: int = 32
+ max_seq_len: int = 2048
+
+ # Dynamic V1
+ dynamic: bool = True
+ extra_cls_token_num: int = 2
+
+
+
+#################################################################################
+# Capping for Attention Softmax #
+#################################################################################
+
+def custom_scaled_dot_product_attention(
+ query, key, value,
+ attn_mask=None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale=None,
+ enable_gqa: bool = False,
+ logit_cap: float = 50.0, # set to None to disable tanh-capping
+) -> torch.Tensor:
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = (1 / math.sqrt(query.size(-1))) if scale is None else scale
+
+ # additive bias for masking (same dtype/device as query)
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
+
+ if is_causal:
+ assert attn_mask is None
+ causal = torch.ones(L, S, dtype=torch.bool, device=query.device).tril()
+ attn_bias = attn_bias.masked_fill(~causal, float("-inf"))
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf"))
+ else:
+ attn_bias = attn_bias + attn_mask.to(dtype=query.dtype, device=query.device)
+
+ if enable_gqa:
+ repeat = query.size(-3) // key.size(-3)
+ key = key.repeat_interleave(repeat, dim=-3)
+ value = value.repeat_interleave(repeat, dim=-3)
+
+ # logits (pre-softmax)
+ logits = (query @ key.transpose(-2, -1)) * scale_factor
+ logits = logits + attn_bias
+
+ # --- tanh cap: softmax(tanh(logits/cap)*cap), but keep -inf from masks ---
+ if logit_cap is not None:
+ finite = torch.isfinite(logits)
+ logits = torch.where(
+ finite,
+ torch.tanh(logits / logit_cap) * logit_cap,
+ logits, # keeps -inf (and +inf if any)
+ )
+
+ # softmax in float32 for stability, then back to value dtype
+ attn_weight = torch.softmax(logits.float(), dim=-1).to(value.dtype)
+
+ if dropout_p:
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ return attn_weight @ value
+
+
+def custom_scaled_dot_product_attention_v2(
+ query, key, value,
+ attn_mask=None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale=None,
+ enable_gqa: bool = False,
+ logit_cap: float = 50.0, # set to None to disable tanh-capping
+ force_true_fp32_qk: bool = True, # disable TF32 for QK^T matmul
+ force_true_fp32_out: bool = False, # set True if you also want (P@V) in true FP32
+) -> torch.Tensor:
+ """
+ Shapes follow (..., H, L, D). Works with bf16/fp16 autocast; stabilizes numerics by
+ doing QK^T + softmax in real FP32 (not TF32), then casting back.
+ """
+ device = query.device
+ q_dtype = query.dtype
+ v_dtype = value.dtype
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = (1.0 / math.sqrt(query.size(-1))) if scale is None else float(scale)
+
+ # --- broadcastable FP32 bias (safer for -inf) ---
+ attn_bias = torch.zeros(L, S, dtype=torch.float32, device=device)
+
+ if is_causal:
+ assert attn_mask is None
+ causal = torch.ones(L, S, dtype=torch.bool, device=device).tril()
+ attn_bias = attn_bias.masked_fill(~causal, float("-inf"))
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf"))
+ else:
+ attn_bias = attn_bias + attn_mask.to(dtype=torch.float32, device=device)
+
+ if enable_gqa:
+ repeat = query.size(-3) // key.size(-3)
+ key = key.repeat_interleave(repeat, dim=-3)
+ value = value.repeat_interleave(repeat, dim=-3)
+
+ # --- QK^T in true FP32 (optionally disable TF32 just here) ---
+ q32 = query.to(torch.float32)
+ k32T = key.to(torch.float32).transpose(-2, -1)
+ tf32_ctx = torch.backends.cuda.matmul.allow_tf32(False) if force_true_fp32_qk else nullcontext()
+ with tf32_ctx:
+ logits = (q32 @ k32T) * scale_factor # FP32 matmul
+
+ # add FP32 bias/mask
+ logits = logits + attn_bias # still FP32
+
+ # optional tanh-cap, preserving -inf from masks
+ if logit_cap is not None:
+ finite = torch.isfinite(logits)
+ logits = torch.where(finite,
+ torch.tanh(logits / logit_cap) * logit_cap,
+ logits)
+
+ # --- softmax in FP32, then cast ---
+ attn_weight = torch.softmax(logits, dim=-1) # FP32
+ attn_weight = attn_weight.to(v_dtype) # back to value dtype for matmul
+
+ if dropout_p:
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ # --- (P @ V): you can keep speed (bf16) or force true FP32 here too ---
+ if force_true_fp32_out:
+ P32 = attn_weight.to(torch.float32)
+ V32 = value.to(torch.float32)
+ with torch.backends.cuda.matmul.allow_tf32(False):
+ out = P32 @ V32
+ return out.to(v_dtype)
+ else:
+ return attn_weight @ value
+
+
+#################################################################################
+# Embedding Layers for Class Labels #
+#################################################################################
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels).unsqueeze(1)
+ return embeddings
+
+
+#################################################################################
+# Embedding Layers for Text Feature #
+#################################################################################
+class CaptionEmbedder(nn.Module):
+ """
+ Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
+ super().__init__()
+ self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
+ self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
+ self.uncond_prob = uncond_prob
+
+ def token_drop(self, caption, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
+ return caption
+
+ def forward(self, caption, train, force_drop_ids=None):
+ use_dropout = self.uncond_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ caption = self.token_drop(caption, force_drop_ids)
+ embeddings = self.cap_proj(caption)
+ return embeddings
+
+
+class TokenEmbedder(nn.Module):
+ def __init__(self, vocab_size, hidden_size, num_codebooks):
+ super().__init__()
+ self.num_codebooks = num_codebooks
+ self.codebooks = nn.ModuleList()
+ for _ in range(num_codebooks):
+ codebook = nn.Embedding(vocab_size // num_codebooks, hidden_size)
+ self.codebooks.append(codebook)
+
+ def forward(self, indices):
+ assert indices.shape[1] == self.num_codebooks
+ latent_features = []
+ for i in range(self.num_codebooks):
+ latent_feature = self.codebooks[i](indices[:, i])
+ latent_features.append(latent_feature)
+ latent_features = torch.stack(latent_features).sum(dim=0)
+ return latent_features
+
+
+class MLP(nn.Module):
+ def __init__(self, in_features, hidden_features, out_features):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
+ self.act = nn.GELU(approximate='tanh')
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ return x
+
+
+#################################################################################
+# GPT Model #
+#################################################################################
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ hidden_dim = 4 * config.dim
+ hidden_dim = int(2 * hidden_dim / 3)
+ # custom dim factor multiplier
+ if config.ffn_dim_multiplier is not None:
+ hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
+ hidden_dim = find_multiple(hidden_dim, config.multiple_of)
+
+ self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
+ self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
+
+ def forward(self, x):
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
+
+
+class KVCache(nn.Module):
+ def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
+ super().__init__()
+ cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+ return k_out, v_out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+ self.dim = config.dim
+ self.head_dim = config.dim // config.n_head
+ self.n_head = config.n_head
+ self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
+ total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
+
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
+ self.kv_cache = None
+
+ # regularization
+ self.attn_dropout_p = config.attn_dropout_p
+ self.resid_dropout = nn.Dropout(config.resid_dropout_p)
+
+ self.capping = config.capping if 'cap' in config.extra_layers else None
+
+ # Add QK or QKV normalization based on config
+ if 'QKV' in config.extra_layers:
+ self.q_norm = RMSNorm(self.head_dim, eps=config.norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=config.norm_eps)
+ self.v_norm = RMSNorm(self.head_dim, eps=config.norm_eps)
+ elif 'QK' in config.extra_layers:
+ self.q_norm = RMSNorm(self.head_dim, eps=config.norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=config.norm_eps)
+
+ def forward(
+ self, x: torch.Tensor, freqs_cis: torch.Tensor = None,
+ input_pos: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
+ full_attn: Optional[bool] = False,
+ ):
+ bsz, seqlen, _ = x.shape
+ kv_size = self.n_kv_head * self.head_dim
+ xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
+ xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
+ xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
+
+ # Apply QK or QKV normalization if configured
+ if hasattr(self, 'q_norm'):
+ xq = self.q_norm(xq)
+ xk = self.k_norm(xk)
+ if hasattr(self, 'v_norm'):
+ xv = self.v_norm(xv)
+
+ if freqs_cis is not None:
+ xq = apply_rotary_emb(xq, freqs_cis)
+ xk = apply_rotary_emb(xk, freqs_cis)
+
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
+
+ if self.kv_cache is not None:
+ keys, values = self.kv_cache.update(input_pos, xk, xv)
+ else:
+ keys, values = xk, xv
+ keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
+ values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
+
+ output = custom_scaled_dot_product_attention(
+ xq, keys, values,
+ attn_mask=mask,
+ is_causal=True if (mask is None and full_attn == False) else False, # is_causal=False is for KV cache
+ dropout_p=self.attn_dropout_p if self.training else 0,
+ logit_cap=self.capping if self.capping is not None else None
+ )
+
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ output = self.resid_dropout(self.wo(output))
+ return output
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs, drop_path: float):
+ super().__init__()
+ self.attention = Attention(config)
+ self.feed_forward = FeedForward(config)
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None,
+ full_attn=False):
+ h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask, full_attn))
+ out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
+ return out
+
+
+class AutoRegressiveHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.num_codebooks = config.num_codebooks
+ self.sub_vocab_size = config.vocab_size // config.num_codebooks
+
+ self.codebooks = nn.ModuleList()
+ for _ in range(self.num_codebooks - 1):
+ codebook = nn.Embedding(self.sub_vocab_size, config.dim)
+ self.codebooks.append(codebook)
+
+ self.layers = torch.nn.ModuleList()
+ for _ in range(config.n_output_layer):
+ self.layers.append(TransformerBlock(config, drop_path=0.))
+
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.linear_head = nn.Linear(config.dim, self.sub_vocab_size)
+
+ def forward_train(self, base_tokens, targets):
+ K = targets.shape[1]
+ B, L, C = base_tokens.shape
+ base_tokens = base_tokens.reshape(B * L, 1, C)
+ targets = targets.permute(0, 2, 1).reshape(B * L, K)[:, :-1]
+ index_embeddings = []
+ for i in range(self.num_codebooks - 1):
+ index_embed = self.codebooks[i](targets[:, i])
+ index_embeddings.append(index_embed)
+ index_embeddings = torch.stack(index_embeddings, dim=1)
+ h = torch.cat((base_tokens, index_embeddings), dim=1)
+ for layer in self.layers:
+ h = layer(h, freqs_cis=None, start_pos=None, mask=None)
+ h = self.norm(h)
+ logits = self.linear_head(h)
+ logits = logits.reshape(B, L, K, -1).permute(0, 2, 1, 3)
+ return logits
+
+ def forward_test(self, base_tokens, idx=None, input_pos=None, mask=None):
+ if idx is not None:
+ h = self.codebooks[input_pos - 1](idx)
+ else:
+ h = base_tokens
+ for layer in self.layers:
+ h = layer(h, freqs_cis=None, start_pos=input_pos, mask=mask)
+ h = self.norm(h)
+ logits = self.linear_head(h)
+ return logits
+
+
+class Transformer(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.vocab_size = config.vocab_size
+ self.n_layer = config.n_layer
+ self.block_size = config.block_size
+ self.num_classes = config.num_classes
+ self.model_type = config.model_type
+ self.num_codebooks = config.num_codebooks
+ self.cls_token_num = config.cls_token_num
+ self.extra_layers = config.extra_layers
+ self.dynamic = config.dynamic
+ self.extra_cls_token_num = config.extra_cls_token_num
+
+ if self.extra_layers not in [None, 'QK', 'QKV', 'FC', 'cap', 'clip', 'QK_cap', 'QKV_cap', 'QK_clip', 'QKV_clip', 'QK_FC_cap', 'QKV_FC_cap', 'QK_FC_clip', 'QKV_FC_clip']:
+ raise ValueError(f"Invalid extra layers: {self.extra_layers}")
+
+ if self.model_type == 'c2i':
+ self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
+ elif self.model_type == 't2i':
+ self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
+ else:
+ raise Exception("please check model type")
+ # self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
+ self.tok_embeddings = TokenEmbedder(config.vocab_size, config.dim, config.num_codebooks)
+ self.tok_dropout = nn.Dropout(config.token_dropout_p)
+
+ # transformer blocks
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
+ self.layers = torch.nn.ModuleList()
+ for layer_id in range(config.n_layer):
+ self.layers.append(TransformerBlock(config, dpr[layer_id]))
+
+ # if dynamic, add new layers
+ if self.dynamic:
+ self.target_h_embedder = nn.Linear(1, config.dim, bias=True)
+ self.target_w_embedder = nn.Linear(1, config.dim, bias=True)
+ else:
+ raise ValueError("Dynamic is not supported")
+
+ # output layer
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.output = AutoRegressiveHead(config)
+
+ # 2d rotary pos embedding
+ # grid_size = int(self.block_size ** 0.5)
+ # assert grid_size * grid_size == self.block_size
+ # self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head,
+ # self.config.rope_base, self.cls_token_num)
+ self.freqs_cis = precompute_freqs_cis(self.block_size, self.config.dim // self.config.n_head,
+ self.config.rope_base, self.cls_token_num)
+
+ # KVCache
+ self.max_batch_size = -1
+ self.max_seq_length = -1
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize nn.Linear and nn.Embedding
+ self.apply(self._init_weights)
+ # exit()
+
+ # Zero-out output layers:
+ nn.init.constant_(self.output.linear_head.weight, 0)
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+
+ def setup_caches(self, max_batch_size, max_seq_length, dtype):
+ # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
+ # return
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_length = find_multiple(max_seq_length, 8)
+ self.max_seq_length = max_seq_length
+ self.max_batch_size = max_batch_size
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
+
+ causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
+ self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
+ # grid_size = int(self.config.block_size ** 0.5)
+ # assert grid_size * grid_size == self.block_size
+ # self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head,
+ # self.config.rope_base, self.cls_token_num)
+ self.freqs_cis = precompute_freqs_cis(self.block_size, self.config.dim // self.config.n_head,
+ self.config.rope_base, self.cls_token_num)
+
+ def setup_head_caches(self, max_batch_size, max_seq_length, dtype, device):
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_length = find_multiple(max_seq_length, 8)
+ self.max_seq_length = max_seq_length
+ self.max_batch_size = max_batch_size
+ for b in self.output.layers:
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype).to(
+ device)
+ causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
+ self.output.head_causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
+
+ def forward(
+ self,
+ idx: torch.Tensor,
+ cond_idx: torch.Tensor, # cond_idx_or_embed
+ input_pos: Optional[torch.Tensor] = None,
+ targets: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
+ valid: Optional[torch.Tensor] = None,
+ target_h: Optional[torch.Tensor] = None,
+ target_w: Optional[torch.Tensor] = None,
+ ):
+
+ if idx is not None and cond_idx is not None: # training or naive inference
+ cond_embeddings = self.cls_embedding(cond_idx, train=self.training)[:, :self.cls_token_num]
+ token_embeddings = self.tok_embeddings(idx)
+
+ if self.dynamic:
+ target_h_embed = self.target_h_embedder(target_h).unsqueeze(1)
+ target_w_embed = self.target_w_embedder(target_w).unsqueeze(1)
+
+ cond_embeddings = cond_embeddings + target_h_embed + target_w_embed
+
+ token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
+ h = self.tok_dropout(token_embeddings)
+ self.freqs_cis = self.freqs_cis.to(h.device)
+
+ else:
+ if cond_idx is not None: # prefill in inference
+ token_embeddings = self.cls_embedding(cond_idx, train=self.training)[:, :self.cls_token_num]
+ if self.dynamic:
+ target_h_embed = self.target_h_embedder(target_h).unsqueeze(1)
+ target_w_embed = self.target_w_embedder(target_w).unsqueeze(1)
+
+ token_embeddings = token_embeddings + target_h_embed + target_w_embed
+
+ else: # decode_n_tokens(kv cache) in inference
+ token_embeddings = self.tok_embeddings(idx)
+
+ bs = token_embeddings.shape[0]
+ mask = self.causal_mask[:bs, None, input_pos]
+ h = self.tok_dropout(token_embeddings)
+ self.freqs_cis = self.freqs_cis
+
+ if self.training:
+ freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
+ else:
+ freqs_cis = self.freqs_cis[input_pos]
+ # transformer blocks
+ for layer in self.layers:
+ h = layer(h, freqs_cis, input_pos, mask)
+
+ # output layers
+ h = self.norm(h)
+ if targets is None:
+ return h
+ else:
+ logits = self.output.forward_train(h, targets=targets)
+
+ if self.training:
+ logits = logits[:, self.cls_token_num - 1:].contiguous()
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if valid is not None:
+ loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
+ valid_all = valid[:, None].repeat(1, targets.shape[1]).view(-1)
+ loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
+ elif targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
+ return list(self.layers)
+
+
+#################################################################################
+# Rotary Positional Embedding Functions #
+#################################################################################
+# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
+ cond_cache = torch.cat(
+ [torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
+ return cond_cache
+
+
+def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
+ # split the dimension into half, one for x and one for y
+ half_dim = n_elem // 2
+ freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
+ t = torch.arange(grid_size, device=freqs.device)
+ freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
+ freqs_grid = torch.concat([
+ freqs[:, None, :].expand(-1, grid_size, -1),
+ freqs[None, :, :].expand(grid_size, -1, -1),
+ ], dim=-1) # (grid_size, grid_size, head_dim // 2)
+ cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)],
+ dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
+ cache = cache_grid.flatten(0, 1)
+ cond_cache = torch.cat(
+ [torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
+ return cond_cache
+
+
+def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
+ # x: (bs, seq_len, n_head, head_dim)
+ # freqs_cis (seq_len, head_dim // 2, 2)
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
+ x_out2 = torch.stack([
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ], dim=-1)
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
+
+
+#################################################################################
+# GPT Configs #
+#################################################################################
+### text-conditional
+def GPT_7B(**kwargs):
+ return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
+
+
+def GPT_3B(**kwargs):
+ return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
+
+
+def GPT_1B(**kwargs):
+ return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
+
+
+### class-conditional
+def GPT_XXXL(**kwargs):
+ return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
+
+
+def GPT_XXL(**kwargs):
+ return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
+
+
+def GPT_XL(**kwargs):
+ return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
+
+
+def GPT_L(**kwargs):
+ return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
+
+
+def GPT_B(**kwargs):
+ return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
+
+
+GPT_models = {
+ 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
+ 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
+}
\ No newline at end of file