diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e0fc2128b3abb03b579f96f612179ed758a9087a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/example_1.png filter=lfs diff=lfs merge=lfs -text +assets/generated_images.png filter=lfs diff=lfs merge=lfs -text +assets/reconstructed.png filter=lfs diff=lfs merge=lfs -text +assets/teaser.png filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6768d3d8c3e16a0313dc7800eb5cad1241f2eb43 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Maitreya Patel + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 4304c3dea1a8ee2d6b911796730dac72685d98b8..a2ff51f361c0a484aa015443eb09a370c6811b68 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,179 @@ +# [CVPR 2026] VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations + +

+ VibeToken Teaser +

+ +

+ CVPR 2026  |  + Paper  |  + Project Page  |  + Checkpoints +

+ +

+ CVPR 2026 + arXiv + License + HuggingFace +

+ --- -title: VibeToken -emoji: 🦀 -colorFrom: blue -colorTo: red -sdk: gradio -sdk_version: 6.6.0 -python_version: '3.12' -app_file: app.py -pinned: false -license: mit ---- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +We introduce an efficient, resolution-agnostic autoregressive (AR) image synthesis approach that generalizes to **arbitrary resolutions and aspect ratios**, narrowing the gap to diffusion models at scale. At its core is **VibeToken**, a novel resolution-agnostic 1D Transformer-based image tokenizer that encodes images into a dynamic, user-controllable sequence of 32--256 tokens, achieving state-of-the-art efficiency and performance trade-off. Building on VibeToken, we present **VibeToken-Gen**, a class-conditioned AR generator with out-of-the-box support for arbitrary resolutions while requiring significantly fewer compute resources. + +### 🔥 Highlights + +| | | +|---|---| +| 🎯 **1024×1024 in just 64 tokens** | Achieves **3.94 gFID** vs. 5.87 gFID for diffusion-based SOTA (1,024 tokens) | +| ⚡ **Constant 179G FLOPs** | 63× more efficient than LlamaGen (11T FLOPs at 1024×1024) | +| 🌐 **Resolution-agnostic** | Supports arbitrary resolutions and aspect ratios out of the box | +| 🎛️ **Dynamic token count** | User-controllable 32--256 tokens per image | +| 🔍 **Native super-resolution** | Supports image super-resolution out of the box | + + +## 📰 News + +- **[Feb 2026]** 🎉 VibeToken is accepted at **CVPR 2026**! +- **[Feb 2026]** Training scripts released. +- **[Feb 2026]** Inference code and checkpoints released. + + +## 🚀 Quick Start + +```bash +# 1. Clone and setup +git clone https://github.com//VibeToken.git +cd VibeToken +uv venv --python=3.11.6 +source .venv/bin/activate +uv pip install -r requirements.txt + +# 2. Download a checkpoint (see Checkpoints section below) +mkdir -p checkpoints +wget https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_LL.bin -O ./checkpoints/VibeToken_LL.bin + +# 3. Reconstruct an image +python reconstruct.py --auto \ + --config configs/vibetoken_ll.yaml \ + --checkpoint ./checkpoints/VibeToken_LL.bin \ + --image ./assets/example_1.png \ + --output ./assets/reconstructed.png +``` + + +## 📦 Checkpoints + +All checkpoints are hosted on [Hugging Face](https://huggingface.co/mpatel57/VibeToken). + +#### Reconstruction Checkpoints + +| Name | Resolution | rFID (256 tokens) | rFID (64 tokens) | Download | +|------|:----------:|:-----------------:|:----------------:|----------| +| VibeToken-LL | 1024×1024 | 3.76 | 4.12 | [VibeToken_LL.bin](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_LL.bin) | +| VibeToken-LL | 256×256 | 5.12 | 0.90 | same as above | +| VibeToken-SL | 1024×1024 | 4.25 | 2.41 | [VibeToken_SL.bin](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeToken_SL.bin) | +| VibeToken-SL | 256×256 | 5.44 | 0.40 | same as above | + +#### Generation Checkpoints + +| Name | Training Resolution(s) | Tokens | Best gFID | Download | +|------|:----------------------:|:------:|:---------:|----------| +| VibeToken-Gen-B | 256×256 | 65 | 7.62 | [VibeTokenGen-b-fixed65_dynamic_1500k.pt](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeTokenGen-b-fixed65_dynamic_1500k.pt) | +| VibeToken-Gen-B | 1024×1024 | 65 | 7.37 | same as above | +| VibeToken-Gen-XXL | 256×256 | 65 | 3.62 | [VibeTokenGen-xxl-dynamic-65_750k.pt](https://huggingface.co/mpatel57/VibeToken/resolve/main/VibeTokenGen-xxl-dynamic-65_750k.pt) | +| VibeToken-Gen-XXL | 1024×1024 | 65 | **3.54** | same as above | + + +## 🛠️ Setup + +```bash +uv venv --python=3.11.6 +source .venv/bin/activate +uv pip install -r requirements.txt +``` + +> **Tip:** If you don't have `uv`, install it via `pip install uv` or see [uv docs](https://github.com/astral-sh/uv). Alternatively, use `python -m venv .venv && pip install -r requirements.txt`. + + +## 🖼️ VibeToken Reconstruction + +Download the VibeToken-LL checkpoint (see [Checkpoints](#-checkpoints)), then: + +```bash +# Auto mode (recommended) -- automatically determines optimal patch sizes +python reconstruct.py --auto \ + --config configs/vibetoken_ll.yaml \ + --checkpoint ./checkpoints/VibeToken_LL.bin \ + --image ./assets/example_1.png \ + --output ./assets/reconstructed.png + +# Manual mode -- specify patch sizes explicitly +python reconstruct.py \ + --config configs/vibetoken_ll.yaml \ + --checkpoint ./checkpoints/VibeToken_LL.bin \ + --image ./assets/example_1.png \ + --output ./assets/reconstructed.png \ + --encoder_patch_size 16 \ + --decoder_patch_size 16 +``` + +> **Note:** For best performance, the input image resolution should be a multiple of 32. Images with other resolutions are automatically rescaled to the nearest multiple of 32. + + +## 🎨 VibeToken-Gen: ImageNet-1k Generation + +Download both the VibeToken-LL and VibeToken-Gen-XXL checkpoints (see [Checkpoints](#-checkpoints)), then: + +```bash +python generate.py \ + --gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \ + --gpt-model GPT-XXL --num-output-layer 4 \ + --num-codebooks 8 --codebook-size 32768 \ + --image-size 256 --cfg-scale 4.0 --top-k 500 --temperature 1.0 \ + --class-dropout-prob 0.1 \ + --extra-layers "QKV" \ + --latent-size 65 \ + --config ./configs/vibetoken_ll.yaml \ + --vq-ckpt ./checkpoints/VibeToken_LL.bin \ + --sample-dir ./assets/ \ + --skip-folder-creation \ + --compile \ + --decoder-patch-size 32,32 \ + --target-resolution 1024,1024 \ + --llamagen-target-resolution 256,256 \ + --precision bf16 \ + --global-seed 156464151 +``` + +The `--target-resolution` controls the tokenizer output resolution, while `--llamagen-target-resolution` controls the generator's internal resolution (max 512×512; for higher resolutions, the tokenizer handles upscaling). + + +## 🏋️ Training + +To train the VibeToken tokenizer from scratch, please refer to [TRAIN.md](TRAIN.md) for detailed instructions. + + +## 🙏 Acknowledgement + +We would like to acknowledge the following repositories that inspired our work and upon which we directly build: +[1d-tokenizer](https://github.com/bytedance/1d-tokenizer), +[LlamaGen](https://github.com/FoundationVision/LlamaGen), and +[UniTok](https://github.com/FoundationVision/UniTok). + + +## 📝 Citation + +If you find VibeToken useful in your research, please consider citing: + +```bibtex +@inproceedings{vibetoken2026, + title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations}, + author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2026} +} +``` + +If you have any questions, feel free to open an issue or reach out! diff --git a/TRAIN.md b/TRAIN.md new file mode 100644 index 0000000000000000000000000000000000000000..3ca4fa9a9ba049934fcc98064346b794e26a59fe --- /dev/null +++ b/TRAIN.md @@ -0,0 +1,110 @@ +# Training Instructions + +## VibeToken MVQ Tokenizer + +This repository contains the training code for our tokenizer. +We provide the example config [VibeToken-Small](configs/training/VibeToken_small.yaml) that trains the small encoder/decoder architecture with 32-64 tokens. + +### Data Preparation + +All data paths are controlled by the `DATA_DIR` environment variable. Set it once to point to your preferred storage location: + +```bash +export DATA_DIR=/path/to/your/storage # defaults to ./data if unset +``` + +Download ImageNet-1k and convert to WebDataset format: + +```bash +source .venv/bin/activate + +# Option 1: Use the setup script (recommended) +bash setup.sh + +# Option 2: Run steps manually +export HF_HUB_ENABLE_HF_TRANSFER=1 +huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir "${DATA_DIR}/imagenet-1k" +python data/convert_imagenet_to_wds.py \ + --input_dir "${DATA_DIR}/imagenet-1k" \ + --output_dir "${DATA_DIR}/imagenet_wds" +``` + +After preparation, update the shard paths in your training config to match your `DATA_DIR`: + +```yaml +dataset: + params: + train_shards_path_or_url: "/imagenet_wds/imagenet-train-{000001..000128}.tar" + eval_shards_path_or_url: "/imagenet_wds/imagenet-val-{000001..000004}.tar" +``` + +### Launch Training + +Start training on 1 node with 8 GPUs: + +```bash +source .venv/bin/activate +bash train_tokenizer.sh +``` + +### Config Reference + +Below are the important hyperparameters to manage the training. + +```yaml +model: + vq_model: + vit_enc_model_size: "small" # this can be small/base/large + vit_dec_model_size: "small" # this can be small/base/large + num_latent_tokens: 64 # in paper we set this to 256 + +losses: + discriminator_start: 100_000 # set based on convergence, in paper we set this to 250_000 + +dataset: + params: + pretokenization: True # keep this true if using the current setup + train_shards_path_or_url: "./data/imagenet_wds/imagenet-train-{000001..000128}.tar" + eval_shards_path_or_url: "./data/imagenet_wds/imagenet-val-{000001..000004}.tar" + preprocessing: + resize_shorter_edge: 512 # maximum size during pretraining but can be any value + crop_size: 512 # maximum size during pretraining but can be any value + min_tokens: 32 # minimum number of tokens to generate + max_tokens: 64 # maximum number of tokens to generate + +training: + gradient_accumulation_steps: 1 # increase for LL model that does not fit on single node + per_gpu_batch_size: 32 # decrease to 16 for LL model; during GAN training this is halved + max_train_steps: 400_000 # in paper we train up to 650_000; model may diverge after 600_000 + num_generated_images: 2 # for validation + variable_resolution: # any-to-any resolution training + any2any: True + dim: + - [256, 256] + - [512, 512] + - [384, 256] + - [256, 384] + - [512, 384] + - [384, 512] + ratio: [0.3, 0.3, 0.1, 0.1, 0.1, 0.1] # probability per resolution; must sum to 1.0 + + +# Remove patch mixture parameters unless the model does not fit in memory. +# This will slow down training and may hurt performance. +# We do not use this in our normal setup. +model: + vq_model: + encoder: + patch_mixture_start_layer: 2 + patch_mixture_end_layer: 22 + decoder: + patch_mixture_start_layer: 2 + patch_mixture_end_layer: 22 +``` + + + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba34a59d354c11c378fe362ed38acfc0778908a --- /dev/null +++ b/app.py @@ -0,0 +1,471 @@ +""" +VibeToken-Gen Gradio Demo +Class-conditional ImageNet generation with dynamic resolution support. +""" +import spaces + +import os +import random + +import gradio as gr +import numpy as np +import torch + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("high") +torch.set_grad_enabled(False) +setattr(torch.nn.Linear, "reset_parameters", lambda self: None) +setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + +from huggingface_hub import hf_hub_download +from PIL import Image + +from vibetokengen.generate import generate +from vibetokengen.model import GPT_models +from vibetoken import VibeTokenTokenizer + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +HF_REPO = "mpatel57/VibeToken" +USE_XXL = os.environ.get("VIBETOKEN_XXL", "0") == "1" + +if USE_XXL: + GPT_MODEL_NAME = "GPT-XXL" + GPT_CKPT_FILENAME = "VibeTokenGen-xxl-dynamic-65_750k.pt" + NUM_OUTPUT_LAYER = 4 + EXTRA_LAYERS = "QKV" +else: + GPT_MODEL_NAME = "GPT-B" + GPT_CKPT_FILENAME = "VibeTokenGen-b-fixed65_dynamic_1500k.pt" + NUM_OUTPUT_LAYER = 4 + EXTRA_LAYERS = "QKV" + +VQ_CKPT_FILENAME = "VibeToken_LL.bin" +CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "vibetoken_ll.yaml") + +CODEBOOK_SIZE = 32768 +NUM_CODEBOOKS = 8 +LATENT_SIZE = 65 +NUM_CLASSES = 1000 +CLS_TOKEN_NUM = 1 +CLASS_DROPOUT_PROB = 0.1 +CAPPING = 50.0 + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 +COMPILE = os.environ.get("VIBETOKEN_NO_COMPILE", "0") != "1" and DEVICE == "cuda" + +# --------------------------------------------------------------------------- +# ImageNet class labels (curated popular subset) +# --------------------------------------------------------------------------- + +IMAGENET_CLASSES = { + "Golden Retriever": 207, + "Labrador Retriever": 208, + "German Shepherd": 235, + "Siberian Husky": 250, + "Pembroke Corgi": 263, + "Tabby Cat": 281, + "Persian Cat": 283, + "Siamese Cat": 284, + "Tiger": 292, + "Lion": 291, + "Cheetah": 293, + "Brown Bear": 294, + "Giant Panda": 388, + "Red Fox": 277, + "Arctic Fox": 279, + "Timber Wolf": 269, + "Bald Eagle": 22, + "Macaw": 88, + "Flamingo": 130, + "Peacock": 84, + "Goldfish": 1, + "Great White Shark": 2, + "Jellyfish": 107, + "Monarch Butterfly": 323, + "Ladybug": 301, + "Snail": 113, + "Red Sports Car": 817, + "School Bus": 779, + "Steam Locomotive": 820, + "Sailboat": 914, + "Space Shuttle": 812, + "Castle": 483, + "Church": 497, + "Lighthouse": 437, + "Volcano": 980, + "Lakeside": 975, + "Cliff": 972, + "Coral Reef": 973, + "Valley": 979, + "Seashore": 978, + "Mushroom": 947, + "Broccoli": 937, + "Pizza": 963, + "Ice Cream": 928, + "Cheeseburger": 933, + "Espresso": 967, + "Acoustic Guitar": 402, + "Grand Piano": 579, + "Violin": 889, + "Balloon": 417, +} + +GENERATOR_RESOLUTION_PRESETS = { + "256 × 256": (256, 256), + "384 × 256": (384, 256), + "256 × 384": (256, 384), + "384 × 384": (384, 384), + "512 × 256": (512, 256), + "256 × 512": (256, 512), + "512 × 512": (512, 512), +} + +OUTPUT_RESOLUTION_PRESETS = { + "Same as generator": None, + "256 × 256": (256, 256), + "384 × 384": (384, 384), + "512 × 512": (512, 512), + "768 × 768": (768, 768), + "1024 × 1024": (1024, 1024), + "512 × 256 (2:1)": (512, 256), + "256 × 512 (1:2)": (256, 512), + "768 × 512 (3:2)": (768, 512), + "512 × 768 (2:3)": (512, 768), + "1024 × 512 (2:1)": (1024, 512), + "512 × 1024 (1:2)": (512, 1024), +} + +# --------------------------------------------------------------------------- +# Model loading +# --------------------------------------------------------------------------- + +vq_model = None +gpt_model = None + + +def download_checkpoint(filename: str) -> str: + return hf_hub_download(repo_id=HF_REPO, filename=filename) + + +def _make_res_tensors(gen_h: int, gen_w: int, multiplier: int): + """Create normalized resolution tensors for the GPT generator.""" + th = torch.tensor(gen_h / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1) + tw = torch.tensor(gen_w / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1) + return th, tw + + +def _warmup(model): + """Run a throwaway generation to trigger torch.compile and warm CUDA caches.""" + print("Warming up (first call triggers compilation, may take ~30-60s)...") + dummy_cond = torch.tensor([0], device=DEVICE) + th, tw = _make_res_tensors(256, 256, multiplier=2) + with torch.inference_mode(): + generate( + model, dummy_cond, LATENT_SIZE, NUM_CODEBOOKS, + cfg_scale=4.0, cfg_interval=-1, + target_h=th, target_w=tw, + temperature=1.0, top_k=500, top_p=1.0, sample_logits=True, + ) + if DEVICE == "cuda": + torch.cuda.synchronize() + print("Warmup complete — subsequent generations will be fast.") + + +def load_models(): + global vq_model, gpt_model + + print("Downloading checkpoints (if needed)...") + vq_path = download_checkpoint(VQ_CKPT_FILENAME) + gpt_path = download_checkpoint(GPT_CKPT_FILENAME) + + print(f"Loading VibeToken tokenizer from {vq_path}...") + vq_model = VibeTokenTokenizer.from_config( + CONFIG_PATH, vq_path, device=DEVICE, dtype=DTYPE, + ) + print("VibeToken tokenizer loaded.") + + print(f"Loading {GPT_MODEL_NAME} from {gpt_path}...") + gpt_model = GPT_models[GPT_MODEL_NAME]( + vocab_size=CODEBOOK_SIZE, + block_size=LATENT_SIZE, + num_classes=NUM_CLASSES, + cls_token_num=CLS_TOKEN_NUM, + model_type="c2i", + num_codebooks=NUM_CODEBOOKS, + n_output_layer=NUM_OUTPUT_LAYER, + class_dropout_prob=CLASS_DROPOUT_PROB, + extra_layers=EXTRA_LAYERS, + capping=CAPPING, + ).to(device=DEVICE, dtype=DTYPE) + + checkpoint = torch.load(gpt_path, map_location="cpu", weights_only=False) + if "model" in checkpoint: + weights = checkpoint["model"] + elif "module" in checkpoint: + weights = checkpoint["module"] + elif "state_dict" in checkpoint: + weights = checkpoint["state_dict"] + else: + weights = checkpoint + gpt_model.load_state_dict(weights, strict=True) + gpt_model.eval() + del checkpoint + print(f"{GPT_MODEL_NAME} loaded.") + + if COMPILE: + print("Compiling GPT model with torch.compile (max-autotune)...") + gpt_model = torch.compile(gpt_model, mode="max-autotune", fullgraph=True) + _warmup(gpt_model) + else: + print("Skipping torch.compile (set VIBETOKEN_NO_COMPILE=0 to enable).") + + +# --------------------------------------------------------------------------- +# Decoder patch-size heuristic +# --------------------------------------------------------------------------- + +def auto_decoder_patch_size(h: int, w: int) -> tuple[int, int]: + max_dim = max(h, w) + if max_dim <= 256: + ps = 8 + elif max_dim <= 512: + ps = 16 + else: + ps = 32 + return (ps, ps) + + +# --------------------------------------------------------------------------- +# Generation +# --------------------------------------------------------------------------- + +@torch.inference_mode() +@spaces.GPU(duration=90) +def generate_image( + class_name: str, + class_id: int, + gen_resolution_preset: str, + out_resolution_preset: str, + decoder_ps_choice: str, + cfg_scale: float, + temperature: float, + top_k: int, + top_p: float, + seed: int, + randomize_seed: bool, +): + if vq_model is None or gpt_model is None: + raise gr.Error("Models are still loading. Please wait a moment and try again.") + + if randomize_seed: + seed = random.randint(0, 2**31 - 1) + + torch.manual_seed(seed) + np.random.seed(seed) + if DEVICE == "cuda": + torch.cuda.manual_seed_all(seed) + + if class_name and class_name != "Custom (enter ID below)": + cid = IMAGENET_CLASSES[class_name] + else: + cid = int(class_id) + cid = max(0, min(cid, NUM_CLASSES - 1)) + + gen_h, gen_w = GENERATOR_RESOLUTION_PRESETS[gen_resolution_preset] + + out_res = OUTPUT_RESOLUTION_PRESETS[out_resolution_preset] + if out_res is None: + out_h, out_w = gen_h, gen_w + else: + out_h, out_w = out_res + + if decoder_ps_choice == "Auto": + dec_ps = auto_decoder_patch_size(out_h, out_w) + else: + ps = int(decoder_ps_choice) + dec_ps = (ps, ps) + + multiplier = 2 if cfg_scale > 1.0 else 1 + + c_indices = torch.tensor([cid], device=DEVICE) + th, tw = _make_res_tensors(gen_h, gen_w, multiplier) + + index_sample = generate( + gpt_model, + c_indices, + LATENT_SIZE, + NUM_CODEBOOKS, + cfg_scale=cfg_scale, + cfg_interval=-1, + target_h=th, + target_w=tw, + temperature=temperature, + top_k=top_k, + top_p=top_p, + sample_logits=True, + ) + + index_sample = index_sample.unsqueeze(2) + samples = vq_model.decode( + index_sample, + height=out_h, + width=out_w, + patch_size=dec_ps, + ) + samples = torch.clamp(samples, 0, 1) + + img_np = (samples[0].permute(1, 2, 0).float().cpu().numpy() * 255).astype("uint8") + pil_img = Image.fromarray(img_np) + + return pil_img, seed + + +# --------------------------------------------------------------------------- +# Gradio UI +# --------------------------------------------------------------------------- + +HEADER_MD = """ +# VibeToken-Gen: Dynamic Resolution Image Generation + +

+ Maitreya Patel, Jingtao Li, Weiming Zhuang, Yezhou Yang, Lingjuan Lyu +  |  +

+

CVPR 2026 (Main Conference)

+ +

+ 🤗 Model  |  + 💻 GitHub +

+ +Generate ImageNet class-conditional images at **arbitrary resolutions** using only **65 tokens**. +VibeToken-Gen maintains a constant **179G FLOPs** regardless of output resolution. +""" + +CITATION_MD = """ +### Citation +```bibtex +@inproceedings{vibetoken2026, + title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations}, + author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2026} +} +``` +""" + +class_choices = ["Custom (enter ID below)"] + sorted(IMAGENET_CLASSES.keys()) + +with gr.Blocks( + title="VibeToken-Gen Demo", + theme=gr.themes.Soft(), +) as demo: + gr.Markdown(HEADER_MD) + + with gr.Row(): + # ---- Left column: controls ---- + with gr.Column(scale=1): + class_dropdown = gr.Dropdown( + label="ImageNet Class", + choices=class_choices, + value="Golden Retriever", + info="Pick a class or choose 'Custom' to enter an ID manually.", + ) + class_id_input = gr.Number( + label="Custom Class ID (0–999)", + value=207, + minimum=0, + maximum=999, + step=1, + visible=False, + ) + gen_resolution_dropdown = gr.Dropdown( + label="Generator Resolution", + choices=list(GENERATOR_RESOLUTION_PRESETS.keys()), + value="256 × 256", + info="Internal resolution for the AR generator (max 512×512).", + ) + out_resolution_dropdown = gr.Dropdown( + label="Output Resolution (Decoder)", + choices=list(OUTPUT_RESOLUTION_PRESETS.keys()), + value="Same as generator", + info="Final image resolution. Set higher for super-resolution (e.g. generate at 256, decode at 1024).", + ) + decoder_ps_dropdown = gr.Dropdown( + label="Decoder Patch Size", + choices=["Auto", "8", "16", "32"], + value="Auto", + info="'Auto' selects based on output resolution. Larger = faster but coarser.", + ) + + with gr.Accordion("Advanced Sampling Parameters", open=False): + cfg_slider = gr.Slider( + label="CFG Scale", + minimum=1.0, maximum=20.0, value=4.0, step=0.5, + info="Classifier-free guidance strength.", + ) + temp_slider = gr.Slider( + label="Temperature", + minimum=0.1, maximum=2.0, value=1.0, step=0.05, + ) + topk_slider = gr.Slider( + label="Top-k", + minimum=0, maximum=2000, value=500, step=10, + info="0 disables top-k filtering.", + ) + topp_slider = gr.Slider( + label="Top-p", + minimum=0.0, maximum=1.0, value=1.0, step=0.05, + info="1.0 disables nucleus sampling.", + ) + seed_input = gr.Number( + label="Seed", value=0, minimum=0, maximum=2**31 - 1, step=1, + ) + randomize_cb = gr.Checkbox(label="Randomize seed", value=True) + + generate_btn = gr.Button("Generate", variant="primary", size="lg") + + # ---- Right column: output ---- + with gr.Column(scale=2): + output_image = gr.Image(label="Generated Image", type="pil", height=512) + used_seed = gr.Number(label="Seed used", interactive=False) + + # Show/hide custom class ID field + def toggle_custom_id(choice): + return gr.update(visible=(choice == "Custom (enter ID below)")) + + class_dropdown.change( + fn=toggle_custom_id, + inputs=[class_dropdown], + outputs=[class_id_input], + ) + + generate_btn.click( + fn=generate_image, + inputs=[ + class_dropdown, + class_id_input, + gen_resolution_dropdown, + out_resolution_dropdown, + decoder_ps_dropdown, + cfg_slider, + temp_slider, + topk_slider, + topp_slider, + seed_input, + randomize_cb, + ], + outputs=[output_image, used_seed], + ) + + gr.Markdown(CITATION_MD) + + +if __name__ == "__main__": + load_models() + demo.launch() diff --git a/assets/example_1.png b/assets/example_1.png new file mode 100644 index 0000000000000000000000000000000000000000..02baa35758b74ca837670d8fbfee94c0c9f62063 --- /dev/null +++ b/assets/example_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da07f6cd58181e54c6b4bbbc0458d99f48946da19a27ea02e9a3920bfb2b5d15 +size 334096 diff --git a/assets/generated_images.png b/assets/generated_images.png new file mode 100644 index 0000000000000000000000000000000000000000..c162f8a3610735f4295df42285d8061f5b6ed409 --- /dev/null +++ b/assets/generated_images.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1f5cdef18942f6331460be883485ce5141c7d9c6db7e1cd7596422a57b5cba7 +size 11085224 diff --git a/assets/reconstructed.png b/assets/reconstructed.png new file mode 100644 index 0000000000000000000000000000000000000000..b8c8f0317a2afe89ef7a49b24727818d8b55141a --- /dev/null +++ b/assets/reconstructed.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86b57a2d196f6af62b37979b15a7dcda8de1097669ccb297d9213b32098d5873 +size 343690 diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..5bd55143ee3ba7080e1bb3e5ddda5fb2eee5aa9a --- /dev/null +++ b/assets/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46bd1ff58c18d17b2808ba9445d6beab0ae9098be21b316e9ed730d553b607fb +size 2859418 diff --git a/configs/training/VibeToken_small.yaml b/configs/training/VibeToken_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4922a7b450f3db0bc69ae24a27cd11e981a4fd2 --- /dev/null +++ b/configs/training/VibeToken_small.yaml @@ -0,0 +1,102 @@ +experiment: + project: "VibeToken_mvq_tiny_main" + name: "VibeToken_mvq_tiny_main" + output_dir: "wandb/VibeToken_mvq_tiny_main" + max_train_examples: 1_281_167 + save_every: 10_000 + eval_every: 10_000 + generate_every: 5_000 + log_every: 50 + log_grad_norm_every: 1_000 + resume: True + +model: + sub_model_type: "vibetoken" + train_with_attention: True + eval_with_attention: True + vq_model: + # encoder: # patch mixture is not supported + # patch_mixture_start_layer: 2 + # patch_mixture_end_layer: 22 + # decoder: # patch mixture is not supported + # patch_mixture_start_layer: 2 + # patch_mixture_end_layer: 22 + quantize_mode: mvq + codebook_size: 32768 # 32768 / 8 = 4096 + token_size: 256 # 256 / 8 = 32 + use_l2_norm: False + commitment_cost: 0.25 + clustering_vq: False + num_codebooks: 8 + # vit arch + vit_enc_model_size: "small" + vit_dec_model_size: "small" + vit_enc_patch_size: 32 + vit_dec_patch_size: 32 + num_latent_tokens: 64 + finetune_decoder: False + is_legacy: False + +losses: + discriminator_start: 100_000 + quantizer_weight: 1.0 + discriminator_factor: 1.0 + discriminator_weight: 0.1 + perceptual_loss: "lpips-convnext_s-1.0-0.1" + perceptual_weight: 1.1 + reconstruction_loss: "l2" + reconstruction_weight: 1.0 + lecam_regularization_weight: 0.001 + +dataset: + params: + pretokenization: True + train_shards_path_or_url: "./data/imagenet_wds/imagenet-train-{000001..000128}.tar" + eval_shards_path_or_url: "./data/imagenet_wds/imagenet-val-{000001..000004}.tar" + num_workers_per_gpu: 12 + preprocessing: + resize_shorter_edge: 512 + crop_size: 512 + random_crop: True + random_flip: True + res_ratio_filtering: True + min_tokens: 32 + max_tokens: 64 + +optimizer: + name: adamw + params: + learning_rate: 1e-4 + discriminator_learning_rate: 1e-4 + beta1: 0.9 + beta2: 0.999 + weight_decay: 1e-4 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 10_000 + end_lr: 1e-5 + +training: + gradient_accumulation_steps: 1 + per_gpu_batch_size: 32 + mixed_precision: "fp16" + enable_tf32: True + enable_wandb: True + use_ema: True + seed: 42 + max_train_steps: 400_000 + num_generated_images: 2 + max_grad_norm: 1.0 + variable_resolution: + any2any: True + dim: + - [256, 256] + - [512, 512] + - [384, 256] + - [256, 384] + - [512, 384] + - [384, 512] + ratio: [0.3, 0.3, 0.1, 0.1, 0.1, 0.1] \ No newline at end of file diff --git a/configs/vibetoken_ll.yaml b/configs/vibetoken_ll.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c69a2440fb3f0e528d7eeee20fb6f2aa0bc2e2e --- /dev/null +++ b/configs/vibetoken_ll.yaml @@ -0,0 +1,40 @@ +# VibeToken Large-Large Configuration +# Large encoder + Large decoder for highest quality +# +# Usage: +# from vibetoken import VibeTokenTokenizer +# tokenizer = VibeTokenTokenizer.from_config( +# "configs/vibetoken_ll.yaml", +# "path/to/checkpoint.bin" +# ) + +model: + sub_model_type: "vibetoken" + vq_model: + # Quantization settings + quantize_mode: mvq + codebook_size: 32768 # 32768 / 8 = 4096 per codebook + token_size: 256 # 256 / 8 = 32 per codebook + num_codebooks: 8 + use_l2_norm: false + commitment_cost: 0.25 + + # Encoder architecture + vit_enc_model_size: "large" + vit_enc_patch_size: 32 + + # Decoder architecture + vit_dec_model_size: "large" + vit_dec_patch_size: 32 + + # Latent tokens + num_latent_tokens: 256 + + # Mode flags + is_legacy: false + finetune_decoder: false + +# Dataset preprocessing defaults (for reference) +dataset: + preprocessing: + crop_size: 512 diff --git a/configs/vibetoken_sl.yaml b/configs/vibetoken_sl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d0f6def50e0b6d01c3e554fd6c15ff1aa2197ad --- /dev/null +++ b/configs/vibetoken_sl.yaml @@ -0,0 +1,40 @@ +# VibeToken Small-Large Configuration +# Small encoder + Large decoder for faster encoding +# +# Usage: +# from vibetoken import VibeTokenTokenizer +# tokenizer = VibeTokenTokenizer.from_config( +# "configs/vibetoken_sl.yaml", +# "path/to/checkpoint.bin" +# ) + +model: + sub_model_type: "vibetoken" + vq_model: + # Quantization settings + quantize_mode: mvq + codebook_size: 32768 # 32768 / 8 = 4096 per codebook + token_size: 256 # 256 / 8 = 32 per codebook + num_codebooks: 8 + use_l2_norm: false + commitment_cost: 0.25 + + # Encoder architecture (Small for faster encoding) + vit_enc_model_size: "small" + vit_enc_patch_size: 32 + + # Decoder architecture (Large for quality) + vit_dec_model_size: "large" + vit_dec_patch_size: 32 + + # Latent tokens + num_latent_tokens: 256 + + # Mode flags + is_legacy: false + finetune_decoder: false + +# Dataset preprocessing defaults (for reference) +dataset: + preprocessing: + crop_size: 512 diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..283b24901f4bca6318dfe6f31b7ff1c39b0d5375 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1 @@ +from .webdataset_reader import SimpleImageDataset, PretoeknizedDataSetJSONL, PretokenizedWebDataset diff --git a/data/convert_imagenet_to_wds.py b/data/convert_imagenet_to_wds.py new file mode 100644 index 0000000000000000000000000000000000000000..d4c15cd898055f8d2339f9818fd1eadaf8291f83 --- /dev/null +++ b/data/convert_imagenet_to_wds.py @@ -0,0 +1,56 @@ +# Adapted from https://github.com/webdataset/webdataset-imagenet/blob/main/convert-imagenet.py + +import argparse +import os +import sys +import time + +import webdataset as wds +from datasets import load_dataset + + +def convert_imagenet_to_wds(input_dir, output_dir, max_train_samples_per_shard, max_val_samples_per_shard): + assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar")) + assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar")) + + opat = os.path.join(output_dir, "imagenet-train-%06d.tar") + output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard) + dataset = load_dataset(input_dir, split="train") + now = time.time() + for i, example in enumerate(dataset): + if i % max_train_samples_per_shard == 0: + print(i, file=sys.stderr) + img, label = example["image"], example["label"] + output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label}) + output.close() + time_taken = time.time() - now + print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.") + + opat = os.path.join(output_dir, "imagenet-val-%06d.tar") + output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard) + dataset = load_dataset(input_dir, split="validation") + now = time.time() + for i, example in enumerate(dataset): + if i % max_val_samples_per_shard == 0: + print(i, file=sys.stderr) + img, label = example["image"], example["label"] + output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label}) + output.close() + time_taken = time.time() - now + print(f"Wrote {i+1} val examples in {time_taken // 60} min.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, + help="Path to the ImageNet-1k dataset (HuggingFace format).") + parser.add_argument("--output_dir", type=str, required=True, + help="Path to the output directory for WebDataset shards.") + parser.add_argument("--max_train_samples_per_shard", type=int, default=10000, + help="Maximum number of training samples per shard.") + parser.add_argument("--max_val_samples_per_shard", type=int, default=10000, + help="Maximum number of validation samples per shard.") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + convert_imagenet_to_wds(args.input_dir, args.output_dir, args.max_train_samples_per_shard, args.max_val_samples_per_shard) \ No newline at end of file diff --git a/data/webdataset_reader.py b/data/webdataset_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..76b95ba61a5e24873ec82f5a60f8bc1f53309743 --- /dev/null +++ b/data/webdataset_reader.py @@ -0,0 +1,518 @@ +"""Data loader using webdataset. + +Reference: + https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py + https://github.com/huggingface/open-muse/blob/main/training/data.py +""" + +import math +from typing import List, Union, Text +import webdataset as wds +import numpy as np +import torch +from torch.utils.data import default_collate +from torchvision import transforms +from torch.utils.data import Dataset +import linecache +import json +from PIL import Image +import random +import cv2 +import numpy as np +from tqdm import tqdm + +Image.MAX_IMAGE_PIXELS = None + + +def load_json(sample): + sample['json'] = json.loads(sample['json'].decode('utf-8')) + return sample + + +def filter_keys(key_set): + def _f(dictionary): + return {k: v for k, v in dictionary.items() if k in key_set} + + return _f + + +def filter_by_res_ratio(min_res=256, min_ratio=0.5, max_ratio=2.0): + def _f(sample): + cfg = sample['json'] + h, w = cfg['original_height'], cfg['original_width'] + ratio = h/w + longer_side = max(h, w) + return ratio >= min_ratio and ratio <= max_ratio and longer_side >= min_res + return _f + +def calculate_laplacian_variance(image): + """Calculate the variance of Laplacian which is a measure of image sharpness/blur.""" + # Convert to grayscale if it's RGB + image = np.array(image) + if len(image.shape) == 3: + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + else: + gray = image + + # Calculate Laplacian + laplacian = cv2.Laplacian(gray, cv2.CV_64F) + + # Calculate variance + return laplacian.var() + +# Add this function to map Laplacian values to token lengths +def get_dynamic_length(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=256, mean_tokens=128): + """ + Maps Laplacian values to token lengths using a bell curve approach. + At the mean Laplacian value, uses mean_tokens. + Values further from the mean get mapped to shorter/longer token lengths. + """ + # Prevent division by zero and handle edge cases + if std <= 0: + return mean_tokens + + # Calculate z-score + z_score = (laplacian_value - mean) / std + + # Use bell curve mapping (gaussian) + # When z_score is 0 (at mean), we get mean_tokens + # As z_score increases, token length increases toward max_tokens + # As z_score decreases, token length decreases toward min_tokens + scaling_factor = 2.0 # Controls how quickly we reach min/max tokens + normalized_position = 0.5 * (1 + math.tanh(scaling_factor * z_score)) + + # Map to token range [min_tokens, max_tokens] + token_length = min_tokens + normalized_position * (max_tokens - min_tokens) + return int(round(token_length)) + +# Add this function to map Laplacian values to token lengths +def get_dynamic_length_v2(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=128, mean_tokens=128): + """ + Maps Laplacian values to token lengths using a linear mapping. + Ensures laplacian_value=0 maps to min_tokens, mean maps to mean_tokens, + and higher values scale up to max_tokens. + """ + # Prevent division by zero and handle edge cases + if std <= 0: + return mean_tokens + + # Linear mapping from laplacian space to token space + # First normalize laplacian value relative to mean + normalized = (laplacian_value - 0.0) / mean + + # Map 0->min_tokens, mean->mean_tokens, and scale up linearly + if laplacian_value <= mean: + # Linear interpolation between min_tokens and mean_tokens + ratio = laplacian_value / mean + token_length = min_tokens + (mean_tokens - min_tokens) * ratio + else: + # Linear interpolation between mean_tokens and max_tokens + ratio = (laplacian_value - mean) / mean # How far past mean + token_length = mean_tokens + (max_tokens - mean_tokens) * ratio + + # Clamp to valid range + token_length = max(min_tokens, min(max_tokens, token_length)) + return int(round(token_length)) + +def get_laplacian_attention_mask(sample): + """Process sample to add Laplacian variance and attention mask.""" + # Create a new dict to avoid modifying the input + processed = dict(sample) + + # Calculate Laplacian variance + var = calculate_laplacian_variance(processed["image"]) + length = get_dynamic_length(var) + + # Create attention mask + attention_mask = torch.zeros((128,), dtype=torch.float32) + attention_mask[:length+1] = 1.0 + + # Add new fields to processed dict + processed["laplacian_var"] = var + processed["attention_mask"] = attention_mask + + return processed + +def get_uniform_attention_mask(min_tokens=32, max_tokens=128): + """Process sample to add uniform random attention mask.""" + def _f(dictionary): + # Sample length uniformly between min_tokens and max_tokens + length = torch.randint(min_tokens, max_tokens+1, (1,)).item() + + # Create attention mask + attention_mask = torch.zeros((max_tokens,), dtype=torch.float32) + attention_mask[:length+1] = 1.0 + + # Add attention mask to dictionary + dictionary["attention_mask"] = attention_mask + return dictionary + return _f + +def process_recap_text(p): + def _f(dictionary): + if "recap_txt" in dictionary: + if random.random() < p: + recap_prefixes = ["The image " + v for v in ['depicts', "displays", 'showcases', 'features', 'shows']] + # Convert input to string and strip whitespace + text = dictionary["recap_txt"].decode("utf-8").strip() + # Check if text starts with any of the phrases + for phrase in recap_prefixes: + if text.startswith(phrase): + # Remove the phrase and any leading/trailing whitespace + text = text[len(phrase):].strip() + # Capitalize the first letter + text = text[0].upper() + text[1:] if text else "" + break + + dictionary["text"] = text.encode("utf-8") + return dictionary + + return _f + + +def identity(x): + return x + + +class ImageTransform: + def __init__(self, + resize_shorter_edge: int = 256, + crop_size: int = 256, + random_crop: bool = True, + random_flip: bool = True, + normalize_mean: List[float] = [0., 0., 0.], + normalize_std: List[float] = [1., 1., 1.]): + """Initializes the WebDatasetReader with specified augmentation parameters. + + Args: + resize_shorter_edge: An integer, the shorter edge size to resize the input image to. + crop_size: An integer, the size to crop the input image to. + random_crop: A boolean, whether to use random crop augmentation during training. + random_flip: A boolean, whether to use random flipping augmentation during training. + normalize_mean: A list of float, the normalization mean used to normalize the image tensor. + normalize_std: A list of float, the normalization std used to normalize the image tensor. + + Raises: + NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"]. + """ + train_transform = [] + interpolation = transforms.InterpolationMode.BICUBIC + + train_transform.append( + transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True)) + if random_crop: + train_transform.append(transforms.RandomCrop(crop_size)) + else: + train_transform.append(transforms.CenterCrop(crop_size)) + if random_flip: + train_transform.append(transforms.RandomHorizontalFlip()) + train_transform.append(transforms.ToTensor()) + # normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1], + # normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1]. + train_transform.append(transforms.Normalize(normalize_mean, normalize_std)) + + self.train_transform = transforms.Compose(train_transform) + self.eval_transform = transforms.Compose( + [ + # Note that we always resize to crop_size during eval to ensure the results + # can be compared against reference numbers on ImageNet etc. + transforms.Resize(crop_size, interpolation=interpolation, antialias=True), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize(normalize_mean, normalize_std) + ] + ) + print(f"self.train_transform: {self.train_transform}") + print(f"self.eval_transform: {self.eval_transform}") + + +class SimpleImageDataset: + def __init__( + self, + train_shards_path: Union[Text, List[Text]], + eval_shards_path: Union[Text, List[Text]], + num_train_examples: int, + per_gpu_batch_size: int, + global_batch_size: int, + num_workers_per_gpu: int = 12, + resize_shorter_edge: int = 256, + crop_size: int = 256, + random_crop = True, + random_flip = True, + normalize_mean: List[float] = [0., 0., 0.], + normalize_std: List[float] = [1., 1., 1.], + dataset_with_class_label: bool = True, + dataset_with_text_label: bool = False, + res_ratio_filtering = False, + min_tokens = 32, + max_tokens = 128, + ): + """Initializes the WebDatasetReader class. + + Args: + train_shards_path: A string or list of string, path to the training data shards in webdataset format. + eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format. + num_train_examples: An integer, total number of training examples. + per_gpu_batch_size: An integer, number of examples per GPU batch. + global_batch_size: An integer, total number of examples in a batch across all GPUs. + num_workers_per_gpu: An integer, number of workers per GPU. + resize_shorter_edge: An integer, the shorter edge size to resize the input image to. + crop_size: An integer, the size to crop the input image to. + random_crop: A boolean, whether to use random crop augmentation during training. + random_flip: A boolean, whether to use random flipping augmentation during training. + normalize_mean: A list of float, the normalization mean used to normalize the image tensor. + normalize_std: A list of float, the normalization std used to normalize the image tensor. + """ + transform = ImageTransform( + resize_shorter_edge, crop_size, random_crop, random_flip, + normalize_mean, normalize_std) + + if dataset_with_class_label: + train_processing_pipeline = [ + wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue), + wds.rename( + image="jpg;png;jpeg;webp", + class_id="cls", + handler=wds.warn_and_continue, + ), + wds.map(filter_keys(set(["image", "class_id", "filename"]))), + wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)), + wds.map_dict( + image=transform.train_transform, + class_id=lambda x: int(x), + attention_mask=lambda x: x, + handler=wds.warn_and_continue, + ), + ] + elif dataset_with_text_label: + train_processing_pipeline = [ + wds.map(load_json), + wds.select(filter_by_res_ratio()) if res_ratio_filtering else wds.map(identity), + wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]),only=["webp", "png", "jpg", "jpeg", "txt"], handler=wds.warn_and_continue), + wds.rename( + image="jpg;png;jpeg;webp", + text="txt", + handler=wds.warn_and_continue, + ), + wds.map(filter_keys(set(["image", "text", "__key__"]))), + wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)), + wds.map_dict( + image=transform.train_transform, + attention_mask=lambda x: x, + handler=wds.warn_and_continue, + ), + ] + else: + raise NotImplementedError + + test_processing_pipeline = [ + wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue), + wds.rename( + image="jpg;png;jpeg;webp", + class_id="cls", + handler=wds.warn_and_continue, + ), + wds.map(filter_keys(set(["image", "class_id", "filename"]))), + wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)), + wds.map_dict( + image=transform.eval_transform, + class_id=lambda x: int(x), + # laplacian_var=lambda x: x, + attention_mask=lambda x: x, + handler=wds.warn_and_continue, + ), + ] + + # Create train dataset and loader. + pipeline = [ + wds.ResampledShards(train_shards_path), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(bufsize=5000, + initial=1000), + *train_processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + + num_batches = math.ceil(num_train_examples / global_batch_size) + num_worker_batches = math.ceil(num_train_examples / + (global_batch_size * num_workers_per_gpu)) + num_batches = num_worker_batches * num_workers_per_gpu + num_samples = num_batches * global_batch_size + + # Each worker is iterating over the complete dataset. + self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) + self._train_dataloader = wds.WebLoader( + self._train_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers_per_gpu, + pin_memory=True, + persistent_workers=True, + ) + # Add meta-data to dataloader instance for convenience. + self._train_dataloader.num_batches = num_batches + self._train_dataloader.num_samples = num_samples + + # Create eval dataset and loader. + pipeline = [ + wds.SimpleShardList(eval_shards_path), + wds.split_by_worker, + wds.tarfile_to_samples(handler=wds.ignore_and_continue), + *test_processing_pipeline, + wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate), + ] + self._eval_dataset = wds.DataPipeline(*pipeline) + self._eval_dataloader = wds.WebLoader( + self._eval_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers_per_gpu, + pin_memory=True, + persistent_workers=True, + ) + + @property + def train_dataset(self): + return self._train_dataset + + @property + def train_dataloader(self): + return self._train_dataloader + + @property + def eval_dataset(self): + return self._eval_dataset + + @property + def eval_dataloader(self): + return self._eval_dataloader + + +class PretoeknizedDataSetJSONL(Dataset): + def __init__(self, data_path): + super().__init__() + self.jsonl_file = data_path + self.num_lines = sum(1 for _ in open(self.jsonl_file)) + # Ensure the file is cached + linecache.checkcache(self.jsonl_file) + print("Number of data:", self.num_lines) + + def __len__(self): + return self.num_lines + + def __getitem__(self, idx): + line = linecache.getline(self.jsonl_file, idx + 1).strip() + data = json.loads(line) + return torch.tensor(data["class_id"]), torch.tensor(data["tokens"]) + + +class PretokenizedWebDataset(SimpleImageDataset): + def __init__ ( + self, + train_shards_path: Union[Text, List[Text]], + eval_shards_path: Union[Text, List[Text]], + num_train_examples: int, + per_gpu_batch_size: int, + global_batch_size: int, + num_workers_per_gpu: int, + resize_shorter_edge: int = 256, + crop_size: int = 256, + random_crop = True, + random_flip = True, + normalize_mean: List[float] = [0., 0., 0.], + normalize_std: List[float] = [1., 1., 1.], + process_recap = False, + use_recap_prob = 0.95, + ): + """Initializes the PretokenizedWebDataset class. + + Text-to-image datasets are pretokenized with careful filtering (Tab. 7 in Supp.) to speed up the training + """ + transform = ImageTransform( + resize_shorter_edge, crop_size, random_crop, random_flip, + normalize_mean, normalize_std) + + def decode_npy(x): + arr = np.frombuffer(x, dtype=np.float16) + ret = torch.tensor(arr) + return ret + + def decode_text(x): + ret = x.decode("utf-8") + return ret + + train_processing_pipeline = [ + wds.rename( + tokens="token.npy", + text="txt", + handler=wds.warn_and_continue, + ), + wds.map(process_recap_text(use_recap_prob) if process_recap else wds.map(identity)), + wds.map(filter_keys(set(["tokens", "text", "aes_score", "__key__"]))), + wds.map_dict( + tokens=decode_npy, + text=decode_text, + handler=wds.warn_and_continue, + ), + ] + + test_processing_pipeline = [ + wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])), + wds.rename( + image="jpg;png;jpeg;webp", + handler=wds.warn_and_continue, + ), + wds.map_dict( + image=transform.eval_transform, + handler=wds.warn_and_continue, + ), + ] + + + # Create train dataset and loader. + pipeline = [ + wds.ResampledShards(train_shards_path), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(bufsize=5000, + initial=1000), + *train_processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + + num_batches = math.ceil(num_train_examples / global_batch_size) + num_worker_batches = math.ceil(num_train_examples / + (global_batch_size * num_workers_per_gpu)) + num_batches = num_worker_batches * num_workers_per_gpu + num_samples = num_batches * global_batch_size + + # Each worker is iterating over the complete dataset. + self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) + self._train_dataloader = wds.WebLoader( + self._train_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers_per_gpu, + pin_memory=True, + persistent_workers=True, + ) + # Add meta-data to dataloader instance for convenience. + self._train_dataloader.num_batches = num_batches + self._train_dataloader.num_samples = num_samples + + # Create eval dataset and loader. + pipeline = [ + wds.SimpleShardList(eval_shards_path), + wds.split_by_worker, + wds.tarfile_to_samples(handler=wds.ignore_and_continue), + *test_processing_pipeline, + wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate), + ] + self._eval_dataset = wds.DataPipeline(*pipeline) + self._eval_dataloader = wds.WebLoader( + self._eval_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers_per_gpu, + pin_memory=True, + persistent_workers=True, + ) \ No newline at end of file diff --git a/evaluator/__init__.py b/evaluator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5774b098f8a2d51bd5ad898a7ad5015434d124c --- /dev/null +++ b/evaluator/__init__.py @@ -0,0 +1 @@ +from .evaluator import VQGANEvaluator \ No newline at end of file diff --git a/evaluator/evaluator.py b/evaluator/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..8860328f0ed5f22512fde935472ee47b92d51be5 --- /dev/null +++ b/evaluator/evaluator.py @@ -0,0 +1,230 @@ +"""Evaluator for reconstruction results.""" + +import warnings + +from typing import Sequence, Optional, Mapping, Text +import numpy as np +from scipy import linalg +import torch +import torch.nn.functional as F + +from .inception import get_inception_model + + +def get_covariance(sigma: torch.Tensor, total: torch.Tensor, num_examples: int) -> torch.Tensor: + """Computes covariance of the input tensor. + + Args: + sigma: A torch.Tensor, sum of outer products of input features. + total: A torch.Tensor, sum of all input features. + num_examples: An integer, number of examples in the input tensor. + Returns: + A torch.Tensor, covariance of the input tensor. + """ + if num_examples == 0: + return torch.zeros_like(sigma) + + sub_matrix = torch.outer(total, total) + sub_matrix = sub_matrix / num_examples + + return (sigma - sub_matrix) / (num_examples - 1) + + +class VQGANEvaluator: + def __init__( + self, + device, + enable_rfid: bool = True, + enable_inception_score: bool = True, + enable_codebook_usage_measure: bool = False, + enable_codebook_entropy_measure: bool = False, + num_codebook_entries: int = 1024 + ): + """Initializes VQGAN Evaluator. + + Args: + device: The device to use for evaluation. + enable_rfid: A boolean, whether enabling rFID score. + enable_inception_score: A boolean, whether enabling Inception Score. + enable_codebook_usage_measure: A boolean, whether enabling codebook usage measure. + enable_codebook_entropy_measure: A boolean, whether enabling codebook entropy measure. + num_codebook_entries: An integer, the number of codebook entries. + """ + self._device = device + + self._enable_rfid = enable_rfid + self._enable_inception_score = enable_inception_score + self._enable_codebook_usage_measure = enable_codebook_usage_measure + self._enable_codebook_entropy_measure = enable_codebook_entropy_measure + self._num_codebook_entries = num_codebook_entries + + # Variables related to Inception score and rFID. + self._inception_model = None + self._is_num_features = 0 + self._rfid_num_features = 0 + if self._enable_inception_score or self._enable_rfid: + self._rfid_num_features = 2048 + self._is_num_features = 1008 + self._inception_model = get_inception_model().to(self._device) + self._inception_model.eval() + self._is_eps = 1e-16 + self._rfid_eps = 1e-6 + + self.reset_metrics() + + def reset_metrics(self): + """Resets all metrics.""" + self._num_examples = 0 + self._num_updates = 0 + + self._is_prob_total = torch.zeros( + self._is_num_features, dtype=torch.float64, device=self._device + ) + self._is_total_kl_d = torch.zeros( + self._is_num_features, dtype=torch.float64, device=self._device + ) + self._rfid_real_sigma = torch.zeros( + (self._rfid_num_features, self._rfid_num_features), + dtype=torch.float64, device=self._device + ) + self._rfid_real_total = torch.zeros( + self._rfid_num_features, dtype=torch.float64, device=self._device + ) + self._rfid_fake_sigma = torch.zeros( + (self._rfid_num_features, self._rfid_num_features), + dtype=torch.float64, device=self._device + ) + self._rfid_fake_total = torch.zeros( + self._rfid_num_features, dtype=torch.float64, device=self._device + ) + + self._set_of_codebook_indices = set() + self._codebook_frequencies = torch.zeros((self._num_codebook_entries), dtype=torch.float64, device=self._device) + + def update( + self, + real_images: torch.Tensor, + fake_images: torch.Tensor, + codebook_indices: Optional[torch.Tensor] = None + ): + """Updates the metrics with the given images. + + Args: + real_images: A torch.Tensor, the real images. + fake_images: A torch.Tensor, the fake images. + codebook_indices: A torch.Tensor, the indices of the codebooks for each image. + + Raises: + ValueError: If the fake images is not in RGB (3 channel). + ValueError: If the fake and real images have different shape. + """ + + batch_size = real_images.shape[0] + dim = tuple(range(1, real_images.ndim)) + self._num_examples += batch_size + self._num_updates += 1 + + if self._enable_inception_score or self._enable_rfid: + # Quantize to uint8 as a real image. + fake_inception_images = (fake_images * 255).to(torch.uint8) + features_fake = self._inception_model(fake_inception_images) + inception_logits_fake = features_fake["logits_unbiased"] + inception_probabilities_fake = F.softmax(inception_logits_fake, dim=-1) + + if self._enable_inception_score: + probabiliies_sum = torch.sum(inception_probabilities_fake, 0, dtype=torch.float64) + + log_prob = torch.log(inception_probabilities_fake + self._is_eps) + if log_prob.dtype != inception_probabilities_fake.dtype: + log_prob = log_prob.to(inception_probabilities_fake) + kl_sum = torch.sum(inception_probabilities_fake * log_prob, 0, dtype=torch.float64) + + self._is_prob_total += probabiliies_sum + self._is_total_kl_d += kl_sum + + if self._enable_rfid: + real_inception_images = (real_images * 255).to(torch.uint8) + features_real = self._inception_model(real_inception_images) + if (features_real['2048'].shape[0] != features_fake['2048'].shape[0] or + features_real['2048'].shape[1] != features_fake['2048'].shape[1]): + raise ValueError(f"Number of features should be equal for real and fake.") + + for f_real, f_fake in zip(features_real['2048'], features_fake['2048']): + self._rfid_real_total += f_real + self._rfid_fake_total += f_fake + + self._rfid_real_sigma += torch.outer(f_real, f_real) + self._rfid_fake_sigma += torch.outer(f_fake, f_fake) + + if self._enable_codebook_usage_measure: + self._set_of_codebook_indices |= set(torch.unique(codebook_indices, sorted=False).tolist()) + + if self._enable_codebook_entropy_measure: + entries, counts = torch.unique(codebook_indices, sorted=False, return_counts=True) + self._codebook_frequencies.index_add_(0, entries.int(), counts.double()) + + + def result(self) -> Mapping[Text, torch.Tensor]: + """Returns the evaluation result.""" + eval_score = {} + + if self._num_examples < 1: + raise ValueError("No examples to evaluate.") + + if self._enable_inception_score: + mean_probs = self._is_prob_total / self._num_examples + log_mean_probs = torch.log(mean_probs + self._is_eps) + if log_mean_probs.dtype != self._is_prob_total.dtype: + log_mean_probs = log_mean_probs.to(self._is_prob_total) + excess_entropy = self._is_prob_total * log_mean_probs + avg_kl_d = torch.sum(self._is_total_kl_d - excess_entropy) / self._num_examples + + inception_score = torch.exp(avg_kl_d).item() + eval_score["InceptionScore"] = inception_score + + if self._enable_rfid: + mu_real = self._rfid_real_total / self._num_examples + mu_fake = self._rfid_fake_total / self._num_examples + sigma_real = get_covariance(self._rfid_real_sigma, self._rfid_real_total, self._num_examples) + sigma_fake = get_covariance(self._rfid_fake_sigma, self._rfid_fake_total, self._num_examples) + + mu_real, mu_fake = mu_real.cpu(), mu_fake.cpu() + sigma_real, sigma_fake = sigma_real.cpu(), sigma_fake.cpu() + + diff = mu_real - mu_fake + + # Product might be almost singular. + covmean, _ = linalg.sqrtm(sigma_real.mm(sigma_fake).numpy(), disp=False) + # Numerical error might give slight imaginary component. + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + if not np.isfinite(covmean).all(): + tr_covmean = np.sum(np.sqrt(( + (np.diag(sigma_real) * self._rfid_eps) * (np.diag(sigma_fake) * self._rfid_eps)) + / (self._rfid_eps * self._rfid_eps) + )) + + rfid = float(diff.dot(diff).item() + torch.trace(sigma_real) + torch.trace(sigma_fake) + - 2 * tr_covmean + ) + if torch.isnan(torch.tensor(rfid)) or torch.isinf(torch.tensor(rfid)): + warnings.warn("The product of covariance of train and test features is out of bounds.") + + eval_score["rFID"] = rfid + + if self._enable_codebook_usage_measure: + usage = float(len(self._set_of_codebook_indices)) / self._num_codebook_entries + eval_score["CodebookUsage"] = usage + + if self._enable_codebook_entropy_measure: + probs = self._codebook_frequencies / self._codebook_frequencies.sum() + entropy = (-torch.log2(probs + 1e-8) * probs).sum() + eval_score["CodebookEntropy"] = entropy + + return eval_score diff --git a/evaluator/inception.py b/evaluator/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..cda0355bbe510617e474ef2128f501044f4eab7a --- /dev/null +++ b/evaluator/inception.py @@ -0,0 +1,215 @@ +"""Inception model for FID evaluation. + +Reference: + https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py +""" +import torch +import torch.nn.functional as F + +from torch_fidelity.feature_extractor_base import FeatureExtractorBase +from torch_fidelity.helpers import vassert +from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2 +from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +# Note: Compared shasum and models should be the same. +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' + +class FeatureExtractorInceptionV3(FeatureExtractorBase): + INPUT_IMAGE_SIZE = 299 + + def __init__( + self, + name, + features_list, + **kwargs, + ): + """ + InceptionV3 feature extractor for 2D RGB 24bit images. + + Args: + + name (str): Unique name of the feature extractor, must be the same as used in + :func:`register_feature_extractor`. + + features_list (list): A list of the requested feature names, which will be produced for each input. This + feature extractor provides the following features: + + - '64' + - '192' + - '768' + - '2048' + - 'logits_unbiased' + - 'logits' + + """ + super(FeatureExtractorInceptionV3, self).__init__(name, features_list) + self.feature_extractor_internal_dtype = torch.float64 + + self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2) + + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2) + + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE_1(1280) + self.Mixed_7c = InceptionE_2(2048) + self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1)) + + self.fc = torch.nn.Linear(2048, 1008) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False) + #state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu') + self.load_state_dict(state_dict) + + self.to(self.feature_extractor_internal_dtype) + self.requires_grad_(False) + self.eval() + + def forward(self, x): + vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8') + vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}') + features = {} + remaining_features = self.features_list.copy() + + x = x.to(self.feature_extractor_internal_dtype) + # N x 3 x ? x ? + + x = interpolate_bilinear_2d_like_tensorflow1x( + x, + size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), + align_corners=False, + ) + # N x 3 x 299 x 299 + + # x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph + x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too + # N x 3 x 299 x 299 + + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.MaxPool_1(x) + # N x 64 x 73 x 73 + + if '64' in remaining_features: + features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) + remaining_features.remove('64') + if len(remaining_features) == 0: + return features + + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.MaxPool_2(x) + # N x 192 x 35 x 35 + + if '192' in remaining_features: + features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) + remaining_features.remove('192') + if len(remaining_features) == 0: + return features + + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + + if '768' in remaining_features: + features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32) + remaining_features.remove('768') + if len(remaining_features) == 0: + return features + + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + x = self.AvgPool(x) + # N x 2048 x 1 x 1 + + x = torch.flatten(x, 1) + # N x 2048 + + if '2048' in remaining_features: + features['2048'] = x + remaining_features.remove('2048') + if len(remaining_features) == 0: + return features + + if 'logits_unbiased' in remaining_features: + x = x.mm(self.fc.weight.T) + # N x 1008 (num_classes) + features['logits_unbiased'] = x + remaining_features.remove('logits_unbiased') + if len(remaining_features) == 0: + return features + + x = x + self.fc.bias.unsqueeze(0) + else: + x = self.fc(x) + # N x 1008 (num_classes) + + features['logits'] = x + return features + + @staticmethod + def get_provided_features_list(): + return '64', '192', '768', '2048', 'logits_unbiased', 'logits' + + @staticmethod + def get_default_feature_layer_for_metric(metric): + return { + 'isc': 'logits_unbiased', + 'fid': '2048', + 'kid': '2048', + 'prc': '2048', + }[metric] + + @staticmethod + def can_be_compiled(): + return True + + @staticmethod + def get_dummy_input_for_compile(): + return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8) + +def get_inception_model(): + model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"]) + return model diff --git a/examples/batch_inference.py b/examples/batch_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0094ce5080f1c95aa3905c3cac3ba6f425b7cd --- /dev/null +++ b/examples/batch_inference.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +"""Batch inference example for VibeToken. + +Demonstrates how to process multiple images efficiently in batches. + +Usage: + # Auto mode (recommended) + python examples/batch_inference.py --auto \ + --config configs/vibetoken_ll.yaml \ + --checkpoint path/to/checkpoint.bin \ + --input_dir path/to/images/ \ + --output_dir path/to/output/ \ + --batch_size 4 + + # Manual mode + python examples/batch_inference.py \ + --config configs/vibetoken_ll.yaml \ + --checkpoint path/to/checkpoint.bin \ + --input_dir path/to/images/ \ + --output_dir path/to/output/ \ + --batch_size 4 \ + --resolution 512 \ + --encoder_patch_size 16,32 \ + --decoder_patch_size 16 +""" + +import argparse +import time +from pathlib import Path + +import torch +from PIL import Image +import numpy as np + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple + + +def parse_patch_size(value): + """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32').""" + if value is None: + return None + if ',' in value: + parts = value.split(',') + return (int(parts[0]), int(parts[1])) + return int(value) + + +def load_and_preprocess_image(path: Path, target_size: tuple = None, auto_mode: bool = False) -> tuple: + """Load and preprocess image. + + Args: + path: Path to image + target_size: Optional target size (width, height) for resizing + auto_mode: If True, use auto_preprocess_image for cropping + + Returns: + image: numpy array + patch_size: auto-determined patch size (if auto_mode) or None + """ + img = Image.open(path).convert("RGB") + + if auto_mode: + # Use centralized auto_preprocess_image + img, patch_size, info = auto_preprocess_image(img, verbose=False) + return np.array(img), patch_size, info + else: + if target_size: + img = img.resize(target_size, Image.LANCZOS) + # Always center crop to ensure dimensions divisible by 32 + img = center_crop_to_multiple(img, multiple=32) + return np.array(img), None, None + + +def main(): + parser = argparse.ArgumentParser(description="VibeToken batch inference example") + parser.add_argument("--config", type=str, required=True, help="Path to config YAML") + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") + parser.add_argument("--input_dir", type=str, required=True, help="Directory with input images") + parser.add_argument("--output_dir", type=str, required=True, help="Directory for output images") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size") + parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") + + # Auto mode + parser.add_argument("--auto", action="store_true", + help="Auto mode: automatically determine optimal settings per image") + + # Manual mode options + parser.add_argument("--resolution", type=int, default=512, help="Target resolution (manual mode)") + parser.add_argument("--encoder_patch_size", type=str, default=None, + help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") + parser.add_argument("--decoder_patch_size", type=str, default=None, + help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") + args = parser.parse_args() + + # Parse patch sizes + encoder_patch_size = parse_patch_size(args.encoder_patch_size) + decoder_patch_size = parse_patch_size(args.decoder_patch_size) + + # Check CUDA + if args.device == "cuda" and not torch.cuda.is_available(): + print("CUDA not available, falling back to CPU") + args.device = "cpu" + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load tokenizer + print(f"Loading tokenizer from {args.config}") + tokenizer = VibeTokenTokenizer.from_config( + config_path=args.config, + checkpoint_path=args.checkpoint, + device=args.device, + ) + + if args.auto: + print("Running in AUTO MODE - optimal settings determined per image") + else: + print(f"Running in MANUAL MODE - resolution: {args.resolution}") + if encoder_patch_size: + print(f" Encoder patch size: {encoder_patch_size}") + if decoder_patch_size: + print(f" Decoder patch size: {decoder_patch_size}") + + # Find all images + input_dir = Path(args.input_dir) + image_extensions = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} + image_paths = [p for p in input_dir.iterdir() if p.suffix.lower() in image_extensions] + print(f"Found {len(image_paths)} images") + + if not image_paths: + print("No images found!") + return + + # Process in batches + target_size = (args.resolution, args.resolution) if not args.auto else None + total_time = 0 + num_processed = 0 + + if args.auto: + # AUTO MODE: Process images one by one since each may have different sizes + for i, path in enumerate(image_paths): + try: + img_array, patch_size, info = load_and_preprocess_image(path, auto_mode=True) + batch_array = img_array[np.newaxis, ...] # Add batch dim + + start_time = time.time() + + # Reconstruct with auto-determined patch size + height, width = info['cropped_size'][1], info['cropped_size'][0] + reconstructed = tokenizer.reconstruct( + batch_array, + encode_patch_size=patch_size, + decode_patch_size=patch_size, + target_height=height, + target_width=width, + ) + + if args.device == "cuda": + torch.cuda.synchronize() + + batch_time = time.time() - start_time + total_time += batch_time + num_processed += 1 + + # Save output + output_images = tokenizer.to_pil(reconstructed) + output_path = output_dir / f"{path.stem}_recon.png" + output_images[0].save(output_path) + + print(f"[{i+1}/{len(image_paths)}] {path.name}: " + f"{info['cropped_size'][0]}x{info['cropped_size'][1]}, " + f"patch_size={patch_size}, {batch_time:.2f}s") + + except Exception as e: + print(f"Error processing {path}: {e}") + continue + else: + # MANUAL MODE: Batch processing with uniform size + for batch_start in range(0, len(image_paths), args.batch_size): + batch_paths = image_paths[batch_start:batch_start + args.batch_size] + batch_names = [p.stem for p in batch_paths] + + # Load batch + batch_images = [] + for path in batch_paths: + try: + img_array, _, _ = load_and_preprocess_image(path, target_size, auto_mode=False) + batch_images.append(img_array) + except Exception as e: + print(f"Error loading {path}: {e}") + continue + + if not batch_images: + continue + + # Stack into batch tensor + batch_array = np.stack(batch_images, axis=0) + + # Measure time + start_time = time.time() + + # Reconstruct + reconstructed = tokenizer.reconstruct( + batch_array, + encode_patch_size=encoder_patch_size, + decode_patch_size=decoder_patch_size, + target_height=args.resolution, + target_width=args.resolution, + ) + + # Synchronize if GPU + if args.device == "cuda": + torch.cuda.synchronize() + + batch_time = time.time() - start_time + total_time += batch_time + num_processed += len(batch_images) + + # Save outputs + output_images = tokenizer.to_pil(reconstructed) + for name, img in zip(batch_names[:len(output_images)], output_images): + output_path = output_dir / f"{name}_recon.png" + img.save(output_path) + + print(f"Processed batch {batch_start // args.batch_size + 1}: " + f"{len(batch_images)} images in {batch_time:.2f}s " + f"({len(batch_images) / batch_time:.2f} img/s)") + + # Summary + if num_processed > 0: + print(f"\nTotal: {num_processed} images in {total_time:.2f}s") + print(f"Average: {num_processed / total_time:.2f} images/sec") + print(f"Per image: {total_time / num_processed * 1000:.1f}ms") + + +if __name__ == "__main__": + main() diff --git a/examples/encode_decode.py b/examples/encode_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..b7eef4436e9baa8fd8b6f3f5da3c022d1920c5ef --- /dev/null +++ b/examples/encode_decode.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +"""Basic encode-decode example for VibeToken. + +Demonstrates how to: +1. Load the tokenizer from config and checkpoint +2. Encode an image to discrete tokens +3. Decode tokens back to an image +4. Save the reconstructed image + +Usage: + # Auto mode (recommended) + python examples/encode_decode.py --auto \ + --config configs/vibetoken_ll.yaml \ + --checkpoint path/to/checkpoint.bin \ + --image path/to/image.jpg \ + --output reconstructed.png + + # Manual mode + python examples/encode_decode.py \ + --config configs/vibetoken_ll.yaml \ + --checkpoint path/to/checkpoint.bin \ + --image path/to/image.jpg \ + --output reconstructed.png \ + --encoder_patch_size 16,32 \ + --decoder_patch_size 16 +""" + +import argparse +from pathlib import Path + +import torch +from PIL import Image + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple + + +def parse_patch_size(value): + """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32').""" + if value is None: + return None + if ',' in value: + parts = value.split(',') + return (int(parts[0]), int(parts[1])) + return int(value) + + +def main(): + parser = argparse.ArgumentParser(description="VibeToken encode-decode example") + parser.add_argument("--config", type=str, required=True, help="Path to config YAML") + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") + parser.add_argument("--image", type=str, required=True, help="Path to input image") + parser.add_argument("--output", type=str, default="reconstructed.png", help="Output image path") + parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") + + # Auto mode + parser.add_argument("--auto", action="store_true", + help="Auto mode: automatically determine optimal settings") + + parser.add_argument("--height", type=int, default=None, help="Output height (default: input height)") + parser.add_argument("--width", type=int, default=None, help="Output width (default: input width)") + parser.add_argument("--encoder_patch_size", type=str, default=None, + help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") + parser.add_argument("--decoder_patch_size", type=str, default=None, + help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") + parser.add_argument("--num_tokens", type=int, default=None, help="Number of tokens to encode") + + args = parser.parse_args() + + # Check if CUDA is available + if args.device == "cuda" and not torch.cuda.is_available(): + print("CUDA not available, falling back to CPU") + args.device = "cpu" + + print(f"Loading tokenizer from {args.config}") + tokenizer = VibeTokenTokenizer.from_config( + config_path=args.config, + checkpoint_path=args.checkpoint, + device=args.device, + ) + print(f"Tokenizer loaded: codebook_size={tokenizer.codebook_size}, " + f"num_latent_tokens={tokenizer.num_latent_tokens}") + + # Load image + print(f"Loading image from {args.image}") + image = Image.open(args.image).convert("RGB") + original_size = image.size # (W, H) + print(f"Original image size: {original_size[0]}x{original_size[1]}") + + if args.auto: + # AUTO MODE - use centralized auto_preprocess_image + print("\n=== AUTO MODE ===") + image, patch_size, info = auto_preprocess_image(image, verbose=True) + encoder_patch_size = patch_size + decoder_patch_size = patch_size + height, width = info['cropped_size'][1], info['cropped_size'][0] + print("=================\n") + + # Encode to tokens + print("Encoding image to tokens...") + print(f" Using encoder patch size: {encoder_patch_size}") + tokens = tokenizer.encode(image, patch_size=encoder_patch_size) + print(f"Token shape: {tokens.shape}") + + # Decode back to image + print(f"Decoding tokens to image ({width}x{height})...") + print(f" Using decoder patch size: {decoder_patch_size}") + reconstructed = tokenizer.decode( + tokens, height=height, width=width, patch_size=decoder_patch_size + ) + + else: + # MANUAL MODE + # Parse patch sizes + encoder_patch_size = parse_patch_size(args.encoder_patch_size) + decoder_patch_size = parse_patch_size(args.decoder_patch_size) + + # Always center crop to ensure dimensions divisible by 32 + image = center_crop_to_multiple(image, multiple=32) + cropped_size = image.size # (W, H) + if cropped_size != original_size: + print(f"Center cropped to {cropped_size[0]}x{cropped_size[1]} (divisible by 32)") + + # Encode to tokens + print("Encoding image to tokens...") + if encoder_patch_size: + print(f" Using encoder patch size: {encoder_patch_size}") + tokens = tokenizer.encode(image, patch_size=encoder_patch_size, num_tokens=args.num_tokens) + print(f"Token shape: {tokens.shape}") + + if tokenizer.model.quantize_mode == "mvq": + print(f" - Batch size: {tokens.shape[0]}") + print(f" - Num codebooks: {tokens.shape[1]}") + print(f" - Sequence length: {tokens.shape[2]}") + else: + print(f" - Batch size: {tokens.shape[0]}") + print(f" - Sequence length: {tokens.shape[1]}") + + # Decode back to image (use cropped size as default) + height = args.height or cropped_size[1] + width = args.width or cropped_size[0] + print(f"Decoding tokens to image ({width}x{height})...") + if decoder_patch_size: + print(f" Using decoder patch size: {decoder_patch_size}") + + reconstructed = tokenizer.decode( + tokens, height=height, width=width, patch_size=decoder_patch_size + ) + + print(f"Reconstructed image shape: {reconstructed.shape}") + + # Convert to PIL and save + output_images = tokenizer.to_pil(reconstructed) + output_path = Path(args.output) + output_images[0].save(output_path) + print(f"Saved reconstructed image to {output_path}") + + # Compute PSNR (compare with cropped image) + import numpy as np + original_np = np.array(image).astype(np.float32) + recon_np = np.array(output_images[0]).astype(np.float32) + if original_np.shape == recon_np.shape: + mse = np.mean((original_np - recon_np) ** 2) + if mse > 0: + psnr = 20 * np.log10(255.0 / np.sqrt(mse)) + print(f"PSNR: {psnr:.2f} dB") + + +if __name__ == "__main__": + main() diff --git a/generate.py b/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..b893bff5c6117e4970ea9b18b3fa6bd3e574ff3c --- /dev/null +++ b/generate.py @@ -0,0 +1,240 @@ +# Modified from: +# DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py + +"""Example run: +python generate.py \ + --gpt-ckpt ./checkpoints/VibeTokenGen-xxl-dynamic-65_750k.pt \ + --gpt-model GPT-XXL --num-output-layer 4 \ + --num-codebooks 8 --codebook-size 32768 \ + --image-size 256 --cfg-scale 2.0 --top-k 0 --temperature 1.0 \ + --class-dropout-prob 0.1 \ + --extra-layers "QKV" \ + --latent-size 65 \ + --config ./configs/vibetoken_ll.yaml \ + --vq-ckpt ./checkpoints/VibeToken_LL.bin \ + --sample-dir ./assets/ \ + --skip-folder-creation \ + --compile \ + --decoder-patch-size 16,16 \ + --target-resolution 1024,1024 \ + --llamagen-target-resolution 256,256 \ + --precision bf16 +""" + +import torch + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision('high') +setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) +setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) +import torch.nn.functional as F +import torch.distributed as dist + +from tqdm import tqdm +import os +from PIL import Image +import numpy as np +import math +import argparse +import sys +from omegaconf import OmegaConf + +from vibetokengen.model import GPT_models +from vibetokengen.generate import generate + +from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def main(args): + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Set global seed for reproducibility + torch.manual_seed(args.global_seed) + np.random.seed(args.global_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.global_seed) + torch.cuda.manual_seed_all(args.global_seed) + + device = "cuda" if torch.cuda.is_available() else "cpu" + precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] + + # Load VibeToken model + vq_model = VibeTokenTokenizer.from_config( + args.config, + args.vq_ckpt, + device=device, + dtype=precision, + ) + print(f"VibeToken image tokenizer is loaded") + + # create and load gpt model + gpt_model = GPT_models[args.gpt_model]( + vocab_size=args.codebook_size, + block_size=args.latent_size, + num_classes=args.num_classes, + cls_token_num=args.cls_token_num, + model_type=args.gpt_type, + num_codebooks=args.num_codebooks, + n_output_layer=args.num_output_layer, + class_dropout_prob=args.class_dropout_prob, + extra_layers=args.extra_layers, + capping=args.capping, + ).to(device=device, dtype=precision) + print(f"GPT model is loaded") + + checkpoint = torch.load(args.gpt_ckpt, map_location="cpu", weights_only=False) + if args.from_fsdp: # fsdp + model_weight = checkpoint + elif "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "module" in checkpoint: # deepspeed + model_weight = checkpoint["module"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight, maybe add --from-fsdp to run command") + gpt_model.load_state_dict(model_weight, strict=True) + gpt_model.eval() + del checkpoint + + print(f"GPT model weights are loaded") + + if args.compile: + print(f"compiling the model...") + gpt_model = torch.compile( + gpt_model, + mode="reduce-overhead", + fullgraph=True + ) # requires PyTorch 2.0 (optional) + else: + print(f"no model compile") + + print(f"GPT model is compiled") + + # Create folder to save samples: + model_string_name = args.gpt_model.replace("/", "-") + if args.from_fsdp: + ckpt_string_name = args.gpt_ckpt.split('/')[-2] + else: + ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "") + folder_name = f"{model_string_name}-{ckpt_string_name}-target-resolution-{args.target_resolution}-llamagen-target-resolution-{args.llamagen_target_resolution}-vibetoken-" \ + f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \ + f"cfg-{args.cfg_scale}-seed-{args.global_seed}" + if args.skip_folder_creation: + sample_folder_dir = args.sample_dir + else: + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + + multiplier = 2 if args.cfg_scale > 1.0 else 1 + + # Use fixed class labels + class_labels = [207, 360, 387, 974, 88, 979, 417, 279] + c_indices = torch.tensor(class_labels, device=device) + n = len(class_labels) + nrow = 4 # 2 rows x 4 columns for 8 images + + index_sample = generate( + gpt_model, c_indices, args.latent_size, args.num_codebooks, + cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval, + target_h=torch.tensor(args.llamagen_target_resolution[0]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1), + target_w=torch.tensor(args.llamagen_target_resolution[1]/1536, device=device, dtype=precision).unsqueeze(0).repeat(len(c_indices)*multiplier, 1), + temperature=args.temperature, top_k=args.top_k, + top_p=args.top_p, sample_logits=True, + ) + + # Use VibeToken decode_tokens method + # VibeToken expects tokens in shape (batch_size, seq_len, 1) + index_sample = index_sample.unsqueeze(2) + samples = vq_model.decode( + index_sample, + height=args.target_resolution[0], + width=args.target_resolution[1], + patch_size=args.decoder_patch_size + ) + + # VibeToken output is in [0, 1] range, clamp and convert to uint8 + samples = torch.clamp(samples, 0, 1) + + # Create a grid of images (2 rows x 4 columns) + from torchvision.utils import make_grid + grid = make_grid(samples, nrow=nrow, padding=2, normalize=False) + + # Convert to PIL and save + grid_np = (grid.permute(1, 2, 0).to(torch.float32).cpu().numpy() * 255).astype('uint8') + Image.fromarray(grid_np).save(f"{sample_folder_dir}/generated_images.png") + print(f"Saved grid of {n} images to {sample_folder_dir}/generated_images.png") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B") + parser.add_argument("--gpt-ckpt", type=str, default=None) + parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", + help="class-conditional or text-conditional") + parser.add_argument("--from-fsdp", action='store_true') + parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input") + parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--compile", action='store_true', default=True) + # parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") + parser.add_argument("--config", type=str, required=True, help="Path to VibeToken config file") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384) + parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256) + parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=4.0) + parser.add_argument("--cfg-interval", type=float, default=-1) + parser.add_argument("--sample-dir", type=str, default="samples") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--num-fid-samples", type=int, default=50000) + parser.add_argument("--global-seed", type=int, default=0) # not used + parser.add_argument("--top-k", type=int, default=500, help="top-k value to sample with") + parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") + parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") + parser.add_argument("--num-codebooks", type=int, default=1) + parser.add_argument("--num-output-layer", type=int, default=1) + parser.add_argument("--class-dropout-prob", type=float, default=0.1) + parser.add_argument("--extra-layers", type=str, choices=['QK', 'QKV', 'FC', 'cap', 'clip', 'QK_cap', 'QKV_cap', 'QK_clip', 'QKV_clip', 'QK_FC_cap', 'QKV_FC_cap', 'QK_FC_clip', 'QKV_FC_clip'], default=None, + help="Type of extra layers to add: QK (query-key), QKV (query-key-value), FC (fully connected), cap (caption), clip (clip), QK_cap (query-key-caption), QKV_cap (query-key-value-caption), QK_clip (query-key-clip), QKV_clip (query-key-value-clip), QK_FC_cap (query-key-fully-connected-caption), QKV_FC_cap (query-key-value-fully-connected-caption), QK_FC_clip (query-key-fully-connected-clip), QKV_FC_clip (query-key-value-fully-connected-clip)") + parser.add_argument("--capping", type=float, default=50.0, help="Capping for attention softmax") + + # VibeToken dynamic + parser.add_argument("--decoder-patch-size", type=str, default="8,8", help="Decoder patch size as 'width,height'") + parser.add_argument("--target-resolution", type=str, default="256,256", help="Target resolution as 'width,height'") + parser.add_argument("--llamagen-target-resolution", type=str, default="256,256", help="LlamaGen target resolution as 'width,height'") + + parser.add_argument("--latent-size", type=int, default=16, help="Latent size") + parser.add_argument("--skip-folder-creation", action='store_true', default=False, help="skip folder creation") + + args = parser.parse_args() + + args.decoder_patch_size = tuple(map(int, args.decoder_patch_size.split(","))) + args.target_resolution = tuple(map(int, args.target_resolution.split(","))) + args.llamagen_target_resolution = tuple(map(int, args.llamagen_target_resolution.split(","))) + + main(args) \ No newline at end of file diff --git a/generator/__init__.py b/generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1caa2c3a152ce588930de89eca95b6b437fcf0 --- /dev/null +++ b/generator/__init__.py @@ -0,0 +1,4 @@ +"""Generator module placeholder for VibeToken-Gen integration.""" + +# Future: Add GPT-based generator for image synthesis +# from .gpt import VibeTokenGenerator diff --git a/modeling/__init__.py b/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modeling/modules/__init__.py b/modeling/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5abed7a88cf447243ab6dc17ba76ae4097d7910a --- /dev/null +++ b/modeling/modules/__init__.py @@ -0,0 +1,6 @@ +from .base_model import BaseModel +from .ema_model import EMAModel +from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, ReconstructionLoss_Single_Stage +from .blocks import TiTokEncoder, TiTokDecoder, TATiTokDecoder, UViTBlock +from .maskgit_vqgan import Decoder as Pixel_Decoder +from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer \ No newline at end of file diff --git a/modeling/modules/base_model.py b/modeling/modules/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb1e4791155ec018f2d663174208b61027a3317 --- /dev/null +++ b/modeling/modules/base_model.py @@ -0,0 +1,124 @@ +"""Base class implementation for models. + +Reference: + https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py +""" +import os +from typing import Union, Callable, Dict, Optional + +import torch + + +class BaseModel(torch.nn.Module): + + def __init__(self): + super().__init__() + + def save_pretrained_weight( + self, + save_directory: Union[str, os.PathLike], + save_function: Callable = None, + state_dict: Optional[Dict[str, torch.Tensor]] = None, + ): + """Saves a model and its configuration file to a directory. + + Args: + save_directory: A string or os.PathLike, directory to which to save. + Will be created if it doesn't exist. + save_function: A Callable function, the function to use to save the state dictionary. + Useful on distributed training like TPUs when one need to replace `torch.save` by + another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. + state_dict: A dictionary from str to torch.Tensor, the state dictionary to save. + If `None`, the model's state dictionary will be saved. + """ + if os.path.isfile(save_directory): + print(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + if state_dict is None: + state_dict = model_to_save.state_dict() + weights_name = "pytorch_model.bin" + + save_function(state_dict, os.path.join(save_directory, weights_name)) + + print(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + + def load_pretrained_weight( + self, + pretrained_model_path: Union[str, os.PathLike], + strict_loading: bool = True, + torch_dtype: Optional[torch.dtype] = None + ): + r"""Instantiates a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + Args: + pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights. + + Raises: + ValueError: If pretrained_model_path does not exist. + """ + # If pretrained_model_path is a file, set model_file to this file. + if os.path.isfile(pretrained_model_path): + model_file = pretrained_model_path + # If pretrained_model_path is a directory, set model_file to the path of the + # file "pytorch_model.bin" in this directory. + elif os.path.isdir(pretrained_model_path): + pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin") + if os.path.isfile(pretrained_model_path): + model_file = pretrained_model_path + else: + raise ValueError(f"{pretrained_model_path} does not exist") + else: + raise ValueError(f"{pretrained_model_path} does not exist") + + # Load model state from checkpoint. + checkpoint = torch.load(model_file, map_location="cpu") + # Load state dictionary into self. + msg = self.load_state_dict(checkpoint, strict=strict_loading) + # Print information about loading weights. + print(f"loading weight from {model_file}, msg: {msg}") + # If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype. + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + self.to(torch_dtype) + + # Set model in evaluation mode to deactivate DropOut modules by default. + self.eval() + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """Gets the number of parameters in the module. + + Args: + only_trainable: A boolean, whether to only include trainable parameters. + exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings. + + Returns: + An integer, the number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + diff --git a/modeling/modules/blocks.py b/modeling/modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..caa5fda7bc78647c76533fa88f605762beab231d --- /dev/null +++ b/modeling/modules/blocks.py @@ -0,0 +1,617 @@ +"""Transformer building blocks. + +Reference: + https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py + https://github.com/baofff/U-ViT/blob/main/libs/timm.py +""" + +import math +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from collections import OrderedDict +import einops +from einops.layers.torch import Rearrange + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model, + n_head, + mlp_ratio = 4.0, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.mlp_ratio = mlp_ratio + # optionally we can disable the FFN + if mlp_ratio > 0: + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + def attention( + self, + x: torch.Tensor + ): + return self.attn(x, x, x, need_weights=False)[0] + + def forward( + self, + x: torch.Tensor, + ): + attn_output = self.attention(x=self.ln_1(x)) + x = x + attn_output + if self.mlp_ratio > 0: + x = x + self.mlp(self.ln_2(x)) + return x + +if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + try: + import xformers + import xformers.ops + ATTENTION_MODE = 'xformers' + except: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, L, C = x.shape + + qkv = self.qkv(x) + if ATTENTION_MODE == 'flash': + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'xformers': + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B L H D + x = xformers.ops.memory_efficient_attention(q, k, v) + x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) + elif ATTENTION_MODE == 'math': + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplemented + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class UViTBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x, skip=None): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, skip) + else: + return self._forward(x, skip) + + def _forward(self, x, skip=None): + if self.skip_linear is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +def _expand_token(token, batch_size: int): + return token.unsqueeze(0).expand(batch_size, -1, -1) + + +class TiTokEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.image_size = config.dataset.preprocessing.crop_size + self.patch_size = config.model.vq_model.vit_enc_patch_size + self.grid_size = self.image_size // self.patch_size + self.model_size = config.model.vq_model.vit_enc_model_size + self.num_latent_tokens = config.model.vq_model.num_latent_tokens + self.token_size = config.model.vq_model.token_size + + if config.model.vq_model.get("quantize_mode", "vq") == "vae": + self.token_size = self.token_size * 2 # needs to split into mean and std + + self.is_legacy = config.model.vq_model.get("is_legacy", True) + + self.width = { + "small": 512, + "base": 768, + "large": 1024, + }[self.model_size] + self.num_layers = { + "small": 8, + "base": 12, + "large": 24, + }[self.model_size] + self.num_heads = { + "small": 8, + "base": 12, + "large": 16, + }[self.model_size] + + self.patch_embed = nn.Conv2d( + in_channels=3, out_channels=self.width, + kernel_size=self.patch_size, stride=self.patch_size, bias=True) + + scale = self.width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size ** 2 + 1, self.width)) + self.latent_token_positional_embedding = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.width)) + self.ln_pre = nn.LayerNorm(self.width) + self.transformer = nn.ModuleList() + for i in range(self.num_layers): + self.transformer.append(ResidualAttentionBlock( + self.width, self.num_heads, mlp_ratio=4.0 + )) + self.ln_post = nn.LayerNorm(self.width) + self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) + + def forward(self, pixel_values, latent_tokens): + batch_size = pixel_values.shape[0] + x = pixel_values + x = self.patch_embed(x) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + # class embeddings and positional embeddings + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width] + + + latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype) + latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype) + x = torch.cat([x, latent_tokens], dim=1) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + x = self.transformer[i](x) + x = x.permute(1, 0, 2) # LND -> NLD + + latent_tokens = x[:, 1+self.grid_size**2:] + latent_tokens = self.ln_post(latent_tokens) + # fake 2D shape + if self.is_legacy: + latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1) + else: + # Fix legacy problem. + latent_tokens = latent_tokens.reshape(batch_size, self.num_latent_tokens, self.width, 1).permute(0, 2, 1, 3) + latent_tokens = self.conv_out(latent_tokens) + latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens) + return latent_tokens + + +class TiTokDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.image_size = config.dataset.preprocessing.crop_size + self.patch_size = config.model.vq_model.vit_dec_patch_size + self.grid_size = self.image_size // self.patch_size + self.model_size = config.model.vq_model.vit_dec_model_size + self.num_latent_tokens = config.model.vq_model.num_latent_tokens + self.token_size = config.model.vq_model.token_size + self.is_legacy = config.model.vq_model.get("is_legacy", True) + self.width = { + "small": 512, + "base": 768, + "large": 1024, + }[self.model_size] + self.num_layers = { + "small": 8, + "base": 12, + "large": 24, + }[self.model_size] + self.num_heads = { + "small": 8, + "base": 12, + "large": 16, + }[self.model_size] + + self.decoder_embed = nn.Linear( + self.token_size, self.width, bias=True) + scale = self.width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size ** 2 + 1, self.width)) + # add mask token and query pos embed + self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width)) + self.latent_token_positional_embedding = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.width)) + self.ln_pre = nn.LayerNorm(self.width) + self.transformer = nn.ModuleList() + for i in range(self.num_layers): + self.transformer.append(ResidualAttentionBlock( + self.width, self.num_heads, mlp_ratio=4.0 + )) + self.ln_post = nn.LayerNorm(self.width) + + if self.is_legacy: + self.ffn = nn.Sequential( + nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True), + nn.Tanh(), + nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True), + ) + self.conv_out = nn.Identity() + else: + # Directly predicting RGB pixels + self.ffn = nn.Sequential( + nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True), + Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)', + p1 = self.patch_size, p2 = self.patch_size),) + self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) + + def forward(self, z_quantized): + N, C, H, W = z_quantized.shape + assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}" + x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD + x = self.decoder_embed(x) + + batchsize, seq_len, _ = x.shape + + mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype) + mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype), + mask_tokens], dim=1) + mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) + x = x + self.latent_token_positional_embedding[:seq_len] + x = torch.cat([mask_tokens, x], dim=1) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + x = self.transformer[i](x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1:1+self.grid_size**2] # remove cls embed + x = self.ln_post(x) + # N L D -> N D H W + x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size) + x = self.ffn(x.contiguous()) + x = self.conv_out(x) + return x + + +class TATiTokDecoder(TiTokDecoder): + def __init__(self, config): + super().__init__(config) + scale = self.width ** -0.5 + self.text_context_length = config.model.vq_model.get("text_context_length", 77) + self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768) + self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width) + self.text_guidance_positional_embedding = nn.Parameter(scale * torch.randn(self.text_context_length, self.width)) + + def forward(self, z_quantized, text_guidance): + N, C, H, W = z_quantized.shape + assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}" + x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD + x = self.decoder_embed(x) + + batchsize, seq_len, _ = x.shape + + mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype) + mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype), + mask_tokens], dim=1) + mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) + x = x + self.latent_token_positional_embedding[:seq_len] + x = torch.cat([mask_tokens, x], dim=1) + + text_guidance = self.text_guidance_proj(text_guidance) + text_guidance = text_guidance + self.text_guidance_positional_embedding + x = torch.cat([x, text_guidance], dim=1) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + x = self.transformer[i](x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1:1+self.grid_size**2] # remove cls embed + x = self.ln_post(x) + # N L D -> N D H W + x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size) + x = self.ffn(x.contiguous()) + x = self.conv_out(x) + return x + + +class WeightTiedLMHead(nn.Module): + def __init__(self, embeddings, target_codebook_size): + super().__init__() + self.weight = embeddings.weight + self.target_codebook_size = target_codebook_size + + def forward(self, x): + # x shape: [batch_size, seq_len, embed_dim] + # Get the weights for the target codebook size + weight = self.weight[:self.target_codebook_size] # Shape: [target_codebook_size, embed_dim] + # Compute the logits by matrix multiplication + logits = torch.matmul(x, weight.t()) # Shape: [batch_size, seq_len, target_codebook_size] + return logits + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + """ + + def __init__( + self, + channels + ): + super().__init__() + self.channels = channels + + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 3 * channels, bias=True) + ) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +class FinalLayer(nn.Module): + """ + The final layer adopted from DiT. + """ + def __init__(self, model_channels, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 2 * model_channels, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class SimpleMLPAdaLN(nn.Module): + """ + The MLP for Diffusion Loss. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param z_channels: channels in the condition. + :param num_res_blocks: number of residual blocks per downsample. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + z_channels, + num_res_blocks, + grad_checkpointing=False, + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.grad_checkpointing = grad_checkpointing + + self.time_embed = TimestepEmbedder(model_channels) + self.cond_embed = nn.Linear(z_channels, model_channels) + + self.input_proj = nn.Linear(in_channels, model_channels) + + res_blocks = [] + for i in range(num_res_blocks): + res_blocks.append(ResBlock( + model_channels, + )) + + self.res_blocks = nn.ModuleList(res_blocks) + self.final_layer = FinalLayer(model_channels, out_channels) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP + nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers + for block in self.res_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, c): + """ + Apply the model to an input batch. + :param x: an [N x C] Tensor of inputs. + :param t: a 1-D batch of timesteps. + :param c: conditioning from AR transformer. + :return: an [N x C] Tensor of outputs. + """ + x = self.input_proj(x) + t = self.time_embed(t) + c = self.cond_embed(c) + + y = t + c + + if self.grad_checkpointing and not torch.jit.is_scripting(): + for block in self.res_blocks: + x = checkpoint(block, x, y) + else: + for block in self.res_blocks: + x = block(x, y) + + return self.final_layer(x, y) + + def forward_with_cfg(self, x, t, c, cfg_scale): + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, c) + eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) \ No newline at end of file diff --git a/modeling/modules/discriminator.py b/modeling/modules/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcb38e84b90c9e028d9666a9c1c47a4a0173dc8 --- /dev/null +++ b/modeling/modules/discriminator.py @@ -0,0 +1,124 @@ +"""Discriminator implementation.""" +import functools +import math +from typing import Tuple + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .maskgit_vqgan import Conv2dSame + + +class BlurBlock(torch.nn.Module): + def __init__(self, + kernel: Tuple[int] = (1, 3, 3, 1) + ): + super().__init__() + + kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False) + kernel = kernel[None, :] * kernel[:, None] + kernel /= kernel.sum() + kernel = kernel.unsqueeze(0).unsqueeze(0) + self.register_buffer("kernel", kernel) + + def calc_same_pad(self, i: int, k: int, s: int) -> int: + return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ic, ih, iw = x.size()[-3:] + pad_h = self.calc_same_pad(i=ih, k=4, s=2) + pad_w = self.calc_same_pad(i=iw, k=4, s=2) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + + weight = self.kernel.expand(ic, -1, -1, -1) + + out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1]) + return out + + +class NLayerDiscriminator(torch.nn.Module): + def __init__( + self, + num_channels: int = 3, + hidden_channels: int = 128, + num_stages: int = 3, + blur_resample: bool = True, + blur_kernel_size: int = 4 + ): + """ Initializes the NLayerDiscriminator. + + Args: + num_channels -> int: The number of input channels. + hidden_channels -> int: The number of hidden channels. + num_stages -> int: The number of stages. + blur_resample -> bool: Whether to use blur resampling. + blur_kernel_size -> int: The blur kernel size. + """ + super().__init__() + assert num_stages > 0, "Discriminator cannot have 0 stages" + assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]" + + in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages))) + init_kernel_size = 5 + activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1) + + self.block_in = torch.nn.Sequential( + Conv2dSame( + num_channels, + hidden_channels, + kernel_size=init_kernel_size + ), + activation(), + ) + + BLUR_KERNEL_MAP = { + 3: (1,2,1), + 4: (1,3,3,1), + 5: (1,4,6,4,1), + } + + discriminator_blocks = [] + for i_level in range(num_stages): + in_channels = hidden_channels * in_channel_mult[i_level] + out_channels = hidden_channels * in_channel_mult[i_level + 1] + block = torch.nn.Sequential( + Conv2dSame( + in_channels, + out_channels, + kernel_size=3, + ), + torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]), + torch.nn.GroupNorm(32, out_channels), + activation(), + ) + discriminator_blocks.append(block) + + self.blocks = torch.nn.ModuleList(discriminator_blocks) + + self.pool = torch.nn.AdaptiveMaxPool2d((16, 16)) + + self.to_logits = torch.nn.Sequential( + Conv2dSame(out_channels, out_channels, 1), + activation(), + Conv2dSame(out_channels, 1, kernel_size=5) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + x -> torch.Tensor: The input tensor. + + Returns: + output -> torch.Tensor: The output tensor. + """ + hidden_states = self.block_in(x) + for block in self.blocks: + hidden_states = block(hidden_states) + + hidden_states = self.pool(hidden_states) + + return self.to_logits(hidden_states) diff --git a/modeling/modules/ema_model.py b/modeling/modules/ema_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d2f319b48e9a228eec464b6c2d58b5e37ffc91 --- /dev/null +++ b/modeling/modules/ema_model.py @@ -0,0 +1,241 @@ +"""EMA (Exponential Moving Average) model. + +Reference: + https://github.com/huggingface/open-muse/blob/64e1afe033717d795866ab8204484705cd4dc3f7/muse/modeling_ema.py#L8 +""" + + +import copy +from typing import Any, Iterable, Optional, Union + +import torch + + +class EMAModel: + """Exponential Moving Average of models weights.""" + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + update_every: int = 1, + current_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + **model_config_kwargs + ): + """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + update_every (int): The number of steps between each EMA update. + current_step (int): The current training step. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + + notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + self.temp_stored_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.update_every = update_every + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + self.optimization_step = current_step + self.cur_decay_value = None # set in `step()` + + self.model_cls = model_cls + self.model_config_kwargs = model_config_kwargs + + @classmethod + def from_pretrained(cls, checkpoint, model_cls, **model_config_kwargs) -> "EMAModel": + model = model_cls(**model_config_kwargs) + model.load_pretrained_weight(checkpoint) + + ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") + + if self.model_config_kwargs is None: + raise ValueError("`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__.") + + model = self.model_cls(**self.model_config_kwargs) + self.copy_to(model.parameters()) + model.save_pretrained_weight(path) + + def set_step(self, optimization_step: int): + self.optimization_step = optimization_step + + def get_decay(self, optimization_step: int) -> float: + """Computes the decay factor for the exponential moving average.""" + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # Make sure decay is not smaller than min_decay. + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + parameters = list(parameters) + + self.optimization_step += 1 + + if (self.optimization_step - 1) % self.update_every != 0: + return + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + self.cur_decay_value = decay + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """Copies current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) + + def to(self, device=None, dtype=None) -> None: + r"""Moves internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.min_decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + } + + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r"""Restores the parameters stored with the `store` method. Useful to validate + the model with EMA parameters without affecting the original optimization process. + Store the parameters before the `copy_to()` method. After validation (or + model saving), use this to restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # Deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") \ No newline at end of file diff --git a/modeling/modules/encoder_decoder.py b/modeling/modules/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..33559d9a446872dafab40c281ee44aa4fa3e8abe --- /dev/null +++ b/modeling/modules/encoder_decoder.py @@ -0,0 +1,1142 @@ +"""Encoder and decoder building blocks for VibeToken. + +Reference: + https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py + https://github.com/baofff/U-ViT/blob/main/libs/timm.py +""" + +import random +import math +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from collections import OrderedDict +import einops +from einops.layers.torch import Rearrange +from typing import Optional, Sequence, Tuple, Union +from modeling.modules.fuzzy_embedding import FuzzyEmbedding +import collections.abc +from itertools import repeat +from typing import Any +import numpy as np +import torch.nn.functional as F +from einops import rearrange +from torch import vmap +from torch import Tensor + +def to_2tuple(x: Any) -> Tuple: + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, 2)) + +class PatchMixture(): + def __init__(self, seed=42): + self.seed = seed + + def get_mask(self, x, mask_ratio=0.0, l1_reg=0.0, inverse=False): + batch_size, num_patches, _ = x.shape + device = x.device + num_mask = int(num_patches * mask_ratio) + num_keep = num_patches - num_mask + token_magnitudes = x.abs().sum(dim=-1) + min_mags = token_magnitudes.min(dim=1, keepdim=True)[0] + max_mags = token_magnitudes.max(dim=1, keepdim=True)[0] + token_magnitudes = (token_magnitudes - min_mags) / (max_mags - min_mags + 1e-8) + if inverse: + adjusted_magnitudes = 1.0 - token_magnitudes + else: + adjusted_magnitudes = token_magnitudes + noise_random = torch.rand(batch_size, num_patches, device=device) + noise = (1.0 - l1_reg) * noise_random + l1_reg * adjusted_magnitudes + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + ids_keep = ids_shuffle[:, :num_keep] + ids_mask = ids_shuffle[:, num_keep:] + mask = torch.ones((batch_size, num_patches), device=device, dtype=torch.bool) + mask.scatter_(1, ids_keep, False) + return { + 'mask': mask, + 'ids_keep': ids_keep, + 'ids_mask': ids_mask, + 'ids_shuffle': ids_shuffle, + 'ids_restore': ids_restore + } + + def start_route(self, x, mask_info): + ids_shuffle = mask_info['ids_shuffle'] + num_keep = mask_info['ids_keep'].size(1) + batch_indices = torch.arange(x.size(0), device=x.device).unsqueeze(-1) + x_shuffled = x.gather(1, ids_shuffle.unsqueeze(-1).expand(-1, -1, x.size(2))) + masked_x = x_shuffled[:, :num_keep, :] + return masked_x + + def end_route(self, masked_x, mask_info, original_x=None, mask_token=0.0): + batch_size, num_patches = mask_info['mask'].shape + num_keep = masked_x.size(1) + dim = masked_x.size(2) + device = masked_x.device + ids_restore = mask_info['ids_restore'] + batch_indices = torch.arange(batch_size, device=device).unsqueeze(-1) + x_unshuffled = torch.empty((batch_size, num_patches, dim), device=device) + x_unshuffled[:, :num_keep, :] = masked_x + if original_x is not None: + x_shuffled = original_x.gather(1, mask_info['ids_shuffle'].unsqueeze(-1).expand(-1, -1, dim)) + x_unshuffled[:, num_keep:, :] = x_shuffled[:, num_keep:, :] + else: + x_unshuffled[:, num_keep:, :].fill_(mask_token) + x_unmasked = x_unshuffled.gather(1, ids_restore.unsqueeze(-1).expand(-1, -1, dim)) + return x_unmasked + +class ResizableBlur(nn.Module): + """ + Single-parameter anti‑aliasing layer. + Call with scale=1,2,4 to downsample by 1× (identity), 2×, or 4×. + """ + def __init__(self, channels: int, + max_kernel_size: int = 9, + init_type: str = "gaussian"): + super().__init__() + self.C = channels + K = max_kernel_size # e.g. 9 for 4× + assert K % 2 == 1, "kernel must be odd" + + # ----- initialise the largest kernel --------------------------------- + if init_type == "gaussian": + # 2‑D separable Gaussian, σ≈K/6 + ax = torch.arange(-(K//2), K//2 + 1) + g1d = torch.exp(-0.5 * (ax / (K/6.0))**2) + g2d = torch.outer(g1d, g1d) + kernel = g2d / g2d.sum() + elif init_type == "lanczos": + a = K//2 # window size parameter + x = torch.arange(-a, a+1).float() + sinc = lambda t: torch.where(t==0, torch.ones_like(t), torch.sin(torch.pi*t)/(torch.pi*t)) + k1d = sinc(x) * sinc(x/a) + k2d = torch.outer(k1d, k1d) + kernel = k2d / k2d.sum() + else: + raise ValueError("unknown init_type") + + # learnable base kernel (shape 1×1×K×K) + self.weight = nn.Parameter(kernel.unsqueeze(0).unsqueeze(0)) + + # ------------------------------------------------------------------------ + @staticmethod + def _resize_and_normalise(weight: torch.Tensor, k_size: int) -> torch.Tensor: + """ + Bilinearly interpolate weight (B,C,H,W) to target k_size×k_size, + then L1‑normalise over spatial dims so Σ=1. + """ + if weight.shape[-1] != k_size: + weight = F.interpolate(weight, size=(k_size, k_size), + mode="bilinear", align_corners=True) + weight = weight / weight.sum(dim=(-2, -1), keepdim=True).clamp(min=1e-8) + return weight + + # ------------------------------------------------------------------------ + def forward(self, x: torch.Tensor, input_size, target_size) -> torch.Tensor: + # Unpack input and target dimensions + input_h, input_w = input_size + target_h, target_w = target_size + + # Calculate scale factors for height and width + scale_h = input_h / target_h + scale_w = input_w / target_w + + # Determine kernel size based on scale factors + # Larger scale factors need larger kernels for better anti-aliasing + k_size_h = min(self.weight.shape[-1], max(1, int(2 * scale_h + 3))) + k_size_w = min(self.weight.shape[-1], max(1, int(2 * scale_w + 3))) + + # Make sure kernel sizes are odd + k_size_h = k_size_h if k_size_h % 2 == 1 else k_size_h + 1 + k_size_w = k_size_w if k_size_w % 2 == 1 else k_size_w + 1 + + # Use the maximum for a square kernel, or create a rectangular kernel if needed + k_size = max(k_size_h, k_size_w) + + # Calculate appropriate stride and padding + stride_h = max(1, round(scale_h)) + stride_w = max(1, round(scale_w)) + pad_h = k_size_h // 2 + pad_w = k_size_w // 2 + + # Get the kernel and normalize it + k = self._resize_and_normalise(self.weight, k_size) # (1,1,k,k) + k = k.repeat(self.C, 1, 1, 1) # depth-wise + + # Apply convolution with calculated parameters + result = F.conv2d(x, weight=k, stride=(stride_h, stride_w), + padding=(pad_h, pad_w), groups=self.C) + + # If the convolution didn't get us exactly to the target size, use interpolation for fine adjustment + if result.shape[2:] != target_size: + result = F.interpolate(result, size=target_size, mode='bilinear', align_corners=True) + + return result + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model, + n_head, + mlp_ratio = 4.0, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.mlp_ratio = mlp_ratio + # optionally we can disable the FFN + if mlp_ratio > 0: + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + def attention( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None + ): + return self.attn(x, x, x, attn_mask=attention_mask, need_weights=False)[0] + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None + ): + attn_output = self.attention(x=self.ln_1(x), attention_mask=attention_mask) + x = x + attn_output + if self.mlp_ratio > 0: + x = x + self.mlp(self.ln_2(x)) + return x + +if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + try: + import xformers + import xformers.ops + ATTENTION_MODE = 'xformers' + except: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, L, C = x.shape + + qkv = self.qkv(x) + if ATTENTION_MODE == 'flash': + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'xformers': + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B L H D + x = xformers.ops.memory_efficient_attention(q, k, v) + x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) + elif ATTENTION_MODE == 'math': + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplemented + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class UViTBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x, skip=None): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, skip) + else: + return self._forward(x, skip) + + def _forward(self, x, skip=None): + if self.skip_linear is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +def _expand_token(token, batch_size: int): + return token.unsqueeze(0).expand(batch_size, -1, -1) + + +class ResolutionEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.image_size = config.dataset.preprocessing.crop_size + self.patch_size = config.model.vq_model.vit_enc_patch_size + self.model_size = config.model.vq_model.vit_enc_model_size + self.num_latent_tokens = config.model.vq_model.num_latent_tokens + self.token_size = config.model.vq_model.token_size + self.apply_fuzzy = config.model.vq_model.get("apply_fuzzy", False) + self.patch_mixture_start_layer = config.model.vq_model.get("patch_mixture_start_layer", 100) + self.patch_mixture_end_layer = config.model.vq_model.get("patch_mixture_end_layer", 100) + + if config.model.vq_model.get("quantize_mode", "vq") == "vae": + self.token_size = self.token_size * 2 # needs to split into mean and std + + self.is_legacy = config.model.vq_model.get("is_legacy", True) + + self.width = { + "tiny": 256, + "small": 512, + "base": 768, + "large": 1024, + }[self.model_size] + self.num_layers = { + "tiny": 4, + "small": 8, + "base": 12, + "large": 24, + }[self.model_size] + self.num_heads = { + "tiny": 4, + "small": 8, + "base": 12, + "large": 16, + }[self.model_size] + + self.patch_embed = nn.Conv2d( + in_channels=3, out_channels=self.width, + kernel_size=self.patch_size, stride=self.patch_size, bias=True) + + scale = self.width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) + + self.positional_embedding = FuzzyEmbedding(1024, scale, self.width) + + self.latent_token_positional_embedding = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.width)) + self.ln_pre = nn.LayerNorm(self.width) + + self.patch_mixture = PatchMixture() + + self.transformer = nn.ModuleList() + for i in range(self.num_layers): + self.transformer.append(ResidualAttentionBlock( + self.width, self.num_heads, mlp_ratio=4.0 + )) + + self.ln_post = nn.LayerNorm(self.width) + self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) + self.pinvs = {} + + def apply_flexivit_patch_embed(self, x, target_patch_size): + patch_size = to_2tuple(target_patch_size) + + # Resize conv weights + if patch_size == to_2tuple(self.patch_size): + weight = self.patch_embed.weight + else: + weight = self.resize_patch_embed(self.patch_embed.weight, patch_size) + + # Apply conv with resized weights + x = F.conv2d(x, weight, bias=self.patch_embed.bias, stride=patch_size) + return x + + def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: + x_resized = F.interpolate( + x[None, None, ...], + shape, + mode="bilinear", + antialias=False, + ) + return x_resized[0, 0, ...] + + def _calculate_pinv( + self, old_shape: Tuple[int, int], new_shape: Tuple[int, int], device=None + ) -> Tensor: + # Use the device from patch_embed weights if available + if device is None and hasattr(self, 'patch_embed'): + device = self.patch_embed.weight.device + + mat = [] + for i in range(np.prod(old_shape)): + basis_vec = torch.zeros(old_shape, device=device) # Specify device here + basis_vec[np.unravel_index(i, old_shape)] = 1.0 + mat.append(self._resize(basis_vec, new_shape).reshape(-1)) + resize_matrix = torch.stack(mat) + return torch.linalg.pinv(resize_matrix) + + def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]): + """Resize patch_embed to target resolution via pseudo-inverse resizing""" + # Return original kernel if no resize is necessary + if to_2tuple(self.patch_size) == new_patch_size: + return patch_embed + + # Calculate pseudo-inverse of resize matrix + if new_patch_size not in self.pinvs: + self.pinvs[new_patch_size] = self._calculate_pinv( + to_2tuple(self.patch_size), new_patch_size, device=patch_embed.device + ) + pinv = self.pinvs[new_patch_size] + + def resample_patch_embed(patch_embed: Tensor): + h, w = new_patch_size + original_dtype = patch_embed.dtype + patch_embed_float = patch_embed.float() + resampled_kernel = pinv @ patch_embed_float.reshape(-1) + resampled_kernel = resampled_kernel.to(original_dtype) + return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w) + + v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1) + + return v_resample_patch_embed(patch_embed) + + def get_attention_mask(self, target_shape, attention_mask): + # Create mask for mask_tokens (all True since we want to attend to all mask tokens) + mask_token_mask = torch.ones(target_shape).to(attention_mask.device) + # Combine with input attention mask + attention_mask = torch.cat((mask_token_mask, attention_mask), dim=1).bool() + sequence_length = attention_mask.shape[1] + + # Create causal attention mask + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S] + attention_mask = attention_mask.expand( + attention_mask.shape[0], + self.num_heads, + sequence_length, + sequence_length + ) + + # Reshape to [B*num_heads, S, S] + attention_mask = attention_mask.reshape( + -1, sequence_length, sequence_length + ) + + # Convert boolean mask to float + attention_mask = attention_mask.float() + + # Convert mask values: True -> 0.0, False -> -inf + attention_mask = attention_mask.masked_fill( + ~attention_mask.bool(), + float('-inf') + ) + return attention_mask + + def forward(self, pixel_values, latent_tokens, attention_mask=None, encode_patch_size=None, train=True): + batch_size, _, H, W = pixel_values.shape + x = pixel_values + + # Apply dynamic patch embedding + # Determine patch size dynamically based on image resolution + # Base patch size (32) is for 512x512 images + # Scale proportionally for other resolutions to maintain ~256 tokens + base_resolution = 512 + + if encode_patch_size is None: + base_patch_size = random.choice([16, 32]) + target_patch_size = min(int(min(H, W) / base_resolution * base_patch_size), 32) # we force it to be at most 32 otherwise we lose information + else: + target_patch_size = encode_patch_size + + if isinstance(target_patch_size, int): + target_patch_size = (target_patch_size, target_patch_size) + + x = self.apply_flexivit_patch_embed(x, target_patch_size) + + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + # class embeddings and positional embeddings + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + + # create image_rotary_emb + grid_height = H // target_patch_size[0] + grid_width = W // target_patch_size[1] + + mask_ratio = 0.0 + if grid_height*grid_width > 256 and train: + mask_ratio = torch.empty(1).uniform_(0.5, 0.7).item() + + num_latent_tokens = latent_tokens.shape[0] + latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype) + latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)[:num_latent_tokens] + + x = x + self.positional_embedding(grid_height, grid_width, train=train, dtype=x.dtype) + + # apply attention_mask before concatenating x and latent_tokens + if attention_mask is not None: + key_attention_mask = attention_mask.clone() + attention_mask = self.get_attention_mask((batch_size, x.shape[1]), key_attention_mask) + full_seq_attention_mask = attention_mask.clone() + else: + key_attention_mask = None + full_seq_attention_mask = None + + # Concatenate x and latent_tokens first + x = torch.cat([x, latent_tokens], dim=1) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + if i == self.patch_mixture_start_layer: + x = x.permute(1, 0, 2) + x_D_last = x[:, 1:grid_height*grid_width+1].clone() + mask_info = self.patch_mixture.get_mask(x[:, 1:grid_height*grid_width+1], mask_ratio=mask_ratio) + new_x = self.patch_mixture.start_route(x, mask_info) + x = torch.cat([x[:, :1], new_x, x[:, grid_height*grid_width+1:]], dim=1) + x = x.permute(1, 0, 2) + if key_attention_mask is not None: + attention_mask = self.get_attention_mask((batch_size, 1+new_x.shape[1]), key_attention_mask) + else: + attention_mask = None + + x = self.transformer[i](x, attention_mask=attention_mask) + + if i == self.patch_mixture_end_layer: + x = x.permute(1, 0, 2) + new_x = self.patch_mixture.end_route(x[:, 1:-self.num_latent_tokens], mask_info, original_x=x_D_last) + x = torch.cat([x[:, :1], new_x, x[:, -self.num_latent_tokens:]], dim=1) + x = x.permute(1, 0, 2) + if full_seq_attention_mask is not None: + attention_mask = full_seq_attention_mask.clone() + else: + attention_mask = None + + x = x.permute(1, 0, 2) # LND -> NLD + + latent_tokens = x[:, 1+grid_height*grid_width:] + latent_tokens = self.ln_post(latent_tokens) + + # fake 2D shape + if self.is_legacy: + latent_tokens = latent_tokens.reshape(batch_size, self.width, num_latent_tokens, 1) + else: + # Fix legacy problem. + latent_tokens = latent_tokens.reshape(batch_size, num_latent_tokens, self.width, 1).permute(0, 2, 1, 3) + latent_tokens = self.conv_out(latent_tokens) + latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, num_latent_tokens) + return latent_tokens + +# Keep the original TiTokEncoder as a legacy class +class TiTokEncoder(ResolutionEncoder): + """Legacy TiTokEncoder - now inherits from ResolutionEncoder for backward compatibility""" + pass + +class ResolutionDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.image_size = config.dataset.preprocessing.crop_size + self.patch_size = config.model.vq_model.vit_dec_patch_size + self.model_size = config.model.vq_model.vit_dec_model_size + self.num_latent_tokens = config.model.vq_model.num_latent_tokens + self.token_size = config.model.vq_model.token_size + self.apply_fuzzy = config.model.vq_model.get("apply_fuzzy", False) + self.patch_mixture_start_layer = config.model.vq_model.get("patch_mixture_start_layer", 100) + self.patch_mixture_end_layer = config.model.vq_model.get("patch_mixture_end_layer", 100) + + self.is_legacy = config.model.vq_model.get("is_legacy", True) + self.width = { + "tiny": 256, + "small": 512, + "base": 768, + "large": 1024, + }[self.model_size] + self.num_layers = { + "tiny": 4, + "small": 8, + "base": 12, + "large": 24, + }[self.model_size] + self.num_heads = { + "tiny": 4, + "small": 8, + "base": 12, + "large": 16, + }[self.model_size] + + self.decoder_embed = nn.Linear( + self.token_size, self.width, bias=True) + scale = self.width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) + + self.positional_embedding = FuzzyEmbedding(1024, scale, self.width) + + # add mask token and query pos embed + self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width)) + self.latent_token_positional_embedding = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.width)) + self.ln_pre = nn.LayerNorm(self.width) + + self.patch_mixture = PatchMixture() + + self.transformer = nn.ModuleList() + for i in range(self.num_layers): + self.transformer.append(ResidualAttentionBlock( + self.width, self.num_heads, mlp_ratio=4.0 + )) + self.ln_post = nn.LayerNorm(self.width) + + if self.is_legacy: + raise NotImplementedError("Legacy mode is not implemented for ResolutionDecoder") + else: + # Directly predicting RGB pixels + self.ffn = nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True) + self.rearrange = Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)', + p1 = self.patch_size, p2 = self.patch_size) + self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos") + self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) + + def get_attention_mask(self, target_shape, attention_mask): + # Create mask for mask_tokens (all True since we want to attend to all mask tokens) + mask_token_mask = torch.ones(target_shape).to(attention_mask.device) + # Combine with input attention mask + attention_mask = torch.cat((mask_token_mask, attention_mask), dim=1).bool() + sequence_length = attention_mask.shape[1] + + # Create causal attention mask + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, S] + attention_mask = attention_mask.expand( + attention_mask.shape[0], + self.num_heads, + sequence_length, + sequence_length + ) + + # Reshape to [B*num_heads, S, S] + attention_mask = attention_mask.reshape( + -1, sequence_length, sequence_length + ) + + # Convert boolean mask to float + attention_mask = attention_mask.float() + + # Convert mask values: True -> 0.0, False -> -inf + attention_mask = attention_mask.masked_fill( + ~attention_mask.bool(), + float('-inf') + ) + return attention_mask + + def forward(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True): + N, C, H, W = z_quantized.shape + x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD + x = self.decoder_embed(x) + + batchsize, seq_len, _ = x.shape + + if height is None: + height = self.image_size + if width is None: + width = self.image_size + + # create image_rotary_emb + if decode_patch_size is None: + # Calculate total area and determine appropriate patch size + total_pixels = height * width + + # Target patch counts between 256 and 1024 + min_patches = 256 + max_patches = 1024 + + # Calculate possible patch sizes that would give us patch counts in our target range + possible_patch_sizes = [] + for patch_size in [8, 16, 32]: + grid_h = height // patch_size + grid_w = width // patch_size + total_patches = grid_h * grid_w + if min_patches <= total_patches <= max_patches: + possible_patch_sizes.append(patch_size) + + if not possible_patch_sizes: + # If no patch size gives us the desired range, pick the one closest to our target range + patch_counts = [] + for patch_size in [8, 16, 32]: + grid_h = height // patch_size + grid_w = width // patch_size + patch_counts.append((patch_size, grid_h * grid_w)) + + # Sort by how close the patch count is to our target range + patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches))) + possible_patch_sizes = [patch_counts[0][0]] + + selected_patch_size = random.choice(possible_patch_sizes) + else: + selected_patch_size = decode_patch_size + + if isinstance(selected_patch_size, int): + selected_patch_size = (selected_patch_size, selected_patch_size) + + grid_height = height // selected_patch_size[0] + grid_width = width // selected_patch_size[1] + + # if grid_height*grid_width>1024 and train: + # grid_height = 32 + # grid_width = 32 + + mask_ratio = 0.0 + if grid_height*grid_width > 256 and train: + mask_ratio = torch.empty(1).uniform_(0.5, 0.7).item() + + mask_tokens = self.mask_token.repeat(batchsize, grid_height*grid_width, 1).to(x.dtype) + mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype), + mask_tokens], dim=1) + + mask_tokens = mask_tokens + self.positional_embedding(grid_height, grid_width, train=train).to(mask_tokens.dtype) + + x = x + self.latent_token_positional_embedding[:seq_len] + x = torch.cat([mask_tokens, x], dim=1) + + if attention_mask is not None: + key_attention_mask = attention_mask.clone() + attention_mask = self.get_attention_mask((batchsize, 1+grid_height*grid_width), key_attention_mask) + full_seq_attention_mask = attention_mask.clone() + else: + key_attention_mask = None + full_seq_attention_mask = None + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + if i == self.patch_mixture_start_layer: + x = x.permute(1, 0, 2) + x_D_last = x[:, 1:grid_height*grid_width+1].clone() + mask_info = self.patch_mixture.get_mask(x[:, 1:grid_height*grid_width+1], mask_ratio=mask_ratio) + new_x = self.patch_mixture.start_route(x, mask_info) + x = torch.cat([x[:, :1], new_x, x[:, grid_height*grid_width+1:]], dim=1) + x = x.permute(1, 0, 2) + if key_attention_mask is not None: + attention_mask = self.get_attention_mask((batchsize, 1+new_x.shape[1]), key_attention_mask) + else: + attention_mask = None + + x = self.transformer[i](x, attention_mask=attention_mask) + + if i == self.patch_mixture_end_layer: + x = x.permute(1, 0, 2) + new_x = self.patch_mixture.end_route(x[:, 1:-self.num_latent_tokens], mask_info, original_x=x_D_last) + x = torch.cat([x[:, :1], new_x, x[:, -self.num_latent_tokens:]], dim=1) + x = x.permute(1, 0, 2) + if full_seq_attention_mask is not None: + attention_mask = full_seq_attention_mask.clone() + else: + attention_mask = None + + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1:1+grid_height*grid_width] # remove cls embed + x = self.ln_post(x) + # N L D -> N D H W + x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width) + x = self.ffn(x.contiguous()) + x = self.rearrange(x) + _, _, org_h, org_w = x.shape + x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width)) + x = self.conv_out(x) + + return x + +# Keep the original TiTokDecoder as a legacy class that inherits from ResolutionDecoder +class TiTokDecoder(ResolutionDecoder): + """Legacy TiTokDecoder - now inherits from ResolutionDecoder for backward compatibility""" + + def __init__(self, config): + # Override config to disable patch mixture and other advanced features for legacy mode + config_copy = type(config)() + for attr in dir(config): + if not attr.startswith('__'): + try: + setattr(config_copy, attr, getattr(config, attr)) + except: + pass + + # Disable patch mixture for legacy mode + if hasattr(config_copy.model.vq_model, 'patch_mixture_start_layer'): + config_copy.model.vq_model.patch_mixture_start_layer = -1 + if hasattr(config_copy.model.vq_model, 'patch_mixture_end_layer'): + config_copy.model.vq_model.patch_mixture_end_layer = -1 + + super().__init__(config_copy) + + # Override grid_size for legacy compatibility + self.grid_size = self.image_size // self.patch_size + + # Replace ResolutionDecoder's advanced final layers with legacy ones if needed + if self.is_legacy: + self.ffn = nn.Sequential( + nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True), + nn.Tanh(), + nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True), + ) + self.conv_out = nn.Identity() + else: + # Use simpler final layers for backward compatibility + self.ffn = nn.Sequential( + nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True), + Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)', + p1 = self.patch_size, p2 = self.patch_size),) + self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) + + def forward(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True): + # Legacy compatibility: use fixed grid size if height/width not provided + if height is None: + height = self.image_size + if width is None: + width = self.image_size + + # Force decode_patch_size to be the original patch_size for legacy compatibility + if decode_patch_size is None: + decode_patch_size = self.patch_size + + # Use the parent's forward method but with legacy parameters + return super().forward(z_quantized, attention_mask, height, width, decode_patch_size, train) + + +class TATiTokDecoder(ResolutionDecoder): + def __init__(self, config): + super().__init__(config) + scale = self.width ** -0.5 + self.text_context_length = config.model.vq_model.get("text_context_length", 77) + self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768) + self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width) + self.text_guidance_positional_embedding = nn.Parameter(scale * torch.randn(self.text_context_length, self.width)) + + # Add grid_size for backward compatibility + self.grid_size = self.image_size // self.patch_size + + def forward(self, z_quantized, text_guidance, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True): + N, C, H, W = z_quantized.shape + x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD + x = self.decoder_embed(x) + + batchsize, seq_len, _ = x.shape + + # Use fixed grid size for backward compatibility + if height is None: + height = self.image_size + if width is None: + width = self.image_size + if decode_patch_size is None: + decode_patch_size = self.patch_size + + grid_height = height // decode_patch_size + grid_width = width // decode_patch_size + + mask_tokens = self.mask_token.repeat(batchsize, grid_height*grid_width, 1).to(x.dtype) + mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype), + mask_tokens], dim=1) + mask_tokens = mask_tokens + self.positional_embedding(grid_height, grid_width, train=train).to(mask_tokens.dtype) + x = x + self.latent_token_positional_embedding[:seq_len] + x = torch.cat([mask_tokens, x], dim=1) + + text_guidance = self.text_guidance_proj(text_guidance) + text_guidance = text_guidance + self.text_guidance_positional_embedding + x = torch.cat([x, text_guidance], dim=1) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + x = self.transformer[i](x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1:1+grid_height*grid_width] # remove cls embed + x = self.ln_post(x) + # N L D -> N D H W + x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width) + x = self.ffn(x.contiguous()) + x = self.conv_out(x) + return x + + +class WeightTiedLMHead(nn.Module): + def __init__(self, embeddings, target_codebook_size): + super().__init__() + self.weight = embeddings.weight + self.target_codebook_size = target_codebook_size + + def forward(self, x): + # x shape: [batch_size, seq_len, embed_dim] + # Get the weights for the target codebook size + weight = self.weight[:self.target_codebook_size] # Shape: [target_codebook_size, embed_dim] + # Compute the logits by matrix multiplication + logits = torch.matmul(x, weight.t()) # Shape: [batch_size, seq_len, target_codebook_size] + return logits + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + """ + + def __init__( + self, + channels + ): + super().__init__() + self.channels = channels + + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 3 * channels, bias=True) + ) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +class FinalLayer(nn.Module): + """ + The final layer adopted from DiT. + """ + def __init__(self, model_channels, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 2 * model_channels, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class SimpleMLPAdaLN(nn.Module): + """ + The MLP for Diffusion Loss. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param z_channels: channels in the condition. + :param num_res_blocks: number of residual blocks per downsample. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + z_channels, + num_res_blocks, + grad_checkpointing=False, + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.grad_checkpointing = grad_checkpointing + + self.time_embed = TimestepEmbedder(model_channels) + self.cond_embed = nn.Linear(z_channels, model_channels) + + self.input_proj = nn.Linear(in_channels, model_channels) + + res_blocks = [] + for i in range(num_res_blocks): + res_blocks.append(ResBlock( + model_channels, + )) + + self.res_blocks = nn.ModuleList(res_blocks) + self.final_layer = FinalLayer(model_channels, out_channels) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP + nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers + for block in self.res_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, c): + """ + Apply the model to an input batch. + :param x: an [N x C] Tensor of inputs. + :param t: a 1-D batch of timesteps. + :param c: conditioning from AR transformer. + :return: an [N x C] Tensor of outputs. + """ + x = self.input_proj(x) + t = self.time_embed(t) + c = self.cond_embed(c) + + y = t + c + + if self.grad_checkpointing and not torch.jit.is_scripting(): + for block in self.res_blocks: + x = checkpoint(block, x, y) + else: + for block in self.res_blocks: + x = block(x, y) + + return self.final_layer(x, y) + + def forward_with_cfg(self, x, t, c, cfg_scale): + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, c) + eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) \ No newline at end of file diff --git a/modeling/modules/fuzzy_embedding.py b/modeling/modules/fuzzy_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..b0820631a0114b7f0aee125ecc861b83482dd375 --- /dev/null +++ b/modeling/modules/fuzzy_embedding.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import einops +import math + +class FuzzyEmbedding(nn.Module): + def __init__(self, grid_size, scale, width, apply_fuzzy=False): + super(FuzzyEmbedding, self).__init__() + assert grid_size == 1024, "grid_size must be 1024 for now" + + self.grid_size = grid_size + self.scale = scale + self.width = width + self.apply_fuzzy = apply_fuzzy + # grid_size is the minimum possible token size + # then we can use grid_sample to get the fuzzy embedding for any resolution + self.positional_embedding = nn.Parameter( + scale * torch.randn(grid_size, width)) + + self.class_positional_embedding = nn.Parameter( + scale * torch.randn(1, width)) + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, grid_height, grid_width, train=True, dtype=torch.float32): + meshx, meshy = torch.meshgrid( + torch.tensor(list(range(grid_height)), device=self.positional_embedding.device), + torch.tensor(list(range(grid_width)), device=self.positional_embedding.device) + ) + meshx = meshx.to(dtype) + meshy = meshy.to(dtype) + + # Normalize coordinates to [-1, 1] range + meshx = 2 * (meshx / (grid_height - 1)) - 1 + meshy = 2 * (meshy / (grid_width - 1)) - 1 + + if self.apply_fuzzy: + # Add uniform noise in range [-0.0004, 0.0004] to x and y coordinates + if train: + noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004 + noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004 + else: + noise_x = torch.zeros_like(meshx) + noise_y = torch.zeros_like(meshy) + + # Apply noise to the mesh coordinates + meshx = meshx + noise_x + meshy = meshy + noise_y + + grid = torch.stack((meshy, meshx), 2).to(self.positional_embedding.device) + grid = grid.unsqueeze(0) # add batch dim + + positional_embedding = einops.rearrange(self.positional_embedding, "(h w) d -> d h w", h=int(math.sqrt(self.grid_size)), w=int(math.sqrt(self.grid_size))) + positional_embedding = positional_embedding.to(dtype) + positional_embedding = positional_embedding.unsqueeze(0) # add batch dim + + fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False) + fuzzy_embedding = fuzzy_embedding.to(dtype) + fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0) + + final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0) + return final_embedding + + +if __name__ == "__main__": + fuzzy_embedding = FuzzyEmbedding(256, 1.0, 1024) + grid_height = 16 + grid_width = 32 + fuzzy_embedding = fuzzy_embedding(grid_height, grid_width, dtype=torch.bfloat16) + print(fuzzy_embedding.shape) \ No newline at end of file diff --git a/modeling/modules/losses.py b/modeling/modules/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..22aae755b0ba10c0452e6cd243a0e9407c306f06 --- /dev/null +++ b/modeling/modules/losses.py @@ -0,0 +1,339 @@ +"""Training loss implementation. + +Ref: + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py +""" +from typing import Mapping, Text, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.cuda.amp import autocast + +from modeling.modules.blocks import SimpleMLPAdaLN +from .perceptual_loss import PerceptualLoss +from .discriminator import NLayerDiscriminator + + +def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor: + """Hinge loss for discrminator. + + This function is borrowed from + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20 + """ + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def compute_lecam_loss( + logits_real_mean: torch.Tensor, + logits_fake_mean: torch.Tensor, + ema_logits_real_mean: torch.Tensor, + ema_logits_fake_mean: torch.Tensor +) -> torch.Tensor: + """Computes the LeCam loss for the given average real and fake logits. + + Args: + logits_real_mean -> torch.Tensor: The average real logits. + logits_fake_mean -> torch.Tensor: The average fake logits. + ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits. + ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits. + + Returns: + lecam_loss -> torch.Tensor: The LeCam loss. + """ + lecam_loss = torch.mean(torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2)) + lecam_loss += torch.mean(torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2)) + return lecam_loss + + +class ReconstructionLoss_Stage1(torch.nn.Module): + def __init__( + self, + config + ): + super().__init__() + loss_config = config.losses + self.quantizer_weight = loss_config.quantizer_weight + self.target_codebook_size = 1024 + + def forward(self, + target_codes: torch.Tensor, + reconstructions: torch.Tensor, + quantizer_loss: torch.Tensor, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + return self._forward_generator(target_codes, reconstructions, quantizer_loss) + + def _forward_generator(self, + target_codes: torch.Tensor, + reconstructions: torch.Tensor, + quantizer_loss: Mapping[Text, torch.Tensor], + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + reconstructions = reconstructions.contiguous() + loss_fct = nn.CrossEntropyLoss(reduction="mean") + batch_size = reconstructions.shape[0] + reconstruction_loss = loss_fct(reconstructions.view(batch_size, self.target_codebook_size, -1), + target_codes.view(batch_size, -1)) + total_loss = reconstruction_loss + \ + self.quantizer_weight * quantizer_loss["quantizer_loss"] + + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + quantizer_loss=(self.quantizer_weight * quantizer_loss["quantizer_loss"]).detach(), + commitment_loss=quantizer_loss["commitment_loss"].detach(), + codebook_loss=quantizer_loss["codebook_loss"].detach(), + ) + + return total_loss, loss_dict + + +class ReconstructionLoss_Stage2(torch.nn.Module): + def __init__( + self, + config + ): + """Initializes the losses module. + + Args: + config: A dictionary, the configuration for the model and everything else. + """ + super().__init__() + loss_config = config.losses + self.discriminator = NLayerDiscriminator() + + self.reconstruction_loss = loss_config.reconstruction_loss + self.reconstruction_weight = loss_config.reconstruction_weight + self.quantizer_weight = loss_config.quantizer_weight + self.perceptual_loss = PerceptualLoss( + loss_config.perceptual_loss).eval() + self.perceptual_weight = loss_config.perceptual_weight + self.discriminator_iter_start = loss_config.discriminator_start + + self.discriminator_factor = loss_config.discriminator_factor + self.discriminator_weight = loss_config.discriminator_weight + self.lecam_regularization_weight = loss_config.lecam_regularization_weight + self.lecam_ema_decay = loss_config.get("lecam_ema_decay", 0.999) + if self.lecam_regularization_weight > 0.0: + self.register_buffer("ema_real_logits_mean", torch.zeros((1))) + self.register_buffer("ema_fake_logits_mean", torch.zeros((1))) + + self.config = config + + @autocast(enabled=False) + def forward(self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int, + mode: str = "generator", + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + # Both inputs and reconstructions are in range [0, 1]. + inputs = inputs.float() + reconstructions = reconstructions.float() + + if mode == "generator": + return self._forward_generator(inputs, reconstructions, extra_result_dict, global_step) + elif mode == "discriminator": + return self._forward_discriminator(inputs, reconstructions, global_step) + else: + raise ValueError(f"Unsupported mode {mode}") + + def should_discriminator_be_trained(self, global_step : int): + return global_step >= self.discriminator_iter_start + + def _forward_generator(self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Generator training step.""" + inputs = inputs.contiguous() + reconstructions = reconstructions.contiguous() + if self.reconstruction_loss == "l1": + reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") + elif self.reconstruction_loss == "l2": + reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") + else: + raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}") + reconstruction_loss *= self.reconstruction_weight + + # Compute perceptual loss. + perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() + + # Compute discriminator loss. + generator_loss = torch.zeros((), device=inputs.device) + discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0 + d_weight = 1.0 + if discriminator_factor > 0.0 and self.discriminator_weight > 0.0: + # Disable discriminator gradients. + for param in self.discriminator.parameters(): + param.requires_grad = False + logits_fake = self.discriminator(reconstructions) + generator_loss = -torch.mean(logits_fake) + + d_weight *= self.discriminator_weight + + # Compute quantizer loss. + quantizer_loss = extra_result_dict["quantizer_loss"] + total_loss = ( + reconstruction_loss + + self.perceptual_weight * perceptual_loss + + self.quantizer_weight * quantizer_loss + + d_weight * discriminator_factor * generator_loss + ) + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), + quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(), + weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(), + discriminator_factor=torch.tensor(discriminator_factor), + commitment_loss=extra_result_dict["commitment_loss"].detach(), + codebook_loss=extra_result_dict["codebook_loss"].detach(), + d_weight=d_weight, + gan_loss=generator_loss.detach(), + ) + + return total_loss, loss_dict + + def _forward_discriminator(self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + global_step: int, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Discrminator training step.""" + discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0 + loss_dict = {} + # Turn the gradients on. + for param in self.discriminator.parameters(): + param.requires_grad = True + + real_images = inputs.detach().requires_grad_(True) + logits_real = self.discriminator(real_images) + logits_fake = self.discriminator(reconstructions.detach()) + + discriminator_loss = discriminator_factor * hinge_d_loss(logits_real=logits_real, logits_fake=logits_fake) + + # optional lecam regularization + lecam_loss = torch.zeros((), device=inputs.device) + if self.lecam_regularization_weight > 0.0: + lecam_loss = compute_lecam_loss( + torch.mean(logits_real), + torch.mean(logits_fake), + self.ema_real_logits_mean, + self.ema_fake_logits_mean + ) * self.lecam_regularization_weight + + self.ema_real_logits_mean = self.ema_real_logits_mean * self.lecam_ema_decay + torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay) + self.ema_fake_logits_mean = self.ema_fake_logits_mean * self.lecam_ema_decay + torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay) + + discriminator_loss += lecam_loss + + loss_dict = dict( + discriminator_loss=discriminator_loss.detach(), + logits_real=logits_real.detach().mean(), + logits_fake=logits_fake.detach().mean(), + lecam_loss=lecam_loss.detach(), + ) + return discriminator_loss, loss_dict + + +class ReconstructionLoss_Single_Stage(ReconstructionLoss_Stage2): + def __init__( + self, + config + ): + super().__init__(config) + loss_config = config.losses + self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") + + if self.quantize_mode == "vae": + self.kl_weight = loss_config.get("kl_weight", 1e-6) + logvar_init = loss_config.get("logvar_init", 0.0) + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init, requires_grad=False) + + def _forward_generator(self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Generator training step.""" + inputs = inputs.contiguous() + reconstructions = reconstructions.contiguous() + if self.reconstruction_loss == "l1": + reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") + elif self.reconstruction_loss == "l2": + reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") + else: + raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}") + reconstruction_loss *= self.reconstruction_weight + + # Compute perceptual loss. + perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() + + # Compute discriminator loss. + generator_loss = torch.zeros((), device=inputs.device) + discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0 + d_weight = 1.0 + if discriminator_factor > 0.0 and self.discriminator_weight > 0.0: + # Disable discriminator gradients. + for param in self.discriminator.parameters(): + param.requires_grad = False + logits_fake = self.discriminator(reconstructions) + generator_loss = -torch.mean(logits_fake) + + d_weight *= self.discriminator_weight + + if self.quantize_mode in ["vq", "mvq", "softvq"]: + # Compute quantizer loss. + quantizer_loss = extra_result_dict["quantizer_loss"] + total_loss = ( + reconstruction_loss + + self.perceptual_weight * perceptual_loss + + self.quantizer_weight * quantizer_loss + + d_weight * discriminator_factor * generator_loss + ) + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), + quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(), + weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(), + discriminator_factor=torch.tensor(discriminator_factor), + commitment_loss=extra_result_dict["commitment_loss"].detach(), + codebook_loss=extra_result_dict["codebook_loss"].detach(), + d_weight=d_weight, + gan_loss=generator_loss.detach(), + ) + elif self.quantize_mode == "vae": + # Compute kl loss. + reconstruction_loss = reconstruction_loss / torch.exp(self.logvar) + posteriors = extra_result_dict + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + total_loss = ( + reconstruction_loss + + self.perceptual_weight * perceptual_loss + + self.kl_weight * kl_loss + + d_weight * discriminator_factor * generator_loss + ) + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), + kl_loss=(self.kl_weight * kl_loss).detach(), + weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(), + discriminator_factor=torch.tensor(discriminator_factor), + d_weight=d_weight, + gan_loss=generator_loss.detach(), + ) + else: + raise NotImplementedError + + return total_loss, loss_dict \ No newline at end of file diff --git a/modeling/modules/lpips.py b/modeling/modules/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..b204699e1f2bf242dfbfc7eef4d34f2caad0fde7 --- /dev/null +++ b/modeling/modules/lpips.py @@ -0,0 +1,181 @@ +"""LPIPS perceptual loss. + +Reference: + https://github.com/richzhang/PerceptualSimilarity/ + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py + https://github.com/CompVis/taming-transformers/blob/master/taming/util.py +""" + +import os +import hashlib +import requests +from collections import namedtuple +from tqdm import tqdm + +import torch +import torch.nn as nn + +from torchvision import models + +_LPIPS_MEAN = [-0.030, -0.088, -0.188] +_LPIPS_STD = [0.458, 0.448, 0.450] + + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric. + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_pretrained(self): + workspace = os.environ.get('WORKSPACE', '') + VGG_PATH = get_ckpt_path("vgg_lpips", os.path.join(workspace, "models/vgg_lpips.pth"), check=True) + self.load_state_dict(torch.load(VGG_PATH, map_location=torch.device("cpu"), weights_only=True), strict=False) + + def forward(self, input, target): + # Notably, the LPIPS w/ pre-trained weights expect the input in the range of [-1, 1]. + # However, our codebase assumes all inputs are in range of [0, 1], and thus a scaling is needed. + input = input * 2. - 1. + target = target * 2. - 1. + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer("shift", torch.Tensor(_LPIPS_MEAN)[None, :, None, None]) + self.register_buffer("scale", torch.Tensor(_LPIPS_STD)[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv.""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/modeling/modules/maskgit_vqgan.py b/modeling/modules/maskgit_vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..8275b811c83f5ad254d40ace51fd74c3857d34e6 --- /dev/null +++ b/modeling/modules/maskgit_vqgan.py @@ -0,0 +1,346 @@ +"""MaskGIT-VQGAN tokenizer. + +Reference: + https://github.com/huggingface/open-muse/blob/main/muse/modeling_maskgit_vqgan.py +""" + +r"""MaskGIT Tokenizer based on VQGAN. + +This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841] +with several modifications. The non-local layers are removed from VQGAN for +faster speed. +""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn + + +# Conv2D with same padding +class Conv2dSame(nn.Conv2d): + def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ih, iw = x.size()[-2:] + + pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) + pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) + + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return super().forward(x) + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + dropout_prob: float = 0.0, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = Conv2dSame(self.in_channels, self.out_channels_, kernel_size=3, bias=False) + + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True) + self.dropout = nn.Dropout(dropout_prob) + self.conv2 = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=3, bias=False) + + if self.in_channels != self.out_channels_: + self.nin_shortcut = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=1, bias=False) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels_: + residual = self.nin_shortcut(hidden_states) + + return hidden_states + residual + + +class DownsamplingBlock(nn.Module): + def __init__(self, config, block_idx: int): + super().__init__() + + self.config = config + self.block_idx = block_idx + + in_channel_mult = (1,) + tuple(self.config.channel_mult) + block_in = self.config.hidden_channels * in_channel_mult[self.block_idx] + block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] + + res_blocks = nn.ModuleList() + for _ in range(self.config.num_res_blocks): + res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)) + block_in = block_out + self.block = res_blocks + + self.downsample = self.block_idx != self.config.num_resolutions - 1 + + def forward(self, hidden_states): + for res_block in self.block: + hidden_states = res_block(hidden_states) + + if self.downsample: + hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2) + + return hidden_states + + +class UpsamplingBlock(nn.Module): + def __init__(self, config, block_idx: int): + super().__init__() + + self.config = config + self.block_idx = block_idx + + if self.block_idx == self.config.num_resolutions - 1: + block_in = self.config.hidden_channels * self.config.channel_mult[-1] + else: + block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1] + + block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] + + res_blocks = [] + for _ in range(self.config.num_res_blocks): + res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)) + block_in = block_out + self.block = nn.ModuleList(res_blocks) + + self.add_upsample = self.block_idx != 0 + if self.add_upsample: + self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3) + + def forward(self, hidden_states): + for res_block in self.block: + hidden_states = res_block(hidden_states) + + if self.add_upsample: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.upsample_conv(hidden_states) + + return hidden_states + + +class Encoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # downsampling + self.conv_in = Conv2dSame(self.config.num_channels, self.config.hidden_channels, kernel_size=3, bias=False) + + downsample_blocks = [] + for i_level in range(self.config.num_resolutions): + downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level)) + self.down = nn.ModuleList(downsample_blocks) + + # middle + mid_channels = self.config.hidden_channels * self.config.channel_mult[-1] + res_blocks = nn.ModuleList() + for _ in range(self.config.num_res_blocks): + res_blocks.append(ResnetBlock(mid_channels, mid_channels, dropout_prob=self.config.dropout)) + self.mid = res_blocks + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True) + self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1) + + def forward(self, pixel_values): + # downsampling + hidden_states = self.conv_in(pixel_values) + for block in self.down: + hidden_states = block(hidden_states) + + # middle + for block in self.mid: + hidden_states = block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class Decoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + # compute in_channel_mult, block_in and curr_res at lowest res + block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1] + curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1) + self.z_shape = (1, self.config.z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3) + + # middle + res_blocks = nn.ModuleList() + for _ in range(self.config.num_res_blocks): + res_blocks.append(ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout)) + self.mid = res_blocks + + # upsampling + upsample_blocks = [] + for i_level in reversed(range(self.config.num_resolutions)): + upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level)) + self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order + + # end + block_out = self.config.hidden_channels * self.config.channel_mult[0] + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True) + self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3) + + def forward(self, hidden_states): + # z to block_in + hidden_states = self.conv_in(hidden_states) + + # middle + for block in self.mid: + hidden_states = block(hidden_states) + + # upsampling + for block in reversed(self.up): + hidden_states = block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + Discretization bottleneck part of the VQ-VAE. + """ + + def __init__(self, num_embeddings, embedding_dim, commitment_cost): + r""" + Args: + num_embeddings: number of vectors in the quantized space. + embedding_dim: dimensionality of the tensors in the quantized space. + Inputs to the modules must be in this format as well. + commitment_cost: scalar which controls the weighting of the loss terms + (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta). + """ + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.commitment_cost = commitment_cost + + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings) + + def forward(self, hidden_states, return_loss=False): + """ + Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the + closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() + + distances = self.compute_distances(hidden_states) + min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1) + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape) + + # reshape to (batch, num_tokens) + min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1) + + # compute loss for embedding + loss = None + if return_loss: + loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean( + (z_q - hidden_states.detach()) ** 2 + ) + # preserve gradients + z_q = hidden_states + (z_q - hidden_states).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, min_encoding_indices, loss + + def compute_distances(self, hidden_states): + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim)) + emb_weights = self.embedding.weight.t() + + inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True) + codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True) + distances = torch.addmm( + inputs_norm_sq + codebook_t_norm_sq, + hidden_states_flattended, + emb_weights, + alpha=-2.0, + ) + return distances + + def get_codebook_entry(self, indices): + # indices are expected to be of shape (batch, num_tokens) + # get quantized latent vectors + if len(indices.shape) == 2: + batch, num_tokens = indices.shape + z_q = self.embedding(indices) + z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2) + elif len(indices.shape) == 3: + batch, height, width = indices.shape + indices = indices.view(batch, -1) + z_q = self.embedding(indices) + z_q = z_q.reshape(batch, height, width, -1).permute(0, 3, 1, 2) + else: + print(indices.shape) + raise NotImplementedError + return z_q + + # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372 + def get_soft_code(self, hidden_states, temp=1.0, stochastic=False): + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel) + distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings) + + soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings) + if stochastic: + code = torch.multinomial(soft_code, 1) # (batch * height * width, 1) + else: + code = distances.argmin(dim=-1) # (batch * height * width) + + code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width) + batch, num_tokens = code.shape + soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings) + return soft_code, code + + def get_code(self, hidden_states): + # reshape z -> (batch, height, width, channel) + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() + distances = self.compute_distances(hidden_states) + indices = torch.argmin(distances, axis=1).unsqueeze(1) + indices = indices.reshape(hidden_states.shape[0], -1) + return indices \ No newline at end of file diff --git a/modeling/modules/perceptual_loss.py b/modeling/modules/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..92a6f0c8fe334cf17812f2d703acce86cfda1348 --- /dev/null +++ b/modeling/modules/perceptual_loss.py @@ -0,0 +1,101 @@ +"""Perceptual loss module using LPIPS and ConvNeXt-S.""" + +import torch +import torch.nn.functional as F + +from torchvision import models +from .lpips import LPIPS + +_IMAGENET_MEAN = [0.485, 0.456, 0.406] +_IMAGENET_STD = [0.229, 0.224, 0.225] + + +class PerceptualLoss(torch.nn.Module): + def __init__(self, model_name: str = "convnext_s"): + """Initializes the PerceptualLoss class. + + Args: + model_name: A string, the name of the perceptual loss model to use. + + Raise: + ValueError: If the model_name does not contain "lpips" or "convnext_s". + """ + super().__init__() + if ("lpips" not in model_name) and ( + "convnext_s" not in model_name): + raise ValueError(f"Unsupported Perceptual Loss model name {model_name}") + self.lpips = None + self.convnext = None + self.loss_weight_lpips = None + self.loss_weight_convnext = None + + # Parsing the model name. We support name formatted in + # "lpips-convnext_s-{float_number}-{float_number}", where the + # {float_number} refers to the loss weight for each component. + # E.g., lpips-convnext_s-1.0-2.0 refers to compute the perceptual loss + # using both the convnext_s and lpips, and average the final loss with + # (1.0 * loss(lpips) + 2.0 * loss(convnext_s)) / (1.0 + 2.0). + if "lpips" in model_name: + self.lpips = LPIPS().eval() + + if "convnext_s" in model_name: + self.convnext = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).eval() + + if "lpips" in model_name and "convnext_s" in model_name: + loss_config = model_name.split('-')[-2:] + self.loss_weight_lpips, self.loss_weight_convnext = float(loss_config[0]), float(loss_config[1]) + print(f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}") + + self.register_buffer("imagenet_mean", torch.Tensor(_IMAGENET_MEAN)[None, :, None, None]) + self.register_buffer("imagenet_std", torch.Tensor(_IMAGENET_STD)[None, :, None, None]) + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """Computes the perceptual loss. + + Args: + input: A tensor of shape (B, C, H, W), the input image. Normalized to [0, 1]. + target: A tensor of shape (B, C, H, W), the target image. Normalized to [0, 1]. + + Returns: + A scalar tensor, the perceptual loss. + """ + # Always in eval mode. + self.eval() + loss = 0. + num_losses = 0. + lpips_loss = 0. + convnext_loss = 0. + # Computes LPIPS loss, if available. + if self.lpips is not None: + lpips_loss = self.lpips(input, target) + if self.loss_weight_lpips is None: + loss += lpips_loss + num_losses += 1 + else: + num_losses += self.loss_weight_lpips + loss += self.loss_weight_lpips * lpips_loss + + if self.convnext is not None: + # Computes ConvNeXt-s loss, if available. + input = torch.nn.functional.interpolate(input, size=224, mode="bilinear", align_corners=False, antialias=True) + target = torch.nn.functional.interpolate(target, size=224, mode="bilinear", align_corners=False, antialias=True) + pred_input = self.convnext((input - self.imagenet_mean) / self.imagenet_std) + pred_target = self.convnext((target - self.imagenet_mean) / self.imagenet_std) + convnext_loss = torch.nn.functional.mse_loss( + pred_input, + pred_target, + reduction="mean") + + if self.loss_weight_convnext is None: + num_losses += 1 + loss += convnext_loss + else: + num_losses += self.loss_weight_convnext + loss += self.loss_weight_convnext * convnext_loss + + # weighted avg. + loss = loss / num_losses + return loss \ No newline at end of file diff --git a/modeling/quantizer/__init__.py b/modeling/quantizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4707078c10391a8ba7edffcb3e72178adbd5a128 --- /dev/null +++ b/modeling/quantizer/__init__.py @@ -0,0 +1,3 @@ +from .quantizer import VectorQuantizer, DiagonalGaussianDistribution +from .mvq import VectorQuantizerMVQ +from .softvq import SoftVectorQuantizer \ No newline at end of file diff --git a/modeling/quantizer/dist.py b/modeling/quantizer/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..48d7fc9962a26decb036487720ad41e50082984c --- /dev/null +++ b/modeling/quantizer/dist.py @@ -0,0 +1,302 @@ +import datetime +import functools +import os +import sys +from typing import List +from typing import Union + +import pytz +import torch +import torch.distributed as tdist +import torch.multiprocessing as mp + +__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu' +__rank_str_zfill = '0' +__initialized = False + + +def initialized(): + return __initialized + + +def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30): + global __device + if not torch.cuda.is_available(): + print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) + return + elif 'RANK' not in os.environ: + torch.cuda.set_device(gpu_id_if_not_distibuted) + __device = torch.empty(1).cuda().device + print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr) + return + # then 'RANK' must exist + global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() + local_rank = global_rank % num_gpus + torch.cuda.set_device(local_rank) + + # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 + if mp.get_start_method(allow_none=True) is None: + method = 'fork' if fork else 'spawn' + print(f'[dist initialize] mp method={method}') + mp.set_start_method(method) + tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60)) + + global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill + __local_rank = local_rank + __rank, __world_size = tdist.get_rank(), tdist.get_world_size() + __rank_str_zfill = str(__rank).zfill(len(str(__world_size))) + __device = torch.empty(1).cuda().device + __initialized = True + + assert tdist.is_initialized(), 'torch.distributed is not initialized!' + print(f'[lrk={get_local_rank()}, rk={get_rank()}]') + + +def get_rank(): + return __rank + + +def get_rank_str_zfill(): + return __rank_str_zfill + + +def get_local_rank(): + return __local_rank + + +def get_world_size(): + return __world_size + + +def get_device(): + return __device + + +def set_gpu_id(gpu_id: int): + if gpu_id is None: return + global __device + if isinstance(gpu_id, (str, int)): + torch.cuda.set_device(int(gpu_id)) + __device = torch.empty(1).cuda().device + else: + raise NotImplementedError + + +def is_master(): + return __rank == 0 + + +def is_local_master(): + return __local_rank == 0 + + +def new_group(ranks: List[int]): + if __initialized: + return tdist.new_group(ranks=ranks) + return None + + +def new_local_machine_group(): + if __initialized: + cur_subgroup, subgroups = tdist.new_subgroups() + return cur_subgroup + return None + + +def barrier(): + if __initialized: + tdist.barrier() + + +def allreduce(t: torch.Tensor, async_op=False): + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + ret = tdist.all_reduce(cu, async_op=async_op) + t.copy_(cu.cpu()) + else: + ret = tdist.all_reduce(t, async_op=async_op) + return ret + return None + + +def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: + if __initialized: + if not t.is_cuda: + t = t.cuda() + ls = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls, t) + else: + ls = [t] + if cat: + ls = torch.cat(ls, dim=0) + return ls + + +def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: + if __initialized: + if not t.is_cuda: + t = t.cuda() + + t_size = torch.tensor(t.size(), device=t.device) + ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] + tdist.all_gather(ls_size, t_size) + + max_B = max(size[0].item() for size in ls_size) + pad = max_B - t_size[0].item() + if pad: + pad_size = (pad, *t.size()[1:]) + t = torch.cat((t, t.new_empty(pad_size)), dim=0) + + ls_padded = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls_padded, t) + ls = [] + for t, size in zip(ls_padded, ls_size): + ls.append(t[:size[0].item()]) + else: + ls = [t] + if cat: + ls = torch.cat(ls, dim=0) + return ls + + +def broadcast(t: torch.Tensor, src_rank) -> None: + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + tdist.broadcast(cu, src=src_rank) + t.copy_(cu.cpu()) + else: + tdist.broadcast(t, src=src_rank) + + +def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: + if not initialized(): + return torch.tensor([val]) if fmt is None else [fmt % val] + + ts = torch.zeros(__world_size) + ts[__rank] = val + allreduce(ts) + if fmt is None: + return ts + return [fmt % v for v in ts.cpu().numpy().tolist()] + + +def master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + force = kwargs.pop('force', False) + if force or is_master(): + ret = func(*args, **kwargs) + else: + ret = None + barrier() + return ret + return wrapper + + +def local_master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + force = kwargs.pop('force', False) + if force or is_local_master(): + ret = func(*args, **kwargs) + else: + ret = None + barrier() + return ret + return wrapper + + +def for_visualize(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_master(): + # with torch.no_grad(): + ret = func(*args, **kwargs) + else: + ret = None + return ret + return wrapper + + +def finalize(): + if __initialized: + tdist.destroy_process_group() + + +def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30): + try: + __initialize(fork=False, timeout_minutes=timeout_minutes) + barrier() + except RuntimeError as e: + print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True) + raise e + + if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True) + _change_builtin_print(is_local_master()) + if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path): + sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False) + + +def _change_builtin_print(is_master): + import builtins as __builtin__ + + builtin_print = __builtin__.print + if type(builtin_print) != type(open): + return + + def prt(*args, **kwargs): + force = kwargs.pop('force', False) + clean = kwargs.pop('clean', False) + deeper = kwargs.pop('deeper', False) + if is_master or force: + if not clean: + f_back = sys._getframe().f_back + if deeper and f_back.f_back is not None: + f_back = f_back.f_back + file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] + time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') + builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs) + else: + builtin_print(*args, **kwargs) + + __builtin__.print = prt + + +class BackupStreamToFile(object): + def __init__(self, local_output_dir, for_stdout=True): + self.for_stdout = for_stdout + self.terminal_stream = sys.stdout if for_stdout else sys.stderr + fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt') + existing = os.path.exists(fname) + self.file_stream = open(fname, 'a') + if existing: + time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') + self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n') + self.file_stream.flush() + self.enabled = True + + def write(self, message): + self.terminal_stream.write(message) + self.file_stream.write(message) + + def flush(self): + self.terminal_stream.flush() + self.file_stream.flush() + + def close(self): + if not self.enabled: + return + self.enabled = False + self.file_stream.flush() + self.file_stream.close() + if self.for_stdout: + sys.stdout = self.terminal_stream + sys.stdout.flush() + else: + sys.stderr = self.terminal_stream + sys.stderr.flush() + + def __del__(self): + self.close() \ No newline at end of file diff --git a/modeling/quantizer/mvq.py b/modeling/quantizer/mvq.py new file mode 100644 index 0000000000000000000000000000000000000000..97dd699299b65af40bfc0690ff240f749a71ad29 --- /dev/null +++ b/modeling/quantizer/mvq.py @@ -0,0 +1,159 @@ +import torch +from typing import List, Tuple +from torch.nn import functional as F +from torch import distributed as tdist, nn as nn + +from .quantizer import VectorQuantizer + +def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau): + E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False) + E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size) + logits = -E_dist.float().mul_(inv_entropy_tau) + # calc per_sample_entropy + prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size) + per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1)) + # calc codebook_entropy + avg_prob = prob.mean(dim=0) # (vocab_size,) + log_avg_prob = torch.log(avg_prob + 1e-7) + codebook_entropy = (-avg_prob * log_avg_prob).sum() + # calc entropy_loss + entropy_loss = per_sample_entropy - codebook_entropy + return entropy_loss + + +class NormalizedEmbedding(nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + # self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32)) + + def forward(self, idx): + return F.embedding( + idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) + + def get_norm_weight(self): + return F.normalize(self.weight, dim=1) + + +class ResConv(nn.Conv2d): + def __init__(self, embed_dim, quant_resi): + ks = 3 if quant_resi < 0 else 1 + super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2) + self.resi_ratio = abs(quant_resi) + + def forward(self, h_BChw): + return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio) + +class VectorQuantizerMVQ(nn.Module): + def __init__( + self, + codebook_size, + token_size, + commitment_cost=0.25, + use_l2_norm=False, + # entropy_temp=0.01, # we do not use this + clustering_vq=False, + num_codebooks=16 + ): + super().__init__() + self.num_codebooks = num_codebooks + self.codebooks = nn.ModuleList() + for _ in range(num_codebooks): + codebook = VectorQuantizer( + codebook_size=codebook_size // num_codebooks, + token_size=token_size // num_codebooks, + commitment_cost=commitment_cost, + use_l2_norm=use_l2_norm, + clustering_vq=clustering_vq, + ) + self.codebooks.append(codebook) + + def init_vocab(self, eini: float): + for codebook in self.codebooks: + codebook.init_vocab(eini) + + def f_to_idx(self, features): + indices = [] + chunk_size = features.shape[-1] // self.num_codebooks + splited_features = features.split(chunk_size, dim=-1) + for i, codebook in enumerate(self.codebooks): + indices.append(codebook.f_to_idx(splited_features[i])) + indices = torch.stack(indices, dim=1) + return indices + + def idx_to_f(self, indices): + assert indices.shape[1] == self.num_codebooks + latent_features = [] + for i, codebook in enumerate(self.codebooks): + sub_indices = indices[:, i].flatten(start_dim=1) + latent_feature = codebook.codebook(sub_indices) + latent_features.append(latent_feature) + latent_features = torch.cat(latent_features, dim=-1) + return latent_features + + def get_codebook_entry(self, indices): + """Get codebook entries for multi-codebook indices. + + Args: + indices: Tensor of shape (N, num_codebooks) or (N, num_codebooks, H, W) + + Returns: + z_quantized: Quantized features + """ + if len(indices.shape) == 2: + # indices shape: (N, num_codebooks) + latent_features = [] + for i, codebook in enumerate(self.codebooks): + sub_indices = indices[:, i] + latent_feature = codebook.get_codebook_entry(sub_indices) + latent_features.append(latent_feature) + return torch.cat(latent_features, dim=-1) + elif len(indices.shape) == 4: + # indices shape: (B, num_codebooks, H, W) + batch_size, _, height, width = indices.shape + latent_features = [] + for i, codebook in enumerate(self.codebooks): + sub_indices = indices[:, i] # (B, H, W) + latent_feature = codebook.get_codebook_entry(sub_indices.flatten()) + # Reshape to (B, H, W, token_size // num_codebooks) + latent_feature = latent_feature.view(batch_size, height, width, -1) + latent_features.append(latent_feature) + # Concatenate along the last dimension and rearrange to (B, C, H, W) + latent_features = torch.cat(latent_features, dim=-1) # (B, H, W, C) + return latent_features.permute(0, 3, 1, 2).contiguous() # (B, C, H, W) + else: + raise NotImplementedError(f"Unsupported indices shape: {indices.shape}") + + def forward(self, features): + latent_features = [] + all_result_dicts = [] + chunk_size = features.shape[1] // self.num_codebooks + splited_features = features.split(chunk_size, dim=1) + + for i, codebook in enumerate(self.codebooks): + latent_feature, result_dict = codebook(splited_features[i].float()) + latent_features.append(latent_feature.to(features.dtype)) + all_result_dicts.append(result_dict) + + # Concatenate latent features + z_quantized = torch.cat(latent_features, dim=1) # Concatenate along channel dimension + + # Calculate global losses + global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks + global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks + global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks + + # Collect all min_encoding_indices + # Each codebook returns indices of shape (B, H, W) + # Stack them to get shape (B, num_codebooks, H, W) + all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1) + + result_dict = dict( + quantizer_loss=global_quantizer_loss, + commitment_loss=global_commitment_loss, + codebook_loss=global_codebook_loss, + min_encoding_indices=all_indices + ) + + return z_quantized, result_dict \ No newline at end of file diff --git a/modeling/quantizer/quantizer.py b/modeling/quantizer/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..db529f1dfe0c8584d0af1a50ff4d3c4ab3949c81 --- /dev/null +++ b/modeling/quantizer/quantizer.py @@ -0,0 +1,158 @@ +"""Vector quantizer. + +Reference: + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py + https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py + https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py + https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py +""" +from typing import Mapping, Text, Tuple + +import torch +from einops import rearrange +from accelerate.utils.operations import gather +from torch.cuda.amp import autocast + +class VectorQuantizer(torch.nn.Module): + def __init__(self, + codebook_size: int = 1024, + token_size: int = 256, + commitment_cost: float = 0.25, + use_l2_norm: bool = False, + clustering_vq: bool = False + ): + super().__init__() + self.codebook_size = codebook_size + self.token_size = token_size + self.commitment_cost = commitment_cost + + self.embedding = torch.nn.Embedding(codebook_size, token_size) + self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) + self.use_l2_norm = use_l2_norm + + self.clustering_vq = clustering_vq + if clustering_vq: + self.decay = 0.99 + self.register_buffer("embed_prob", torch.zeros(self.codebook_size)) + + # Ensure quantization is performed using f32 + @autocast(enabled=False) + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + z = z.float() + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = rearrange(z, 'b h w c -> (b h w) c') + unnormed_z_flattened = z_flattened + + if self.use_l2_norm: + z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) + embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) + else: + embedding = self.embedding.weight + d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ + torch.sum(embedding**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, embedding.T) + + min_encoding_indices = torch.argmin(d, dim=1) # num_ele + z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) + + if self.use_l2_norm: + z = torch.nn.functional.normalize(z, dim=-1) + + # compute loss for embedding + commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2) + codebook_loss = torch.mean((z_quantized - z.detach()) **2) + + if self.clustering_vq and self.training: + with torch.no_grad(): + # Gather distance matrix from all GPUs. + encoding_indices = gather(min_encoding_indices) + if len(min_encoding_indices.shape) != 1: + raise ValueError(f"min_encoding_indices in a wrong shape, {min_encoding_indices.shape}") + # Compute and update the usage of each entry in the codebook. + encodings = torch.zeros(encoding_indices.shape[0], self.codebook_size, device=z.device) + encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) + avg_probs = torch.mean(encodings, dim=0) + self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1-self.decay) + # Closest sampling to update the codebook. + all_d = gather(d) + all_unnormed_z_flattened = gather(unnormed_z_flattened).detach() + if all_d.shape[0] != all_unnormed_z_flattened.shape[0]: + raise ValueError( + "all_d and all_unnormed_z_flattened have different length" + + f"{all_d.shape}, {all_unnormed_z_flattened.shape}") + indices = torch.argmin(all_d, dim=0) + random_feat = all_unnormed_z_flattened[indices] + # Decay parameter based on the average usage. + decay = torch.exp(-(self.embed_prob * self.codebook_size * 10) / + (1 - self.decay) - 1e-3).unsqueeze(1).repeat(1, self.token_size) + self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay + + loss = commitment_loss + codebook_loss + + # preserve gradients + z_quantized = z + (z_quantized - z).detach() + + # reshape back to match original input shape + z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() + + result_dict = dict( + quantizer_loss=loss, + commitment_loss=commitment_loss, + codebook_loss=codebook_loss, + min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) + ) + + return z_quantized, result_dict + + @autocast(enabled=False) + def get_codebook_entry(self, indices): + indices = indices.long() + if len(indices.shape) == 1: + z_quantized = self.embedding(indices) + elif len(indices.shape) == 2: + z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight) + else: + raise NotImplementedError + if self.use_l2_norm: + z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) + return z_quantized + + +class DiagonalGaussianDistribution(object): + @autocast(enabled=False) + def __init__(self, parameters, deterministic=False): + """Initializes a Gaussian distribution instance given the parameters. + + Args: + parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected + to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension. + First C channels are used for mean and last C are used for logvar in the Gaussian distribution. + deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results + is purely based on mean (i.e., std = 0). + """ + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + @autocast(enabled=False) + def sample(self): + x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + @autocast(enabled=False) + def mode(self): + return self.mean + + @autocast(enabled=False) + def kl(self): + if self.deterministic: + return torch.Tensor([0.]) + else: + return 0.5 * torch.sum(torch.pow(self.mean.float(), 2) + + self.var.float() - 1.0 - self.logvar.float(), + dim=[1, 2]) diff --git a/modeling/quantizer/softvq.py b/modeling/quantizer/softvq.py new file mode 100644 index 0000000000000000000000000000000000000000..68ad4413a9b2b78e1e44d81e0bd2158f3d7d9db5 --- /dev/null +++ b/modeling/quantizer/softvq.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Mapping, Text, Tuple +from einops import rearrange +from torch.cuda.amp import autocast + + +class SoftVectorQuantizer(torch.nn.Module): + def __init__(self, + codebook_size: int = 1024, + token_size: int = 256, + commitment_cost: float = 0.25, + use_l2_norm: bool = False, + clustering_vq: bool = False, + entropy_loss_ratio: float = 0.01, + tau: float = 0.07, + num_codebooks: int = 1, + show_usage: bool = False + ): + super().__init__() + # Map new parameter names to internal names for compatibility + self.codebook_size = codebook_size + self.token_size = token_size + self.commitment_cost = commitment_cost + self.use_l2_norm = use_l2_norm + self.clustering_vq = clustering_vq + + # Keep soft quantization specific parameters + self.num_codebooks = num_codebooks + self.n_e = codebook_size + self.e_dim = token_size + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = use_l2_norm + self.show_usage = show_usage + self.tau = tau + + # Single embedding layer for all codebooks + self.embedding = nn.Parameter(torch.randn(num_codebooks, codebook_size, token_size)) + self.embedding.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + if self.l2_norm: + self.embedding.data = F.normalize(self.embedding.data, p=2, dim=-1) + + if self.show_usage: + self.register_buffer("codebook_used", torch.zeros(num_codebooks, 65536)) + + # Ensure quantization is performed using f32 + @autocast(enabled=False) + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + z = z.float() + original_shape = z.shape + + # Handle input reshaping to match VectorQuantizer format + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z = z.view(z.size(0), -1, z.size(-1)) + + batch_size, seq_length, _ = z.shape + + # Ensure sequence length is divisible by number of codebooks + assert seq_length % self.num_codebooks == 0, \ + f"Sequence length ({seq_length}) must be divisible by number of codebooks ({self.num_codebooks})" + + segment_length = seq_length // self.num_codebooks + z_segments = z.view(batch_size, self.num_codebooks, segment_length, self.e_dim) + + # Apply L2 norm if needed + embedding = F.normalize(self.embedding, p=2, dim=-1) if self.l2_norm else self.embedding + if self.l2_norm: + z_segments = F.normalize(z_segments, p=2, dim=-1) + + z_flat = z_segments.permute(1, 0, 2, 3).contiguous().view(self.num_codebooks, -1, self.e_dim) + + logits = torch.einsum('nbe, nke -> nbk', z_flat, embedding.detach()) + + # Calculate probabilities (soft quantization) + probs = F.softmax(logits / self.tau, dim=-1) + + # Soft quantize + z_q = torch.einsum('nbk, nke -> nbe', probs, embedding) + + # Reshape back + z_q = z_q.view(self.num_codebooks, batch_size, segment_length, self.e_dim).permute(1, 0, 2, 3).contiguous() + + # Calculate cosine similarity + with torch.no_grad(): + zq_z_cos = F.cosine_similarity( + z_segments.view(-1, self.e_dim), + z_q.view(-1, self.e_dim), + dim=-1 + ).mean() + + # Get indices for usage tracking + indices = torch.argmax(probs, dim=-1) # (num_codebooks, batch_size * segment_length) + indices = indices.transpose(0, 1).contiguous() # (batch_size * segment_length, num_codebooks) + + # Track codebook usage + if self.show_usage and self.training: + for k in range(self.num_codebooks): + cur_len = indices.size(0) + self.codebook_used[k, :-cur_len].copy_(self.codebook_used[k, cur_len:].clone()) + self.codebook_used[k, -cur_len:].copy_(indices[:, k]) + + # Calculate losses if training + if self.training: + # Soft quantization doesn't have traditional commitment/codebook loss + # Map entropy loss to quantizer_loss for compatibility + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(logits.view(-1, self.n_e)) + quantizer_loss = entropy_loss + commitment_loss = torch.tensor(0.0, device=z.device) + codebook_loss = torch.tensor(0.0, device=z.device) + else: + quantizer_loss = torch.tensor(0.0, device=z.device) + commitment_loss = torch.tensor(0.0, device=z.device) + codebook_loss = torch.tensor(0.0, device=z.device) + + # Calculate codebook usage + codebook_usage = torch.tensor([ + len(torch.unique(self.codebook_used[k])) / self.n_e + for k in range(self.num_codebooks) + ]).mean() if self.show_usage else 0 + + z_q = z_q.view(batch_size, -1, self.e_dim) + + # Reshape back to original input shape to match VectorQuantizer + z_q = z_q.view(batch_size, original_shape[2], original_shape[3], original_shape[1]) + z_quantized = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + # Calculate average probabilities + avg_probs = torch.mean(torch.mean(probs, dim=-1)) + max_probs = torch.mean(torch.max(probs, dim=-1)[0]) + + # Return format matching VectorQuantizer + result_dict = dict( + quantizer_loss=quantizer_loss, + commitment_loss=commitment_loss, + codebook_loss=codebook_loss, + min_encoding_indices=indices.view(batch_size, self.num_codebooks, segment_length).view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) + ) + + return z_quantized, result_dict + + def get_codebook_entry(self, indices): + """Added for compatibility with VectorQuantizer API""" + if len(indices.shape) == 1: + # For single codebook case + z_quantized = self.embedding[0][indices] + elif len(indices.shape) == 2: + z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding[0]) + else: + raise NotImplementedError + if self.use_l2_norm: + z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) + return z_quantized + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-6)) + sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss \ No newline at end of file diff --git a/modeling/vibetoken_model.py b/modeling/vibetoken_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4a77d2a60f9caf3074a5c6d4f1a6c89980ee488b --- /dev/null +++ b/modeling/vibetoken_model.py @@ -0,0 +1,219 @@ +"""VibeToken model definition.""" + +import torch +import torch.nn as nn +from einops import rearrange + +from modeling.modules.base_model import BaseModel +from modeling.modules.encoder_decoder import ResolutionEncoder, ResolutionDecoder +from modeling.quantizer import VectorQuantizer, DiagonalGaussianDistribution, VectorQuantizerMVQ, SoftVectorQuantizer +from modeling.modules.maskgit_vqgan import Encoder as Pixel_Eecoder +from modeling.modules.maskgit_vqgan import Decoder as Pixel_Decoder +from modeling.modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer +import json +from omegaconf import OmegaConf +from pathlib import Path + +from huggingface_hub import PyTorchModelHubMixin + + +class PretrainedTokenizer(nn.Module): + def __init__(self, pretrained_weight): + super().__init__() + conf = OmegaConf.create( + {"channel_mult": [1, 1, 2, 2, 4], + "num_resolutions": 5, + "dropout": 0.0, + "hidden_channels": 128, + "num_channels": 3, + "num_res_blocks": 2, + "resolution": 256, + "z_channels": 256}) + self.encoder = Pixel_Eecoder(conf) + self.decoder = Pixel_Decoder(conf) + self.quantize = Pixel_Quantizer( + num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) + # Load pretrained weights + self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True) + + self.eval() + for param in self.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode(self, x): + hidden_states = self.encoder(x) + quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states) + return codebook_indices.detach() + + @torch.no_grad() + def decode(self, codes): + quantized_states = self.quantize.get_codebook_entry(codes) + rec_images = self.decoder(quantized_states) + rec_images = torch.clamp(rec_images, 0.0, 1.0) + return rec_images.detach() + + @torch.no_grad() + def decode_tokens(self, codes): + return self.decode(codes) + + +class VibeTokenModel(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"]): + def __init__(self, config): + + if isinstance(config, dict): + config = OmegaConf.create(config) + + super().__init__() + self.config = config + # This should be False for stage1 and True for stage2. + self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True) + + self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") + if self.quantize_mode not in ["vq", "vae", "softvq", "mvq"]: + raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.") + + if self.finetune_decoder and self.quantize_mode not in ["vq", "softvq", "mvq"]: + raise ValueError("Only supprot finetune_decoder with vq quantization for now.") + + self.encoder = ResolutionEncoder(config) + self.decoder = ResolutionDecoder(config) + + self.num_latent_tokens = config.model.vq_model.num_latent_tokens + scale = self.encoder.width ** -0.5 + self.latent_tokens = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.encoder.width)) + + self.apply(self._init_weights) + + if self.quantize_mode == "vq": + self.quantize = VectorQuantizer( + codebook_size=config.model.vq_model.codebook_size, + token_size=config.model.vq_model.token_size, + commitment_cost=config.model.vq_model.commitment_cost, + use_l2_norm=config.model.vq_model.use_l2_norm,) + elif self.quantize_mode == "vae": + self.quantize = DiagonalGaussianDistribution + elif self.quantize_mode == "mvq": + self.quantize = VectorQuantizerMVQ( + codebook_size=config.model.vq_model.codebook_size, + token_size=config.model.vq_model.token_size, + commitment_cost=config.model.vq_model.commitment_cost, + use_l2_norm=config.model.vq_model.use_l2_norm, + num_codebooks=config.model.vq_model.num_codebooks, + ) + elif self.quantize_mode == "softvq": + self.quantize = SoftVectorQuantizer( + codebook_size=config.model.vq_model.codebook_size, + token_size=config.model.vq_model.token_size, + commitment_cost=config.model.vq_model.commitment_cost, + use_l2_norm=config.model.vq_model.use_l2_norm, + num_codebooks=config.model.vq_model.num_codebooks, + ) + else: + raise NotImplementedError + + if self.finetune_decoder: + # Freeze encoder/quantizer/latent tokens + self.latent_tokens.requires_grad_(False) + self.encoder.eval() + self.encoder.requires_grad_(False) + self.quantize.eval() + self.quantize.requires_grad_(False) + + # Include MaskGiT-VQGAN's quantizer and decoder + self.pixel_quantize = Pixel_Quantizer( + num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) + self.pixel_decoder = Pixel_Decoder(OmegaConf.create( + {"channel_mult": [1, 1, 2, 2, 4], + "num_resolutions": 5, + "dropout": 0.0, + "hidden_channels": 128, + "num_channels": 3, + "num_res_blocks": 2, + "resolution": 256, + "z_channels": 256})) + + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights and config to a local directory.""" + # Assume 'self.config' is your DictConfig object + # Convert to a regular dictionary + dict_config = OmegaConf.to_container(self.config) + # Save as JSON + file_path = Path(save_directory) / "config.json" + with open(file_path, 'w') as json_file: + json.dump(dict_config, json_file, indent=4) + super()._save_pretrained(save_directory) + + def _init_weights(self, module): + """ Initialize the weights. + :param: + module -> torch.nn.Module: module to initialize + """ + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d): + module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def encode(self, x, attention_mask=None, encode_patch_size=None, train=True, length=None): + if self.finetune_decoder: + with torch.no_grad(): + self.encoder.eval() + self.quantize.eval() + z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train) + z_quantized, result_dict = self.quantize(z) + result_dict["quantizer_loss"] *= 0 + result_dict["commitment_loss"] *= 0 + result_dict["codebook_loss"] *= 0 + else: + if length is not None: + attention_mask = None + latent_tokens = self.latent_tokens[:length+1] + else: + latent_tokens = self.latent_tokens + z = self.encoder(pixel_values=x, latent_tokens=latent_tokens, attention_mask=attention_mask, encode_patch_size=encode_patch_size, train=train) + if self.quantize_mode in ["vq", "mvq", "softvq"]: + z_quantized, result_dict = self.quantize(z) + elif self.quantize_mode == "vae": + posteriors = self.quantize(z) + z_quantized = posteriors.sample() + result_dict = posteriors + + return z_quantized, result_dict + + def decode(self, z_quantized, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True): + decoded = self.decoder(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train) + if self.finetune_decoder: + quantized_states = torch.einsum( + 'nchw,cd->ndhw', decoded.softmax(1), + self.pixel_quantize.embedding.weight) + decoded = self.pixel_decoder(quantized_states) + return decoded + + def decode_tokens(self, tokens, attention_mask=None, height=None, width=None, decode_patch_size=None, train=True): + if self.quantize_mode in ["vq", "softvq"]: + tokens = tokens.squeeze(1) + batch, seq_len = tokens.shape # B x N + z_quantized = self.quantize.get_codebook_entry( + tokens.reshape(-1)).reshape(batch, 1, seq_len, -1) + z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() + elif self.quantize_mode == "mvq": + z_quantized = self.quantize.get_codebook_entry(tokens) + elif self.quantize_mode == "vae": + z_quantized = tokens + z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype) + decoded = self.decode(z_quantized, attention_mask=attention_mask, height=height, width=width, decode_patch_size=decode_patch_size, train=train) + return decoded + + def forward(self, x, key_attention_mask=None, height=None, width=None, train=True): + if height is None: + batch_size, channels, height, width = x.shape + z_quantized, result_dict = self.encode(x, attention_mask=key_attention_mask, train=train) + z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype) + decoded = self.decode(z_quantized, attention_mask=key_attention_mask, height=height, width=width, train=train) + return decoded, result_dict \ No newline at end of file diff --git a/reconstruct.py b/reconstruct.py new file mode 100644 index 0000000000000000000000000000000000000000..29e9fe4ef0fe5aaef4a7bab363c861af16617c63 --- /dev/null +++ b/reconstruct.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +"""Simple reconstruction script for VibeToken. + +Usage: + # Auto mode (recommended) - automatically determines optimal settings + python reconstruct.py --auto \ + --config configs/vibetoken_ll.yaml \ + --checkpoint /path/to/checkpoint.bin \ + --image assets/example_1.jpg \ + --output assets/reconstructed.png + + # Manual mode - specify all parameters + python reconstruct.py \ + --config configs/vibetoken_ll.yaml \ + --checkpoint /path/to/checkpoint.bin \ + --image assets/example_1.jpg \ + --output assets/reconstructed.png \ + --input_height 512 --input_width 512 \ + --encoder_patch_size 16,32 \ + --decoder_patch_size 16 +""" + +import argparse +from PIL import Image +from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple + + +def parse_patch_size(value): + """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32').""" + if value is None: + return None + if ',' in value: + parts = value.split(',') + return (int(parts[0]), int(parts[1])) + return int(value) + + +def main(): + parser = argparse.ArgumentParser(description="VibeToken image reconstruction") + parser.add_argument("--config", type=str, default="configs/vibetoken_ll.yaml", + help="Path to config YAML") + parser.add_argument("--checkpoint", type=str, required=True, + help="Path to model checkpoint") + parser.add_argument("--image", type=str, default="assets/example_1.jpg", + help="Path to input image") + parser.add_argument("--output", type=str, default="./assets/reconstructed.png", + help="Path to output image") + parser.add_argument("--device", type=str, default="cuda", + help="Device (cuda/cpu)") + + # Auto mode + parser.add_argument("--auto", action="store_true", + help="Auto mode: automatically determine optimal input resolution and patch sizes") + + # Input resolution (optional - resize input before encoding) + parser.add_argument("--input_height", type=int, default=None, + help="Resize input to this height before encoding (default: original)") + parser.add_argument("--input_width", type=int, default=None, + help="Resize input to this width before encoding (default: original)") + + # Output resolution (optional - decode to this size) + parser.add_argument("--output_height", type=int, default=None, + help="Decode to this height (default: same as input)") + parser.add_argument("--output_width", type=int, default=None, + help="Decode to this width (default: same as input)") + + # Patch sizes (optional) - supports single int or tuple like "16,32" + parser.add_argument("--encoder_patch_size", type=str, default=None, + help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") + parser.add_argument("--decoder_patch_size", type=str, default=None, + help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") + + args = parser.parse_args() + + # Load tokenizer + print(f"Loading tokenizer from {args.config}") + tokenizer = VibeTokenTokenizer.from_config( + args.config, + args.checkpoint, + device=args.device, + ) + + # Load image + print(f"Loading image from {args.image}") + image = Image.open(args.image).convert("RGB") + original_size = image.size # (W, H) + print(f"Original image size: {original_size[0]}x{original_size[1]}") + + if args.auto: + # AUTO MODE - use centralized auto_preprocess_image + print("\n=== AUTO MODE ===") + image, patch_size, info = auto_preprocess_image(image, verbose=True) + input_width, input_height = info["cropped_size"] + output_width, output_height = input_width, input_height + encoder_patch_size = patch_size + decoder_patch_size = patch_size + print("=================\n") + + else: + # MANUAL MODE + # Parse patch sizes + encoder_patch_size = parse_patch_size(args.encoder_patch_size) + decoder_patch_size = parse_patch_size(args.decoder_patch_size) + + # Resize input if specified + if args.input_width or args.input_height: + input_width = args.input_width or original_size[0] + input_height = args.input_height or original_size[1] + print(f"Resizing input to {input_width}x{input_height}") + image = image.resize((input_width, input_height), Image.LANCZOS) + + # Always center crop to ensure dimensions divisible by 32 + image = center_crop_to_multiple(image, multiple=32) + input_width, input_height = image.size + if (input_width, input_height) != original_size: + print(f"Center cropped to {input_width}x{input_height} (divisible by 32)") + + # Determine output size + output_height = args.output_height or input_height + output_width = args.output_width or input_width + + # Encode image to tokens + print("Encoding image to tokens...") + if encoder_patch_size: + print(f" Using encoder patch size: {encoder_patch_size}") + tokens = tokenizer.encode(image, patch_size=encoder_patch_size) + print(f"Token shape: {tokens.shape}") + + # Decode back to image + print(f"Decoding to {output_width}x{output_height}...") + if decoder_patch_size: + print(f" Using decoder patch size: {decoder_patch_size}") + reconstructed = tokenizer.decode( + tokens, + height=output_height, + width=output_width, + patch_size=decoder_patch_size + ) + print(f"Reconstructed shape: {reconstructed.shape}") + + # Convert tensor to PIL and save + output_images = tokenizer.to_pil(reconstructed) + output_images[0].save(args.output) + print(f"Saved reconstructed image to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..756560a57d27a87a7e3e5a847181fde70663474c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +spaces +torch>=2.0.0 +torchvision +einops>=0.6.0 +omegaconf>=2.3.0 +pillow>=9.0.0 +numpy>=1.20.0 +huggingface_hub>=0.16.0 +accelerate +wandb +webdataset +timm +open_clip_torch +transformers +scipy +torch-fidelity +torchinfo +termcolor +iopath +opencv-python +diffusers +gdown +tqdm +requests +datasets +gradio>=4.0.0 \ No newline at end of file diff --git a/scripts/train_vibetoken.py b/scripts/train_vibetoken.py new file mode 100644 index 0000000000000000000000000000000000000000..241bb12540474854a43048331036975d5e6b8705 --- /dev/null +++ b/scripts/train_vibetoken.py @@ -0,0 +1,223 @@ +"""Training script for VibeToken. + +Reference: + https://github.com/huggingface/open-muse +""" +import math +import os +import sys +from pathlib import Path +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) +sys.path.append(parent_dir) + +from accelerate.utils import set_seed +from accelerate import Accelerator + +import torch +import wandb +from omegaconf import OmegaConf +from utils.logger import setup_logger + +from utils.train_utils import ( + get_config, create_pretrained_tokenizer, + create_model_and_loss_module, + create_optimizer, create_lr_scheduler, create_dataloader, + create_evaluator, auto_resume, save_checkpoint, + train_one_epoch) + + +def main(): + workspace = os.environ.get('WORKSPACE', '') + if workspace: + torch.hub.set_dir(workspace + "/models/hub") + + config = get_config() + # Enable TF32 on Ampere GPUs. + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + output_dir = config.experiment.output_dir + os.makedirs(output_dir, exist_ok=True) + config.experiment.logging_dir = os.path.join(output_dir, "logs") + + # Whether logging to Wandb or Tensorboard. + tracker = "tensorboard" + if config.training.enable_wandb: + tracker = "wandb" + + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with=tracker, + project_dir=config.experiment.logging_dir, + split_batches=False, + ) + + logger = setup_logger(name="VibeToken", log_level="INFO", + output_file=f"{output_dir}/log{accelerator.process_index}.txt") + + if accelerator.is_main_process: + if config.training.enable_wandb: + wandb_config = config.training.get("wandb", {}) + wandb_project = wandb_config.get("project", config.experiment.project) + wandb_entity = wandb_config.get("entity", None) + wandb_name = wandb_config.get("name", config.experiment.name) + wandb_tags = list(wandb_config.get("tags", [])) + wandb_notes = wandb_config.get("notes", None) + wandb_resume_id = wandb_config.get("resume_id", None) + + wandb_init_kwargs = { + "wandb": { + "name": wandb_name, + "dir": output_dir, + "resume": "allow", + } + } + if wandb_entity: + wandb_init_kwargs["wandb"]["entity"] = wandb_entity + if wandb_tags: + wandb_init_kwargs["wandb"]["tags"] = wandb_tags + if wandb_notes: + wandb_init_kwargs["wandb"]["notes"] = wandb_notes + if wandb_resume_id: + wandb_init_kwargs["wandb"]["id"] = wandb_resume_id + + accelerator.init_trackers( + project_name=wandb_project, + config=OmegaConf.to_container(config, resolve=True), + init_kwargs=wandb_init_kwargs, + ) + logger.info(f"WandB initialized - Project: {wandb_project}, Name: {wandb_name}") + else: + accelerator.init_trackers(config.experiment.name) + + config_path = Path(output_dir) / "config.yaml" + logger.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + logger.info(f"Config:\n{OmegaConf.to_yaml(config)}") + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed, device_specific=True) + + accelerator.wait_for_everyone() + + # Create pretrained tokenizer in a synchronized manner + if config.model.vq_model.is_legacy: + if accelerator.is_main_process: + logger.info("Creating pretrained tokenizer on main process...") + accelerator.wait_for_everyone() + pretrained_tokenizer = create_pretrained_tokenizer(config, accelerator) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info("Pretrained tokenizer creation completed.") + else: + pretrained_tokenizer = None + + if accelerator.is_main_process: + logger.info("Creating model and loss module...") + accelerator.wait_for_everyone() + + model, ema_model, loss_module = create_model_and_loss_module( + config, logger, accelerator, model_type="vibetoken") + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info("Model creation completed.") + + optimizer, discriminator_optimizer = create_optimizer(config, logger, model, loss_module, model_type="vibetoken") + + lr_scheduler, discriminator_lr_scheduler = create_lr_scheduler( + config, logger, accelerator, optimizer, discriminator_optimizer) + + if accelerator.is_main_process: + logger.info("Creating dataloaders...") + train_dataloader, eval_dataloader = create_dataloader(config, logger, accelerator) + accelerator.wait_for_everyone() + + # Set up evaluator. + if accelerator.is_main_process: + logger.info("Setting up evaluator...") + evaluator = create_evaluator(config, logger, accelerator) + + # Prepare everything with accelerator. + logger.info("Preparing model, optimizer and dataloaders") + # The dataloader are already aware of distributed training, so we don't need to prepare them. + if config.model.vq_model.is_legacy: + if config.model.vq_model.finetune_decoder: + model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare( + model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler + ) + else: + model, optimizer, lr_scheduler = accelerator.prepare( + model, optimizer, lr_scheduler + ) + else: + model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare( + model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler + ) + + if config.training.use_ema: + ema_model.to(accelerator.device) + + total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes + num_batches = math.ceil( + config.experiment.max_train_examples / total_batch_size_without_accum) + num_update_steps_per_epoch = math.ceil(num_batches / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + # Start training. + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + logger.info(f" Instantaneous batch size per gpu = { config.training.per_gpu_batch_size}") + logger.info(f""" Total train batch size (w. parallel, distributed & accumulation) = {( + config.training.per_gpu_batch_size * + accelerator.num_processes * + config.training.gradient_accumulation_steps)}""") + global_step = 0 + first_epoch = 0 + + global_step, first_epoch = auto_resume( + config, logger, accelerator, ema_model, num_update_steps_per_epoch, + strict=True) + + for current_epoch in range(first_epoch, num_train_epochs): + accelerator.print(f"Epoch {current_epoch}/{num_train_epochs-1} started.") + global_step = train_one_epoch(config, logger, accelerator, + model, ema_model, loss_module, + optimizer, discriminator_optimizer, + lr_scheduler, discriminator_lr_scheduler, + train_dataloader, eval_dataloader, + evaluator, + global_step, + pretrained_tokenizer=pretrained_tokenizer, + model_type="vibetoken") + # Stop training if max steps is reached. + if global_step >= config.training.max_train_steps: + accelerator.print( + f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}" + ) + break + + accelerator.wait_for_everyone() + # Save checkpoint at the end of training. + save_checkpoint(model, output_dir, accelerator, global_step, logger=logger) + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + if config.training.use_ema: + ema_model.copy_to(model.parameters()) + model.save_pretrained_weight(output_dir) + + if accelerator.is_main_process and config.training.enable_wandb: + wandb.finish() + logger.info("WandB run finished") + accelerator.end_training() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..4ca3020f3dcd0efb2dee70ed6e7cc0d05b9053ac --- /dev/null +++ b/setup.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Data preparation script for VibeToken training. +# Set DATA_DIR to control where datasets are stored (defaults to ./data). +# +# Usage: +# export DATA_DIR=/mnt/fastssd/datasets # optional, defaults to ./data +# bash setup.sh + +DATA_DIR="${DATA_DIR:-./data}" + +echo "Using DATA_DIR=${DATA_DIR}" + +# Download ImageNet-1k via HuggingFace +export HF_HUB_ENABLE_HF_TRANSFER=1 +huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir "${DATA_DIR}/imagenet-1k" + +# Convert to WebDataset format +python data/convert_imagenet_to_wds.py \ + --input_dir "${DATA_DIR}/imagenet-1k" \ + --output_dir "${DATA_DIR}/imagenet_wds" diff --git a/train_tokenvibe.sh b/train_tokenvibe.sh new file mode 100644 index 0000000000000000000000000000000000000000..517edf995d0b12071e0fb86c857646f953771dff --- /dev/null +++ b/train_tokenvibe.sh @@ -0,0 +1,14 @@ +# Run training with 8 GPUs across 2 nodes (4 GPUs per node) +NODE_RANK=${RANK:-1} +MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} +MASTER_PORT=${MASTER_PORT:-9871} + +accelerate launch \ + --num_machines=1 \ + --num_processes=8 \ + --machine_rank=$NODE_RANK \ + --main_process_ip=$MASTER_ADDR \ + --main_process_port=$MASTER_PORT \ + --same_network \ + scripts/train_tokenvibe.py \ + config=configs/training/VibeToken_small.yaml \ No newline at end of file diff --git a/train_vibetoken.sh b/train_vibetoken.sh new file mode 100644 index 0000000000000000000000000000000000000000..ab7ebdb0a4df3774d4aab3d0afaf778dee78effb --- /dev/null +++ b/train_vibetoken.sh @@ -0,0 +1,14 @@ +# Run training with 8 GPUs across 2 nodes (4 GPUs per node) +NODE_RANK=${RANK:-1} +MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} +MASTER_PORT=${MASTER_PORT:-9871} + +accelerate launch \ + --num_machines=1 \ + --num_processes=8 \ + --machine_rank=$NODE_RANK \ + --main_process_ip=$MASTER_ADDR \ + --main_process_port=$MASTER_PORT \ + --same_network \ + scripts/train_vibetoken.py \ + config=configs/training/VibeToken_small.yaml \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7e94c7b322c0a04d213d587ed94d666c9d073d --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,69 @@ +"""Util functions supporting logging to terminal and files.""" + +import functools +import sys +from accelerate.logging import MultiProcessAdapter +import logging +from termcolor import colored + +from iopath.common.file_io import PathManager as PathManagerClass + +__all__ = ["setup_logger", "PathManager"] + +PathManager = PathManagerClass() + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + self._abbrev_name = kwargs.pop("abbrev_name", self._root_name) + if len(self._abbrev_name): + self._abbrev_name = self._abbrev_name + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +@functools.lru_cache() +def setup_logger(name="TiTok", log_level: str = None, color=True, use_accelerate=True, + output_file=None): + logger = logging.getLogger(name) + if log_level is None: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(log_level.upper()) + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" + ) + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + if output_file is not None: + fileHandler = logging.FileHandler(output_file) + fileHandler.setFormatter(formatter) + logger.addHandler(fileHandler) + + if use_accelerate: + return MultiProcessAdapter(logger, {}) + else: + return logger \ No newline at end of file diff --git a/utils/lr_schedulers.py b/utils/lr_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5ff550c6accb988b2e8f827ea76f61cbbdd939 --- /dev/null +++ b/utils/lr_schedulers.py @@ -0,0 +1,129 @@ +"""Learning rate schedulers. + +Reference: + https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py +""" +import math +from enum import Enum +from typing import Optional, Union + +import torch + + +class SchedulerType(Enum): + COSINE = "cosine" + CONSTANT = "constant" + +def get_cosine_schedule_with_warmup( + optimizer: torch.optim.Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + base_lr: float = 1e-4, + end_lr: float = 0.0, +): + """Creates a cosine learning rate schedule with warm-up and ending learning rate. + + Args: + optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate. + num_warmup_steps: An integer, the number of steps for the warmup phase. + num_training_steps: An integer, the total number of training steps. + num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to + just decrease from the max value to 0 following a half-cosine). + last_epoch: An integer, the index of the last epoch when resuming training. + base_lr: A float, the base learning rate. + end_lr: A float, the final learning rate. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / \ + float(max(1, num_training_steps - num_warmup_steps)) + ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + return (end_lr + (base_lr - end_lr) * ratio) / base_lr + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_constant_schedule_with_warmup( + optimizer: torch.optim.Optimizer, + num_warmup_steps: int, + num_training_steps: int, + base_lr: float = 1e-4, + end_lr: float = 0.0, +): + """UViT: Creates a constant learning rate schedule with warm-up. + + Args: + optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate. + num_warmup_steps: An integer, the number of steps for the warmup phase. + num_training_steps: An integer, the total number of training steps. + num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to + just decrease from the max value to 0 following a half-cosine). + last_epoch: An integer, the index of the last epoch when resuming training. + base_lr: A float, the base learning rate. + end_lr: A float, the final learning rate. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + else: + return 1.0 + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule_with_warmup, +} + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: torch.optim.Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + base_lr: float = 1e-4, + end_lr: float = 0.0, +): + """Retrieves a learning rate scheduler from the given name and optimizer. + + Args: + name: A string or SchedulerType, the name of the scheduler to retrieve. + optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler. + num_warmup_steps: An integer, the number of warmup steps. + num_training_steps: An integer, the total number of training steps. + base_lr: A float, the base learning rate. + end_lr: A float, the final learning rate. + + Returns: + A instance of torch.optim.lr_scheduler.LambdaLR + + Raises: + ValueError: If num_warmup_steps or num_training_steps is not provided. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + base_lr=base_lr, + end_lr=end_lr, + ) \ No newline at end of file diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..154b64488c5150a8221ed5f53f23ba942d8852f9 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,342 @@ +"""This file is borrowed from https://github.com/LTH14/mar/blob/main/util/misc.py +""" +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +if TORCH_MAJOR == 1 and TORCH_MINOR < 8: + from torch._six import inf +else: + from torch import inf +import copy + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def add_weight_decay(model, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name: + no_decay.append(param) # no weight decay on bias, norm and diffloss + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None): + if epoch_name is None: + epoch_name = str(epoch) + output_dir = Path(args.output_dir) + checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name) + + # ema + if ema_params is not None: + ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) + for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): + assert name in ema_state_dict + ema_state_dict[name] = ema_params[i] + else: + ema_state_dict = None + + to_save = { + 'model': model_without_ddp.state_dict(), + 'model_ema': ema_state_dict, + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + save_on_master(to_save, checkpoint_path) + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x \ No newline at end of file diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6f399c4fdaa4f50fd482382bb61bb40e9f891816 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,836 @@ +"""Training utils for VibeToken.""" +import json +import os +import time +import math +from pathlib import Path +import pprint +import glob +from collections import defaultdict +import random +import gc + +from data import SimpleImageDataset, PretoeknizedDataSetJSONL, PretokenizedWebDataset +import torch +from torch.utils.data import DataLoader +from omegaconf import OmegaConf +from torch.optim import AdamW +from utils.lr_schedulers import get_scheduler +from modeling.modules import EMAModel, ReconstructionLoss_Single_Stage +from modeling.vibetoken_model import VibeTokenModel, PretrainedTokenizer +from evaluator import VQGANEvaluator + +from utils.viz_utils import make_viz_from_samples +from torchinfo import summary +import accelerate + +def get_config(): + """Reads configs from a yaml file and terminal.""" + cli_conf = OmegaConf.from_cli() + + yaml_conf = OmegaConf.load(cli_conf.config) + conf = OmegaConf.merge(yaml_conf, cli_conf) + + return conf + + +class AverageMeter(object): + """Computes and stores the average and current value. + + This class is borrowed from + https://github.com/pytorch/examples/blob/main/imagenet/main.py#L423 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def create_pretrained_tokenizer(config, accelerator=None): + if config.model.vq_model.finetune_decoder: + pretrianed_tokenizer = None + else: + pretrianed_tokenizer = PretrainedTokenizer(config.model.vq_model.pretrained_tokenizer_weight) + if accelerator is not None: + pretrianed_tokenizer.to(accelerator.device) + return pretrianed_tokenizer + + +def create_model_and_loss_module(config, logger, accelerator, + model_type="vibetoken"): + """Creates model and loss module.""" + logger.info("Creating model and loss module.") + if model_type == "vibetoken": + if config.model.sub_model_type == "vibetoken": + model_cls = VibeTokenModel + loss_cls = ReconstructionLoss_Single_Stage + else: + raise ValueError(f"Unsupported sub_model_type {config.model.sub_model_type}") + else: + raise ValueError(f"Unsupported model_type {model_type}") + model = model_cls(config) + + if config.experiment.get("init_weight", ""): + model_weight = torch.load(config.experiment.init_weight, map_location="cpu") + if config.model.vq_model.finetune_decoder: + pretrained_tokenizer_weight = torch.load( + config.model.vq_model.pretrained_tokenizer_weight, map_location="cpu" + ) + pretrained_tokenizer_weight = {"pixel_" + k:v for k,v in pretrained_tokenizer_weight.items() if not "encoder." in k} + model_weight.update(pretrained_tokenizer_weight) + + msg = model.load_state_dict(model_weight, strict=False) + logger.info(f"loading weight from {config.experiment.init_weight}, msg: {msg}") + + # Create the EMA model. + ema_model = None + if config.training.use_ema: + ema_model = EMAModel(model.parameters(), decay=0.999, + model_cls=model_cls, config=config) + def load_model_hook(models, input_dir): + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), + model_cls=model_cls, config=config) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + ema_model.save_pretrained(os.path.join(output_dir, "ema_model")) + + accelerator.register_load_state_pre_hook(load_model_hook) + accelerator.register_save_state_pre_hook(save_model_hook) + + loss_module = loss_cls(config=config) if loss_cls is not None else None + + if accelerator.is_main_process: + if model_type in ["vibetoken"]: + logger.info("VibeToken model summary not implemented yet.") + else: + raise NotImplementedError + + return model, ema_model, loss_module + + +def create_optimizer(config, logger, model, loss_module, + model_type="vibetoken", need_discrminator=True): + """Creates optimizer for model and discriminator.""" + logger.info("Creating optimizers.") + optimizer_config = config.optimizer.params + learning_rate = optimizer_config.learning_rate + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer_cls = AdamW + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + exclude = (lambda n, p: p.ndim < 2 or "ln" in n or "bias" in n or 'latent_tokens' in n + or 'mask_token' in n or 'embedding' in n or 'norm' in n or 'gamma' in n or 'embed' in n) + include = lambda n, p: not exclude(n, p) + named_parameters = list(model.named_parameters()) + gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + optimizer = optimizer_cls( + [ + {"params": gain_or_bias_params, "weight_decay": 0.}, + {"params": rest_params, "weight_decay": optimizer_config.weight_decay}, + ], + lr=learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2) + ) + + if (config.model.vq_model.finetune_decoder or model_type == "vibetoken") and need_discrminator: + discriminator_learning_rate = optimizer_config.discriminator_learning_rate + discriminator_named_parameters = list(loss_module.named_parameters()) + discriminator_gain_or_bias_params = [p for n, p in discriminator_named_parameters if exclude(n, p) and p.requires_grad] + discriminator_rest_params = [p for n, p in discriminator_named_parameters if include(n, p) and p.requires_grad] + + discriminator_optimizer = optimizer_cls( + [ + {"params": discriminator_gain_or_bias_params, "weight_decay": 0.}, + {"params": discriminator_rest_params, "weight_decay": optimizer_config.weight_decay}, + ], + lr=discriminator_learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2) + ) + else: + discriminator_optimizer = None + + assert discriminator_optimizer is not None, "Discriminator optimizer is None with condition values: {config.model.vq_model.finetune_decoder} {model_type} {need_discrminator}" + + return optimizer, discriminator_optimizer + + +def create_lr_scheduler(config, logger, accelerator, optimizer, discriminator_optimizer=None): + """Creates learning rate scheduler for model and discriminator.""" + logger.info("Creating lr_schedulers.") + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps * accelerator.num_processes, + num_warmup_steps=config.lr_scheduler.params.warmup_steps * accelerator.num_processes, + base_lr=config.lr_scheduler.params.learning_rate, + end_lr=config.lr_scheduler.params.end_lr, + ) + if discriminator_optimizer is not None: + discriminator_lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=discriminator_optimizer, + num_training_steps=config.training.max_train_steps * accelerator.num_processes - config.losses.discriminator_start, + num_warmup_steps=config.lr_scheduler.params.warmup_steps * accelerator.num_processes, + base_lr=config.lr_scheduler.params.learning_rate, + end_lr=config.lr_scheduler.params.end_lr, + ) + else: + discriminator_lr_scheduler = None + return lr_scheduler, discriminator_lr_scheduler + + +def create_dataloader(config, logger, accelerator): + """Creates data loader for training and testing.""" + logger.info("Creating dataloaders.") + total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes + total_batch_size = ( + config.training.per_gpu_batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + if dataset_config.get("pretokenization", "") and dataset_config.get("dataset_with_text_label", False) is True: + dataset = PretokenizedWebDataset( + train_shards_path=dataset_config.train_shards_path_or_url, + eval_shards_path=dataset_config.eval_shards_path_or_url, + num_train_examples=config.experiment.max_train_examples, + per_gpu_batch_size=config.training.per_gpu_batch_size, + global_batch_size=total_batch_size_without_accum, + num_workers_per_gpu=dataset_config.num_workers_per_gpu, + resize_shorter_edge=preproc_config.resize_shorter_edge, + crop_size=preproc_config.crop_size, + random_crop=preproc_config.random_crop, + random_flip=preproc_config.random_flip, + normalize_mean=preproc_config.normalize_mean, + normalize_std=preproc_config.normalize_std, + process_recap=preproc_config.get("preproc_recap", True), + use_recap_prob=preproc_config.get("use_recap_prob", 0.95) + ) + train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader + elif dataset_config.get("pretokenization", "") and dataset_config.get("dataset_with_text_label", False) is False: + dataset = SimpleImageDataset( + train_shards_path=dataset_config.train_shards_path_or_url, + eval_shards_path=dataset_config.eval_shards_path_or_url, + num_train_examples=config.experiment.max_train_examples, + per_gpu_batch_size=config.training.per_gpu_batch_size, + global_batch_size=total_batch_size_without_accum, + num_workers_per_gpu=dataset_config.num_workers_per_gpu, + resize_shorter_edge=preproc_config.resize_shorter_edge, + crop_size=preproc_config.crop_size, + random_crop=preproc_config.random_crop, + random_flip=preproc_config.random_flip, + dataset_with_class_label=dataset_config.get("dataset_with_class_label", True), + dataset_with_text_label=dataset_config.get("dataset_with_text_label", False), + res_ratio_filtering=preproc_config.get("res_ratio_filtering", False), + min_tokens=preproc_config.min_tokens, + max_tokens=preproc_config.max_tokens, + ) + train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader + else: + if dataset_config.get("pretokenization", ""): + train_dataloader = DataLoader( + PretoeknizedDataSetJSONL(dataset_config.pretokenization), + batch_size=config.training.per_gpu_batch_size, + shuffle=True, drop_last=True, pin_memory=True) + train_dataloader.num_batches = math.ceil( + config.experiment.max_train_examples / total_batch_size_without_accum) + + return train_dataloader, eval_dataloader + + +class LazyVQGANEvaluator: + """A lazy-loading wrapper for VQGANEvaluator that delays inception model initialization.""" + + def __init__(self, device, enable_rfid=True, enable_inception_score=True, + enable_codebook_usage_measure=False, enable_codebook_entropy_measure=False, + num_codebook_entries=1024, accelerator=None): + self._device = device + self._enable_rfid = enable_rfid + self._enable_inception_score = enable_inception_score + self._enable_codebook_usage_measure = enable_codebook_usage_measure + self._enable_codebook_entropy_measure = enable_codebook_entropy_measure + self._num_codebook_entries = num_codebook_entries + self._accelerator = accelerator + self._evaluator = None + self._initialized = False + + def _ensure_initialized(self): + """Initialize the real evaluator only when needed.""" + if not self._initialized: + if self._accelerator and self._accelerator.num_processes > 1: + if self._accelerator.is_main_process: + try: + from evaluator.inception import get_inception_model + _ = get_inception_model() + except Exception as e: + print(f"Warning: Failed to pre-load inception model: {e}") + + if self._accelerator: + self._accelerator.wait_for_everyone() + + try: + self._evaluator = VQGANEvaluator( + device=self._device, + enable_rfid=self._enable_rfid, + enable_inception_score=self._enable_inception_score, + enable_codebook_usage_measure=self._enable_codebook_usage_measure, + enable_codebook_entropy_measure=self._enable_codebook_entropy_measure, + num_codebook_entries=self._num_codebook_entries + ) + self._initialized = True + except Exception as e: + print(f"Warning: Failed to create VQGANEvaluator, using dummy: {e}") + class DummyEvaluator: + def reset_metrics(self): pass + def update(self, real_images, fake_images, codebook_indices=None): pass + def result(self): + return {"InceptionScore": 0.0, "rFID": 0.0, "CodebookUsage": 0.0, "CodebookEntropy": 0.0} + self._evaluator = DummyEvaluator() + self._initialized = True + + def reset_metrics(self): + self._ensure_initialized() + return self._evaluator.reset_metrics() + + def update(self, real_images, fake_images, codebook_indices=None): + self._ensure_initialized() + return self._evaluator.update(real_images, fake_images, codebook_indices) + + def result(self): + self._ensure_initialized() + return self._evaluator.result() + + +def create_evaluator(config, logger, accelerator): + """Creates evaluator.""" + logger.info("Creating evaluator.") + + if config.model.vq_model.get("quantize_mode", "vq") in ["vq", "softvq", "mvq"]: + evaluator = LazyVQGANEvaluator( + device=accelerator.device, + enable_rfid=True, + enable_inception_score=True, + enable_codebook_usage_measure=True, + enable_codebook_entropy_measure=True, + num_codebook_entries=config.model.vq_model.codebook_size, + accelerator=accelerator + ) + elif config.model.vq_model.get("quantize_mode", "vq") == "vae": + evaluator = LazyVQGANEvaluator( + device=accelerator.device, + enable_rfid=True, + enable_inception_score=True, + enable_codebook_usage_measure=False, + enable_codebook_entropy_measure=False, + accelerator=accelerator + ) + else: + raise NotImplementedError + + logger.info("Lazy evaluator creation completed.") + return evaluator + + +def auto_resume(config, logger, accelerator, ema_model, + num_update_steps_per_epoch, strict=True): + """Auto resuming the training.""" + global_step = 0 + first_epoch = 0 + if config.experiment.resume: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + local_ckpt_list = list(glob.glob(os.path.join( + config.experiment.output_dir, "checkpoint*"))) + logger.info(f"All globbed checkpoints are: {local_ckpt_list}") + else: + local_ckpt_list = [] + + if accelerator.num_processes > 1: + checkpoint_count = torch.tensor(len(local_ckpt_list), device=accelerator.device) + accelerate.utils.broadcast(checkpoint_count, 0) + + if checkpoint_count > 0: + if accelerator.is_main_process: + if len(local_ckpt_list) > 1: + fn = lambda x: int(x.split('/')[-1].split('-')[-1]) + checkpoint_paths = sorted(local_ckpt_list, key=fn, reverse=True) + else: + checkpoint_paths = local_ckpt_list + latest_checkpoint = checkpoint_paths[0] + else: + latest_checkpoint = "" + + if accelerator.is_main_process: + checkpoint_path_tensor = torch.tensor([ord(c) for c in latest_checkpoint], device=accelerator.device, dtype=torch.long) + path_length = torch.tensor(len(latest_checkpoint), device=accelerator.device) + else: + path_length = torch.tensor(0, device=accelerator.device) + + accelerate.utils.broadcast(path_length, 0) + + if not accelerator.is_main_process: + checkpoint_path_tensor = torch.zeros(path_length.item(), device=accelerator.device, dtype=torch.long) + + accelerate.utils.broadcast(checkpoint_path_tensor, 0) + + if not accelerator.is_main_process: + latest_checkpoint = ''.join([chr(c.item()) for c in checkpoint_path_tensor]) + + global_step = load_checkpoint( + Path(latest_checkpoint), + accelerator, + logger=logger, + strict=strict + ) + if config.training.use_ema: + ema_model.set_step(global_step) + first_epoch = global_step // num_update_steps_per_epoch + else: + logger.info("Training from scratch.") + else: + if len(local_ckpt_list) >= 1: + if len(local_ckpt_list) > 1: + fn = lambda x: int(x.split('/')[-1].split('-')[-1]) + checkpoint_paths = sorted(local_ckpt_list, key=fn, reverse=True) + else: + checkpoint_paths = local_ckpt_list + global_step = load_checkpoint( + Path(checkpoint_paths[0]), + accelerator, + logger=logger, + strict=strict + ) + if config.training.use_ema: + ema_model.set_step(global_step) + first_epoch = global_step // num_update_steps_per_epoch + else: + logger.info("Training from scratch.") + + accelerator.wait_for_everyone() + return global_step, first_epoch + + +def train_one_epoch(config, logger, accelerator, + model, ema_model, loss_module, + optimizer, discriminator_optimizer, + lr_scheduler, discriminator_lr_scheduler, + train_dataloader, eval_dataloader, + evaluator, + global_step, + model_type="vibetoken", + clip_tokenizer=None, + clip_encoder=None, + pretrained_tokenizer=None): + """One epoch training.""" + batch_time_meter = AverageMeter() + data_time_meter = AverageMeter() + end = time.time() + + model.train() + + autoencoder_logs = defaultdict(float) + discriminator_logs = defaultdict(float) + for i, batch in enumerate(train_dataloader): + model.train() + if "image" in batch: + images = batch["image"].to( + accelerator.device, memory_format=torch.contiguous_format, non_blocking=True + ) + if config.training.get("variable_resolution", False): + any2any = config.training.variable_resolution.get("any2any", True) + + dims = config.training.variable_resolution.dim + ratios = config.training.variable_resolution.ratio + assert len(dims) == len(ratios), "dims and ratios must have the same length" + input_res = tuple(random.choices(dims, weights=ratios, k=1)[0]) + + if any2any: + output_res = tuple(random.choices(dims, weights=ratios, k=1)[0]) + else: + output_res = input_res + + images = torch.nn.functional.interpolate(images, size=output_res, mode="bilinear", align_corners=False) + input_images = torch.nn.functional.interpolate(images, size=input_res, mode="bilinear", align_corners=False) + else: + input_images = images + output_res = (None, None) + + fnames = batch["__key__"] + data_time_meter.update(time.time() - end) + + if pretrained_tokenizer is not None: + pretrained_tokenizer.eval() + proxy_codes = pretrained_tokenizer.encode(images) + else: + proxy_codes = None + + with accelerator.accumulate([model, loss_module]): + additional_args = {} + if config.model.get("train_with_attention", False): + additional_args["key_attention_mask"] = batch["attention_mask"].to( + accelerator.device, memory_format=torch.contiguous_format, non_blocking=True + ) + reconstructed_images, extra_results_dict = model(input_images, height=output_res[0], width=output_res[1], **additional_args) + autoencoder_loss, loss_dict = loss_module( + images, + reconstructed_images, + extra_results_dict, + global_step, + mode="generator", + ) + + autoencoder_logs = {} + for k, v in loss_dict.items(): + if k in ["discriminator_factor", "d_weight"]: + if type(v) == torch.Tensor: + autoencoder_logs["train/" + k] = v.cpu().item() + else: + autoencoder_logs["train/" + k] = v + else: + gathered_tensor = accelerator.gather(v) + autoencoder_logs["train/" + k] = gathered_tensor.mean().item() + del gathered_tensor + + torch.cuda.empty_cache() + accelerator.backward(autoencoder_loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + # Train discriminator. + discriminator_logs = defaultdict(float) + if (config.model.vq_model.finetune_decoder or model_type == "vibetoken") and accelerator.unwrap_model(loss_module).should_discriminator_be_trained(global_step): + discriminator_logs = defaultdict(float) + discriminator_loss, loss_dict_discriminator = loss_module( + images, + reconstructed_images, + extra_results_dict, + global_step=global_step, + mode="discriminator", + ) + + for k, v in loss_dict_discriminator.items(): + if k in ["logits_real", "logits_fake"]: + if type(v) == torch.Tensor: + discriminator_logs["train/" + k] = v.cpu().item() + else: + discriminator_logs["train/" + k] = v + else: + gathered_tensor = accelerator.gather(v) + discriminator_logs["train/" + k] = gathered_tensor.mean().item() + del gathered_tensor + + torch.cuda.empty_cache() + accelerator.backward(discriminator_loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(loss_module.parameters(), config.training.max_grad_norm) + + discriminator_optimizer.step() + discriminator_lr_scheduler.step() + + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(loss_module, accelerator, global_step + 1) + + discriminator_optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + if config.training.use_ema: + ema_model.step(model.parameters()) + batch_time_meter.update(time.time() - end) + end = time.time() + + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * config.training.per_gpu_batch_size / batch_time_meter.val + ) + + lr = lr_scheduler.get_last_lr()[0] + logger.info( + f"Data (t): {data_time_meter.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_meter.val:0.4f} " + f"LR: {lr:0.6f} " + f"Step: {global_step + 1} " + f"Total Loss: {autoencoder_logs['train/total_loss']:0.4f} " + f"Recon Loss: {autoencoder_logs['train/reconstruction_loss']:0.4f} " + ) + logs = { + "lr": lr, + "lr/generator": lr, + "samples/sec/gpu": samples_per_second_per_gpu, + "time/data_time": data_time_meter.val, + "time/batch_time": batch_time_meter.val, + } + logs.update(autoencoder_logs) + logs.update(discriminator_logs) + accelerator.log(logs, step=global_step + 1) + + del autoencoder_logs, discriminator_logs, logs + gc.collect() + + batch_time_meter.reset() + data_time_meter.reset() + + # Save model checkpoint. + if (global_step + 1) % config.experiment.save_every == 0: + save_path = save_checkpoint( + model, config.experiment.output_dir, accelerator, global_step + 1, logger=logger) + accelerator.wait_for_everyone() + + # Generate images. + if (global_step + 1) % config.experiment.generate_every == 0: + if accelerator.is_main_process: + if config.training.get("use_ema", False): + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) + + reconstruct_images( + model, + images[:config.training.num_generated_images], + fnames[:config.training.num_generated_images], + accelerator, + global_step + 1, + config.experiment.output_dir, + logger=logger, + config=config, + pretrained_tokenizer=pretrained_tokenizer + ) + + if config.training.get("use_ema", False): + ema_model.restore(model.parameters()) + + accelerator.wait_for_everyone() + + + # Evaluate reconstruction. + if eval_dataloader is not None and (global_step + 1) % config.experiment.eval_every == 0: + logger.info(f"Computing metrics on the validation set.") + if config.training.get("use_ema", False): + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) + eval_scores = eval_reconstruction( + config, + model, + eval_dataloader, + accelerator, + evaluator, + pretrained_tokenizer=pretrained_tokenizer + ) + logger.info( + f"EMA EVALUATION " + f"Step: {global_step + 1} " + ) + logger.info(pprint.pformat(eval_scores)) + if accelerator.is_main_process: + eval_log = {f'ema_eval/'+k: v for k, v in eval_scores.items()} + accelerator.log(eval_log, step=global_step + 1) + if config.training.get("use_ema", False): + ema_model.restore(model.parameters()) + else: + eval_scores = eval_reconstruction( + config, + model, + eval_dataloader, + accelerator, + evaluator, + pretrained_tokenizer=pretrained_tokenizer + ) + + logger.info( + f"Non-EMA EVALUATION " + f"Step: {global_step + 1} " + ) + logger.info(pprint.pformat(eval_scores)) + if accelerator.is_main_process: + eval_log = {f'eval/'+k: v for k, v in eval_scores.items()} + accelerator.log(eval_log, step=global_step + 1) + + accelerator.wait_for_everyone() + + global_step += 1 + + if global_step >= config.training.max_train_steps: + accelerator.print( + f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}" + ) + break + + + return global_step + + +@torch.no_grad() +def eval_reconstruction( + config, + model, + eval_loader, + accelerator, + evaluator, + pretrained_tokenizer=None +): + model.eval() + evaluator.reset_metrics() + local_model = accelerator.unwrap_model(model) + + accelerator.wait_for_everyone() + + for batch in eval_loader: + images = batch["image"].to( + accelerator.device, memory_format=torch.contiguous_format, non_blocking=True + ) + + original_images = torch.clone(images) + additional_args = {} + if config.model.get("eval_with_attention", False): + additional_args["key_attention_mask"] = batch["attention_mask"].to( + accelerator.device, memory_format=torch.contiguous_format, non_blocking=True + ) + reconstructed_images, model_dict = local_model(images, **additional_args) + + if pretrained_tokenizer is not None: + reconstructed_images = pretrained_tokenizer.decode(reconstructed_images.argmax(1)) + reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0) + reconstructed_images = torch.round(reconstructed_images * 255.0) / 255.0 + original_images = torch.clamp(original_images, 0.0, 1.0) + + if isinstance(model_dict, dict): + evaluator.update(original_images, reconstructed_images.squeeze(2), model_dict["min_encoding_indices"]) + else: + evaluator.update(original_images, reconstructed_images.squeeze(2), None) + + accelerator.wait_for_everyone() + + local_results = evaluator.result() + + if accelerator.num_processes > 1: + gathered_results = {} + for key, value in local_results.items(): + if isinstance(value, (int, float)): + value_tensor = torch.tensor(value, device=accelerator.device) + gathered_values = accelerator.gather(value_tensor) + gathered_results[key] = gathered_values.mean().item() + else: + gathered_results[key] = value + + accelerator.wait_for_everyone() + model.train() + return gathered_results + else: + model.train() + return local_results + + +@torch.no_grad() +def reconstruct_images(model, original_images, fnames, accelerator, + global_step, output_dir, logger, config=None, + pretrained_tokenizer=None): + logger.info("Reconstructing images...") + original_images = torch.clone(original_images) + _, _, height, width = original_images.shape + model.eval() + dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + dtype = torch.bfloat16 + + with torch.autocast("cuda", dtype=dtype, enabled=accelerator.mixed_precision != "no"): + enc_tokens, encoder_dict = accelerator.unwrap_model(model).encode(original_images) + reconstructed_images = accelerator.unwrap_model(model).decode(enc_tokens, height=height, width=width) + if pretrained_tokenizer is not None: + reconstructed_images = pretrained_tokenizer.decode(reconstructed_images.argmax(1)) + + images_for_saving, images_for_logging = make_viz_from_samples( + original_images, + reconstructed_images + ) + if config.training.enable_wandb: + accelerator.get_tracker("wandb").log_images( + {f"Train Reconstruction": images_for_saving}, + step=global_step + ) + else: + accelerator.get_tracker("tensorboard").log_images( + {"Train Reconstruction": images_for_logging}, step=global_step + ) + root = Path(output_dir) / "train_images" + os.makedirs(root, exist_ok=True) + for i,img in enumerate(images_for_saving): + filename = f"{global_step:08}_s-{i:03}-{fnames[i]}.png" + path = os.path.join(root, filename) + img.save(path) + + model.train() + + +def save_checkpoint(model, output_dir, accelerator, global_step, logger) -> Path: + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained_weight( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + accelerator.save_state(save_path) + return save_path + + +def load_checkpoint(checkpoint_path: Path, accelerator, logger, strict=True): + logger.info(f"Load checkpoint from {checkpoint_path}") + + accelerator.load_state(checkpoint_path, strict=strict) + + with open(checkpoint_path / "metadata.json", "r") as f: + global_step = int(json.load(f)["global_step"]) + + logger.info(f"Resuming at global_step {global_step}") + return global_step + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) diff --git a/utils/viz_utils.py b/utils/viz_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f7e6c28f2af996b6fd489d56c34dced895379ac --- /dev/null +++ b/utils/viz_utils.py @@ -0,0 +1,83 @@ +"""Utils functions for visualization.""" + +import torch +import torchvision.transforms.functional as F +from einops import rearrange +from PIL import Image, ImageDraw, ImageFont + +def make_viz_from_samples( + original_images, + reconstructed_images +): + """Generates visualization images from original images and reconstructed images. + + Args: + original_images: A torch.Tensor, original images. + reconstructed_images: A torch.Tensor, reconstructed images. + + Returns: + A tuple containing two lists - images_for_saving and images_for_logging. + """ + reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0) + reconstructed_images = reconstructed_images * 255.0 + reconstructed_images = reconstructed_images.cpu() + + original_images = torch.clamp(original_images, 0.0, 1.0) + original_images *= 255.0 + original_images = original_images.cpu() + + diff_img = torch.abs(original_images - reconstructed_images) + to_stack = [original_images, reconstructed_images, diff_img] + + images_for_logging = rearrange( + torch.stack(to_stack), + "(l1 l2) b c h w -> b c (l1 h) (l2 w)", + l1=1).byte() + images_for_saving = [F.to_pil_image(image) for image in images_for_logging] + + return images_for_saving, images_for_logging + + +def make_viz_from_samples_generation( + generated_images, +): + generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0 + images_for_logging = rearrange( + generated, + "(l1 l2) c h w -> c (l1 h) (l2 w)", + l1=2) + + images_for_logging = images_for_logging.cpu().byte() + images_for_saving = F.to_pil_image(images_for_logging) + + return images_for_saving, images_for_logging + + +def make_viz_from_samples_t2i_generation( + generated_images, + captions, +): + generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0 + images_for_logging = rearrange( + generated, + "(l1 l2) c h w -> c (l1 h) (l2 w)", + l1=2) + + images_for_logging = images_for_logging.cpu().byte() + images_for_saving = F.to_pil_image(images_for_logging) + + # Create a new image with space for captions + width, height = images_for_saving.size + caption_height = 20 * len(captions) + 10 + new_height = height + caption_height + new_image = Image.new("RGB", (width, new_height), "black") + new_image.paste(images_for_saving, (0, 0)) + + # Adding captions below the image + draw = ImageDraw.Draw(new_image) + font = ImageFont.load_default() + + for i, caption in enumerate(captions): + draw.text((10, height + 10 + i * 20), caption, fill="white", font=font) + + return new_image, images_for_logging \ No newline at end of file diff --git a/vibetoken/__init__.py b/vibetoken/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee1543c09c194ce2ac9ed6dc4fd213506323206d --- /dev/null +++ b/vibetoken/__init__.py @@ -0,0 +1,16 @@ +"""VibeToken - Minimal inference library for VibeToken image tokenizer.""" + +from .tokenizer import ( + VibeTokenTokenizer, + center_crop_to_multiple, + get_auto_patch_size, + auto_preprocess_image, +) + +__version__ = "0.1.0" +__all__ = [ + "VibeTokenTokenizer", + "center_crop_to_multiple", + "get_auto_patch_size", + "auto_preprocess_image", +] diff --git a/vibetoken/modeling/__init__.py b/vibetoken/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fab0334ac4e28c40f68a3261677d55ba214d9c23 --- /dev/null +++ b/vibetoken/modeling/__init__.py @@ -0,0 +1,7 @@ +"""VibeToken modeling components.""" + +from .vibetoken import VibeToken +from .encoder import ResolutionEncoder +from .decoder import ResolutionDecoder + +__all__ = ["VibeToken", "ResolutionEncoder", "ResolutionDecoder"] diff --git a/vibetoken/modeling/blocks.py b/vibetoken/modeling/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5c70e95ccf7385802aeb80e78d3cc68ce61b4c --- /dev/null +++ b/vibetoken/modeling/blocks.py @@ -0,0 +1,288 @@ +"""Transformer building blocks for VibeToken. + +Reference: + https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py + https://github.com/baofff/U-ViT/blob/main/libs/timm.py +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from typing import Optional +import einops + + +# Determine attention mode based on available implementations +if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + try: + import xformers + import xformers.ops + ATTENTION_MODE = 'xformers' + except ImportError: + ATTENTION_MODE = 'math' + + +class Attention(nn.Module): + """Multi-head self-attention with support for flash/xformers/math backends.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, L, C = x.shape + + qkv = self.qkv(x) + if ATTENTION_MODE == 'flash': + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() + q, k, v = qkv[0], qkv[1], qkv[2] + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'xformers': + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] + x = xformers.ops.memory_efficient_attention(q, k, v) + x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) + else: # math + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class ResidualAttentionBlock(nn.Module): + """Residual attention block with MLP.""" + + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: type = nn.GELU, + norm_layer: type = nn.LayerNorm, + ): + super().__init__() + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.mlp_ratio = mlp_ratio + + if mlp_ratio > 0: + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + def attention( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.attn(x, x, x, attn_mask=attention_mask, need_weights=False)[0] + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_output = self.attention(x=self.ln_1(x), attention_mask=attention_mask) + x = x + attn_output + if self.mlp_ratio > 0: + x = x + self.mlp(self.ln_2(x)) + return x + + +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """Drop paths (Stochastic Depth) per sample.""" + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample.""" + + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + """MLP block with GELU activation.""" + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: type = nn.GELU, + drop: float = 0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class UViTBlock(nn.Module): + """U-ViT block with optional skip connection.""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: type = nn.GELU, + norm_layer: type = nn.LayerNorm, + skip: bool = False, + use_checkpoint: bool = False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, + qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, skip, use_reentrant=False) + return self._forward(x, skip) + + def _forward(self, x: torch.Tensor, skip: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.skip_linear is not None and skip is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ResizableBlur(nn.Module): + """Anti-aliasing layer for downsampling with learnable blur kernel.""" + + def __init__( + self, + channels: int, + max_kernel_size: int = 9, + init_type: str = "gaussian", + ): + super().__init__() + self.C = channels + K = max_kernel_size + assert K % 2 == 1, "kernel must be odd" + + if init_type == "gaussian": + ax = torch.arange(-(K // 2), K // 2 + 1) + g1d = torch.exp(-0.5 * (ax / (K / 6.0)) ** 2) + g2d = torch.outer(g1d, g1d) + kernel = g2d / g2d.sum() + elif init_type == "lanczos": + a = K // 2 + x = torch.arange(-a, a + 1).float() + sinc = lambda t: torch.where( + t == 0, torch.ones_like(t), + torch.sin(torch.pi * t) / (torch.pi * t) + ) + k1d = sinc(x) * sinc(x / a) + k2d = torch.outer(k1d, k1d) + kernel = k2d / k2d.sum() + else: + raise ValueError(f"Unknown init_type: {init_type}") + + self.weight = nn.Parameter(kernel.unsqueeze(0).unsqueeze(0)) + + @staticmethod + def _resize_and_normalise(weight: torch.Tensor, k_size: int) -> torch.Tensor: + if weight.shape[-1] != k_size: + weight = F.interpolate(weight, size=(k_size, k_size), mode="bilinear", align_corners=True) + weight = weight / weight.sum(dim=(-2, -1), keepdim=True).clamp(min=1e-8) + return weight + + def forward(self, x: torch.Tensor, input_size: tuple, target_size: tuple) -> torch.Tensor: + input_h, input_w = input_size + target_h, target_w = target_size + + scale_h = input_h / target_h + scale_w = input_w / target_w + + k_size_h = min(self.weight.shape[-1], max(1, int(2 * scale_h + 3))) + k_size_w = min(self.weight.shape[-1], max(1, int(2 * scale_w + 3))) + k_size_h = k_size_h if k_size_h % 2 == 1 else k_size_h + 1 + k_size_w = k_size_w if k_size_w % 2 == 1 else k_size_w + 1 + k_size = max(k_size_h, k_size_w) + + stride_h = max(1, round(scale_h)) + stride_w = max(1, round(scale_w)) + pad_h = k_size_h // 2 + pad_w = k_size_w // 2 + + k = self._resize_and_normalise(self.weight, k_size) + k = k.repeat(self.C, 1, 1, 1) + + result = F.conv2d(x, weight=k, stride=(stride_h, stride_w), + padding=(pad_h, pad_w), groups=self.C) + + if result.shape[2:] != target_size: + result = F.interpolate(result, size=target_size, mode='bilinear', align_corners=True) + + return result + + +def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor: + """Expand a single token to batch size.""" + return token.unsqueeze(0).expand(batch_size, -1, -1) diff --git a/vibetoken/modeling/decoder.py b/vibetoken/modeling/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e265c648fcefdd89702a5ff102cbc5cf0c993b --- /dev/null +++ b/vibetoken/modeling/decoder.py @@ -0,0 +1,206 @@ +"""Resolution-aware decoder for VibeToken. + +Vision Transformer-based decoder with flexible output resolutions. +""" + +import torch +import torch.nn as nn +from typing import Optional, Tuple +from einops.layers.torch import Rearrange + +from .blocks import ResidualAttentionBlock, ResizableBlur, _expand_token +from .embeddings import FuzzyEmbedding + + +class ResolutionDecoder(nn.Module): + """Vision Transformer decoder with flexible resolution support. + + Decodes latent tokens back to images with support for variable + output resolutions and patch sizes. + """ + + # Model size configurations + MODEL_CONFIGS = { + "small": {"width": 512, "num_layers": 8, "num_heads": 8}, + "base": {"width": 768, "num_layers": 12, "num_heads": 12}, + "large": {"width": 1024, "num_layers": 24, "num_heads": 16}, + } + + def __init__(self, config): + """Initialize ResolutionDecoder. + + Args: + config: OmegaConf config with model parameters. + """ + super().__init__() + self.config = config + + # Extract config values + vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model + self.image_size = getattr(config.dataset.preprocessing, 'crop_size', 512) if hasattr(config, 'dataset') else 512 + self.patch_size = getattr(vq_config, 'vit_dec_patch_size', 32) + self.model_size = getattr(vq_config, 'vit_dec_model_size', 'large') + self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256) + self.token_size = getattr(vq_config, 'token_size', 256) + self.is_legacy = getattr(vq_config, 'is_legacy', False) + + if self.is_legacy: + raise NotImplementedError("Legacy mode is not supported in this inference-only version") + + # Get model dimensions + model_cfg = self.MODEL_CONFIGS[self.model_size] + self.width = model_cfg["width"] + self.num_layers = model_cfg["num_layers"] + self.num_heads = model_cfg["num_heads"] + + # Input projection + self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True) + + # Embeddings + scale = self.width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) + self.positional_embedding = FuzzyEmbedding(1024, scale, self.width) + self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width)) + self.latent_token_positional_embedding = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.width) + ) + self.ln_pre = nn.LayerNorm(self.width) + + # Transformer layers + self.transformer = nn.ModuleList([ + ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) + for _ in range(self.num_layers) + ]) + + # Output projection + self.ln_post = nn.LayerNorm(self.width) + self.ffn = nn.Conv2d( + self.width, self.patch_size * self.patch_size * 3, + kernel_size=1, padding=0, bias=True + ) + self.rearrange = Rearrange( + 'b (p1 p2 c) h w -> b c (h p1) (w p2)', + p1=self.patch_size, p2=self.patch_size + ) + self.down_scale = ResizableBlur(channels=3, max_kernel_size=9, init_type="lanczos") + self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) + + def _select_patch_size(self, height: int, width: int) -> int: + """Select appropriate patch size based on target resolution. + + Args: + height: Target image height. + width: Target image width. + + Returns: + Selected patch size. + """ + total_pixels = height * width + min_patches, max_patches = 256, 1024 + + possible_sizes = [] + for ps in [8, 16, 32]: + grid_h = height // ps + grid_w = width // ps + total_patches = grid_h * grid_w + if min_patches <= total_patches <= max_patches: + possible_sizes.append(ps) + + if not possible_sizes: + # Find closest to target range + patch_counts = [] + for ps in [8, 16, 32]: + grid_h = height // ps + grid_w = width // ps + patch_counts.append((ps, grid_h * grid_w)) + patch_counts.sort(key=lambda x: min(abs(x[1] - min_patches), abs(x[1] - max_patches))) + return patch_counts[0][0] + + return possible_sizes[0] + + def forward( + self, + z_quantized: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + decode_patch_size: Optional[int] = None, + ) -> torch.Tensor: + """Decode latent tokens to images. + + Args: + z_quantized: Quantized latent features (B, C, H, W). + attention_mask: Optional attention mask. + height: Target image height. + width: Target image width. + decode_patch_size: Optional custom patch size for decoding. + + Returns: + Decoded images (B, 3, height, width), values in [0, 1]. + """ + N, C, H, W = z_quantized.shape + + # Reshape and project input + x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # (N, seq_len, C*H) + x = self.decoder_embed(x) + + batchsize, seq_len, _ = x.shape + + # Default output size + if height is None: + height = self.image_size + if width is None: + width = self.image_size + + # Determine patch size + if decode_patch_size is None: + selected_patch_size = self._select_patch_size(height, width) + else: + selected_patch_size = decode_patch_size + + if isinstance(selected_patch_size, int): + selected_patch_size = (selected_patch_size, selected_patch_size) + + grid_height = height // selected_patch_size[0] + grid_width = width // selected_patch_size[1] + + # Create mask tokens for output positions + mask_tokens = self.mask_token.repeat(batchsize, grid_height * grid_width, 1).to(x.dtype) + mask_tokens = torch.cat([ + _expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype), + mask_tokens + ], dim=1) + + # Add positional embeddings + mask_tokens = mask_tokens + self.positional_embedding( + grid_height, grid_width, train=False + ).to(mask_tokens.dtype) + + x = x + self.latent_token_positional_embedding[:seq_len] + x = torch.cat([mask_tokens, x], dim=1) + + # Pre-norm and reshape for transformer + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # (seq_len, B, width) + + # Apply transformer layers + for layer in self.transformer: + x = layer(x, attention_mask=None) + + x = x.permute(1, 0, 2) # (B, seq_len, width) + + # Extract output tokens (excluding class token and latent tokens) + x = x[:, 1:1 + grid_height * grid_width] + x = self.ln_post(x) + + # Reshape to spatial format and project to pixels + x = x.permute(0, 2, 1).reshape(batchsize, self.width, grid_height, grid_width) + x = self.ffn(x.contiguous()) + x = self.rearrange(x) + + # Downsample to target resolution + _, _, org_h, org_w = x.shape + x = self.down_scale(x, input_size=(org_h, org_w), target_size=(height, width)) + x = self.conv_out(x) + + return x diff --git a/vibetoken/modeling/embeddings.py b/vibetoken/modeling/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf3c89fb5ccf4e7a648fb7ef0426475c82ac424 --- /dev/null +++ b/vibetoken/modeling/embeddings.py @@ -0,0 +1,193 @@ +"""Embedding modules for VibeToken. + +Includes positional embeddings with fuzzy interpolation for variable resolutions. +""" + +import math +from typing import Tuple, Any +import collections.abc +from itertools import repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F +import einops +import numpy as np +from einops import rearrange +from torch import Tensor, vmap + + +def to_2tuple(x: Any) -> Tuple: + """Convert input to 2-tuple.""" + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, 2)) + + +class FuzzyEmbedding(nn.Module): + """Fuzzy positional embedding with bilinear interpolation. + + Supports variable-resolution inputs by interpolating from a base + positional embedding grid. + """ + + def __init__( + self, + grid_size: int, + scale: float, + width: int, + apply_fuzzy: bool = False, + ): + """Initialize FuzzyEmbedding. + + Args: + grid_size: Base grid size (must be 1024). + scale: Initialization scale for embeddings. + width: Embedding dimension. + apply_fuzzy: Whether to add noise during training. + """ + super().__init__() + assert grid_size == 1024, "grid_size must be 1024" + + self.grid_size = grid_size + self.scale = scale + self.width = width + self.apply_fuzzy = apply_fuzzy + + self.positional_embedding = nn.Parameter(scale * torch.randn(grid_size, width)) + self.class_positional_embedding = nn.Parameter(scale * torch.randn(1, width)) + + @torch.amp.autocast('cuda', enabled=False) + def forward( + self, + grid_height: int, + grid_width: int, + train: bool = False, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """Compute positional embeddings for given grid size. + + Args: + grid_height: Target grid height. + grid_width: Target grid width. + train: Whether in training mode (affects fuzzy noise). + dtype: Output dtype. + + Returns: + Positional embeddings of shape (1 + grid_height * grid_width, width). + """ + device = self.positional_embedding.device + + meshx, meshy = torch.meshgrid( + torch.arange(grid_height, device=device), + torch.arange(grid_width, device=device), + indexing='ij' + ) + meshx = meshx.to(dtype) + meshy = meshy.to(dtype) + + # Normalize coordinates to [-1, 1] + meshx = 2 * (meshx / max(grid_height - 1, 1)) - 1 + meshy = 2 * (meshy / max(grid_width - 1, 1)) - 1 + + if self.apply_fuzzy and train: + noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004 + noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004 + meshx = meshx + noise_x + meshy = meshy + noise_y + + grid = torch.stack((meshy, meshx), 2).unsqueeze(0) + + # Reshape positional embedding to 2D grid + base_size = int(math.sqrt(self.grid_size)) + positional_embedding = einops.rearrange( + self.positional_embedding, "(h w) d -> d h w", h=base_size, w=base_size + ) + positional_embedding = positional_embedding.to(dtype).unsqueeze(0) + + # Bilinear interpolation + fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False) + fuzzy_embedding = fuzzy_embedding.to(dtype) + fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0) + + final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0) + return final_embedding + + +class FlexiPatchEmbed: + """Flexible patch embedding utilities for variable patch sizes. + + Based on FlexiViT: https://arxiv.org/abs/2212.08013 + """ + + def __init__(self, base_patch_size: int, width: int): + """Initialize FlexiPatchEmbed. + + Args: + base_patch_size: Original patch size of the model. + width: Embedding dimension. + """ + self.base_patch_size = base_patch_size + self.width = width + self.pinvs = {} + + def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: + """Bilinear resize of 2D tensor.""" + x_resized = F.interpolate( + x[None, None, ...], + shape, + mode="bilinear", + antialias=False, + ) + return x_resized[0, 0, ...] + + def _calculate_pinv( + self, + old_shape: Tuple[int, int], + new_shape: Tuple[int, int], + device: torch.device, + ) -> Tensor: + """Calculate pseudo-inverse of resize matrix.""" + mat = [] + for i in range(np.prod(old_shape)): + basis_vec = torch.zeros(old_shape, device=device) + basis_vec[np.unravel_index(i, old_shape)] = 1.0 + mat.append(self._resize(basis_vec, new_shape).reshape(-1)) + resize_matrix = torch.stack(mat) + return torch.linalg.pinv(resize_matrix) + + def resize_patch_embed( + self, + patch_embed: Tensor, + base_patch_size: Tuple[int, int], + new_patch_size: Tuple[int, int], + ) -> Tensor: + """Resize patch embedding kernel to new patch size. + + Args: + patch_embed: Original patch embedding weight (out_ch, in_ch, H, W). + base_patch_size: Original patch size. + new_patch_size: Target patch size. + + Returns: + Resized patch embedding weight. + """ + if base_patch_size == new_patch_size: + return patch_embed + + if new_patch_size not in self.pinvs: + self.pinvs[new_patch_size] = self._calculate_pinv( + base_patch_size, new_patch_size, device=patch_embed.device + ) + pinv = self.pinvs[new_patch_size] + + def resample_patch_embed(patch_embed: Tensor): + h, w = new_patch_size + original_dtype = patch_embed.dtype + patch_embed_float = patch_embed.float() + resampled_kernel = pinv @ patch_embed_float.reshape(-1) + resampled_kernel = resampled_kernel.to(original_dtype) + return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w) + + v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1) + return v_resample_patch_embed(patch_embed) diff --git a/vibetoken/modeling/encoder.py b/vibetoken/modeling/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d78db0b765557a0d1c97595369c22a4b557bfe --- /dev/null +++ b/vibetoken/modeling/encoder.py @@ -0,0 +1,234 @@ +"""Resolution-aware encoder for VibeToken. + +Vision Transformer-based encoder with flexible patch sizes for variable-resolution inputs. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +from einops import rearrange +from torch import Tensor, vmap +import numpy as np + +from .blocks import ResidualAttentionBlock, _expand_token +from .embeddings import FuzzyEmbedding, to_2tuple + + +class ResolutionEncoder(nn.Module): + """Vision Transformer encoder with flexible resolution support. + + Encodes images into latent tokens using a ViT architecture with + support for variable input resolutions and patch sizes. + """ + + # Model size configurations + MODEL_CONFIGS = { + "small": {"width": 512, "num_layers": 8, "num_heads": 8}, + "base": {"width": 768, "num_layers": 12, "num_heads": 12}, + "large": {"width": 1024, "num_layers": 24, "num_heads": 16}, + } + + def __init__(self, config): + """Initialize ResolutionEncoder. + + Args: + config: OmegaConf config with model parameters. + """ + super().__init__() + self.config = config + + # Extract config values + vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model + self.patch_size = getattr(vq_config, 'vit_enc_patch_size', 32) + self.model_size = getattr(vq_config, 'vit_enc_model_size', 'large') + self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256) + self.token_size = getattr(vq_config, 'token_size', 256) + self.is_legacy = getattr(vq_config, 'is_legacy', False) + + # Handle VAE mode (doubles token size for mean+std) + quantize_mode = getattr(vq_config, 'quantize_mode', 'vq') + if quantize_mode == "vae": + self.token_size = self.token_size * 2 + + # Get model dimensions from config + model_cfg = self.MODEL_CONFIGS[self.model_size] + self.width = model_cfg["width"] + self.num_layers = model_cfg["num_layers"] + self.num_heads = model_cfg["num_heads"] + + # Patch embedding + self.patch_embed = nn.Conv2d( + in_channels=3, out_channels=self.width, + kernel_size=self.patch_size, stride=self.patch_size, bias=True + ) + + # Embeddings + scale = self.width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) + self.positional_embedding = FuzzyEmbedding(1024, scale, self.width) + self.latent_token_positional_embedding = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.width) + ) + self.ln_pre = nn.LayerNorm(self.width) + + # Transformer layers + self.transformer = nn.ModuleList([ + ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) + for _ in range(self.num_layers) + ]) + + # Output projection + self.ln_post = nn.LayerNorm(self.width) + self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) + + # Cache for pseudo-inverse matrices + self.pinvs = {} + + def _resize(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: + """Bilinear resize of 2D tensor.""" + x_resized = F.interpolate( + x[None, None, ...], shape, mode="bilinear", antialias=False + ) + return x_resized[0, 0, ...] + + def _calculate_pinv( + self, + old_shape: Tuple[int, int], + new_shape: Tuple[int, int], + device: torch.device, + ) -> Tensor: + """Calculate pseudo-inverse of resize matrix for FlexiViT.""" + mat = [] + for i in range(np.prod(old_shape)): + basis_vec = torch.zeros(old_shape, device=device) + basis_vec[np.unravel_index(i, old_shape)] = 1.0 + mat.append(self._resize(basis_vec, new_shape).reshape(-1)) + resize_matrix = torch.stack(mat) + return torch.linalg.pinv(resize_matrix) + + def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: Tuple[int, int]) -> Tensor: + """Resize patch embedding kernel to new patch size (FlexiViT). + + Args: + patch_embed: Original weight tensor (out_ch, in_ch, H, W). + new_patch_size: Target (H, W) patch size. + + Returns: + Resized weight tensor. + """ + base_size = to_2tuple(self.patch_size) + if base_size == new_patch_size: + return patch_embed + + if new_patch_size not in self.pinvs: + self.pinvs[new_patch_size] = self._calculate_pinv( + base_size, new_patch_size, device=patch_embed.device + ) + pinv = self.pinvs[new_patch_size] + + def resample_patch_embed(pe: Tensor) -> Tensor: + h, w = new_patch_size + original_dtype = pe.dtype + resampled = pinv @ pe.float().reshape(-1) + return rearrange(resampled.to(original_dtype), "(h w) -> h w", h=h, w=w) + + v_resample = vmap(vmap(resample_patch_embed, 0, 0), 1, 1) + return v_resample(patch_embed) + + def apply_flexivit_patch_embed(self, x: Tensor, target_patch_size: Tuple[int, int]) -> Tensor: + """Apply patch embedding with flexible patch size. + + Args: + x: Input image tensor (B, 3, H, W). + target_patch_size: Target patch size (H, W). + + Returns: + Patch embeddings (B, C, grid_H, grid_W). + """ + patch_size = to_2tuple(target_patch_size) + + if patch_size == to_2tuple(self.patch_size): + weight = self.patch_embed.weight + else: + weight = self.resize_patch_embed(self.patch_embed.weight, patch_size) + + return F.conv2d(x, weight, bias=self.patch_embed.bias, stride=patch_size) + + def forward( + self, + pixel_values: torch.Tensor, + latent_tokens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encode_patch_size: Optional[Tuple[int, int]] = None, + ) -> torch.Tensor: + """Encode images to latent tokens. + + Args: + pixel_values: Input images (B, 3, H, W), values in [0, 1]. + latent_tokens: Learnable latent tokens (num_latent, width). + attention_mask: Optional attention mask. + encode_patch_size: Optional custom patch size for encoding. + + Returns: + Encoded latent features (B, token_size, 1, num_latent). + """ + batch_size, _, H, W = pixel_values.shape + + # Determine patch size + if encode_patch_size is None: + target_patch_size = (self.patch_size, self.patch_size) + elif isinstance(encode_patch_size, int): + target_patch_size = (encode_patch_size, encode_patch_size) + else: + target_patch_size = encode_patch_size + + # Apply flexible patch embedding + x = self.apply_flexivit_patch_embed(pixel_values, target_patch_size) + + # Flatten spatial dimensions + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) # (B, num_patches, width) + + # Add class embedding + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + + # Compute grid dimensions + grid_height = H // target_patch_size[0] + grid_width = W // target_patch_size[1] + + # Add positional embeddings to latent tokens + num_latent = latent_tokens.shape[0] + latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype) + latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)[:num_latent] + + # Add positional embeddings to image patches + x = x + self.positional_embedding(grid_height, grid_width, train=False, dtype=x.dtype) + + # Concatenate image patches and latent tokens + x = torch.cat([x, latent_tokens], dim=1) + + # Pre-norm and reshape for transformer + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # (seq_len, B, width) + + # Apply transformer layers + for layer in self.transformer: + x = layer(x, attention_mask=None) + + x = x.permute(1, 0, 2) # (B, seq_len, width) + + # Extract latent tokens + latent_tokens = x[:, 1 + grid_height * grid_width:] + latent_tokens = self.ln_post(latent_tokens) + + # Reshape and project to token size + if self.is_legacy: + latent_tokens = latent_tokens.reshape(batch_size, self.width, num_latent, 1) + else: + latent_tokens = latent_tokens.reshape(batch_size, num_latent, self.width, 1).permute(0, 2, 1, 3) + + latent_tokens = self.conv_out(latent_tokens) + latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, num_latent) + + return latent_tokens diff --git a/vibetoken/modeling/vibetoken.py b/vibetoken/modeling/vibetoken.py new file mode 100644 index 0000000000000000000000000000000000000000..f60b76472dba6dc78659bee63baead4625c9ea99 --- /dev/null +++ b/vibetoken/modeling/vibetoken.py @@ -0,0 +1,273 @@ +"""VibeToken model - Main model class for image tokenization. + +Combines encoder, quantizer, and decoder into a unified model for +encoding images to discrete tokens and decoding back to images. +""" + +import torch +import torch.nn as nn +from typing import Dict, Any, Optional, Tuple, Union +from einops import rearrange +from omegaconf import OmegaConf + +from .encoder import ResolutionEncoder +from .decoder import ResolutionDecoder +from ..quantizer import VectorQuantizer, VectorQuantizerMVQ + + +class VibeToken(nn.Module): + """VibeToken image tokenizer model. + + A Vision Transformer-based image tokenizer that encodes images into + discrete tokens and decodes them back to images. Supports multiple + quantization modes (VQ, MVQ, VAE) and variable resolutions. + """ + + def __init__(self, config: Union[dict, Any]): + """Initialize VibeToken model. + + Args: + config: Configuration dict or OmegaConf object with model parameters. + """ + super().__init__() + + if isinstance(config, dict): + config = OmegaConf.create(config) + + self.config = config + + # Get model config + vq_config = config.model.vq_model if hasattr(config.model, 'vq_model') else config.model + + # Quantization mode + self.quantize_mode = getattr(vq_config, 'quantize_mode', 'vq') + if self.quantize_mode not in ["vq", "vae", "mvq"]: + raise ValueError(f"Unsupported quantize mode: {self.quantize_mode}") + + # Build encoder and decoder + self.encoder = ResolutionEncoder(config) + self.decoder = ResolutionDecoder(config) + + # Latent tokens (learnable queries for encoder) + self.num_latent_tokens = getattr(vq_config, 'num_latent_tokens', 256) + scale = self.encoder.width ** -0.5 + self.latent_tokens = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.encoder.width) + ) + + # Build quantizer based on mode + if self.quantize_mode == "vq": + self.quantize = VectorQuantizer( + codebook_size=getattr(vq_config, 'codebook_size', 32768), + token_size=getattr(vq_config, 'token_size', 256), + commitment_cost=getattr(vq_config, 'commitment_cost', 0.25), + use_l2_norm=getattr(vq_config, 'use_l2_norm', False), + ) + elif self.quantize_mode == "mvq": + self.quantize = VectorQuantizerMVQ( + codebook_size=getattr(vq_config, 'codebook_size', 32768), + token_size=getattr(vq_config, 'token_size', 256), + commitment_cost=getattr(vq_config, 'commitment_cost', 0.25), + use_l2_norm=getattr(vq_config, 'use_l2_norm', False), + num_codebooks=getattr(vq_config, 'num_codebooks', 8), + ) + elif self.quantize_mode == "vae": + # VAE mode uses DiagonalGaussianDistribution directly in forward + from ..quantizer.vector_quantizer import DiagonalGaussianDistribution + self.quantize = DiagonalGaussianDistribution + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, module: nn.Module): + """Initialize weights for module.""" + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + def encode( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encode_patch_size: Optional[Tuple[int, int]] = None, + length: Optional[int] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Encode images to quantized latent tokens. + + Args: + x: Input images (B, 3, H, W), values in [0, 1]. + attention_mask: Optional attention mask. + encode_patch_size: Optional custom patch size for encoding. + length: Optional token length limit. + + Returns: + z_quantized: Quantized latent features. + result_dict: Dictionary with token indices and optional losses. + """ + # Select latent tokens + if length is not None: + latent_tokens = self.latent_tokens[:length + 1] + else: + latent_tokens = self.latent_tokens + + # Encode through ViT encoder + z = self.encoder( + pixel_values=x, + latent_tokens=latent_tokens, + attention_mask=attention_mask, + encode_patch_size=encode_patch_size, + ) + + # Quantize + if self.quantize_mode in ["vq", "mvq"]: + z_quantized, result_dict = self.quantize(z) + elif self.quantize_mode == "vae": + posteriors = self.quantize(z) + z_quantized = posteriors.sample() + result_dict = {"posteriors": posteriors} + + return z_quantized, result_dict + + def decode( + self, + z_quantized: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + decode_patch_size: Optional[int] = None, + ) -> torch.Tensor: + """Decode quantized features to images. + + Args: + z_quantized: Quantized latent features. + attention_mask: Optional attention mask. + height: Target image height. + width: Target image width. + decode_patch_size: Optional custom patch size for decoding. + + Returns: + Decoded images (B, 3, height, width), values approximately in [0, 1]. + """ + z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype) + decoded = self.decoder( + z_quantized, + attention_mask=attention_mask, + height=height, + width=width, + decode_patch_size=decode_patch_size, + ) + return decoded + + def decode_tokens( + self, + tokens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + decode_patch_size: Optional[int] = None, + ) -> torch.Tensor: + """Decode discrete tokens directly to images. + + This is the main interface used by generators. + + Args: + tokens: Token indices. Shape depends on quantize_mode: + - VQ: (B, seq_len) or (B, 1, seq_len) + - MVQ: (B, num_codebooks, seq_len) or (B, seq_len, 1) + - VAE: Continuous latent (B, C, H, W) + attention_mask: Optional attention mask. + height: Target image height. + width: Target image width. + decode_patch_size: Optional custom patch size for decoding. + + Returns: + Decoded images (B, 3, height, width), values approximately in [0, 1]. + """ + if self.quantize_mode == "vq": + tokens = tokens.squeeze(1) + batch, seq_len = tokens.shape + z_quantized = self.quantize.get_codebook_entry(tokens.reshape(-1)) + z_quantized = z_quantized.reshape(batch, 1, seq_len, -1) + z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() + elif self.quantize_mode == "mvq": + z_quantized = self.quantize.get_codebook_entry(tokens) + elif self.quantize_mode == "vae": + z_quantized = tokens + + z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype) + decoded = self.decode( + z_quantized, + attention_mask=attention_mask, + height=height, + width=width, + decode_patch_size=decode_patch_size, + ) + return decoded + + def forward( + self, + x: torch.Tensor, + key_attention_mask: Optional[torch.Tensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Full encode-decode forward pass. + + Args: + x: Input images (B, 3, H, W), values in [0, 1]. + key_attention_mask: Optional attention mask. + height: Target output height (default: same as input). + width: Target output width (default: same as input). + + Returns: + decoded: Reconstructed images. + result_dict: Dictionary with token indices and losses. + """ + if height is None: + _, _, height, width = x.shape + + z_quantized, result_dict = self.encode(x, attention_mask=key_attention_mask) + z_quantized = z_quantized.to(self.decoder.decoder_embed.weight.dtype) + decoded = self.decode(z_quantized, attention_mask=key_attention_mask, height=height, width=width) + + return decoded, result_dict + + @classmethod + def from_pretrained( + cls, + config_path: str, + checkpoint_path: str, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + ) -> "VibeToken": + """Load pretrained model from config and checkpoint. + + Args: + config_path: Path to YAML config file. + checkpoint_path: Path to model checkpoint (.pt or .bin). + device: Device to load model on. + dtype: Optional dtype for model parameters. + + Returns: + Loaded VibeToken model. + """ + config = OmegaConf.load(config_path) + model = cls(config) + + state_dict = torch.load(checkpoint_path, map_location="cpu") + model.load_state_dict(state_dict) + + model.eval() + model.requires_grad_(False) + model.to(device) + + if dtype is not None: + model.to(dtype) + + return model diff --git a/vibetoken/quantizer/__init__.py b/vibetoken/quantizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2044a6a5301a12a660da4e20c79a63a89bf5c0ed --- /dev/null +++ b/vibetoken/quantizer/__init__.py @@ -0,0 +1,6 @@ +"""Quantizer modules for VibeToken.""" + +from .vector_quantizer import VectorQuantizer +from .mvq import VectorQuantizerMVQ + +__all__ = ["VectorQuantizer", "VectorQuantizerMVQ"] diff --git a/vibetoken/quantizer/mvq.py b/vibetoken/quantizer/mvq.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8033572653783cfc25b4174b6228069b365d8a --- /dev/null +++ b/vibetoken/quantizer/mvq.py @@ -0,0 +1,183 @@ +"""Multi-codebook Vector Quantizer (MVQ) for VibeToken. + +Uses multiple independent codebooks for richer discrete representations. +""" + +import torch +import torch.nn as nn +from typing import Tuple, Dict, Any + +from .vector_quantizer import VectorQuantizer + + +class VectorQuantizerMVQ(nn.Module): + """Multi-codebook Vector Quantizer. + + Splits the latent representation into multiple parts, each quantized + by an independent codebook. This allows for more expressive discrete + representations. + """ + + def __init__( + self, + codebook_size: int, + token_size: int, + commitment_cost: float = 0.25, + use_l2_norm: bool = False, + num_codebooks: int = 8, + ): + """Initialize MVQ. + + Args: + codebook_size: Total codebook size (divided among codebooks). + token_size: Total token dimension (divided among codebooks). + commitment_cost: Weight for commitment loss. + use_l2_norm: Whether to L2-normalize embeddings. + num_codebooks: Number of independent codebooks. + """ + super().__init__() + self.num_codebooks = num_codebooks + self.codebooks = nn.ModuleList() + + for _ in range(num_codebooks): + codebook = VectorQuantizer( + codebook_size=codebook_size // num_codebooks, + token_size=token_size // num_codebooks, + commitment_cost=commitment_cost, + use_l2_norm=use_l2_norm, + ) + self.codebooks.append(codebook) + + def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Quantize features using multiple codebooks. + + Args: + features: Input features of shape (B, C, H, W). + + Returns: + z_quantized: Quantized features of shape (B, C, H, W). + result_dict: Dictionary with losses and indices. + """ + latent_features = [] + all_result_dicts = [] + chunk_size = features.shape[1] // self.num_codebooks + splited_features = features.split(chunk_size, dim=1) + + for i, codebook in enumerate(self.codebooks): + latent_feature, result_dict = codebook(splited_features[i].float()) + latent_features.append(latent_feature.to(features.dtype)) + all_result_dicts.append(result_dict) + + # Concatenate quantized features + z_quantized = torch.cat(latent_features, dim=1) + + # Aggregate losses + global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks + global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks + global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks + + # Stack indices: shape (B, num_codebooks, H, W) + all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1) + + result_dict = dict( + quantizer_loss=global_quantizer_loss, + commitment_loss=global_commitment_loss, + codebook_loss=global_codebook_loss, + min_encoding_indices=all_indices + ) + + return z_quantized, result_dict + + def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor: + """Get codebook entries for multi-codebook indices. + + Args: + indices: Tensor of shape: + - (B, num_codebooks): single token per codebook + - (B, num_codebooks, seq_len): sequence of tokens per codebook + - (B, num_codebooks, H, W): 2D spatial tokens per codebook + - (B, seq_len, 1): generator format (single codebook index per position) + + Returns: + z_quantized: Quantized features. + """ + if len(indices.shape) == 2: + # Shape: (B, num_codebooks) - each entry is a token index + latent_features = [] + for i, codebook in enumerate(self.codebooks): + sub_indices = indices[:, i] + latent_feature = codebook.get_codebook_entry(sub_indices) + latent_features.append(latent_feature) + return torch.cat(latent_features, dim=-1) + + elif len(indices.shape) == 3: + batch_size, dim1, dim2 = indices.shape + + # Check if this is (B, num_codebooks, seq_len) or (B, seq_len, 1) + if dim1 == self.num_codebooks: + # Shape: (B, num_codebooks, seq_len) - from encode() + seq_len = dim2 + latent_features = [] + for i, codebook in enumerate(self.codebooks): + sub_indices = indices[:, i, :] # (B, seq_len) + latent_feature = codebook.get_codebook_entry(sub_indices.flatten()) + latent_feature = latent_feature.view(batch_size, seq_len, -1) + latent_features.append(latent_feature) + + # Concatenate along feature dimension: (B, seq_len, C) + z_quantized = torch.cat(latent_features, dim=-1) + # Reshape to (B, C, 1, seq_len) for decoder + z_quantized = z_quantized.permute(0, 2, 1).unsqueeze(2) + return z_quantized + + elif dim2 == 1: + # Shape: (B, seq_len, 1) - common format from generator + indices = indices.squeeze(-1) # (B, seq_len) + seq_len = dim1 + + # For generator format, all codebooks use the same indices + latent_features = [] + for i, codebook in enumerate(self.codebooks): + latent_feature = codebook.get_codebook_entry(indices.flatten()) + latent_feature = latent_feature.view(batch_size, seq_len, -1) + latent_features.append(latent_feature) + + z_quantized = torch.cat(latent_features, dim=-1) # (B, seq_len, C) + z_quantized = z_quantized.permute(0, 2, 1).unsqueeze(2) + return z_quantized + else: + raise ValueError(f"Ambiguous 3D indices shape: {indices.shape}. " + f"Expected (B, {self.num_codebooks}, seq_len) or (B, seq_len, 1)") + + elif len(indices.shape) == 4: + # Shape: (B, num_codebooks, H, W) + batch_size, _, height, width = indices.shape + latent_features = [] + for i, codebook in enumerate(self.codebooks): + sub_indices = indices[:, i] # (B, H, W) + latent_feature = codebook.get_codebook_entry(sub_indices.flatten()) + latent_feature = latent_feature.view(batch_size, height, width, -1) + latent_features.append(latent_feature) + + # Concatenate and permute to (B, C, H, W) + latent_features = torch.cat(latent_features, dim=-1) + return latent_features.permute(0, 3, 1, 2).contiguous() + else: + raise NotImplementedError(f"Unsupported indices shape: {indices.shape}") + + def f_to_idx(self, features: torch.Tensor) -> torch.Tensor: + """Convert features directly to indices without quantization. + + Args: + features: Input features. + + Returns: + indices: Token indices for each codebook. + """ + indices = [] + chunk_size = features.shape[-1] // self.num_codebooks + splited_features = features.split(chunk_size, dim=-1) + for i, codebook in enumerate(self.codebooks): + indices.append(codebook.f_to_idx(splited_features[i])) + indices = torch.stack(indices, dim=1) + return indices diff --git a/vibetoken/quantizer/vector_quantizer.py b/vibetoken/quantizer/vector_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2bfc32961e252d3fa58ce8a8845268b5f9f88e42 --- /dev/null +++ b/vibetoken/quantizer/vector_quantizer.py @@ -0,0 +1,174 @@ +"""Vector Quantizer for VibeToken. + +Simplified for inference-only use. Training-specific features removed. + +Reference: + https://github.com/CompVis/taming-transformers + https://github.com/google-research/magvit +""" + +from typing import Mapping, Text, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.amp import autocast + + +class VectorQuantizer(nn.Module): + """Vector Quantizer module for discrete tokenization. + + Converts continuous latent representations to discrete tokens using + a learned codebook. + """ + + def __init__( + self, + codebook_size: int = 1024, + token_size: int = 256, + commitment_cost: float = 0.25, + use_l2_norm: bool = False, + ): + """Initialize VectorQuantizer. + + Args: + codebook_size: Number of entries in the codebook. + token_size: Dimension of each codebook entry. + commitment_cost: Weight for commitment loss (unused in inference). + use_l2_norm: Whether to L2-normalize embeddings. + """ + super().__init__() + self.codebook_size = codebook_size + self.token_size = token_size + self.commitment_cost = commitment_cost + self.use_l2_norm = use_l2_norm + + self.embedding = nn.Embedding(codebook_size, token_size) + self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) + + @autocast('cuda', enabled=False) + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Quantize input tensor. + + Args: + z: Input tensor of shape (B, C, H, W). + + Returns: + z_quantized: Quantized tensor of shape (B, C, H, W). + result_dict: Dictionary containing min_encoding_indices and losses. + """ + z = z.float() + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = rearrange(z, 'b h w c -> (b h w) c') + + if self.use_l2_norm: + z_flattened = nn.functional.normalize(z_flattened, dim=-1) + embedding = nn.functional.normalize(self.embedding.weight, dim=-1) + else: + embedding = self.embedding.weight + + # Compute distances to codebook entries + d = (torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) - + 2 * torch.einsum('bd,dn->bn', z_flattened, embedding.T)) + + min_encoding_indices = torch.argmin(d, dim=1) + z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) + + if self.use_l2_norm: + z_norm = nn.functional.normalize(z, dim=-1) + else: + z_norm = z + + # Compute losses (for compatibility, not used in inference) + commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z_norm) ** 2) + codebook_loss = torch.mean((z_quantized - z_norm.detach()) ** 2) + loss = commitment_loss + codebook_loss + + # Straight-through estimator: preserve gradients + z_quantized = z_norm + (z_quantized - z_norm).detach() + + # Reshape back to original format + z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() + + result_dict = dict( + quantizer_loss=loss, + commitment_loss=commitment_loss, + codebook_loss=codebook_loss, + min_encoding_indices=min_encoding_indices.view( + z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3] + ) + ) + + return z_quantized, result_dict + + @autocast('cuda', enabled=False) + def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor: + """Get codebook entries for given indices. + + Args: + indices: Token indices, shape (N,) or (N, vocab_size) for soft indices. + + Returns: + Codebook entries, shape (N, token_size). + """ + indices = indices.long() + if len(indices.shape) == 1: + z_quantized = self.embedding(indices) + elif len(indices.shape) == 2: + # Soft indices (weighted sum of embeddings) + z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight) + else: + raise NotImplementedError(f"Unsupported indices shape: {indices.shape}") + + if self.use_l2_norm: + z_quantized = nn.functional.normalize(z_quantized, dim=-1) + + return z_quantized + + +class DiagonalGaussianDistribution: + """Diagonal Gaussian distribution for VAE-style quantization. + + Used when quantize_mode='vae' instead of discrete VQ. + """ + + @autocast('cuda', enabled=False) + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + """Initialize Gaussian distribution. + + Args: + parameters: Tensor of shape (B, 2*C, H, W) containing mean and logvar. + deterministic: If True, sample() returns mean (no noise). + """ + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + @autocast('cuda', enabled=False) + def sample(self) -> torch.Tensor: + """Sample from the distribution.""" + x = self.mean.float() + self.std.float() * torch.randn( + self.mean.shape, device=self.parameters.device + ) + return x + + @autocast('cuda', enabled=False) + def mode(self) -> torch.Tensor: + """Return the mode (mean) of the distribution.""" + return self.mean + + @autocast('cuda', enabled=False) + def kl(self) -> torch.Tensor: + """Compute KL divergence from standard Gaussian.""" + if self.deterministic: + return torch.Tensor([0.0]) + return 0.5 * torch.sum( + torch.pow(self.mean.float(), 2) + self.var.float() - 1.0 - self.logvar.float(), + dim=[1, 2] + ) diff --git a/vibetoken/tokenizer.py b/vibetoken/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6435b6c3ff44f506ab18f67f0cdda4bb96a4475a --- /dev/null +++ b/vibetoken/tokenizer.py @@ -0,0 +1,439 @@ +"""High-level VibeToken Tokenizer API. + +Provides a clean, user-friendly interface for image tokenization. +""" + +import os +from pathlib import Path +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from PIL import Image +import numpy as np +from omegaconf import OmegaConf + +from .modeling.vibetoken import VibeToken + + +def center_crop_to_multiple(image: Image.Image, multiple: int = 32) -> Image.Image: + """Center crop image to dimensions divisible by multiple. + + Args: + image: PIL Image + multiple: Dimensions must be divisible by this value (default: 32) + + Returns: + Cropped PIL Image + """ + w, h = image.size + new_w = (w // multiple) * multiple + new_h = (h // multiple) * multiple + + if new_w == w and new_h == h: + return image + + left = (w - new_w) // 2 + top = (h - new_h) // 2 + right = left + new_w + bottom = top + new_h + + return image.crop((left, top, right, bottom)) + + +def get_auto_patch_size(height: int, width: int) -> int: + """Automatically determine optimal patch size based on resolution. + + Rules: + - <= 256x256: patch size 8 + - <= 512x512: patch size 16 + - > 512x512: patch size 32 + + Args: + height: Image height + width: Image width + + Returns: + Optimal patch size as int + """ + # Check <= 256x256 + if height <= 256 and width <= 256: + return 8 + + # Check <= 512x512 + if height <= 512 and width <= 512: + return 16 + + # Larger resolutions + return 32 + + +def auto_preprocess_image(image: Image.Image, verbose: bool = True) -> Tuple[Image.Image, int, dict]: + """Automatically preprocess image: center crop to divisible by 32 and determine optimal patch size. + + Args: + image: PIL Image + verbose: Whether to print preprocessing info + + Returns: + image: Preprocessed PIL Image + patch_size: Optimal patch size + info: Dictionary with preprocessing details + """ + original_size = image.size # (W, H) + + # Center crop to dimensions divisible by 32 + image = center_crop_to_multiple(image, multiple=32) + cropped_size = image.size # (W, H) + + # Get optimal patch size + patch_size = get_auto_patch_size(cropped_size[1], cropped_size[0]) # (H, W) + + info = { + "original_size": original_size, + "cropped_size": cropped_size, + "patch_size": patch_size, + "was_cropped": original_size != cropped_size, + } + + if verbose: + if info["was_cropped"]: + print(f"Center cropped: {original_size[0]}x{original_size[1]} -> {cropped_size[0]}x{cropped_size[1]} (divisible by 32)") + print(f"Auto patch size: {patch_size}x{patch_size}") + + return image, patch_size, info + + +class VibeTokenTokenizer: + """High-level API for VibeToken image tokenization. + + Provides simple encode/decode methods for converting images to/from + discrete tokens. + + Example: + >>> tokenizer = VibeTokenTokenizer.from_config("configs/vibetoken_ll.yaml", "model.bin") + >>> tokens = tokenizer.encode(images) # (B, num_codebooks, seq_len) + >>> reconstructed = tokenizer.decode(tokens, height=512, width=512) + """ + + def __init__( + self, + model: VibeToken, + device: str = "cuda", + ): + """Initialize tokenizer with a VibeToken model. + + Args: + model: Loaded VibeToken model. + device: Device for inference. + """ + self.model = model + self.device = device + self.num_codebooks = getattr(model.config.model.vq_model, 'num_codebooks', 1) \ + if hasattr(model.config.model, 'vq_model') else 1 + + @classmethod + def from_config( + cls, + config_path: str, + checkpoint_path: str, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + ) -> "VibeTokenTokenizer": + """Create tokenizer from config and checkpoint files. + + Args: + config_path: Path to YAML config file. + checkpoint_path: Path to model checkpoint. + device: Device for inference. + dtype: Optional dtype for model. + + Returns: + VibeTokenTokenizer instance. + """ + model = VibeToken.from_pretrained( + config_path=config_path, + checkpoint_path=checkpoint_path, + device=device, + dtype=dtype, + ) + return cls(model=model, device=device) + + @classmethod + def from_pretrained( + cls, + model_name: str, + checkpoint_path: str, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + ) -> "VibeTokenTokenizer": + """Load tokenizer by model name. + + Args: + model_name: Model variant name ("ll" for Large-Large, "sl" for Small-Large). + checkpoint_path: Path to model checkpoint. + device: Device for inference. + dtype: Optional dtype for model. + + Returns: + VibeTokenTokenizer instance. + """ + # Get config path based on model name + package_dir = Path(__file__).parent.parent + config_map = { + "ll": "vibetoken_ll.yaml", + "sl": "vibetoken_sl.yaml", + } + + if model_name not in config_map: + raise ValueError(f"Unknown model name: {model_name}. Choose from: {list(config_map.keys())}") + + config_path = package_dir / "configs" / config_map[model_name] + + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + return cls.from_config( + config_path=str(config_path), + checkpoint_path=checkpoint_path, + device=device, + dtype=dtype, + ) + + def preprocess( + self, + images: Union[torch.Tensor, np.ndarray, Image.Image, list], + ) -> torch.Tensor: + """Preprocess images for encoding. + + Args: + images: Input images. Can be: + - torch.Tensor (B, 3, H, W) or (3, H, W), values in [0, 1] + - numpy array (B, H, W, 3) or (H, W, 3), values in [0, 255] + - PIL Image + - List of PIL Images + + Returns: + Preprocessed tensor (B, 3, H, W), values in [0, 1]. + """ + if isinstance(images, Image.Image): + images = [images] + + if isinstance(images, list): + # List of PIL images + tensors = [] + for img in images: + if isinstance(img, Image.Image): + arr = np.array(img.convert("RGB")) + tensor = torch.from_numpy(arr).float().permute(2, 0, 1) / 255.0 + tensors.append(tensor) + else: + raise ValueError(f"Unsupported image type in list: {type(img)}") + images = torch.stack(tensors) + elif isinstance(images, np.ndarray): + if images.ndim == 3: + images = images[None] # Add batch dim + # Assume (B, H, W, C) format + images = torch.from_numpy(images).float().permute(0, 3, 1, 2) / 255.0 + elif isinstance(images, torch.Tensor): + if images.ndim == 3: + images = images.unsqueeze(0) # Add batch dim + # Assume already in [0, 1] range + if images.max() > 1.0: + images = images / 255.0 + else: + raise ValueError(f"Unsupported image type: {type(images)}") + + return images.to(self.device) + + @torch.no_grad() + def encode( + self, + images: Union[torch.Tensor, np.ndarray, Image.Image, list], + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + """Encode images to discrete tokens. + + Args: + images: Input images (see preprocess for supported formats). + patch_size: Optional encoding patch size. Can be int or tuple (H, W). + return_dict: Whether to return result dictionary with additional info. + + Returns: + tokens: Token indices. Shape depends on quantize_mode: + - VQ: (B, seq_len) + - MVQ: (B, num_codebooks, seq_len) + result_dict (optional): Dictionary with additional encoding info. + """ + images = self.preprocess(images) + + # Convert int to tuple if needed + if isinstance(patch_size, int): + encode_ps = (patch_size, patch_size) + else: + encode_ps = patch_size # Already tuple or None + _, result_dict = self.model.encode(images, encode_patch_size=encode_ps) + + # Extract token indices + tokens = result_dict['min_encoding_indices'] + + # Reshape based on quantize mode + if self.model.quantize_mode == "mvq": + # Already in (B, num_codebooks, H, W) format, flatten spatial + B, num_cb, H, W = tokens.shape + tokens = tokens.reshape(B, num_cb, H * W) + else: + # VQ mode: (B, H, W) -> (B, H*W) + B, H, W = tokens.shape + tokens = tokens.reshape(B, H * W) + + if return_dict: + return tokens, result_dict + return tokens + + @torch.no_grad() + def decode( + self, + tokens: torch.Tensor, + height: int, + width: int, + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + clamp: bool = True, + ) -> torch.Tensor: + """Decode tokens back to images. + + Args: + tokens: Token indices from encode(). + height: Target image height. + width: Target image width. + patch_size: Optional decoding patch size. Can be int or tuple (H, W). + clamp: Whether to clamp output to [0, 1]. + + Returns: + Decoded images (B, 3, height, width), values in [0, 1]. + """ + tokens = tokens.to(self.device) + + # Convert int to tuple if needed + if isinstance(patch_size, int): + decode_ps = (patch_size, patch_size) + else: + decode_ps = patch_size # Already tuple or None + + decoded = self.model.decode_tokens( + tokens, + height=height, + width=width, + decode_patch_size=decode_ps, + ) + + if clamp: + decoded = torch.clamp(decoded, 0.0, 1.0) + + return decoded + + @torch.no_grad() + def reconstruct( + self, + images: Union[torch.Tensor, np.ndarray, Image.Image, list], + encode_patch_size: Optional[Union[int, Tuple[int, int]]] = None, + decode_patch_size: Optional[Union[int, Tuple[int, int]]] = None, + target_height: Optional[int] = None, + target_width: Optional[int] = None, + ) -> torch.Tensor: + """Encode and decode images (full reconstruction). + + Args: + images: Input images. + encode_patch_size: Optional encoding patch size. Can be int or tuple (H, W). + decode_patch_size: Optional decoding patch size. Can be int or tuple (H, W). + target_height: Target output height (default: same as input). + target_width: Target output width (default: same as input). + + Returns: + Reconstructed images (B, 3, H, W), values in [0, 1]. + """ + images = self.preprocess(images) + B, C, H, W = images.shape + + if target_height is None: + target_height = H + if target_width is None: + target_width = W + + # Encode with optional patch size + tokens = self.encode(images, patch_size=encode_patch_size) + + # Decode with optional patch size + decoded = self.decode( + tokens, + height=target_height, + width=target_width, + patch_size=decode_patch_size, + ) + + return decoded + + def to_pil(self, images: torch.Tensor) -> list: + """Convert tensor images to PIL Images. + + Args: + images: Tensor (B, 3, H, W), values in [0, 1]. + + Returns: + List of PIL Images. + """ + images = torch.clamp(images * 255, 0, 255) + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + return [Image.fromarray(img) for img in images] + + @torch.no_grad() + def auto_reconstruct( + self, + image: Image.Image, + verbose: bool = True, + ) -> Tuple[torch.Tensor, dict]: + """Automatically reconstruct image with optimal settings. + + This method: + 1. Center crops input to dimensions divisible by 32 + 2. Auto-determines optimal patch sizes based on resolution + 3. Returns output at same size as (cropped) input + + Args: + image: Input PIL Image + verbose: Whether to print auto-selected settings + + Returns: + reconstructed: Reconstructed image tensor (1, 3, H, W) + info: Dictionary with auto-selected settings + """ + # Use centralized auto_preprocess_image + original_size = image.size + image, patch_size, preprocess_info = auto_preprocess_image(image, verbose=verbose) + width, height = preprocess_info["cropped_size"] + + # Encode and decode + tokens = self.encode(image, patch_size=patch_size) + reconstructed = self.decode(tokens, height=height, width=width, patch_size=patch_size) + + info = { + "original_size": original_size, + "cropped_size": (width, height), + "patch_size": patch_size, + "token_shape": tokens.shape, + } + + return reconstructed, info + + @property + def codebook_size(self) -> int: + """Get total codebook size.""" + vq_config = self.model.config.model.vq_model if hasattr(self.model.config.model, 'vq_model') else self.model.config.model + return getattr(vq_config, 'codebook_size', 32768) + + @property + def num_latent_tokens(self) -> int: + """Get number of latent tokens.""" + return self.model.num_latent_tokens diff --git a/vibetokengen/__init__.py b/vibetokengen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vibetokengen/generate.py b/vibetokengen/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..1b61934df10639adc238bdd30dff9f9b9b0986d7 --- /dev/null +++ b/vibetokengen/generate.py @@ -0,0 +1,219 @@ +# Modified from: +# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py +# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py +import torch +import torch.nn as nn +from torch.nn import functional as F +import torch._dynamo.config +import torch._inductor.config +import copy + + +# torch._inductor.config.coordinate_descent_tuning = True +# torch._inductor.config.triton.unique_kernel_names = True +# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + +### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html +def top_k_top_p_filtering( + logits, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +def sample(logits, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, sample_logits=True): + logits = logits[:, -1, :] / max(temperature, 1e-5) + if top_k > 0 or top_p < 1.0: + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + probs = F.softmax(logits, dim=-1) + if sample_logits: + idx = torch.multinomial(probs, num_samples=1) + else: + _, idx = torch.topk(probs, k=1, dim=-1) + return idx, probs + + +def logits_to_probs(logits, temperature: float = 1.0, top_p: float = 1.0, top_k: int = None, **kwargs): + logits = logits / max(temperature, 1e-5) + if top_k > 0 or top_p < 1.0: + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, guidance_scale: float, guidance_scale_pow: float, target_h: float, target_w: float, **sampling_kwargs): + scale_pow = torch.ones((1), device=cond_idx.device) * guidance_scale_pow + scale_step = (1 + torch.cos( + ((0 / 256) ** scale_pow) * torch.pi)) / 2 + cfg_scale = (guidance_scale - 1) * scale_step + 1 + + next_token = model(None, cond_idx, input_pos, target_h=target_h, target_w=target_w) + model.setup_head_caches( + next_token.shape[0], + model.num_codebooks, + dtype=next_token.dtype, + device=next_token.device + ) + indices = [] + bs = next_token.shape[0] + for i in range(model.num_codebooks): + start_pos = torch.tensor([i], dtype=torch.int) + mask = model.output.head_causal_mask[:bs, None, start_pos] + if i == 0: + logits = model.output.forward_test(next_token, input_pos=start_pos, mask=mask) + else: + if guidance_scale > 0.0: + pred_idx = torch.cat([pred_idx, pred_idx]) + logits = model.output.forward_test(next_token, idx=pred_idx, input_pos=start_pos, mask=mask) + if guidance_scale > 0.0: + cond_logits, uncond_logits = torch.split(logits, len(logits) // 2, dim=0) + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + pred_idx = sample(logits, **sampling_kwargs)[0] + else: + pred_idx = sample(logits, **sampling_kwargs)[0] + indices.append(pred_idx) + indices = torch.stack(indices, dim=1) + return indices + + +def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, guidance_scale: float, guidance_scale_pow: float, cfg_flag: bool, target_h: float, target_w: float, token_idx: int, **sampling_kwargs): + assert input_pos.shape[-1] == 1 + scale_pow = torch.ones((1), device=x.device) * guidance_scale_pow + scale_step = (1 + torch.cos( + ((token_idx / 256) ** scale_pow) * torch.pi)) / 2 + cfg_scale = (guidance_scale - 1) * scale_step + 1 + + if guidance_scale > 0.0: + x_combined = torch.cat([x, x]) + next_token = model(x_combined, cond_idx=None, input_pos=input_pos, target_h=target_h, target_w=target_w) + else: + next_token = model(x, cond_idx=None, input_pos=input_pos, target_h=target_h, target_w=target_w) + model.setup_head_caches( + next_token.shape[0], + model.num_codebooks, + dtype=next_token.dtype, + device=next_token.device + ) + indices = [] + bs = next_token.shape[0] + for i in range(model.num_codebooks): + start_pos = torch.tensor([i], dtype=torch.int) + mask = model.output.head_causal_mask[:bs, None, start_pos] + if i == 0: + logits = model.output.forward_test(next_token, input_pos=start_pos, mask=mask) + else: + if guidance_scale > 0.0: + pred_idx = torch.cat([pred_idx, pred_idx]) + logits = model.output.forward_test(next_token, idx=pred_idx, input_pos=start_pos, mask=mask) + if guidance_scale > 0.0: + cond_logits, uncond_logits = torch.split(logits, len(logits) // 2, dim=0) + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + pred_idx = sample(logits, **sampling_kwargs)[0] + else: + pred_idx = sample(logits, **sampling_kwargs)[0] + indices.append(pred_idx) + indices = torch.stack(indices, dim=1) + return indices + + +def decode_n_tokens( + model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, + guidance_scale: float, guidance_scale_pow: float, cfg_interval: int, target_h: float, target_w: float, **sampling_kwargs + ): + new_tokens = [] + cfg_flag = True + for i in range(num_new_tokens): + # Actually better for Inductor to codegen attention here + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + next_token = decode_one_token(model, cur_token, input_pos, guidance_scale, guidance_scale_pow, cfg_flag, target_h, target_w, i+1, **sampling_kwargs) + input_pos += 1 + new_tokens.append(next_token.clone()) + cur_token = next_token + return new_tokens + + +@torch.no_grad() +def generate(model, cond, max_new_tokens, num_codebooks, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, + target_h=None, target_w=None, guidance_scale=16.0, guidance_scale_pow=2.5, **sampling_kwargs): + + if model.model_type == 'c2i': + if guidance_scale > 0.0: + cond_null = torch.ones_like(cond) * model.num_classes + cond_combined = torch.cat([cond, cond_null]) + else: + cond_combined = cond + T = 1 + else: + raise Exception("please check model type") + + T_new = T + max_new_tokens + max_seq_length = T_new + max_batch_size = cond.shape[0] + + device = cond.device + with torch.device(device): + max_batch_size_cfg = max_batch_size * 2 if guidance_scale > 0.0 else max_batch_size + model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, + dtype=model.tok_embeddings.codebooks[0].weight.dtype) + + if emb_masks is not None: + assert emb_masks.shape[0] == max_batch_size + assert emb_masks.shape[-1] == T + if guidance_scale > 0.0: + model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1) + else: + model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1) + + eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device) + model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix + + # create an empty tensor of the expected final shape and fill in the current tokens + seq = torch.empty((max_batch_size, num_codebooks, T_new), dtype=torch.int, device=device) + + input_pos = torch.arange(0, T, device=device) + next_token = prefill(model, cond_combined, input_pos, guidance_scale, guidance_scale_pow, target_h, target_w, **sampling_kwargs) + seq[:, :, T:T+1] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens = decode_n_tokens(model, next_token, input_pos, max_new_tokens - 1, guidance_scale, guidance_scale_pow, cfg_interval, target_h, target_w, + **sampling_kwargs) + seq[:, :, T+1:] = torch.cat(generated_tokens, dim=-1) + + return seq[:, :, T:] + + diff --git a/vibetokengen/model.py b/vibetokengen/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f45c58d6d74f0428a26e6e6b3ac241db270b6191 --- /dev/null +++ b/vibetokengen/model.py @@ -0,0 +1,777 @@ +# Modified from: +# VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py +# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py +# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py +# llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py +# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py +from dataclasses import dataclass +from typing import Optional, List + +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +from contextlib import nullcontext +import logging + +logger = logging.getLogger(__name__) + +def find_multiple(n: int, k: int): + if n % k == 0: + return n + return n + k - (n % k) + +class DropPath(torch.nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layer: int = 32 + n_head: int = 32 + n_kv_head: Optional[int] = None + n_output_layer: int = 1 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + rope_base: float = 10000 + norm_eps: float = 1e-5 + initializer_range: float = 0.02 + + token_dropout_p: float = 0.1 + attn_dropout_p: float = 0.0 + resid_dropout_p: float = 0.1 + ffn_dropout_p: float = 0.1 + drop_path_rate: float = 0.0 + + num_codebooks: int = 8 + num_classes: int = 1000 + caption_dim: int = 2048 + class_dropout_prob: float = 0.1 + model_type: str = 'c2i' + extra_layers: str = None + capping: float = 50.0 + + vocab_size: int = 16384 + cls_token_num: int = 1 + block_size: int = 256 + max_batch_size: int = 32 + max_seq_len: int = 2048 + + # Dynamic V1 + dynamic: bool = True + extra_cls_token_num: int = 2 + + + +################################################################################# +# Capping for Attention Softmax # +################################################################################# + +def custom_scaled_dot_product_attention( + query, key, value, + attn_mask=None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale=None, + enable_gqa: bool = False, + logit_cap: float = 50.0, # set to None to disable tanh-capping +) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = (1 / math.sqrt(query.size(-1))) if scale is None else scale + + # additive bias for masking (same dtype/device as query) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + + if is_causal: + assert attn_mask is None + causal = torch.ones(L, S, dtype=torch.bool, device=query.device).tril() + attn_bias = attn_bias.masked_fill(~causal, float("-inf")) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf")) + else: + attn_bias = attn_bias + attn_mask.to(dtype=query.dtype, device=query.device) + + if enable_gqa: + repeat = query.size(-3) // key.size(-3) + key = key.repeat_interleave(repeat, dim=-3) + value = value.repeat_interleave(repeat, dim=-3) + + # logits (pre-softmax) + logits = (query @ key.transpose(-2, -1)) * scale_factor + logits = logits + attn_bias + + # --- tanh cap: softmax(tanh(logits/cap)*cap), but keep -inf from masks --- + if logit_cap is not None: + finite = torch.isfinite(logits) + logits = torch.where( + finite, + torch.tanh(logits / logit_cap) * logit_cap, + logits, # keeps -inf (and +inf if any) + ) + + # softmax in float32 for stability, then back to value dtype + attn_weight = torch.softmax(logits.float(), dim=-1).to(value.dtype) + + if dropout_p: + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + + return attn_weight @ value + + +def custom_scaled_dot_product_attention_v2( + query, key, value, + attn_mask=None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale=None, + enable_gqa: bool = False, + logit_cap: float = 50.0, # set to None to disable tanh-capping + force_true_fp32_qk: bool = True, # disable TF32 for QK^T matmul + force_true_fp32_out: bool = False, # set True if you also want (P@V) in true FP32 +) -> torch.Tensor: + """ + Shapes follow (..., H, L, D). Works with bf16/fp16 autocast; stabilizes numerics by + doing QK^T + softmax in real FP32 (not TF32), then casting back. + """ + device = query.device + q_dtype = query.dtype + v_dtype = value.dtype + + L, S = query.size(-2), key.size(-2) + scale_factor = (1.0 / math.sqrt(query.size(-1))) if scale is None else float(scale) + + # --- broadcastable FP32 bias (safer for -inf) --- + attn_bias = torch.zeros(L, S, dtype=torch.float32, device=device) + + if is_causal: + assert attn_mask is None + causal = torch.ones(L, S, dtype=torch.bool, device=device).tril() + attn_bias = attn_bias.masked_fill(~causal, float("-inf")) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf")) + else: + attn_bias = attn_bias + attn_mask.to(dtype=torch.float32, device=device) + + if enable_gqa: + repeat = query.size(-3) // key.size(-3) + key = key.repeat_interleave(repeat, dim=-3) + value = value.repeat_interleave(repeat, dim=-3) + + # --- QK^T in true FP32 (optionally disable TF32 just here) --- + q32 = query.to(torch.float32) + k32T = key.to(torch.float32).transpose(-2, -1) + tf32_ctx = torch.backends.cuda.matmul.allow_tf32(False) if force_true_fp32_qk else nullcontext() + with tf32_ctx: + logits = (q32 @ k32T) * scale_factor # FP32 matmul + + # add FP32 bias/mask + logits = logits + attn_bias # still FP32 + + # optional tanh-cap, preserving -inf from masks + if logit_cap is not None: + finite = torch.isfinite(logits) + logits = torch.where(finite, + torch.tanh(logits / logit_cap) * logit_cap, + logits) + + # --- softmax in FP32, then cast --- + attn_weight = torch.softmax(logits, dim=-1) # FP32 + attn_weight = attn_weight.to(v_dtype) # back to value dtype for matmul + + if dropout_p: + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + + # --- (P @ V): you can keep speed (bf16) or force true FP32 here too --- + if force_true_fp32_out: + P32 = attn_weight.to(torch.float32) + V32 = value.to(torch.float32) + with torch.backends.cuda.matmul.allow_tf32(False): + out = P32 @ V32 + return out.to(v_dtype) + else: + return attn_weight @ value + + +################################################################################# +# Embedding Layers for Class Labels # +################################################################################# +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels).unsqueeze(1) + return embeddings + + +################################################################################# +# Embedding Layers for Text Feature # +################################################################################# +class CaptionEmbedder(nn.Module): + """ + Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120): + super().__init__() + self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size) + self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + embeddings = self.cap_proj(caption) + return embeddings + + +class TokenEmbedder(nn.Module): + def __init__(self, vocab_size, hidden_size, num_codebooks): + super().__init__() + self.num_codebooks = num_codebooks + self.codebooks = nn.ModuleList() + for _ in range(num_codebooks): + codebook = nn.Embedding(vocab_size // num_codebooks, hidden_size) + self.codebooks.append(codebook) + + def forward(self, indices): + assert indices.shape[1] == self.num_codebooks + latent_features = [] + for i in range(self.num_codebooks): + latent_feature = self.codebooks[i](indices[:, i]) + latent_features.append(latent_feature) + latent_features = torch.stack(latent_features).sum(dim=0) + return latent_features + + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features, out_features): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=False) + self.act = nn.GELU(approximate='tanh') + self.fc2 = nn.Linear(hidden_features, out_features, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +################################################################################# +# GPT Model # +################################################################################# +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + hidden_dim = 4 * config.dim + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if config.ffn_dim_multiplier is not None: + hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) + hidden_dim = find_multiple(hidden_dim, config.multiple_of) + + self.w1 = nn.Linear(config.dim, hidden_dim, bias=False) + self.w3 = nn.Linear(config.dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, config.dim, bias=False) + self.ffn_dropout = nn.Dropout(config.ffn_dropout_p) + + def forward(self, x): + return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) + + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype): + super().__init__() + cache_shape = (max_batch_size, n_head, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + return k_out, v_out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + self.dim = config.dim + self.head_dim = config.dim // config.n_head + self.n_head = config.n_head + self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head + total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim + + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + # regularization + self.attn_dropout_p = config.attn_dropout_p + self.resid_dropout = nn.Dropout(config.resid_dropout_p) + + self.capping = config.capping if 'cap' in config.extra_layers else None + + # Add QK or QKV normalization based on config + if 'QKV' in config.extra_layers: + self.q_norm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.v_norm = RMSNorm(self.head_dim, eps=config.norm_eps) + elif 'QK' in config.extra_layers: + self.q_norm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, x: torch.Tensor, freqs_cis: torch.Tensor = None, + input_pos: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + full_attn: Optional[bool] = False, + ): + bsz, seqlen, _ = x.shape + kv_size = self.n_kv_head * self.head_dim + xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) + + # Apply QK or QKV normalization if configured + if hasattr(self, 'q_norm'): + xq = self.q_norm(xq) + xk = self.k_norm(xk) + if hasattr(self, 'v_norm'): + xv = self.v_norm(xv) + + if freqs_cis is not None: + xq = apply_rotary_emb(xq, freqs_cis) + xk = apply_rotary_emb(xk, freqs_cis) + + xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) + + if self.kv_cache is not None: + keys, values = self.kv_cache.update(input_pos, xk, xv) + else: + keys, values = xk, xv + keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1) + values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1) + + output = custom_scaled_dot_product_attention( + xq, keys, values, + attn_mask=mask, + is_causal=True if (mask is None and full_attn == False) else False, # is_causal=False is for KV cache + dropout_p=self.attn_dropout_p if self.training else 0, + logit_cap=self.capping if self.capping is not None else None + ) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + output = self.resid_dropout(self.wo(output)) + return output + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs, drop_path: float): + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None, + full_attn=False): + h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask, full_attn)) + out = h + self.drop_path(self.feed_forward(self.ffn_norm(h))) + return out + + +class AutoRegressiveHead(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_codebooks = config.num_codebooks + self.sub_vocab_size = config.vocab_size // config.num_codebooks + + self.codebooks = nn.ModuleList() + for _ in range(self.num_codebooks - 1): + codebook = nn.Embedding(self.sub_vocab_size, config.dim) + self.codebooks.append(codebook) + + self.layers = torch.nn.ModuleList() + for _ in range(config.n_output_layer): + self.layers.append(TransformerBlock(config, drop_path=0.)) + + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.linear_head = nn.Linear(config.dim, self.sub_vocab_size) + + def forward_train(self, base_tokens, targets): + K = targets.shape[1] + B, L, C = base_tokens.shape + base_tokens = base_tokens.reshape(B * L, 1, C) + targets = targets.permute(0, 2, 1).reshape(B * L, K)[:, :-1] + index_embeddings = [] + for i in range(self.num_codebooks - 1): + index_embed = self.codebooks[i](targets[:, i]) + index_embeddings.append(index_embed) + index_embeddings = torch.stack(index_embeddings, dim=1) + h = torch.cat((base_tokens, index_embeddings), dim=1) + for layer in self.layers: + h = layer(h, freqs_cis=None, start_pos=None, mask=None) + h = self.norm(h) + logits = self.linear_head(h) + logits = logits.reshape(B, L, K, -1).permute(0, 2, 1, 3) + return logits + + def forward_test(self, base_tokens, idx=None, input_pos=None, mask=None): + if idx is not None: + h = self.codebooks[input_pos - 1](idx) + else: + h = base_tokens + for layer in self.layers: + h = layer(h, freqs_cis=None, start_pos=input_pos, mask=mask) + h = self.norm(h) + logits = self.linear_head(h) + return logits + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.n_layer = config.n_layer + self.block_size = config.block_size + self.num_classes = config.num_classes + self.model_type = config.model_type + self.num_codebooks = config.num_codebooks + self.cls_token_num = config.cls_token_num + self.extra_layers = config.extra_layers + self.dynamic = config.dynamic + self.extra_cls_token_num = config.extra_cls_token_num + + if self.extra_layers not in [None, 'QK', 'QKV', 'FC', 'cap', 'clip', 'QK_cap', 'QKV_cap', 'QK_clip', 'QKV_clip', 'QK_FC_cap', 'QKV_FC_cap', 'QK_FC_clip', 'QKV_FC_clip']: + raise ValueError(f"Invalid extra layers: {self.extra_layers}") + + if self.model_type == 'c2i': + self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob) + elif self.model_type == 't2i': + self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob) + else: + raise Exception("please check model type") + # self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.tok_embeddings = TokenEmbedder(config.vocab_size, config.dim, config.num_codebooks) + self.tok_dropout = nn.Dropout(config.token_dropout_p) + + # transformer blocks + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)] + self.layers = torch.nn.ModuleList() + for layer_id in range(config.n_layer): + self.layers.append(TransformerBlock(config, dpr[layer_id])) + + # if dynamic, add new layers + if self.dynamic: + self.target_h_embedder = nn.Linear(1, config.dim, bias=True) + self.target_w_embedder = nn.Linear(1, config.dim, bias=True) + else: + raise ValueError("Dynamic is not supported") + + # output layer + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = AutoRegressiveHead(config) + + # 2d rotary pos embedding + # grid_size = int(self.block_size ** 0.5) + # assert grid_size * grid_size == self.block_size + # self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, + # self.config.rope_base, self.cls_token_num) + self.freqs_cis = precompute_freqs_cis(self.block_size, self.config.dim // self.config.n_head, + self.config.rope_base, self.cls_token_num) + + # KVCache + self.max_batch_size = -1 + self.max_seq_length = -1 + + self.initialize_weights() + + def initialize_weights(self): + # Initialize nn.Linear and nn.Embedding + self.apply(self._init_weights) + # exit() + + # Zero-out output layers: + nn.init.constant_(self.output.linear_head.weight, 0) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + + def setup_caches(self, max_batch_size, max_seq_length, dtype): + # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + # return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype) + + causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1) + # grid_size = int(self.config.block_size ** 0.5) + # assert grid_size * grid_size == self.block_size + # self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, + # self.config.rope_base, self.cls_token_num) + self.freqs_cis = precompute_freqs_cis(self.block_size, self.config.dim // self.config.n_head, + self.config.rope_base, self.cls_token_num) + + def setup_head_caches(self, max_batch_size, max_seq_length, dtype, device): + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.output.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype).to( + device) + causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device) + self.output.head_causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1) + + def forward( + self, + idx: torch.Tensor, + cond_idx: torch.Tensor, # cond_idx_or_embed + input_pos: Optional[torch.Tensor] = None, + targets: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + valid: Optional[torch.Tensor] = None, + target_h: Optional[torch.Tensor] = None, + target_w: Optional[torch.Tensor] = None, + ): + + if idx is not None and cond_idx is not None: # training or naive inference + cond_embeddings = self.cls_embedding(cond_idx, train=self.training)[:, :self.cls_token_num] + token_embeddings = self.tok_embeddings(idx) + + if self.dynamic: + target_h_embed = self.target_h_embedder(target_h).unsqueeze(1) + target_w_embed = self.target_w_embedder(target_w).unsqueeze(1) + + cond_embeddings = cond_embeddings + target_h_embed + target_w_embed + + token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1) + h = self.tok_dropout(token_embeddings) + self.freqs_cis = self.freqs_cis.to(h.device) + + else: + if cond_idx is not None: # prefill in inference + token_embeddings = self.cls_embedding(cond_idx, train=self.training)[:, :self.cls_token_num] + if self.dynamic: + target_h_embed = self.target_h_embedder(target_h).unsqueeze(1) + target_w_embed = self.target_w_embedder(target_w).unsqueeze(1) + + token_embeddings = token_embeddings + target_h_embed + target_w_embed + + else: # decode_n_tokens(kv cache) in inference + token_embeddings = self.tok_embeddings(idx) + + bs = token_embeddings.shape[0] + mask = self.causal_mask[:bs, None, input_pos] + h = self.tok_dropout(token_embeddings) + self.freqs_cis = self.freqs_cis + + if self.training: + freqs_cis = self.freqs_cis[:token_embeddings.shape[1]] + else: + freqs_cis = self.freqs_cis[input_pos] + # transformer blocks + for layer in self.layers: + h = layer(h, freqs_cis, input_pos, mask) + + # output layers + h = self.norm(h) + if targets is None: + return h + else: + logits = self.output.forward_train(h, targets=targets) + + if self.training: + logits = logits[:, self.cls_token_num - 1:].contiguous() + + # if we are given some desired targets also calculate the loss + loss = None + if valid is not None: + loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') + valid_all = valid[:, None].repeat(1, targets.shape[1]).view(-1) + loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1) + elif targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss + + def get_fsdp_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120): + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2) + cond_cache = torch.cat( + [torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2) + return cond_cache + + +def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120): + # split the dimension into half, one for x and one for y + half_dim = n_elem // 2 + freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim)) + t = torch.arange(grid_size, device=freqs.device) + freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2) + freqs_grid = torch.concat([ + freqs[:, None, :].expand(-1, grid_size, -1), + freqs[None, :, :].expand(grid_size, -1, -1), + ], dim=-1) # (grid_size, grid_size, head_dim // 2) + cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], + dim=-1) # (grid_size, grid_size, head_dim // 2, 2) + cache = cache_grid.flatten(0, 1) + cond_cache = torch.cat( + [torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2) + return cond_cache + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor): + # x: (bs, seq_len, n_head, head_dim) + # freqs_cis (seq_len, head_dim // 2, 2) + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2) + x_out2 = torch.stack([ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], dim=-1) + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + +################################################################################# +# GPT Configs # +################################################################################# +### text-conditional +def GPT_7B(**kwargs): + return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B + + +def GPT_3B(**kwargs): + return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B + + +def GPT_1B(**kwargs): + return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B + + +### class-conditional +def GPT_XXXL(**kwargs): + return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B + + +def GPT_XXL(**kwargs): + return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B + + +def GPT_XL(**kwargs): + return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M + + +def GPT_L(**kwargs): + return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M + + +def GPT_B(**kwargs): + return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M + + +GPT_models = { + 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL, + 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, +} \ No newline at end of file