diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..98f252900c6064513f2378cd1973102f2cf6f681 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,76 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.webp filter=xet diff=xet merge=xet -text
+*.exr filter=xet diff=xet merge=xet -text
+*.png filter=xet diff=xet merge=xet -text
+*.ply filter=xet diff=xet merge=xet -text
+assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/T.png filter=lfs diff=lfs merge=lfs -text
+assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp filter=lfs diff=lfs merge=lfs -text
+assets/example_texturing/the_forgotten_knight.ply filter=lfs diff=lfs merge=lfs -text
+assets/hdri/city.exr filter=lfs diff=lfs merge=lfs -text
+assets/hdri/courtyard.exr filter=lfs diff=lfs merge=lfs -text
+assets/hdri/forest.exr filter=lfs diff=lfs merge=lfs -text
+assets/hdri/interior.exr filter=lfs diff=lfs merge=lfs -text
+assets/hdri/night.exr filter=lfs diff=lfs merge=lfs -text
+assets/hdri/sunrise.exr filter=lfs diff=lfs merge=lfs -text
+assets/hdri/sunset.exr filter=lfs diff=lfs merge=lfs -text
+assets/teaser.webp filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..4711186cc74ab8d2ae812895aedc3daf4fea8429
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,19 @@
+__pycache__/
+*.pyc
+.venv/
+venv/
+*.ckpt
+*.pt
+*.bin
+*.safetensors
+wandb/
+.wandb/
+node_modules/
+*.egg-info/
+.gradio/
+example.py
+
+outputs*/
+results*/
+ckpts*/
+tmp/example.py
diff --git a/README.md b/README.md
index 867da31dc29ac7f954c535fc6046951a1d8b1c3f..4c344761ef566aa6646e0383932994757e61d5fe 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,14 @@
---
title: Pixal3D T
-emoji: 📈
+emoji: 🏆
colorFrom: indigo
-colorTo: green
+colorTo: gray
sdk: gradio
sdk_version: 6.13.0
+python_version: "3.10"
app_file: app.py
pinned: false
+license: mit
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..604598ae0aaad102acbb36a4628fb7fd7d646882
--- /dev/null
+++ b/app.py
@@ -0,0 +1,584 @@
+"""
+Pixal3D (TRELLIS.2 Backbone) - Gradio App
+
+Image-to-3D generation using Proj-mode Cascade inference (512->1024/1536).
+
+"""
+
+import gradio as gr
+
+import os
+import subprocess
+subprocess.run([
+ "pip", "install", "--force-reinstall", "--no-deps",
+ "https://github.com/LDYang694/Storages/releases/download/20260430/utils3d-0.0.2-py3-none-any.whl"
+], check=True)
+
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
+import argparse
+import math
+import time
+from datetime import datetime
+import shutil
+import cv2
+from typing import *
+import torch
+import numpy as np
+from PIL import Image
+import base64
+import io
+from trellis2.modules.sparse import SparseTensor
+from trellis2.pipelines import Pixal3DImageTo3DPipeline
+from trellis2.renderers import EnvMap
+from trellis2.utils import render_utils
+import o_voxel
+
+
+# ============================================================================
+# Constants & Defaults
+# ============================================================================
+
+MAX_SEED = np.iinfo(np.int32).max
+TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
+MODES = [
+ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
+ {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
+ {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
+ {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
+ {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
+ {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
+]
+STEPS = 8
+DEFAULT_MODE = 3
+DEFAULT_STEP = 3
+
+# Cascade parameters
+CASCADE_LR_RESOLUTION = 512
+CASCADE_MAX_NUM_TOKENS = 49152
+
+# MoGe defaults
+MOGE_MODEL_NAME = "Ruicheng/moge-2-vitl"
+WILD_MESH_SCALE = 1.0
+WILD_EXTEND_PIXEL = 0
+WILD_IMAGE_RESOLUTION = 512
+
+# Image Cond Model configs (extracted from training configs, hardcoded)
+IMAGE_COND_CONFIGS = {
+ "ss": {
+ "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
+ "image_size": 512,
+ "grid_resolution": 16,
+ },
+ "shape_512": {
+ "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
+ "image_size": 512,
+ "grid_resolution": 32,
+ "use_naf_upsample": True,
+ "naf_target_size": 512,
+ },
+ "shape_1024": {
+ "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
+ "image_size": 1024,
+ "grid_resolution": 64,
+ "use_naf_upsample": True,
+ "naf_target_size": 512,
+ },
+ "tex_1024": {
+ "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
+ "image_size": 1024,
+ "grid_resolution": 64,
+ "use_naf_upsample": True,
+ "naf_target_size": 1024,
+ },
+}
+
+
+# ============================================================================
+# CSS & JS
+# ============================================================================
+
+css = """
+.stepper-wrapper { padding: 0; }
+.stepper-container { padding: 0; align-items: center; }
+.step-button { flex-direction: row; }
+.step-connector { transform: none; }
+.step-number { width: 16px; height: 16px; }
+.step-label { position: relative; bottom: 0; }
+.wrap.center.full { inset: 0; height: 100%; }
+.wrap.center.full.translucent { background: var(--block-background-fill); }
+.meta-text-center {
+ display: block !important; position: absolute !important;
+ top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important;
+}
+.previewer-container {
+ position: relative;
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
+ width: 100%; height: 722px; margin: 0 auto; padding: 20px;
+ display: flex; flex-direction: column; align-items: center; justify-content: center;
+}
+.previewer-container .tips-icon {
+ position: absolute; right: 10px; top: 10px; z-index: 10;
+ border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none;
+}
+.previewer-container .tips-text {
+ position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent);
+ border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10;
+ transition: all 0.3s; opacity: 0%; user-select: none;
+}
+.previewer-container .tips-text p { font-size: 14px; line-height: 1.2; }
+.tips-icon:hover + .tips-text { display: block; opacity: 100%; }
+.previewer-container .mode-row {
+ width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap;
+}
+.previewer-container .mode-btn {
+ width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5;
+ transition: all 0.2s; border: 2px solid #ddd; object-fit: cover;
+}
+.previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
+.previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
+.previewer-container .display-row {
+ margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1;
+ display: flex; justify-content: center; align-items: center;
+}
+.previewer-container .previewer-main-image {
+ max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none;
+}
+.previewer-container .previewer-main-image.visible { display: block; }
+.previewer-container .slider-row {
+ width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px;
+}
+.previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
+.previewer-container input[type=range]::-webkit-slider-runnable-track {
+ width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px;
+}
+.previewer-container input[type=range]::-webkit-slider-thumb {
+ height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent);
+ cursor: pointer; -webkit-appearance: none; margin-top: -6px;
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s;
+}
+.previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
+.gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
+.gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
+"""
+
+head = """
+
+"""
+
+empty_html = f"""
+
+"""
+
+
+# ============================================================================
+# Model Loading Utilities
+# ============================================================================
+
+def build_image_cond_model(config: dict):
+ """Build DinoV3ProjFeatureExtractor."""
+ from trellis2.trainers.flow_matching.mixins.image_conditioned_proj import DinoV3ProjFeatureExtractor
+ model = DinoV3ProjFeatureExtractor(**config)
+ model.eval()
+ return model
+
+
+# ============================================================================
+# Camera Parameter Utilities
+# ============================================================================
+
+def compute_f_pixels(camera_angle_x: float, resolution: int) -> float:
+ focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0))
+ f_pixels = focal_length * resolution / 32.0
+ return float(f_pixels.item())
+
+
+def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution):
+ rotation_matrix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
+ gp = grid_point.to(torch.float32) @ rotation_matrix.T
+ gp = gp / mesh_scale / 2
+ xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item()
+ xt, yt = float(target_point[0].item()), float(target_point[1].item())
+ f_pixels = compute_f_pixels(camera_angle_x, image_resolution)
+ x_ndc = xt - image_resolution / 2.0
+ y_ndc = -(yt - image_resolution / 2.0)
+ distance_x = f_pixels * xw / x_ndc - yw
+ return {"distance_from_x": float(distance_x), "f_pixels": float(f_pixels)}
+
+
+def load_moge_model(device="cuda", model_name=MOGE_MODEL_NAME):
+ print(f"[MoGe-2] Loading model {model_name}...")
+ from moge.model.v2 import MoGeModel
+ moge_model = MoGeModel.from_pretrained(model_name).to(device)
+ moge_model.eval()
+ print("[MoGe-2] Model loaded!")
+ return moge_model
+
+
+def get_camera_params_wild_moge(image, moge_model, device="cuda",
+ mesh_scale=1.0, extend_pixel=0, image_resolution=512):
+ """Estimate camera parameters via MoGe-2."""
+ if isinstance(image, str):
+ pil_image = Image.open(image).convert("RGB")
+ elif isinstance(image, Image.Image):
+ pil_image = image.convert("RGB")
+ else:
+ raise ValueError(f"Unsupported image type: {type(image)}")
+ width, height = pil_image.size
+ image_np = np.array(pil_image).astype(np.float32) / 255.0
+ image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device)
+ with torch.no_grad():
+ output = moge_model.infer(image_tensor)
+ intrinsics = output["intrinsics"].squeeze().cpu().numpy()
+ fx_normalized = intrinsics[0, 0]
+ fx = fx_normalized * width
+ camera_angle_x = 2 * math.atan(width / (2 * fx))
+
+ grid_point = torch.tensor([-1.0, 0.0, 0.0])
+ distance = distance_from_fov(
+ camera_angle_x, grid_point,
+ torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]),
+ mesh_scale, image_resolution
+ )["distance_from_x"]
+ return {'camera_angle_x': camera_angle_x, 'distance': distance, 'mesh_scale': mesh_scale}
+
+
+# ============================================================================
+# UI Utilities
+# ============================================================================
+
+def image_to_base64(image):
+ buffered = io.BytesIO()
+ image = image.convert("RGB")
+ image.save(buffered, format="jpeg", quality=85)
+ img_str = base64.b64encode(buffered.getvalue()).decode()
+ return f"data:image/jpeg;base64,{img_str}"
+
+
+def start_session(req: gr.Request):
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
+ os.makedirs(user_dir, exist_ok=True)
+
+
+def end_session(req: gr.Request):
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
+ if os.path.exists(user_dir):
+ shutil.rmtree(user_dir)
+
+
+def preprocess_image(image: Image.Image) -> Image.Image:
+ return pipeline.preprocess_image(image)
+
+
+def pack_state(shape_slat, tex_slat, res):
+ return {
+ 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
+ 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
+ 'coords': shape_slat.coords.cpu().numpy(),
+ 'res': res,
+ }
+
+
+def unpack_state(state):
+ shape_slat = SparseTensor(
+ feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
+ coords=torch.from_numpy(state['coords']).cuda(),
+ )
+ tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
+ return shape_slat, tex_slat, state['res']
+
+
+def get_seed(randomize_seed, seed):
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
+
+
+# ============================================================================
+# Core Inference
+# ============================================================================
+
+def image_to_3d(
+ image, seed, resolution,
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
+ req: gr.Request,
+ progress=gr.Progress(track_tqdm=True),
+):
+ device = pipeline.device
+ torch.manual_seed(seed)
+ hr_resolution = int(resolution)
+
+ total_t0 = time.time()
+ print(f"\n{'='*60}")
+ print(f" [Generate] Start | seed={seed}, resolution={hr_resolution}")
+ print(f"{'='*60}")
+
+ # Preprocessing
+ image_preprocessed = pipeline.preprocess_image(image)
+
+ # Camera estimation via MoGe-2
+ camera_params = get_camera_params_wild_moge(
+ image_preprocessed, moge_model, device=str(device),
+ mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL,
+ image_resolution=WILD_IMAGE_RESOLUTION,
+ )
+
+ ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength,
+ "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t}
+ shape_sampler_override = {"steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength,
+ "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t}
+ tex_sampler_override = {"steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength,
+ "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t}
+
+ # Run pipeline
+ pipeline_type = f"{hr_resolution}_cascade"
+ mesh_list, (shape_slat, tex_slat, res) = pipeline.run(
+ image_preprocessed,
+ camera_params=camera_params,
+ seed=seed,
+ sparse_structure_sampler_params=ss_sampler_override,
+ shape_slat_sampler_params=shape_sampler_override,
+ tex_slat_sampler_params=tex_sampler_override,
+ preprocess_image=False,
+ return_latent=True,
+ pipeline_type=pipeline_type,
+ max_num_tokens=CASCADE_MAX_NUM_TOKENS,
+ )
+ mesh = mesh_list[0]
+ state = pack_state(shape_slat, tex_slat, res)
+ del shape_slat, tex_slat, mesh_list
+ torch.cuda.empty_cache()
+
+ # Render
+ mesh.simplify(16777216)
+ images = render_utils.render_proj_aligned_video(
+ mesh, camera_angle_x=camera_params['camera_angle_x'],
+ distance=camera_params['distance'], resolution=1024,
+ num_frames=STEPS, envmap=envmap,
+ )
+ del mesh
+ torch.cuda.empty_cache()
+ print(f"\n [Generate] Total time: {time.time()-total_t0:.2f}s")
+
+ # Build HTML
+ images_html = ""
+ for m_idx, mode in enumerate(MODES):
+ for s_idx in range(STEPS):
+ unique_id = f"view-m{m_idx}-s{s_idx}"
+ is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
+ vis_class = "visible" if is_visible else ""
+ img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
+ images_html += f'
'
+
+ btns_html = ""
+ for idx, mode in enumerate(MODES):
+ active_class = "active" if idx == DEFAULT_MODE else ""
+ btns_html += f'
'
+
+ full_html = f"""
+
+
+
Tips
+
+
Render Mode - Click circular buttons to switch render modes.
+
View Angle - Drag the slider to change the view angle.
+
+
+
{images_html}
+
{btns_html}
+
+
+
+
+ """
+ return state, full_html
+
+
+def extract_glb(state, decimation_target, texture_size, req: gr.Request, progress=gr.Progress(track_tqdm=True)):
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
+ shape_slat, tex_slat, res = unpack_state(state)
+ mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
+ glb = o_voxel.postprocess.to_glb(
+ vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
+ coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
+ grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
+ decimation_target=decimation_target, texture_size=texture_size,
+ remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True,
+ )
+ now = datetime.now()
+ timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
+ os.makedirs(user_dir, exist_ok=True)
+ glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
+ glb.export(glb_path, extension_webp=True)
+ torch.cuda.empty_cache()
+ return glb_path, glb_path
+
+
+# ============================================================================
+# Gradio UI
+# ============================================================================
+
+with gr.Blocks(delete_cache=(600, 600)) as demo:
+ gr.Markdown("""
+ ## Image to 3D Asset with Pixal3D (TRELLIS.2 Backbone)
+ * Upload an image and click **Generate** to create a 3D asset using Pixal3D with TRELLIS.2 backbone.
+ * Click **Extract GLB** to export and download the generated GLB file.
+ * Camera parameters are estimated automatically via MoGe-2.
+ """)
+
+ with gr.Row():
+ with gr.Column(scale=1, min_width=360):
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
+ resolution = gr.Radio(["1024", "1536"], label="Resolution", value="1536")
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=42, step=1)
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
+ decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=1000000, step=10000)
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=4096, step=1024)
+ generate_btn = gr.Button("Generate")
+
+ with gr.Accordion(label="Advanced Settings", open=False):
+ gr.Markdown("Stage 1: Sparse Structure Generation")
+ with gr.Row():
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
+ ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
+ gr.Markdown("Stage 2: Shape Generation")
+ with gr.Row():
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
+ shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
+ gr.Markdown("Stage 3: Material Generation")
+ with gr.Row():
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
+
+ with gr.Column(scale=10):
+ with gr.Walkthrough(selected=0) as walkthrough:
+ with gr.Step("Preview", id=0):
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
+ extract_btn = gr.Button("Extract GLB")
+ with gr.Step("Extract", id=1):
+ glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
+ download_btn = gr.DownloadButton(label="Download GLB")
+
+ with gr.Column(scale=1, min_width=172):
+ examples = gr.Examples(
+ examples=[f'assets/example_image/{image}' for image in os.listdir("assets/example_image")],
+ inputs=[image_prompt], fn=preprocess_image, outputs=[image_prompt],
+ run_on_click=True, examples_per_page=18,
+ )
+
+ output_buf = gr.State()
+
+ demo.load(start_session)
+ demo.unload(end_session)
+ image_prompt.upload(preprocess_image, inputs=[image_prompt], outputs=[image_prompt])
+
+ generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
+ ).then(
+ image_to_3d,
+ inputs=[image_prompt, seed, resolution,
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t],
+ outputs=[output_buf, preview_output],
+ )
+
+ extract_btn.click(lambda: gr.Walkthrough(selected=1), outputs=walkthrough).then(
+ extract_glb, inputs=[output_buf, decimation_target, texture_size], outputs=[glb_output, download_btn],
+ )
+
+
+# ============================================================================
+# Launch
+# ============================================================================
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Pixal3D Gradio App")
+ parser.add_argument("--model_path", type=str, default="TencentARC/Pixal3D-T",
+ help="HuggingFace repo ID or local path (default: TencentARC/Pixal3D-T)")
+ parser.add_argument("--port", type=int, default=7860)
+ parser.add_argument("--share", action="store_true", default=True)
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(TMP_DIR, exist_ok=True)
+
+ # Construct UI icon base64
+ for i in range(len(MODES)):
+ icon = Image.open(MODES[i]['icon'])
+ MODES[i]['icon_base64'] = image_to_base64(icon)
+
+ # Load pipeline from HuggingFace or local path
+ print(f"[Pipeline] Loading from {args.model_path}...")
+ pipeline = Pixal3DImageTo3DPipeline.from_pretrained(args.model_path)
+
+ # Load environment maps
+ envmap = {
+ 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
+ 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
+ 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
+ }
+
+ # Build image cond models and set on pipeline
+ print("[ImageCond] Building DinoV3ProjFeatureExtractor models...")
+ pipeline.image_cond_model_ss = build_image_cond_model(IMAGE_COND_CONFIGS["ss"])
+ pipeline.image_cond_model_shape_512 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_512"])
+ pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
+ pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
+
+ pipeline.cuda()
+
+ # Pre-download NAF model (avoid lazy-loading during inference)
+ print("[NAF] Pre-loading NAF upsampler model...")
+ for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
+ model = getattr(pipeline, attr, None)
+ if model is not None and getattr(model, 'use_naf_upsample', False):
+ model._load_naf()
+ print("[NAF] NAF model loaded.")
+
+ # Load MoGe-2
+ print("\n[MoGe-2] Loading model for camera estimation...")
+ moge_model = load_moge_model(device="cuda")
+
+ print(f"\n{'=' * 60}")
+ print(f" Pixal3D ready! Model loaded from: {args.model_path}")
+ print(f" Cascade: {CASCADE_LR_RESOLUTION} -> 1024/1536")
+ print(f"{'=' * 60}\n")
+
+ demo.launch(css=css, head=head, server_port=args.port, share=args.share)
diff --git a/assets/app/basecolor.png b/assets/app/basecolor.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7dbeaf6344b4b292757964eca2c6960e7e10a68
Binary files /dev/null and b/assets/app/basecolor.png differ
diff --git a/assets/app/clay.png b/assets/app/clay.png
new file mode 100644
index 0000000000000000000000000000000000000000..e02866a15d1101d1d1cd3f6098c7025c0763cfdb
Binary files /dev/null and b/assets/app/clay.png differ
diff --git a/assets/app/hdri_city.png b/assets/app/hdri_city.png
new file mode 100644
index 0000000000000000000000000000000000000000..43e1c2a54c3bb8aa376613c698268fca0aea9b29
Binary files /dev/null and b/assets/app/hdri_city.png differ
diff --git a/assets/app/hdri_courtyard.png b/assets/app/hdri_courtyard.png
new file mode 100644
index 0000000000000000000000000000000000000000..4261ad62862163d32c22ac38495defdbdf3bebd0
Binary files /dev/null and b/assets/app/hdri_courtyard.png differ
diff --git a/assets/app/hdri_forest.png b/assets/app/hdri_forest.png
new file mode 100644
index 0000000000000000000000000000000000000000..7617fe19adf536c12db3f4bf7bc1f6a42f54f8b7
Binary files /dev/null and b/assets/app/hdri_forest.png differ
diff --git a/assets/app/hdri_interior.png b/assets/app/hdri_interior.png
new file mode 100644
index 0000000000000000000000000000000000000000..e00c1d656c94c71402e93067cac64725425754d0
Binary files /dev/null and b/assets/app/hdri_interior.png differ
diff --git a/assets/app/hdri_night.png b/assets/app/hdri_night.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0423d221069904b0de32c82956675a210e8c375
Binary files /dev/null and b/assets/app/hdri_night.png differ
diff --git a/assets/app/hdri_studio.png b/assets/app/hdri_studio.png
new file mode 100644
index 0000000000000000000000000000000000000000..0f5a4e8d5c5717b6ca4217e6e081e695e2bec264
Binary files /dev/null and b/assets/app/hdri_studio.png differ
diff --git a/assets/app/hdri_sunrise.png b/assets/app/hdri_sunrise.png
new file mode 100644
index 0000000000000000000000000000000000000000..9cee3bb066a0f01ddf84b088da31fba1af6648d0
Binary files /dev/null and b/assets/app/hdri_sunrise.png differ
diff --git a/assets/app/hdri_sunset.png b/assets/app/hdri_sunset.png
new file mode 100644
index 0000000000000000000000000000000000000000..bd67070912b846b7cb353d13696f4aff4be40831
Binary files /dev/null and b/assets/app/hdri_sunset.png differ
diff --git a/assets/app/normal.png b/assets/app/normal.png
new file mode 100644
index 0000000000000000000000000000000000000000..352e92b5750414ee1c2562d1c719468ce5da2883
Binary files /dev/null and b/assets/app/normal.png differ
diff --git a/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp b/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp
new file mode 100644
index 0000000000000000000000000000000000000000..8cb4652f561eca8ab3aa32f6e00457bc9b9f194a
--- /dev/null
+++ b/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83d3765ff57511f11054d62e6beaf52648d277006b6dcb3c1d5f9e03ef502c49
+size 108096
diff --git a/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp b/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b9dbb51554fe1f2daeb448c0e47cfaa512a07eff
--- /dev/null
+++ b/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a31ba3b084757faa66e4bb91a85bd956fcc1eaaaf298a091fd08044351cd0293
+size 184424
diff --git a/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp b/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp
new file mode 100644
index 0000000000000000000000000000000000000000..62488149f1f5d00ac044ff6f95b91efa6117b7ed
--- /dev/null
+++ b/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8d419ac16852b73fbc7b92c6857ba5d4be2f2dd8056bf44af707e6af4d62207
+size 156794
diff --git a/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp b/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..f6a749a33c8ffd111cf198f499bda2b499537d51
--- /dev/null
+++ b/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1a36d67f1cd7bba5c78b429c85346990aa3f269e19eee60b6d8b5d632cd743a
+size 100332
diff --git a/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp b/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp
new file mode 100644
index 0000000000000000000000000000000000000000..3c3ac5dfc7a6d30528928e30eca04f36870f33c3
Binary files /dev/null and b/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp differ
diff --git a/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp b/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e596271b736281ad07491dd7b7d3544214e2c7aa
Binary files /dev/null and b/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp differ
diff --git a/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp b/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp
new file mode 100644
index 0000000000000000000000000000000000000000..287db1bed9d597e5ff7e4a49da4a63c5529f51c7
--- /dev/null
+++ b/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb7cdcea3b238d40c7b838659288f5c12febd77292b22d44f54f785e2712d49a
+size 159636
diff --git a/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp b/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp
new file mode 100644
index 0000000000000000000000000000000000000000..bdfc4f52348dc4ea59cae734f9ee0f4aa29484da
--- /dev/null
+++ b/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4202eb19d6802319aa140e65385b3474d09aaa7f008218acf1e1f5186b923c90
+size 167372
diff --git a/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp b/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp
new file mode 100644
index 0000000000000000000000000000000000000000..ef5029ebcfa044e437caa35458bb6ab860dbda8a
--- /dev/null
+++ b/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d15269d0c0b427eabcca39d6d093cc2cfdaab19cb8a40b6158988a70a91ffe45
+size 207084
diff --git a/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp b/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp
new file mode 100644
index 0000000000000000000000000000000000000000..89889d4b276ba78b3fc451ec6a504c1b5089c859
--- /dev/null
+++ b/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6cd45dd941dc559191c71f49ac5841fc6d06d32c7427fcf535f38c007611cd3f
+size 109042
diff --git a/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp b/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp
new file mode 100644
index 0000000000000000000000000000000000000000..962281565b2c60a013e64776313145fd8186e4ef
--- /dev/null
+++ b/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4e26e3e2f3de96cf396c64f1cef39ac82c13f231e2567fe0fe902dc103bf949
+size 169506
diff --git a/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp b/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0af9ee139499ab072b6068e5fb1d974f99103a30
--- /dev/null
+++ b/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:88c65db47f93beed7d6b184c32f9fa27c5f64ced2e3d4f535ad628179763cff8
+size 115508
diff --git a/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp b/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp
new file mode 100644
index 0000000000000000000000000000000000000000..7411bc618bfd5af692073d70281a857df6932b22
Binary files /dev/null and b/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp differ
diff --git a/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp b/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5f00cc2b4562fb026033a555cf95c946a3922246
Binary files /dev/null and b/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp differ
diff --git a/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp b/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0a3a3a893b1fa0914e63d79e82d7894afe552483
--- /dev/null
+++ b/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:47be8271d2d1d27fc334cc470d63534807ad44c6f50cee18a204125a1c6b66a4
+size 140210
diff --git a/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp b/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..6ae86128b367dc0618fd87e605cd08832b6b931b
--- /dev/null
+++ b/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae7ad70cbb2a616fca63bc73d319172d5fda0930c4aaf9d897e21c09ccf17ad7
+size 167670
diff --git a/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp b/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4cb9bdb5ac7fe4bd2ea84771de21eff56dda4cf4
--- /dev/null
+++ b/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b5ca8a0b47aa2dbd5480cd356e7ab574dadd565ed738ac5ed655f36a4349eea
+size 175080
diff --git a/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp b/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp
new file mode 100644
index 0000000000000000000000000000000000000000..01a571a7a2ad661e29c1a4c3b5ba68353451cd70
--- /dev/null
+++ b/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:933c10aeebb2920b08cb34a08ab1878817b64eb9e30efdcc3d76731069fc0849
+size 131488
diff --git a/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp b/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp
new file mode 100644
index 0000000000000000000000000000000000000000..60375c981e776e3667287f8df99b04f12098f1df
--- /dev/null
+++ b/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d5224235eeadadfdb93ab37664055ef55ffd930b085d268cd62c5faf9d101de
+size 136872
diff --git a/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp b/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp
new file mode 100644
index 0000000000000000000000000000000000000000..ddd363de151adc06a1815d148a69fce18a238942
--- /dev/null
+++ b/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86b909e6847118f2cec6e9c0945e6a78154ce5918f4ba14c0b0afa0be2e8647f
+size 150020
diff --git a/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp b/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4151759cf55814774c5e1ff29e219e965a21a423
--- /dev/null
+++ b/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:62b7097ce57d01e730a0ce2cc120b6c9d27585026c3de84d6b1a0dcaf5fea9d3
+size 128572
diff --git a/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp b/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp
new file mode 100644
index 0000000000000000000000000000000000000000..cd9120b3d0c12fcec01cf68a27060eba6c658e37
--- /dev/null
+++ b/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:47cbe7b6f29adb0f7d02c2cd73f6f8bd90a54b12d416368499458a8c571e35c5
+size 178460
diff --git a/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp b/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp
new file mode 100644
index 0000000000000000000000000000000000000000..7a10d813fb288d647e49f2b8bb2f5cdf578f873d
--- /dev/null
+++ b/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbd98cf5da79c56f8efc6cf86804391e71bdad2e935d21a7262472653a0674dc
+size 126914
diff --git a/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp b/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp
new file mode 100644
index 0000000000000000000000000000000000000000..021980a74ef652ae55bef0272527aab211275753
--- /dev/null
+++ b/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfad86b88eb81da36a5acf77891822c042e991436bb004c7d75d8b19e89c45bd
+size 110562
diff --git a/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp b/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp
new file mode 100644
index 0000000000000000000000000000000000000000..3fe995897ad1712af150b391a5d91063c3e417b5
--- /dev/null
+++ b/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd73dcc5ae5b42a442c87f1522f370869f3d2e63922e6ecf769fd90fdeed1e66
+size 124876
diff --git a/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp b/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp
new file mode 100644
index 0000000000000000000000000000000000000000..8f787a298a8ba70c64beff1ad88cc9b96d4d58b9
Binary files /dev/null and b/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp differ
diff --git a/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp b/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp
new file mode 100644
index 0000000000000000000000000000000000000000..54a78c84746298a677529c8abe5858d6d734bbe7
Binary files /dev/null and b/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp differ
diff --git a/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp b/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp
new file mode 100644
index 0000000000000000000000000000000000000000..31b50241ad8cc38c4d2b9a00b36dc258dc172b73
--- /dev/null
+++ b/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0dd605ac2927b56f6ba5ced2d0b56aaba1be9f7dc2ab7f8d558d0f37856f6b5
+size 169830
diff --git a/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp b/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b7a521fcdb71be9f5b095bd0c2a3a8b3ed9a32fc
Binary files /dev/null and b/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp differ
diff --git a/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp b/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp
new file mode 100644
index 0000000000000000000000000000000000000000..51557cf27478f8c301978278157bf71c1d4069f1
--- /dev/null
+++ b/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41fd0f2a1cd66f4615a3947b6fdbddf527d4d628935a6fd208ca32a3b46849b1
+size 111992
diff --git a/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp b/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b07990950bc8f3aee982664d6fd1264570906752
--- /dev/null
+++ b/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b5a7835a908470c1142678cddfbfbe093ba0ec1085852ed072ae597f1eae5c37
+size 111836
diff --git a/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp b/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5f10f35ab7a6f2d1d95d40e82c4bb91e30f165b5
--- /dev/null
+++ b/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7807391ef76523dab937bfcae5f9fc4803a65b86ab49310c17df0714288b7c48
+size 195232
diff --git a/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp b/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp
new file mode 100644
index 0000000000000000000000000000000000000000..ab02dea2e3abcb802b48c82d6e7d670783f1be5e
--- /dev/null
+++ b/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3977e4b323080e122ec297f2fafe1203e717fef8417e82b6ed523e8438f829c6
+size 191642
diff --git a/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp b/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp
new file mode 100644
index 0000000000000000000000000000000000000000..60dd646a277d93fb76441ec47f88e22c4cf6967a
--- /dev/null
+++ b/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6707893380d3c791a9a2d33bc0caa5f10719f7168d25f8d447fa26f4420919db
+size 108942
diff --git a/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp b/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp
new file mode 100644
index 0000000000000000000000000000000000000000..3d6c84ab1ccc878d9f292a04b579642b92c14fc9
--- /dev/null
+++ b/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9558307682a1e723f86291af377f0b53970c0d301304c1f43dec741fedc209d7
+size 117870
diff --git a/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp b/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp
new file mode 100644
index 0000000000000000000000000000000000000000..f12cbf93875f086f3084efc29072f8fb579f3143
--- /dev/null
+++ b/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a319ace2549835da92a6ffa5db73eebd7fce29079e5865cb32dfbdac21d9b900
+size 100414
diff --git a/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp b/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp
new file mode 100644
index 0000000000000000000000000000000000000000..afc4b658a9e2dfefd67a66e91cf60f009b74e8bd
--- /dev/null
+++ b/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b8373ce90506089d23907895d7efac4c37defb603ea7434bee9c4de870bf7da
+size 179774
diff --git a/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp b/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d4298ffbb33598e6bb6392741f65cc9bc648a468
--- /dev/null
+++ b/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:643b3d426830d89dc62634027f802e3af5450eb45591eeb413ae67d5ba56f557
+size 114408
diff --git a/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp b/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp
new file mode 100644
index 0000000000000000000000000000000000000000..bbd93cd629c4850d0ddfcd65763478d1897cbb1c
--- /dev/null
+++ b/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4e65a295b78ca3505337eb1cc58d0e5ff7effc0d500d01e503c4d5243af0a68
+size 103346
diff --git a/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp b/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a0795fe45f2477d4811a07bd2168f08ef9460960
--- /dev/null
+++ b/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:23b3f7294731c66a3a9db7ccbd593708e1ba56819cff970b29fd9fbf286588e0
+size 187208
diff --git a/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp b/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b668a21a0f5e8a3c0730377da02357cd6e30b869
--- /dev/null
+++ b/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d42562b49df17554583f2acbe18f8a26c9075cdab278caefab27aa7587df4ac
+size 115424
diff --git a/assets/example_image/T.png b/assets/example_image/T.png
new file mode 100644
index 0000000000000000000000000000000000000000..171f72e68f117c7d2507f94970a6a8d2d5e7e563
--- /dev/null
+++ b/assets/example_image/T.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db468bad8a04f1474a8d68140c07501b013b3ec6124b911fb7852675d64c05ee
+size 1718470
diff --git a/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp b/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp
new file mode 100644
index 0000000000000000000000000000000000000000..7f764c337d834d1077aef043aa277f360b20d9ba
--- /dev/null
+++ b/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94b508edb4a3c49ff4e5c4c583b6980a7d0ef84a2f16a301ac9b671d421b8a52
+size 103072
diff --git a/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp b/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e4cf1566ad8daa3be8f938d3d3e5b00b297c7414
--- /dev/null
+++ b/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71c2035d7968538b51a5edb0d853a310f3f03b9626b705ac7f14398bb72f6093
+size 101638
diff --git a/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp b/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..29b7868f213043532468e2fa1c169c3a9c98ce26
--- /dev/null
+++ b/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a8dbf4af05a60591cb157bcbb19c58c05c1d5622aaecddaf5aaabf292e89ea0
+size 208094
diff --git a/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp b/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp
new file mode 100644
index 0000000000000000000000000000000000000000..242c31ce26348b99d98a69f789cf646ca472db05
--- /dev/null
+++ b/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b3786b5f2c6715d02a20abde45f4465883eb8b791a807f9dfc86dc78a17b7c58
+size 180294
diff --git a/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp b/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp
new file mode 100644
index 0000000000000000000000000000000000000000..00790a403686c1fb1a78e30b9e0bad269dcb80a6
--- /dev/null
+++ b/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:69f3ab098a28bf0c06194ab523ba10e30ce3db331871e2e8fa97cfd5967853b0
+size 180186
diff --git a/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp b/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp
new file mode 100644
index 0000000000000000000000000000000000000000..dc0630884d160addf2997b2435fd15ea3bece805
--- /dev/null
+++ b/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc3d5f2d5228c9314d8037105cc24798499aea1aaf311cde332df74d17f96da6
+size 130468
diff --git a/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp b/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp
new file mode 100644
index 0000000000000000000000000000000000000000..cfa8a87a3b0c0a5ed90f249c030815517317403c
--- /dev/null
+++ b/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a77a4ab99beb302c5934efaeaeef2effc48b2c07b9d68cadb06d68396a48abea
+size 106998
diff --git a/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp b/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9b23b87c2a30e9c704de86bc7c2214a9629661a8
--- /dev/null
+++ b/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01962bf62febbf916100201f915d27ab881c747a9e283228d8b9c890c1eb2fb8
+size 155612
diff --git a/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp b/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4ae373686902684181a5f7c84f976a02f3fea9d5
--- /dev/null
+++ b/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed898f0689bd19f163fbfb601b9c9d31c422ab2ad409ec07f61b70471636c164
+size 155628
diff --git a/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp b/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4282564790b1a57ea1d4a7c2f77dd84775cade17
--- /dev/null
+++ b/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:961adcb3c9c10db565957bfcf5d0be49e18c23012363a78bebcf8b49dbd7153c
+size 206140
diff --git a/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp b/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4203f61711b705c5e5e0e6f49e52f9b555c878a9
--- /dev/null
+++ b/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e211aff13d1e7dcdf592d68ea34e03794c4464c5b159224d7c508d9ecae5eb59
+size 151940
diff --git a/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp b/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp
new file mode 100644
index 0000000000000000000000000000000000000000..fc11f391f770d92966ef73037c8769f807337875
--- /dev/null
+++ b/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4533ce41604e7aff386c71f37f0b2727242a4615ef0e37c3cd62273678ad1809
+size 142682
diff --git a/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp b/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp
new file mode 100644
index 0000000000000000000000000000000000000000..898e801ef377765e4082aa5968db56dd6e696a06
--- /dev/null
+++ b/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ec6060536a5db96092d419be6e0f9d14985f6954ced1605e9d473d415ff29368
+size 764638
diff --git a/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp b/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp
new file mode 100644
index 0000000000000000000000000000000000000000..6d79bbc33c14ba034cbc6f5dc7ab443e0f92d445
--- /dev/null
+++ b/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eaf705e6205f5101af261fac1a3477121cf377f92ff93feaf506eada70de98d7
+size 217676
diff --git a/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp b/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp
new file mode 100644
index 0000000000000000000000000000000000000000..66bc8713ed9f5a86463850a3dd130f6d9a4953c0
--- /dev/null
+++ b/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c0c09536a1a6c1c44e7c8649b1fda58ed4404ffa26debdc6f1719aaa7fdc59b
+size 235370
diff --git a/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp b/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp
new file mode 100644
index 0000000000000000000000000000000000000000..2573d75a2a428aef7d4d42daaa29d087b293a461
--- /dev/null
+++ b/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:373afe2a56a038d1356f3fc5f1c8fb04778e90cb2216ad1c60ba7e2d2877705f
+size 153922
diff --git a/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp b/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp
new file mode 100644
index 0000000000000000000000000000000000000000..65b852da0703baeaee573de95f33b9ce366731d8
Binary files /dev/null and b/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp differ
diff --git a/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp b/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a773e5c7e9a3e0078acdf30dbace00de6e22cd18
Binary files /dev/null and b/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp differ
diff --git a/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp b/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9ffe4be1a5349b8a65b220b6063386831816db4b
--- /dev/null
+++ b/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:309713c97b32a579a232faea918aea7c49d3638b07c47f29072523bb4e95c5a0
+size 167178
diff --git a/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp b/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp
new file mode 100644
index 0000000000000000000000000000000000000000..335531a6f0d695d3463e0bbfc478c08ad19c93ed
--- /dev/null
+++ b/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:15e90c0ce28673b896e754ca56060dd840799cb8b95fb4ebeb06661675f8ed07
+size 182650
diff --git a/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp b/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp
new file mode 100644
index 0000000000000000000000000000000000000000..114d6ffaa3545023654eea953aee8c22031e40a1
--- /dev/null
+++ b/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05a0cabaea29f4d03c6468d039fe09af8eb8c3647952f3bcb342421445ea00a2
+size 229896
diff --git a/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp b/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e64acdec6e24981b82c80f35bafb44766ac6cc31
Binary files /dev/null and b/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp differ
diff --git a/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp b/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp
new file mode 100644
index 0000000000000000000000000000000000000000..08850da10b51022be1509076a797e4b4ea2887a1
--- /dev/null
+++ b/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:477c0c9ef774e796a27645e1c143e7bd8d574cb6f31f2a0a9f57aaf66e80c0be
+size 106434
diff --git a/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp b/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp
new file mode 100644
index 0000000000000000000000000000000000000000..ba44a2d028fc04d8b74cef6935f8db9e9df17ce5
--- /dev/null
+++ b/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b169654d84f98e808a875bc80b47e0cfe988af495250dba4368edf3eea4f633
+size 304284
diff --git a/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp b/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b9e06fd37eba38547fe24b88ce68c442086a83c7
--- /dev/null
+++ b/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:242db48e9c2e6e45a888fdd24f133b966f460b583f1e77324252b869d1a7d65d
+size 254674
diff --git a/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp b/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp
new file mode 100644
index 0000000000000000000000000000000000000000..16921f8bf9843166dd1e7695fcb5fa13008c75e1
--- /dev/null
+++ b/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f669f15496ecc1f3ec9819b40c8ca46ce2c080f49eab342724f2872f5c9dcff
+size 186202
diff --git a/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp b/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp
new file mode 100644
index 0000000000000000000000000000000000000000..c2ab696162c9dae5f90114d4825af5ec30cb28cb
--- /dev/null
+++ b/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f4d093a8f99f3fc25a588785ceee95180427ff8dd063e2fda9f1e7f41cf40940
+size 247040
diff --git a/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp b/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp
new file mode 100644
index 0000000000000000000000000000000000000000..1ae36239b3659bf579bd2588f930cdcd02d68b13
--- /dev/null
+++ b/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d65e9fdbe6e4b432e34fbae16ea8a3d98e3cef81a9b1613f294039e96805e77
+size 208356
diff --git a/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp b/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9c24d0eb0a3d0c2c7ddc69e3ae39f4e9d85b9243
Binary files /dev/null and b/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp differ
diff --git a/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp b/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp
new file mode 100644
index 0000000000000000000000000000000000000000..f90db415387320470e55cc5bdcf62635e7b45726
Binary files /dev/null and b/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp differ
diff --git a/assets/example_texturing/image.webp b/assets/example_texturing/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a69f81dc924ae63e0c0a4f3212f476280d95439b
Binary files /dev/null and b/assets/example_texturing/image.webp differ
diff --git a/assets/example_texturing/readme b/assets/example_texturing/readme
new file mode 100644
index 0000000000000000000000000000000000000000..7f6c799d95a410000c9116e459665339a701c88a
--- /dev/null
+++ b/assets/example_texturing/readme
@@ -0,0 +1,11 @@
+## Asset Information
+
+* Title: The Forgotten Knight
+* Author: dark_igorek
+* Source: https://sketchfab.com/3d-models/the-forgotten-knight-d14eb14d83bd4e7ba7cbe443d76a10fd
+* License: Creative Commons Attribution (CC BY)
+
+## Usage
+
+The asset is used for research purposes only.
+Please credit the original author and include the Sketchfab link when using or redistributing this model.
\ No newline at end of file
diff --git a/assets/example_texturing/the_forgotten_knight.ply b/assets/example_texturing/the_forgotten_knight.ply
new file mode 100644
index 0000000000000000000000000000000000000000..919ffdc534b30c9e19af438705d50a6f1814531f
--- /dev/null
+++ b/assets/example_texturing/the_forgotten_knight.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56b16fba10017445de9529b2c2c1f68a97d2ca883293b27fa5efce116c6489b1
+size 4753123
diff --git a/assets/hdri/city.exr b/assets/hdri/city.exr
new file mode 100644
index 0000000000000000000000000000000000000000..fac69a5403a2a3b51d725bca679174d2f0208b25
--- /dev/null
+++ b/assets/hdri/city.exr
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e42abcd2fa3231e5c2485ca6dd64800534d157b7194ab7fe9cc3bf5a56d0256
+size 204863
diff --git a/assets/hdri/courtyard.exr b/assets/hdri/courtyard.exr
new file mode 100644
index 0000000000000000000000000000000000000000..142504138fb47eba3024495d69923211ebeda9b0
--- /dev/null
+++ b/assets/hdri/courtyard.exr
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6690b47725965531559380121e6878373eb599655e35fefd109a4bd0911366f3
+size 255126
diff --git a/assets/hdri/forest.exr b/assets/hdri/forest.exr
new file mode 100644
index 0000000000000000000000000000000000000000..7809de1141fe935457de0cc73d8fe87ecf924e1a
--- /dev/null
+++ b/assets/hdri/forest.exr
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bdf2298244affa0f85509380fd130ac6d4dfaa3c856df065998f7f4c1a93dc0d
+size 552641
diff --git a/assets/hdri/interior.exr b/assets/hdri/interior.exr
new file mode 100644
index 0000000000000000000000000000000000000000..ddb91daa8b222bedea152d60221499768d0e0fd7
--- /dev/null
+++ b/assets/hdri/interior.exr
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e945ff5c1ddd7a3aaf05e9fb5c3bc9cb93c5518414febd44ed2c394e013f0cbd
+size 189416
diff --git a/assets/hdri/license.txt b/assets/hdri/license.txt
new file mode 100644
index 0000000000000000000000000000000000000000..eba5cb8037c9eb45e8c8638c99b39f7160a69214
--- /dev/null
+++ b/assets/hdri/license.txt
@@ -0,0 +1,15 @@
+All HDRIs are licensed as CC0.
+
+These were created by Greg Zaal (Poly Haven https://polyhaven.com).
+Originals used for each HDRI:
+- City: https://polyhaven.com/a/portland_landing_pad
+- Courtyard: https://polyhaven.com/a/courtyard
+- Forest: https://polyhaven.com/a/ninomaru_teien
+- Interior: https://polyhaven.com/a/hotel_room
+- Night: Probably https://polyhaven.com/a/moonless_golf
+- Studio: Probably https://polyhaven.com/a/studio_small_01
+- Sunrise: https://polyhaven.com/a/spruit_sunrise
+- Sunset: https://polyhaven.com/a/venice_sunset
+
+1K resolution of each was taken, and compressed with oiiotool:
+oiiotool input.exr --ch R,G,B -d float --compression dwab:300 --clamp:min=0.0:max=32000.0 -o output.exr
diff --git a/assets/hdri/night.exr b/assets/hdri/night.exr
new file mode 100644
index 0000000000000000000000000000000000000000..380591f26bebbd1419b703145dedb20cab765f59
--- /dev/null
+++ b/assets/hdri/night.exr
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17480af5547465d160f7307c92585cf30820ca563f87f066395decbad8ac32a4
+size 139959
diff --git a/assets/hdri/studio.exr b/assets/hdri/studio.exr
new file mode 100644
index 0000000000000000000000000000000000000000..baf478da00acddbd70593e36d90146eb15c8939d
Binary files /dev/null and b/assets/hdri/studio.exr differ
diff --git a/assets/hdri/sunrise.exr b/assets/hdri/sunrise.exr
new file mode 100644
index 0000000000000000000000000000000000000000..b0689a58e51c6527ed8810186644dd482f96799a
--- /dev/null
+++ b/assets/hdri/sunrise.exr
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6a1180126e4db7d01f134c5430ea43c1b263ab1a12faf58a444d9ce9c03f3a84
+size 251877
diff --git a/assets/hdri/sunset.exr b/assets/hdri/sunset.exr
new file mode 100644
index 0000000000000000000000000000000000000000..bf0414b39faf538b48baa62257aa6e5ff168928f
--- /dev/null
+++ b/assets/hdri/sunset.exr
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bcafdda4f2d7b9759cc1d73004d34d721a274c29b1a0947be88a602dbac426b
+size 171164
diff --git a/assets/teaser.webp b/assets/teaser.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d8b89bd17c1fdf5831bb79a3989bf428086c6a7b
--- /dev/null
+++ b/assets/teaser.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b110c904037017d996933cc318ba948bb2dec7691d5a709568e2e0d9ae86068b
+size 290204
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a6295dd81c7a9a4d865420bfcf05e3f867b9aa90
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+--extra-index-url https://download.pytorch.org/whl/cu124
+
+torch==2.6.0
+torchvision==0.21.0
+triton==3.2.0
+pillow==12.0.0
+imageio==2.37.2
+imageio-ffmpeg==0.6.0
+tqdm==4.67.1
+easydict==1.13
+opencv-python-headless==4.12.0.88
+trimesh==4.10.1
+transformers==4.57.3
+zstandard==0.25.0
+kornia==0.8.2
+timm==1.0.22
+diffusers==0.37.1
+accelerate==1.13.0
+gradio
+
+git+https://github.com/microsoft/MoGe.git
+https://github.com/LDYang694/Storages/releases/download/20260430/natten-0.21.0-cp310-cp310-linux_x86_64.whl
+https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
+https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl
+https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl
+https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl
+https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/nvdiffrast-0.4.0-cp310-cp310-linux_x86_64.whl
+https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/nvdiffrec_render-0.0.0-cp310-cp310-linux_x86_64.whl
\ No newline at end of file
diff --git a/trellis2/__init__.py b/trellis2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..20d240afc9c26a21aee76954628b3d4ef9a1ccbd
--- /dev/null
+++ b/trellis2/__init__.py
@@ -0,0 +1,6 @@
+from . import models
+from . import modules
+from . import pipelines
+from . import renderers
+from . import representations
+from . import utils
diff --git a/trellis2/datasets/__init__.py b/trellis2/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eadadadb935480ed30d44ae11ebe41f785ed772b
--- /dev/null
+++ b/trellis2/datasets/__init__.py
@@ -0,0 +1,52 @@
+import importlib
+
+__attributes = {
+ 'FlexiDualGridDataset': 'flexi_dual_grid',
+ 'SparseVoxelPbrDataset':'sparse_voxel_pbr',
+
+ 'SparseStructureLatent': 'sparse_structure_latent',
+ 'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
+ 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
+ 'SparseStructureLatentView': 'sparse_structure_latent',
+ 'ViewImageConditionedSparseStructureLatentView': 'sparse_structure_latent',
+
+ 'SLat': 'structured_latent',
+ 'ImageConditionedSLat': 'structured_latent',
+ 'SLatShape': 'structured_latent_shape',
+ 'ImageConditionedSLatShape': 'structured_latent_shape',
+ 'SLatShapeView': 'structured_latent_shape',
+ 'ViewImageConditionedSLatShapeView': 'structured_latent_shape',
+ 'SLatPbr': 'structured_latent_svpbr',
+ 'ImageConditionedSLatPbr': 'structured_latent_svpbr',
+ 'SLatPbrView': 'structured_latent_svpbr',
+ 'ViewImageConditionedSLatPbrView': 'structured_latent_svpbr',
+}
+
+__submodules = []
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+ if name not in globals():
+ if name in __attributes:
+ module_name = __attributes[name]
+ module = importlib.import_module(f".{module_name}", __name__)
+ globals()[name] = getattr(module, name)
+ elif name in __submodules:
+ module = importlib.import_module(f".{name}", __name__)
+ globals()[name] = module
+ else:
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+ return globals()[name]
+
+
+# For Pylance
+if __name__ == '__main__':
+ from .flexi_dual_grid import FlexiDualGridDataset
+ from .sparse_voxel_pbr import SparseVoxelPbrDataset
+
+ from .sparse_structure_latent import SparseStructureLatent, ImageConditionedSparseStructureLatent
+ from .structured_latent import SLat, ImageConditionedSLat
+ from .structured_latent_shape import SLatShape, ImageConditionedSLatShape
+ from .structured_latent_svpbr import SLatPbr, ImageConditionedSLatPbr
+
\ No newline at end of file
diff --git a/trellis2/datasets/components.py b/trellis2/datasets/components.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ebc93c1e87b16a28bb24fef8da61bc43502db60
--- /dev/null
+++ b/trellis2/datasets/components.py
@@ -0,0 +1,349 @@
+from typing import *
+import json
+from abc import abstractmethod
+import os
+import json
+import torch
+import numpy as np
+import pandas as pd
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class StandardDatasetBase(Dataset):
+ """
+ Base class for standard datasets.
+
+ Args:
+ roots (str): paths to the dataset
+ skip_list (str, optional): path to a file containing sha256 hashes to skip (one per line)
+ Format: "dataset/sha256" (e.g., "ABO/6a79dbb5...")
+ skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check
+ (e.g., ["texverse"] for datasets without aesthetic_score)
+ """
+
+ def __init__(self,
+ roots: str,
+ skip_list: Optional[str] = None,
+ skip_aesthetic_score_datasets: Optional[List[str]] = None,
+ ):
+ super().__init__()
+
+ # Datasets to skip aesthetic score check
+ self.skip_aesthetic_score_datasets = set(skip_aesthetic_score_datasets or [])
+
+ # Load skip list if provided
+ self.skip_set = set()
+ if skip_list is not None and os.path.exists(skip_list):
+ with open(skip_list, 'r') as f:
+ for line in f:
+ line = line.strip()
+ if line and not line.startswith('#'):
+ self.skip_set.add(line)
+ print(f'Loaded {len(self.skip_set)} items from skip_list: {skip_list}')
+
+ try:
+ self.roots = json.loads(roots)
+ root_type = 'obj'
+ except:
+ self.roots = roots.split(',')
+ root_type = 'list'
+ self.instances = []
+ self.metadata = pd.DataFrame()
+
+ self._stats = {}
+ if root_type == 'obj':
+ for key, root in self.roots.items():
+ self._stats[key] = {}
+ metadata = pd.DataFrame(columns=['sha256']).set_index('sha256')
+
+ # 只从 ss_latent 和 render_cond 合并关键字段
+ # 不包含 base,因为 base/metadata.csv 中的 cond_rendered=False 会错误覆盖真实值
+ for sub_key, r in root.items():
+ if sub_key == 'base':
+ continue # 跳过 base 目录
+ metadata_file = os.path.join(r, 'metadata.csv')
+ if os.path.exists(metadata_file):
+ metadata = metadata.combine_first(pd.read_csv(metadata_file).set_index('sha256'))
+
+ # 从 base 单独读取 aesthetic_score(不读取其他可能冲突的列)
+ if 'base' in root:
+ base_metadata_file = os.path.join(root['base'], 'metadata.csv')
+ if os.path.exists(base_metadata_file):
+ base_df = pd.read_csv(base_metadata_file).set_index('sha256')
+ if 'aesthetic_score' in base_df.columns and 'aesthetic_score' not in metadata.columns:
+ metadata['aesthetic_score'] = base_df['aesthetic_score']
+
+ self._stats[key]['Total'] = len(metadata)
+ metadata, stats = self.filter_metadata(metadata, dataset_name=key)
+ self._stats[key].update(stats)
+
+ # Filter out items in skip_list
+ skipped_count = 0
+ for sha256 in metadata.index.values:
+ skip_key = f'{key}/{sha256}'
+ if skip_key in self.skip_set:
+ skipped_count += 1
+ else:
+ self.instances.append((root, sha256, key))
+ if skipped_count > 0:
+ self._stats[key]['Skipped (skip_list)'] = skipped_count
+ self._stats[key]['After skip_list'] = len(metadata) - skipped_count
+
+ self.metadata = pd.concat([self.metadata, metadata])
+ else:
+ for root in self.roots:
+ key = os.path.basename(root)
+ self._stats[key] = {}
+ metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
+ self._stats[key]['Total'] = len(metadata)
+ metadata, stats = self.filter_metadata(metadata, dataset_name=key)
+ self._stats[key].update(stats)
+
+ # Filter out items in skip_list
+ skipped_count = 0
+ for sha256 in metadata['sha256'].values:
+ skip_key = f'{key}/{sha256}'
+ if skip_key in self.skip_set:
+ skipped_count += 1
+ else:
+ self.instances.append((root, sha256, key))
+ if skipped_count > 0:
+ self._stats[key]['Skipped (skip_list)'] = skipped_count
+ self._stats[key]['After skip_list'] = len(metadata) - skipped_count
+ metadata.set_index('sha256', inplace=True)
+ self.metadata = pd.concat([self.metadata, metadata])
+
+ @abstractmethod
+ def filter_metadata(self, metadata: pd.DataFrame, dataset_name: str = None) -> Tuple[pd.DataFrame, Dict[str, int]]:
+ pass
+
+ @abstractmethod
+ def get_instance(self, root, instance: str) -> Dict[str, Any]:
+ pass
+
+ def __len__(self):
+ return len(self.instances)
+
+ def __getitem__(self, index) -> Dict[str, Any]:
+ try:
+ root, instance, dataset_name = self.instances[index]
+ pack = self.get_instance(root, instance)
+ pack['_dataset_name'] = dataset_name
+ pack['_sha256'] = instance
+ return pack
+ except Exception as e:
+ print(f'Error loading {self.instances[index][1]}: {e}')
+ return self.__getitem__(np.random.randint(0, len(self)))
+
+ def __str__(self):
+ lines = []
+ lines.append(self.__class__.__name__)
+ lines.append(f' - Total instances: {len(self)}')
+ lines.append(f' - Sources:')
+ for key, stats in self._stats.items():
+ lines.append(f' - {key}:')
+ for k, v in stats.items():
+ lines.append(f' - {k}: {v}')
+ return '\n'.join(lines)
+
+
+class ImageConditionedMixin:
+ def __init__(self, roots, *, image_size=518, **kwargs):
+ self.image_size = image_size
+ super().__init__(roots, **kwargs)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ metadata, stats = super().filter_metadata(metadata, dataset_name=dataset_name)
+ metadata = metadata[metadata['cond_rendered'].notna()]
+ stats['Cond rendered'] = len(metadata)
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ pack = super().get_instance(root, instance)
+
+ image_root = os.path.join(root['render_cond'], instance)
+ with open(os.path.join(image_root, 'transforms.json')) as f:
+ metadata = json.load(f)
+ n_views = len(metadata['frames'])
+ view = np.random.randint(n_views)
+ metadata = metadata['frames'][view]
+
+ image_path = os.path.join(image_root, metadata['file_path'])
+ image = Image.open(image_path)
+
+ alpha = np.array(image.getchannel(3))
+ bbox = np.array(alpha).nonzero()
+ bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
+ aug_hsize = hsize
+ aug_center_offset = [0, 0]
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
+ image = image.crop(aug_bbox)
+
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
+ alpha = image.getchannel(3)
+ image = image.convert('RGB')
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
+ image = image * alpha.unsqueeze(0)
+ pack['cond'] = image
+
+ return pack
+
+
+class ViewImageConditionedMixin:
+ """
+ Mixin for view-based image-conditioned datasets.
+
+ This mixin is designed for datasets where ss_latent is stored per-view (view{XX}.npz),
+ and needs to load the corresponding view image and scale from view{XX}_scale.json.
+
+ Args:
+ image_size: Target image size
+ load_camera_info: Whether to load camera information for view-aligned conditioning
+ """
+ def __init__(self, roots, *, image_size=518, load_camera_info=False, **kwargs):
+ self.image_size = image_size
+ # self.load_camera_info = load_camera_info
+ super().__init__(roots, **kwargs)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ metadata, stats = super().filter_metadata(metadata, dataset_name=dataset_name)
+ metadata = metadata[metadata['cond_rendered'].notna()]
+ stats['Cond rendered'] = len(metadata)
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ """
+ Get instance with view-aligned image and camera info.
+
+ Expects parent class to set:
+ - pack['x_0']: the latent tensor
+ - self._current_view_idx: the selected view index
+ - self._current_latent_dir: the latent directory path
+ """
+ pack = super().get_instance(root, instance)
+
+ # Get view_idx from parent class (set by SparseStructureLatentView)
+ if not hasattr(self, '_current_view_idx'):
+ raise RuntimeError("Parent class must set '_current_view_idx' before calling ViewImageConditionedMixin.get_instance")
+ if not hasattr(self, '_current_latent_dir'):
+ raise RuntimeError("Parent class must set '_current_latent_dir' before calling ViewImageConditionedMixin.get_instance")
+ view_idx = self._current_view_idx
+ latent_dir = self._current_latent_dir
+
+ # Load image metadata
+ image_root = os.path.join(root['render_cond'], instance)
+ with open(os.path.join(image_root, 'transforms.json')) as f:
+ metadata = json.load(f)
+
+ # Load corresponding image for this view
+ frame_metadata = metadata['frames'][view_idx]
+ image_path = os.path.join(image_root, frame_metadata['file_path'])
+ image = Image.open(image_path)
+
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
+ alpha = image.getchannel(3)
+ image = image.convert('RGB')
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
+ image = image * alpha.unsqueeze(0)
+ pack['cond'] = image
+
+ # Load camera info if requested
+
+ # camera_angle_x: check frame first, then root metadata
+ if 'camera_angle_x' in frame_metadata:
+ camera_angle_x = float(frame_metadata['camera_angle_x'])
+ elif 'camera_angle_x' in metadata:
+ camera_angle_x = float(metadata['camera_angle_x'])
+ else:
+ raise KeyError(f"'camera_angle_x' not found in transforms.json for {instance}")
+ pack['camera_angle_x'] = torch.tensor(camera_angle_x, dtype=torch.float32)
+
+ # transform_matrix
+ if 'transform_matrix' not in frame_metadata:
+ raise KeyError(f"'transform_matrix' not found in frame {view_idx} for {instance}")
+ transform_matrix = torch.tensor(frame_metadata['transform_matrix'], dtype=torch.float32)
+ distance = torch.norm(transform_matrix[:3, 3]).item()
+
+ pack['camera_distance'] = torch.tensor(distance, dtype=torch.float32)
+ # NOTE: Do NOT pass transform_matrix to ProjGrid.
+ # shape_latent space objects are already rotated to front-view by transform_mesh,
+ # so ProjGrid should use the default front_view_transform_matrix + distance.
+ # pack['transform_matrix'] = transform_matrix
+
+ # Load mesh_scale from ss_latent directory's view{XX}_scale.json
+ scale_json_path = os.path.join(latent_dir, f'view{view_idx:02d}_scale.json')
+ if not os.path.exists(scale_json_path):
+ raise FileNotFoundError(f"Scale file not found: {scale_json_path}")
+ with open(scale_json_path) as f:
+ scale_data = json.load(f)
+ if 'total_scale' not in scale_data:
+ raise KeyError(f"'total_scale' not found in {scale_json_path}")
+ pack['mesh_scale'] = torch.tensor(float(scale_data['total_scale']), dtype=torch.float32)
+
+ return pack
+
+
+class MultiImageConditionedMixin:
+ def __init__(self, roots, *, image_size=518, max_image_cond_view = 4, **kwargs):
+ self.image_size = image_size
+ self.max_image_cond_view = max_image_cond_view
+ super().__init__(roots, **kwargs)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ metadata, stats = super().filter_metadata(metadata, dataset_name=dataset_name)
+ metadata = metadata[metadata['cond_rendered'].notna()]
+ stats['Cond rendered'] = len(metadata)
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ pack = super().get_instance(root, instance)
+
+ image_root = os.path.join(root['render_cond'], instance)
+ with open(os.path.join(image_root, 'transforms.json')) as f:
+ metadata = json.load(f)
+
+ n_views = len(metadata['frames'])
+ n_sample_views = np.random.randint(1, self.max_image_cond_view+1)
+
+ assert n_views >= n_sample_views, f'Not enough views to sample {n_sample_views} unique images.'
+
+ sampled_views = np.random.choice(n_views, size=n_sample_views, replace=False)
+
+ cond_images = []
+ for v in sampled_views:
+ frame_info = metadata['frames'][v]
+ image_path = os.path.join(image_root, frame_info['file_path'])
+ image = Image.open(image_path)
+
+ alpha = np.array(image.getchannel(3))
+ bbox = np.array(alpha).nonzero()
+ bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
+ aug_hsize = hsize
+ aug_center = center
+ aug_bbox = [
+ int(aug_center[0] - aug_hsize),
+ int(aug_center[1] - aug_hsize),
+ int(aug_center[0] + aug_hsize),
+ int(aug_center[1] + aug_hsize),
+ ]
+
+ img = image.crop(aug_bbox)
+ img = img.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
+ alpha = img.getchannel(3)
+ img = img.convert('RGB')
+ img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
+ img = img * alpha.unsqueeze(0)
+
+ cond_images.append(img)
+
+ pack['cond'] = [torch.stack(cond_images, dim=0)] # (V,3,H,W)
+ return pack
diff --git a/trellis2/datasets/flexi_dual_grid.py b/trellis2/datasets/flexi_dual_grid.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3a935349c105116653acfe9a7ab4ef91820d38d
--- /dev/null
+++ b/trellis2/datasets/flexi_dual_grid.py
@@ -0,0 +1,173 @@
+import os
+import numpy as np
+import pickle
+import torch
+import utils3d
+from .components import StandardDatasetBase
+from ..modules import sparse as sp
+from ..renderers import MeshRenderer
+from ..representations import Mesh
+from ..utils.data_utils import load_balanced_group_indices
+import o_voxel
+
+
+class FlexiDualGridVisMixin:
+ @torch.no_grad()
+ def visualize_sample(self, x: dict):
+ mesh = x['mesh']
+
+ renderer = MeshRenderer({'near': 1, 'far': 3})
+ renderer.rendering_options.resolution = 512
+ renderer.rendering_options.ssaa = 4
+
+ # Build camera
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
+ yaws = [y + yaws_offset for y in yaws]
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
+
+ exts = []
+ ints = []
+ for yaw, pitch in zip(yaws, pitch):
+ orig = torch.tensor([
+ np.sin(yaw) * np.cos(pitch),
+ np.cos(yaw) * np.cos(pitch),
+ np.sin(pitch),
+ ]).float().cuda() * 2
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
+ exts.append(extrinsics)
+ ints.append(intrinsics)
+
+ # Build each representation
+ images = []
+ for m in mesh:
+ image = torch.zeros(3, 1024, 1024).cuda()
+ tile = [2, 2]
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = \
+ renderer.render(m.cuda(), ext, intr)['normal']
+ images.append(image)
+ images = torch.stack(images)
+
+ return images
+
+
+class FlexiDualGridDataset(FlexiDualGridVisMixin, StandardDatasetBase):
+ """
+ Flexible Dual Grid Dataset
+
+ Args:
+ roots (str): path to the dataset
+ resolution (int): resolution of the voxel grid
+ min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
+ """
+
+ def __init__(
+ self,
+ roots,
+ resolution: int = 1024,
+ max_active_voxels: int = 1000000,
+ max_num_faces: int = None,
+ min_aesthetic_score: float = 5.0,
+ ):
+ self.resolution = resolution
+ self.min_aesthetic_score = min_aesthetic_score
+ self.max_active_voxels = max_active_voxels
+ self.max_num_faces = max_num_faces
+ self.value_range = (0, 1)
+
+ super().__init__(roots)
+
+ self.loads = [self.metadata.loc[sha256, f'dual_grid_size'] for _, sha256, _ in self.instances]
+
+ def __str__(self):
+ lines = [
+ super().__str__(),
+ f' - Resolution: {self.resolution}',
+ ]
+ return '\n'.join(lines)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ stats = {}
+ metadata = metadata[metadata[f'dual_grid_converted'] == True]
+ stats['Dual Grid Converted'] = len(metadata)
+ if self.min_aesthetic_score is not None:
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
+ metadata = metadata[metadata[f'dual_grid_size'] <= self.max_active_voxels]
+ stats[f'Active Voxels <= {self.max_active_voxels}'] = len(metadata)
+ if self.max_num_faces is not None:
+ metadata = metadata[metadata['num_faces'] <= self.max_num_faces]
+ stats[f'Faces <= {self.max_num_faces}'] = len(metadata)
+ return metadata, stats
+
+ def read_mesh(self, root, instance):
+ with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f:
+ dump = pickle.load(f)
+ start = 0
+ vertices = []
+ faces = []
+ for obj in dump['objects']:
+ if obj['vertices'].size == 0 or obj['faces'].size == 0:
+ continue
+ vertices.append(obj['vertices'])
+ faces.append(obj['faces'] + start)
+ start += len(obj['vertices'])
+ vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float()
+ faces = torch.from_numpy(np.concatenate(faces, axis=0)).long()
+ vertices_min = vertices.min(dim=0)[0]
+ vertices_max = vertices.max(dim=0)[0]
+ center = (vertices_min + vertices_max) / 2
+ scale = 0.99999 / (vertices_max - vertices_min).max()
+ vertices = (vertices - center) * scale
+ assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
+ return {'mesh': [Mesh(vertices=vertices, faces=faces)]}
+
+ def read_dual_grid(self, root, instance):
+ coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4)
+ vertices = sp.SparseTensor(
+ (attr['vertices'] / 255.0).float(),
+ torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
+ )
+ intersected = vertices.replace(torch.cat([
+ attr['intersected'] % 2,
+ attr['intersected'] // 2 % 2,
+ attr['intersected'] // 4 % 2,
+ ], dim=-1).bool())
+ return {'vertices': vertices, 'intersected': intersected}
+
+ def get_instance(self, root, instance):
+ mesh = self.read_mesh(root['mesh_dump'], instance)
+ dual_grid = self.read_dual_grid(root['dual_grid'], instance)
+ return {**mesh, **dual_grid}
+
+ @staticmethod
+ def collate_fn(batch, split_size=None):
+ if split_size is None:
+ group_idx = [list(range(len(batch)))]
+ else:
+ group_idx = load_balanced_group_indices([b['vertices'].feats.shape[0] for b in batch], split_size)
+ packs = []
+ for group in group_idx:
+ sub_batch = [batch[i] for i in group]
+ pack = {}
+
+ keys = [k for k in sub_batch[0].keys()]
+ for k in keys:
+ if isinstance(sub_batch[0][k], torch.Tensor):
+ pack[k] = torch.stack([b[k] for b in sub_batch])
+ elif isinstance(sub_batch[0][k], sp.SparseTensor):
+ pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0)
+ elif isinstance(sub_batch[0][k], list):
+ pack[k] = sum([b[k] for b in sub_batch], [])
+ else:
+ pack[k] = [b[k] for b in sub_batch]
+
+ packs.append(pack)
+
+ if split_size is None:
+ return packs[0]
+ return packs
+
\ No newline at end of file
diff --git a/trellis2/datasets/sparse_structure_latent.py b/trellis2/datasets/sparse_structure_latent.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd24d241e92c88591226f5719f09065db78bc2e8
--- /dev/null
+++ b/trellis2/datasets/sparse_structure_latent.py
@@ -0,0 +1,408 @@
+import os
+import json
+from typing import *
+import numpy as np
+import torch
+import utils3d
+from PIL import Image
+from ..representations import Voxel
+from ..renderers import VoxelRenderer
+from .components import StandardDatasetBase, ImageConditionedMixin, ViewImageConditionedMixin
+from .. import models
+from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics
+
+
+class SparseStructureLatentVisMixin:
+ def __init__(
+ self,
+ *args,
+ pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16.json',
+ ss_dec_path: Optional[str] = None,
+ ss_dec_ckpt: Optional[str] = None,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.ss_dec = None
+ self.pretrained_ss_dec = pretrained_ss_dec
+ self.ss_dec_path = ss_dec_path
+ self.ss_dec_ckpt = ss_dec_ckpt
+
+ def _loading_ss_dec(self):
+ if self.ss_dec is not None:
+ return
+ if self.ss_dec_path is not None:
+ cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
+ ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
+ else:
+ decoder = models.from_pretrained(self.pretrained_ss_dec)
+ self.ss_dec = decoder.cuda().eval()
+
+ def _delete_ss_dec(self):
+ del self.ss_dec
+ self.ss_dec = None
+
+ @torch.no_grad()
+ def decode_latent(self, z, batch_size=4):
+ self._loading_ss_dec()
+ ss = []
+ if self.normalization:
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
+ for i in range(0, z.shape[0], batch_size):
+ ss.append(self.ss_dec(z[i:i+batch_size]))
+ ss = torch.cat(ss, dim=0)
+ self._delete_ss_dec()
+ return ss
+
+ @torch.no_grad()
+ def visualize_sample(
+ self,
+ x_0: Union[torch.Tensor, dict],
+ camera_angle_x: Optional[torch.Tensor] = None,
+ camera_distance: Optional[torch.Tensor] = None,
+ mesh_scale: Optional[torch.Tensor] = None,
+ ):
+ """
+ Visualize sparse structure samples.
+
+ Args:
+ x_0: Latent tensor [B, C, D, H, W] or dict containing 'x_0'
+ camera_angle_x: Optional [B] camera FOV angle in radians
+ camera_distance: Optional [B] camera distance for GT view rendering
+ mesh_scale: Optional [B] mesh scale factor for coordinate alignment
+
+ Returns:
+ dict with:
+ 'multiview': [B, 3, 1024, 1024] - 4 fixed views rendered in 2x2 grid
+ 'gt_view': [B, 3, 512, 512] - GT camera view (if camera params provided)
+ """
+ x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
+ x_0 = self.decode_latent(x_0.cuda())
+
+ renderer = VoxelRenderer()
+ renderer.rendering_options.resolution = 512
+ renderer.rendering_options.ssaa = 4
+
+ # Build fixed camera views (4 views: 0°, 90°, 180°, 270°)
+ yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
+ yaw_offset = -16 / 180 * np.pi
+ yaw = [y + yaw_offset for y in yaw]
+ pitch = [20 / 180 * np.pi for _ in range(4)]
+ fixed_exts, fixed_ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
+
+ # Check if we have GT camera parameters for front view rendering
+ # GT view uses the fixed front_view_transform_matrix from image_conditioned_proj.py
+ has_gt_camera = (
+ camera_angle_x is not None and
+ camera_distance is not None and
+ mesh_scale is not None
+ )
+
+ multiview_images = []
+ gt_view_images = []
+
+ # Build each representation
+ x_0 = x_0.cuda()
+ for i in range(x_0.shape[0]):
+ coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
+ resolution = x_0.shape[-1]
+ color = coords / resolution
+
+ # Standard voxel for fixed multiview rendering (origin at [-0.5, -0.5, -0.5])
+ rep = Voxel(
+ origin=[-0.5, -0.5, -0.5],
+ voxel_size=1/resolution,
+ coords=coords,
+ attrs=color,
+ layout={
+ 'color': slice(0, 3),
+ }
+ )
+
+ # Render 4 fixed views (2x2 grid)
+ image = torch.zeros(3, 1024, 1024).cuda()
+ tile = [2, 2]
+ for j, (ext, intr) in enumerate(zip(fixed_exts, fixed_ints)):
+ res = renderer.render(rep, ext, intr, colors_overwrite=color)
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
+ multiview_images.append(image)
+
+ # Render GT camera view using the fixed front view from image_conditioned_proj.py
+ if has_gt_camera:
+ # The GT view should match exactly how ProjGrid projects 3D points to 2D.
+ #
+ # In image_conditioned_proj.py (ProjGrid.forward):
+ # 1. grid_points are in [-1, 1]^3 (from torch.linspace(-1, 1, res))
+ # 2. grid_points are rotated by rotation_matrix (Y-Z swap): x'=x, y'=-z, z'=y
+ # 3. grid_points are scaled: grid_points / mesh_scale / 2
+ # 4. Points are projected using front_view_transform_matrix with distance
+ #
+ # front_view_transform_matrix (camera-to-world):
+ # [[1, 0, 0, 0],
+ # [0, 0, -1, -distance],
+ # [0, 1, 0, 0],
+ # [0, 0, 0, 1]]
+ #
+ # Camera is at (0, -distance, 0) in Blender coords (Z-up), looking at origin.
+ #
+ # To match this in VoxelRenderer:
+ # 1. Voxel coords [0, res-1] map to positions via: pos = (coords + 0.5) * voxel_size + origin
+ # 2. We need these positions to match ProjGrid's transformed grid_points
+ # 3. Apply rotation by swapping/flipping coords, then scale voxel_size and origin
+
+ scale = mesh_scale[i].item()
+ distance = camera_distance[i].item()
+ fov = camera_angle_x[i].item()
+
+ # Coordinate transformation to match ProjGrid's rotation (x'=x, y'=-z, z'=y)
+ # new_coords maps to rotated positions in the same grid structure
+ new_coords = torch.zeros_like(coords)
+ new_coords[:, 0] = coords[:, 0] # x stays
+ new_coords[:, 1] = (resolution - 1) - coords[:, 2] # y' = -z (flip for negation)
+ new_coords[:, 2] = coords[:, 1] # z' = y
+
+ # Voxel position calculation:
+ # Original: pos = (coords + 0.5) / res - 0.5 -> range [-0.5, 0.5]
+ # We need: pos = (coords + 0.5) * 2 / res - 1 -> range [-1, 1] (like ProjGrid)
+ # Then: pos_final = pos / scale / 2 -> range [-0.5/scale, 0.5/scale]
+ #
+ # Combined: pos_final = ((coords + 0.5) * 2 / res - 1) / scale / 2
+ # = (coords + 0.5) / res / scale - 0.5 / scale
+ # = (coords + 0.5) * voxel_size + origin
+ # where: voxel_size = 1 / res / scale
+ # origin = -0.5 / scale
+
+ scaled_voxel_size = 1.0 / resolution / scale
+ scaled_origin = [-0.5 / scale, -0.5 / scale, -0.5 / scale]
+
+ rep_scaled = Voxel(
+ origin=scaled_origin,
+ voxel_size=scaled_voxel_size,
+ coords=new_coords,
+ attrs=color,
+ layout={
+ 'color': slice(0, 3),
+ }
+ )
+
+ # Build the fixed front view camera (same as front_view_transform_matrix)
+ # Camera at (0, -distance, 0), looking at origin, up is Z
+ cam_pos = torch.tensor([0.0, -distance, 0.0], device=coords.device)
+ look_at = torch.tensor([0.0, 0.0, 0.0], device=coords.device)
+ cam_up = torch.tensor([0.0, 0.0, 1.0], device=coords.device)
+
+ gt_ext = utils3d.torch.extrinsics_look_at(cam_pos, look_at, cam_up)
+ gt_int = utils3d.torch.intrinsics_from_fov_xy(
+ torch.tensor(fov, device=coords.device),
+ torch.tensor(fov, device=coords.device)
+ )
+
+ # Ensure tensors are on the correct device (utils3d may not preserve device)
+ gt_ext = gt_ext.to(coords.device)
+ gt_int = gt_int.to(coords.device)
+
+ gt_res = renderer.render(rep_scaled, gt_ext, gt_int, colors_overwrite=color)
+ gt_view_images.append(gt_res['color'])
+
+ result = {
+ 'multiview': torch.stack(multiview_images),
+ }
+
+ if has_gt_camera and len(gt_view_images) > 0:
+ result['gt_view'] = torch.stack(gt_view_images)
+
+ return result
+
+
+class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
+ """
+ Sparse structure latent dataset
+
+ Args:
+ roots (str): path to the dataset
+ min_aesthetic_score (float): minimum aesthetic score
+ normalization (dict): normalization stats
+ pretrained_ss_dec (str): name of the pretrained sparse structure decoder
+ ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
+ ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
+ skip_list (str, optional): path to a file containing sha256 hashes to skip
+ skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check
+ """
+ def __init__(self,
+ roots: str,
+ *,
+ min_aesthetic_score: float = 5.0,
+ normalization: Optional[dict] = None,
+ pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
+ ss_dec_path: Optional[str] = None,
+ ss_dec_ckpt: Optional[str] = None,
+ skip_list: Optional[str] = None,
+ skip_aesthetic_score_datasets: Optional[list] = None,
+ ):
+ self.min_aesthetic_score = min_aesthetic_score
+ self.normalization = normalization
+ self.value_range = (0, 1)
+
+ super().__init__(
+ roots,
+ pretrained_ss_dec=pretrained_ss_dec,
+ ss_dec_path=ss_dec_path,
+ ss_dec_ckpt=ss_dec_ckpt,
+ skip_list=skip_list,
+ skip_aesthetic_score_datasets=skip_aesthetic_score_datasets,
+ )
+
+ if self.normalization is not None:
+ self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
+ self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ stats = {}
+ metadata = metadata[metadata['ss_latent_encoded'] == True]
+ stats['With latent'] = len(metadata)
+ # Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist
+ skip_aesthetic = (
+ (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or
+ ('aesthetic_score' not in metadata.columns)
+ )
+ if skip_aesthetic:
+ stats[f'Aesthetic score check skipped'] = len(metadata)
+ else:
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ latent = np.load(os.path.join(root['ss_latent'], f'{instance}.npz'))
+ z = torch.tensor(latent['z']).float()
+ if self.normalization is not None:
+ z = (z - self.mean) / self.std
+
+ pack = {
+ 'x_0': z,
+ }
+ return pack
+
+
+class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent):
+ """
+ Image-conditioned sparse structure dataset
+ """
+ pass
+
+
+class SparseStructureLatentView(SparseStructureLatentVisMixin, StandardDatasetBase):
+ """
+ View-based sparse structure latent dataset.
+
+ Data format: {sha256}/view{XX}.npz where each npz contains 'z' key.
+
+ Args:
+ num_views (int): Number of views to use (0 to num_views-1). Default is 2.
+ skip_list (str, optional): path to a file containing sha256 hashes to skip
+ skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check
+ """
+ def __init__(self,
+ roots: str,
+ *,
+ min_aesthetic_score: float = 5.0,
+ normalization: Optional[dict] = None,
+ num_views: int = 2,
+ pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
+ ss_dec_path: Optional[str] = None,
+ ss_dec_ckpt: Optional[str] = None,
+ skip_list: Optional[str] = None,
+ skip_aesthetic_score_datasets: Optional[list] = None,
+ ):
+ self.min_aesthetic_score = min_aesthetic_score
+ self.normalization = normalization
+ self.num_views = num_views
+ self.value_range = (0, 1)
+
+ super().__init__(
+ roots,
+ pretrained_ss_dec=pretrained_ss_dec,
+ ss_dec_path=ss_dec_path,
+ ss_dec_ckpt=ss_dec_ckpt,
+ skip_list=skip_list,
+ skip_aesthetic_score_datasets=skip_aesthetic_score_datasets,
+ )
+
+ if self.normalization is not None:
+ self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
+ self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ stats = {}
+ # View-based ss_latent uses columns like:
+ # ss_latent_view00_encoded, ss_latent_view01_encoded, ... (view format)
+ # ss_latent_view_scale00_encoded, ss_latent_view_scale01_encoded, ... (view_scale format)
+ # Check both formats and use whichever exists (prefer view_scale over view)
+ required_view_cols = [f'ss_latent_view_scale{i:02d}_encoded' for i in range(self.num_views)]
+ existing_view_cols = [col for col in required_view_cols if col in metadata.columns]
+
+ if not existing_view_cols:
+ # Fallback to view format
+ required_view_cols = [f'ss_latent_view{i:02d}_encoded' for i in range(self.num_views)]
+ existing_view_cols = [col for col in required_view_cols if col in metadata.columns]
+
+ if existing_view_cols:
+ # Filter rows where all required views are encoded
+ # 注意:NaN 需要被视为 False,所以用 == True 显式比较
+ has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
+ metadata = metadata[has_all_views]
+ stats[f'With {self.num_views} view latents'] = len(metadata)
+ else:
+ # Fallback: check ss_latent_encoded column
+ if 'ss_latent_encoded' in metadata.columns:
+ metadata = metadata[metadata['ss_latent_encoded'] == True]
+ stats['With latent'] = len(metadata)
+ else:
+ raise ValueError(f'No view columns found in metadata: {metadata.columns.tolist()}')
+ # Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist
+ skip_aesthetic = (
+ (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or
+ ('aesthetic_score' not in metadata.columns)
+ )
+ if skip_aesthetic:
+ stats[f'Aesthetic score check skipped'] = len(metadata)
+ else:
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ # View-based format: directory with view{XX}.npz files
+ latent_dir = os.path.join(root['ss_latent'], instance)
+
+ # Randomly select a view from the configured range
+ view_idx = np.random.randint(0, self.num_views)
+ view_file = f'view{view_idx:02d}.npz'
+
+ # Store view info for ViewImageConditionedMixin
+ self._current_view_idx = view_idx
+ self._current_latent_dir = latent_dir
+
+ latent = np.load(os.path.join(latent_dir, view_file))
+ z = torch.tensor(latent['z']).float()
+ if self.normalization is not None:
+ z = (z - self.mean) / self.std
+
+ pack = {
+ 'x_0': z,
+ 'view_idx': view_idx,
+ }
+ return pack
+
+
+class ViewImageConditionedSparseStructureLatentView(ViewImageConditionedMixin, SparseStructureLatentView):
+ """
+ Image-conditioned view-based sparse structure dataset.
+
+ Loads ss_latent from {sha256}/view{XX}.npz format and pairs with
+ corresponding view from render_cond.
+
+ Uses ViewImageConditionedMixin which reads mesh_scale from view{XX}_scale.json.
+ """
+ pass
diff --git a/trellis2/datasets/sparse_voxel_pbr.py b/trellis2/datasets/sparse_voxel_pbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9146b164111df390665af6dfeab19a9a2ca64e1
--- /dev/null
+++ b/trellis2/datasets/sparse_voxel_pbr.py
@@ -0,0 +1,298 @@
+import os
+import io
+from typing import Union
+import numpy as np
+import pickle
+import torch
+from PIL import Image
+import o_voxel
+import utils3d
+from .components import StandardDatasetBase
+from ..modules import sparse as sp
+from ..renderers import VoxelRenderer
+from ..representations import Voxel
+from ..representations.mesh import MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
+
+from ..utils.data_utils import load_balanced_group_indices
+
+
+def is_power_of_two(n: int) -> bool:
+ return n > 0 and (n & (n - 1)) == 0
+
+
+def nearest_power_of_two(n: int) -> int:
+ if n < 1:
+ raise ValueError("n must be >= 1")
+ if is_power_of_two(n):
+ return n
+ lower = 2 ** (n.bit_length() - 1)
+ upper = 2 ** n.bit_length()
+ if n - lower < upper - n:
+ return lower
+ else:
+ return upper
+
+
+class SparseVoxelPbrVisMixin:
+ @torch.no_grad()
+ def visualize_sample(self, x: Union[sp.SparseTensor, dict]):
+ x = x if isinstance(x, sp.SparseTensor) else x['x']
+
+ renderer = VoxelRenderer()
+ renderer.rendering_options.resolution = 512
+ renderer.rendering_options.ssaa = 4
+
+ # Build camera
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
+ yaws = [y + yaws_offset for y in yaws]
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
+
+ exts = []
+ ints = []
+ for yaw, pitch in zip(yaws, pitch):
+ orig = torch.tensor([
+ np.sin(yaw) * np.cos(pitch),
+ np.cos(yaw) * np.cos(pitch),
+ np.sin(pitch),
+ ]).float().cuda() * 2
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
+ exts.append(extrinsics)
+ ints.append(intrinsics)
+
+ images = {k: [] for k in self.layout}
+
+ # Build each representation
+ x = x.cuda()
+ for i in range(x.shape[0]):
+ rep = Voxel(
+ origin=[-0.5, -0.5, -0.5],
+ voxel_size=1/self.resolution,
+ coords=x[i].coords[:, 1:].contiguous(),
+ attrs=None,
+ layout={
+ 'color': slice(0, 3),
+ }
+ )
+ for k in self.layout:
+ image = torch.zeros(3, 1024, 1024).cuda()
+ tile = [2, 2]
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
+ attr = x[i].feats[:, self.layout[k]].expand(-1, 3)
+ res = renderer.render(rep, ext, intr, colors_overwrite=attr)
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
+ images[k].append(image)
+
+ for k in self.layout:
+ images[k] = torch.stack(images[k])
+
+ return images
+
+
+class SparseVoxelPbrDataset(SparseVoxelPbrVisMixin, StandardDatasetBase):
+ """
+ Sparse Voxel PBR dataset.
+
+ Args:
+ roots (str): path to the dataset
+ resolution (int): resolution of the voxel grid
+ min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
+ """
+
+ def __init__(
+ self,
+ roots,
+ resolution: int = 1024,
+ max_active_voxels: int = 1000000,
+ max_num_faces: int = None,
+ min_aesthetic_score: float = 5.0,
+ attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'],
+ with_mesh: bool = True,
+ ):
+ self.resolution = resolution
+ self.min_aesthetic_score = min_aesthetic_score
+ self.max_active_voxels = max_active_voxels
+ self.max_num_faces = max_num_faces
+ self.with_mesh = with_mesh
+ self.value_range = (-1, 1)
+ self.channels = {
+ 'base_color': 3,
+ 'metallic': 1,
+ 'roughness': 1,
+ 'emissive': 3,
+ 'alpha': 1,
+ }
+ self.layout = {}
+ start = 0
+ for attr in attrs:
+ self.layout[attr] = slice(start, start + self.channels[attr])
+ start += self.channels[attr]
+
+ super().__init__(roots)
+
+ self.loads = [self.metadata.loc[sha256, f'num_pbr_voxels'] for _, sha256, _ in self.instances]
+
+ def __str__(self):
+ lines = [
+ super().__str__(),
+ f' - Resolution: {self.resolution}',
+ f' - Attributes: {list(self.layout.keys())}',
+ ]
+ return '\n'.join(lines)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ stats = {}
+ metadata = metadata[metadata['pbr_voxelized'] == True]
+ stats['PBR Voxelized'] = len(metadata)
+ if self.min_aesthetic_score is not None:
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
+ metadata = metadata[metadata['num_pbr_voxels'] <= self.max_active_voxels]
+ stats[f'Active voxels <= {self.max_active_voxels}'] = len(metadata)
+ if self.max_num_faces is not None:
+ metadata = metadata[metadata['num_faces'] <= self.max_num_faces]
+ stats[f'Faces <= {self.max_num_faces}'] = len(metadata)
+ return metadata, stats
+
+ @staticmethod
+ def _texture_from_dump(pack) -> Texture:
+ png_bytes = pack['image']
+ image = Image.open(io.BytesIO(png_bytes))
+ if image.width != image.height or not is_power_of_two(image.width):
+ size = nearest_power_of_two(max(image.width, image.height))
+ image = image.resize((size, size), Image.LANCZOS)
+ texture = torch.tensor(np.array(image) / 255.0, dtype=torch.float32).reshape(image.height, image.width, -1)
+ filter_mode = {
+ 'Linear': TextureFilterMode.LINEAR,
+ 'Closest': TextureFilterMode.CLOSEST,
+ 'Cubic': TextureFilterMode.LINEAR,
+ 'Smart': TextureFilterMode.LINEAR,
+ }[pack['interpolation']]
+ wrap_mode = {
+ 'REPEAT': TextureWrapMode.REPEAT,
+ 'EXTEND': TextureWrapMode.CLAMP_TO_EDGE,
+ 'CLIP': TextureWrapMode.CLAMP_TO_EDGE,
+ 'MIRROR': TextureWrapMode.MIRRORED_REPEAT,
+ }[pack['extension']]
+ return Texture(texture, filter_mode=filter_mode, wrap_mode=wrap_mode)
+
+ def read_mesh_with_texture(self, root, instance):
+ with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f:
+ dump = pickle.load(f)
+
+ # Fix dump alpha map
+ for mat in dump['materials']:
+ if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE':
+ mat['alphaMode'] = 'BLEND'
+
+ # process material
+ materials = []
+ for mat in dump['materials']:
+ materials.append(PbrMaterial(
+ base_color_texture=self._texture_from_dump(mat['baseColorTexture']) if mat['baseColorTexture'] is not None else None,
+ base_color_factor=mat['baseColorFactor'],
+ metallic_texture=self._texture_from_dump(mat['metallicTexture']) if mat['metallicTexture'] is not None else None,
+ metallic_factor=mat['metallicFactor'],
+ roughness_texture=self._texture_from_dump(mat['roughnessTexture']) if mat['roughnessTexture'] is not None else None,
+ roughness_factor=mat['roughnessFactor'],
+ alpha_texture=self._texture_from_dump(mat['alphaTexture']) if mat['alphaTexture'] is not None else None,
+ alpha_factor=mat['alphaFactor'],
+ alpha_mode={
+ 'OPAQUE': AlphaMode.OPAQUE,
+ 'MASK': AlphaMode.MASK,
+ 'BLEND': AlphaMode.BLEND,
+ }[mat['alphaMode']],
+ alpha_cutoff=mat['alphaCutoff'],
+ ))
+ materials.append(PbrMaterial(
+ base_color_factor=[0.8, 0.8, 0.8],
+ alpha_factor=1.0,
+ metallic_factor=0.0,
+ roughness_factor=0.5,
+ alpha_mode=AlphaMode.OPAQUE,
+ alpha_cutoff=0.5,
+ )) # append default material
+
+ # process mesh
+ start = 0
+ vertices = []
+ faces = []
+ material_ids = []
+ uv_coords = []
+ for obj in dump['objects']:
+ if obj['vertices'].size == 0 or obj['faces'].size == 0:
+ continue
+ vertices.append(obj['vertices'])
+ faces.append(obj['faces'] + start)
+ obj['mat_ids'][obj['mat_ids'] == -1] = len(materials) - 1
+ material_ids.append(obj['mat_ids'])
+ uv_coords.append(obj['uvs'] if obj['uvs'] is not None else np.zeros((obj['faces'].shape[0], 3, 2), dtype=np.float32))
+ start += len(obj['vertices'])
+
+ vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float()
+ faces = torch.from_numpy(np.concatenate(faces, axis=0)).long()
+ material_ids = torch.from_numpy(np.concatenate(material_ids, axis=0)).long()
+ uv_coords = torch.from_numpy(np.concatenate(uv_coords, axis=0)).float()
+
+ # Normalize vertices
+ vertices_min = vertices.min(dim=0)[0]
+ vertices_max = vertices.max(dim=0)[0]
+ center = (vertices_min + vertices_max) / 2
+ scale = 0.99999 / (vertices_max - vertices_min).max()
+ vertices = (vertices - center) * scale
+ assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
+
+ return {'mesh': [MeshWithPbrMaterial(
+ vertices=vertices,
+ faces=faces,
+ material_ids=material_ids,
+ uv_coords=uv_coords,
+ materials=materials,
+ )]}
+
+ def read_pbr_voxel(self, root, instance):
+ coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4)
+ feats = torch.concat([attr[k] for k in self.layout], dim=-1) / 255.0 * 2 - 1
+ x = sp.SparseTensor(
+ feats.float(),
+ torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
+ )
+ return {'x': x}
+
+ def get_instance(self, root, instance):
+ if self.with_mesh:
+ mesh = self.read_mesh_with_texture(root['pbr_dump'], instance)
+ pbr_voxel = self.read_pbr_voxel(root['pbr_voxel'], instance)
+ return {**mesh, **pbr_voxel}
+ else:
+ return self.read_pbr_voxel(root['pbr_voxel'], instance)
+
+ @staticmethod
+ def collate_fn(batch, split_size=None):
+ if split_size is None:
+ group_idx = [list(range(len(batch)))]
+ else:
+ group_idx = load_balanced_group_indices([b['x'].feats.shape[0] for b in batch], split_size)
+ packs = []
+ for group in group_idx:
+ sub_batch = [batch[i] for i in group]
+ pack = {}
+
+ keys = [k for k in sub_batch[0].keys()]
+ for k in keys:
+ if isinstance(sub_batch[0][k], torch.Tensor):
+ pack[k] = torch.stack([b[k] for b in sub_batch])
+ elif isinstance(sub_batch[0][k], sp.SparseTensor):
+ pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0)
+ elif isinstance(sub_batch[0][k], list):
+ pack[k] = sum([b[k] for b in sub_batch], [])
+ else:
+ pack[k] = [b[k] for b in sub_batch]
+
+ packs.append(pack)
+
+ if split_size is None:
+ return packs[0]
+ return packs
diff --git a/trellis2/datasets/structured_latent.py b/trellis2/datasets/structured_latent.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cc0bccd4128187cb948c259a80b892a6d715a76
--- /dev/null
+++ b/trellis2/datasets/structured_latent.py
@@ -0,0 +1,224 @@
+import json
+import os
+from typing import *
+import numpy as np
+import torch
+import utils3d.torch
+from .components import StandardDatasetBase, ImageConditionedMixin
+from ..modules.sparse.basic import SparseTensor
+from .. import models
+from ..utils.render_utils import get_renderer
+from ..utils.data_utils import load_balanced_group_indices
+
+
+class SLatVisMixin:
+ def __init__(
+ self,
+ *args,
+ pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
+ slat_dec_path: Optional[str] = None,
+ slat_dec_ckpt: Optional[str] = None,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.slat_dec = None
+ self.pretrained_slat_dec = pretrained_slat_dec
+ self.slat_dec_path = slat_dec_path
+ self.slat_dec_ckpt = slat_dec_ckpt
+
+ def _loading_slat_dec(self):
+ if self.slat_dec is not None:
+ return
+ if self.slat_dec_path is not None:
+ cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
+ ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
+ else:
+ decoder = models.from_pretrained(self.pretrained_slat_dec)
+ self.slat_dec = decoder.cuda().eval()
+
+ def _delete_slat_dec(self):
+ del self.slat_dec
+ self.slat_dec = None
+
+ @torch.no_grad()
+ def decode_latent(self, z, batch_size=4):
+ self._loading_slat_dec()
+ reps = []
+ if self.normalization is not None:
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
+ for i in range(0, z.shape[0], batch_size):
+ reps.append(self.slat_dec(z[i:i+batch_size]))
+ reps = sum(reps, [])
+ self._delete_slat_dec()
+ return reps
+
+ @torch.no_grad()
+ def visualize_sample(self, x_0: Union[SparseTensor, dict]):
+ x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
+ reps = self.decode_latent(x_0.cuda())
+
+ # Build camera
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
+ yaws = [y + yaws_offset for y in yaws]
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
+
+ exts = []
+ ints = []
+ for yaw, pitch in zip(yaws, pitch):
+ orig = torch.tensor([
+ np.sin(yaw) * np.cos(pitch),
+ np.cos(yaw) * np.cos(pitch),
+ np.sin(pitch),
+ ]).float().cuda() * 2
+ fov = torch.deg2rad(torch.tensor(40)).cuda()
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
+ exts.append(extrinsics)
+ ints.append(intrinsics)
+
+ renderer = get_renderer(reps[0])
+ images = []
+ for representation in reps:
+ image = torch.zeros(3, 1024, 1024).cuda()
+ tile = [2, 2]
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
+ res = renderer.render(representation, ext, intr)
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
+ images.append(image)
+ images = torch.stack(images)
+
+ return images
+
+
+class SLat(SLatVisMixin, StandardDatasetBase):
+ """
+ structured latent V2 dataset
+
+ Args:
+ roots (str): path to the dataset
+ min_aesthetic_score (float): minimum aesthetic score
+ max_tokens (int): maximum number of tokens
+ latent_key (str): key of the latent to be used
+ normalization (dict): normalization stats
+ pretrained_slat_dec (str): name of the pretrained slat decoder
+ slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
+ slat_dec_ckpt (str): name of the slat decoder checkpoint
+ skip_list (str, optional): path to a file containing sha256 hashes to skip
+ skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check
+ """
+ def __init__(self,
+ roots: str,
+ *,
+ min_aesthetic_score: float = 5.0,
+ max_tokens: int = 32768,
+ latent_key: str = 'shape_latent',
+ normalization: Optional[dict] = None,
+ pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
+ slat_dec_path: Optional[str] = None,
+ slat_dec_ckpt: Optional[str] = None,
+ skip_list: Optional[str] = None,
+ skip_aesthetic_score_datasets: Optional[list] = None,
+ ):
+ self.normalization = normalization
+ self.min_aesthetic_score = min_aesthetic_score
+ self.max_tokens = max_tokens
+ self.latent_key = latent_key
+ self.value_range = (0, 1)
+
+ super().__init__(
+ roots,
+ pretrained_slat_dec=pretrained_slat_dec,
+ slat_dec_path=slat_dec_path,
+ slat_dec_ckpt=slat_dec_ckpt,
+ skip_list=skip_list,
+ skip_aesthetic_score_datasets=skip_aesthetic_score_datasets,
+ )
+
+ self.loads = [self.metadata.loc[sha256, f'{latent_key}_tokens'] for _, sha256, _ in self.instances]
+
+ if self.normalization is not None:
+ self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
+ self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ stats = {}
+ metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True]
+ stats['With latent'] = len(metadata)
+ # Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist
+ skip_aesthetic = (
+ (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or
+ ('aesthetic_score' not in metadata.columns)
+ )
+ if skip_aesthetic:
+ stats[f'Aesthetic score check skipped'] = len(metadata)
+ else:
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
+ metadata = metadata[metadata[f'{self.latent_key}_tokens'] <= self.max_tokens]
+ stats[f'Num tokens <= {self.max_tokens}'] = len(metadata)
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ data = np.load(os.path.join(root[self.latent_key], f'{instance}.npz'))
+ coords = torch.tensor(data['coords']).int()
+ feats = torch.tensor(data['feats']).float()
+ if self.normalization is not None:
+ feats = (feats - self.mean) / self.std
+ return {
+ 'coords': coords,
+ 'feats': feats,
+ }
+
+ @staticmethod
+ def collate_fn(batch, split_size=None):
+ if split_size is None:
+ group_idx = [list(range(len(batch)))]
+ else:
+ group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
+ packs = []
+ for group in group_idx:
+ sub_batch = [batch[i] for i in group]
+ pack = {}
+ coords = []
+ feats = []
+ layout = []
+ start = 0
+ for i, b in enumerate(sub_batch):
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
+ feats.append(b['feats'])
+ layout.append(slice(start, start + b['coords'].shape[0]))
+ start += b['coords'].shape[0]
+ coords = torch.cat(coords)
+ feats = torch.cat(feats)
+ pack['x_0'] = SparseTensor(
+ coords=coords,
+ feats=feats,
+ )
+ pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
+ pack['x_0'].register_spatial_cache('layout', layout)
+
+ # collate other data
+ keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
+ for k in keys:
+ if isinstance(sub_batch[0][k], torch.Tensor):
+ pack[k] = torch.stack([b[k] for b in sub_batch])
+ elif isinstance(sub_batch[0][k], list):
+ pack[k] = sum([b[k] for b in sub_batch], [])
+ else:
+ pack[k] = [b[k] for b in sub_batch]
+
+ packs.append(pack)
+
+ if split_size is None:
+ return packs[0]
+ return packs
+
+
+class ImageConditionedSLat(ImageConditionedMixin, SLat):
+ """
+ Image conditioned structured latent dataset
+ """
+ pass
diff --git a/trellis2/datasets/structured_latent_shape.py b/trellis2/datasets/structured_latent_shape.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b40e0a61a1caadb34eecf91bf38d6c8128e2c5
--- /dev/null
+++ b/trellis2/datasets/structured_latent_shape.py
@@ -0,0 +1,402 @@
+import os
+import json
+from typing import *
+import numpy as np
+import torch
+import utils3d
+from .. import models
+from .components import ImageConditionedMixin, ViewImageConditionedMixin
+from ..modules.sparse import SparseTensor
+from .structured_latent import SLatVisMixin, SLat
+from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics
+from ..utils.data_utils import load_balanced_group_indices
+
+
+class SLatShapeVisMixin(SLatVisMixin):
+ def _loading_slat_dec(self):
+ if self.slat_dec is not None:
+ return
+ if self.slat_dec_path is not None:
+ cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
+ ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
+ else:
+ decoder = models.from_pretrained(self.pretrained_slat_dec)
+ decoder.set_resolution(self.resolution)
+ self.slat_dec = decoder.cuda().eval()
+
+ @torch.no_grad()
+ def visualize_sample(
+ self,
+ x_0: Union[SparseTensor, dict],
+ camera_angle_x: Optional[torch.Tensor] = None,
+ camera_distance: Optional[torch.Tensor] = None,
+ mesh_scale: Optional[torch.Tensor] = None,
+ ):
+ """
+ Visualize shape samples.
+
+ Args:
+ x_0: SparseTensor or dict containing 'x_0'
+ camera_angle_x: Optional [B] camera FOV angle in radians
+ camera_distance: Optional [B] camera distance for GT view rendering
+ mesh_scale: Optional [B] mesh scale factor for coordinate alignment
+
+ Returns:
+ dict with:
+ 'multiview': [B, 3, 1024, 1024] - 4 fixed views rendered in 2x2 grid (normal)
+ 'gt_view': [B, 3, 512, 512] - GT camera view (if camera params provided)
+ """
+ x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
+ reps = self.decode_latent(x_0.cuda())
+
+ # build fixed camera views (4 views: 0°, 90°, 180°, 270°)
+ yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
+ yaw_offset = -16 / 180 * np.pi
+ yaw = [y + yaw_offset for y in yaw]
+ pitch = [20 / 180 * np.pi for _ in range(4)]
+ fixed_exts, fixed_ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
+
+ # Check if we have GT camera parameters for GT view rendering
+ has_gt_camera = (
+ camera_angle_x is not None and
+ camera_distance is not None and
+ mesh_scale is not None
+ )
+
+ # render
+ renderer = get_renderer(reps[0])
+ multiview_images = []
+ gt_view_images = []
+
+ for i, representation in enumerate(reps):
+ # Render 4 fixed views (2x2 grid)
+ image = torch.zeros(3, 1024, 1024).cuda()
+ tile = [2, 2]
+
+ # Validate mesh data before rasterization
+ verts = representation.vertices
+ faces = representation.faces
+ if verts.shape[0] == 0 or faces.shape[0] == 0:
+ print(f"[visualize_sample] Warning: sample {i} has empty mesh, skipping")
+ multiview_images.append(image)
+ continue
+ if faces.max() >= verts.shape[0]:
+ print(f"[visualize_sample] Warning: sample {i} has out-of-bound face indices "
+ f"(max face idx={faces.max().item()}, num verts={verts.shape[0]}), skipping")
+ multiview_images.append(image)
+ continue
+ if torch.isnan(verts).any() or torch.isinf(verts).any():
+ print(f"[visualize_sample] Warning: sample {i} has NaN/Inf vertices, skipping")
+ multiview_images.append(image)
+ continue
+
+ try:
+ for j, (ext, intr) in enumerate(zip(fixed_exts, fixed_ints)):
+ res = renderer.render(representation, ext, intr)
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['normal']
+ except RuntimeError as e:
+ print(f"[visualize_sample] Warning: render failed for sample {i}: {e}")
+ image = torch.zeros(3, 1024, 1024).cuda()
+ multiview_images.append(image)
+
+ # Render GT camera view using the fixed front view (same as sparse_structure_latent.py)
+ if has_gt_camera:
+ # The GT view should match exactly how ProjGrid projects 3D points to 2D.
+ #
+ # In image_conditioned_proj.py (ProjGrid.forward):
+ # 1. grid_points are in [-1, 1]^3 (from torch.linspace(-1, 1, res))
+ # 2. grid_points are rotated by rotation_matrix (Y-Z swap): x'=x, y'=-z, z'=y
+ # 3. grid_points are scaled: grid_points / mesh_scale / 2
+ # 4. Points are projected using front_view_transform_matrix with distance
+ #
+ # Mesh vertices are in [-0.5, 0.5]^3. To match ProjGrid's coordinate space,
+ # we need to scale them: vertices / mesh_scale -> [-0.5/s, 0.5/s]^3
+ # This is equivalent to ProjGrid's: [-1,1]^3 / scale / 2 -> [-0.5/s, 0.5/s]^3
+ #
+ # Camera position: ProjGrid camera is at (0, -distance, 0) in Blender coords (Z-up).
+ # After inverse rotation to mesh space, camera is at (0, 0, distance).
+
+ scale = mesh_scale[i].item()
+ distance = camera_distance[i].item()
+ fov = camera_angle_x[i].item()
+ device = representation.vertices.device
+
+ # Scale mesh vertices to match ProjGrid's projection space
+ from ..representations import Mesh
+ scaled_rep = Mesh(
+ vertices=representation.vertices / scale,
+ faces=representation.faces,
+ )
+
+ cam_pos = torch.tensor([0.0, 0.0, distance], device=device)
+ look_at = torch.tensor([0.0, 0.0, 0.0], device=device)
+ cam_up = torch.tensor([0.0, 1.0, 0.0], device=device)
+
+ gt_ext = utils3d.torch.extrinsics_look_at(cam_pos, look_at, cam_up)
+ gt_int = utils3d.torch.intrinsics_from_fov_xy(
+ torch.tensor(fov, device=device),
+ torch.tensor(fov, device=device)
+ )
+
+ gt_ext = gt_ext.to(device)
+ gt_int = gt_int.to(device)
+
+ # Use scaled mesh renderer with appropriate near/far for smaller mesh
+ mesh_half_size = 0.5 / scale
+ renderer.rendering_options.near = max(0.01, distance - mesh_half_size - 0.5)
+ renderer.rendering_options.far = distance + mesh_half_size + 0.5
+
+ try:
+ gt_res = renderer.render(scaled_rep, gt_ext, gt_int)
+ gt_view_images.append(gt_res['normal'])
+ except RuntimeError as e:
+ print(f"[visualize_sample] Warning: GT view render failed for sample {i}: {e}")
+ gt_view_images.append(torch.full((3, 512, 512), 0.5, device=device))
+
+ result = {
+ 'multiview': torch.stack(multiview_images),
+ }
+
+ if has_gt_camera and len(gt_view_images) > 0:
+ result['gt_view'] = torch.stack(gt_view_images)
+
+ return result
+
+
+class SLatShape(SLatShapeVisMixin, SLat):
+ """
+ structured latent for shape generation
+
+ Args:
+ roots (str): path to the dataset
+ resolution (int): resolution of the shape
+ min_aesthetic_score (float): minimum aesthetic score
+ max_tokens (int): maximum number of tokens
+ latent_key (str): key of the latent to be used
+ normalization (dict): normalization stats
+ pretrained_slat_dec (str): name of the pretrained slat decoder
+ slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
+ slat_dec_ckpt (str): name of the slat decoder checkpoint
+ skip_list (str, optional): path to a file containing sha256 hashes to skip
+ skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check
+ """
+ def __init__(self,
+ roots: str,
+ *,
+ resolution: int,
+ min_aesthetic_score: float = 5.0,
+ max_tokens: int = 32768,
+ normalization: Optional[dict] = None,
+ pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
+ slat_dec_path: Optional[str] = None,
+ slat_dec_ckpt: Optional[str] = None,
+ skip_list: Optional[str] = None,
+ skip_aesthetic_score_datasets: Optional[list] = None,
+ ):
+ super().__init__(
+ roots,
+ min_aesthetic_score=min_aesthetic_score,
+ max_tokens=max_tokens,
+ latent_key='shape_latent',
+ normalization=normalization,
+ pretrained_slat_dec=pretrained_slat_dec,
+ slat_dec_path=slat_dec_path,
+ slat_dec_ckpt=slat_dec_ckpt,
+ skip_list=skip_list,
+ skip_aesthetic_score_datasets=skip_aesthetic_score_datasets,
+ )
+ self.resolution = resolution
+
+
+class ImageConditionedSLatShape(ImageConditionedMixin, SLatShape):
+ """
+ Image conditioned structured latent for shape generation
+ """
+ pass
+
+
+class SLatShapeView(SLatShapeVisMixin, SLat):
+ """
+ View-based structured latent for shape generation.
+
+ Data format: {sha256}/view{XX}.npz where each npz contains 'coords' and 'feats' keys.
+
+ Args:
+ roots (str): path to the dataset
+ resolution (int): resolution of the shape
+ min_aesthetic_score (float): minimum aesthetic score
+ max_tokens (int): maximum number of tokens
+ num_views (int): Number of views to use (0 to num_views-1). Default is 2.
+ normalization (dict): normalization stats
+ pretrained_slat_dec (str): name of the pretrained slat decoder
+ slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
+ slat_dec_ckpt (str): name of the slat decoder checkpoint
+ skip_list (str, optional): path to a file containing sha256 hashes to skip
+ skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check
+ """
+ def __init__(self,
+ roots: str,
+ *,
+ resolution: int,
+ min_aesthetic_score: float = 5.0,
+ max_tokens: int = 32768,
+ num_views: int = 2,
+ normalization: Optional[dict] = None,
+ pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
+ slat_dec_path: Optional[str] = None,
+ slat_dec_ckpt: Optional[str] = None,
+ skip_list: Optional[str] = None,
+ skip_aesthetic_score_datasets: Optional[list] = None,
+ ):
+ self.normalization = normalization
+ self.min_aesthetic_score = min_aesthetic_score
+ self.max_tokens = max_tokens
+ self.num_views = num_views
+ self.latent_key = 'shape_latent'
+ self.value_range = (0, 1)
+
+ # Initialize parent with SLatVisMixin parameters
+ from .components import StandardDatasetBase
+ SLatVisMixin.__init__(
+ self,
+ roots,
+ pretrained_slat_dec=pretrained_slat_dec,
+ slat_dec_path=slat_dec_path,
+ slat_dec_ckpt=slat_dec_ckpt,
+ )
+ StandardDatasetBase.__init__(self, roots, skip_list=skip_list, skip_aesthetic_score_datasets=skip_aesthetic_score_datasets)
+
+ self.resolution = resolution
+
+ # Calculate loads for load balancing
+ self.loads = []
+ for _, sha256, _ in self.instances:
+ if f'{self.latent_key}_tokens' in self.metadata.columns:
+ try:
+ self.loads.append(self.metadata.loc[sha256, f'{self.latent_key}_tokens'])
+ except:
+ self.loads.append(self.max_tokens)
+ else:
+ self.loads.append(self.max_tokens)
+
+ if self.normalization is not None:
+ self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
+ self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ stats = {}
+ # View-based shape_latent uses columns like shape_latent_view00_encoded, shape_latent_view01_encoded, etc.
+ required_view_cols = [f'shape_latent_view{i:02d}_encoded' for i in range(self.num_views)]
+ existing_view_cols = [col for col in required_view_cols if col in metadata.columns]
+
+ if existing_view_cols:
+ # Filter rows where all required views are encoded
+ # 注意:NaN 需要被视为 False,所以用 == True 显式比较
+ has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
+ metadata = metadata[has_all_views]
+ stats[f'With {self.num_views} view latents'] = len(metadata)
+ else:
+ # Fallback: check shape_latent_encoded column
+ if f'{self.latent_key}_encoded' in metadata.columns:
+ metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True]
+ stats['With latent'] = len(metadata)
+
+ # Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist
+ skip_aesthetic = (
+ (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or
+ ('aesthetic_score' not in metadata.columns)
+ )
+ if skip_aesthetic:
+ stats[f'Aesthetic score check skipped'] = len(metadata)
+ else:
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
+
+ # Filter by max_tokens if column exists
+ tokens_col = f'{self.latent_key}_tokens'
+ if tokens_col in metadata.columns:
+ metadata = metadata[metadata[tokens_col] <= self.max_tokens]
+ stats[f'Num tokens <= {self.max_tokens}'] = len(metadata)
+
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ # View-based format: directory with view{XX}.npz files
+ latent_dir = os.path.join(root[self.latent_key], instance)
+
+ # Randomly select a view from the configured range
+ view_idx = np.random.randint(0, self.num_views)
+ view_file = f'view{view_idx:02d}.npz'
+
+ # Store view info for ViewImageConditionedMixin
+ self._current_view_idx = view_idx
+ self._current_latent_dir = latent_dir
+
+ data = np.load(os.path.join(latent_dir, view_file))
+ coords = torch.tensor(data['coords']).int()
+ feats = torch.tensor(data['feats']).float()
+ if self.normalization is not None:
+ feats = (feats - self.mean) / self.std
+ return {
+ 'coords': coords,
+ 'feats': feats,
+ 'view_idx': view_idx,
+ }
+
+ @staticmethod
+ def collate_fn(batch, split_size=None):
+ if split_size is None:
+ group_idx = [list(range(len(batch)))]
+ else:
+ group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
+ packs = []
+ for group in group_idx:
+ sub_batch = [batch[i] for i in group]
+ pack = {}
+ coords = []
+ feats = []
+ layout = []
+ start = 0
+ for i, b in enumerate(sub_batch):
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
+ feats.append(b['feats'])
+ layout.append(slice(start, start + b['coords'].shape[0]))
+ start += b['coords'].shape[0]
+ coords = torch.cat(coords)
+ feats = torch.cat(feats)
+ pack['x_0'] = SparseTensor(
+ coords=coords,
+ feats=feats,
+ )
+ pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
+ pack['x_0'].register_spatial_cache('layout', layout)
+
+ # collate other data
+ keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
+ for k in keys:
+ if isinstance(sub_batch[0][k], torch.Tensor):
+ pack[k] = torch.stack([b[k] for b in sub_batch])
+ elif isinstance(sub_batch[0][k], list):
+ pack[k] = sum([b[k] for b in sub_batch], [])
+ else:
+ pack[k] = [b[k] for b in sub_batch]
+
+ packs.append(pack)
+
+ if split_size is None:
+ return packs[0]
+ return packs
+
+
+class ViewImageConditionedSLatShapeView(ViewImageConditionedMixin, SLatShapeView):
+ """
+ Image-conditioned view-based structured latent for shape generation.
+
+ Loads shape_latent from {sha256}/view{XX}.npz format and pairs with
+ corresponding view from render_cond.
+
+ Uses ViewImageConditionedMixin which reads mesh_scale from view{XX}_scale.json.
+ """
+ pass
diff --git a/trellis2/datasets/structured_latent_svpbr.py b/trellis2/datasets/structured_latent_svpbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aeb1579ed23a2ac5016b63683697ec7cc7ccf82
--- /dev/null
+++ b/trellis2/datasets/structured_latent_svpbr.py
@@ -0,0 +1,666 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+import json
+from typing import *
+import numpy as np
+import torch
+import cv2
+import utils3d
+from .. import models
+from .components import StandardDatasetBase, ImageConditionedMixin, ViewImageConditionedMixin
+from ..modules.sparse import SparseTensor, sparse_cat
+from ..representations import MeshWithVoxel
+from ..renderers import PbrMeshRenderer, EnvMap
+from ..utils.data_utils import load_balanced_group_indices
+from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics
+
+
+class SLatPbrVisMixin:
+ def __init__(
+ self,
+ *args,
+ pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16',
+ pbr_slat_dec_path: Optional[str] = None,
+ pbr_slat_dec_ckpt: Optional[str] = None,
+ pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
+ shape_slat_dec_path: Optional[str] = None,
+ shape_slat_dec_ckpt: Optional[str] = None,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.pbr_slat_dec = None
+ self.pretrained_pbr_slat_dec = pretrained_pbr_slat_dec
+ self.pbr_slat_dec_path = pbr_slat_dec_path
+ self.pbr_slat_dec_ckpt = pbr_slat_dec_ckpt
+ self.shape_slat_dec = None
+ self.pretrained_shape_slat_dec = pretrained_shape_slat_dec
+ self.shape_slat_dec_path = shape_slat_dec_path
+ self.shape_slat_dec_ckpt = shape_slat_dec_ckpt
+
+ def _loading_slat_dec(self):
+ if self.pbr_slat_dec is not None and self.shape_slat_dec is not None:
+ return
+ if self.pbr_slat_dec_path is not None:
+ cfg = json.load(open(os.path.join(self.pbr_slat_dec_path, 'config.json'), 'r'))
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
+ ckpt_path = os.path.join(self.pbr_slat_dec_path, 'ckpts', f'decoder_{self.pbr_slat_dec_ckpt}.pt')
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
+ else:
+ decoder = models.from_pretrained(self.pretrained_pbr_slat_dec)
+ self.pbr_slat_dec = decoder.cuda().eval()
+
+ if self.shape_slat_dec_path is not None:
+ cfg = json.load(open(os.path.join(self.shape_slat_dec_path, 'config.json'), 'r'))
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
+ ckpt_path = os.path.join(self.shape_slat_dec_path, 'ckpts', f'decoder_{self.shape_slat_dec_ckpt}.pt')
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
+ else:
+ decoder = models.from_pretrained(self.pretrained_shape_slat_dec)
+ decoder.set_resolution(self.resolution)
+ self.shape_slat_dec = decoder.cuda().eval()
+
+ def _delete_slat_dec(self):
+ del self.pbr_slat_dec
+ self.pbr_slat_dec = None
+ del self.shape_slat_dec
+ self.shape_slat_dec = None
+
+ @torch.no_grad()
+ def decode_latent(self, z, shape_z, batch_size=4):
+ self._loading_slat_dec()
+ reps = []
+ if self.shape_slat_normalization is not None:
+ shape_z = shape_z * self.shape_slat_std.to(z.device) + self.shape_slat_mean.to(z.device)
+ if self.pbr_slat_normalization is not None:
+ z = z * self.pbr_slat_std.to(z.device) + self.pbr_slat_mean.to(z.device)
+ for i in range(0, z.shape[0], batch_size):
+ mesh, subs = self.shape_slat_dec(shape_z[i:i+batch_size], return_subs=True)
+ vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs) * 0.5 + 0.5
+ reps.extend([
+ MeshWithVoxel(
+ m.vertices, m.faces,
+ origin = [-0.5, -0.5, -0.5],
+ voxel_size = 1 / self.resolution,
+ coords = v.coords[:, 1:],
+ attrs = v.feats,
+ voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
+ layout = self.layout,
+ )
+ for m, v in zip(mesh, vox)
+ ])
+ self._delete_slat_dec()
+ return reps
+
+ @torch.no_grad()
+ def visualize_sample(self, sample: dict):
+ shape_z = sample['concat_cond'].cuda()
+ z = sample['x_0'].cuda()
+ reps = self.decode_latent(z, shape_z)
+
+ # Extract camera parameters for GT view rendering (if available)
+ camera_angle_x = sample.get('camera_angle_x')
+ camera_distance = sample.get('camera_distance')
+ mesh_scale = sample.get('mesh_scale')
+ has_gt_camera = (
+ camera_angle_x is not None and
+ camera_distance is not None and
+ mesh_scale is not None
+ )
+
+ # build camera
+ yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
+ yaw_offset = -16 / 180 * np.pi
+ yaw = [y + yaw_offset for y in yaw]
+ pitch = [20 / 180 * np.pi for _ in range(4)]
+ exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
+
+ # render
+ renderer = PbrMeshRenderer()
+ renderer.rendering_options.resolution = 512
+ renderer.rendering_options.near = 1
+ renderer.rendering_options.far = 100
+ renderer.rendering_options.ssaa = 2
+ renderer.rendering_options.peel_layers = 8
+ envmap = EnvMap(torch.tensor(
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
+ dtype=torch.float32, device='cuda'
+ ))
+
+ images = {}
+ gt_view_images = {}
+ for i, representation in enumerate(reps):
+ # Validate mesh data before rasterization (same as shape training)
+ verts = representation.vertices
+ faces = representation.faces
+ if verts.shape[0] == 0 or faces.shape[0] == 0:
+ print(f"[visualize_sample] Warning: sample {i} has empty mesh, skipping")
+ continue
+ if faces.max() >= verts.shape[0]:
+ print(f"[visualize_sample] Warning: sample {i} has out-of-bound face indices "
+ f"(max face idx={faces.max().item()}, num verts={verts.shape[0]}), skipping")
+ continue
+ if torch.isnan(verts).any() or torch.isinf(verts).any():
+ print(f"[visualize_sample] Warning: sample {i} has NaN/Inf vertices, skipping")
+ continue
+
+ image = {}
+ tile = [2, 2]
+ try:
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
+ res = renderer.render(representation, ext, intr, envmap=envmap)
+ for k, v in res.items():
+ if k not in images:
+ images[k] = []
+ if k not in image:
+ image[k] = torch.zeros(3, 1024, 1024).cuda()
+ image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = v
+ for k in images.keys():
+ images[k].append(image[k])
+ except RuntimeError as e:
+ print(f"[visualize_sample] Warning: render failed for sample {i}: {e}")
+ try:
+ torch.cuda.synchronize()
+ except Exception:
+ pass
+ try:
+ torch.cuda.empty_cache()
+ except Exception:
+ pass
+ continue
+
+ # Render GT camera view
+ # Must scale mesh vertices by / mesh_scale to match ProjGrid's projection space.
+ # ProjGrid maps [-1,1]^3 -> / scale / 2 -> [-0.5/s, 0.5/s]^3
+ # Mesh vertices in [-0.5, 0.5]^3 -> / scale -> [-0.5/s, 0.5/s]^3 (equivalent)
+ if has_gt_camera:
+ try:
+ scale = mesh_scale[i].item()
+ distance = camera_distance[i].item()
+ fov = camera_angle_x[i].item()
+ device = representation.vertices.device
+
+ # Scale mesh and voxel to match ProjGrid's projection space
+ scaled_rep = MeshWithVoxel(
+ vertices=representation.vertices / scale,
+ faces=representation.faces,
+ origin=(representation.origin / scale).tolist(),
+ voxel_size=representation.voxel_size / scale,
+ coords=representation.coords,
+ attrs=representation.attrs,
+ voxel_shape=representation.voxel_shape,
+ layout=representation.layout,
+ )
+
+ cam_pos = torch.tensor([0.0, 0.0, distance], device=device)
+ look_at = torch.tensor([0.0, 0.0, 0.0], device=device)
+ cam_up = torch.tensor([0.0, 1.0, 0.0], device=device)
+
+ gt_ext = utils3d.torch.extrinsics_look_at(cam_pos, look_at, cam_up)
+ gt_int = utils3d.torch.intrinsics_from_fov_xy(
+ torch.tensor(fov, device=device),
+ torch.tensor(fov, device=device)
+ )
+ gt_ext = gt_ext.to(device)
+ gt_int = gt_int.to(device)
+
+ # Update near/far for the smaller scaled mesh
+ mesh_half_size = 0.5 / scale
+ renderer.rendering_options.near = max(0.01, distance - mesh_half_size - 0.5)
+ renderer.rendering_options.far = distance + mesh_half_size + 0.5
+
+ gt_res = renderer.render(scaled_rep, gt_ext, gt_int, envmap=envmap)
+ for k, v in gt_res.items():
+ gt_key = f'gt_view_{k}'
+ if gt_key not in gt_view_images:
+ gt_view_images[gt_key] = []
+ gt_view_images[gt_key].append(v)
+ except RuntimeError as e:
+ print(f"[visualize_sample] Warning: GT view render failed for sample {i}: {e}")
+ try:
+ torch.cuda.synchronize()
+ except Exception:
+ pass
+ try:
+ torch.cuda.empty_cache()
+ except Exception:
+ pass
+
+ for k in images.keys():
+ images[k] = torch.stack(images[k], dim=0)
+
+ for k, v in gt_view_images.items():
+ images[k] = torch.stack(v)
+
+ return images
+
+
+class SLatPbr(SLatPbrVisMixin, StandardDatasetBase):
+ """
+ structured latent for sparse voxel pbr dataset
+
+ Args:
+ roots (str): path to the dataset
+ latent_key (str): key of the latent to be used
+ min_aesthetic_score (float): minimum aesthetic score
+ normalization (dict): normalization stats
+ resolution (int): resolution of decoded sparse voxel
+ attrs (list): attributes to be decoded
+ pretained_slat_dec (str): name of the pretrained slat decoder
+ slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
+ slat_dec_ckpt (str): name of the slat decoder checkpoint
+ """
+ def __init__(self,
+ roots: str,
+ *,
+ resolution: int,
+ min_aesthetic_score: float = 5.0,
+ max_tokens: int = 32768,
+ full_pbr: bool = False,
+ pbr_slat_normalization: Optional[dict] = None,
+ shape_slat_normalization: Optional[dict] = None,
+ attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'],
+ pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16',
+ pbr_slat_dec_path: Optional[str] = None,
+ pbr_slat_dec_ckpt: Optional[str] = None,
+ pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
+ shape_slat_dec_path: Optional[str] = None,
+ shape_slat_dec_ckpt: Optional[str] = None,
+ **kwargs
+ ):
+ self.resolution = resolution
+ self.pbr_slat_normalization = pbr_slat_normalization
+ self.shape_slat_normalization = shape_slat_normalization
+ self.min_aesthetic_score = min_aesthetic_score
+ self.max_tokens = max_tokens
+ self.full_pbr = full_pbr
+ self.value_range = (0, 1)
+
+ super().__init__(
+ roots,
+ pretrained_pbr_slat_dec=pretrained_pbr_slat_dec,
+ pbr_slat_dec_path=pbr_slat_dec_path,
+ pbr_slat_dec_ckpt=pbr_slat_dec_ckpt,
+ pretrained_shape_slat_dec=pretrained_shape_slat_dec,
+ shape_slat_dec_path=shape_slat_dec_path,
+ shape_slat_dec_ckpt=shape_slat_dec_ckpt,
+ **kwargs
+ )
+
+ self.loads = [self.metadata.loc[sha256, 'pbr_latent_tokens'] for _, sha256, _ in self.instances]
+
+ if self.pbr_slat_normalization is not None:
+ self.pbr_slat_mean = torch.tensor(self.pbr_slat_normalization['mean']).reshape(1, -1)
+ self.pbr_slat_std = torch.tensor(self.pbr_slat_normalization['std']).reshape(1, -1)
+
+ if self.shape_slat_normalization is not None:
+ self.shape_slat_mean = torch.tensor(self.shape_slat_normalization['mean']).reshape(1, -1)
+ self.shape_slat_std = torch.tensor(self.shape_slat_normalization['std']).reshape(1, -1)
+
+ self.attrs = attrs
+ self.channels = {
+ 'base_color': 3,
+ 'metallic': 1,
+ 'roughness': 1,
+ 'emissive': 3,
+ 'alpha': 1,
+ }
+ self.layout = {}
+ start = 0
+ for attr in attrs:
+ self.layout[attr] = slice(start, start + self.channels[attr])
+ start += self.channels[attr]
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ stats = {}
+ metadata = metadata[metadata['pbr_latent_encoded'] == True]
+ stats['With PBR latent'] = len(metadata)
+ metadata = metadata[metadata['shape_latent_encoded'] == True]
+ stats['With shape latent'] = len(metadata)
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
+ metadata = metadata[metadata['pbr_latent_tokens'] <= self.max_tokens]
+ stats[f'Num tokens <= {self.max_tokens}'] = len(metadata)
+ if self.full_pbr:
+ metadata = metadata[metadata['num_basecolor_tex'] > 0]
+ metadata = metadata[metadata['num_metallic_tex'] > 0]
+ metadata = metadata[metadata['num_roughness_tex'] > 0]
+ stats['Full PBR'] = len(metadata)
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ # PBR latent
+ data = np.load(os.path.join(root['pbr_latent'], f'{instance}.npz'))
+ coords = torch.tensor(data['coords']).int()
+ coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1)
+ feats = torch.tensor(data['feats']).float()
+ if self.pbr_slat_normalization is not None:
+ feats = (feats - self.pbr_slat_mean) / self.pbr_slat_std
+ pbr_z = SparseTensor(feats, coords)
+
+ # Shape latent
+ data = np.load(os.path.join(root['shape_latent'], f'{instance}.npz'))
+ coords = torch.tensor(data['coords']).int()
+ coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1)
+ feats = torch.tensor(data['feats']).float()
+ if self.shape_slat_normalization is not None:
+ feats = (feats - self.shape_slat_mean) / self.shape_slat_std
+ shape_z = SparseTensor(feats, coords)
+
+ assert torch.equal(shape_z.coords, pbr_z.coords), \
+ f"Shape latent and PBR latent have different coordinates: {shape_z.coords.shape} vs {pbr_z.coords.shape}"
+
+ return {
+ 'x_0': pbr_z,
+ 'concat_cond': shape_z,
+ }
+
+ @staticmethod
+ def collate_fn(batch, split_size=None):
+ if split_size is None:
+ group_idx = [list(range(len(batch)))]
+ else:
+ group_idx = load_balanced_group_indices([b['x_0'].feats.shape[0] for b in batch], split_size)
+ packs = []
+ for group in group_idx:
+ sub_batch = [batch[i] for i in group]
+ pack = {}
+
+ keys = [k for k in sub_batch[0].keys()]
+ for k in keys:
+ if isinstance(sub_batch[0][k], torch.Tensor):
+ pack[k] = torch.stack([b[k] for b in sub_batch])
+ elif isinstance(sub_batch[0][k], SparseTensor):
+ pack[k] = sparse_cat([b[k] for b in sub_batch], dim=0)
+ elif isinstance(sub_batch[0][k], list):
+ pack[k] = sum([b[k] for b in sub_batch], [])
+ else:
+ pack[k] = [b[k] for b in sub_batch]
+
+ packs.append(pack)
+
+ if split_size is None:
+ return packs[0]
+ return packs
+
+
+class ImageConditionedSLatPbr(ImageConditionedMixin, SLatPbr):
+ """
+ Image conditioned structured latent dataset
+ """
+ pass
+
+
+class SLatPbrView(SLatPbrVisMixin, StandardDatasetBase):
+ """
+ View-based structured latent for PBR/texture generation with view-aligned projection.
+
+ Data format:
+ PBR latent: {sha256}/view{XX}.npz (coords + feats)
+ Shape latent: {sha256}/view{XX}.npz (coords + feats, from shape_latent_view dir)
+
+ Each view's PBR latent and Shape latent share the same sparse coordinates.
+
+ Args:
+ roots (str): path to the dataset
+ resolution (int): resolution of decoded sparse voxel
+ min_aesthetic_score (float): minimum aesthetic score
+ max_tokens (int): maximum number of tokens
+ num_views (int): Number of views to use (0 to num_views-1). Default is 2.
+ full_pbr (bool): Whether to require full PBR textures
+ pbr_slat_normalization (dict): normalization stats for PBR latent
+ shape_slat_normalization (dict): normalization stats for shape latent
+ attrs (list): PBR attributes to decode
+ pretrained_pbr_slat_dec (str): pretrained PBR decoder name
+ pretrained_shape_slat_dec (str): pretrained shape decoder name
+ skip_list (str, optional): path to a file containing sha256 hashes to skip
+ skip_aesthetic_score_datasets (list, optional): datasets to skip aesthetic score check
+ """
+ def __init__(self,
+ roots: str,
+ *,
+ resolution: int,
+ min_aesthetic_score: float = 5.0,
+ max_tokens: int = 32768,
+ num_views: int = 2,
+ full_pbr: bool = False,
+ pbr_slat_normalization: Optional[dict] = None,
+ shape_slat_normalization: Optional[dict] = None,
+ attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'],
+ pretrained_pbr_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16',
+ pbr_slat_dec_path: Optional[str] = None,
+ pbr_slat_dec_ckpt: Optional[str] = None,
+ pretrained_shape_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
+ shape_slat_dec_path: Optional[str] = None,
+ shape_slat_dec_ckpt: Optional[str] = None,
+ skip_list: Optional[str] = None,
+ skip_aesthetic_score_datasets: Optional[list] = None,
+ ):
+ self.resolution = resolution
+ self.pbr_slat_normalization = pbr_slat_normalization
+ self.shape_slat_normalization = shape_slat_normalization
+ self.min_aesthetic_score = min_aesthetic_score
+ self.max_tokens = max_tokens
+ self.num_views = num_views
+ self.full_pbr = full_pbr
+ self.value_range = (0, 1)
+ self.skip_aesthetic_score_datasets = set(skip_aesthetic_score_datasets or [])
+
+ # Initialize visualization mixin
+ SLatPbrVisMixin.__init__(
+ self,
+ roots,
+ pretrained_pbr_slat_dec=pretrained_pbr_slat_dec,
+ pbr_slat_dec_path=pbr_slat_dec_path,
+ pbr_slat_dec_ckpt=pbr_slat_dec_ckpt,
+ pretrained_shape_slat_dec=pretrained_shape_slat_dec,
+ shape_slat_dec_path=shape_slat_dec_path,
+ shape_slat_dec_ckpt=shape_slat_dec_ckpt,
+ )
+ StandardDatasetBase.__init__(
+ self, roots,
+ skip_list=skip_list,
+ skip_aesthetic_score_datasets=skip_aesthetic_score_datasets,
+ )
+
+ # Calculate loads for load balancing
+ self.loads = []
+ for _, sha256, _ in self.instances:
+ if 'pbr_latent_tokens' in self.metadata.columns:
+ try:
+ self.loads.append(self.metadata.loc[sha256, 'pbr_latent_tokens'])
+ except:
+ self.loads.append(self.max_tokens)
+ else:
+ self.loads.append(self.max_tokens)
+
+ if self.pbr_slat_normalization is not None:
+ self.pbr_slat_mean = torch.tensor(self.pbr_slat_normalization['mean']).reshape(1, -1)
+ self.pbr_slat_std = torch.tensor(self.pbr_slat_normalization['std']).reshape(1, -1)
+
+ if self.shape_slat_normalization is not None:
+ self.shape_slat_mean = torch.tensor(self.shape_slat_normalization['mean']).reshape(1, -1)
+ self.shape_slat_std = torch.tensor(self.shape_slat_normalization['std']).reshape(1, -1)
+
+ self.attrs = attrs
+ self.channels = {
+ 'base_color': 3,
+ 'metallic': 1,
+ 'roughness': 1,
+ 'emissive': 3,
+ 'alpha': 1,
+ }
+ self.layout = {}
+ start = 0
+ for attr in attrs:
+ self.layout[attr] = slice(start, start + self.channels[attr])
+ start += self.channels[attr]
+
+ def filter_metadata(self, metadata, dataset_name=None):
+ stats = {}
+ # View-based PBR latent uses columns like pbr_latent_view00_encoded, etc.
+ required_pbr_view_cols = [f'pbr_latent_view{i:02d}_encoded' for i in range(self.num_views)]
+ existing_pbr_view_cols = [col for col in required_pbr_view_cols if col in metadata.columns]
+
+ if existing_pbr_view_cols:
+ has_all_pbr_views = (metadata[existing_pbr_view_cols] == True).all(axis=1)
+ metadata = metadata[has_all_pbr_views]
+ stats[f'With {self.num_views} PBR view latents'] = len(metadata)
+ else:
+ # Fallback: check pbr_latent_encoded
+ if 'pbr_latent_encoded' in metadata.columns:
+ metadata = metadata[metadata['pbr_latent_encoded'] == True]
+ stats['With PBR latent'] = len(metadata)
+
+ # Also require shape latent views
+ required_shape_view_cols = [f'shape_latent_view{i:02d}_encoded' for i in range(self.num_views)]
+ existing_shape_view_cols = [col for col in required_shape_view_cols if col in metadata.columns]
+
+ if existing_shape_view_cols:
+ has_all_shape_views = (metadata[existing_shape_view_cols] == True).all(axis=1)
+ metadata = metadata[has_all_shape_views]
+ stats[f'With {self.num_views} shape view latents'] = len(metadata)
+ else:
+ if 'shape_latent_encoded' in metadata.columns:
+ metadata = metadata[metadata['shape_latent_encoded'] == True]
+ stats['With shape latent'] = len(metadata)
+
+ # Skip aesthetic score check for specified datasets
+ skip_aesthetic = (
+ (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or
+ ('aesthetic_score' not in metadata.columns)
+ )
+ if skip_aesthetic:
+ stats[f'Aesthetic score check skipped'] = len(metadata)
+ else:
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
+
+ # Filter by max_tokens if column exists
+ if 'pbr_latent_tokens' in metadata.columns:
+ metadata = metadata[metadata['pbr_latent_tokens'] <= self.max_tokens]
+ stats[f'Num tokens <= {self.max_tokens}'] = len(metadata)
+
+ if self.full_pbr:
+ if 'num_basecolor_tex' in metadata.columns:
+ metadata = metadata[metadata['num_basecolor_tex'] > 0]
+ if 'num_metallic_tex' in metadata.columns:
+ metadata = metadata[metadata['num_metallic_tex'] > 0]
+ if 'num_roughness_tex' in metadata.columns:
+ metadata = metadata[metadata['num_roughness_tex'] > 0]
+ stats['Full PBR'] = len(metadata)
+
+ return metadata, stats
+
+ def get_instance(self, root, instance):
+ # Randomly select a view from the configured range
+ view_idx = np.random.randint(0, self.num_views)
+ view_file = f'view{view_idx:02d}.npz'
+
+ # Store view info for ViewImageConditionedMixin
+ self._current_view_idx = view_idx
+
+ # Load PBR latent for this view
+ pbr_latent_dir = os.path.join(root['pbr_latent'], instance)
+ self._current_latent_dir = pbr_latent_dir
+
+ data = np.load(os.path.join(pbr_latent_dir, view_file))
+ pbr_coords = torch.tensor(data['coords']).int()
+ pbr_feats = torch.tensor(data['feats']).float()
+ if self.pbr_slat_normalization is not None:
+ pbr_feats = (pbr_feats - self.pbr_slat_mean) / self.pbr_slat_std
+
+ # Load Shape latent for this view (as concat_cond)
+ shape_latent_dir = os.path.join(root['shape_latent'], instance)
+ data = np.load(os.path.join(shape_latent_dir, view_file))
+ shape_coords = torch.tensor(data['coords']).int()
+ shape_feats = torch.tensor(data['feats']).float()
+ if self.shape_slat_normalization is not None:
+ shape_feats = (shape_feats - self.shape_slat_mean) / self.shape_slat_std
+
+ # Verify coordinates match
+ assert torch.equal(pbr_coords, shape_coords), \
+ f"PBR and shape latent coordinates mismatch for {instance}/view{view_idx:02d}"
+
+ return {
+ 'coords': pbr_coords,
+ 'pbr_feats': pbr_feats,
+ 'shape_feats': shape_feats,
+ 'view_idx': view_idx,
+ }
+
+ @staticmethod
+ def collate_fn(batch, split_size=None):
+ if split_size is None:
+ group_idx = [list(range(len(batch)))]
+ else:
+ group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
+ packs = []
+ for group in group_idx:
+ sub_batch = [batch[i] for i in group]
+ pack = {}
+
+ # Build x_0 (PBR latent) and concat_cond (shape latent) as SparseTensors
+ coords_list = []
+ pbr_feats_list = []
+ shape_feats_list = []
+ layout = []
+ start = 0
+ for i, b in enumerate(sub_batch):
+ batch_coords = torch.cat([
+ torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32),
+ b['coords']
+ ], dim=-1)
+ coords_list.append(batch_coords)
+ pbr_feats_list.append(b['pbr_feats'])
+ shape_feats_list.append(b['shape_feats'])
+ layout.append(slice(start, start + b['coords'].shape[0]))
+ start += b['coords'].shape[0]
+
+ all_coords = torch.cat(coords_list)
+
+ # x_0: PBR latent
+ pack['x_0'] = SparseTensor(
+ coords=all_coords,
+ feats=torch.cat(pbr_feats_list),
+ )
+ pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['pbr_feats'].shape[1:]])
+ pack['x_0'].register_spatial_cache('layout', layout)
+
+ # concat_cond: Shape latent (same coordinates)
+ pack['concat_cond'] = SparseTensor(
+ coords=all_coords.clone(),
+ feats=torch.cat(shape_feats_list),
+ )
+ pack['concat_cond']._shape = torch.Size([len(group), *sub_batch[0]['shape_feats'].shape[1:]])
+ pack['concat_cond'].register_spatial_cache('layout', layout)
+
+ # collate other data (excluding already handled fields)
+ skip_keys = {'coords', 'pbr_feats', 'shape_feats'}
+ keys = [k for k in sub_batch[0].keys() if k not in skip_keys]
+ for k in keys:
+ if isinstance(sub_batch[0][k], torch.Tensor):
+ pack[k] = torch.stack([b[k] for b in sub_batch])
+ elif isinstance(sub_batch[0][k], list):
+ pack[k] = sum([b[k] for b in sub_batch], [])
+ else:
+ pack[k] = [b[k] for b in sub_batch]
+
+ packs.append(pack)
+
+ if split_size is None:
+ return packs[0]
+ return packs
+
+
+class ViewImageConditionedSLatPbrView(ViewImageConditionedMixin, SLatPbrView):
+ """
+ Image-conditioned view-based structured latent for PBR/texture generation
+ with view-aligned projection.
+
+ Loads PBR latent and shape latent from {sha256}/view{XX}.npz format and pairs
+ with corresponding view from render_cond.
+
+ Uses ViewImageConditionedMixin which reads mesh_scale from view{XX}_scale.json
+ and provides camera parameters for 3D-to-2D projection.
+ """
+ pass
diff --git a/trellis2/models/__init__.py b/trellis2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4fed035ed7c1f6352d93e9787f1aaed072876d5
--- /dev/null
+++ b/trellis2/models/__init__.py
@@ -0,0 +1,78 @@
+import importlib
+
+__attributes = {
+ # Sparse Structure
+ 'SparseStructureEncoder': 'sparse_structure_vae',
+ 'SparseStructureDecoder': 'sparse_structure_vae',
+ 'SparseStructureFlowModel': 'sparse_structure_flow',
+
+ # SLat Generation
+ 'SLatFlowModel': 'structured_latent_flow',
+ 'ElasticSLatFlowModel': 'structured_latent_flow',
+
+ # SC-VAEs
+ 'SparseUnetVaeEncoder': 'sc_vaes.sparse_unet_vae',
+ 'SparseUnetVaeDecoder': 'sc_vaes.sparse_unet_vae',
+ 'FlexiDualGridVaeEncoder': 'sc_vaes.fdg_vae',
+ 'FlexiDualGridVaeDecoder': 'sc_vaes.fdg_vae'
+}
+
+__submodules = []
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+ if name not in globals():
+ if name in __attributes:
+ module_name = __attributes[name]
+ module = importlib.import_module(f".{module_name}", __name__)
+ globals()[name] = getattr(module, name)
+ elif name in __submodules:
+ module = importlib.import_module(f".{name}", __name__)
+ globals()[name] = module
+ else:
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+ return globals()[name]
+
+
+def from_pretrained(path: str, **kwargs):
+ """
+ Load a model from a pretrained checkpoint.
+
+ Args:
+ path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
+ NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
+ **kwargs: Additional arguments for the model constructor.
+ """
+ import os
+ import json
+ from safetensors.torch import load_file
+ is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
+
+ if is_local:
+ config_file = f"{path}.json"
+ model_file = f"{path}.safetensors"
+ else:
+ from huggingface_hub import hf_hub_download
+ path_parts = path.split('/')
+ repo_id = f'{path_parts[0]}/{path_parts[1]}'
+ model_name = '/'.join(path_parts[2:])
+ config_file = hf_hub_download(repo_id, f"{model_name}.json")
+ model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
+
+ with open(config_file, 'r') as f:
+ config = json.load(f)
+ model = __getattr__(config['name'])(**config['args'], **kwargs)
+ model.load_state_dict(load_file(model_file), strict=False)
+
+ return model
+
+
+# For Pylance
+if __name__ == '__main__':
+ from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
+ from .sparse_structure_flow import SparseStructureFlowModel
+ from .structured_latent_flow import SLatFlowModel, ElasticSLatFlowModel
+
+ from .sc_vaes.sparse_unet_vae import SparseUnetVaeEncoder, SparseUnetVaeDecoder
+ from .sc_vaes.fdg_vae import FlexiDualGridVaeEncoder, FlexiDualGridVaeDecoder
diff --git a/trellis2/models/sc_vaes/fdg_vae.py b/trellis2/models/sc_vaes/fdg_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..0209e690ea6f4096dd42a4d137dbc0bee1f51367
--- /dev/null
+++ b/trellis2/models/sc_vaes/fdg_vae.py
@@ -0,0 +1,110 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ...modules import sparse as sp
+from .sparse_unet_vae import (
+ SparseResBlock3d,
+ SparseConvNeXtBlock3d,
+
+ SparseResBlockDownsample3d,
+ SparseResBlockUpsample3d,
+ SparseResBlockS2C3d,
+ SparseResBlockC2S3d,
+)
+from .sparse_unet_vae import (
+ SparseUnetVaeEncoder,
+ SparseUnetVaeDecoder,
+)
+from ...representations import Mesh
+from o_voxel.convert import flexible_dual_grid_to_mesh
+
+
+class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder):
+ def __init__(
+ self,
+ model_channels: List[int],
+ latent_channels: int,
+ num_blocks: List[int],
+ block_type: List[str],
+ down_block_type: List[str],
+ block_args: List[Dict[str, Any]],
+ use_fp16: bool = False,
+ ):
+ super().__init__(
+ 6,
+ model_channels,
+ latent_channels,
+ num_blocks,
+ block_type,
+ down_block_type,
+ block_args,
+ use_fp16,
+ )
+
+ def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False):
+ x = vertices.replace(torch.cat([
+ vertices.feats - 0.5,
+ intersected.feats.float() - 0.5,
+ ], dim=1))
+ return super().forward(x, sample_posterior, return_raw)
+
+
+class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
+ def __init__(
+ self,
+ resolution: int,
+ model_channels: List[int],
+ latent_channels: int,
+ num_blocks: List[int],
+ block_type: List[str],
+ up_block_type: List[str],
+ block_args: List[Dict[str, Any]],
+ voxel_margin: float = 0.5,
+ use_fp16: bool = False,
+ ):
+ self.resolution = resolution
+ self.voxel_margin = voxel_margin
+
+ super().__init__(
+ 7,
+ model_channels,
+ latent_channels,
+ num_blocks,
+ block_type,
+ up_block_type,
+ block_args,
+ use_fp16,
+ )
+
+ def set_resolution(self, resolution: int) -> None:
+ self.resolution = resolution
+
+ def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs):
+ decoded = super().forward(x, **kwargs)
+ if self.training:
+ h, subs_gt, subs = decoded
+ vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
+ intersected_logits = h.replace(h.feats[..., 3:6])
+ quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
+ mesh = [Mesh(*flexible_dual_grid_to_mesh(
+ v.coords[:, 1:], v.feats, i.feats, q.feats,
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
+ grid_size=self.resolution,
+ train=True
+ )) for v, i, q in zip(vertices, gt_intersected, quad_lerp)]
+ return mesh, vertices, intersected_logits, subs_gt, subs
+ else:
+ out_list = list(decoded) if isinstance(decoded, tuple) else [decoded]
+ h = out_list[0]
+ vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
+ intersected = h.replace(h.feats[..., 3:6] > 0)
+ quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
+ mesh = [Mesh(*flexible_dual_grid_to_mesh(
+ v.coords[:, 1:], v.feats, i.feats, q.feats,
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
+ grid_size=self.resolution,
+ train=False
+ )) for v, i, q in zip(vertices, intersected, quad_lerp)]
+ out_list[0] = mesh
+ return out_list[0] if len(out_list) == 1 else tuple(out_list)
diff --git a/trellis2/models/sc_vaes/sparse_unet_vae.py b/trellis2/models/sc_vaes/sparse_unet_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9902a155a8a85c5c616a4503be92f43f6fdde27
--- /dev/null
+++ b/trellis2/models/sc_vaes/sparse_unet_vae.py
@@ -0,0 +1,522 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from ...modules.utils import convert_module_to_f16, convert_module_to_f32, zero_module
+from ...modules import sparse as sp
+from ...modules.norm import LayerNorm32
+
+
+class SparseResBlock3d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ downsample: bool = False,
+ upsample: bool = False,
+ resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest',
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.downsample = downsample
+ self.upsample = upsample
+ self.resample_mode = resample_mode
+ self.use_checkpoint = use_checkpoint
+
+ assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
+
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
+ if resample_mode == 'nearest':
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
+ elif resample_mode =='spatial2channel' and not self.downsample:
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3)
+ elif resample_mode =='spatial2channel' and self.downsample:
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3)
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
+ if resample_mode == 'nearest':
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
+ elif resample_mode =='spatial2channel' and self.downsample:
+ self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
+ elif resample_mode =='spatial2channel' and not self.downsample:
+ self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
+ self.updown = None
+ if self.downsample:
+ if resample_mode == 'nearest':
+ self.updown = sp.SparseDownsample(2)
+ elif resample_mode =='spatial2channel':
+ self.updown = sp.SparseSpatial2Channel(2)
+ elif self.upsample:
+ self.to_subdiv = sp.SparseLinear(channels, 8)
+ if resample_mode == 'nearest':
+ self.updown = sp.SparseUpsample(2)
+ elif resample_mode =='spatial2channel':
+ self.updown = sp.SparseChannel2Spatial(2)
+
+ def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
+ if self.downsample:
+ x = self.updown(x)
+ elif self.upsample:
+ x = self.updown(x, subdiv.replace(subdiv.feats > 0))
+ return x
+
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ subdiv = None
+ if self.upsample:
+ subdiv = self.to_subdiv(x)
+ h = x.replace(self.norm1(x.feats))
+ h = h.replace(F.silu(h.feats))
+ if self.resample_mode == 'spatial2channel':
+ h = self.conv1(h)
+ h = self._updown(h, subdiv)
+ x = self._updown(x, subdiv)
+ if self.resample_mode == 'nearest':
+ h = self.conv1(h)
+ h = h.replace(self.norm2(h.feats))
+ h = h.replace(F.silu(h.feats))
+ h = self.conv2(h)
+ h = h + self.skip_connection(x)
+ if self.upsample:
+ return h, subdiv
+ return h
+
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseResBlockDownsample3d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_checkpoint = use_checkpoint
+
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
+ self.updown = sp.SparseDownsample(2)
+
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ h = x.replace(self.norm1(x.feats))
+ h = h.replace(F.silu(h.feats))
+ h = self.updown(h)
+ x = self.updown(x)
+ h = self.conv1(h)
+ h = h.replace(self.norm2(h.feats))
+ h = h.replace(F.silu(h.feats))
+ h = self.conv2(h)
+ h = h + self.skip_connection(x)
+ return h
+
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseResBlockUpsample3d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ use_checkpoint: bool = False,
+ pred_subdiv: bool = True,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_checkpoint = use_checkpoint
+ self.pred_subdiv = pred_subdiv
+
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
+ if self.pred_subdiv:
+ self.to_subdiv = sp.SparseLinear(channels, 8)
+ self.updown = sp.SparseUpsample(2)
+
+ def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
+ if self.pred_subdiv:
+ subdiv = self.to_subdiv(x)
+ h = x.replace(self.norm1(x.feats))
+ h = h.replace(F.silu(h.feats))
+ subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
+ h = self.updown(h, subdiv_binarized)
+ x = self.updown(x, subdiv_binarized)
+ h = self.conv1(h)
+ h = h.replace(self.norm2(h.feats))
+ h = h.replace(F.silu(h.feats))
+ h = self.conv2(h)
+ h = h + self.skip_connection(x)
+ if self.pred_subdiv:
+ return h, subdiv
+ else:
+ return h
+
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseResBlockS2C3d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_checkpoint = use_checkpoint
+
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3)
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
+ self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
+ self.updown = sp.SparseSpatial2Channel(2)
+
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ h = x.replace(self.norm1(x.feats))
+ h = h.replace(F.silu(h.feats))
+ h = self.conv1(h)
+ h = self.updown(h)
+ x = self.updown(x)
+ h = h.replace(self.norm2(h.feats))
+ h = h.replace(F.silu(h.feats))
+ h = self.conv2(h)
+ h = h + self.skip_connection(x)
+ return h
+
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseResBlockC2S3d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ use_checkpoint: bool = False,
+ pred_subdiv: bool = True,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_checkpoint = use_checkpoint
+ self.pred_subdiv = pred_subdiv
+
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3)
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
+ self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
+ if pred_subdiv:
+ self.to_subdiv = sp.SparseLinear(channels, 8)
+ self.updown = sp.SparseChannel2Spatial(2)
+
+ def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
+ if self.pred_subdiv:
+ subdiv = self.to_subdiv(x)
+ h = x.replace(self.norm1(x.feats))
+ h = h.replace(F.silu(h.feats))
+ h = self.conv1(h)
+ subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
+ h = self.updown(h, subdiv_binarized)
+ x = self.updown(x, subdiv_binarized)
+ h = h.replace(self.norm2(h.feats))
+ h = h.replace(F.silu(h.feats))
+ h = self.conv2(h)
+ h = h + self.skip_connection(x)
+ if self.pred_subdiv:
+ return h, subdiv
+ else:
+ return h
+
+ def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False)
+ else:
+ return self._forward(x, subdiv)
+
+
+class SparseConvNeXtBlock3d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ mlp_ratio: float = 4.0,
+ use_checkpoint: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.use_checkpoint = use_checkpoint
+
+ self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.conv = sp.SparseConv3d(channels, channels, 3)
+ self.mlp = nn.Sequential(
+ nn.Linear(channels, int(channels * mlp_ratio)),
+ nn.SiLU(),
+ zero_module(nn.Linear(int(channels * mlp_ratio), channels)),
+ )
+
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ h = self.conv(x)
+ h = h.replace(self.norm(h.feats))
+ h = h.replace(self.mlp(h.feats))
+ return h + x
+
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseUnetVaeEncoder(nn.Module):
+ """
+ Sparse Swin Transformer Unet VAE model.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ model_channels: List[int],
+ latent_channels: int,
+ num_blocks: List[int],
+ block_type: List[str],
+ down_block_type: List[str],
+ block_args: List[Dict[str, Any]],
+ use_fp16: bool = False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.num_blocks = num_blocks
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+
+ self.input_layer = sp.SparseLinear(in_channels, model_channels[0])
+ self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels)
+
+ self.blocks = nn.ModuleList([])
+ for i in range(len(num_blocks)):
+ self.blocks.append(nn.ModuleList([]))
+ for j in range(num_blocks[i]):
+ self.blocks[-1].append(
+ globals()[block_type[i]](
+ model_channels[i],
+ **block_args[i],
+ )
+ )
+ if i < len(num_blocks) - 1:
+ self.blocks[-1].append(
+ globals()[down_block_type[i]](
+ model_channels[i],
+ model_channels[i+1],
+ **block_args[i],
+ )
+ )
+
+ self.initialize_weights()
+ if use_fp16:
+ self.convert_to_fp16()
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ self.blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.blocks.apply(convert_module_to_f32)
+
+ def initialize_weights(self) -> None:
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False):
+ h = self.input_layer(x)
+ h = h.type(self.dtype)
+ for i, res in enumerate(self.blocks):
+ for j, block in enumerate(res):
+ h = block(h)
+ h = h.type(x.dtype)
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+ h = self.to_latent(h)
+
+ # Sample from the posterior distribution
+ mean, logvar = h.feats.chunk(2, dim=-1)
+ if sample_posterior:
+ std = torch.exp(0.5 * logvar)
+ z = mean + std * torch.randn_like(std)
+ else:
+ z = mean
+ z = h.replace(z)
+
+ if return_raw:
+ return z, mean, logvar
+ else:
+ return z
+
+
+class SparseUnetVaeDecoder(nn.Module):
+ """
+ Sparse Swin Transformer Unet VAE model.
+ """
+ def __init__(
+ self,
+ out_channels: int,
+ model_channels: List[int],
+ latent_channels: int,
+ num_blocks: List[int],
+ block_type: List[str],
+ up_block_type: List[str],
+ block_args: List[Dict[str, Any]],
+ use_fp16: bool = False,
+ pred_subdiv: bool = True,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.model_channels = model_channels
+ self.num_blocks = num_blocks
+ self.use_fp16 = use_fp16
+ self.pred_subdiv = pred_subdiv
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.low_vram = False
+
+ self.output_layer = sp.SparseLinear(model_channels[-1], out_channels)
+ self.from_latent = sp.SparseLinear(latent_channels, model_channels[0])
+
+ self.blocks = nn.ModuleList([])
+ for i in range(len(num_blocks)):
+ self.blocks.append(nn.ModuleList([]))
+ for j in range(num_blocks[i]):
+ self.blocks[-1].append(
+ globals()[block_type[i]](
+ model_channels[i],
+ **block_args[i],
+ )
+ )
+ if i < len(num_blocks) - 1:
+ self.blocks[-1].append(
+ globals()[up_block_type[i]](
+ model_channels[i],
+ model_channels[i+1],
+ pred_subdiv=pred_subdiv,
+ **block_args[i],
+ )
+ )
+
+ self.initialize_weights()
+ if use_fp16:
+ self.convert_to_fp16()
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ self.blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.blocks.apply(convert_module_to_f32)
+
+ def initialize_weights(self) -> None:
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor:
+ assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs"
+ assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs"
+
+ h = self.from_latent(x)
+ h = h.type(self.dtype)
+ subs_gt = []
+ subs = []
+ for i, res in enumerate(self.blocks):
+ for j, block in enumerate(res):
+ if i < len(self.blocks) - 1 and j == len(res) - 1:
+ if self.pred_subdiv:
+ if self.training:
+ subs_gt.append(h.get_spatial_cache('subdivision'))
+ h, sub = block(h)
+ subs.append(sub)
+ else:
+ h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
+ else:
+ h = block(h)
+ h = h.type(x.dtype)
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+ h = self.output_layer(h)
+ if self.training and self.pred_subdiv:
+ return h, subs_gt, subs
+ else:
+ if return_subs:
+ return h, subs
+ else:
+ return h
+
+ def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor:
+ assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling"
+
+ h = self.from_latent(x)
+ h = h.type(self.dtype)
+ for i, res in enumerate(self.blocks):
+ if i == upsample_times:
+ return h.coords
+ for j, block in enumerate(res):
+ if i < len(self.blocks) - 1 and j == len(res) - 1:
+ h, sub = block(h)
+ else:
+ h = block(h)
+
\ No newline at end of file
diff --git a/trellis2/models/sparse_elastic_mixin.py b/trellis2/models/sparse_elastic_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..66d204c89bedabc2afd1795cdfc6f5d58a6b1ac0
--- /dev/null
+++ b/trellis2/models/sparse_elastic_mixin.py
@@ -0,0 +1,24 @@
+from contextlib import contextmanager
+from typing import *
+import math
+from ..modules import sparse as sp
+from ..utils.elastic_utils import ElasticModuleMixin
+
+
+class SparseTransformerElasticMixin(ElasticModuleMixin):
+ def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
+ return x.feats.shape[0]
+
+ @contextmanager
+ def with_mem_ratio(self, mem_ratio=1.0):
+ if mem_ratio == 1.0:
+ yield 1.0
+ return
+ num_blocks = len(self.blocks)
+ num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
+ exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
+ for i in range(num_blocks):
+ self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
+ yield exact_mem_ratio
+ for i in range(num_blocks):
+ self.blocks[i].use_checkpoint = False
diff --git a/trellis2/models/sparse_structure_flow.py b/trellis2/models/sparse_structure_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..be6b50eac0095d3f57b24c134b893342edc6f101
--- /dev/null
+++ b/trellis2/models/sparse_structure_flow.py
@@ -0,0 +1,298 @@
+from typing import *
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ..modules.utils import convert_module_to, manual_cast, str_to_dtype
+from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
+from ..modules.attention import RotaryPositionEmbedder
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ Args:
+ t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ dim: the dimension of the output.
+ max_period: controls the minimum frequency of the embeddings.
+
+ Returns:
+ an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class SparseStructureFlowModel(nn.Module):
+ """
+ Sparse Structure Flow Model for 3D generation.
+
+ Supports two conditioning modes:
+ - "cross": Standard cross-attention with image features
+ - "proj": View-aligned projection attention with camera-aware features
+ """
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ model_channels: int,
+ cond_channels: int,
+ out_channels: int,
+ num_blocks: int,
+ num_heads: Optional[int] = None,
+ num_head_channels: Optional[int] = 64,
+ mlp_ratio: float = 4,
+ pe_mode: Literal["ape", "rope"] = "ape",
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
+ dtype: str = 'float32',
+ use_checkpoint: bool = False,
+ share_mod: bool = False,
+ initialization: str = 'vanilla',
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross",
+ proj_in_channels: Optional[int] = None,
+ vae_in_channels: Optional[int] = None,
+ **kwargs
+ ):
+ super().__init__()
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.cond_channels = cond_channels
+ self.out_channels = out_channels
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads or model_channels // num_head_channels
+ self.mlp_ratio = mlp_ratio
+ self.pe_mode = pe_mode
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.initialization = initialization
+ self.qk_rms_norm = qk_rms_norm
+ self.qk_rms_norm_cross = qk_rms_norm_cross
+ self.image_attn_mode = image_attn_mode
+ self.proj_in_channels = proj_in_channels
+ self.vae_in_channels = vae_in_channels
+ self.dtype = str_to_dtype(dtype)
+
+ self.t_embedder = TimestepEmbedder(model_channels)
+ if share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
+ )
+
+ if pe_mode == "ape":
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
+ pos_emb = pos_embedder(coords)
+ self.register_buffer("pos_emb", pos_emb)
+ elif pe_mode == "rope":
+ pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3)
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
+ rope_phases = pos_embedder(coords)
+ self.register_buffer("rope_phases", rope_phases)
+
+ if pe_mode != "rope":
+ self.rope_phases = None
+
+ self.input_layer = nn.Linear(in_channels, model_channels)
+
+ self.blocks = nn.ModuleList([
+ ModulatedTransformerCrossBlock(
+ model_channels,
+ cond_channels,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ attn_mode='full',
+ use_checkpoint=self.use_checkpoint,
+ use_rope=(pe_mode == "rope"),
+ rope_freq=rope_freq,
+ share_mod=share_mod,
+ qk_rms_norm=self.qk_rms_norm,
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
+ image_attn_mode=image_attn_mode,
+ proj_in_channels=proj_in_channels,
+ vae_in_channels=vae_in_channels,
+ )
+ for _ in range(num_blocks)
+ ])
+
+ self.out_layer = nn.Linear(model_channels, out_channels)
+
+ self.initialize_weights()
+ self.convert_to(self.dtype)
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to(self, dtype: torch.dtype) -> None:
+ """
+ Convert the torso of the model to the specified dtype.
+ """
+ self.dtype = dtype
+ self.blocks.apply(partial(convert_module_to, dtype=dtype))
+
+ def initialize_weights(self) -> None:
+ if self.initialization == 'vanilla':
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ if self.share_mod:
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+ else:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.out_layer.weight, 0)
+ nn.init.constant_(self.out_layer.bias, 0)
+
+ elif self.initialization == 'scaled':
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels)))
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Scaled init for to_out and ffn2
+ def _scaled_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels))
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ for block in self.blocks:
+ block.self_attn.to_out.apply(_scaled_init)
+ # Handle cross, proj, and gated_proj modes
+ if self.image_attn_mode in ("proj", "gated_proj"):
+ block.cross_attn.cross_attn_block.to_out.apply(_scaled_init)
+ else:
+ block.cross_attn.to_out.apply(_scaled_init)
+ block.mlp.mlp[2].apply(_scaled_init)
+
+ # Initialize input layer to make the initial representation have variance 1
+ nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels))
+ nn.init.zeros_(self.input_layer.bias)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ if self.share_mod:
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+ else:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.out_layer.weight, 0)
+ nn.init.constant_(self.out_layer.bias, 0)
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass.
+
+ Args:
+ x: Input tensor [B, C, D, H, W]
+ t: Timestep tensor [B]
+ cond: Conditioning tensor. For "cross" mode: [B, N, D].
+ For "proj" mode: dict {'global': global_cond, 'proj': proj_cond}
+ or tuple of (global_cond, proj_cond)
+
+ Returns:
+ Output tensor [B, C, D, H, W]
+ """
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
+
+ h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
+
+ h = self.input_layer(h)
+ if self.pe_mode == "ape":
+ h = h + self.pos_emb[None]
+ t_emb = self.t_embedder(t)
+ if self.share_mod:
+ t_emb = self.adaLN_modulation(t_emb)
+ t_emb = manual_cast(t_emb, self.dtype)
+ h = manual_cast(h, self.dtype)
+
+ # Handle different conditioning modes
+ if self.image_attn_mode == 'proj':
+ if isinstance(cond, dict):
+ global_cond = cond['global']
+ proj_cond = cond['proj']
+ else:
+ global_cond, proj_cond = cond
+ global_cond = manual_cast(global_cond, self.dtype)
+ proj_cond = manual_cast(proj_cond, self.dtype)
+ cond = (global_cond, proj_cond)
+ elif self.image_attn_mode == 'gated_proj':
+ global_cond = manual_cast(cond['global'], self.dtype)
+ proj_semantic = manual_cast(cond['proj_semantic'], self.dtype)
+ proj_color = manual_cast(cond['proj_color'], self.dtype)
+ cond = {'global': global_cond, 'proj_semantic': proj_semantic, 'proj_color': proj_color}
+ else:
+ cond = manual_cast(cond, self.dtype)
+
+ for block in self.blocks:
+ h = block(h, t_emb, cond, self.rope_phases)
+ h = manual_cast(h, x.dtype)
+ h = F.layer_norm(h, h.shape[-1:])
+ h = self.out_layer(h)
+
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
+
+ return h
diff --git a/trellis2/models/sparse_structure_vae.py b/trellis2/models/sparse_structure_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e09136cf294c4c1b47b0f09fa6ee57bad2166d
--- /dev/null
+++ b/trellis2/models/sparse_structure_vae.py
@@ -0,0 +1,306 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..modules.norm import GroupNorm32, ChannelLayerNorm32
+from ..modules.spatial import pixel_shuffle_3d
+from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
+
+
+def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
+ """
+ Return a normalization layer.
+ """
+ if norm_type == "group":
+ return GroupNorm32(32, *args, **kwargs)
+ elif norm_type == "layer":
+ return ChannelLayerNorm32(*args, **kwargs)
+ else:
+ raise ValueError(f"Invalid norm type {norm_type}")
+
+
+class ResBlock3d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ norm_type: Literal["group", "layer"] = "layer",
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.norm1 = norm_layer(norm_type, channels)
+ self.norm2 = norm_layer(norm_type, self.out_channels)
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.norm1(x)
+ h = F.silu(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = F.silu(h)
+ h = self.conv2(h)
+ h = h + self.skip_connection(x)
+ return h
+
+
+class DownsampleBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ mode: Literal["conv", "avgpool"] = "conv",
+ ):
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
+
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if mode == "conv":
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
+ elif mode == "avgpool":
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if hasattr(self, "conv"):
+ return self.conv(x)
+ else:
+ return F.avg_pool3d(x, 2)
+
+
+class UpsampleBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ mode: Literal["conv", "nearest"] = "conv",
+ ):
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
+
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if mode == "conv":
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
+ elif mode == "nearest":
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if hasattr(self, "conv"):
+ x = self.conv(x)
+ return pixel_shuffle_3d(x, 2)
+ else:
+ return F.interpolate(x, scale_factor=2, mode="nearest")
+
+
+class SparseStructureEncoder(nn.Module):
+ """
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
+
+ Args:
+ in_channels (int): Channels of the input.
+ latent_channels (int): Channels of the latent representation.
+ num_res_blocks (int): Number of residual blocks at each resolution.
+ channels (List[int]): Channels of the encoder blocks.
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
+ use_fp16 (bool): Whether to use FP16.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ latent_channels: int,
+ num_res_blocks: int,
+ channels: List[int],
+ num_res_blocks_middle: int = 2,
+ norm_type: Literal["group", "layer"] = "layer",
+ use_fp16: bool = False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.latent_channels = latent_channels
+ self.num_res_blocks = num_res_blocks
+ self.channels = channels
+ self.num_res_blocks_middle = num_res_blocks_middle
+ self.norm_type = norm_type
+ self.use_fp16 = use_fp16
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+ for i, ch in enumerate(channels):
+ self.blocks.extend([
+ ResBlock3d(ch, ch)
+ for _ in range(num_res_blocks)
+ ])
+ if i < len(channels) - 1:
+ self.blocks.append(
+ DownsampleBlock3d(ch, channels[i+1])
+ )
+
+ self.middle_block = nn.Sequential(*[
+ ResBlock3d(channels[-1], channels[-1])
+ for _ in range(num_res_blocks_middle)
+ ])
+
+ self.out_layer = nn.Sequential(
+ norm_layer(norm_type, channels[-1]),
+ nn.SiLU(),
+ nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
+ )
+
+ if use_fp16:
+ self.convert_to_fp16()
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ self.use_fp16 = True
+ self.dtype = torch.float16
+ self.blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.use_fp16 = False
+ self.dtype = torch.float32
+ self.blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
+ h = self.input_layer(x)
+ h = h.type(self.dtype)
+
+ for block in self.blocks:
+ h = block(h)
+ h = self.middle_block(h)
+
+ h = h.type(x.dtype)
+ h = self.out_layer(h)
+
+ mean, logvar = h.chunk(2, dim=1)
+
+ if sample_posterior:
+ std = torch.exp(0.5 * logvar)
+ z = mean + std * torch.randn_like(std)
+ else:
+ z = mean
+
+ if return_raw:
+ return z, mean, logvar
+ return z
+
+
+class SparseStructureDecoder(nn.Module):
+ """
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
+
+ Args:
+ out_channels (int): Channels of the output.
+ latent_channels (int): Channels of the latent representation.
+ num_res_blocks (int): Number of residual blocks at each resolution.
+ channels (List[int]): Channels of the decoder blocks.
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
+ use_fp16 (bool): Whether to use FP16.
+ """
+ def __init__(
+ self,
+ out_channels: int,
+ latent_channels: int,
+ num_res_blocks: int,
+ channels: List[int],
+ num_res_blocks_middle: int = 2,
+ norm_type: Literal["group", "layer"] = "layer",
+ use_fp16: bool = False,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.latent_channels = latent_channels
+ self.num_res_blocks = num_res_blocks
+ self.channels = channels
+ self.num_res_blocks_middle = num_res_blocks_middle
+ self.norm_type = norm_type
+ self.use_fp16 = use_fp16
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
+
+ self.middle_block = nn.Sequential(*[
+ ResBlock3d(channels[0], channels[0])
+ for _ in range(num_res_blocks_middle)
+ ])
+
+ self.blocks = nn.ModuleList([])
+ for i, ch in enumerate(channels):
+ self.blocks.extend([
+ ResBlock3d(ch, ch)
+ for _ in range(num_res_blocks)
+ ])
+ if i < len(channels) - 1:
+ self.blocks.append(
+ UpsampleBlock3d(ch, channels[i+1])
+ )
+
+ self.out_layer = nn.Sequential(
+ norm_layer(norm_type, channels[-1]),
+ nn.SiLU(),
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1)
+ )
+
+ if use_fp16:
+ self.convert_to_fp16()
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to_fp16(self) -> None:
+ """
+ Convert the torso of the model to float16.
+ """
+ self.use_fp16 = True
+ self.dtype = torch.float16
+ self.blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self) -> None:
+ """
+ Convert the torso of the model to float32.
+ """
+ self.use_fp16 = False
+ self.dtype = torch.float32
+ self.blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.input_layer(x)
+
+ h = h.type(self.dtype)
+
+ h = self.middle_block(h)
+ for block in self.blocks:
+ h = block(h)
+
+ h = h.type(x.dtype)
+ h = self.out_layer(h)
+ return h
diff --git a/trellis2/models/structured_latent_flow.py b/trellis2/models/structured_latent_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..23d78ad65c62aadf3442de2d81efd12dc2490ce2
--- /dev/null
+++ b/trellis2/models/structured_latent_flow.py
@@ -0,0 +1,265 @@
+from typing import *
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ..modules.utils import convert_module_to, manual_cast, str_to_dtype
+from ..modules.transformer import AbsolutePositionEmbedder
+from ..modules import sparse as sp
+from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
+from .sparse_structure_flow import TimestepEmbedder
+from .sparse_elastic_mixin import SparseTransformerElasticMixin
+
+
+class SLatFlowModel(nn.Module):
+ """
+ Structured Latent Flow Model for 3D generation.
+
+ Supports two conditioning modes:
+ - "cross": Standard cross-attention with image features
+ - "proj": View-aligned projection attention with camera-aware features
+ """
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ model_channels: int,
+ cond_channels: int,
+ out_channels: int,
+ num_blocks: int,
+ num_heads: Optional[int] = None,
+ num_head_channels: Optional[int] = 64,
+ mlp_ratio: float = 4,
+ pe_mode: Literal["ape", "rope"] = "ape",
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
+ dtype: str = 'float32',
+ use_checkpoint: bool = False,
+ share_mod: bool = False,
+ initialization: str = 'vanilla',
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross",
+ proj_in_channels: Optional[int] = None,
+ vae_in_channels: Optional[int] = None,
+ **kwargs
+ ):
+ super().__init__()
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.cond_channels = cond_channels
+ self.out_channels = out_channels
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads or model_channels // num_head_channels
+ self.mlp_ratio = mlp_ratio
+ self.pe_mode = pe_mode
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.initialization = initialization
+ self.qk_rms_norm = qk_rms_norm
+ self.qk_rms_norm_cross = qk_rms_norm_cross
+ self.image_attn_mode = image_attn_mode
+ self.proj_in_channels = proj_in_channels
+ self.vae_in_channels = vae_in_channels
+ self.dtype = str_to_dtype(dtype)
+
+ self.t_embedder = TimestepEmbedder(model_channels)
+ if share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
+ )
+
+ if pe_mode == "ape":
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
+
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
+
+ self.blocks = nn.ModuleList([
+ ModulatedSparseTransformerCrossBlock(
+ model_channels,
+ cond_channels,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ attn_mode='full',
+ use_checkpoint=self.use_checkpoint,
+ use_rope=(pe_mode == "rope"),
+ rope_freq=rope_freq,
+ share_mod=self.share_mod,
+ qk_rms_norm=self.qk_rms_norm,
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
+ image_attn_mode=image_attn_mode,
+ proj_in_channels=proj_in_channels,
+ vae_in_channels=vae_in_channels,
+ )
+ for _ in range(num_blocks)
+ ])
+
+ self.out_layer = sp.SparseLinear(model_channels, out_channels)
+
+ self.initialize_weights()
+ self.convert_to(self.dtype)
+
+ @property
+ def device(self) -> torch.device:
+ """
+ Return the device of the model.
+ """
+ return next(self.parameters()).device
+
+ def convert_to(self, dtype: torch.dtype) -> None:
+ """
+ Convert the torso of the model to the specified dtype.
+ """
+ self.dtype = dtype
+ self.blocks.apply(partial(convert_module_to, dtype=dtype))
+
+ def initialize_weights(self) -> None:
+ if self.initialization == 'vanilla':
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ if self.share_mod:
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+ else:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.out_layer.weight, 0)
+ nn.init.constant_(self.out_layer.bias, 0)
+
+ elif self.initialization == 'scaled':
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels)))
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Scaled init for to_out and ffn2
+ def _scaled_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels))
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ for block in self.blocks:
+ block.self_attn.to_out.apply(_scaled_init)
+ # Handle cross, proj, and gated_proj modes
+ if self.image_attn_mode in ("proj", "gated_proj"):
+ block.cross_attn.cross_attn_block.to_out.apply(_scaled_init)
+ else:
+ block.cross_attn.to_out.apply(_scaled_init)
+ block.mlp.mlp[2].apply(_scaled_init)
+
+ # Initialize input layer to make the initial representation have variance 1
+ nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels))
+ nn.init.zeros_(self.input_layer.bias)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ if self.share_mod:
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+ else:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.out_layer.weight, 0)
+ nn.init.constant_(self.out_layer.bias, 0)
+
+ def forward(
+ self,
+ x: sp.SparseTensor,
+ t: torch.Tensor,
+ cond: Union[torch.Tensor, List[torch.Tensor], Dict[str, Union[torch.Tensor, sp.SparseTensor]], Tuple],
+ concat_cond: Optional[sp.SparseTensor] = None,
+ **kwargs
+ ) -> sp.SparseTensor:
+ """
+ Forward pass.
+
+ Args:
+ x: SparseTensor input
+ t: Timestep tensor [B]
+ cond: Conditioning tensor. For "cross" mode: list of tensors or tensor.
+ For "proj" mode: dict {'global': global_cond, 'proj': proj_cond}
+ or tuple of (global_cond, proj_cond)
+ concat_cond: Optional concatenation condition
+
+ Returns:
+ SparseTensor output
+ """
+ if concat_cond is not None:
+ x = sp.sparse_cat([x, concat_cond], dim=-1)
+
+ h = self.input_layer(x)
+ h = manual_cast(h, self.dtype)
+ t_emb = self.t_embedder(t)
+ if self.share_mod:
+ t_emb = self.adaLN_modulation(t_emb)
+ t_emb = manual_cast(t_emb, self.dtype)
+
+ if self.pe_mode == "ape":
+ pe = self.pos_embedder(h.coords[:, 1:])
+ h = h + manual_cast(pe, self.dtype)
+
+ # Handle different conditioning modes
+ if self.image_attn_mode == 'proj':
+ if isinstance(cond, dict):
+ global_cond = cond['global']
+ proj_cond = cond['proj']
+ else:
+ global_cond, proj_cond = cond
+ if isinstance(global_cond, list):
+ global_cond = sp.VarLenTensor.from_tensor_list(global_cond)
+ global_cond = manual_cast(global_cond, self.dtype)
+ proj_cond = manual_cast(proj_cond, self.dtype)
+ cond = (global_cond, proj_cond)
+ elif self.image_attn_mode == 'gated_proj':
+ global_cond = cond['global']
+ if isinstance(global_cond, list):
+ global_cond = sp.VarLenTensor.from_tensor_list(global_cond)
+ global_cond = manual_cast(global_cond, self.dtype)
+ proj_semantic = manual_cast(cond['proj_semantic'], self.dtype)
+ proj_color = manual_cast(cond['proj_color'], self.dtype)
+ cond = {'global': global_cond, 'proj_semantic': proj_semantic, 'proj_color': proj_color}
+ else:
+ if isinstance(cond, list):
+ cond = sp.VarLenTensor.from_tensor_list(cond)
+ cond = manual_cast(cond, self.dtype)
+
+ for block in self.blocks:
+ h = block(h, t_emb, cond)
+
+ h = manual_cast(h, x.dtype)
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+ h = self.out_layer(h)
+ return h
+
+
+class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
+ """
+ SLat Flow Model with elastic memory management.
+ Used for training with low VRAM.
+ """
+ pass
diff --git a/trellis2/modules/attention/__init__.py b/trellis2/modules/attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..318c6d3b13cf0f2d6b814876a8efc7c9941806a8
--- /dev/null
+++ b/trellis2/modules/attention/__init__.py
@@ -0,0 +1,4 @@
+from .full_attn import *
+from .modules import *
+from .rope import *
+from .proj_attention import ProjectAttention, GatedProjectAttention
diff --git a/trellis2/modules/attention/config.py b/trellis2/modules/attention/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..579db837b858c8b2f424732c8633489a577079e2
--- /dev/null
+++ b/trellis2/modules/attention/config.py
@@ -0,0 +1,32 @@
+from typing import *
+
+BACKEND = 'flash_attn'
+DEBUG = False
+
+def __from_env():
+ import os
+
+ global BACKEND
+ global DEBUG
+
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
+ env_attn_debug = os.environ.get('ATTN_DEBUG')
+
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'flash_attn_4', 'sdpa', 'naive']:
+ BACKEND = env_attn_backend
+ if env_attn_debug is not None:
+ DEBUG = env_attn_debug == '1'
+
+ print(f"[ATTENTION] Using backend: {BACKEND}")
+
+
+__from_env()
+
+
+def set_backend(backend: Literal['xformers', 'flash_attn', 'flash_attn_3', 'flash_attn_4', 'sdpa', 'naive']):
+ global BACKEND
+ BACKEND = backend
+
+def set_debug(debug: bool):
+ global DEBUG
+ DEBUG = debug
diff --git a/trellis2/modules/attention/full_attn.py b/trellis2/modules/attention/full_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c0216288bb3e4b52673488b0b011e004eda68fd
--- /dev/null
+++ b/trellis2/modules/attention/full_attn.py
@@ -0,0 +1,153 @@
+from typing import *
+import torch
+import math
+from . import config
+
+
+__all__ = [
+ 'scaled_dot_product_attention',
+]
+
+
+def _naive_sdpa(q, k, v):
+ """
+ Naive implementation of scaled dot product attention.
+ """
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
+ scale_factor = 1 / math.sqrt(q.size(-1))
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ out = attn_weight @ v
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
+ return out
+
+
+@overload
+def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention.
+
+ Args:
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
+ """
+ ...
+
+@overload
+def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention.
+
+ Args:
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
+ """
+ ...
+
+@overload
+def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention.
+
+ Args:
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
+
+ Note:
+ k and v are assumed to have the same coordinate map.
+ """
+ ...
+
+def scaled_dot_product_attention(*args, **kwargs):
+ arg_names_dict = {
+ 1: ['qkv'],
+ 2: ['q', 'kv'],
+ 3: ['q', 'k', 'v']
+ }
+ num_all_args = len(args) + len(kwargs)
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
+ for key in arg_names_dict[num_all_args][len(args):]:
+ assert key in kwargs, f"Missing argument {key}"
+
+ if num_all_args == 1:
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
+ device = qkv.device
+
+ elif num_all_args == 2:
+ q = args[0] if len(args) > 0 else kwargs['q']
+ kv = args[1] if len(args) > 1 else kwargs['kv']
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
+ device = q.device
+
+ elif num_all_args == 3:
+ q = args[0] if len(args) > 0 else kwargs['q']
+ k = args[1] if len(args) > 1 else kwargs['k']
+ v = args[2] if len(args) > 2 else kwargs['v']
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
+ device = q.device
+
+ if config.BACKEND == 'xformers':
+ if 'xops' not in globals():
+ import xformers.ops as xops
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=2)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=2)
+ out = xops.memory_efficient_attention(q, k, v)
+ elif config.BACKEND == 'flash_attn':
+ if 'flash_attn' not in globals():
+ import flash_attn
+ if num_all_args == 1:
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
+ elif num_all_args == 2:
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
+ elif num_all_args == 3:
+ out = flash_attn.flash_attn_func(q, k, v)
+ elif config.BACKEND == 'flash_attn_3':
+ if 'flash_attn_3' not in globals():
+ import flash_attn_interface as flash_attn_3
+ if num_all_args == 1:
+ out = flash_attn_3.flash_attn_qkvpacked_func(qkv)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=2)
+ out = flash_attn_3.flash_attn_func(q, k, v)
+ elif num_all_args == 3:
+ out = flash_attn_3.flash_attn_func(q, k, v)
+ elif config.BACKEND == 'flash_attn_4':
+ if 'flash_attn_4' not in globals():
+ from flash_attn.cute import flash_attn_func as flash_attn_4_func
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=2)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=2)
+ out, _ = flash_attn_4_func(q, k, v)
+ elif config.BACKEND == 'sdpa':
+ if 'sdpa' not in globals():
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=2)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=2)
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
+ out = sdpa(q, k, v) # [N, H, L, C]
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
+ elif config.BACKEND == 'naive':
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=2)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=2)
+ out = _naive_sdpa(q, k, v)
+ else:
+ raise ValueError(f"Unknown attention module: {config.BACKEND}")
+
+ return out
diff --git a/trellis2/modules/attention/modules.py b/trellis2/modules/attention/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..492784c7ba8f572c4820b604f51c924ed564ab00
--- /dev/null
+++ b/trellis2/modules/attention/modules.py
@@ -0,0 +1,102 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .full_attn import scaled_dot_product_attention
+from .rope import RotaryPositionEmbedder
+
+
+class MultiHeadRMSNorm(nn.Module):
+ def __init__(self, dim: int, heads: int):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ ctx_channels: Optional[int]=None,
+ type: Literal["self", "cross"] = "self",
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ qkv_bias: bool = True,
+ use_rope: bool = False,
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ ):
+ super().__init__()
+ assert channels % num_heads == 0
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
+
+ if attn_mode == "windowed":
+ raise NotImplementedError("Windowed attention is not yet implemented")
+
+ self.channels = channels
+ self.head_dim = channels // num_heads
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
+ self.num_heads = num_heads
+ self._type = type
+ self.attn_mode = attn_mode
+ self.window_size = window_size
+ self.shift_window = shift_window
+ self.use_rope = use_rope
+ self.qk_rms_norm = qk_rms_norm
+
+ if self._type == "self":
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
+ else:
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
+
+ if self.qk_rms_norm:
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
+
+ self.to_out = nn.Linear(channels, channels)
+
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ B, L, C = x.shape
+ if self._type == "self":
+ qkv = self.to_qkv(x)
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
+
+ if self.attn_mode == "full":
+ if self.qk_rms_norm or self.use_rope:
+ q, k, v = qkv.unbind(dim=2)
+ if self.qk_rms_norm:
+ q = self.q_rms_norm(q)
+ k = self.k_rms_norm(k)
+ if self.use_rope:
+ assert phases is not None, "Phases must be provided for RoPE"
+ q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
+ k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
+ h = scaled_dot_product_attention(q, k, v)
+ else:
+ h = scaled_dot_product_attention(qkv)
+ elif self.attn_mode == "windowed":
+ raise NotImplementedError("Windowed attention is not yet implemented")
+ else:
+ Lkv = context.shape[1]
+ q = self.to_q(x)
+ kv = self.to_kv(context)
+ q = q.reshape(B, L, self.num_heads, -1)
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
+ if self.qk_rms_norm:
+ q = self.q_rms_norm(q)
+ k, v = kv.unbind(dim=2)
+ k = self.k_rms_norm(k)
+ h = scaled_dot_product_attention(q, k, v)
+ else:
+ h = scaled_dot_product_attention(q, kv)
+ h = h.reshape(B, L, -1)
+ h = self.to_out(h)
+ return h
diff --git a/trellis2/modules/attention/proj_attention.py b/trellis2/modules/attention/proj_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..77011c1189e596e3c926ed31a94cfa2d14d6fe46
--- /dev/null
+++ b/trellis2/modules/attention/proj_attention.py
@@ -0,0 +1,101 @@
+"""
+View-Aligned Projection Attention Module for TRELLIS2
+
+This module implements the projection-based attention mechanism that combines
+global cross-attention with view-aligned projected features.
+
+Supports two modes:
+- "proj": Standard projection (DINOv3 only), per-block proj_linear
+- "gated_proj": Gated fusion of DINOv3 (semantic) + VAE (color) features
+"""
+
+from typing import *
+import torch
+import torch.nn as nn
+
+
+class ProjectAttention(nn.Module):
+ """
+ Projection-based Attention Module with per-block proj_linear.
+
+ Combines global cross-attention with view-aligned projected features.
+ Each block owns a proj_linear that projects DINOv3 features from
+ proj_in_channels (e.g. 1024) to model_channels (e.g. 1536).
+
+ The module receives:
+ - x: Input features from the transformer
+ - context: A dict with keys:
+ - 'global': Global image features, shape [B, M, ctx_channels]
+ - 'proj': View-aligned projected features, shape [B, N, proj_in_channels]
+
+ The output combines the cross-attention result with the projected context.
+ """
+ def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int):
+ super().__init__()
+ self.cross_attn_block = cross_attn_block
+ self.proj_linear = nn.Linear(proj_in_channels, channels, bias=True)
+
+ def forward(self, x: torch.Tensor, context: Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
+ if isinstance(context, dict):
+ global_context = context['global']
+ proj_context = context['proj']
+ else:
+ global_context, proj_context = context
+
+ global_out = self.cross_attn_block(x, global_context)
+ proj_out = self.proj_linear(proj_context)
+ context_combined = proj_out + global_out
+ return context_combined
+
+
+class GatedProjectAttention(nn.Module):
+ """
+ Concat-Projection Attention Module for DINOv3 (semantic) + VAE (color) features.
+
+ Concatenates DINOv3 and VAE projected features and applies a single linear
+ projection to model_channels. This is mathematically equivalent to two
+ separate proj_linears + addition, but allows cross-dimensional interactions
+ between semantic and color features through the shared weight matrix.
+
+ Zero-initialized for stable training: at init, fused=0 so only global
+ cross-attention contributes; color+semantic signals are gradually learned.
+
+ The module receives:
+ - x: Input features from the transformer
+ - context: A dict with keys:
+ - 'global': Global image features, shape [B, M, ctx_channels]
+ - 'proj_semantic': DINOv3 projected features, shape [B, N, dino_channels]
+ - 'proj_color': VAE projected features, shape [B, N, vae_channels]
+ """
+ def __init__(
+ self,
+ cross_attn_block: nn.Module,
+ channels: int,
+ dino_in_channels: int,
+ vae_in_channels: int,
+ ):
+ """
+ Args:
+ cross_attn_block: The underlying cross-attention module
+ channels: Model channels (output dimension)
+ dino_in_channels: DINOv3 proj feature dimension (e.g. 1024)
+ vae_in_channels: VAE latent feature dimension (e.g. 16)
+ """
+ super().__init__()
+ self.cross_attn_block = cross_attn_block
+ self.proj_linear = nn.Linear(dino_in_channels + vae_in_channels, channels, bias=True)
+ # Zero-init: at start, fused=0, only global cross-attn contributes
+ nn.init.zeros_(self.proj_linear.weight)
+ nn.init.zeros_(self.proj_linear.bias)
+
+ def forward(self, x: torch.Tensor, context: Union[Dict[str, torch.Tensor], Tuple]) -> torch.Tensor:
+ if isinstance(context, dict):
+ global_context = context['global']
+ proj_semantic = context['proj_semantic']
+ proj_color = context['proj_color']
+ else:
+ global_context, proj_semantic, proj_color = context
+
+ global_out = self.cross_attn_block(x, global_context)
+ fused = self.proj_linear(torch.cat([proj_semantic, proj_color], dim=-1))
+ return fused + global_out
diff --git a/trellis2/modules/attention/rope.py b/trellis2/modules/attention/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cf6c5b321c2417443ffad3804df97ff5fbe6658
--- /dev/null
+++ b/trellis2/modules/attention/rope.py
@@ -0,0 +1,48 @@
+from typing import *
+import torch
+import torch.nn as nn
+
+
+class RotaryPositionEmbedder(nn.Module):
+ def __init__(
+ self,
+ head_dim: int,
+ dim: int = 3,
+ rope_freq: Tuple[float, float] = (1.0, 10000.0)
+ ):
+ super().__init__()
+ assert head_dim % 2 == 0, "Head dim must be divisible by 2"
+ self.head_dim = head_dim
+ self.dim = dim
+ self.rope_freq = rope_freq
+ self.freq_dim = head_dim // 2 // dim
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
+ self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
+
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
+ self.freqs = self.freqs.to(indices.device)
+ phases = torch.outer(indices, self.freqs)
+ phases = torch.polar(torch.ones_like(phases), phases)
+ return phases
+
+ @staticmethod
+ def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ x_rotated = x_complex * phases.unsqueeze(-2)
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
+ return x_embed
+
+ def forward(self, indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
+ """
+ assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
+ if phases.shape[-1] < self.head_dim // 2:
+ padn = self.head_dim // 2 - phases.shape[-1]
+ phases = torch.cat([phases, torch.polar(
+ torch.ones(*phases.shape[:-1], padn, device=phases.device),
+ torch.zeros(*phases.shape[:-1], padn, device=phases.device)
+ )], dim=-1)
+ return phases
\ No newline at end of file
diff --git a/trellis2/modules/image_feature_extractor.py b/trellis2/modules/image_feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3cb515ad1152752b09fef9999cc5ff44bd885f7
--- /dev/null
+++ b/trellis2/modules/image_feature_extractor.py
@@ -0,0 +1,118 @@
+from typing import *
+import torch
+import torch.nn.functional as F
+from torchvision import transforms
+from transformers import DINOv3ViTModel
+import numpy as np
+from PIL import Image
+
+
+class DinoV2FeatureExtractor:
+ """
+ Feature extractor for DINOv2 models.
+ """
+ def __init__(self, model_name: str):
+ self.model_name = model_name
+ self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True)
+ self.model.eval()
+ self.transform = transforms.Compose([
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ])
+
+ def to(self, device):
+ self.model.to(device)
+
+ def cuda(self):
+ self.model.cuda()
+
+ def cpu(self):
+ self.model.cpu()
+
+ @torch.no_grad()
+ def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
+ """
+ Extract features from the image.
+
+ Args:
+ image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
+
+ Returns:
+ A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
+ """
+ if isinstance(image, torch.Tensor):
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
+ elif isinstance(image, list):
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
+ image = [i.resize((518, 518), Image.LANCZOS) for i in image]
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
+ image = torch.stack(image).cuda()
+ else:
+ raise ValueError(f"Unsupported type of image: {type(image)}")
+
+ image = self.transform(image).cuda()
+ features = self.model(image, is_training=True)['x_prenorm']
+ patchtokens = F.layer_norm(features, features.shape[-1:])
+ return patchtokens
+
+
+class DinoV3FeatureExtractor:
+ """
+ Feature extractor for DINOv3 models.
+ """
+ def __init__(self, model_name: str, image_size=512):
+ self.model_name = model_name
+ self.model = DINOv3ViTModel.from_pretrained(model_name)
+ self.model.eval()
+ self.image_size = image_size
+ self.transform = transforms.Compose([
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ])
+
+ def to(self, device):
+ self.model.to(device)
+
+ def cuda(self):
+ self.model.cuda()
+
+ def cpu(self):
+ self.model.cpu()
+
+ def extract_features(self, image: torch.Tensor) -> torch.Tensor:
+ image = image.to(self.model.embeddings.patch_embeddings.weight.dtype)
+ hidden_states = self.model.embeddings(image, bool_masked_pos=None)
+ position_embeddings = self.model.rope_embeddings(image)
+
+ for i, layer_module in enumerate(self.model.layer):
+ hidden_states = layer_module(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ )
+
+ return F.layer_norm(hidden_states, hidden_states.shape[-1:])
+
+ @torch.no_grad()
+ def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
+ """
+ Extract features from the image.
+
+ Args:
+ image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
+
+ Returns:
+ A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
+ """
+ if isinstance(image, torch.Tensor):
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
+ elif isinstance(image, list):
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
+ image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image]
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
+ image = torch.stack(image).cuda()
+ else:
+ raise ValueError(f"Unsupported type of image: {type(image)}")
+
+ image = self.transform(image).cuda()
+ features = self.extract_features(image)
+ return features
diff --git a/trellis2/modules/norm.py b/trellis2/modules/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..78675d0f850d2e34d3c90b5d6bc14db708e5b400
--- /dev/null
+++ b/trellis2/modules/norm.py
@@ -0,0 +1,32 @@
+import torch
+import torch.nn as nn
+from .utils import manual_cast
+
+
+class LayerNorm32(nn.LayerNorm):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x_dtype = x.dtype
+ x = manual_cast(x, torch.float32)
+ o = super().forward(x)
+ return manual_cast(o, x_dtype)
+
+
+class GroupNorm32(nn.GroupNorm):
+ """
+ A GroupNorm layer that converts to float32 before the forward pass.
+ """
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x_dtype = x.dtype
+ x = manual_cast(x, torch.float32)
+ o = super().forward(x)
+ return manual_cast(o, x_dtype)
+
+
+class ChannelLayerNorm32(LayerNorm32):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ DIM = x.dim()
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
+ x = super().forward(x)
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
+ return x
+
\ No newline at end of file
diff --git a/trellis2/modules/sparse/__init__.py b/trellis2/modules/sparse/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e73f232abc6f31cafeac4172a96150906ba20b7b
--- /dev/null
+++ b/trellis2/modules/sparse/__init__.py
@@ -0,0 +1,69 @@
+from . import config
+import importlib
+
+__attributes = {
+ 'VarLenTensor': 'basic',
+ 'varlen_cat': 'basic',
+ 'varlen_unbind': 'basic',
+ 'SparseTensor': 'basic',
+ 'sparse_cat': 'basic',
+ 'sparse_unbind': 'basic',
+ 'SparseGroupNorm': 'norm',
+ 'SparseLayerNorm': 'norm',
+ 'SparseGroupNorm32': 'norm',
+ 'SparseLayerNorm32': 'norm',
+ 'SparseReLU': 'nonlinearity',
+ 'SparseSiLU': 'nonlinearity',
+ 'SparseGELU': 'nonlinearity',
+ 'SparseActivation': 'nonlinearity',
+ 'SparseLinear': 'linear',
+ 'sparse_scaled_dot_product_attention': 'attention',
+ 'SerializeMode': 'attention',
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
+ 'sparse_windowed_scaled_dot_product_cross_attention': 'attention',
+ 'SparseRotaryPositionEmbedder': 'attention',
+ 'SparseMultiHeadAttention': 'attention',
+ 'SparseConv3d': 'conv',
+ 'SparseInverseConv3d': 'conv',
+ 'SparseDownsample': 'spatial',
+ 'SparseUpsample': 'spatial',
+ 'SparseSubdivide': 'spatial',
+ 'SparseSpatial2Channel': 'spatial',
+ 'SparseChannel2Spatial': 'spatial',
+ 'sparse_nearest_interpolate': 'spatial',
+ 'sparse_trilinear_interpolate': 'spatial',
+ 'encode_seq': 'serialize',
+ 'decode_seq': 'serialize',
+}
+
+__submodules = ['transformer', 'conv']
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+ if name not in globals():
+ if name in __attributes:
+ module_name = __attributes[name]
+ module = importlib.import_module(f".{module_name}", __name__)
+ globals()[name] = getattr(module, name)
+ elif name in __submodules:
+ module = importlib.import_module(f".{name}", __name__)
+ globals()[name] = module
+ else:
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+ return globals()[name]
+
+
+# For Pylance
+if __name__ == '__main__':
+ from .basic import *
+ from .norm import *
+ from .nonlinearity import *
+ from .linear import *
+ from .attention import *
+ from .conv import *
+ from .spatial import *
+ from .serialize import *
+ import transformer
+ import conv
diff --git a/trellis2/modules/sparse/attention/__init__.py b/trellis2/modules/sparse/attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..695105b1997bb0ad13c94efa4ca04a95861adb1b
--- /dev/null
+++ b/trellis2/modules/sparse/attention/__init__.py
@@ -0,0 +1,4 @@
+from .full_attn import *
+from .windowed_attn import *
+from .modules import *
+from .proj_attention import *
diff --git a/trellis2/modules/sparse/attention/full_attn.py b/trellis2/modules/sparse/attention/full_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7d559768fb3d0c5d0b67b835a12084f09ccbf0
--- /dev/null
+++ b/trellis2/modules/sparse/attention/full_attn.py
@@ -0,0 +1,238 @@
+from typing import *
+import torch
+from .. import VarLenTensor
+from .. import config
+
+
+__all__ = [
+ 'sparse_scaled_dot_product_attention',
+]
+
+
+@overload
+def sparse_scaled_dot_product_attention(qkv: VarLenTensor) -> VarLenTensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ qkv (VarLenTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: VarLenTensor, kv: Union[VarLenTensor, torch.Tensor]) -> VarLenTensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (VarLenTensor): A [N, *, H, C] sparse tensor containing Qs.
+ kv (VarLenTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: VarLenTensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (torch.Tensor): A [N, L, H, C] dense tensor containing Qs.
+ kv (VarLenTensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: VarLenTensor, k: VarLenTensor, v: VarLenTensor) -> VarLenTensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs.
+ k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks.
+ v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs.
+
+ Note:
+ k and v are assumed to have the same coordinate map.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: VarLenTensor, k: torch.Tensor, v: torch.Tensor) -> VarLenTensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs.
+ k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
+ v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
+ """
+ ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: torch.Tensor, k: VarLenTensor, v: VarLenTensor) -> torch.Tensor:
+ """
+ Apply scaled dot product attention to a sparse tensor.
+
+ Args:
+ q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
+ k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks.
+ v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs.
+ """
+ ...
+
+def sparse_scaled_dot_product_attention(*args, **kwargs):
+ arg_names_dict = {
+ 1: ['qkv'],
+ 2: ['q', 'kv'],
+ 3: ['q', 'k', 'v']
+ }
+ num_all_args = len(args) + len(kwargs)
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
+ for key in arg_names_dict[num_all_args][len(args):]:
+ assert key in kwargs, f"Missing argument {key}"
+
+ if num_all_args == 1:
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
+ assert isinstance(qkv, VarLenTensor), f"qkv must be a VarLenTensor, got {type(qkv)}"
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
+ device = qkv.device
+
+ s = qkv
+ q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
+ kv_seqlen = q_seqlen
+ qkv = qkv.feats # [T, 3, H, C]
+
+ elif num_all_args == 2:
+ q = args[0] if len(args) > 0 else kwargs['q']
+ kv = args[1] if len(args) > 1 else kwargs['kv']
+ assert isinstance(q, VarLenTensor) and isinstance(kv, (VarLenTensor, torch.Tensor)) or \
+ isinstance(q, torch.Tensor) and isinstance(kv, VarLenTensor), \
+ f"Invalid types, got {type(q)} and {type(kv)}"
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
+ device = q.device
+
+ if isinstance(q, VarLenTensor):
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
+ s = q
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
+ q = q.feats # [T_Q, H, C]
+ else:
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
+ s = None
+ N, L, H, C = q.shape
+ q_seqlen = [L] * N
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
+
+ if isinstance(kv, VarLenTensor):
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
+ kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
+ kv = kv.feats # [T_KV, 2, H, C]
+ else:
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
+ N, L, _, H, C = kv.shape
+ kv_seqlen = [L] * N
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
+
+ elif num_all_args == 3:
+ q = args[0] if len(args) > 0 else kwargs['q']
+ k = args[1] if len(args) > 1 else kwargs['k']
+ v = args[2] if len(args) > 2 else kwargs['v']
+ assert isinstance(q, VarLenTensor) and isinstance(k, (VarLenTensor, torch.Tensor)) and type(k) == type(v) or \
+ isinstance(q, torch.Tensor) and isinstance(k, VarLenTensor) and isinstance(v, VarLenTensor), \
+ f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
+ device = q.device
+
+ if isinstance(q, VarLenTensor):
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
+ s = q
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
+ q = q.feats # [T_Q, H, Ci]
+ else:
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
+ s = None
+ N, L, H, CI = q.shape
+ q_seqlen = [L] * N
+ q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
+
+ if isinstance(k, VarLenTensor):
+ assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
+ assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
+ kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
+ k = k.feats # [T_KV, H, Ci]
+ v = v.feats # [T_KV, H, Co]
+ else:
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
+ N, L, H, CI, CO = *k.shape, v.shape[-1]
+ kv_seqlen = [L] * N
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
+
+ if config.ATTN == 'xformers':
+ if 'xops' not in globals():
+ import xformers.ops as xops
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=1)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=1)
+ q = q.unsqueeze(0)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
+ out = xops.memory_efficient_attention(q, k, v, mask)[0]
+ elif config.ATTN == 'flash_attn':
+ if 'flash_attn' not in globals():
+ import flash_attn
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
+ if num_all_args in [2, 3]:
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
+ if num_all_args == 1:
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
+ elif num_all_args == 2:
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
+ elif num_all_args == 3:
+ out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
+ elif config.ATTN == 'flash_attn_3':
+ if 'flash_attn_3' not in globals():
+ import flash_attn_interface as flash_attn_3
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=1)
+ cu_seqlens_kv = cu_seqlens_q.clone()
+ max_q_seqlen = max_kv_seqlen = max(q_seqlen)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=1)
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
+ max_q_seqlen = max(q_seqlen)
+ max_kv_seqlen = max(kv_seqlen)
+ elif num_all_args == 3:
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
+ max_q_seqlen = max(q_seqlen)
+ max_kv_seqlen = max(kv_seqlen)
+ out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
+ elif config.ATTN == 'flash_attn_4':
+ if 'flash_attn_4' not in globals():
+ from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
+ if num_all_args == 1:
+ q, k, v = qkv.unbind(dim=1)
+ cu_seqlens_kv = cu_seqlens_q.clone()
+ max_q_seqlen = max_kv_seqlen = max(q_seqlen)
+ elif num_all_args == 2:
+ k, v = kv.unbind(dim=1)
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
+ max_q_seqlen = max(q_seqlen)
+ max_kv_seqlen = max(kv_seqlen)
+ elif num_all_args == 3:
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
+ max_q_seqlen = max(q_seqlen)
+ max_kv_seqlen = max(kv_seqlen)
+ out, _ = flash_attn_4_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
+ else:
+ raise ValueError(f"Unknown attention module: {config.ATTN}")
+
+ if s is not None:
+ return s.replace(out)
+ else:
+ return out.reshape(N, L, H, -1)
diff --git a/trellis2/modules/sparse/attention/modules.py b/trellis2/modules/sparse/attention/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..d762b4b2e01335d26d856b3fe82eca31a2735123
--- /dev/null
+++ b/trellis2/modules/sparse/attention/modules.py
@@ -0,0 +1,141 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .. import VarLenTensor, SparseTensor
+from .full_attn import sparse_scaled_dot_product_attention
+from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
+from .rope import SparseRotaryPositionEmbedder
+
+
+class SparseMultiHeadRMSNorm(nn.Module):
+ def __init__(self, dim: int, heads: int):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
+
+ def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
+ x_type = x.dtype
+ x = x.float()
+ if isinstance(x, VarLenTensor):
+ x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
+ else:
+ x = F.normalize(x, dim=-1) * self.gamma * self.scale
+ return x.to(x_type)
+
+
+class SparseMultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ ctx_channels: Optional[int] = None,
+ type: Literal["self", "cross"] = "self",
+ attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ qkv_bias: bool = True,
+ use_rope: bool = False,
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ ):
+ super().__init__()
+ assert channels % num_heads == 0
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
+ assert attn_mode in ["full", "windowed", "double_windowed"], f"Invalid attention mode: {attn_mode}"
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
+ assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
+ if attn_mode == 'double_windowed':
+ assert window_size % 2 == 0, "Window size must be even for double windowed attention"
+ assert num_heads % 2 == 0, "Number of heads must be even for double windowed attention"
+ self.channels = channels
+ self.head_dim = channels // num_heads
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
+ self.num_heads = num_heads
+ self._type = type
+ self.attn_mode = attn_mode
+ self.window_size = window_size
+ self.shift_window = shift_window
+ self.use_rope = use_rope
+ self.qk_rms_norm = qk_rms_norm
+
+ if self._type == "self":
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
+ else:
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
+
+ if self.qk_rms_norm:
+ self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
+ self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
+
+ self.to_out = nn.Linear(channels, channels)
+
+ if use_rope:
+ self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq)
+
+ @staticmethod
+ def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
+ if isinstance(x, VarLenTensor):
+ return x.replace(module(x.feats))
+ else:
+ return module(x)
+
+ @staticmethod
+ def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]:
+ if isinstance(x, VarLenTensor):
+ return x.reshape(*shape)
+ else:
+ return x.reshape(*x.shape[:2], *shape)
+
+ def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]:
+ if isinstance(x, VarLenTensor):
+ x_feats = x.feats.unsqueeze(0)
+ else:
+ x_feats = x
+ x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
+ return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
+
+ def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
+ if self._type == "self":
+ qkv = self._linear(self.to_qkv, x)
+ qkv = self._fused_pre(qkv, num_fused=3)
+ if self.qk_rms_norm or self.use_rope:
+ q, k, v = qkv.unbind(dim=-3)
+ if self.qk_rms_norm:
+ q = self.q_rms_norm(q)
+ k = self.k_rms_norm(k)
+ if self.use_rope:
+ q, k = self.rope(q, k)
+ qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
+ if self.attn_mode == "full":
+ h = sparse_scaled_dot_product_attention(qkv)
+ elif self.attn_mode == "windowed":
+ h = sparse_windowed_scaled_dot_product_self_attention(
+ qkv, self.window_size, shift_window=self.shift_window
+ )
+ elif self.attn_mode == "double_windowed":
+ qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
+ qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
+ h0 = sparse_windowed_scaled_dot_product_self_attention(
+ qkv0, self.window_size, shift_window=(0, 0, 0)
+ )
+ h1 = sparse_windowed_scaled_dot_product_self_attention(
+ qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
+ )
+ h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
+ else:
+ q = self._linear(self.to_q, x)
+ q = self._reshape_chs(q, (self.num_heads, -1))
+ kv = self._linear(self.to_kv, context)
+ kv = self._fused_pre(kv, num_fused=2)
+ if self.qk_rms_norm:
+ q = self.q_rms_norm(q)
+ k, v = kv.unbind(dim=-3)
+ k = self.k_rms_norm(k)
+ h = sparse_scaled_dot_product_attention(q, k, v)
+ else:
+ h = sparse_scaled_dot_product_attention(q, kv)
+ h = self._reshape_chs(h, (-1,))
+ h = self._linear(self.to_out, h)
+ return h
diff --git a/trellis2/modules/sparse/attention/proj_attention.py b/trellis2/modules/sparse/attention/proj_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..185c1dbf8bb480b1cc2267002c655f878e667afd
--- /dev/null
+++ b/trellis2/modules/sparse/attention/proj_attention.py
@@ -0,0 +1,99 @@
+"""
+Sparse View-Aligned Projection Attention Module for TRELLIS2
+
+Sparse versions of ProjectAttention and GatedProjectAttention.
+
+Supports two modes:
+- "proj": Standard projection (DINOv3 only)
+- "gated_proj": Gated fusion of DINOv3 (semantic) + VAE (color) features
+"""
+
+from typing import *
+import torch
+import torch.nn as nn
+from ..basic import SparseTensor, VarLenTensor
+
+
+class SparseProjectAttention(nn.Module):
+ """
+ Sparse Projection-based Attention Module with per-block proj_linear.
+ """
+ def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int):
+ super().__init__()
+ self.cross_attn_block = cross_attn_block
+ self.proj_linear = nn.Linear(proj_in_channels, channels, bias=True)
+
+ def forward(
+ self,
+ x: SparseTensor,
+ context: Union[Dict[str, Union[torch.Tensor, VarLenTensor, SparseTensor]],
+ Tuple[Union[torch.Tensor, VarLenTensor], SparseTensor]]
+ ) -> SparseTensor:
+ if isinstance(context, dict):
+ global_context = context['global']
+ proj_context = context['proj']
+ else:
+ global_context, proj_context = context
+
+ global_out = self.cross_attn_block(x, global_context)
+
+ if isinstance(proj_context, SparseTensor):
+ proj_feats = self.proj_linear(proj_context.feats)
+ combined_feats = proj_feats + global_out.feats
+ else:
+ proj_feats = self.proj_linear(proj_context)
+ combined_feats = proj_feats + global_out.feats
+
+ return global_out.replace(combined_feats)
+
+
+class SparseGatedProjectAttention(nn.Module):
+ """
+ Sparse Concat-Projection Attention Module for DINOv3 + VAE features.
+
+ Concatenates DINOv3 and VAE projected features and applies a single linear
+ projection to model_channels. Zero-initialized for stable training.
+
+ Context dict must contain:
+ - 'global': Global image features for cross-attention
+ - 'proj_semantic': DINOv3 projected features (SparseTensor or Tensor)
+ - 'proj_color': VAE projected features (SparseTensor or Tensor)
+ """
+ def __init__(
+ self,
+ cross_attn_block: nn.Module,
+ channels: int,
+ dino_in_channels: int,
+ vae_in_channels: int,
+ ):
+ super().__init__()
+ self.cross_attn_block = cross_attn_block
+ self.proj_linear = nn.Linear(dino_in_channels + vae_in_channels, channels, bias=True)
+ # Zero-init: at start, fused=0, only global cross-attn contributes
+ nn.init.zeros_(self.proj_linear.weight)
+ nn.init.zeros_(self.proj_linear.bias)
+
+ def _get_feats(self, t):
+ return t.feats if isinstance(t, SparseTensor) else t
+
+ def forward(
+ self,
+ x: SparseTensor,
+ context: Union[Dict[str, Union[torch.Tensor, VarLenTensor, SparseTensor]], Tuple],
+ ) -> SparseTensor:
+ if isinstance(context, dict):
+ global_context = context['global']
+ proj_semantic = context['proj_semantic']
+ proj_color = context['proj_color']
+ else:
+ global_context, proj_semantic, proj_color = context
+
+ global_out = self.cross_attn_block(x, global_context)
+
+ fused = self.proj_linear(torch.cat([
+ self._get_feats(proj_semantic),
+ self._get_feats(proj_color),
+ ], dim=-1))
+ combined_feats = fused + global_out.feats
+
+ return global_out.replace(combined_feats)
diff --git a/trellis2/modules/sparse/attention/rope.py b/trellis2/modules/sparse/attention/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb877291f3430f2cff4329a67b3592e0e3c3f137
--- /dev/null
+++ b/trellis2/modules/sparse/attention/rope.py
@@ -0,0 +1,58 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..basic import SparseTensor
+
+
+class SparseRotaryPositionEmbedder(nn.Module):
+ def __init__(
+ self,
+ head_dim: int,
+ dim: int = 3,
+ rope_freq: Tuple[float, float] = (1.0, 10000.0)
+ ):
+ super().__init__()
+ assert head_dim % 2 == 0, "Head dim must be divisible by 2"
+ self.head_dim = head_dim
+ self.dim = dim
+ self.rope_freq = rope_freq
+ self.freq_dim = head_dim // 2 // dim
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
+ self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
+
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
+ self.freqs = self.freqs.to(indices.device)
+ phases = torch.outer(indices, self.freqs)
+ phases = torch.polar(torch.ones_like(phases), phases)
+ return phases
+
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ x_rotated = x_complex * phases.unsqueeze(-2)
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
+ return x_embed
+
+ def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ q (SparseTensor): [..., N, H, D] tensor of queries
+ k (SparseTensor): [..., N, H, D] tensor of keys
+ """
+ assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1"
+ phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}'
+ phases = q.get_spatial_cache(phases_cache_name)
+ if phases is None:
+ coords = q.coords[..., 1:]
+ phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1)
+ if phases.shape[-1] < self.head_dim // 2:
+ padn = self.head_dim // 2 - phases.shape[-1]
+ phases = torch.cat([phases, torch.polar(
+ torch.ones(*phases.shape[:-1], padn, device=phases.device),
+ torch.zeros(*phases.shape[:-1], padn, device=phases.device)
+ )], dim=-1)
+ q.register_spatial_cache(phases_cache_name, phases)
+ q_embed = q.replace(self._rotary_embedding(q.feats, phases))
+ if k is None:
+ return q_embed
+ k_embed = k.replace(self._rotary_embedding(k.feats, phases))
+ return q_embed, k_embed
\ No newline at end of file
diff --git a/trellis2/modules/sparse/attention/windowed_attn.py b/trellis2/modules/sparse/attention/windowed_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..675aa0870d98ea014206961ed947704bc6262970
--- /dev/null
+++ b/trellis2/modules/sparse/attention/windowed_attn.py
@@ -0,0 +1,205 @@
+from typing import *
+import torch
+import math
+from .. import SparseTensor
+from .. import config
+
+
+__all__ = [
+ 'sparse_windowed_scaled_dot_product_self_attention',
+ 'sparse_windowed_scaled_dot_product_cross_attention',
+]
+
+
+def calc_window_partition(
+ tensor: SparseTensor,
+ window_size: Union[int, Tuple[int, ...]],
+ shift_window: Union[int, Tuple[int, ...]] = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
+ """
+ Calculate serialization and partitioning for a set of coordinates.
+
+ Args:
+ tensor (SparseTensor): The input tensor.
+ window_size (int): The window size to use.
+ shift_window (Tuple[int, ...]): The shift of serialized coordinates.
+
+ Returns:
+ (torch.Tensor): Forwards indices.
+ (torch.Tensor): Backwards indices.
+ (torch.Tensor): Sequence lengths.
+ (dict): Attn func args.
+ """
+ DIM = tensor.coords.shape[1] - 1
+ shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
+ window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
+ shifted_coords = tensor.coords.clone().detach()
+ shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
+
+ MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
+ NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
+ OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
+
+ shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
+ shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
+ fwd_indices = torch.argsort(shifted_indices)
+ bwd_indices = torch.empty_like(fwd_indices)
+ bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
+ seq_lens = torch.bincount(shifted_indices)
+ mask = seq_lens != 0
+ seq_lens = seq_lens[mask]
+
+ if config.ATTN == 'xformers':
+ if 'xops' not in globals():
+ import xformers.ops as xops
+ attn_func_args = {
+ 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
+ }
+ elif config.ATTN == 'flash_attn':
+ attn_func_args = {
+ 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
+ 'max_seqlen': torch.max(seq_lens)
+ }
+ elif config.ATTN == 'flash_attn_4':
+ attn_func_args = {
+ 'cu_seqlens_q': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
+ 'cu_seqlens_k': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
+ 'max_seqlen_q': torch.max(seq_lens),
+ 'max_seqlen_k': torch.max(seq_lens),
+ }
+
+ return fwd_indices, bwd_indices, seq_lens, attn_func_args
+
+
+def sparse_windowed_scaled_dot_product_self_attention(
+ qkv: SparseTensor,
+ window_size: int,
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
+) -> SparseTensor:
+ """
+ Apply windowed scaled dot product self attention to a sparse tensor.
+
+ Args:
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
+ window_size (int): The window size to use.
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
+
+ Returns:
+ (SparseTensor): [N, *, H, C] sparse tensor containing the output features.
+ """
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
+
+ serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
+ if serialization_spatial_cache is None:
+ fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
+ else:
+ fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
+
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
+
+ if config.DEBUG:
+ start = 0
+ qkv_coords = qkv.coords[fwd_indices]
+ for i in range(len(seq_lens)):
+ seq_coords = qkv_coords[start:start+seq_lens[i]]
+ assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
+ start += seq_lens[i]
+
+ if config.ATTN == 'xformers':
+ if 'xops' not in globals():
+ import xformers.ops as xops
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
+ q = q.unsqueeze(0) # [1, M, H, C]
+ k = k.unsqueeze(0) # [1, M, H, C]
+ v = v.unsqueeze(0) # [1, M, H, C]
+ out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
+ elif config.ATTN == 'flash_attn':
+ if 'flash_attn' not in globals():
+ import flash_attn
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
+
+ out = out[bwd_indices] # [T, H, C]
+
+ if config.DEBUG:
+ qkv_coords = qkv_coords[bwd_indices]
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
+
+ return qkv.replace(out)
+
+
+def sparse_windowed_scaled_dot_product_cross_attention(
+ q: SparseTensor,
+ kv: SparseTensor,
+ q_window_size: int,
+ kv_window_size: int,
+ q_shift_window: Tuple[int, int, int] = (0, 0, 0),
+ kv_shift_window: Tuple[int, int, int] = (0, 0, 0),
+) -> SparseTensor:
+ """
+ Apply windowed scaled dot product cross attention to two sparse tensors.
+
+ Args:
+ q (SparseTensor): [N, *, H, C] sparse tensor containing Qs.
+ kv (SparseTensor): [N, *, 2, H, C] sparse tensor containing Ks and Vs.
+ q_window_size (int): The window size to use for Qs.
+ kv_window_size (int): The window size to use for Ks and Vs.
+ q_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Qs.
+ kv_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Ks and Vs.
+
+ Returns:
+ (SparseTensor): [N, *, H, C] sparse tensor containing the output features.
+ """
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
+
+ q_serialization_spatial_cache_name = f'windowed_attention_{q_window_size}_{q_shift_window}'
+ q_serialization_spatial_cache = q.get_spatial_cache(q_serialization_spatial_cache_name)
+ if q_serialization_spatial_cache is None:
+ q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = calc_window_partition(q, q_window_size, q_shift_window)
+ q.register_spatial_cache(q_serialization_spatial_cache_name, (q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args))
+ else:
+ q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = q_serialization_spatial_cache
+ kv_serialization_spatial_cache_name = f'windowed_attention_{kv_window_size}_{kv_shift_window}'
+ kv_serialization_spatial_cache = kv.get_spatial_cache(kv_serialization_spatial_cache_name)
+ if kv_serialization_spatial_cache is None:
+ kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = calc_window_partition(kv, kv_window_size, kv_shift_window)
+ kv.register_spatial_cache(kv_serialization_spatial_cache_name, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args))
+ else:
+ kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = kv_serialization_spatial_cache
+
+ assert len(q_seq_lens) == len(kv_seq_lens), "Number of sequences in q and kv must match"
+
+ q_feats = q.feats[q_fwd_indices] # [M, H, C]
+ kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C]
+
+ if config.ATTN == 'xformers':
+ if 'xops' not in globals():
+ import xformers.ops as xops
+ k, v = kv_feats.unbind(dim=1) # [M, H, C]
+ q = q.unsqueeze(0) # [1, M, H, C]
+ k = k.unsqueeze(0) # [1, M, H, C]
+ v = v.unsqueeze(0) # [1, M, H, C]
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens)
+ out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)[0] # [M, H, C]
+ elif config.ATTN == 'flash_attn':
+ if 'flash_attn' not in globals():
+ import flash_attn
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats, kv_feats,
+ cu_seqlens_q=q_attn_func_args['cu_seqlens'], cu_seqlens_k=kv_attn_func_args['cu_seqlens'],
+ max_seqlen_q=q_attn_func_args['max_seqlen'], max_seqlen_k=kv_attn_func_args['max_seqlen'],
+ ) # [M, H, C]
+ elif config.ATTN == 'flash_attn_4':
+ if 'flash_attn_4' not in globals():
+ from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func
+ k, v = kv_feats.unbind(dim=1)
+ out, _ = flash_attn_4_varlen_func(q_feats, k, v,
+ cu_seqlens_q=q_attn_func_args['cu_seqlens_q'], cu_seqlens_k=kv_attn_func_args['cu_seqlens_k'],
+ max_seqlen_q=q_attn_func_args['max_seqlen_q'], max_seqlen_k=kv_attn_func_args['max_seqlen_k'],
+ ) # [M, H, C]
+
+ out = out[q_bwd_indices] # [T, H, C]
+
+ return q.replace(out)
diff --git a/trellis2/modules/sparse/basic.py b/trellis2/modules/sparse/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..880973b8dd6bdafcfca4ca7c529308d2ef2ad266
--- /dev/null
+++ b/trellis2/modules/sparse/basic.py
@@ -0,0 +1,836 @@
+from typing import *
+from fractions import Fraction
+import torch
+from . import config
+
+
+__all__ = [
+ 'VarLenTensor',
+ 'varlen_cat',
+ 'varlen_unbind',
+ 'SparseTensor',
+ 'sparse_cat',
+ 'sparse_unbind',
+]
+
+
+class VarLenTensor:
+ """
+ Sequential tensor with variable length.
+
+ Args:
+ feats (torch.Tensor): Features of the varlen tensor.
+ layout (List[slice]): Layout of the varlen tensor for each batch
+ """
+ def __init__(self, feats: torch.Tensor, layout: List[slice]=None):
+ self.feats = feats
+ self.layout = layout if layout is not None else [slice(0, feats.shape[0])]
+ self._cache = {}
+
+ @staticmethod
+ def layout_from_seqlen(seqlen: list) -> List[slice]:
+ """
+ Create a layout from a tensor of sequence lengths.
+ """
+ layout = []
+ start = 0
+ for l in seqlen:
+ layout.append(slice(start, start + l))
+ start += l
+ return layout
+
+ @staticmethod
+ def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor':
+ """
+ Create a VarLenTensor from a list of tensors.
+ """
+ feats = torch.cat(tensor_list, dim=0)
+ layout = []
+ start = 0
+ for tensor in tensor_list:
+ layout.append(slice(start, start + tensor.shape[0]))
+ start += tensor.shape[0]
+ return VarLenTensor(feats, layout)
+
+ def to_tensor_list(self) -> List[torch.Tensor]:
+ """
+ Convert a VarLenTensor to a list of tensors.
+ """
+ tensor_list = []
+ for s in self.layout:
+ tensor_list.append(self.feats[s])
+ return tensor_list
+
+ def __len__(self) -> int:
+ return len(self.layout)
+
+ @property
+ def shape(self) -> torch.Size:
+ return torch.Size([len(self.layout), *self.feats.shape[1:]])
+
+ def dim(self) -> int:
+ return len(self.shape)
+
+ @property
+ def ndim(self) -> int:
+ return self.dim()
+
+ @property
+ def dtype(self):
+ return self.feats.dtype
+
+ @property
+ def device(self):
+ return self.feats.device
+
+ @property
+ def seqlen(self) -> torch.LongTensor:
+ if 'seqlen' not in self._cache:
+ self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device)
+ return self._cache['seqlen']
+
+ @property
+ def cum_seqlen(self) -> torch.LongTensor:
+ if 'cum_seqlen' not in self._cache:
+ self._cache['cum_seqlen'] = torch.cat([
+ torch.tensor([0], dtype=torch.long, device=self.device),
+ self.seqlen.cumsum(dim=0)
+ ], dim=0)
+ return self._cache['cum_seqlen']
+
+ @property
+ def batch_boardcast_map(self) -> torch.LongTensor:
+ """
+ Get the broadcast map for the varlen tensor.
+ """
+ if 'batch_boardcast_map' not in self._cache:
+ self._cache['batch_boardcast_map'] = torch.repeat_interleave(
+ torch.arange(len(self.layout), device=self.device),
+ self.seqlen,
+ )
+ return self._cache['batch_boardcast_map']
+
+ @overload
+ def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ...
+
+ @overload
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ...
+
+ def to(self, *args, **kwargs) -> 'VarLenTensor':
+ device = None
+ dtype = None
+ if len(args) == 2:
+ device, dtype = args
+ elif len(args) == 1:
+ if isinstance(args[0], torch.dtype):
+ dtype = args[0]
+ else:
+ device = args[0]
+ if 'dtype' in kwargs:
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
+ dtype = kwargs['dtype']
+ if 'device' in kwargs:
+ assert device is None, "to() received multiple values for argument 'device'"
+ device = kwargs['device']
+ non_blocking = kwargs.get('non_blocking', False)
+ copy = kwargs.get('copy', False)
+
+ new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy)
+ return self.replace(new_feats)
+
+ def type(self, dtype):
+ new_feats = self.feats.type(dtype)
+ return self.replace(new_feats)
+
+ def cpu(self) -> 'VarLenTensor':
+ new_feats = self.feats.cpu()
+ return self.replace(new_feats)
+
+ def cuda(self) -> 'VarLenTensor':
+ new_feats = self.feats.cuda()
+ return self.replace(new_feats)
+
+ def half(self) -> 'VarLenTensor':
+ new_feats = self.feats.half()
+ return self.replace(new_feats)
+
+ def float(self) -> 'VarLenTensor':
+ new_feats = self.feats.float()
+ return self.replace(new_feats)
+
+ def detach(self) -> 'VarLenTensor':
+ new_feats = self.feats.detach()
+ return self.replace(new_feats)
+
+ def reshape(self, *shape) -> 'VarLenTensor':
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
+ return self.replace(new_feats)
+
+ def unbind(self, dim: int) -> List['VarLenTensor']:
+ return varlen_unbind(self, dim)
+
+ def replace(self, feats: torch.Tensor) -> 'VarLenTensor':
+ new_tensor = VarLenTensor(
+ feats=feats,
+ layout=self.layout,
+ )
+ new_tensor._cache = self._cache
+ return new_tensor
+
+ def to_dense(self, max_length=None) -> torch.Tensor:
+ """
+ Convert a VarLenTensor to a dense representation without for-loop.
+
+ Returns:
+ dense (torch.Tensor): (N, L, C) dense tensor
+ mask (torch.BoolTensor): (N, L) mask indicating valid positions
+ """
+ N = len(self)
+ L = max_length or self.seqlen.max().item()
+ spatial = self.feats.shape[1:]
+ idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L)
+ mask = (idx < self.seqlen.unsqueeze(1))
+ mapping = mask.reshape(-1).cumsum(dim=0) - 1
+ dense = self.feats[mapping]
+ dense = dense.reshape(N, L, *spatial)
+ return dense, mask
+
+ def __neg__(self) -> 'VarLenTensor':
+ return self.replace(-self.feats)
+
+ def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor':
+ if isinstance(other, torch.Tensor):
+ try:
+ other = torch.broadcast_to(other, self.shape)
+ other = other[self.batch_boardcast_map]
+ except:
+ pass
+ if isinstance(other, VarLenTensor):
+ other = other.feats
+ new_feats = op(self.feats, other)
+ new_tensor = self.replace(new_feats)
+ return new_tensor
+
+ def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
+ return self.__elemwise__(other, torch.add)
+
+ def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
+ return self.__elemwise__(other, torch.add)
+
+ def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
+ return self.__elemwise__(other, torch.sub)
+
+ def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
+ return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
+
+ def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
+ return self.__elemwise__(other, torch.mul)
+
+ def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
+ return self.__elemwise__(other, torch.mul)
+
+ def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
+ return self.__elemwise__(other, torch.div)
+
+ def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
+ return self.__elemwise__(other, lambda x, y: torch.div(y, x))
+
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ idx = [idx]
+ elif isinstance(idx, slice):
+ idx = range(*idx.indices(self.shape[0]))
+ elif isinstance(idx, list):
+ assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}"
+ elif isinstance(idx, torch.Tensor):
+ if idx.dtype == torch.bool:
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
+ idx = idx.nonzero().squeeze(1)
+ elif idx.dtype in [torch.int32, torch.int64]:
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
+ else:
+ raise ValueError(f"Unknown index type: {idx.dtype}")
+ else:
+ raise ValueError(f"Unknown index type: {type(idx)}")
+
+ new_feats = []
+ new_layout = []
+ start = 0
+ for new_idx, old_idx in enumerate(idx):
+ new_feats.append(self.feats[self.layout[old_idx]])
+ new_layout.append(slice(start, start + len(new_feats[-1])))
+ start += len(new_feats[-1])
+ new_feats = torch.cat(new_feats, dim=0).contiguous()
+ new_tensor = VarLenTensor(feats=new_feats, layout=new_layout)
+ return new_tensor
+
+ def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
+ if isinstance(dim, int):
+ dim = (dim,)
+
+ if op =='mean':
+ red = self.feats.mean(dim=dim, keepdim=keepdim)
+ elif op =='sum':
+ red = self.feats.sum(dim=dim, keepdim=keepdim)
+ elif op == 'prod':
+ red = self.feats.prod(dim=dim, keepdim=keepdim)
+ else:
+ raise ValueError(f"Unsupported reduce operation: {op}")
+
+ if dim is None or 0 in dim:
+ return red
+
+ red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen)
+ return red
+
+ def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
+ return self.reduce(op='mean', dim=dim, keepdim=keepdim)
+
+ def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
+ return self.reduce(op='sum', dim=dim, keepdim=keepdim)
+
+ def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
+ return self.reduce(op='prod', dim=dim, keepdim=keepdim)
+
+ def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
+ mean = self.mean(dim=dim, keepdim=True)
+ mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True)
+ std = (mean2 - mean ** 2).sqrt()
+ return std
+
+ def __repr__(self) -> str:
+ return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
+
+
+def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor:
+ """
+ Concatenate a list of varlen tensors.
+
+ Args:
+ inputs (List[VarLenTensor]): List of varlen tensors to concatenate.
+ """
+ if dim == 0:
+ new_feats = torch.cat([input.feats for input in inputs], dim=0)
+ start = 0
+ new_layout = []
+ for input in inputs:
+ for l in input.layout:
+ new_layout.append(slice(start, start + l.stop - l.start))
+ start += l.stop - l.start
+ output = VarLenTensor(feats=new_feats, layout=new_layout)
+ else:
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
+ output = inputs[0].replace(feats)
+
+ return output
+
+
+def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]:
+ """
+ Unbind a varlen tensor along a dimension.
+
+ Args:
+ input (VarLenTensor): Varlen tensor to unbind.
+ dim (int): Dimension to unbind.
+ """
+ if dim == 0:
+ return [input[i] for i in range(len(input))]
+ else:
+ feats = input.feats.unbind(dim)
+ return [input.replace(f) for f in feats]
+
+
+class SparseTensor(VarLenTensor):
+ """
+ Sparse tensor with support for both torchsparse and spconv backends.
+
+ Parameters:
+ - feats (torch.Tensor): Features of the sparse tensor.
+ - coords (torch.Tensor): Coordinates of the sparse tensor.
+ - shape (torch.Size): Shape of the sparse tensor.
+ - layout (List[slice]): Layout of the sparse tensor for each batch
+ - data (SparseTensorData): Sparse tensor data used for convolusion
+
+ NOTE:
+ - Data corresponding to a same batch should be contiguous.
+ - Coords should be in [0, 1023]
+ """
+ SparseTensorData = None
+
+ @overload
+ def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ...
+
+ @overload
+ def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ...
+
+ def __init__(self, *args, **kwargs):
+ # Lazy import of sparse tensor backend
+ if self.SparseTensorData is None:
+ import importlib
+ if config.CONV == 'torchsparse':
+ self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor
+ elif config.CONV == 'spconv':
+ self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
+
+ method_id = 0
+ if len(args) != 0:
+ method_id = 0 if isinstance(args[0], torch.Tensor) else 1
+ else:
+ method_id = 1 if 'data' in kwargs else 0
+
+ if method_id == 0:
+ feats, coords, shape = args + (None,) * (3 - len(args))
+ if 'feats' in kwargs:
+ feats = kwargs['feats']
+ del kwargs['feats']
+ if 'coords' in kwargs:
+ coords = kwargs['coords']
+ del kwargs['coords']
+ if 'shape' in kwargs:
+ shape = kwargs['shape']
+ del kwargs['shape']
+
+ if config.CONV == 'torchsparse':
+ self.data = self.SparseTensorData(feats, coords, **kwargs)
+ elif config.CONV == 'spconv':
+ spatial_shape = list(coords.max(0)[0] + 1)
+ self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs)
+ self.data._features = feats
+ else:
+ self.data = {
+ 'feats': feats,
+ 'coords': coords,
+ }
+ elif method_id == 1:
+ data, shape = args + (None,) * (2 - len(args))
+ if 'data' in kwargs:
+ data = kwargs['data']
+ del kwargs['data']
+ if 'shape' in kwargs:
+ shape = kwargs['shape']
+ del kwargs['shape']
+
+ self.data = data
+
+ self._shape = shape
+ self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)))
+ self._spatial_cache = kwargs.get('spatial_cache', {})
+
+ if config.DEBUG:
+ try:
+ assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
+ assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
+ assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
+ for i in range(self.shape[0]):
+ assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
+ except Exception as e:
+ print('Debugging information:')
+ print(f"- Shape: {self.shape}")
+ print(f"- Layout: {self.layout}")
+ print(f"- Scale: {self._scale}")
+ print(f"- Coords: {self.coords}")
+ raise e
+
+ @staticmethod
+ def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor':
+ """
+ Create a SparseTensor from a list of tensors.
+ """
+ feats = torch.cat(feats_list, dim=0)
+ coords = []
+ for i, coord in enumerate(coords_list):
+ coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1)
+ coords.append(coord)
+ coords = torch.cat(coords, dim=0)
+ return SparseTensor(feats, coords)
+
+ def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+ """
+ Convert a SparseTensor to list of tensors.
+ """
+ feats_list = []
+ coords_list = []
+ for s in self.layout:
+ feats_list.append(self.feats[s])
+ coords_list.append(self.coords[s])
+ return feats_list, coords_list
+
+ def __len__(self) -> int:
+ return len(self.layout)
+
+ def __cal_shape(self, feats, coords):
+ shape = []
+ shape.append(coords[:, 0].max().item() + 1)
+ shape.extend([*feats.shape[1:]])
+ return torch.Size(shape)
+
+ def __cal_layout(self, coords, batch_size):
+ seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
+ offset = torch.cumsum(seq_len, dim=0)
+ layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
+ return layout
+
+ def __cal_spatial_shape(self, coords):
+ return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist())
+
+ @property
+ def shape(self) -> torch.Size:
+ if self._shape is None:
+ self._shape = self.__cal_shape(self.feats, self.coords)
+ return self._shape
+
+ @property
+ def layout(self) -> List[slice]:
+ layout = self.get_spatial_cache('layout')
+ if layout is None:
+ layout = self.__cal_layout(self.coords, self.shape[0])
+ self.register_spatial_cache('layout', layout)
+ return layout
+
+ @property
+ def spatial_shape(self) -> torch.Size:
+ spatial_shape = self.get_spatial_cache('shape')
+ if spatial_shape is None:
+ spatial_shape = self.__cal_spatial_shape(self.coords)
+ self.register_spatial_cache('shape', spatial_shape)
+ return spatial_shape
+
+ @property
+ def feats(self) -> torch.Tensor:
+ if config.CONV == 'torchsparse':
+ return self.data.F
+ elif config.CONV == 'spconv':
+ return self.data.features
+ else:
+ return self.data['feats']
+
+ @feats.setter
+ def feats(self, value: torch.Tensor):
+ if config.CONV == 'torchsparse':
+ self.data.F = value
+ elif config.CONV == 'spconv':
+ self.data.features = value
+ else:
+ self.data['feats'] = value
+
+ @property
+ def coords(self) -> torch.Tensor:
+ if config.CONV == 'torchsparse':
+ return self.data.C
+ elif config.CONV == 'spconv':
+ return self.data.indices
+ else:
+ return self.data['coords']
+
+ @coords.setter
+ def coords(self, value: torch.Tensor):
+ if config.CONV == 'torchsparse':
+ self.data.C = value
+ elif config.CONV == 'spconv':
+ self.data.indices = value
+ else:
+ self.data['coords'] = value
+
+ @property
+ def dtype(self):
+ return self.feats.dtype
+
+ @property
+ def device(self):
+ return self.feats.device
+
+ @property
+ def seqlen(self) -> torch.LongTensor:
+ seqlen = self.get_spatial_cache('seqlen')
+ if seqlen is None:
+ seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device)
+ self.register_spatial_cache('seqlen', seqlen)
+ return seqlen
+
+ @property
+ def cum_seqlen(self) -> torch.LongTensor:
+ cum_seqlen = self.get_spatial_cache('cum_seqlen')
+ if cum_seqlen is None:
+ cum_seqlen = torch.cat([
+ torch.tensor([0], dtype=torch.long, device=self.device),
+ self.seqlen.cumsum(dim=0)
+ ], dim=0)
+ self.register_spatial_cache('cum_seqlen', cum_seqlen)
+ return cum_seqlen
+
+ @property
+ def batch_boardcast_map(self) -> torch.LongTensor:
+ """
+ Get the broadcast map for the varlen tensor.
+ """
+ batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map')
+ if batch_boardcast_map is None:
+ batch_boardcast_map = torch.repeat_interleave(
+ torch.arange(len(self.layout), device=self.device),
+ self.seqlen,
+ )
+ self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map)
+ return batch_boardcast_map
+
+ @overload
+ def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ...
+
+ @overload
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ...
+
+ def to(self, *args, **kwargs) -> 'SparseTensor':
+ device = None
+ dtype = None
+ if len(args) == 2:
+ device, dtype = args
+ elif len(args) == 1:
+ if isinstance(args[0], torch.dtype):
+ dtype = args[0]
+ else:
+ device = args[0]
+ if 'dtype' in kwargs:
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
+ dtype = kwargs['dtype']
+ if 'device' in kwargs:
+ assert device is None, "to() received multiple values for argument 'device'"
+ device = kwargs['device']
+ non_blocking = kwargs.get('non_blocking', False)
+ copy = kwargs.get('copy', False)
+
+ new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy)
+ new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy)
+ return self.replace(new_feats, new_coords)
+
+ def type(self, dtype):
+ new_feats = self.feats.type(dtype)
+ return self.replace(new_feats)
+
+ def cpu(self) -> 'SparseTensor':
+ new_feats = self.feats.cpu()
+ new_coords = self.coords.cpu()
+ return self.replace(new_feats, new_coords)
+
+ def cuda(self) -> 'SparseTensor':
+ new_feats = self.feats.cuda()
+ new_coords = self.coords.cuda()
+ return self.replace(new_feats, new_coords)
+
+ def half(self) -> 'SparseTensor':
+ new_feats = self.feats.half()
+ return self.replace(new_feats)
+
+ def float(self) -> 'SparseTensor':
+ new_feats = self.feats.float()
+ return self.replace(new_feats)
+
+ def detach(self) -> 'SparseTensor':
+ new_coords = self.coords.detach()
+ new_feats = self.feats.detach()
+ return self.replace(new_feats, new_coords)
+
+ def reshape(self, *shape) -> 'SparseTensor':
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
+ return self.replace(new_feats)
+
+ def unbind(self, dim: int) -> List['SparseTensor']:
+ return sparse_unbind(self, dim)
+
+ def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
+ if config.CONV == 'torchsparse':
+ new_data = self.SparseTensorData(
+ feats=feats,
+ coords=self.data.coords if coords is None else coords,
+ stride=self.data.stride,
+ spatial_range=self.data.spatial_range,
+ )
+ new_data._caches = self.data._caches
+ elif config.CONV == 'spconv':
+ new_data = self.SparseTensorData(
+ self.data.features.reshape(self.data.features.shape[0], -1),
+ self.data.indices,
+ self.data.spatial_shape,
+ self.data.batch_size,
+ self.data.grid,
+ self.data.voxel_num,
+ self.data.indice_dict
+ )
+ new_data._features = feats
+ new_data.benchmark = self.data.benchmark
+ new_data.benchmark_record = self.data.benchmark_record
+ new_data.thrust_allocator = self.data.thrust_allocator
+ new_data._timer = self.data._timer
+ new_data.force_algo = self.data.force_algo
+ new_data.int8_scale = self.data.int8_scale
+ if coords is not None:
+ new_data.indices = coords
+ else:
+ new_data = {
+ 'feats': feats,
+ 'coords': self.data['coords'] if coords is None else coords,
+ }
+ new_tensor = SparseTensor(
+ new_data,
+ shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None,
+ scale=self._scale,
+ spatial_cache=self._spatial_cache
+ )
+ return new_tensor
+
+ def to_dense(self) -> torch.Tensor:
+ if config.CONV == 'torchsparse':
+ return self.data.dense()
+ elif config.CONV == 'spconv':
+ return self.data.dense()
+ else:
+ spatial_shape = self.spatial_shape
+ ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device)
+ idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1)
+ ret[tuple(idx)] = self.feats
+ return ret
+
+ @staticmethod
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
+ N, C = dim
+ x = torch.arange(aabb[0], aabb[3] + 1)
+ y = torch.arange(aabb[1], aabb[4] + 1)
+ z = torch.arange(aabb[2], aabb[5] + 1)
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
+ coords = torch.cat([
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
+ coords.repeat(N, 1),
+ ], dim=1).to(dtype=torch.int32, device=device)
+ feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
+ return SparseTensor(feats=feats, coords=coords)
+
+ def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
+ new_cache = {}
+ for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
+ if k in self._spatial_cache:
+ new_cache[k] = self._spatial_cache[k]
+ if k in other._spatial_cache:
+ if k not in new_cache:
+ new_cache[k] = other._spatial_cache[k]
+ else:
+ new_cache[k].update(other._spatial_cache[k])
+ return new_cache
+
+ def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor':
+ if isinstance(other, torch.Tensor):
+ try:
+ other = torch.broadcast_to(other, self.shape)
+ other = other[self.batch_boardcast_map]
+ except:
+ pass
+ if isinstance(other, VarLenTensor):
+ other = other.feats
+ new_feats = op(self.feats, other)
+ new_tensor = self.replace(new_feats)
+ if isinstance(other, SparseTensor):
+ new_tensor._spatial_cache = self.__merge_sparse_cache(other)
+ return new_tensor
+
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ idx = [idx]
+ elif isinstance(idx, slice):
+ idx = range(*idx.indices(self.shape[0]))
+ elif isinstance(idx, list):
+ assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}"
+ elif isinstance(idx, torch.Tensor):
+ if idx.dtype == torch.bool:
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
+ idx = idx.nonzero().squeeze(1)
+ elif idx.dtype in [torch.int32, torch.int64]:
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
+ else:
+ raise ValueError(f"Unknown index type: {idx.dtype}")
+ else:
+ raise ValueError(f"Unknown index type: {type(idx)}")
+
+ new_coords = []
+ new_feats = []
+ new_layout = []
+ new_shape = torch.Size([len(idx)] + list(self.shape[1:]))
+ start = 0
+ for new_idx, old_idx in enumerate(idx):
+ new_coords.append(self.coords[self.layout[old_idx]].clone())
+ new_coords[-1][:, 0] = new_idx
+ new_feats.append(self.feats[self.layout[old_idx]])
+ new_layout.append(slice(start, start + len(new_coords[-1])))
+ start += len(new_coords[-1])
+ new_coords = torch.cat(new_coords, dim=0).contiguous()
+ new_feats = torch.cat(new_feats, dim=0).contiguous()
+ new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape)
+ new_tensor.register_spatial_cache('layout', new_layout)
+ return new_tensor
+
+ def clear_spatial_cache(self) -> None:
+ """
+ Clear all spatial caches.
+ """
+ self._spatial_cache = {}
+
+ def register_spatial_cache(self, key, value) -> None:
+ """
+ Register a spatial cache.
+ The spatial cache can be any thing you want to cache.
+ The registery and retrieval of the cache is based on current scale.
+ """
+ scale_key = str(self._scale)
+ if scale_key not in self._spatial_cache:
+ self._spatial_cache[scale_key] = {}
+ self._spatial_cache[scale_key][key] = value
+
+ def get_spatial_cache(self, key=None):
+ """
+ Get a spatial cache.
+ """
+ scale_key = str(self._scale)
+ cur_scale_cache = self._spatial_cache.get(scale_key, {})
+ if key is None:
+ return cur_scale_cache
+ return cur_scale_cache.get(key, None)
+
+ def __repr__(self) -> str:
+ return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
+
+def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
+ """
+ Concatenate a list of sparse tensors.
+
+ Args:
+ inputs (List[SparseTensor]): List of sparse tensors to concatenate.
+ """
+ if dim == 0:
+ start = 0
+ coords = []
+ for input in inputs:
+ coords.append(input.coords.clone())
+ coords[-1][:, 0] += start
+ start += input.shape[0]
+ coords = torch.cat(coords, dim=0)
+ feats = torch.cat([input.feats for input in inputs], dim=0)
+ output = SparseTensor(
+ coords=coords,
+ feats=feats,
+ )
+ else:
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
+ output = inputs[0].replace(feats)
+
+ return output
+
+
+def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
+ """
+ Unbind a sparse tensor along a dimension.
+
+ Args:
+ input (SparseTensor): Sparse tensor to unbind.
+ dim (int): Dimension to unbind.
+ """
+ if dim == 0:
+ return [input[i] for i in range(input.shape[0])]
+ else:
+ feats = input.feats.unbind(dim)
+ return [input.replace(f) for f in feats]
diff --git a/trellis2/modules/sparse/config.py b/trellis2/modules/sparse/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..74e995366efaa6413a2ab64fe521a8eda119508a
--- /dev/null
+++ b/trellis2/modules/sparse/config.py
@@ -0,0 +1,43 @@
+from typing import *
+
+CONV = 'flex_gemm'
+DEBUG = False
+ATTN = 'flash_attn'
+
+def __from_env():
+ import os
+
+ global CONV
+ global DEBUG
+ global ATTN
+
+ env_sparse_conv_backend = os.environ.get('SPARSE_CONV_BACKEND')
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
+ env_sparse_attn_backend = os.environ.get('SPARSE_ATTN_BACKEND')
+ if env_sparse_attn_backend is None:
+ env_sparse_attn_backend = os.environ.get('ATTN_BACKEND')
+
+ if env_sparse_conv_backend is not None and env_sparse_conv_backend in ['none', 'spconv', 'torchsparse', 'flex_gemm']:
+ CONV = env_sparse_conv_backend
+ if env_sparse_debug is not None:
+ DEBUG = env_sparse_debug == '1'
+ if env_sparse_attn_backend is not None and env_sparse_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'flash_attn_4']:
+ ATTN = env_sparse_attn_backend
+
+ print(f"[SPARSE] Conv backend: {CONV}; Attention backend: {ATTN}")
+
+
+__from_env()
+
+
+def set_conv_backend(backend: Literal['none', 'spconv', 'torchsparse', 'flex_gemm']):
+ global CONV
+ CONV = backend
+
+def set_debug(debug: bool):
+ global DEBUG
+ DEBUG = debug
+
+def set_attn_backend(backend: Literal['xformers', 'flash_attn', 'flash_attn_3', 'flash_attn_4']):
+ global ATTN
+ ATTN = backend
diff --git a/trellis2/modules/sparse/conv/__init__.py b/trellis2/modules/sparse/conv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7f5911f2cb266f93e52cfcdea9f63f39be172c6
--- /dev/null
+++ b/trellis2/modules/sparse/conv/__init__.py
@@ -0,0 +1,2 @@
+from .conv import SparseConv3d, SparseInverseConv3d
+from . import config
diff --git a/trellis2/modules/sparse/conv/config.py b/trellis2/modules/sparse/conv/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac0848906703e7811300235e32d14e50ad5aac51
--- /dev/null
+++ b/trellis2/modules/sparse/conv/config.py
@@ -0,0 +1,3 @@
+SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
+FLEX_GEMM_ALGO = 'masked_implicit_gemm_splitk' # 'explicit_gemm', 'implicit_gemm', 'implicit_gemm_splitk', 'masked_implicit_gemm', 'masked_implicit_gemm_splitk'
+FLEX_GEMM_HASHMAP_RATIO = 2.0 # Ratio of hashmap size to input size
diff --git a/trellis2/modules/sparse/conv/conv.py b/trellis2/modules/sparse/conv/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c7d40707a24e26ba2a11a90c41dbc9eb11e7ab2
--- /dev/null
+++ b/trellis2/modules/sparse/conv/conv.py
@@ -0,0 +1,30 @@
+from .. import config
+import importlib
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+
+
+_backends = {}
+
+
+class SparseConv3d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
+ super(SparseConv3d, self).__init__()
+ if config.CONV not in _backends:
+ _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__)
+ _backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key)
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ return _backends[config.CONV].sparse_conv3d_forward(self, x)
+
+
+class SparseInverseConv3d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
+ super(SparseInverseConv3d, self).__init__()
+ if config.CONV not in _backends:
+ _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__)
+ _backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key)
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x)
diff --git a/trellis2/modules/sparse/conv/conv_flex_gemm.py b/trellis2/modules/sparse/conv/conv_flex_gemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25619475e4bf39307f97d47b3828b19c48cd7da
--- /dev/null
+++ b/trellis2/modules/sparse/conv/conv_flex_gemm.py
@@ -0,0 +1,68 @@
+import math
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+from . import config
+import flex_gemm
+from flex_gemm.ops.spconv import sparse_submanifold_conv3d
+
+
+def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
+ assert stride == 1 and (padding is None), 'Currently flex_gemm implementation only support submanifold sparse convolution (stride=1, padding=None)'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3
+ self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3
+
+ self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size)))
+ if bias:
+ self.bias = nn.Parameter(torch.empty(out_channels))
+ else:
+ self.register_parameter("bias", None)
+
+ # initialize parameters
+ torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
+ if fan_in != 0:
+ bound = 1 / math.sqrt(fan_in)
+ torch.nn.init.uniform_(self.bias, -bound, bound)
+
+ # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci)
+ self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous())
+
+
+def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
+ flex_gemm.ops.spconv.set_algorithm(config.FLEX_GEMM_ALGO)
+ flex_gemm.ops.spconv.set_hashmap_ratio(config.FLEX_GEMM_HASHMAP_RATIO)
+
+ # check if neighbor map is already computed
+ Co, Kd, Kh, Kw, Ci = self.weight.shape
+ neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
+ neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
+
+ out, neighbor_cache_ = sparse_submanifold_conv3d(
+ x.feats,
+ x.coords,
+ torch.Size([*x.shape, *x.spatial_shape]),
+ self.weight,
+ self.bias,
+ neighbor_cache,
+ self.dilation
+ )
+
+ if neighbor_cache is None:
+ x.register_spatial_cache(neighbor_cache_key, neighbor_cache_)
+
+ out = x.replace(out)
+ return out
+
+
+def sparse_inverse_conv3d_init(self, *args, **kwargs):
+ raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet')
+
+
+def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
+ raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet')
diff --git a/trellis2/modules/sparse/conv/conv_spconv.py b/trellis2/modules/sparse/conv/conv_spconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..f709708d4f8ee3bb98930c68b28e7f3b897fc591
--- /dev/null
+++ b/trellis2/modules/sparse/conv/conv_spconv.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+from . import config
+import spconv.pytorch as spconv
+
+
+def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
+ algo = None
+ if config.SPCONV_ALGO == 'native':
+ algo = spconv.ConvAlgo.Native
+ elif config.SPCONV_ALGO == 'implicit_gemm':
+ algo = spconv.ConvAlgo.MaskImplicitGemm
+ if stride == 1 and (padding is None):
+ self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
+ else:
+ self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
+ self.padding = padding
+
+
+def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
+ spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
+ new_data = self.conv(x.data)
+ new_shape = [x.shape[0], self.conv.out_channels]
+ new_layout = None if spatial_changed else x.layout
+
+ if spatial_changed and (x.shape[0] != 1):
+ # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
+ fwd = new_data.indices[:, 0].argsort()
+ bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
+ sorted_feats = new_data.features[fwd]
+ sorted_coords = new_data.indices[fwd]
+ unsorted_data = new_data
+ new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
+
+ out = SparseTensor(
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
+ scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
+ spatial_cache=x._spatial_cache,
+ )
+
+ if spatial_changed and (x.shape[0] != 1):
+ out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
+ out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
+
+ return out
+
+
+def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
+ self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
+
+
+def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
+ spatial_changed = any(s != 1 for s in self.stride)
+ if spatial_changed:
+ # recover the original spconv order
+ data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
+ bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
+ data = data.replace_feature(x.feats[bwd])
+ else:
+ data = x.data
+
+ new_data = self.conv(data)
+ new_shape = [x.shape[0], self.conv.out_channels]
+ new_layout = None if spatial_changed else x.layout
+ out = SparseTensor(
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
+ scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
+ spatial_cache=x._spatial_cache,
+ )
+ return out
diff --git a/trellis2/modules/sparse/conv/conv_torchsparse.py b/trellis2/modules/sparse/conv/conv_torchsparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..5234bd15553aa8d71df280672475a898ffe56af7
--- /dev/null
+++ b/trellis2/modules/sparse/conv/conv_torchsparse.py
@@ -0,0 +1,30 @@
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+import torchsparse
+
+
+def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
+
+
+def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
+ out = self.conv(x.data)
+ new_shape = [x.shape[0], self.conv.out_channels]
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
+ out._spatial_cache = x._spatial_cache
+ out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
+ return out
+
+
+def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
+
+
+def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
+ out = self.conv(x.data)
+ new_shape = [x.shape[0], self.conv.out_channels]
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
+ out._spatial_cache = x._spatial_cache
+ out._scale = tuple([s / stride for s, stride in zip(x._scale, self.conv.stride)])
+ return out
diff --git a/trellis2/modules/sparse/linear.py b/trellis2/modules/sparse/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..44317709ab16b17fa0132fd48e48519ea0ef9ea9
--- /dev/null
+++ b/trellis2/modules/sparse/linear.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+from . import VarLenTensor
+
+__all__ = [
+ 'SparseLinear'
+]
+
+
+class SparseLinear(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True):
+ super(SparseLinear, self).__init__(in_features, out_features, bias)
+
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
+ return input.replace(super().forward(input.feats))
diff --git a/trellis2/modules/sparse/nonlinearity.py b/trellis2/modules/sparse/nonlinearity.py
new file mode 100644
index 0000000000000000000000000000000000000000..950e5c03c997905162e39fec701db49a2640700c
--- /dev/null
+++ b/trellis2/modules/sparse/nonlinearity.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+from . import VarLenTensor
+
+__all__ = [
+ 'SparseReLU',
+ 'SparseSiLU',
+ 'SparseGELU',
+ 'SparseActivation'
+]
+
+
+class SparseReLU(nn.ReLU):
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
+ return input.replace(super().forward(input.feats))
+
+
+class SparseSiLU(nn.SiLU):
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
+ return input.replace(super().forward(input.feats))
+
+
+class SparseGELU(nn.GELU):
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
+ return input.replace(super().forward(input.feats))
+
+
+class SparseActivation(nn.Module):
+ def __init__(self, activation: nn.Module):
+ super().__init__()
+ self.activation = activation
+
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
+ return input.replace(self.activation(input.feats))
+
diff --git a/trellis2/modules/sparse/norm.py b/trellis2/modules/sparse/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..95711203f0adbae1c7ea845e2500a3823997d652
--- /dev/null
+++ b/trellis2/modules/sparse/norm.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+from ..utils import manual_cast
+from . import VarLenTensor
+from . import config
+
+__all__ = [
+ 'SparseGroupNorm',
+ 'SparseLayerNorm',
+ 'SparseGroupNorm32',
+ 'SparseLayerNorm32',
+]
+
+
+class SparseGroupNorm(nn.GroupNorm):
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
+ super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
+
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
+ nfeats = torch.zeros_like(input.feats)
+ for k in range(input.shape[0]):
+ bfeats = input.feats[input.layout[k]]
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
+ bfeats = super().forward(bfeats)
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
+ nfeats[input.layout[k]] = bfeats
+ return input.replace(nfeats)
+
+
+class SparseLayerNorm(nn.LayerNorm):
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
+ super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
+
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
+ nfeats = torch.zeros_like(input.feats)
+ for k in range(input.shape[0]):
+ bfeats = input.feats[input.layout[k]]
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
+ bfeats = super().forward(bfeats)
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
+ nfeats[input.layout[k]] = bfeats
+ return input.replace(nfeats)
+
+
+class SparseGroupNorm32(SparseGroupNorm):
+ """
+ A GroupNorm layer that converts to float32 before the forward pass.
+ """
+ def forward(self, x: VarLenTensor) -> VarLenTensor:
+ x_dtype = x.dtype
+ x = manual_cast(x, torch.float32)
+ o = super().forward(x)
+ return manual_cast(o, x_dtype)
+
+
+class SparseLayerNorm32(SparseLayerNorm):
+ """
+ A LayerNorm layer that converts to float32 before the forward pass.
+ """
+ def forward(self, x: VarLenTensor) -> VarLenTensor:
+ x_dtype = x.dtype
+ x = manual_cast(x, torch.float32)
+ o = super().forward(x)
+ return manual_cast(o, x_dtype)
diff --git a/trellis2/modules/sparse/spatial/__init__.py b/trellis2/modules/sparse/spatial/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e27425f165d271fd16a2c6f7b7684d4a81202ebd
--- /dev/null
+++ b/trellis2/modules/sparse/spatial/__init__.py
@@ -0,0 +1,2 @@
+from .basic import *
+from .spatial2channel import *
diff --git a/trellis2/modules/sparse/spatial/basic.py b/trellis2/modules/sparse/spatial/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaeb8afefd889d8a94812c579383e528b8b56699
--- /dev/null
+++ b/trellis2/modules/sparse/spatial/basic.py
@@ -0,0 +1,109 @@
+from typing import *
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+
+__all__ = [
+ 'SparseDownsample',
+ 'SparseUpsample',
+]
+
+
+class SparseDownsample(nn.Module):
+ """
+ Downsample a sparse tensor by a factor of `factor`.
+ Implemented as average pooling.
+ """
+ def __init__(self, factor: int, mode: Literal['mean', 'max'] = 'mean'):
+ super(SparseDownsample, self).__init__()
+ self.factor = factor
+ self.mode = mode
+ assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}'
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ cache = x.get_spatial_cache(f'downsample_{self.factor}')
+ if cache is None:
+ DIM = x.coords.shape[-1] - 1
+
+ coord = list(x.coords.unbind(dim=-1))
+ for i in range(DIM):
+ coord[i+1] = coord[i+1] // self.factor
+
+ MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
+ OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
+ code = sum([c * o for c, o in zip(coord, OFFSET)])
+ code, idx = code.unique(return_inverse=True)
+
+ new_coords = torch.stack(
+ [code // OFFSET[0]] +
+ [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
+ dim=-1
+ )
+ else:
+ new_coords, idx = cache
+
+ new_feats = torch.scatter_reduce(
+ torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype),
+ dim=0,
+ index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]),
+ src=x.feats,
+ reduce=self.mode,
+ include_self=False,
+ )
+ out = SparseTensor(new_feats, new_coords, x._shape)
+ out._scale = tuple([s * self.factor for s in x._scale])
+ out._spatial_cache = x._spatial_cache
+
+ if cache is None:
+ x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx))
+ out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx))
+ out.register_spatial_cache(f'shape', torch.Size(MAX))
+ if self.training:
+ subidx = x.coords[:, 1:] % self.factor
+ subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
+ subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool)
+ subdivision[idx, subidx] = True
+ out.register_spatial_cache(f'subdivision', subdivision)
+
+ return out
+
+
+class SparseUpsample(nn.Module):
+ """
+ Upsample a sparse tensor by a factor of `factor`.
+ Implemented as nearest neighbor interpolation.
+ """
+ def __init__(
+ self, factor: int
+ ):
+ super(SparseUpsample, self).__init__()
+ self.factor = factor
+
+ def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor:
+ DIM = x.coords.shape[-1] - 1
+
+ cache = x.get_spatial_cache(f'upsample_{self.factor}')
+ if cache is None:
+ if subdivision is None:
+ raise ValueError('Cache not found. Provide subdivision tensor or pair SparseUpsample with SparseDownsample.')
+ else:
+ sub = subdivision.feats
+ N_leaf = sub.sum(dim=-1)
+ subidx = sub.nonzero()[:, -1]
+ new_coords = x.coords.clone().detach()
+ new_coords[:, 1:] *= self.factor
+ new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
+ for i in range(DIM):
+ new_coords[:, i+1] += subidx // self.factor ** i % self.factor
+ idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
+ else:
+ new_coords, idx = cache
+
+ new_feats = x.feats[idx]
+ out = SparseTensor(new_feats, new_coords, x._shape)
+ out._scale = tuple([s / self.factor for s in x._scale])
+ if cache is not None: # only keep cache when subdiv following it
+ out._spatial_cache = x._spatial_cache
+
+ return out
+
\ No newline at end of file
diff --git a/trellis2/modules/sparse/spatial/spatial2channel.py b/trellis2/modules/sparse/spatial/spatial2channel.py
new file mode 100644
index 0000000000000000000000000000000000000000..577f36d208726f64422f8774c3556a1d643f1e2d
--- /dev/null
+++ b/trellis2/modules/sparse/spatial/spatial2channel.py
@@ -0,0 +1,93 @@
+from typing import *
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+
+
+class SparseSpatial2Channel(nn.Module):
+ """
+ Downsample a sparse tensor by a factor of `factor`.
+ Implemented as rearranging its features from spatial to channel.
+ """
+ def __init__(self, factor: int = 2):
+ super(SparseSpatial2Channel, self).__init__()
+ self.factor = factor
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ DIM = x.coords.shape[-1] - 1
+ cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
+ if cache is None:
+ coord = list(x.coords.unbind(dim=-1))
+ for i in range(DIM):
+ coord[i+1] = coord[i+1] // self.factor
+ subidx = x.coords[:, 1:] % self.factor
+ subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
+
+ MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
+ OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
+ code = sum([c * o for c, o in zip(coord, OFFSET)])
+ code, idx = code.unique(return_inverse=True)
+
+ new_coords = torch.stack(
+ [code // OFFSET[0]] +
+ [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
+ dim=-1
+ )
+ else:
+ new_coords, idx, subidx = cache
+
+ new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
+ new_feats[idx * self.factor ** DIM + subidx] = x.feats
+
+ out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
+ out._scale = tuple([s * self.factor for s in x._scale])
+ out._spatial_cache = x._spatial_cache
+
+ if cache is None:
+ x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
+ out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx))
+ out.register_spatial_cache(f'shape', torch.Size(MAX))
+ if self.training:
+ subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool)
+ subdivision[idx, subidx] = True
+ out.register_spatial_cache(f'subdivision', subdivision)
+
+ return out
+
+
+class SparseChannel2Spatial(nn.Module):
+ """
+ Upsample a sparse tensor by a factor of `factor`.
+ Implemented as rearranging its features from channel to spatial.
+ """
+ def __init__(self, factor: int = 2):
+ super(SparseChannel2Spatial, self).__init__()
+ self.factor = factor
+
+ def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor:
+ DIM = x.coords.shape[-1] - 1
+
+ cache = x.get_spatial_cache(f'channel2spatial_{self.factor}')
+ if cache is None:
+ if subdivision is None:
+ raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.')
+ else:
+ sub = subdivision.feats # [N, self.factor ** DIM]
+ N_leaf = sub.sum(dim=-1) # [N]
+ subidx = sub.nonzero()[:, -1]
+ new_coords = x.coords.clone().detach()
+ new_coords[:, 1:] *= self.factor
+ new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
+ for i in range(DIM):
+ new_coords[:, i+1] += subidx // self.factor ** i % self.factor
+ idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
+ else:
+ new_coords, idx, subidx = cache
+
+ x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
+ new_feats = x_feats[idx * self.factor ** DIM + subidx]
+ out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
+ out._scale = tuple([s / self.factor for s in x._scale])
+ if cache is not None: # only keep cache when subdiv following it
+ out._spatial_cache = x._spatial_cache
+ return out
diff --git a/trellis2/modules/sparse/transformer/__init__.py b/trellis2/modules/sparse/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd
--- /dev/null
+++ b/trellis2/modules/sparse/transformer/__init__.py
@@ -0,0 +1,2 @@
+from .blocks import *
+from .modulated import *
\ No newline at end of file
diff --git a/trellis2/modules/sparse/transformer/blocks.py b/trellis2/modules/sparse/transformer/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d1ec600404fba490894872109d44be6b6477186
--- /dev/null
+++ b/trellis2/modules/sparse/transformer/blocks.py
@@ -0,0 +1,145 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..basic import VarLenTensor, SparseTensor
+from ..linear import SparseLinear
+from ..nonlinearity import SparseGELU
+from ..attention import SparseMultiHeadAttention
+from ...norm import LayerNorm32
+
+
+class SparseFeedForwardNet(nn.Module):
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ SparseLinear(channels, int(channels * mlp_ratio)),
+ SparseGELU(approximate="tanh"),
+ SparseLinear(int(channels * mlp_ratio), channels),
+ )
+
+ def forward(self, x: VarLenTensor) -> VarLenTensor:
+ return self.mlp(x)
+
+
+class SparseTransformerBlock(nn.Module):
+ """
+ Sparse Transformer block (MSA + FFN).
+ """
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "swin"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ qkv_bias: bool = True,
+ ln_affine: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.attn = SparseMultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ rope_freq=rope_freq,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.mlp = SparseFeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+
+ def _forward(self, x: SparseTensor) -> SparseTensor:
+ h = x.replace(self.norm1(x.feats))
+ h = self.attn(h)
+ x = x + h
+ h = x.replace(self.norm2(x.feats))
+ h = self.mlp(h)
+ x = x + h
+ return x
+
+ def forward(self, x: SparseTensor) -> SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+ else:
+ return self._forward(x)
+
+
+class SparseTransformerCrossBlock(nn.Module):
+ """
+ Sparse Transformer cross-attention block (MSA + MCA + FFN).
+ """
+ def __init__(
+ self,
+ channels: int,
+ ctx_channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "swin"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ qkv_bias: bool = True,
+ ln_affine: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.self_attn = SparseMultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ type="self",
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.cross_attn = SparseMultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.mlp = SparseFeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+
+ def _forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
+ h = x.replace(self.norm1(x.feats))
+ h = self.self_attn(h)
+ x = x + h
+ h = x.replace(self.norm2(x.feats))
+ h = self.cross_attn(h, context)
+ x = x + h
+ h = x.replace(self.norm3(x.feats))
+ h = self.mlp(h)
+ x = x + h
+ return x
+
+ def forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
+ else:
+ return self._forward(x, context)
diff --git a/trellis2/modules/sparse/transformer/modulated.py b/trellis2/modules/sparse/transformer/modulated.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc8fa56513da50906b3ba503c6cdce52441cc99e
--- /dev/null
+++ b/trellis2/modules/sparse/transformer/modulated.py
@@ -0,0 +1,205 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..basic import VarLenTensor, SparseTensor
+from ..attention import SparseMultiHeadAttention, SparseProjectAttention, SparseGatedProjectAttention
+from ...norm import LayerNorm32
+from .blocks import SparseFeedForwardNet
+
+
+class ModulatedSparseTransformerBlock(nn.Module):
+ """
+ Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
+ """
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "swin"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ qkv_bias: bool = True,
+ share_mod: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.attn = SparseMultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ rope_freq=rope_freq,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.mlp = SparseFeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+ if not share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 6 * channels, bias=True)
+ )
+ else:
+ self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
+
+ def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
+ if self.share_mod:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+ h = x.replace(self.norm1(x.feats))
+ h = h * (1 + scale_msa) + shift_msa
+ h = self.attn(h)
+ h = h * gate_msa
+ x = x + h
+ h = x.replace(self.norm2(x.feats))
+ h = h * (1 + scale_mlp) + shift_mlp
+ h = self.mlp(h)
+ h = h * gate_mlp
+ x = x + h
+ return x
+
+ def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
+ else:
+ return self._forward(x, mod)
+
+
+class ModulatedSparseTransformerCrossBlock(nn.Module):
+ """
+ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
+
+ Supports two image attention modes:
+ - "cross": Standard cross-attention with image features
+ - "proj": Projection-based attention with view-aligned features
+ """
+ def __init__(
+ self,
+ channels: int,
+ ctx_channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "swin"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ qkv_bias: bool = True,
+ share_mod: bool = False,
+ image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross",
+ proj_in_channels: Optional[int] = None,
+ vae_in_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.image_attn_mode = image_attn_mode
+
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.self_attn = SparseMultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ type="self",
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ rope_freq=rope_freq,
+ qk_rms_norm=qk_rms_norm,
+ )
+
+ # Build cross attention based on mode
+ if image_attn_mode == "cross":
+ self.cross_attn = SparseMultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ elif image_attn_mode == "proj":
+ _proj_in = proj_in_channels if proj_in_channels is not None else ctx_channels
+ cross_attn_block = SparseMultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.cross_attn = SparseProjectAttention(cross_attn_block, channels, _proj_in)
+ elif image_attn_mode == "gated_proj":
+ _dino_in = proj_in_channels if proj_in_channels is not None else ctx_channels
+ _vae_in = vae_in_channels if vae_in_channels is not None else 16
+ cross_attn_block = SparseMultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.cross_attn = SparseGatedProjectAttention(cross_attn_block, channels, _dino_in, _vae_in)
+ else:
+ raise ValueError(f"Unknown image attention mode: {image_attn_mode}")
+
+ self.mlp = SparseFeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+ if not share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 6 * channels, bias=True)
+ )
+ else:
+ self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
+
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
+ if self.share_mod:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+ h = x.replace(self.norm1(x.feats))
+ h = h * (1 + scale_msa) + shift_msa
+ h = self.self_attn(h)
+ h = h * gate_msa
+ x = x + h
+ h = x.replace(self.norm2(x.feats))
+ h = self.cross_attn(h, context)
+ x = x + h
+ h = x.replace(self.norm3(x.feats))
+ h = h * (1 + scale_mlp) + shift_mlp
+ h = self.mlp(h)
+ h = h * gate_mlp
+ x = x + h
+ return x
+
+ def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
+ else:
+ return self._forward(x, mod, context)
diff --git a/trellis2/modules/spatial.py b/trellis2/modules/spatial.py
new file mode 100644
index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3
--- /dev/null
+++ b/trellis2/modules/spatial.py
@@ -0,0 +1,48 @@
+import torch
+
+
+def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
+ """
+ 3D pixel shuffle.
+ """
+ B, C, H, W, D = x.shape
+ C_ = C // scale_factor**3
+ x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
+ x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
+ return x
+
+
+def patchify(x: torch.Tensor, patch_size: int):
+ """
+ Patchify a tensor.
+
+ Args:
+ x (torch.Tensor): (N, C, *spatial) tensor
+ patch_size (int): Patch size
+ """
+ DIM = x.dim() - 2
+ for d in range(2, DIM + 2):
+ assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
+
+ x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
+ x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
+ x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
+ return x
+
+
+def unpatchify(x: torch.Tensor, patch_size: int):
+ """
+ Unpatchify a tensor.
+
+ Args:
+ x (torch.Tensor): (N, C, *spatial) tensor
+ patch_size (int): Patch size
+ """
+ DIM = x.dim() - 2
+ assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
+
+ x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
+ x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
+ x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
+ return x
diff --git a/trellis2/modules/transformer/__init__.py b/trellis2/modules/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd
--- /dev/null
+++ b/trellis2/modules/transformer/__init__.py
@@ -0,0 +1,2 @@
+from .blocks import *
+from .modulated import *
\ No newline at end of file
diff --git a/trellis2/modules/transformer/blocks.py b/trellis2/modules/transformer/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb6f5eb5462fec62aa5edc062104f643fca03bfa
--- /dev/null
+++ b/trellis2/modules/transformer/blocks.py
@@ -0,0 +1,186 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..attention import MultiHeadAttention
+from ..norm import LayerNorm32
+
+
+class AbsolutePositionEmbedder(nn.Module):
+ """
+ Embeds spatial positions into vector representations.
+ """
+ def __init__(self, channels: int, in_channels: int = 3):
+ super().__init__()
+ self.channels = channels
+ self.in_channels = in_channels
+ self.freq_dim = channels // in_channels // 2
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
+ self.freqs = 1.0 / (10000 ** self.freqs)
+
+ def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Create sinusoidal position embeddings.
+
+ Args:
+ x: a 1-D Tensor of N indices
+
+ Returns:
+ an (N, D) Tensor of positional embeddings.
+ """
+ self.freqs = self.freqs.to(x.device)
+ out = torch.outer(x, self.freqs)
+ out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
+ return out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): (N, D) tensor of spatial positions
+ """
+ N, D = x.shape
+ assert D == self.in_channels, "Input dimension must match number of input channels"
+ embed = self._sin_cos_embedding(x.reshape(-1))
+ embed = embed.reshape(N, -1)
+ if embed.shape[1] < self.channels:
+ embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
+ return embed
+
+
+class FeedForwardNet(nn.Module):
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(channels, int(channels * mlp_ratio)),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(int(channels * mlp_ratio), channels),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.mlp(x)
+
+
+class TransformerBlock(nn.Module):
+ """
+ Transformer block (MSA + FFN).
+ """
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[int] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ qkv_bias: bool = True,
+ ln_affine: bool = True,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.attn = MultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ rope_freq=rope_freq,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.mlp = FeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+
+ def _forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ h = self.norm1(x)
+ h = self.attn(h, phases=phases)
+ x = x + h
+ h = self.norm2(x)
+ h = self.mlp(h)
+ x = x + h
+ return x
+
+ def forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, phases, use_reentrant=False)
+ else:
+ return self._forward(x, phases)
+
+
+class TransformerCrossBlock(nn.Module):
+ """
+ Transformer cross-attention block (MSA + MCA + FFN).
+ """
+ def __init__(
+ self,
+ channels: int,
+ ctx_channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ qkv_bias: bool = True,
+ ln_affine: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+ self.self_attn = MultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ type="self",
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ rope_freq=rope_freq,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.cross_attn = MultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.mlp = FeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+
+ def _forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ h = self.norm1(x)
+ h = self.self_attn(h, phases=phases)
+ x = x + h
+ h = self.norm2(x)
+ h = self.cross_attn(h, context)
+ x = x + h
+ h = self.norm3(x)
+ h = self.mlp(h)
+ x = x + h
+ return x
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, phases, use_reentrant=False)
+ else:
+ return self._forward(x, context, phases)
+
\ No newline at end of file
diff --git a/trellis2/modules/transformer/modulated.py b/trellis2/modules/transformer/modulated.py
new file mode 100644
index 0000000000000000000000000000000000000000..4444b208c5d7559029b0908eb8519282c512973c
--- /dev/null
+++ b/trellis2/modules/transformer/modulated.py
@@ -0,0 +1,205 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..attention import MultiHeadAttention, ProjectAttention, GatedProjectAttention
+from ..norm import LayerNorm32
+from .blocks import FeedForwardNet
+
+
+class ModulatedTransformerBlock(nn.Module):
+ """
+ Transformer block (MSA + FFN) with adaptive layer norm conditioning.
+ """
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ qkv_bias: bool = True,
+ share_mod: bool = False,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.attn = MultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ rope_freq=rope_freq,
+ qk_rms_norm=qk_rms_norm,
+ )
+ self.mlp = FeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+ if not share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 6 * channels, bias=True)
+ )
+ else:
+ self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
+
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.share_mod:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+ h = self.norm1(x)
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+ h = self.attn(h, phases=phases)
+ h = h * gate_msa.unsqueeze(1)
+ x = x + h
+ h = self.norm2(x)
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+ h = self.mlp(h)
+ h = h * gate_mlp.unsqueeze(1)
+ x = x + h
+ return x
+
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, phases, use_reentrant=False)
+ else:
+ return self._forward(x, mod, phases)
+
+
+class ModulatedTransformerCrossBlock(nn.Module):
+ """
+ Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
+
+ Supports two image attention modes:
+ - "cross": Standard cross-attention with image features
+ - "proj": Projection-based attention with view-aligned features
+ """
+ def __init__(
+ self,
+ channels: int,
+ ctx_channels: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: Literal["full", "windowed"] = "full",
+ window_size: Optional[int] = None,
+ shift_window: Optional[Tuple[int, int, int]] = None,
+ use_checkpoint: bool = False,
+ use_rope: bool = False,
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
+ qk_rms_norm: bool = False,
+ qk_rms_norm_cross: bool = False,
+ qkv_bias: bool = True,
+ share_mod: bool = False,
+ image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross",
+ proj_in_channels: Optional[int] = None,
+ vae_in_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ self.use_checkpoint = use_checkpoint
+ self.share_mod = share_mod
+ self.image_attn_mode = image_attn_mode
+
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+ self.self_attn = MultiHeadAttention(
+ channels,
+ num_heads=num_heads,
+ type="self",
+ attn_mode=attn_mode,
+ window_size=window_size,
+ shift_window=shift_window,
+ qkv_bias=qkv_bias,
+ use_rope=use_rope,
+ rope_freq=rope_freq,
+ qk_rms_norm=qk_rms_norm,
+ )
+
+ # Build cross attention based on mode
+ if image_attn_mode == "cross":
+ self.cross_attn = MultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ elif image_attn_mode == "proj":
+ _proj_in = proj_in_channels if proj_in_channels is not None else ctx_channels
+ cross_attn_block = MultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.cross_attn = ProjectAttention(cross_attn_block, channels, _proj_in)
+ elif image_attn_mode == "gated_proj":
+ _dino_in = proj_in_channels if proj_in_channels is not None else ctx_channels
+ _vae_in = vae_in_channels if vae_in_channels is not None else 16
+ cross_attn_block = MultiHeadAttention(
+ channels,
+ ctx_channels=ctx_channels,
+ num_heads=num_heads,
+ type="cross",
+ attn_mode="full",
+ qkv_bias=qkv_bias,
+ qk_rms_norm=qk_rms_norm_cross,
+ )
+ self.cross_attn = GatedProjectAttention(cross_attn_block, channels, _dino_in, _vae_in)
+ else:
+ raise ValueError(f"Unknown image attention mode: {image_attn_mode}")
+
+ self.mlp = FeedForwardNet(
+ channels,
+ mlp_ratio=mlp_ratio,
+ )
+ if not share_mod:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 6 * channels, bias=True)
+ )
+ else:
+ self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
+
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.share_mod:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+ h = self.norm1(x)
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+ h = self.self_attn(h, phases=phases)
+ h = h * gate_msa.unsqueeze(1)
+ x = x + h
+ h = self.norm2(x)
+ h = self.cross_attn(h, context)
+ x = x + h
+ h = self.norm3(x)
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+ h = self.mlp(h)
+ h = h * gate_mlp.unsqueeze(1)
+ x = x + h
+ return x
+
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False)
+ else:
+ return self._forward(x, mod, context, phases)
+
\ No newline at end of file
diff --git a/trellis2/modules/utils.py b/trellis2/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d92d7d5b0ab07972fb6e0397c9c8e525e196211
--- /dev/null
+++ b/trellis2/modules/utils.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+from ..modules import sparse as sp
+
+MIX_PRECISION_MODULES = (
+ nn.Conv1d,
+ nn.Conv2d,
+ nn.Conv3d,
+ nn.ConvTranspose1d,
+ nn.ConvTranspose2d,
+ nn.ConvTranspose3d,
+ nn.Linear,
+ sp.SparseConv3d,
+ sp.SparseInverseConv3d,
+ sp.SparseLinear,
+)
+
+
+def convert_module_to_f16(l):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(l, MIX_PRECISION_MODULES):
+ for p in l.parameters():
+ p.data = p.data.half()
+
+
+def convert_module_to_f32(l):
+ """
+ Convert primitive modules to float32, undoing convert_module_to_f16().
+ """
+ if isinstance(l, MIX_PRECISION_MODULES):
+ for p in l.parameters():
+ p.data = p.data.float()
+
+
+def convert_module_to(l, dtype):
+ """
+ Convert primitive modules to the given dtype.
+ """
+ if isinstance(l, MIX_PRECISION_MODULES):
+ for p in l.parameters():
+ p.data = p.data.to(dtype)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+def manual_cast(tensor, dtype):
+ """
+ Cast if autocast is not enabled.
+ """
+ if not torch.is_autocast_enabled():
+ return tensor.type(dtype)
+ return tensor
+
+
+def str_to_dtype(dtype_str: str):
+ return {
+ 'f16': torch.float16,
+ 'fp16': torch.float16,
+ 'float16': torch.float16,
+ 'bf16': torch.bfloat16,
+ 'bfloat16': torch.bfloat16,
+ 'f32': torch.float32,
+ 'fp32': torch.float32,
+ 'float32': torch.float32,
+ }[dtype_str]
diff --git a/trellis2/pipelines/__init__.py b/trellis2/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..807eae0d16c6bce6d4b9dfdf8ae60ada8e43e180
--- /dev/null
+++ b/trellis2/pipelines/__init__.py
@@ -0,0 +1,54 @@
+import importlib
+
+__attributes = {
+ "Trellis2ImageTo3DPipeline": "trellis2_image_to_3d",
+ "Trellis2TexturingPipeline": "trellis2_texturing",
+ "Pixal3DImageTo3DPipeline": "pixal3d_image_to_3d",
+}
+
+__submodules = ['samplers', 'rembg']
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+ if name not in globals():
+ if name in __attributes:
+ module_name = __attributes[name]
+ module = importlib.import_module(f".{module_name}", __name__)
+ globals()[name] = getattr(module, name)
+ elif name in __submodules:
+ module = importlib.import_module(f".{name}", __name__)
+ globals()[name] = module
+ else:
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+ return globals()[name]
+
+
+def from_pretrained(path: str):
+ """
+ Load a pipeline from a model folder or a Hugging Face model hub.
+
+ Args:
+ path: The path to the model. Can be either local path or a Hugging Face model name.
+ """
+ import os
+ import json
+ is_local = os.path.exists(f"{path}/pipeline.json")
+
+ if is_local:
+ config_file = f"{path}/pipeline.json"
+ else:
+ from huggingface_hub import hf_hub_download
+ config_file = hf_hub_download(path, "pipeline.json")
+
+ with open(config_file, 'r') as f:
+ config = json.load(f)
+ return globals()[config['name']].from_pretrained(path)
+
+
+# For PyLance
+if __name__ == '__main__':
+ from . import samplers, rembg
+ from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline
+ from .trellis2_texturing import Trellis2TexturingPipeline
+ from .pixal3d_image_to_3d import Pixal3DImageTo3DPipeline
diff --git a/trellis2/pipelines/base.py b/trellis2/pipelines/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..331e1ed979b47b23ef1d45c8851658df5d49fb69
--- /dev/null
+++ b/trellis2/pipelines/base.py
@@ -0,0 +1,72 @@
+from typing import *
+import torch
+import torch.nn as nn
+from .. import models
+
+
+class Pipeline:
+ """
+ A base class for pipelines.
+ """
+ def __init__(
+ self,
+ models: dict[str, nn.Module] = None,
+ ):
+ if models is None:
+ return
+ self.models = models
+ for model in self.models.values():
+ model.eval()
+
+ @classmethod
+ def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pipeline":
+ """
+ Load a pretrained model.
+ """
+ import os
+ import json
+ is_local = os.path.exists(f"{path}/{config_file}")
+
+ if is_local:
+ config_file = f"{path}/{config_file}"
+ else:
+ from huggingface_hub import hf_hub_download
+ config_file = hf_hub_download(path, config_file)
+
+ with open(config_file, 'r') as f:
+ args = json.load(f)['args']
+
+ _models = {}
+ for k, v in args['models'].items():
+ if hasattr(cls, 'model_names_to_load') and k not in cls.model_names_to_load:
+ continue
+ try:
+ _models[k] = models.from_pretrained(f"{path}/{v}")
+ except Exception as e:
+ _models[k] = models.from_pretrained(v)
+
+ new_pipeline = cls(_models)
+ new_pipeline._pretrained_args = args
+ return new_pipeline
+
+ @property
+ def device(self) -> torch.device:
+ if hasattr(self, '_device'):
+ return self._device
+ for model in self.models.values():
+ if hasattr(model, 'device'):
+ return model.device
+ for model in self.models.values():
+ if hasattr(model, 'parameters'):
+ return next(model.parameters()).device
+ raise RuntimeError("No device found.")
+
+ def to(self, device: torch.device) -> None:
+ for model in self.models.values():
+ model.to(device)
+
+ def cuda(self) -> None:
+ self.to(torch.device("cuda"))
+
+ def cpu(self) -> None:
+ self.to(torch.device("cpu"))
\ No newline at end of file
diff --git a/trellis2/pipelines/pixal3d_image_to_3d.py b/trellis2/pipelines/pixal3d_image_to_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f1826c8d791411f13d22fc304e7b5dd379b3d01
--- /dev/null
+++ b/trellis2/pipelines/pixal3d_image_to_3d.py
@@ -0,0 +1,783 @@
+from typing import *
+import torch
+import torch.nn as nn
+import numpy as np
+from PIL import Image
+from .base import Pipeline
+from . import samplers, rembg
+from ..modules.sparse import SparseTensor
+from ..modules import image_feature_extractor
+from ..representations import Mesh, MeshWithVoxel
+
+
+class Pixal3DImageTo3DPipeline(Pipeline):
+ """
+ Pipeline for inferring Pixal3D (proj mode) image-to-3D models.
+
+ 基于 Trellis2 pipeline,使用 proj 模式进行推理。
+ 每个 stage (SS, Shape 512, Shape 1024, Tex 1024) 有独立的 image_cond_model (DinoV3ProjFeatureExtractor)。
+ 条件构建使用 camera-aware projection(需要 camera_angle_x, distance, mesh_scale 参数)。
+
+ Args:
+ models (dict[str, nn.Module]): The models to use in the pipeline.
+ sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
+ shape_slat_sampler (samplers.Sampler): The sampler for the structured latent.
+ tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
+ sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler.
+ shape_slat_sampler_params (dict): The parameters for the structured latent sampler.
+ tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
+ shape_slat_normalization (dict): The normalization parameters for the structured latent.
+ tex_slat_normalization (dict): The normalization parameters for the texture latent.
+ image_cond_model_ss (nn.Module): Proj image cond model for sparse structure stage.
+ image_cond_model_shape_512 (nn.Module): Proj image cond model for shape LR (512) stage.
+ image_cond_model_shape_1024 (nn.Module): Proj image cond model for shape HR (1024) stage.
+ image_cond_model_tex_1024 (nn.Module): Proj image cond model for texture (1024) stage.
+ rembg_model (Callable): The model for removing background.
+ low_vram (bool): Whether to use low-VRAM mode.
+ """
+ model_names_to_load = [
+ 'sparse_structure_flow_model',
+ 'sparse_structure_decoder',
+ 'shape_slat_flow_model_512',
+ 'shape_slat_flow_model_1024',
+ 'shape_slat_decoder',
+ 'tex_slat_flow_model_512',
+ 'tex_slat_flow_model_1024',
+ 'tex_slat_decoder',
+ ]
+
+ def __init__(
+ self,
+ models: dict[str, nn.Module] = None,
+ sparse_structure_sampler: samplers.Sampler = None,
+ shape_slat_sampler: samplers.Sampler = None,
+ tex_slat_sampler: samplers.Sampler = None,
+ sparse_structure_sampler_params: dict = None,
+ shape_slat_sampler_params: dict = None,
+ tex_slat_sampler_params: dict = None,
+ shape_slat_normalization: dict = None,
+ tex_slat_normalization: dict = None,
+ image_cond_model_ss: nn.Module = None,
+ image_cond_model_shape_512: nn.Module = None,
+ image_cond_model_shape_1024: nn.Module = None,
+ image_cond_model_tex_1024: nn.Module = None,
+ rembg_model: Callable = None,
+ low_vram: bool = True,
+ default_pipeline_type: str = '1024_cascade',
+ ):
+ if models is None:
+ return
+ super().__init__(models)
+ self.sparse_structure_sampler = sparse_structure_sampler
+ self.shape_slat_sampler = shape_slat_sampler
+ self.tex_slat_sampler = tex_slat_sampler
+ self.sparse_structure_sampler_params = sparse_structure_sampler_params
+ self.shape_slat_sampler_params = shape_slat_sampler_params
+ self.tex_slat_sampler_params = tex_slat_sampler_params
+ self.shape_slat_normalization = shape_slat_normalization
+ self.tex_slat_normalization = tex_slat_normalization
+ self.image_cond_model_ss = image_cond_model_ss
+ self.image_cond_model_shape_512 = image_cond_model_shape_512
+ self.image_cond_model_shape_1024 = image_cond_model_shape_1024
+ self.image_cond_model_tex_1024 = image_cond_model_tex_1024
+ self.rembg_model = rembg_model
+ self.low_vram = low_vram
+ self.default_pipeline_type = default_pipeline_type
+ self.pbr_attr_layout = {
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+ self._device = 'cpu'
+
+ @classmethod
+ def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pixal3DImageTo3DPipeline":
+ """
+ Load a pretrained model.
+
+ Args:
+ path (str): The path to the model. Can be either local path or a Hugging Face repository.
+ """
+ pipeline = super().from_pretrained(path, config_file)
+ args = pipeline._pretrained_args
+
+ pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
+ pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
+
+ pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
+ pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
+
+ pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
+ pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
+
+ pipeline.shape_slat_normalization = args['shape_slat_normalization']
+ pipeline.tex_slat_normalization = args['tex_slat_normalization']
+
+ # Proj mode: image_cond_models 需要外部加载后设置,这里先置为 None
+ pipeline.image_cond_model_ss = None
+ pipeline.image_cond_model_shape_512 = None
+ pipeline.image_cond_model_shape_1024 = None
+ pipeline.image_cond_model_tex_1024 = None
+
+ pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
+
+ pipeline.low_vram = args.get('low_vram', True)
+ pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
+ pipeline.pbr_attr_layout = {
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+ pipeline._device = 'cpu'
+
+ return pipeline
+
+ def to(self, device: torch.device) -> None:
+ self._device = device
+ if not self.low_vram:
+ super().to(device)
+ if self.rembg_model is not None:
+ self.rembg_model.to(device)
+
+ def preprocess_image(self, input: Image.Image, bg_color: tuple = (0, 0, 0)) -> Image.Image:
+ """
+ Preprocess the input image.
+
+ Args:
+ input: Input image (RGB or RGBA).
+ bg_color: Background color (R, G, B) in 0~255. Default black (0,0,0).
+ """
+ # if has alpha channel, use it directly; otherwise, remove background
+ has_alpha = False
+ if input.mode == 'RGBA':
+ alpha = np.array(input)[:, :, 3]
+ if not np.all(alpha == 255):
+ has_alpha = True
+ max_size = max(input.size)
+ scale = min(1, 1024 / max_size)
+ if scale < 1:
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
+ if has_alpha:
+ output = input
+ else:
+ input = input.convert('RGB')
+ if self.low_vram:
+ self.rembg_model.to(self.device)
+ output = self.rembg_model(input)
+ if self.low_vram:
+ self.rembg_model.cpu()
+ output_np = np.array(output)
+ alpha = output_np[:, :, 3]
+ bbox = np.argwhere(alpha > 0.8 * 255)
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
+ size = int(size * 1.1)
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
+ output = output.crop(bbox) # type: ignore
+ output = np.array(output).astype(np.float32) / 255
+ rgb = output[:, :, :3]
+ a = output[:, :, 3:4]
+ bg = np.array(bg_color, dtype=np.float32) / 255.0
+ output = rgb * a + bg * (1.0 - a)
+ output = Image.fromarray((np.clip(output, 0, 1) * 255).astype(np.uint8))
+ return output
+
+ # =========================================================================
+ # Proj 模式条件构建
+ # =========================================================================
+
+ @torch.no_grad()
+ def get_proj_cond_ss(
+ self,
+ image: list,
+ camera_angle_x: float = 0.8575560450553894,
+ distance: float = 2.0,
+ mesh_scale: float = 1.0,
+ ) -> dict:
+ """
+ Get proj conditioning for sparse structure stage.
+
+ Args:
+ image: List of PIL images.
+ camera_angle_x: Camera horizontal FOV in radians.
+ distance: Camera distance.
+ mesh_scale: Mesh scale.
+
+ Returns:
+ dict with 'cond' and 'neg_cond', each containing {'global': ..., 'proj': ...}
+ """
+ device = self.device
+ image_cond_model = self.image_cond_model_ss
+ if self.low_vram:
+ image_cond_model.to(device)
+ cam_angle = torch.tensor([camera_angle_x], device=device)
+ dist_tensor = torch.tensor([distance], device=device)
+ scale_tensor = torch.tensor([mesh_scale], device=device)
+ z_global, z_proj = image_cond_model(
+ image, camera_angle_x=cam_angle, distance=dist_tensor, mesh_scale=scale_tensor,
+ )
+ if self.low_vram:
+ image_cond_model.cpu()
+ return {
+ 'cond': {'global': z_global, 'proj': z_proj},
+ 'neg_cond': {'global': torch.zeros_like(z_global), 'proj': torch.zeros_like(z_proj)},
+ }
+
+ @torch.no_grad()
+ def get_proj_cond_shape(
+ self,
+ image_cond_model: nn.Module,
+ image: list,
+ coords: torch.Tensor,
+ camera_angle_x: float = 0.8575560450553894,
+ distance: float = 2.0,
+ mesh_scale: float = 1.0,
+ grid_resolution_override: int = None,
+ ) -> dict:
+ """
+ Get proj conditioning for shape/texture stages (sparse-token aligned).
+
+ Args:
+ image_cond_model: The proj image cond model for this stage.
+ image: List of PIL images.
+ coords: Sparse structure coordinates [N, 4] (batch_idx, x, y, z).
+ camera_angle_x: Camera horizontal FOV in radians.
+ distance: Camera distance.
+ mesh_scale: Mesh scale.
+ grid_resolution_override: Override the grid resolution if not None.
+
+ Returns:
+ dict with 'cond' and 'neg_cond', each containing {'global': ..., 'proj': SparseTensor}
+ """
+ device = self.device
+ if self.low_vram:
+ image_cond_model.to(device)
+
+ orig_grid_res = image_cond_model.grid_resolution
+ if grid_resolution_override is not None and grid_resolution_override != orig_grid_res:
+ image_cond_model.grid_resolution = grid_resolution_override
+ image_cond_model.proj_grid = image_cond_model.proj_grid.__class__(
+ grid_resolution=grid_resolution_override,
+ image_resolution=image_cond_model.proj_grid.image_resolution,
+ ).to(device)
+
+ B = 1
+ cam_angle = torch.tensor([camera_angle_x], device=device)
+ dist_tensor = torch.tensor([distance], device=device)
+ scale_tensor = torch.tensor([mesh_scale], device=device)
+ z_global, z_proj = image_cond_model(
+ image, camera_angle_x=cam_angle, distance=dist_tensor, mesh_scale=scale_tensor,
+ )
+ grid_res = image_cond_model.grid_resolution
+ z_proj_grid = z_proj.reshape(B, grid_res, grid_res, grid_res, -1)
+ batch_indices = coords[:, 0].long()
+ x_coords = coords[:, 1].long()
+ y_coords = coords[:, 2].long()
+ z_coords = coords[:, 3].long()
+ z_proj_sparse = z_proj_grid[batch_indices, x_coords, y_coords, z_coords]
+ z_proj_st = SparseTensor(feats=z_proj_sparse, coords=coords)
+
+ if grid_resolution_override is not None and grid_resolution_override != orig_grid_res:
+ image_cond_model.grid_resolution = orig_grid_res
+ image_cond_model.proj_grid = image_cond_model.proj_grid.__class__(
+ grid_resolution=orig_grid_res,
+ image_resolution=image_cond_model.proj_grid.image_resolution,
+ ).to(device)
+
+ if self.low_vram:
+ image_cond_model.cpu()
+ return {
+ 'cond': {'global': z_global, 'proj': z_proj_st},
+ 'neg_cond': {'global': torch.zeros_like(z_global), 'proj': SparseTensor(feats=torch.zeros_like(z_proj_sparse), coords=coords)},
+ }
+
+ # =========================================================================
+ # Sampling methods (保持与 Trellis2 一致)
+ # =========================================================================
+
+ def sample_sparse_structure(
+ self,
+ cond: dict,
+ resolution: int,
+ num_samples: int = 1,
+ sampler_params: dict = {},
+ ) -> torch.Tensor:
+ """
+ Sample sparse structures with the given conditioning.
+
+ Args:
+ cond (dict): The conditioning information.
+ resolution (int): The resolution of the sparse structure.
+ num_samples (int): The number of samples to generate.
+ sampler_params (dict): Additional parameters for the sampler.
+ """
+ # Sample sparse structure latent
+ flow_model = self.models['sparse_structure_flow_model']
+ reso = flow_model.resolution
+ in_channels = flow_model.in_channels
+ noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device)
+ sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ z_s = self.sparse_structure_sampler.sample(
+ flow_model,
+ noise,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling sparse structure (proj)",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ # Decode sparse structure latent
+ decoder = self.models['sparse_structure_decoder']
+ if self.low_vram:
+ decoder.to(self.device)
+ decoded = decoder(z_s)>0
+ if self.low_vram:
+ decoder.cpu()
+ if resolution != decoded.shape[2]:
+ ratio = decoded.shape[2] // resolution
+ decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
+ coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
+
+ return coords
+
+ def sample_shape_slat(
+ self,
+ cond: dict,
+ flow_model,
+ coords: torch.Tensor,
+ sampler_params: dict = {},
+ ) -> SparseTensor:
+ """
+ Sample structured latent with the given conditioning.
+
+ Args:
+ cond (dict): The conditioning information.
+ coords (torch.Tensor): The coordinates of the sparse structure.
+ sampler_params (dict): Additional parameters for the sampler.
+ """
+ # Sample structured latent
+ noise = SparseTensor(
+ feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
+ coords=coords,
+ )
+ sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ slat = self.shape_slat_sampler.sample(
+ flow_model,
+ noise,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling shape SLat (proj)",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ return slat
+
+ def sample_shape_slat_cascade(
+ self,
+ lr_cond: dict,
+ cond: dict,
+ flow_model_lr,
+ flow_model,
+ lr_resolution: int,
+ resolution: int,
+ coords: torch.Tensor,
+ sampler_params: dict = {},
+ max_num_tokens: int = 49152,
+ ) -> SparseTensor:
+ """
+ Sample structured latent with cascade (LR → HR).
+
+ Args:
+ lr_cond (dict): The conditioning information for LR stage.
+ cond (dict): The conditioning information for HR stage.
+ flow_model_lr: LR flow model.
+ flow_model: HR flow model.
+ lr_resolution (int): LR resolution.
+ resolution (int): Target HR resolution.
+ coords (torch.Tensor): The coordinates of the sparse structure.
+ sampler_params (dict): Additional parameters for the sampler.
+ max_num_tokens (int): Maximum number of tokens.
+ """
+ # LR
+ noise = SparseTensor(
+ feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device),
+ coords=coords,
+ )
+ sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model_lr.to(self.device)
+ slat = self.shape_slat_sampler.sample(
+ flow_model_lr,
+ noise,
+ **lr_cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling LR shape SLat (proj, 512)",
+ ).samples
+ if self.low_vram:
+ flow_model_lr.cpu()
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ # Upsample
+ if self.low_vram:
+ self.models['shape_slat_decoder'].to(self.device)
+ self.models['shape_slat_decoder'].low_vram = True
+ hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4)
+ if self.low_vram:
+ self.models['shape_slat_decoder'].cpu()
+ self.models['shape_slat_decoder'].low_vram = False
+ hr_resolution = resolution
+ while True:
+ quant_coords = torch.cat([
+ hr_coords[:, :1],
+ ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
+ ], dim=1)
+ coords = quant_coords.unique(dim=0)
+ num_tokens = coords.shape[0]
+ if num_tokens < max_num_tokens or hr_resolution == 1024:
+ if hr_resolution != resolution:
+ print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.")
+ break
+ hr_resolution -= 128
+
+ # Sample structured latent (HR)
+ noise = SparseTensor(
+ feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
+ coords=coords,
+ )
+ sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ slat = self.shape_slat_sampler.sample(
+ flow_model,
+ noise,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc=f"Sampling HR shape SLat (proj, {hr_resolution})",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ return slat, hr_resolution
+
+ def decode_shape_slat(
+ self,
+ slat: SparseTensor,
+ resolution: int,
+ ) -> Tuple[List[Mesh], List[SparseTensor]]:
+ """
+ Decode the structured latent.
+
+ Args:
+ slat (SparseTensor): The structured latent.
+
+ Returns:
+ List[Mesh]: The decoded meshes.
+ List[SparseTensor]: The decoded substructures.
+ """
+ self.models['shape_slat_decoder'].set_resolution(resolution)
+ if self.low_vram:
+ self.models['shape_slat_decoder'].to(self.device)
+ self.models['shape_slat_decoder'].low_vram = True
+ ret = self.models['shape_slat_decoder'](slat, return_subs=True)
+ if self.low_vram:
+ self.models['shape_slat_decoder'].cpu()
+ self.models['shape_slat_decoder'].low_vram = False
+ return ret
+
+ def sample_tex_slat(
+ self,
+ cond: dict,
+ flow_model,
+ shape_slat: SparseTensor,
+ sampler_params: dict = {},
+ ) -> SparseTensor:
+ """
+ Sample texture structured latent with the given conditioning.
+
+ Args:
+ cond (dict): The conditioning information.
+ shape_slat (SparseTensor): The structured latent for shape.
+ sampler_params (dict): Additional parameters for the sampler.
+ """
+ # Sample structured latent
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
+ shape_slat = (shape_slat - mean) / std
+
+ in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
+ noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
+ sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ slat = self.tex_slat_sampler.sample(
+ flow_model,
+ noise,
+ concat_cond=shape_slat,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling texture SLat (proj)",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ return slat
+
+ def decode_tex_slat(
+ self,
+ slat: SparseTensor,
+ subs: List[SparseTensor],
+ ) -> SparseTensor:
+ """
+ Decode the structured latent.
+
+ Args:
+ slat (SparseTensor): The structured latent.
+
+ Returns:
+ SparseTensor: The decoded texture voxels
+ """
+ if self.low_vram:
+ self.models['tex_slat_decoder'].to(self.device)
+ ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5
+ if self.low_vram:
+ self.models['tex_slat_decoder'].cpu()
+ return ret
+
+ @torch.no_grad()
+ def decode_latent(
+ self,
+ shape_slat: SparseTensor,
+ tex_slat: SparseTensor,
+ resolution: int,
+ ) -> List[MeshWithVoxel]:
+ """
+ Decode the latent codes.
+
+ Args:
+ shape_slat (SparseTensor): The structured latent for shape.
+ tex_slat (SparseTensor): The structured latent for texture.
+ resolution (int): The resolution of the output.
+ """
+ meshes, subs = self.decode_shape_slat(shape_slat, resolution)
+ tex_voxels = self.decode_tex_slat(tex_slat, subs)
+ out_mesh = []
+ torch.cuda.synchronize()
+ for m, v in zip(meshes, tex_voxels):
+ m.fill_holes()
+ out_mesh.append(
+ MeshWithVoxel(
+ m.vertices, m.faces,
+ origin = [-0.5, -0.5, -0.5],
+ voxel_size = 1 / resolution,
+ coords = v.coords[:, 1:],
+ attrs = v.feats,
+ voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
+ layout=self.pbr_attr_layout
+ )
+ )
+ return out_mesh
+
+ @torch.no_grad()
+ def run(
+ self,
+ image: Image.Image,
+ camera_params: dict,
+ num_samples: int = 1,
+ seed: int = 42,
+ sparse_structure_sampler_params: dict = {},
+ shape_slat_sampler_params: dict = {},
+ tex_slat_sampler_params: dict = {},
+ preprocess_image: bool = True,
+ return_latent: bool = False,
+ pipeline_type: Optional[str] = None,
+ max_num_tokens: int = 49152,
+ ) -> List[MeshWithVoxel]:
+ """
+ Run the Pixal3D pipeline (proj mode, cascade).
+
+ Args:
+ image (Image.Image): The image prompt.
+ camera_params (dict): Camera parameters with keys:
+ - camera_angle_x (float): Horizontal FOV in radians.
+ - distance (float): Camera distance.
+ - mesh_scale (float): Mesh scale factor.
+ num_samples (int): The number of samples to generate.
+ seed (int): The random seed.
+ sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
+ shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler.
+ tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
+ preprocess_image (bool): Whether to preprocess the image.
+ return_latent (bool): Whether to return the latent codes.
+ pipeline_type (str): The type of the pipeline. Options: '1024_cascade', '1536_cascade'.
+ max_num_tokens (int): The maximum number of tokens to use.
+ """
+ # Check pipeline type
+ pipeline_type = pipeline_type or self.default_pipeline_type
+ if pipeline_type == '1024_cascade':
+ assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
+ assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
+ assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
+ hr_resolution = 1024
+ elif pipeline_type == '1536_cascade':
+ assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
+ assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
+ assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
+ hr_resolution = 1536
+ else:
+ raise ValueError(f"Invalid pipeline type for Pixal3D proj mode: {pipeline_type}. "
+ f"Supported: '1024_cascade', '1536_cascade'.")
+
+ # Validate image_cond_models are set
+ assert self.image_cond_model_ss is not None, "image_cond_model_ss not set."
+ assert self.image_cond_model_shape_512 is not None, "image_cond_model_shape_512 not set."
+ assert self.image_cond_model_shape_1024 is not None, "image_cond_model_shape_1024 not set."
+ assert self.image_cond_model_tex_1024 is not None, "image_cond_model_tex_1024 not set."
+
+ # Extract camera params
+ camera_angle_x = camera_params['camera_angle_x']
+ distance = camera_params['distance']
+ mesh_scale = camera_params.get('mesh_scale', 1.0)
+
+ if preprocess_image:
+ image = self.preprocess_image(image)
+ torch.manual_seed(seed)
+
+ # ---- Stage 1: Sparse Structure (proj) ----
+ cond_ss = self.get_proj_cond_ss(
+ [image],
+ camera_angle_x=camera_angle_x,
+ distance=distance,
+ mesh_scale=mesh_scale,
+ )
+ ss_res = 32
+ coords = self.sample_sparse_structure(
+ cond_ss, ss_res,
+ num_samples, sparse_structure_sampler_params
+ )
+ del cond_ss
+ torch.cuda.empty_cache()
+
+ # ---- Stage 2: Shape LR 512 (proj) ----
+ cond_shape_lr = self.get_proj_cond_shape(
+ self.image_cond_model_shape_512, [image], coords,
+ camera_angle_x=camera_angle_x,
+ distance=distance,
+ mesh_scale=mesh_scale,
+ )
+ lr_slat = self.sample_shape_slat(
+ cond_shape_lr, self.models['shape_slat_flow_model_512'],
+ coords, shape_slat_sampler_params
+ )
+ del cond_shape_lr
+ torch.cuda.empty_cache()
+
+ # ---- Stage 3a: Upsample LR → HR ----
+ if self.low_vram:
+ self.models['shape_slat_decoder'].to(self.device)
+ self.models['shape_slat_decoder'].low_vram = True
+ hr_coords = self.models['shape_slat_decoder'].upsample(lr_slat, upsample_times=4)
+ if self.low_vram:
+ self.models['shape_slat_decoder'].cpu()
+ self.models['shape_slat_decoder'].low_vram = False
+
+ lr_resolution = 512
+ actual_hr_resolution = hr_resolution
+ while True:
+ grid_res = actual_hr_resolution // 16
+ quant_coords = torch.cat([
+ hr_coords[:, :1],
+ ((hr_coords[:, 1:] + 0.5) / lr_resolution * (grid_res - 1)).round().int(),
+ ], dim=1)
+ hr_coords_unique = quant_coords.unique(dim=0)
+ num_tokens = hr_coords_unique.shape[0]
+ if num_tokens < max_num_tokens or actual_hr_resolution == 1024:
+ break
+ actual_hr_resolution -= 128
+
+ actual_grid_res = actual_hr_resolution // 16
+ del lr_slat, hr_coords, quant_coords
+ torch.cuda.empty_cache()
+
+ # ---- Stage 3b: Shape HR (proj) ----
+ cond_shape_hr = self.get_proj_cond_shape(
+ self.image_cond_model_shape_1024, [image], hr_coords_unique,
+ camera_angle_x=camera_angle_x,
+ distance=distance,
+ mesh_scale=mesh_scale,
+ grid_resolution_override=actual_grid_res,
+ )
+ noise_hr = SparseTensor(
+ feats=torch.randn(hr_coords_unique.shape[0], self.models['shape_slat_flow_model_1024'].in_channels).to(self.device),
+ coords=hr_coords_unique,
+ )
+ sampler_params_hr = {**self.shape_slat_sampler_params, **shape_slat_sampler_params}
+ flow_model_hr = self.models['shape_slat_flow_model_1024']
+ if self.low_vram:
+ flow_model_hr.to(self.device)
+ hr_slat = self.shape_slat_sampler.sample(
+ flow_model_hr,
+ noise_hr,
+ **cond_shape_hr,
+ **sampler_params_hr,
+ verbose=True,
+ tqdm_desc=f"Sampling HR shape SLat (proj, {actual_hr_resolution})",
+ ).samples
+ if self.low_vram:
+ flow_model_hr.cpu()
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(hr_slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(hr_slat.device)
+ shape_slat = hr_slat * std + mean
+ del cond_shape_hr, noise_hr, hr_slat, hr_coords_unique
+ torch.cuda.empty_cache()
+
+ # ---- Stage 4: Texture (proj) ----
+ tex_grid_res = actual_hr_resolution // 16
+ cond_tex = self.get_proj_cond_shape(
+ self.image_cond_model_tex_1024, [image], shape_slat.coords,
+ camera_angle_x=camera_angle_x,
+ distance=distance,
+ mesh_scale=mesh_scale,
+ grid_resolution_override=tex_grid_res,
+ )
+ tex_slat = self.sample_tex_slat(
+ cond_tex, self.models['tex_slat_flow_model_1024'],
+ shape_slat, tex_slat_sampler_params
+ )
+ del cond_tex
+ torch.cuda.empty_cache()
+
+ # ---- Stage 5: Decode ----
+ res = actual_hr_resolution
+ out_mesh = self.decode_latent(shape_slat, tex_slat, res)
+ if return_latent:
+ return out_mesh, (shape_slat, tex_slat, res)
+ else:
+ return out_mesh
diff --git a/trellis2/pipelines/rembg/BiRefNet.py b/trellis2/pipelines/rembg/BiRefNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c71a99274823aefe6f18ab921a5beb074177de18
--- /dev/null
+++ b/trellis2/pipelines/rembg/BiRefNet.py
@@ -0,0 +1,42 @@
+from typing import *
+from transformers import AutoModelForImageSegmentation
+import torch
+from torchvision import transforms
+from PIL import Image
+
+
+class BiRefNet:
+ def __init__(self, model_name: str = "ZhengPeng7/BiRefNet"):
+ self.model = AutoModelForImageSegmentation.from_pretrained(
+ model_name, trust_remote_code=True
+ )
+ self.model.eval()
+ self.transform_image = transforms.Compose(
+ [
+ transforms.Resize((1024, 1024)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+
+ def to(self, device: str):
+ self.model.to(device)
+
+ def cuda(self):
+ self.model.cuda()
+
+ def cpu(self):
+ self.model.cpu()
+
+ def __call__(self, image: Image.Image) -> Image.Image:
+ image_size = image.size
+ input_images = self.transform_image(image).unsqueeze(0).to("cuda")
+ # Prediction
+ with torch.no_grad():
+ preds = self.model(input_images)[-1].sigmoid().cpu()
+ pred = preds[0].squeeze()
+ pred_pil = transforms.ToPILImage()(pred)
+ mask = pred_pil.resize(image_size)
+ image.putalpha(mask)
+ return image
+
\ No newline at end of file
diff --git a/trellis2/pipelines/rembg/__init__.py b/trellis2/pipelines/rembg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc1eed1ba0c962478fc48132432f2e649fe03411
--- /dev/null
+++ b/trellis2/pipelines/rembg/__init__.py
@@ -0,0 +1 @@
+from .BiRefNet import *
diff --git a/trellis2/pipelines/samplers/__init__.py b/trellis2/pipelines/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a69b95e5963c6d3a837518ed9c3cf6972235d60
--- /dev/null
+++ b/trellis2/pipelines/samplers/__init__.py
@@ -0,0 +1,6 @@
+from .base import Sampler
+from .flow_euler import (
+ FlowEulerSampler,
+ FlowEulerCfgSampler,
+ FlowEulerGuidanceIntervalSampler,
+)
\ No newline at end of file
diff --git a/trellis2/pipelines/samplers/base.py b/trellis2/pipelines/samplers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..1966ce787009a5ee0c1ed06dce491525ff1dbcbf
--- /dev/null
+++ b/trellis2/pipelines/samplers/base.py
@@ -0,0 +1,20 @@
+from typing import *
+from abc import ABC, abstractmethod
+
+
+class Sampler(ABC):
+ """
+ A base class for samplers.
+ """
+
+ @abstractmethod
+ def sample(
+ self,
+ model,
+ **kwargs
+ ):
+ """
+ Sample from a model.
+ """
+ pass
+
\ No newline at end of file
diff --git a/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c7a4da4c1324bed5319f89916b532a159471d84
--- /dev/null
+++ b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py
@@ -0,0 +1,29 @@
+from typing import *
+
+
+class ClassifierFreeGuidanceSamplerMixin:
+ """
+ A mixin class for samplers that apply classifier-free guidance.
+ """
+
+ def _inference_model(self, model, x_t, t, cond, neg_cond, guidance_strength, guidance_rescale=0.0, **kwargs):
+ if guidance_strength == 1:
+ return super()._inference_model(model, x_t, t, cond, **kwargs)
+ elif guidance_strength == 0:
+ return super()._inference_model(model, x_t, t, neg_cond, **kwargs)
+ else:
+ pred_pos = super()._inference_model(model, x_t, t, cond, **kwargs)
+ pred_neg = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
+ pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg
+
+ # CFG rescale
+ if guidance_rescale > 0:
+ x_0_pos = self._pred_to_xstart(x_t, t, pred_pos)
+ x_0_cfg = self._pred_to_xstart(x_t, t, pred)
+ std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True)
+ std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True)
+ x_0_rescaled = x_0_cfg * (std_pos / std_cfg)
+ x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg
+ pred = self._xstart_to_pred(x_t, t, x_0)
+
+ return pred
diff --git a/trellis2/pipelines/samplers/flow_euler.py b/trellis2/pipelines/samplers/flow_euler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ff72b84221f210f6cb06684bdf13aedc6cf20c3
--- /dev/null
+++ b/trellis2/pipelines/samplers/flow_euler.py
@@ -0,0 +1,208 @@
+from typing import *
+import torch
+import numpy as np
+from tqdm import tqdm
+from easydict import EasyDict as edict
+from .base import Sampler
+from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
+from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
+
+
+class FlowEulerSampler(Sampler):
+ """
+ Generate samples from a flow-matching model using Euler sampling.
+
+ Args:
+ sigma_min: The minimum scale of noise in flow.
+ """
+ def __init__(
+ self,
+ sigma_min: float,
+ ):
+ self.sigma_min = sigma_min
+
+ def _eps_to_xstart(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
+
+ def _xstart_to_eps(self, x_t, t, x_0):
+ assert x_t.shape == x_0.shape
+ return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
+
+ def _v_to_xstart_eps(self, x_t, t, v):
+ assert x_t.shape == v.shape
+ eps = (1 - t) * v + x_t
+ x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
+ return x_0, eps
+
+ def _pred_to_xstart(self, x_t, t, pred):
+ return (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * pred
+
+ def _xstart_to_pred(self, x_t, t, x_0):
+ return ((1 - self.sigma_min) * x_t - x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
+
+ def _inference_model(self, model, x_t, t, cond=None, **kwargs):
+ t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
+ return model(x_t, t, cond, **kwargs)
+
+ def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
+ pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
+ pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
+ return pred_x_0, pred_eps, pred_v
+
+ @torch.no_grad()
+ def sample_once(
+ self,
+ model,
+ x_t,
+ t: float,
+ t_prev: float,
+ cond: Optional[Any] = None,
+ **kwargs
+ ):
+ """
+ Sample x_{t-1} from the model using Euler method.
+
+ Args:
+ model: The model to sample from.
+ x_t: The [N x C x ...] tensor of noisy inputs at time t.
+ t: The current timestep.
+ t_prev: The previous timestep.
+ cond: conditional information.
+ **kwargs: Additional arguments for model inference.
+
+ Returns:
+ a dict containing the following
+ - 'pred_x_prev': x_{t-1}.
+ - 'pred_x_0': a prediction of x_0.
+ """
+ pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
+ pred_x_prev = x_t - (t - t_prev) * pred_v
+ return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
+
+ @torch.no_grad()
+ def sample(
+ self,
+ model,
+ noise,
+ cond: Optional[Any] = None,
+ steps: int = 50,
+ rescale_t: float = 1.0,
+ verbose: bool = True,
+ tqdm_desc: str = "Sampling",
+ **kwargs
+ ):
+ """
+ Generate samples from the model using Euler method.
+
+ Args:
+ model: The model to sample from.
+ noise: The initial noise tensor.
+ cond: conditional information.
+ steps: The number of steps to sample.
+ rescale_t: The rescale factor for t.
+ verbose: If True, show a progress bar.
+ tqdm_desc: A customized tqdm desc.
+ **kwargs: Additional arguments for model_inference.
+
+ Returns:
+ a dict containing the following
+ - 'samples': the model samples.
+ - 'pred_x_t': a list of prediction of x_t.
+ - 'pred_x_0': a list of prediction of x_0.
+ """
+ sample = noise
+ t_seq = np.linspace(1, 0, steps + 1)
+ t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
+ t_seq = t_seq.tolist()
+ t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
+ ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
+ for t, t_prev in tqdm(t_pairs, desc=tqdm_desc, disable=not verbose):
+ out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
+ sample = out.pred_x_prev
+ ret.pred_x_t.append(out.pred_x_prev)
+ ret.pred_x_0.append(out.pred_x_0)
+ ret.samples = sample
+ return ret
+
+
+class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
+ """
+ Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
+ """
+ @torch.no_grad()
+ def sample(
+ self,
+ model,
+ noise,
+ cond,
+ neg_cond,
+ steps: int = 50,
+ rescale_t: float = 1.0,
+ guidance_strength: float = 3.0,
+ verbose: bool = True,
+ **kwargs
+ ):
+ """
+ Generate samples from the model using Euler method.
+
+ Args:
+ model: The model to sample from.
+ noise: The initial noise tensor.
+ cond: conditional information.
+ neg_cond: negative conditional information.
+ steps: The number of steps to sample.
+ rescale_t: The rescale factor for t.
+ guidance_strength: The strength of classifier-free guidance.
+ verbose: If True, show a progress bar.
+ **kwargs: Additional arguments for model_inference.
+
+ Returns:
+ a dict containing the following
+ - 'samples': the model samples.
+ - 'pred_x_t': a list of prediction of x_t.
+ - 'pred_x_0': a list of prediction of x_0.
+ """
+ return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, **kwargs)
+
+
+class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
+ """
+ Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
+ """
+ @torch.no_grad()
+ def sample(
+ self,
+ model,
+ noise,
+ cond,
+ neg_cond,
+ steps: int = 50,
+ rescale_t: float = 1.0,
+ guidance_strength: float = 3.0,
+ guidance_interval: Tuple[float, float] = (0.0, 1.0),
+ verbose: bool = True,
+ **kwargs
+ ):
+ """
+ Generate samples from the model using Euler method.
+
+ Args:
+ model: The model to sample from.
+ noise: The initial noise tensor.
+ cond: conditional information.
+ neg_cond: negative conditional information.
+ steps: The number of steps to sample.
+ rescale_t: The rescale factor for t.
+ guidance_strength: The strength of classifier-free guidance.
+ guidance_interval: The interval for classifier-free guidance.
+ verbose: If True, show a progress bar.
+ **kwargs: Additional arguments for model_inference.
+
+ Returns:
+ a dict containing the following
+ - 'samples': the model samples.
+ - 'pred_x_t': a list of prediction of x_t.
+ - 'pred_x_0': a list of prediction of x_0.
+ """
+ return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, guidance_interval=guidance_interval, **kwargs)
diff --git a/trellis2/pipelines/samplers/guidance_interval_mixin.py b/trellis2/pipelines/samplers/guidance_interval_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f57869a17d1626f5b2c58eb3c477127bf464abf
--- /dev/null
+++ b/trellis2/pipelines/samplers/guidance_interval_mixin.py
@@ -0,0 +1,13 @@
+from typing import *
+
+
+class GuidanceIntervalSamplerMixin:
+ """
+ A mixin class for samplers that apply classifier-free guidance with interval.
+ """
+
+ def _inference_model(self, model, x_t, t, cond, guidance_strength, guidance_interval, **kwargs):
+ if guidance_interval[0] <= t <= guidance_interval[1]:
+ return super()._inference_model(model, x_t, t, cond, guidance_strength=guidance_strength, **kwargs)
+ else:
+ return super()._inference_model(model, x_t, t, cond, guidance_strength=1, **kwargs)
diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b4d16c5a2106a7dfc545603846452b0600018f6
--- /dev/null
+++ b/trellis2/pipelines/trellis2_image_to_3d.py
@@ -0,0 +1,610 @@
+from typing import *
+import torch
+import torch.nn as nn
+import numpy as np
+from PIL import Image
+from .base import Pipeline
+from . import samplers, rembg
+from ..modules.sparse import SparseTensor
+from ..modules import image_feature_extractor
+from ..representations import Mesh, MeshWithVoxel
+
+
+class Trellis2ImageTo3DPipeline(Pipeline):
+ """
+ Pipeline for inferring Trellis2 image-to-3D models.
+
+ Args:
+ models (dict[str, nn.Module]): The models to use in the pipeline.
+ sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
+ shape_slat_sampler (samplers.Sampler): The sampler for the structured latent.
+ tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
+ sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler.
+ shape_slat_sampler_params (dict): The parameters for the structured latent sampler.
+ tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
+ shape_slat_normalization (dict): The normalization parameters for the structured latent.
+ tex_slat_normalization (dict): The normalization parameters for the texture latent.
+ image_cond_model (Callable): The image conditioning model.
+ rembg_model (Callable): The model for removing background.
+ low_vram (bool): Whether to use low-VRAM mode.
+ """
+ model_names_to_load = [
+ 'sparse_structure_flow_model',
+ 'sparse_structure_decoder',
+ 'shape_slat_flow_model_512',
+ 'shape_slat_flow_model_1024',
+ 'shape_slat_decoder',
+ 'tex_slat_flow_model_512',
+ 'tex_slat_flow_model_1024',
+ 'tex_slat_decoder',
+ ]
+
+ def __init__(
+ self,
+ models: dict[str, nn.Module] = None,
+ sparse_structure_sampler: samplers.Sampler = None,
+ shape_slat_sampler: samplers.Sampler = None,
+ tex_slat_sampler: samplers.Sampler = None,
+ sparse_structure_sampler_params: dict = None,
+ shape_slat_sampler_params: dict = None,
+ tex_slat_sampler_params: dict = None,
+ shape_slat_normalization: dict = None,
+ tex_slat_normalization: dict = None,
+ image_cond_model: Callable = None,
+ rembg_model: Callable = None,
+ low_vram: bool = True,
+ default_pipeline_type: str = '1024_cascade',
+ ):
+ if models is None:
+ return
+ super().__init__(models)
+ self.sparse_structure_sampler = sparse_structure_sampler
+ self.shape_slat_sampler = shape_slat_sampler
+ self.tex_slat_sampler = tex_slat_sampler
+ self.sparse_structure_sampler_params = sparse_structure_sampler_params
+ self.shape_slat_sampler_params = shape_slat_sampler_params
+ self.tex_slat_sampler_params = tex_slat_sampler_params
+ self.shape_slat_normalization = shape_slat_normalization
+ self.tex_slat_normalization = tex_slat_normalization
+ self.image_cond_model = image_cond_model
+ self.rembg_model = rembg_model
+ self.low_vram = low_vram
+ self.default_pipeline_type = default_pipeline_type
+ self.pbr_attr_layout = {
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+ self._device = 'cpu'
+
+ @classmethod
+ def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2ImageTo3DPipeline":
+ """
+ Load a pretrained model.
+
+ Args:
+ path (str): The path to the model. Can be either local path or a Hugging Face repository.
+ """
+ pipeline = super().from_pretrained(path, config_file)
+ args = pipeline._pretrained_args
+
+ pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
+ pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
+
+ pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
+ pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
+
+ pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
+ pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
+
+ pipeline.shape_slat_normalization = args['shape_slat_normalization']
+ pipeline.tex_slat_normalization = args['tex_slat_normalization']
+
+ # HACK: 替换 dinov3 模型源为 camenduru 镜像
+ image_cond_args = args['image_cond_model']['args'].copy()
+ if image_cond_args.get('model_name') == 'facebook/dinov3-vitl16-pretrain-lvd1689m':
+ image_cond_args['model_name'] = 'camenduru/dinov3-vitl16-pretrain-lvd1689m'
+ pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**image_cond_args)
+ pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
+
+ pipeline.low_vram = args.get('low_vram', True)
+ pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
+ pipeline.pbr_attr_layout = {
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+ pipeline._device = 'cpu'
+
+ return pipeline
+
+ def to(self, device: torch.device) -> None:
+ self._device = device
+ if not self.low_vram:
+ super().to(device)
+ self.image_cond_model.to(device)
+ if self.rembg_model is not None:
+ self.rembg_model.to(device)
+
+ def preprocess_image(self, input: Image.Image, bg_color: tuple = (0, 0, 0)) -> Image.Image:
+ """
+ Preprocess the input image.
+
+ Args:
+ input: Input image (RGB or RGBA).
+ bg_color: Background color (R, G, B) in 0~255. Default black (0,0,0).
+ """
+ # if has alpha channel, use it directly; otherwise, remove background
+ has_alpha = False
+ if input.mode == 'RGBA':
+ alpha = np.array(input)[:, :, 3]
+ if not np.all(alpha == 255):
+ has_alpha = True
+ max_size = max(input.size)
+ scale = min(1, 1024 / max_size)
+ if scale < 1:
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
+ if has_alpha:
+ output = input
+ else:
+ input = input.convert('RGB')
+ if self.low_vram:
+ self.rembg_model.to(self.device)
+ output = self.rembg_model(input)
+ if self.low_vram:
+ self.rembg_model.cpu()
+ output_np = np.array(output)
+ alpha = output_np[:, :, 3]
+ bbox = np.argwhere(alpha > 0.8 * 255)
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
+ size = int(size * 1.1)
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
+ output = output.crop(bbox) # type: ignore
+ output = np.array(output).astype(np.float32) / 255
+ rgb = output[:, :, :3]
+ a = output[:, :, 3:4]
+ bg = np.array(bg_color, dtype=np.float32) / 255.0
+ output = rgb * a + bg * (1.0 - a)
+ output = Image.fromarray((np.clip(output, 0, 1) * 255).astype(np.uint8))
+ return output
+
+ def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
+ """
+ Get the conditioning information for the model.
+
+ Args:
+ image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
+
+ Returns:
+ dict: The conditioning information
+ """
+ self.image_cond_model.image_size = resolution
+ if self.low_vram:
+ self.image_cond_model.to(self.device)
+ cond = self.image_cond_model(image)
+ if self.low_vram:
+ self.image_cond_model.cpu()
+ if not include_neg_cond:
+ return {'cond': cond}
+ neg_cond = torch.zeros_like(cond)
+ return {
+ 'cond': cond,
+ 'neg_cond': neg_cond,
+ }
+
+ def sample_sparse_structure(
+ self,
+ cond: dict,
+ resolution: int,
+ num_samples: int = 1,
+ sampler_params: dict = {},
+ ) -> torch.Tensor:
+ """
+ Sample sparse structures with the given conditioning.
+
+ Args:
+ cond (dict): The conditioning information.
+ resolution (int): The resolution of the sparse structure.
+ num_samples (int): The number of samples to generate.
+ sampler_params (dict): Additional parameters for the sampler.
+ """
+ # Sample sparse structure latent
+ flow_model = self.models['sparse_structure_flow_model']
+ reso = flow_model.resolution
+ in_channels = flow_model.in_channels
+ noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device)
+ sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ z_s = self.sparse_structure_sampler.sample(
+ flow_model,
+ noise,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling sparse structure",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ # Decode sparse structure latent
+ decoder = self.models['sparse_structure_decoder']
+ if self.low_vram:
+ decoder.to(self.device)
+ decoded = decoder(z_s)>0
+ if self.low_vram:
+ decoder.cpu()
+ if resolution != decoded.shape[2]:
+ ratio = decoded.shape[2] // resolution
+ decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
+ coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
+
+ return coords
+
+ def sample_shape_slat(
+ self,
+ cond: dict,
+ flow_model,
+ coords: torch.Tensor,
+ sampler_params: dict = {},
+ ) -> SparseTensor:
+ """
+ Sample structured latent with the given conditioning.
+
+ Args:
+ cond (dict): The conditioning information.
+ coords (torch.Tensor): The coordinates of the sparse structure.
+ sampler_params (dict): Additional parameters for the sampler.
+ """
+ # Sample structured latent
+ noise = SparseTensor(
+ feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
+ coords=coords,
+ )
+ sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ slat = self.shape_slat_sampler.sample(
+ flow_model,
+ noise,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling shape SLat",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ return slat
+
+ def sample_shape_slat_cascade(
+ self,
+ lr_cond: dict,
+ cond: dict,
+ flow_model_lr,
+ flow_model,
+ lr_resolution: int,
+ resolution: int,
+ coords: torch.Tensor,
+ sampler_params: dict = {},
+ max_num_tokens: int = 49152,
+ ) -> SparseTensor:
+ """
+ Sample structured latent with the given conditioning.
+
+ Args:
+ cond (dict): The conditioning information.
+ coords (torch.Tensor): The coordinates of the sparse structure.
+ sampler_params (dict): Additional parameters for the sampler.
+ """
+ # LR
+ noise = SparseTensor(
+ feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device),
+ coords=coords,
+ )
+ sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model_lr.to(self.device)
+ slat = self.shape_slat_sampler.sample(
+ flow_model_lr,
+ noise,
+ **lr_cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling shape SLat",
+ ).samples
+ if self.low_vram:
+ flow_model_lr.cpu()
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ # Upsample
+ if self.low_vram:
+ self.models['shape_slat_decoder'].to(self.device)
+ self.models['shape_slat_decoder'].low_vram = True
+ hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4)
+ if self.low_vram:
+ self.models['shape_slat_decoder'].cpu()
+ self.models['shape_slat_decoder'].low_vram = False
+ hr_resolution = resolution
+ while True:
+ quant_coords = torch.cat([
+ hr_coords[:, :1],
+ ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
+ ], dim=1)
+ coords = quant_coords.unique(dim=0)
+ num_tokens = coords.shape[0]
+ if num_tokens < max_num_tokens or hr_resolution == 1024:
+ if hr_resolution != resolution:
+ print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.")
+ break
+ hr_resolution -= 128
+
+ # Sample structured latent
+ noise = SparseTensor(
+ feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
+ coords=coords,
+ )
+ sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ slat = self.shape_slat_sampler.sample(
+ flow_model,
+ noise,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling shape SLat",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ return slat, hr_resolution
+
+ def decode_shape_slat(
+ self,
+ slat: SparseTensor,
+ resolution: int,
+ ) -> Tuple[List[Mesh], List[SparseTensor]]:
+ """
+ Decode the structured latent.
+
+ Args:
+ slat (SparseTensor): The structured latent.
+
+ Returns:
+ List[Mesh]: The decoded meshes.
+ List[SparseTensor]: The decoded substructures.
+ """
+ self.models['shape_slat_decoder'].set_resolution(resolution)
+ if self.low_vram:
+ self.models['shape_slat_decoder'].to(self.device)
+ self.models['shape_slat_decoder'].low_vram = True
+ ret = self.models['shape_slat_decoder'](slat, return_subs=True)
+ if self.low_vram:
+ self.models['shape_slat_decoder'].cpu()
+ self.models['shape_slat_decoder'].low_vram = False
+ return ret
+
+ def sample_tex_slat(
+ self,
+ cond: dict,
+ flow_model,
+ shape_slat: SparseTensor,
+ sampler_params: dict = {},
+ ) -> SparseTensor:
+ """
+ Sample structured latent with the given conditioning.
+
+ Args:
+ cond (dict): The conditioning information.
+ shape_slat (SparseTensor): The structured latent for shape
+ sampler_params (dict): Additional parameters for the sampler.
+ """
+ # Sample structured latent
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
+ shape_slat = (shape_slat - mean) / std
+
+ in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
+ noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
+ sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ slat = self.tex_slat_sampler.sample(
+ flow_model,
+ noise,
+ concat_cond=shape_slat,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling texture SLat",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ return slat
+
+ def decode_tex_slat(
+ self,
+ slat: SparseTensor,
+ subs: List[SparseTensor],
+ ) -> SparseTensor:
+ """
+ Decode the structured latent.
+
+ Args:
+ slat (SparseTensor): The structured latent.
+
+ Returns:
+ SparseTensor: The decoded texture voxels
+ """
+ if self.low_vram:
+ self.models['tex_slat_decoder'].to(self.device)
+ ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5
+ if self.low_vram:
+ self.models['tex_slat_decoder'].cpu()
+ return ret
+
+ @torch.no_grad()
+ def decode_latent(
+ self,
+ shape_slat: SparseTensor,
+ tex_slat: SparseTensor,
+ resolution: int,
+ ) -> List[MeshWithVoxel]:
+ """
+ Decode the latent codes.
+
+ Args:
+ shape_slat (SparseTensor): The structured latent for shape.
+ tex_slat (SparseTensor): The structured latent for texture.
+ resolution (int): The resolution of the output.
+ """
+ meshes, subs = self.decode_shape_slat(shape_slat, resolution)
+ tex_voxels = self.decode_tex_slat(tex_slat, subs)
+ out_mesh = []
+ torch.cuda.synchronize()
+ for m, v in zip(meshes, tex_voxels):
+ # try:
+ m.fill_holes()
+ # except RuntimeError as e:
+ # print(f"[WARNING] fill_holes failed (likely PyTorch/cumesh compatibility issue), skipping: {e}")
+ out_mesh.append(
+ MeshWithVoxel(
+ m.vertices, m.faces,
+ origin = [-0.5, -0.5, -0.5],
+ voxel_size = 1 / resolution,
+ coords = v.coords[:, 1:],
+ attrs = v.feats,
+ voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
+ layout=self.pbr_attr_layout
+ )
+ )
+ return out_mesh
+
+ @torch.no_grad()
+ def run(
+ self,
+ image: Image.Image,
+ num_samples: int = 1,
+ seed: int = 42,
+ sparse_structure_sampler_params: dict = {},
+ shape_slat_sampler_params: dict = {},
+ tex_slat_sampler_params: dict = {},
+ preprocess_image: bool = True,
+ return_latent: bool = False,
+ pipeline_type: Optional[str] = None,
+ max_num_tokens: int = 49152,
+ ) -> List[MeshWithVoxel]:
+ """
+ Run the pipeline.
+
+ Args:
+ image (Image.Image): The image prompt.
+ num_samples (int): The number of samples to generate.
+ seed (int): The random seed.
+ sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
+ shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler.
+ tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
+ preprocess_image (bool): Whether to preprocess the image.
+ return_latent (bool): Whether to return the latent codes.
+ pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
+ max_num_tokens (int): The maximum number of tokens to use.
+ """
+ # Check pipeline type
+ pipeline_type = pipeline_type or self.default_pipeline_type
+ if pipeline_type == '512':
+ assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
+ assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found."
+ elif pipeline_type == '1024':
+ assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
+ assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
+ elif pipeline_type == '1024_cascade':
+ assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
+ assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
+ assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
+ elif pipeline_type == '1536_cascade':
+ assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
+ assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
+ assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
+ else:
+ raise ValueError(f"Invalid pipeline type: {pipeline_type}")
+
+ if preprocess_image:
+ image = self.preprocess_image(image)
+ torch.manual_seed(seed)
+ cond_512 = self.get_cond([image], 512)
+ cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None
+ ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type]
+ coords = self.sample_sparse_structure(
+ cond_512, ss_res,
+ num_samples, sparse_structure_sampler_params
+ )
+ if pipeline_type == '512':
+ shape_slat = self.sample_shape_slat(
+ cond_512, self.models['shape_slat_flow_model_512'],
+ coords, shape_slat_sampler_params
+ )
+ tex_slat = self.sample_tex_slat(
+ cond_512, self.models['tex_slat_flow_model_512'],
+ shape_slat, tex_slat_sampler_params
+ )
+ res = 512
+ elif pipeline_type == '1024':
+ shape_slat = self.sample_shape_slat(
+ cond_1024, self.models['shape_slat_flow_model_1024'],
+ coords, shape_slat_sampler_params
+ )
+ tex_slat = self.sample_tex_slat(
+ cond_1024, self.models['tex_slat_flow_model_1024'],
+ shape_slat, tex_slat_sampler_params
+ )
+ res = 1024
+ elif pipeline_type == '1024_cascade':
+ shape_slat, res = self.sample_shape_slat_cascade(
+ cond_512, cond_1024,
+ self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
+ 512, 1024,
+ coords, shape_slat_sampler_params,
+ max_num_tokens
+ )
+ tex_slat = self.sample_tex_slat(
+ cond_1024, self.models['tex_slat_flow_model_1024'],
+ shape_slat, tex_slat_sampler_params
+ )
+ elif pipeline_type == '1536_cascade':
+ shape_slat, res = self.sample_shape_slat_cascade(
+ cond_512, cond_1024,
+ self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
+ 512, 1536,
+ coords, shape_slat_sampler_params,
+ max_num_tokens
+ )
+ tex_slat = self.sample_tex_slat(
+ cond_1024, self.models['tex_slat_flow_model_1024'],
+ shape_slat, tex_slat_sampler_params
+ )
+ torch.cuda.empty_cache()
+ out_mesh = self.decode_latent(shape_slat, tex_slat, res)
+ if return_latent:
+ return out_mesh, (shape_slat, tex_slat, res)
+ else:
+ return out_mesh
diff --git a/trellis2/pipelines/trellis2_texturing.py b/trellis2/pipelines/trellis2_texturing.py
new file mode 100644
index 0000000000000000000000000000000000000000..c184b5e73ab49d5f29256e16ff900abc92be73be
--- /dev/null
+++ b/trellis2/pipelines/trellis2_texturing.py
@@ -0,0 +1,408 @@
+from typing import *
+import torch
+import torch.nn as nn
+import numpy as np
+from PIL import Image
+import trimesh
+from .base import Pipeline
+from . import samplers, rembg
+from ..modules.sparse import SparseTensor
+from ..modules import image_feature_extractor
+import o_voxel
+import cumesh
+import nvdiffrast.torch as dr
+import cv2
+import flex_gemm
+
+
+class Trellis2TexturingPipeline(Pipeline):
+ """
+ Pipeline for inferring Trellis2 image-to-3D models.
+
+ Args:
+ models (dict[str, nn.Module]): The models to use in the pipeline.
+ tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
+ tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
+ shape_slat_normalization (dict): The normalization parameters for the structured latent.
+ tex_slat_normalization (dict): The normalization parameters for the texture latent.
+ image_cond_model (Callable): The image conditioning model.
+ rembg_model (Callable): The model for removing background.
+ low_vram (bool): Whether to use low-VRAM mode.
+ """
+ model_names_to_load = [
+ 'shape_slat_encoder',
+ 'tex_slat_decoder',
+ 'tex_slat_flow_model_512',
+ 'tex_slat_flow_model_1024'
+ ]
+
+ def __init__(
+ self,
+ models: dict[str, nn.Module] = None,
+ tex_slat_sampler: samplers.Sampler = None,
+ tex_slat_sampler_params: dict = None,
+ shape_slat_normalization: dict = None,
+ tex_slat_normalization: dict = None,
+ image_cond_model: Callable = None,
+ rembg_model: Callable = None,
+ low_vram: bool = True,
+ ):
+ if models is None:
+ return
+ super().__init__(models)
+ self.tex_slat_sampler = tex_slat_sampler
+ self.tex_slat_sampler_params = tex_slat_sampler_params
+ self.shape_slat_normalization = shape_slat_normalization
+ self.tex_slat_normalization = tex_slat_normalization
+ self.image_cond_model = image_cond_model
+ self.rembg_model = rembg_model
+ self.low_vram = low_vram
+ self.pbr_attr_layout = {
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+ self._device = 'cpu'
+
+ @classmethod
+ def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline":
+ """
+ Load a pretrained model.
+
+ Args:
+ path (str): The path to the model. Can be either local path or a Hugging Face repository.
+ """
+ pipeline = super().from_pretrained(path, config_file)
+ args = pipeline._pretrained_args
+
+ pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
+ pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
+
+ pipeline.shape_slat_normalization = args['shape_slat_normalization']
+ pipeline.tex_slat_normalization = args['tex_slat_normalization']
+
+ pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
+ pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
+
+ pipeline.low_vram = args.get('low_vram', True)
+ pipeline.pbr_attr_layout = {
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+ pipeline._device = 'cpu'
+ return pipeline
+
+ def to(self, device: torch.device) -> None:
+ self._device = device
+ if not self.low_vram:
+ super().to(device)
+ self.image_cond_model.to(device)
+ if self.rembg_model is not None:
+ self.rembg_model.to(device)
+
+ def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
+ """
+ Preprocess the input mesh.
+ """
+ vertices = mesh.vertices
+ vertices_min = vertices.min(axis=0)
+ vertices_max = vertices.max(axis=0)
+ center = (vertices_min + vertices_max) / 2
+ scale = 0.99999 / (vertices_max - vertices_min).max()
+ vertices = (vertices - center) * scale
+ tmp = vertices[:, 1].copy()
+ vertices[:, 1] = -vertices[:, 2]
+ vertices[:, 2] = tmp
+ assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range'
+ return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False)
+
+ def preprocess_image(self, input: Image.Image) -> Image.Image:
+ """
+ Preprocess the input image.
+ """
+ # if has alpha channel, use it directly; otherwise, remove background
+ has_alpha = False
+ if input.mode == 'RGBA':
+ alpha = np.array(input)[:, :, 3]
+ if not np.all(alpha == 255):
+ has_alpha = True
+ max_size = max(input.size)
+ scale = min(1, 1024 / max_size)
+ if scale < 1:
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
+ if has_alpha:
+ output = input
+ else:
+ input = input.convert('RGB')
+ if self.low_vram:
+ self.rembg_model.to(self.device)
+ output = self.rembg_model(input)
+ if self.low_vram:
+ self.rembg_model.cpu()
+ output_np = np.array(output)
+ alpha = output_np[:, :, 3]
+ bbox = np.argwhere(alpha > 0.8 * 255)
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
+ size = int(size * 1)
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
+ output = output.crop(bbox) # type: ignore
+ output = np.array(output).astype(np.float32) / 255
+ output = output[:, :, :3] * output[:, :, 3:4]
+ output = Image.fromarray((output * 255).astype(np.uint8))
+ return output
+
+ def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
+ """
+ Get the conditioning information for the model.
+
+ Args:
+ image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
+
+ Returns:
+ dict: The conditioning information
+ """
+ self.image_cond_model.image_size = resolution
+ if self.low_vram:
+ self.image_cond_model.to(self.device)
+ cond = self.image_cond_model(image)
+ if self.low_vram:
+ self.image_cond_model.cpu()
+ if not include_neg_cond:
+ return {'cond': cond}
+ neg_cond = torch.zeros_like(cond)
+ return {
+ 'cond': cond,
+ 'neg_cond': neg_cond,
+ }
+
+ def encode_shape_slat(
+ self,
+ mesh: trimesh.Trimesh,
+ resolution: int = 1024,
+ ) -> SparseTensor:
+ """
+ Encode the meshes to structured latent.
+
+ Args:
+ mesh (trimesh.Trimesh): The mesh to encode.
+ resolution (int): The resolution of mesh
+
+ Returns:
+ SparseTensor: The encoded structured latent.
+ """
+ vertices = torch.from_numpy(mesh.vertices).float()
+ faces = torch.from_numpy(mesh.faces).long()
+
+ voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
+ vertices.cpu(), faces.cpu(),
+ grid_size=resolution,
+ aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],
+ face_weight=1.0,
+ boundary_weight=0.2,
+ regularization_weight=1e-2,
+ timing=True,
+ )
+
+ vertices = SparseTensor(
+ feats=dual_vertices * resolution - voxel_indices,
+ coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1)
+ ).to(self.device)
+ intersected = vertices.replace(intersected).to(self.device)
+
+ if self.low_vram:
+ self.models['shape_slat_encoder'].to(self.device)
+ shape_slat = self.models['shape_slat_encoder'](vertices, intersected)
+ if self.low_vram:
+ self.models['shape_slat_encoder'].cpu()
+ return shape_slat
+
+ def sample_tex_slat(
+ self,
+ cond: dict,
+ flow_model,
+ shape_slat: SparseTensor,
+ sampler_params: dict = {},
+ ) -> SparseTensor:
+ """
+ Sample structured latent with the given conditioning.
+
+ Args:
+ cond (dict): The conditioning information.
+ shape_slat (SparseTensor): The structured latent for shape
+ sampler_params (dict): Additional parameters for the sampler.
+ """
+ # Sample structured latent
+ std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
+ mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
+ shape_slat = (shape_slat - mean) / std
+
+ in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
+ noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
+ sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
+ if self.low_vram:
+ flow_model.to(self.device)
+ slat = self.tex_slat_sampler.sample(
+ flow_model,
+ noise,
+ concat_cond=shape_slat,
+ **cond,
+ **sampler_params,
+ verbose=True,
+ tqdm_desc="Sampling texture SLat",
+ ).samples
+ if self.low_vram:
+ flow_model.cpu()
+
+ std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
+ mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
+ slat = slat * std + mean
+
+ return slat
+
+ def decode_tex_slat(
+ self,
+ slat: SparseTensor,
+ ) -> SparseTensor:
+ """
+ Decode the structured latent.
+
+ Args:
+ slat (SparseTensor): The structured latent.
+
+ Returns:
+ SparseTensor: The decoded texture voxels
+ """
+ if self.low_vram:
+ self.models['tex_slat_decoder'].to(self.device)
+ ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5
+ if self.low_vram:
+ self.models['tex_slat_decoder'].cpu()
+ return ret
+
+ def postprocess_mesh(
+ self,
+ mesh: trimesh.Trimesh,
+ pbr_voxel: SparseTensor,
+ resolution: int = 1024,
+ texture_size: int = 1024,
+ ) -> trimesh.Trimesh:
+ vertices = mesh.vertices
+ faces = mesh.faces
+ normals = mesh.vertex_normals
+ vertices_torch = torch.from_numpy(vertices).float().cuda()
+ faces_torch = torch.from_numpy(faces).int().cuda()
+ if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
+ uvs = mesh.visual.uv.copy()
+ uvs[:, 1] = 1 - uvs[:, 1]
+ uvs_torch = torch.from_numpy(uvs).float().cuda()
+ else:
+ _cumesh = cumesh.CuMesh()
+ _cumesh.init(vertices_torch, faces_torch)
+ vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True)
+ vertices_torch = vertices_torch.cuda()
+ faces_torch = faces_torch.cuda()
+ uvs_torch = uvs_torch.cuda()
+ vertices = vertices_torch.cpu().numpy()
+ faces = faces_torch.cpu().numpy()
+ uvs = uvs_torch.cpu().numpy()
+ normals = normals[vmap.cpu().numpy()]
+
+ # rasterize
+ ctx = dr.RasterizeCudaContext()
+ uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0)
+ rast, _ = dr.rasterize(
+ ctx, uvs_torch, faces_torch,
+ resolution=[texture_size, texture_size],
+ )
+ mask = rast[0, ..., 3] > 0
+ pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0]
+
+ attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device)
+ attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d(
+ pbr_voxel.feats,
+ pbr_voxel.coords,
+ shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]),
+ grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3),
+ mode='trilinear',
+ )
+
+ # construct mesh
+ mask = mask.cpu().numpy()
+ base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
+ metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
+ roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
+ alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
+
+ # extend
+ mask = (~mask).astype(np.uint8)
+ base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA)
+ metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None]
+ roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None]
+ alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None]
+
+ material = trimesh.visual.material.PBRMaterial(
+ baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)),
+ baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8),
+ metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)),
+ metallicFactor=1.0,
+ roughnessFactor=1.0,
+ alphaMode='OPAQUE',
+ doubleSided=True,
+ )
+
+ # Swap Y and Z axes, invert Y (common conversion for GLB compatibility)
+ vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1]
+ normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1]
+ uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate
+
+ textured_mesh = trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ vertex_normals=normals,
+ process=False,
+ visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)
+ )
+
+ return textured_mesh
+
+
+ @torch.no_grad()
+ def run(
+ self,
+ mesh: trimesh.Trimesh,
+ image: Image.Image,
+ seed: int = 42,
+ tex_slat_sampler_params: dict = {},
+ preprocess_image: bool = True,
+ resolution: int = 1024,
+ texture_size: int = 2048,
+ ) -> trimesh.Trimesh:
+ """
+ Run the pipeline.
+
+ Args:
+ mesh (trimesh.Trimesh): The mesh to texture.
+ image (Image.Image): The image prompt.
+ seed (int): The random seed.
+ tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler.
+ preprocess_image (bool): Whether to preprocess the image.
+ """
+ if preprocess_image:
+ image = self.preprocess_image(image)
+ mesh = self.preprocess_mesh(mesh)
+ torch.manual_seed(seed)
+ cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024)
+ shape_slat = self.encode_shape_slat(mesh, resolution)
+ tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024']
+ tex_slat = self.sample_tex_slat(
+ cond, tex_model,
+ shape_slat, tex_slat_sampler_params
+ )
+ pbr_voxel = self.decode_tex_slat(tex_slat)
+ out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size)
+ return out_mesh
diff --git a/trellis2/renderers/__init__.py b/trellis2/renderers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..de3203d1bb16065f912bff039e431f609911782d
--- /dev/null
+++ b/trellis2/renderers/__init__.py
@@ -0,0 +1,33 @@
+import importlib
+
+__attributes = {
+ 'MeshRenderer': 'mesh_renderer',
+ 'VoxelRenderer': 'voxel_renderer',
+ 'PbrMeshRenderer': 'pbr_mesh_renderer',
+ 'EnvMap': 'pbr_mesh_renderer',
+}
+
+__submodules = []
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+ if name not in globals():
+ if name in __attributes:
+ module_name = __attributes[name]
+ module = importlib.import_module(f".{module_name}", __name__)
+ globals()[name] = getattr(module, name)
+ elif name in __submodules:
+ module = importlib.import_module(f".{name}", __name__)
+ globals()[name] = module
+ else:
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+ return globals()[name]
+
+
+# For Pylance
+if __name__ == '__main__':
+ from .mesh_renderer import MeshRenderer
+ from .voxel_renderer import VoxelRenderer
+ from .pbr_mesh_renderer import PbrMeshRenderer, EnvMap
+
\ No newline at end of file
diff --git a/trellis2/renderers/mesh_renderer.py b/trellis2/renderers/mesh_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e20efc59c26fadcbd9c11cbafa2e31b78eec395c
--- /dev/null
+++ b/trellis2/renderers/mesh_renderer.py
@@ -0,0 +1,414 @@
+from typing import *
+import torch
+from easydict import EasyDict as edict
+from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode
+import torch.nn.functional as F
+
+
+def intrinsics_to_projection(
+ intrinsics: torch.Tensor,
+ near: float,
+ far: float,
+ ) -> torch.Tensor:
+ """
+ OpenCV intrinsics to OpenGL perspective matrix
+
+ Args:
+ intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
+ near (float): near plane to clip
+ far (float): far plane to clip
+ Returns:
+ (torch.Tensor): [4, 4] OpenGL perspective matrix
+ """
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
+ ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
+ ret[0, 0] = 2 * fx
+ ret[1, 1] = 2 * fy
+ ret[0, 2] = 2 * cx - 1
+ ret[1, 2] = - 2 * cy + 1
+ ret[2, 2] = (far + near) / (far - near)
+ ret[2, 3] = 2 * near * far / (near - far)
+ ret[3, 2] = 1.
+ return ret
+
+
+class MeshRenderer:
+ """
+ Renderer for the Mesh representation.
+
+ Args:
+ rendering_options (dict): Rendering options.
+ """
+ def __init__(self, rendering_options={}, device='cuda'):
+ if 'dr' not in globals():
+ import nvdiffrast.torch as dr
+
+ self.rendering_options = edict({
+ "resolution": None,
+ "near": None,
+ "far": None,
+ "ssaa": 1,
+ "chunk_size": None,
+ "antialias": True,
+ "clamp_barycentric_coords": False,
+ })
+ self.rendering_options.update(rendering_options)
+ self.glctx = dr.RasterizeCudaContext(device=device)
+ self.device=device
+
+ def render(
+ self,
+ mesh : Mesh,
+ extrinsics: torch.Tensor,
+ intrinsics: torch.Tensor,
+ return_types = ["mask", "normal", "depth"],
+ transformation : Optional[torch.Tensor] = None
+ ) -> edict:
+ """
+ Render the mesh.
+
+ Args:
+ mesh : meshmodel
+ extrinsics (torch.Tensor): (4, 4) camera extrinsics
+ intrinsics (torch.Tensor): (3, 3) camera intrinsics
+ return_types (list): list of return types, can be "attr", "mask", "depth", "coord", "normal"
+
+ Returns:
+ edict based on return_types containing:
+ attr (torch.Tensor): [C, H, W] rendered attr image
+ depth (torch.Tensor): [H, W] rendered depth image
+ normal (torch.Tensor): [3, H, W] rendered normal image
+ mask (torch.Tensor): [H, W] rendered mask image
+ """
+ if 'dr' not in globals():
+ import nvdiffrast.torch as dr
+
+ resolution = self.rendering_options["resolution"]
+ near = self.rendering_options["near"]
+ far = self.rendering_options["far"]
+ ssaa = self.rendering_options["ssaa"]
+ chunk_size = self.rendering_options["chunk_size"]
+ antialias = self.rendering_options["antialias"]
+ clamp_barycentric_coords = self.rendering_options["clamp_barycentric_coords"]
+
+ if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
+ ret_dict = edict()
+ for type in return_types:
+ if type == "mask" :
+ ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device)
+ elif type == "depth":
+ ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device)
+ elif type == "normal":
+ ret_dict[type] = torch.full((3, resolution, resolution), 0.5, dtype=torch.float32, device=self.device)
+ elif type == "coord":
+ ret_dict[type] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device)
+ elif type == "attr":
+ if isinstance(mesh, MeshWithVoxel):
+ ret_dict[type] = torch.zeros((mesh.attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device)
+ else:
+ ret_dict[type] = torch.zeros((mesh.vertex_attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device)
+ return ret_dict
+
+ perspective = intrinsics_to_projection(intrinsics, near, far)
+
+ full_proj = (perspective @ extrinsics).unsqueeze(0)
+ extrinsics = extrinsics.unsqueeze(0)
+
+ vertices = mesh.vertices.unsqueeze(0)
+ vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
+ if transformation is not None:
+ vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2))
+ vertices = vertices_homo[..., :3].contiguous()
+ vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2))
+ vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
+ faces = mesh.faces
+
+ if 'normal' in return_types:
+ v0 = vertices_camera[0, mesh.faces[:, 0], :3]
+ v1 = vertices_camera[0, mesh.faces[:, 1], :3]
+ v2 = vertices_camera[0, mesh.faces[:, 2], :3]
+ e0 = v1 - v0
+ e1 = v2 - v0
+ face_normal = torch.cross(e0, e1, dim=1)
+ face_normal = F.normalize(face_normal, dim=1)
+ face_normal = torch.where(torch.sum(face_normal * v0, dim=1, keepdim=True) > 0, face_normal, -face_normal)
+
+ out_dict = edict()
+ if chunk_size is None:
+ rast, rast_db = dr.rasterize(
+ self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa)
+ )
+ if clamp_barycentric_coords:
+ rast[..., :2] = torch.clamp(rast[..., :2], 0, 1)
+ rast[..., :2] /= torch.where(rast[..., :2].sum(dim=-1, keepdim=True) > 1, rast[..., :2].sum(dim=-1, keepdim=True), torch.ones_like(rast[..., :2]))
+ for type in return_types:
+ img = None
+ if type == "mask" :
+ img = (rast[..., -1:] > 0).float()
+ if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
+ elif type == "depth":
+ img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0]
+ if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
+ elif type == "normal" :
+ img = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0]
+ if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
+ img = (img + 1) / 2
+ elif type == "coord":
+ img = dr.interpolate(vertices, rast, faces)[0]
+ if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
+ elif type == "attr":
+ if isinstance(mesh, MeshWithVoxel):
+ if 'grid_sample_3d' not in globals():
+ from flex_gemm.ops.grid_sample import grid_sample_3d
+ mask = rast[..., -1:] > 0
+ xyz = dr.interpolate(vertices, rast, faces)[0]
+ xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3)
+ img = grid_sample_3d(
+ mesh.attrs,
+ torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1),
+ mesh.voxel_shape,
+ xyz,
+ mode='trilinear'
+ )
+ img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask
+ elif isinstance(mesh, MeshWithPbrMaterial):
+ tri_id = rast[0, :, :, -1:]
+ mask = tri_id > 0
+ uv_coords = mesh.uv_coords.reshape(1, -1, 2)
+ texc, texd = dr.interpolate(
+ uv_coords,
+ rast,
+ torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3),
+ rast_db=rast_db,
+ diff_attrs='all'
+ )
+ # Fix problematic texture coordinates
+ texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3)
+ texc = torch.clamp(texc, min=-1e3, max=1e3)
+ texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3)
+ texd = torch.clamp(texd, min=-1e3, max=1e3)
+ mid = mesh.material_ids[(tri_id - 1).long()]
+ imgs = {
+ 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device),
+ 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device),
+ 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device),
+ 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
+ }
+ for id, mat in enumerate(mesh.materials):
+ mat_mask = (mid == id).float() * mask.float()
+ mat_texc = texc * mat_mask
+ mat_texd = texd * mat_mask
+
+ if mat.base_color_texture is not None:
+ base_color = dr.texture(
+ mat.base_color_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ imgs['base_color'] += base_color * mat.base_color_factor * mat_mask
+ else:
+ imgs['base_color'] += mat.base_color_factor * mat_mask
+
+ if mat.metallic_texture is not None:
+ metallic = dr.texture(
+ mat.metallic_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ imgs['metallic'] += metallic * mat.metallic_factor * mat_mask
+ else:
+ imgs['metallic'] += mat.metallic_factor * mat_mask
+
+ if mat.roughness_texture is not None:
+ roughness = dr.texture(
+ mat.roughness_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ imgs['roughness'] += roughness * mat.roughness_factor * mat_mask
+ else:
+ imgs['roughness'] += mat.roughness_factor * mat_mask
+
+ if mat.alpha_mode == AlphaMode.OPAQUE:
+ imgs['alpha'] += 1.0 * mat_mask
+ else:
+ if mat.alpha_texture is not None:
+ alpha = dr.texture(
+ mat.alpha_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ if mat.alpha_mode == AlphaMode.MASK:
+ imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
+ elif mat.alpha_mode == AlphaMode.BLEND:
+ imgs['alpha'] += alpha * mat.alpha_factor * mat_mask
+ else:
+ if mat.alpha_mode == AlphaMode.MASK:
+ imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
+ elif mat.alpha_mode == AlphaMode.BLEND:
+ imgs['alpha'] += mat.alpha_factor * mat_mask
+
+ img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0)
+ else:
+ img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces)[0]
+ if antialias: img = dr.antialias(img, rast, vertices_clip, faces)
+
+ out_dict[type] = img
+ else:
+ z_buffer = torch.full((1, resolution * ssaa, resolution * ssaa), torch.inf, device=self.device, dtype=torch.float32)
+ for i in range(0, faces.shape[0], chunk_size):
+ faces_chunk = faces[i:i+chunk_size]
+ rast, rast_db = dr.rasterize(
+ self.glctx, vertices_clip, faces_chunk, (resolution * ssaa, resolution * ssaa)
+ )
+ z_filter = torch.logical_and(
+ rast[..., 3] != 0,
+ rast[..., 2] < z_buffer
+ )
+ z_buffer[z_filter] = rast[z_filter][..., 2]
+
+ for type in return_types:
+ img = None
+ if type == "mask" :
+ img = (rast[..., -1:] > 0).float()
+ elif type == "depth":
+ img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_chunk)[0]
+ elif type == "normal" :
+ face_normal_chunk = face_normal[i:i+chunk_size]
+ img = dr.interpolate(face_normal_chunk.unsqueeze(0), rast, torch.arange(face_normal_chunk.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0]
+ img = (img + 1) / 2
+ elif type == "coord":
+ img = dr.interpolate(vertices, rast, faces_chunk)[0]
+ elif type == "attr":
+ if isinstance(mesh, MeshWithVoxel):
+ if 'grid_sample_3d' not in globals():
+ from flex_gemm.ops.grid_sample import grid_sample_3d
+ mask = rast[..., -1:] > 0
+ xyz = dr.interpolate(vertices, rast, faces_chunk)[0]
+ xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3)
+ img = grid_sample_3d(
+ mesh.attrs,
+ torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1),
+ mesh.voxel_shape,
+ xyz,
+ mode='trilinear'
+ )
+ img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask
+ elif isinstance(mesh, MeshWithPbrMaterial):
+ tri_id = rast[0, :, :, -1:]
+ mask = tri_id > 0
+ uv_coords = mesh.uv_coords.reshape(1, -1, 2)
+ texc, texd = dr.interpolate(
+ uv_coords,
+ rast,
+ torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3),
+ rast_db=rast_db,
+ diff_attrs='all'
+ )
+ # Fix problematic texture coordinates
+ texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3)
+ texc = torch.clamp(texc, min=-1e3, max=1e3)
+ texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3)
+ texd = torch.clamp(texd, min=-1e3, max=1e3)
+ mid = mesh.material_ids[(tri_id - 1).long()]
+ imgs = {
+ 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device),
+ 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device),
+ 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device),
+ 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
+ }
+ for id, mat in enumerate(mesh.materials):
+ mat_mask = (mid == id).float() * mask.float()
+ mat_texc = texc * mat_mask
+ mat_texd = texd * mat_mask
+
+ if mat.base_color_texture is not None:
+ base_color = dr.texture(
+ mat.base_color_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ imgs['base_color'] += base_color * mat.base_color_factor * mat_mask
+ else:
+ imgs['base_color'] += mat.base_color_factor * mat_mask
+
+ if mat.metallic_texture is not None:
+ metallic = dr.texture(
+ mat.metallic_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ imgs['metallic'] += metallic * mat.metallic_factor * mat_mask
+ else:
+ imgs['metallic'] += mat.metallic_factor * mat_mask
+
+ if mat.roughness_texture is not None:
+ roughness = dr.texture(
+ mat.roughness_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ imgs['roughness'] += roughness * mat.roughness_factor * mat_mask
+ else:
+ imgs['roughness'] += mat.roughness_factor * mat_mask
+
+ if mat.alpha_mode == AlphaMode.OPAQUE:
+ imgs['alpha'] += 1.0 * mat_mask
+ else:
+ if mat.alpha_texture is not None:
+ alpha = dr.texture(
+ mat.alpha_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ if mat.alpha_mode == AlphaMode.MASK:
+ imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
+ elif mat.alpha_mode == AlphaMode.BLEND:
+ imgs['alpha'] += alpha * mat.alpha_factor * mat_mask
+ else:
+ if mat.alpha_mode == AlphaMode.MASK:
+ imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
+ elif mat.alpha_mode == AlphaMode.BLEND:
+ imgs['alpha'] += mat.alpha_factor * mat_mask
+
+ img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0)
+ else:
+ img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces_chunk)[0]
+
+ if type not in out_dict:
+ out_dict[type] = img
+ else:
+ out_dict[type][z_filter] = img[z_filter]
+
+ for type in return_types:
+ img = out_dict[type]
+ if ssaa > 1:
+ img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
+ img = img.squeeze()
+ else:
+ img = img.permute(0, 3, 1, 2).squeeze()
+ out_dict[type] = img
+
+ if isinstance(mesh, (MeshWithVoxel, MeshWithPbrMaterial)) and 'attr' in return_types:
+ for k, s in mesh.layout.items():
+ out_dict[k] = out_dict['attr'][s]
+ del out_dict['attr']
+
+ return out_dict
diff --git a/trellis2/renderers/pbr_mesh_renderer.py b/trellis2/renderers/pbr_mesh_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..14ded300c861e163e15c1bff4c3dd7ee27a08e7c
--- /dev/null
+++ b/trellis2/renderers/pbr_mesh_renderer.py
@@ -0,0 +1,521 @@
+from typing import *
+import torch
+from easydict import EasyDict as edict
+import numpy as np
+import utils3d
+from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode
+import torch.nn.functional as F
+
+
+def cube_to_dir(s, x, y):
+ if s == 0: rx, ry, rz = torch.ones_like(x), -x, -y
+ elif s == 1: rx, ry, rz = -torch.ones_like(x), x, -y
+ elif s == 2: rx, ry, rz = x, y, torch.ones_like(x)
+ elif s == 3: rx, ry, rz = x, -y, -torch.ones_like(x)
+ elif s == 4: rx, ry, rz = x, torch.ones_like(x), -y
+ elif s == 5: rx, ry, rz = -x, -torch.ones_like(x), -y
+ return torch.stack((rx, ry, rz), dim=-1)
+
+
+def latlong_to_cubemap(latlong_map, res):
+ if 'dr' not in globals():
+ import nvdiffrast.torch as dr
+ cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda')
+ for s in range(6):
+ gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'),
+ torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'),
+ indexing='ij')
+ v = F.normalize(cube_to_dir(s, gx, gy), dim=-1)
+
+ tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5
+ tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi
+ texcoord = torch.cat((tu, tv), dim=-1)
+
+ cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0]
+ return cubemap
+
+
+class EnvMap:
+ def __init__(self, image: torch.Tensor):
+ self.image = image
+
+ @property
+ def _backend(self):
+ if not hasattr(self, '_nvdiffrec_envlight'):
+ if 'EnvironmentLight' not in globals():
+ from nvdiffrec_render.light import EnvironmentLight
+ cubemap = latlong_to_cubemap(self.image, [512, 512])
+ self._nvdiffrec_envlight = EnvironmentLight(cubemap)
+ self._nvdiffrec_envlight.build_mips()
+ return self._nvdiffrec_envlight
+
+ def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True):
+ return self._backend.shade(gb_pos, gb_normal, kd, ks, view_pos, specular)
+
+ def sample(self, directions: torch.Tensor):
+ if 'dr' not in globals():
+ import nvdiffrast.torch as dr
+ return dr.texture(
+ self._backend.base.unsqueeze(0),
+ directions.unsqueeze(0),
+ boundary_mode='cube',
+ )[0]
+
+
+def intrinsics_to_projection(
+ intrinsics: torch.Tensor,
+ near: float,
+ far: float,
+ ) -> torch.Tensor:
+ """
+ OpenCV intrinsics to OpenGL perspective matrix
+
+ Args:
+ intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
+ near (float): near plane to clip
+ far (float): far plane to clip
+ Returns:
+ (torch.Tensor): [4, 4] OpenGL perspective matrix
+ """
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
+ ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
+ ret[0, 0] = 2 * fx
+ ret[1, 1] = 2 * fy
+ ret[0, 2] = 2 * cx - 1
+ ret[1, 2] = - 2 * cy + 1
+ ret[2, 2] = (far + near) / (far - near)
+ ret[2, 3] = 2 * near * far / (near - far)
+ ret[3, 2] = 1.
+ return ret
+
+
+def screen_space_ambient_occlusion(
+ depth: torch.Tensor,
+ normal: torch.Tensor,
+ perspective: torch.Tensor,
+ radius: float = 0.1,
+ bias: float = 1e-6,
+ samples: int = 64,
+ intensity: float = 1.0,
+) -> torch.Tensor:
+ """
+ Screen space ambient occlusion (SSAO)
+
+ Args:
+ depth (torch.Tensor): [H, W, 1] depth image
+ normal (torch.Tensor): [H, W, 3] normal image
+ perspective (torch.Tensor): [4, 4] camera projection matrix
+ radius (float): radius of the SSAO kernel
+ bias (float): bias to avoid self-occlusion
+ samples (int): number of samples to use for the SSAO kernel
+ intensity (float): intensity of the SSAO effect
+ Returns:
+ (torch.Tensor): [H, W, 1] SSAO image
+ """
+ device = depth.device
+ H, W, _ = depth.shape
+
+ fx = perspective[0, 0]
+ fy = perspective[1, 1]
+ cx = perspective[0, 2]
+ cy = perspective[1, 2]
+
+ y_grid, x_grid = torch.meshgrid(
+ (torch.arange(H, device=device) + 0.5) / H * 2 - 1,
+ (torch.arange(W, device=device) + 0.5) / W * 2 - 1,
+ indexing='ij'
+ )
+ x_view = (x_grid.float() - cx) * depth[..., 0] / fx
+ y_view = (y_grid.float() - cy) * depth[..., 0] / fy
+ view_pos = torch.stack([x_view, y_view, depth[..., 0]], dim=-1) # [H, W, 3]
+
+ depth_feat = depth.permute(2, 0, 1).unsqueeze(0)
+ occlusion = torch.zeros((H, W), device=device)
+
+ # start sampling
+ for _ in range(samples):
+ # sample normal distribution, if inside, flip the sign
+ rnd_vec = torch.randn(H, W, 3, device=device)
+ rnd_vec = F.normalize(rnd_vec, p=2, dim=-1)
+ dot_val = torch.sum(rnd_vec * normal, dim=-1, keepdim=True)
+ sample_dir = torch.sign(dot_val) * rnd_vec
+ scale = torch.rand(H, W, 1, device=device)
+ scale = scale * scale
+ sample_pos = view_pos + sample_dir * radius * scale
+ sample_z = sample_pos[..., 2]
+
+ # project to screen space
+ z_safe = torch.clamp(sample_pos[..., 2], min=1e-5)
+ proj_u = (sample_pos[..., 0] * fx / z_safe) + cx
+ proj_v = (sample_pos[..., 1] * fy / z_safe) + cy
+ grid = torch.stack([proj_u, proj_v], dim=-1).unsqueeze(0)
+ geo_z = F.grid_sample(depth_feat, grid, mode='nearest', padding_mode='border').squeeze()
+ range_check = torch.abs(geo_z - sample_z) < radius
+ is_occluded = (geo_z <= sample_z - bias) & range_check
+ occlusion += is_occluded.float()
+
+ f_occ = occlusion / samples * intensity
+ f_occ = torch.clamp(f_occ, 0.0, 1.0)
+
+ return f_occ.unsqueeze(-1)
+
+
+def aces_tonemapping(x: torch.Tensor) -> torch.Tensor:
+ """
+ Applies ACES tone mapping curve to an HDR image tensor.
+ Input: x - HDR tensor, shape (..., 3), range [0, +inf)
+ Output: LDR tensor, same shape, range [0, 1]
+
+ NOTE: This causes bright base_color objects to over-expose (appear white)
+ compared to Blender's standard sRGB display transform. Use linear_to_srgb()
+ instead for renders_cond alignment. See pbr_color_diff_debug_guide.md #5.
+ """
+ a = 2.51
+ b = 0.03
+ c = 2.43
+ d = 0.59
+ e = 0.14
+
+ # Apply the ACES fitted curve
+ mapped = (x * (a * x + b)) / (x * (c * x + d) + e)
+
+ # Clamp to [0, 1] for display or saving
+ return torch.clamp(mapped, 0.0, 1.0)
+
+
+def gamma_correction(x: torch.Tensor, gamma: float = 2.2) -> torch.Tensor:
+ """
+ Applies gamma correction to an HDR image tensor.
+ """
+ return torch.clamp(x ** (1.0 / gamma), 0.0, 1.0)
+
+
+def linear_to_srgb(x: torch.Tensor) -> torch.Tensor:
+ """
+ Standard linear-to-sRGB conversion matching Blender's sRGB display transform.
+
+ Applies the official sRGB EOTF inverse:
+ - For values <= 0.0031308: sRGB = 12.92 * linear
+ - For values > 0.0031308: sRGB = 1.055 * linear^(1/2.4) - 0.055
+
+ Input: x - linear HDR tensor, clamped to [0, +inf)
+ Output: sRGB tensor, range [0, 1]
+ """
+ x = torch.clamp(x, 0.0)
+ low = 12.92 * x
+ high = 1.055 * torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) - 0.055
+ return torch.clamp(torch.where(x <= 0.0031308, low, high), 0.0, 1.0)
+
+
+class PbrMeshRenderer:
+ """
+ Renderer for the PBR mesh.
+
+ Args:
+ rendering_options (dict): Rendering options.
+ """
+ def __init__(self, rendering_options={}, device='cuda'):
+ if 'dr' not in globals():
+ import nvdiffrast.torch as dr
+
+ self.rendering_options = edict({
+ "resolution": None,
+ "near": None,
+ "far": None,
+ "ssaa": 1,
+ "peel_layers": 8,
+ })
+ self.rendering_options.update(rendering_options)
+ self.glctx = dr.RasterizeCudaContext(device=device)
+ self.device=device
+
+ def render(
+ self,
+ mesh : Mesh,
+ extrinsics: torch.Tensor,
+ intrinsics: torch.Tensor,
+ envmap : Union[EnvMap, Dict[str, EnvMap]],
+ use_envmap_bg : bool = False,
+ transformation : Optional[torch.Tensor] = None
+ ) -> edict:
+ """
+ Render the mesh.
+
+ Args:
+ mesh : meshmodel
+ extrinsics (torch.Tensor): (4, 4) camera extrinsics
+ intrinsics (torch.Tensor): (3, 3) camera intrinsics
+ envmap (Union[EnvMap, Dict[str, EnvMap]]): environment map or a dictionary of environment maps
+ use_envmap_bg (bool): whether to use envmap as background
+ transformation (torch.Tensor): (4, 4) transformation matrix
+
+ Returns:
+ edict based on return_types containing:
+ shaded (torch.Tensor): [3, H, W] shaded color image
+ normal (torch.Tensor): [3, H, W] normal image
+ base_color (torch.Tensor): [3, H, W] base color image
+ metallic (torch.Tensor): [H, W] metallic image
+ roughness (torch.Tensor): [H, W] roughness image
+ """
+ if 'dr' not in globals():
+ import nvdiffrast.torch as dr
+
+ if not isinstance(envmap, dict):
+ envmap = {'' : envmap}
+ num_envmaps = len(envmap)
+
+ resolution = self.rendering_options["resolution"]
+ near = self.rendering_options["near"]
+ far = self.rendering_options["far"]
+ ssaa = self.rendering_options["ssaa"]
+
+ if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0 or \
+ not torch.isfinite(mesh.vertices).all() or \
+ mesh.faces.max() >= mesh.vertices.shape[0] or mesh.faces.min() < 0:
+ if mesh.vertices.shape[0] > 0 and not torch.isfinite(mesh.vertices).all():
+ print(f"[PbrMeshRenderer] WARNING: mesh vertices contain NaN/Inf, returning blank image.")
+ if mesh.faces.shape[0] > 0 and mesh.vertices.shape[0] > 0 and \
+ (mesh.faces.max() >= mesh.vertices.shape[0] or mesh.faces.min() < 0):
+ print(f"[PbrMeshRenderer] WARNING: mesh faces contain out-of-bound indices "
+ f"(max={mesh.faces.max().item()}, min={mesh.faces.min().item()}, "
+ f"num_vertices={mesh.vertices.shape[0]}), returning blank image.")
+ out_dict = edict(
+ normal=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device),
+ mask=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
+ base_color=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device),
+ metallic=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
+ roughness=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
+ alpha=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
+ clay=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device),
+ )
+ for i, k in enumerate(envmap.keys()):
+ shaded_key = f"shaded_{k}" if k != '' else "shaded"
+ out_dict[shaded_key] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device)
+ return out_dict
+
+ rays_o, rays_d = utils3d.torch.get_image_rays(
+ extrinsics, intrinsics, resolution * ssaa, resolution * ssaa
+ )
+
+ perspective = intrinsics_to_projection(intrinsics, near, far)
+
+ full_proj = (perspective @ extrinsics).unsqueeze(0)
+ extrinsics = extrinsics.unsqueeze(0)
+
+ vertices = mesh.vertices.unsqueeze(0)
+ vertices_orig = vertices.clone()
+ vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
+ if transformation is not None:
+ vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2))
+ vertices = vertices_homo[..., :3].contiguous()
+ vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2))
+ vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
+ faces = mesh.faces
+
+ v0 = vertices[0, mesh.faces[:, 0], :3]
+ v1 = vertices[0, mesh.faces[:, 1], :3]
+ v2 = vertices[0, mesh.faces[:, 2], :3]
+ e0 = v1 - v0
+ e1 = v2 - v0
+ face_normal = torch.cross(e0, e1, dim=1)
+ face_normal = F.normalize(face_normal, dim=1)
+
+ out_dict = edict()
+ shaded = torch.zeros((num_envmaps, resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device)
+ depth = torch.full((resolution * ssaa, resolution * ssaa, 1), 1e10, dtype=torch.float32, device=self.device)
+ normal = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device)
+ max_w = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
+ alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
+ with dr.DepthPeeler(self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa)) as peeler:
+ for _ in range(self.rendering_options["peel_layers"]):
+ rast, rast_db = peeler.rasterize_next_layer()
+
+ # Pos
+ pos = dr.interpolate(vertices, rast, faces)[0][0]
+
+ # Depth
+ gb_depth = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0][0]
+
+ # Normal
+ gb_normal = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0][0]
+ gb_normal = torch.where(
+ torch.sum(gb_normal * (pos - rays_o), dim=-1, keepdim=True) > 0,
+ -gb_normal,
+ gb_normal
+ )
+ gb_cam_normal = (extrinsics[..., :3, :3].reshape(1, 1, 3, 3) @ gb_normal.unsqueeze(-1)).squeeze(-1)
+ if _ == 0:
+ out_dict.normal = -gb_cam_normal * 0.5 + 0.5
+ mask = (rast[0, ..., -1:] > 0).float()
+ out_dict.mask = mask
+
+ # PBR attributes
+ if isinstance(mesh, MeshWithVoxel):
+ if 'grid_sample_3d' not in globals():
+ from flex_gemm.ops.grid_sample import grid_sample_3d
+ mask = rast[..., -1:] > 0
+ xyz = dr.interpolate(vertices_orig, rast, faces)[0]
+ xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3)
+ img = grid_sample_3d(
+ mesh.attrs,
+ torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1),
+ mesh.voxel_shape,
+ xyz,
+ mode='trilinear'
+ )
+ img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask
+ gb_basecolor = img[0, ..., mesh.layout['base_color']]
+ gb_metallic = img[0, ..., mesh.layout['metallic']]
+ gb_roughness = img[0, ..., mesh.layout['roughness']]
+ gb_alpha = img[0, ..., mesh.layout['alpha']]
+ elif isinstance(mesh, MeshWithPbrMaterial):
+ tri_id = rast[0, :, :, -1:]
+ mask = tri_id > 0
+ uv_coords = mesh.uv_coords.reshape(1, -1, 2)
+ texc, texd = dr.interpolate(
+ uv_coords,
+ rast,
+ torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3),
+ rast_db=rast_db,
+ diff_attrs='all'
+ )
+ # Fix problematic texture coordinates
+ texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3)
+ texc = torch.clamp(texc, min=-1e3, max=1e3)
+ texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3)
+ texd = torch.clamp(texd, min=-1e3, max=1e3)
+ mid = mesh.material_ids[(tri_id - 1).long()]
+ gb_basecolor = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device)
+ gb_metallic = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
+ gb_roughness = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
+ gb_alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device)
+ for id, mat in enumerate(mesh.materials):
+ mat_mask = (mid == id).float() * mask.float()
+ mat_texc = texc * mat_mask
+ mat_texd = texd * mat_mask
+
+ if mat.base_color_texture is not None:
+ bc = dr.texture(
+ mat.base_color_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ gb_basecolor += bc * mat.base_color_factor * mat_mask
+ else:
+ gb_basecolor += mat.base_color_factor * mat_mask
+
+ if mat.metallic_texture is not None:
+ m = dr.texture(
+ mat.metallic_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ gb_metallic += m * mat.metallic_factor * mat_mask
+ else:
+ gb_metallic += mat.metallic_factor * mat_mask
+
+ if mat.roughness_texture is not None:
+ r = dr.texture(
+ mat.roughness_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ gb_roughness += r * mat.roughness_factor * mat_mask
+ else:
+ gb_roughness += mat.roughness_factor * mat_mask
+
+ if mat.alpha_mode == AlphaMode.OPAQUE:
+ gb_alpha += 1.0 * mat_mask
+ else:
+ if mat.alpha_texture is not None:
+ a = dr.texture(
+ mat.alpha_texture.image.unsqueeze(0),
+ mat_texc,
+ mat_texd,
+ filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest',
+ boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap'
+ )[0]
+ if mat.alpha_mode == AlphaMode.MASK:
+ gb_alpha += (a * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
+ elif mat.alpha_mode == AlphaMode.BLEND:
+ gb_alpha += a * mat.alpha_factor * mat_mask
+ else:
+ if mat.alpha_mode == AlphaMode.MASK:
+ gb_alpha += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask
+ elif mat.alpha_mode == AlphaMode.BLEND:
+ gb_alpha += mat.alpha_factor * mat_mask
+ if _ == 0:
+ out_dict.base_color = gb_basecolor
+ out_dict.metallic = gb_metallic
+ out_dict.roughness = gb_roughness
+ out_dict.alpha = gb_alpha
+
+ # Shading
+ #TODO
+ gb_basecolor = torch.clamp(gb_basecolor, 0.0, 1.0) ** 2.2
+ # gb_basecolor = torch.clamp(gb_basecolor, 0.0, 1.0)
+ gb_metallic = torch.clamp(gb_metallic, 0.0, 1.0)
+ gb_roughness = torch.clamp(gb_roughness, 0.0, 1.0)
+ gb_alpha = torch.clamp(gb_alpha, 0.0, 1.0)
+ gb_orm = torch.cat([
+ torch.zeros_like(gb_metallic),
+ gb_roughness,
+ gb_metallic,
+ ], dim=-1)
+ gb_shaded = torch.stack([
+ e.shade(
+ pos.unsqueeze(0),
+ gb_normal.unsqueeze(0),
+ gb_basecolor.unsqueeze(0),
+ gb_orm.unsqueeze(0),
+ rays_o,
+ specular=True,
+ )[0]
+ for e in envmap.values()
+ ], dim=0)
+
+ # Compositing
+ w = (1 - alpha) * gb_alpha
+ depth = torch.where(w > max_w, gb_depth, depth)
+ normal = torch.where(w > max_w, gb_cam_normal, normal)
+ max_w = torch.maximum(max_w, w)
+ shaded += w * gb_shaded
+ alpha += w
+
+ # Ambient occulusion
+ f_occ = screen_space_ambient_occlusion(
+ depth, normal, perspective, intensity=1.5
+ )
+ shaded *= (1 - f_occ)
+ out_dict.clay = (1 - f_occ)
+
+ # Background
+ if use_envmap_bg:
+ bg = torch.stack([e.sample(rays_d) for e in envmap.values()], dim=0)
+ shaded += (1 - alpha) * bg
+
+ for i, k in enumerate(envmap.keys()):
+ shaded_key = f"shaded_{k}" if k != '' else "shaded"
+ out_dict[shaded_key] = shaded[i]
+
+ # SSAA
+ for k in out_dict.keys():
+ if ssaa > 1:
+ out_dict[k] = F.interpolate(out_dict[k].unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
+ else:
+ out_dict[k] = out_dict[k].permute(2, 0, 1)
+ out_dict[k] = out_dict[k].squeeze()
+
+ # Post processing: linear → sRGB (matches Blender's display transform)
+ for k in envmap.keys():
+ shaded_key = f"shaded_{k}" if k != '' else "shaded"
+ out_dict[shaded_key] = linear_to_srgb(out_dict[shaded_key])
+
+ return out_dict
diff --git a/trellis2/renderers/voxel_renderer.py b/trellis2/renderers/voxel_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfe28ad8d341ed62c1d7a5ab739fed6cace30a5f
--- /dev/null
+++ b/trellis2/renderers/voxel_renderer.py
@@ -0,0 +1,68 @@
+import torch
+from easydict import EasyDict as edict
+from ..representations import Voxel
+from easydict import EasyDict as edict
+
+
+class VoxelRenderer:
+ """
+ Renderer for the Voxel representation.
+
+ Args:
+ rendering_options (dict): Rendering options.
+ """
+
+ def __init__(self, rendering_options={}) -> None:
+ self.rendering_options = edict({
+ "resolution": None,
+ "near": 0.1,
+ "far": 10.0,
+ "ssaa": 1,
+ })
+ self.rendering_options.update(rendering_options)
+
+ def render(
+ self,
+ voxel: Voxel,
+ extrinsics: torch.Tensor,
+ intrinsics: torch.Tensor,
+ colors_overwrite: torch.Tensor = None
+ ) -> edict:
+ """
+ Render the gausssian.
+
+ Args:
+ voxel (Voxel): Voxel representation.
+ extrinsics (torch.Tensor): (4, 4) camera extrinsics
+ intrinsics (torch.Tensor): (3, 3) camera intrinsics
+ colors_overwrite (torch.Tensor): (N, 3) override color
+
+ Returns:
+ edict containing:
+ color (torch.Tensor): (3, H, W) rendered color image
+ depth (torch.Tensor): (H, W) rendered depth
+ alpha (torch.Tensor): (H, W) rendered alpha
+ ...
+ """
+ # lazy import
+ if 'o_voxel' not in globals():
+ import o_voxel
+ renderer = o_voxel.rasterize.VoxelRenderer(self.rendering_options)
+ positions = voxel.position
+ attrs = voxel.attrs if colors_overwrite is None else colors_overwrite
+ voxel_size = voxel.voxel_size
+
+ # Render
+ render_ret = renderer.render(positions, attrs, voxel_size, extrinsics, intrinsics)
+
+ ret = {
+ 'depth': render_ret['depth'],
+ 'alpha': render_ret['alpha'],
+ }
+ if colors_overwrite is not None:
+ ret['color'] = render_ret['attr']
+ else:
+ for k, s in voxel.layout.items():
+ ret[k] = render_ret['attr'][s]
+
+ return ret
diff --git a/trellis2/representations/__init__.py b/trellis2/representations/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e7d9299f866c344e81e27f748f837e5ce81ed8b
--- /dev/null
+++ b/trellis2/representations/__init__.py
@@ -0,0 +1,31 @@
+import importlib
+
+__attributes = {
+ 'Mesh': 'mesh',
+ 'Voxel': 'voxel',
+ 'MeshWithVoxel': 'mesh',
+ 'MeshWithPbrMaterial': 'mesh',
+}
+
+__submodules = []
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+ if name not in globals():
+ if name in __attributes:
+ module_name = __attributes[name]
+ module = importlib.import_module(f".{module_name}", __name__)
+ globals()[name] = getattr(module, name)
+ elif name in __submodules:
+ module = importlib.import_module(f".{name}", __name__)
+ globals()[name] = module
+ else:
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+ return globals()[name]
+
+
+# For Pylance
+if __name__ == '__main__':
+ from .mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial
+ from .voxel import Voxel
diff --git a/trellis2/representations/mesh/__init__.py b/trellis2/representations/mesh/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff4c99a97f764a7c695d87d6a60bd03d61e2106
--- /dev/null
+++ b/trellis2/representations/mesh/__init__.py
@@ -0,0 +1 @@
+from .base import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
diff --git a/trellis2/representations/mesh/base.py b/trellis2/representations/mesh/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bdb47fe89047233628b44900124723770e66502
--- /dev/null
+++ b/trellis2/representations/mesh/base.py
@@ -0,0 +1,234 @@
+from typing import *
+import torch
+from ..voxel import Voxel
+import cumesh
+from flex_gemm.ops.grid_sample import grid_sample_3d
+
+
+class Mesh:
+ def __init__(self,
+ vertices,
+ faces,
+ vertex_attrs=None
+ ):
+ self.vertices = vertices.float()
+ self.faces = faces.int()
+ self.vertex_attrs = vertex_attrs
+
+ @property
+ def device(self):
+ return self.vertices.device
+
+ def to(self, device, non_blocking=False):
+ return Mesh(
+ self.vertices.to(device, non_blocking=non_blocking),
+ self.faces.to(device, non_blocking=non_blocking),
+ self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
+ )
+
+ def cuda(self, non_blocking=False):
+ return self.to('cuda', non_blocking=non_blocking)
+
+ def cpu(self):
+ return self.to('cpu')
+
+ def fill_holes(self, max_hole_perimeter=3e-2):
+ vertices = self.vertices.clone().cuda().contiguous()
+ faces = self.faces.clone().cuda().contiguous()
+
+ mesh = cumesh.CuMesh()
+ mesh.init(vertices, faces)
+ mesh.get_edges()
+ mesh.get_boundary_info()
+ if mesh.num_boundaries == 0:
+ return
+ mesh.get_vertex_edge_adjacency()
+ mesh.get_vertex_boundary_adjacency()
+ mesh.get_manifold_boundary_adjacency()
+ mesh.read_manifold_boundary_adjacency()
+ mesh.get_boundary_connected_components()
+ mesh.get_boundary_loops()
+ if mesh.num_boundary_loops == 0:
+ return
+ mesh.fill_holes(max_hole_perimeter=max_hole_perimeter)
+ new_vertices, new_faces = mesh.read()
+
+ self.vertices = new_vertices.to(self.device)
+ self.faces = new_faces.to(self.device)
+
+ def remove_faces(self, face_mask: torch.Tensor):
+ vertices = self.vertices.clone().cuda().contiguous()
+ faces = self.faces.clone().cuda().contiguous()
+
+ mesh = cumesh.CuMesh()
+ mesh.init(vertices, faces)
+ mesh.remove_faces(face_mask)
+ new_vertices, new_faces = mesh.read()
+
+ self.vertices = new_vertices.to(self.device)
+ self.faces = new_faces.to(self.device)
+
+ def simplify(self, target=1000000, verbose: bool=False, options: dict={}):
+ vertices = self.vertices.clone().cuda().contiguous()
+ faces = self.faces.clone().cuda().contiguous()
+
+ mesh = cumesh.CuMesh()
+ mesh.init(vertices, faces)
+ mesh.simplify(target, verbose=verbose, options=options)
+ new_vertices, new_faces = mesh.read()
+
+ self.vertices = new_vertices.to(self.device)
+ self.faces = new_faces.to(self.device)
+
+
+class TextureFilterMode:
+ CLOSEST = 0
+ LINEAR = 1
+
+
+class TextureWrapMode:
+ CLAMP_TO_EDGE = 0
+ REPEAT = 1
+ MIRRORED_REPEAT = 2
+
+
+class AlphaMode:
+ OPAQUE = 0
+ MASK = 1
+ BLEND = 2
+
+
+class Texture:
+ def __init__(
+ self,
+ image: torch.Tensor,
+ filter_mode: TextureFilterMode = TextureFilterMode.LINEAR,
+ wrap_mode: TextureWrapMode = TextureWrapMode.REPEAT
+ ):
+ self.image = image
+ self.filter_mode = filter_mode
+ self.wrap_mode = wrap_mode
+
+ def to(self, device, non_blocking=False):
+ return Texture(
+ self.image.to(device, non_blocking=non_blocking),
+ self.filter_mode,
+ self.wrap_mode,
+ )
+
+
+class PbrMaterial:
+ def __init__(
+ self,
+ base_color_texture: Optional[Texture] = None,
+ base_color_factor: Union[torch.Tensor, List[float]] = [1.0, 1.0, 1.0],
+ metallic_texture: Optional[Texture] = None,
+ metallic_factor: float = 1.0,
+ roughness_texture: Optional[Texture] = None,
+ roughness_factor: float = 1.0,
+ alpha_texture: Optional[Texture] = None,
+ alpha_factor: float = 1.0,
+ alpha_mode: AlphaMode = AlphaMode.OPAQUE,
+ alpha_cutoff: float = 0.5,
+ ):
+ self.base_color_texture = base_color_texture
+ self.base_color_factor = torch.tensor(base_color_factor, dtype=torch.float32)[:3]
+ self.metallic_texture = metallic_texture
+ self.metallic_factor = metallic_factor
+ self.roughness_texture = roughness_texture
+ self.roughness_factor = roughness_factor
+ self.alpha_texture = alpha_texture
+ self.alpha_factor = alpha_factor
+ self.alpha_mode = alpha_mode
+ self.alpha_cutoff = alpha_cutoff
+
+ def to(self, device, non_blocking=False):
+ return PbrMaterial(
+ base_color_texture=self.base_color_texture.to(device, non_blocking=non_blocking) if self.base_color_texture is not None else None,
+ base_color_factor=self.base_color_factor.to(device, non_blocking=non_blocking),
+ metallic_texture=self.metallic_texture.to(device, non_blocking=non_blocking) if self.metallic_texture is not None else None,
+ metallic_factor=self.metallic_factor,
+ roughness_texture=self.roughness_texture.to(device, non_blocking=non_blocking) if self.roughness_texture is not None else None,
+ roughness_factor=self.roughness_factor,
+ alpha_texture=self.alpha_texture.to(device, non_blocking=non_blocking) if self.alpha_texture is not None else None,
+ alpha_factor=self.alpha_factor,
+ alpha_mode=self.alpha_mode,
+ alpha_cutoff=self.alpha_cutoff,
+ )
+
+
+class MeshWithPbrMaterial(Mesh):
+ def __init__(self,
+ vertices,
+ faces,
+ material_ids,
+ uv_coords,
+ materials: List[PbrMaterial],
+ ):
+ self.vertices = vertices.float()
+ self.faces = faces.int()
+ self.material_ids = material_ids # [M]
+ self.uv_coords = uv_coords # [M, 3, 2]
+ self.materials = materials
+ self.layout = {
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+
+ def to(self, device, non_blocking=False):
+ return MeshWithPbrMaterial(
+ self.vertices.to(device, non_blocking=non_blocking),
+ self.faces.to(device, non_blocking=non_blocking),
+ self.material_ids.to(device, non_blocking=non_blocking),
+ self.uv_coords.to(device, non_blocking=non_blocking),
+ [material.to(device, non_blocking=non_blocking) for material in self.materials],
+ )
+
+
+class MeshWithVoxel(Mesh, Voxel):
+ def __init__(self,
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ origin: list,
+ voxel_size: float,
+ coords: torch.Tensor,
+ attrs: torch.Tensor,
+ voxel_shape: torch.Size,
+ layout: Dict = {},
+ ):
+ self.vertices = vertices.float()
+ self.faces = faces.int()
+ self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
+ self.voxel_size = voxel_size
+ self.coords = coords
+ self.attrs = attrs
+ self.voxel_shape = voxel_shape
+ self.layout = layout
+
+ def to(self, device, non_blocking=False):
+ return MeshWithVoxel(
+ self.vertices.to(device, non_blocking=non_blocking),
+ self.faces.to(device, non_blocking=non_blocking),
+ self.origin.tolist(),
+ self.voxel_size,
+ self.coords.to(device, non_blocking=non_blocking),
+ self.attrs.to(device, non_blocking=non_blocking),
+ self.voxel_shape,
+ self.layout,
+ )
+
+ def query_attrs(self, xyz):
+ grid = ((xyz - self.origin) / self.voxel_size).reshape(1, -1, 3)
+ vertex_attrs = grid_sample_3d(
+ self.attrs,
+ torch.cat([torch.zeros_like(self.coords[..., :1]), self.coords], dim=-1),
+ self.voxel_shape,
+ grid,
+ mode='trilinear'
+ )[0]
+ return vertex_attrs
+
+ def query_vertex_attrs(self):
+ return self.query_attrs(self.vertices)
diff --git a/trellis2/representations/voxel/__init__.py b/trellis2/representations/voxel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5792ea14b2371a96c4130371eb976f0aff4b5dd
--- /dev/null
+++ b/trellis2/representations/voxel/__init__.py
@@ -0,0 +1 @@
+from .voxel_model import Voxel
\ No newline at end of file
diff --git a/trellis2/representations/voxel/voxel_model.py b/trellis2/representations/voxel/voxel_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9317ab22db61c5da96514ca46a780378392f3cc8
--- /dev/null
+++ b/trellis2/representations/voxel/voxel_model.py
@@ -0,0 +1,54 @@
+from typing import Dict
+import torch
+
+
+class Voxel:
+ def __init__(
+ self,
+ origin: list,
+ voxel_size: float,
+ coords: torch.Tensor = None,
+ attrs: torch.Tensor = None,
+ layout: Dict = {},
+ device: torch.device = 'cuda'
+ ):
+ self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
+ self.voxel_size = voxel_size
+ self.coords = coords
+ self.attrs = attrs
+ self.layout = layout
+ self.device = device
+
+ @property
+ def position(self):
+ return (self.coords + 0.5) * self.voxel_size + self.origin[None, :]
+
+ def split_attrs(self):
+ return {
+ k: self.attrs[:, self.layout[k]]
+ for k in self.layout
+ }
+
+ def save(self, path):
+ # lazy import
+ if 'o_voxel' not in globals():
+ import o_voxel
+ o_voxel.io.write(
+ path,
+ self.coords,
+ self.split_attrs(),
+ )
+
+ def load(self, path):
+ # lazy import
+ if 'o_voxel' not in globals():
+ import o_voxel
+ coord, attrs = o_voxel.io.read(path)
+ self.coords = coord.int().to(self.device)
+ self.attrs = torch.cat([attrs[k] for k in attrs], dim=1).to(self.device)
+ # build layout
+ start = 0
+ self.layout = {}
+ for k in attrs:
+ self.layout[k] = slice(start, start + attrs[k].shape[1])
+ start += attrs[k].shape[1]
diff --git a/trellis2/trainers/__init__.py b/trellis2/trainers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a107517e58c4582b6de4150b507012077700010
--- /dev/null
+++ b/trellis2/trainers/__init__.py
@@ -0,0 +1,74 @@
+import importlib
+
+__attributes = {
+ 'BasicTrainer': 'basic',
+
+ 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
+ 'ShapeVaeTrainer': 'vae.shape_vae',
+ 'PbrVaeTrainer': 'vae.pbr_vae',
+
+ 'FlowMatchingTrainer': 'flow_matching.flow_matching',
+ 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
+ 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
+ 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
+ 'ImageConditionedProjFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
+
+ 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
+ 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
+ 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
+ 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
+ 'MultiImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
+ 'ImageConditionedProjSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
+
+ 'DinoV2FeatureExtractor': 'flow_matching.mixins.image_conditioned',
+ 'DinoV3FeatureExtractor': 'flow_matching.mixins.image_conditioned',
+ 'DinoV3ProjFeatureExtractor': 'flow_matching.mixins.image_conditioned_proj',
+ 'DinoV3VaeProjFeatureExtractor': 'flow_matching.mixins.image_conditioned_proj',
+ 'ImageConditionedProjMixin': 'flow_matching.mixins.image_conditioned_proj',
+}
+
+__submodules = []
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+ if name not in globals():
+ if name in __attributes:
+ module_name = __attributes[name]
+ module = importlib.import_module(f".{module_name}", __name__)
+ globals()[name] = getattr(module, name)
+ elif name in __submodules:
+ module = importlib.import_module(f".{name}", __name__)
+ globals()[name] = module
+ else:
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+ return globals()[name]
+
+
+# For Pylance
+if __name__ == '__main__':
+ from .basic import BasicTrainer
+
+ from .vae.sparse_structure_vae import SparseStructureVaeTrainer
+ from .vae.shape_vae import ShapeVaeTrainer
+ from .vae.pbr_vae import PbrVaeTrainer
+
+ from .flow_matching.flow_matching import (
+ FlowMatchingTrainer,
+ FlowMatchingCFGTrainer,
+ TextConditionedFlowMatchingCFGTrainer,
+ ImageConditionedFlowMatchingCFGTrainer,
+ ImageConditionedProjFlowMatchingCFGTrainer,
+ )
+
+ from .flow_matching.sparse_flow_matching import (
+ SparseFlowMatchingTrainer,
+ SparseFlowMatchingCFGTrainer,
+ TextConditionedSparseFlowMatchingCFGTrainer,
+ ImageConditionedSparseFlowMatchingCFGTrainer,
+ )
+
+ from .flow_matching.mixins.image_conditioned import (
+ DinoV2FeatureExtractor,
+ DinoV3FeatureExtractor,
+ )
diff --git a/trellis2/trainers/basic.py b/trellis2/trainers/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ac701b32d7376a8bdcb66415d30d0c7890f1819
--- /dev/null
+++ b/trellis2/trainers/basic.py
@@ -0,0 +1,1293 @@
+from abc import abstractmethod
+import os
+import time
+import json
+import copy
+import threading
+from functools import partial
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+from torch.nn.parallel import DistributedDataParallel as DDP
+import numpy as np
+
+from torchvision import utils
+
+try:
+ import wandb
+ WANDB_AVAILABLE = True
+except ImportError:
+ WANDB_AVAILABLE = False
+
+from .utils import *
+from ..utils.general_utils import *
+from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
+from ..utils.dist_utils import *
+from ..utils import grad_clip_utils, elastic_utils
+
+
+class BasicTrainer:
+ """
+ Trainer for basic training loop.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ mix_precision_mode (str):
+ - None: No mixed precision.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ mix_precision_dtype (str): Mixed precision dtype.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ parallel_mode (str): Parallel mode. Options are 'ddp'.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+ """
+ def __init__(self,
+ models,
+ dataset,
+ *,
+ output_dir,
+ load_dir,
+ step,
+ max_steps,
+ batch_size=None,
+ batch_size_per_gpu=None,
+ batch_split=None,
+ optimizer={},
+ lr_scheduler=None,
+ elastic=None,
+ grad_clip=None,
+ ema_rate=0.9999,
+ fp16_mode=None,
+ mix_precision_mode='inflat_all',
+ mix_precision_dtype='float16',
+ fp16_scale_growth=1e-3,
+ parallel_mode='ddp',
+ finetune_ckpt=None,
+ log_param_stats=False,
+ prefetch_data=True,
+ snapshot_batch_size=4,
+ snapshot_num_samples=64,
+ num_workers=None,
+ debug=False,
+ i_print=1000,
+ i_log=500,
+ i_sample=10000,
+ i_save=10000,
+ i_ddpcheck=10000,
+ wandb_run=None, # wandb run object
+ **kwargs
+ ):
+ assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
+
+ self.models = models
+ self.dataset = dataset
+ self.batch_split = batch_split if batch_split is not None else 1
+ self.max_steps = max_steps
+ self.debug = debug
+ self.optimizer_config = optimizer
+ self.lr_scheduler_config = lr_scheduler
+ self.elastic_controller_config = elastic
+ self.grad_clip = grad_clip
+ self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
+ if fp16_mode is not None:
+ mix_precision_dtype = 'float16'
+ mix_precision_mode = fp16_mode
+ self.mix_precision_mode = mix_precision_mode
+ self.mix_precision_dtype = str_to_dtype(mix_precision_dtype)
+ self.fp16_scale_growth = fp16_scale_growth
+ self.parallel_mode = parallel_mode
+ self.log_param_stats = log_param_stats
+ self.prefetch_data = prefetch_data
+ self.snapshot_batch_size = snapshot_batch_size
+ self.snapshot_num_samples = snapshot_num_samples
+ self.num_workers = num_workers
+ self.log = []
+ if self.prefetch_data:
+ self._data_prefetched = None
+
+ self.output_dir = output_dir
+ from datetime import datetime
+ self._log_file = os.path.join(self.output_dir, f'log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt')
+ self.i_print = i_print
+ self.i_log = i_log
+ self.i_sample = i_sample
+ self.i_save = i_save
+ self.i_ddpcheck = i_ddpcheck
+
+ if dist.is_initialized():
+ # Multi-GPU params
+ self.world_size = dist.get_world_size()
+ self.rank = dist.get_rank()
+ self.local_rank = dist.get_rank() % torch.cuda.device_count()
+ self.is_master = self.rank == 0
+ else:
+ # Single-GPU params
+ self.world_size = 1
+ self.rank = 0
+ self.local_rank = 0
+ self.is_master = True
+
+ self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
+ self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
+ assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
+ assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
+
+ self.init_models_and_more(**kwargs)
+ self.prepare_dataloader(**kwargs)
+
+ # Load checkpoint
+ self.step = 0
+ if load_dir is not None and step is not None:
+ self.load(load_dir, step)
+ elif finetune_ckpt is not None:
+ self.finetune_from(finetune_ckpt)
+
+ if self.is_master:
+ os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
+ os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
+ self.writer = None # TensorBoard disabled (S3 FUSE does not support append)
+ # Initialize wandb
+ self.wandb_run = wandb_run
+ if self.wandb_run is not None:
+ print(f'Wandb logging enabled: {self.wandb_run.url}')
+
+ if self.parallel_mode == 'ddp' and self.world_size > 1:
+ self.check_ddp()
+
+ if self.is_master:
+ print('\n\nTrainer initialized.')
+ print(self)
+
+ def __str__(self):
+ lines = []
+ lines.append(self.__class__.__name__)
+ lines.append(f' - Models:')
+ for name, model in self.models.items():
+ lines.append(f' - {name}: {model.__class__.__name__}')
+ lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
+ lines.append(f' - Dataloader:')
+ lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
+ lines.append(f' - Num workers: {self.dataloader.num_workers}')
+ lines.append(f' - Number of steps: {self.max_steps}')
+ lines.append(f' - Number of GPUs: {self.world_size}')
+ lines.append(f' - Batch size: {self.batch_size}')
+ lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
+ lines.append(f' - Batch split: {self.batch_split}')
+ lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
+ lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
+ if self.lr_scheduler_config is not None:
+ lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
+ if self.elastic_controller_config is not None:
+ lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
+ if self.grad_clip is not None:
+ lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
+ lines.append(f' - EMA rate: {self.ema_rate}')
+ lines.append(f' - Mixed precision dtype: {self.mix_precision_dtype}')
+ lines.append(f' - Mixed precision mode: {self.mix_precision_mode}')
+ if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
+ lines.append(f' - FP16 scale growth: {self.fp16_scale_growth}')
+ lines.append(f' - Parallel mode: {self.parallel_mode}')
+ return '\n'.join(lines)
+
+ @property
+ def device(self):
+ for _, model in self.models.items():
+ if hasattr(model, 'device'):
+ return model.device
+ return next(list(self.models.values())[0].parameters()).device
+
+ def init_models_and_more(self, **kwargs):
+ """
+ Initialize models and more.
+ """
+ if self.world_size > 1:
+ # Prepare distributed data parallel
+ self.training_models = {
+ name: DDP(
+ model,
+ device_ids=[self.local_rank],
+ output_device=self.local_rank,
+ bucket_cap_mb=128,
+ find_unused_parameters=False
+ )
+ for name, model in self.models.items()
+ }
+ else:
+ self.training_models = self.models
+
+ # Build master params
+ self.model_params = sum(
+ [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
+ , [])
+ if self.mix_precision_mode == 'amp':
+ self.master_params = self.model_params
+ if self.mix_precision_dtype == torch.float16:
+ self.scaler = torch.GradScaler()
+ elif self.mix_precision_mode == 'inflat_all':
+ self.master_params = make_master_params(self.model_params)
+ if self.mix_precision_dtype == torch.float16:
+ self.log_scale = 20.0
+ elif self.mix_precision_mode is None:
+ self.master_params = self.model_params
+ else:
+ raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.')
+
+ # Build EMA params
+ if self.is_master:
+ self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
+
+ # Initialize optimizer
+ if hasattr(torch.optim, self.optimizer_config['name']):
+ self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
+ else:
+ self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
+
+ # Initalize learning rate scheduler
+ if self.lr_scheduler_config is not None:
+ if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
+ self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
+ else:
+ self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
+
+ # Initialize elastic memory controller
+ if self.elastic_controller_config is not None:
+ assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
+ 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
+ self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
+ for model in self.models.values():
+ if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
+ model.register_memory_controller(self.elastic_controller)
+
+ # Initialize gradient clipper
+ if self.grad_clip is not None:
+ if isinstance(self.grad_clip, (float, int)):
+ self.grad_clip = float(self.grad_clip)
+ else:
+ self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
+
+ def prepare_dataloader(self, **kwargs):
+ """
+ Prepare dataloader.
+ """
+ self.data_sampler = ResumableSampler(
+ self.dataset,
+ shuffle=True,
+ )
+ if self.num_workers is None or self.num_workers == -1:
+ num_workers = max(1, int(np.ceil((os.cpu_count() - 16) / torch.cuda.device_count())))
+ else:
+ num_workers = self.num_workers
+
+ self.dataloader = DataLoader(
+ self.dataset,
+ batch_size=self.batch_size_per_gpu,
+ num_workers=num_workers,
+ pin_memory=True,
+ drop_last=True,
+ persistent_workers=True,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ sampler=self.data_sampler,
+ )
+ self.data_iterator = cycle(self.dataloader)
+
+ def _master_params_to_state_dicts(self, master_params):
+ """
+ Convert master params to dict of state_dicts.
+ """
+ if self.mix_precision_mode == 'inflat_all':
+ master_params = unflatten_master_params(self.model_params, master_params)
+ state_dicts = {name: model.state_dict() for name, model in self.models.items()}
+ master_params_names = sum(
+ [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
+ , [])
+ for i, (model_name, param_name) in enumerate(master_params_names):
+ state_dicts[model_name][param_name] = master_params[i]
+ return state_dicts
+
+ def _state_dicts_to_master_params(self, master_params, state_dicts):
+ """
+ Convert a state_dict to master params.
+ """
+ master_params_names = sum(
+ [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
+ , [])
+ params = [state_dicts[name][param_name] for name, param_name in master_params_names]
+ if self.mix_precision_mode == 'inflat_all':
+ model_params_to_master_params(params, master_params)
+ else:
+ for i, param in enumerate(params):
+ master_params[i].data.copy_(param.data)
+
+ def load(self, load_dir, step=0):
+ """
+ Load a checkpoint.
+ Should be called by all processes.
+ """
+ if self.is_master:
+ print(f'\nLoading checkpoint from step {step}...', end='')
+
+ model_ckpts = {}
+ for name, model in self.models.items():
+ model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
+ model_ckpts[name] = model_ckpt
+ model.load_state_dict(model_ckpt)
+ self._state_dicts_to_master_params(self.master_params, model_ckpts)
+ del model_ckpts
+
+ if self.is_master:
+ for i, ema_rate in enumerate(self.ema_rate):
+ ema_ckpts = {}
+ for name, model in self.models.items():
+ ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
+ ema_ckpts[name] = ema_ckpt
+ self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
+ del ema_ckpts
+
+ misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
+ self.optimizer.load_state_dict(misc_ckpt['optimizer'])
+ self.step = misc_ckpt['step']
+ self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
+ if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
+ self.scaler.load_state_dict(misc_ckpt['scaler'])
+ elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
+ self.log_scale = misc_ckpt['log_scale']
+ if self.lr_scheduler_config is not None:
+ self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
+ if self.elastic_controller_config is not None:
+ self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
+ if self.grad_clip is not None and not isinstance(self.grad_clip, float):
+ self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
+ del misc_ckpt
+
+ if self.world_size > 1:
+ dist.barrier()
+ if self.is_master:
+ print(' Done.')
+
+ if self.world_size > 1:
+ self.check_ddp()
+
+ def save(self, non_blocking=True):
+ """
+ Save a checkpoint.
+ Should be called only by the rank 0 process.
+ """
+ assert self.is_master, 'save() should be called only by the rank 0 process.'
+ print(f'\nSaving checkpoint at step {self.step}...', end='')
+
+ model_ckpts = self._master_params_to_state_dicts(self.master_params)
+ for name, model_ckpt in model_ckpts.items():
+ model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving
+ if non_blocking:
+ threading.Thread(
+ target=torch.save,
+ args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')),
+ ).start()
+ else:
+ torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
+
+ for i, ema_rate in enumerate(self.ema_rate):
+ ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
+ for name, ema_ckpt in ema_ckpts.items():
+ ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()} # Move to CPU for saving
+ if non_blocking:
+ threading.Thread(
+ target=torch.save,
+ args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')),
+ ).start()
+ else:
+ torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
+
+ misc_ckpt = {
+ 'optimizer': self.optimizer.state_dict(),
+ 'step': self.step,
+ 'data_sampler': self.data_sampler.state_dict(),
+ }
+ if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
+ misc_ckpt['scaler'] = self.scaler.state_dict()
+ elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
+ misc_ckpt['log_scale'] = self.log_scale
+ if self.lr_scheduler_config is not None:
+ misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
+ if self.elastic_controller_config is not None:
+ misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
+ if self.grad_clip is not None and not isinstance(self.grad_clip, float):
+ misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
+ if non_blocking:
+ threading.Thread(
+ target=torch.save,
+ args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')),
+ ).start()
+ else:
+ torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
+ print(' Done.')
+
+ def _remap_checkpoint_keys(self, model_ckpt, model_state_dict):
+ """
+ Remap checkpoint keys to match model state dict.
+
+ Handles structural changes like:
+ - cross_attn.xxx -> cross_attn.cross_attn_block.xxx (for ProjectAttention wrapper)
+
+ Args:
+ model_ckpt: Checkpoint state dict
+ model_state_dict: Model state dict
+
+ Returns:
+ Remapped checkpoint dict
+ """
+ remapped_ckpt = {}
+ remapped_count = 0
+
+ for ckpt_key, ckpt_value in model_ckpt.items():
+ # Check if key exists directly
+ if ckpt_key in model_state_dict:
+ remapped_ckpt[ckpt_key] = ckpt_value
+ continue
+
+ # Try remapping: cross_attn.xxx -> cross_attn.cross_attn_block.xxx
+ # This handles the case when cross_attn is wrapped by ProjectAttention
+ if '.cross_attn.' in ckpt_key:
+ # Split at .cross_attn.
+ parts = ckpt_key.split('.cross_attn.')
+ if len(parts) == 2:
+ new_key = f'{parts[0]}.cross_attn.cross_attn_block.{parts[1]}'
+ if new_key in model_state_dict:
+ remapped_ckpt[new_key] = ckpt_value
+ remapped_count += 1
+ continue
+
+ # Key not remapped, keep original (will be handled by missing key logic)
+ remapped_ckpt[ckpt_key] = ckpt_value
+
+ if remapped_count > 0 and self.is_master:
+ print(f'Info: Remapped {remapped_count} cross_attn keys to cross_attn.cross_attn_block')
+
+ return remapped_ckpt
+
+ def finetune_from(self, finetune_ckpt):
+ """
+ Finetune from a checkpoint.
+ Should be called by all processes.
+ """
+ # 允许缺失的 keys(如 register_buffer 的参数)
+ ALLOWED_MISSING_KEYS = {'rope_phases'}
+
+ if self.is_master:
+ print('\nFinetuning from:')
+ for name, path in finetune_ckpt.items():
+ print(f' - {name}: {path}')
+
+ model_ckpts = {}
+ for name, model in self.models.items():
+ model_state_dict = model.state_dict()
+ if name in finetune_ckpt:
+ model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
+
+ # Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper)
+ model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict)
+
+ # 检查多余的 keys(在 ckpt 中但不在 model 中)
+ for k, v in model_ckpt.items():
+ if k not in model_state_dict:
+ if self.is_master:
+ print(f'Warning: {k} not found in model_state_dict, skipped.')
+ model_ckpt[k] = None
+ elif model_ckpt[k].shape != model_state_dict[k].shape:
+ if self.is_master:
+ print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
+ model_ckpt[k] = model_state_dict[k]
+ model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
+
+ # 检查缺失的 keys(在 model 中但不在 ckpt 中)
+ missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys())
+ unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS
+ if unexpected_missing and self.is_master:
+ print(f'Error: Missing keys in checkpoint: {unexpected_missing}')
+ raise RuntimeError(f'Missing keys in checkpoint: {unexpected_missing}')
+ if missing_keys & ALLOWED_MISSING_KEYS and self.is_master:
+ print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}')
+
+ # 补充缺失的 keys(使用模型初始化值)
+ for k in missing_keys:
+ model_ckpt[k] = model_state_dict[k]
+
+ model_ckpts[name] = model_ckpt
+ model.load_state_dict(model_ckpt)
+ else:
+ if self.is_master:
+ print(f'Warning: {name} not found in finetune_ckpt, skipped.')
+ model_ckpts[name] = model_state_dict
+ self._state_dicts_to_master_params(self.master_params, model_ckpts)
+ if self.is_master:
+ for i, ema_rate in enumerate(self.ema_rate):
+ self._state_dicts_to_master_params(self.ema_params[i], model_ckpts)
+ del model_ckpts
+
+ if self.world_size > 1:
+ dist.barrier()
+ if self.is_master:
+ print('Done.')
+
+ if self.world_size > 1:
+ self.check_ddp()
+
+ @abstractmethod
+ def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
+ """
+ Run a snapshot of the model.
+ """
+ pass
+
+ @torch.no_grad()
+ def visualize_sample(self, sample):
+ """
+ Convert a sample to an image.
+ """
+ if hasattr(self.dataset, 'visualize_sample'):
+ return self.dataset.visualize_sample(sample)
+ else:
+ return sample
+
+ @torch.no_grad()
+ def snapshot_dataset(self, num_samples=100, batch_size=4):
+ """
+ Sample images from the dataset.
+ """
+ dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=batch_size,
+ num_workers=0,
+ shuffle=True,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ )
+ save_cfg = {}
+ for i in range(0, num_samples, batch_size):
+ data = next(iter(dataloader))
+ data = {k: v[:min(num_samples - i, batch_size)] for k, v in data.items()}
+ data = recursive_to_device(data, self.device)
+ try:
+ vis = self.visualize_sample(data)
+ except (RuntimeError, Exception) as e:
+ print(f'\033[93m[WARN] snapshot_dataset visualize_sample failed (batch {i}), skipping: {e}\033[0m')
+ torch.cuda.empty_cache()
+ continue
+ if isinstance(vis, dict):
+ for k, v in vis.items():
+ if f'dataset_{k}' not in save_cfg:
+ save_cfg[f'dataset_{k}'] = []
+ save_cfg[f'dataset_{k}'].append(v)
+ else:
+ if 'dataset' not in save_cfg:
+ save_cfg['dataset'] = []
+ save_cfg['dataset'].append(vis)
+ for name, image in save_cfg.items():
+ utils.save_image(
+ torch.cat(image, dim=0),
+ os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
+ nrow=int(np.sqrt(num_samples)),
+ normalize=True,
+ value_range=self.dataset.value_range,
+ )
+
+ @torch.no_grad()
+ def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
+ """
+ Sample images from the model.
+ NOTE: When num_samples >= 4, this function should be called by all processes.
+ When num_samples < 4, only master runs snapshot (other ranks skip via barrier).
+ """
+ # Free cached GPU memory before snapshot to avoid OOM / illegal address errors
+ import gc
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if self.is_master:
+ print(f'\nSampling {num_samples} images...', end='')
+
+ if suffix is None:
+ suffix = f'step{self.step:07d}'
+
+ # When num_samples < 4, only master runs snapshot to avoid multi-rank gather issues
+ master_only = num_samples < 4
+
+ sample_metadata = None # Will hold list of "dataset_name/sha256" strings
+
+ if master_only and self.world_size > 1:
+ if not self.is_master:
+ # Non-master ranks just wait at barrier
+ dist.barrier()
+ return
+
+ # Master runs snapshot alone
+ amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext
+ with amp_context():
+ samples = self.run_snapshot(num_samples, batch_size=batch_size, verbose=verbose)
+
+ # Extract metadata before preprocessing
+ sample_metadata = samples.pop('_metadata', None)
+
+ # Free GPU memory after sampling, before decode + render
+ torch.cuda.empty_cache()
+
+ # Preprocess images
+ for key in list(samples.keys()):
+ if samples[key]['type'] == 'sample':
+ try:
+ vis = self.visualize_sample(samples[key]['value'])
+ except RuntimeError as e:
+ print(f"[Snapshot] WARNING: visualize_sample failed for '{key}': {e}")
+ # Reset CUDA error state and skip this sample
+ try:
+ torch.cuda.synchronize()
+ except RuntimeError:
+ pass
+ torch.cuda.empty_cache()
+ del samples[key]
+ continue
+ if isinstance(vis, dict):
+ for k, v in vis.items():
+ samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
+ del samples[key]
+ else:
+ samples[key] = {'value': vis, 'type': 'image'}
+
+ # No gather needed, master already has all samples
+ dist.barrier()
+ else:
+ # Distribute sampling across all ranks
+ num_samples_per_process = int(np.ceil(num_samples / self.world_size))
+ amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext
+
+ with amp_context():
+ samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
+
+ # Extract metadata before preprocessing
+ local_metadata = samples.pop('_metadata', None)
+
+ # Free GPU memory after sampling, before decode + render
+ torch.cuda.empty_cache()
+
+ # Preprocess images
+ for key in list(samples.keys()):
+ if samples[key]['type'] == 'sample':
+ try:
+ vis = self.visualize_sample(samples[key]['value'])
+ except RuntimeError as e:
+ print(f"[Snapshot] WARNING: visualize_sample failed for '{key}': {e}")
+ torch.cuda.synchronize()
+ del samples[key]
+ continue
+ if isinstance(vis, dict):
+ for k, v in vis.items():
+ samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
+ del samples[key]
+ else:
+ samples[key] = {'value': vis, 'type': 'image'}
+
+ # Gather results
+ if self.world_size > 1:
+ for key in samples.keys():
+ samples[key]['value'] = samples[key]['value'].contiguous()
+ if self.is_master:
+ all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
+ else:
+ all_images = []
+ dist.gather(samples[key]['value'], all_images, dst=0)
+ if self.is_master:
+ samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
+
+ # Gather metadata across ranks
+ if local_metadata is not None:
+ all_metadata = [None] * self.world_size
+ dist.all_gather_object(all_metadata, local_metadata)
+ if self.is_master:
+ sample_metadata = sum(all_metadata, [])[:num_samples]
+ else:
+ sample_metadata = None
+ else:
+ sample_metadata = local_metadata
+
+ # Save images
+ if self.is_master:
+ os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
+ wandb_images = {} # Collect images for wandb logging
+ nrow = int(np.sqrt(num_samples))
+ vr = self.dataset.value_range
+
+ # Build metadata caption string for wandb
+ metadata_caption = ''
+ if sample_metadata:
+ metadata_caption = '\n' + ' | '.join(sample_metadata)
+ # Also save metadata to file
+ with open(os.path.join(self.output_dir, 'samples', suffix, 'metadata.txt'), 'w') as f:
+ for i, m in enumerate(sample_metadata):
+ f.write(f'{i}: {m}\n')
+
+ # Helper: make a normalized grid tensor from a batch of images
+ def _make_grid(tensor):
+ return utils.make_grid(tensor, nrow=nrow, normalize=True, value_range=vr)
+
+ # Helper: resize grid to target height (keep aspect ratio)
+ def _resize_to_height(grid, target_h):
+ import torch.nn.functional as F
+ _, h, w = grid.shape
+ if h == target_h:
+ return grid
+ target_w = int(round(w * target_h / h))
+ return F.interpolate(grid.unsqueeze(0), size=(target_h, target_w), mode='bilinear', align_corners=False).squeeze(0)
+
+ # --- Save individual images (original behavior) ---
+ for key in samples.keys():
+ if samples[key]['type'] == 'image':
+ image_path = os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg')
+ utils.save_image(
+ samples[key]['value'],
+ image_path,
+ nrow=nrow,
+ normalize=True,
+ value_range=vr,
+ )
+ # Collect for wandb
+ if self.wandb_run is not None:
+ grid = _make_grid(samples[key]['value'])
+ grid_np = grid.permute(1, 2, 0).cpu().numpy()
+ grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8)
+ wandb_images[f'samples/{key}'] = wandb.Image(grid_np, caption=f'{key} at step {self.step}{metadata_caption}')
+ elif samples[key]['type'] == 'number':
+ val_min = samples[key]['value'].min()
+ val_max = samples[key]['value'].max()
+ images = (samples[key]['value'] - val_min) / (val_max - val_min)
+ images = utils.make_grid(
+ images,
+ nrow=nrow,
+ normalize=False,
+ )
+ save_image_with_notes(
+ images,
+ os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
+ notes=f'{key} min: {val_min}, max: {val_max}',
+ )
+
+ # --- Save combined images ---
+ sample_keys = set(samples.keys())
+
+ # Combined 1: image + sample_gt_view + sample_gt_gt_view (shape)
+ # image + sample_gt_view_{attr} + sample_gt_gt_view_{attr} (tex, per attribute)
+ # Detect gt_view attribute suffixes from sample keys
+ gt_view_attrs = set()
+ for k in sample_keys:
+ if k.startswith('sample_gt_view_'):
+ attr = k[len('sample_gt_view_'):]
+ gt_view_attrs.add(attr)
+
+ if gt_view_attrs:
+ # Tex mode: generate combined view for each PBR attribute
+ for attr in sorted(gt_view_attrs):
+ combo1_keys = ['image', f'sample_gt_view_{attr}', f'sample_gt_gt_view_{attr}']
+ combo1_present = [k for k in combo1_keys if k in sample_keys and samples[k]['type'] == 'image']
+ if len(combo1_present) >= 2:
+ grids = [_make_grid(samples[k]['value']) for k in combo1_present]
+ target_h = max(g.shape[1] for g in grids)
+ grids = [_resize_to_height(g, target_h) for g in grids]
+ combined = torch.cat(grids, dim=2)
+ combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_views_{attr}_{suffix}.jpg')
+ utils.save_image(combined, combined_path, normalize=False)
+ if self.wandb_run is not None:
+ grid_np = combined.permute(1, 2, 0).cpu().numpy()
+ grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8)
+ label = ' | '.join(combo1_present)
+ wandb_images[f'samples/combined_views_{attr}'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}')
+ else:
+ # Shape mode: single gt_view
+ combo1_keys = ['image', 'sample_gt_view', 'sample_gt_gt_view']
+ combo1_present = [k for k in combo1_keys if k in sample_keys and samples[k]['type'] == 'image']
+ if len(combo1_present) >= 2:
+ grids = [_make_grid(samples[k]['value']) for k in combo1_present]
+ target_h = max(g.shape[1] for g in grids)
+ grids = [_resize_to_height(g, target_h) for g in grids]
+ combined = torch.cat(grids, dim=2)
+ combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_views_{suffix}.jpg')
+ utils.save_image(combined, combined_path, normalize=False)
+ if self.wandb_run is not None:
+ grid_np = combined.permute(1, 2, 0).cpu().numpy()
+ grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8)
+ label = ' | '.join(combo1_present)
+ wandb_images[f'samples/combined_views'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}')
+
+ # Combined 2: sample_multiview + sample_gt_multiview
+ combo2_keys = ['sample_multiview', 'sample_gt_multiview']
+ combo2_present = [k for k in combo2_keys if k in sample_keys and samples[k]['type'] == 'image']
+ if len(combo2_present) >= 2:
+ grids = [_make_grid(samples[k]['value']) for k in combo2_present]
+ target_h = max(g.shape[1] for g in grids)
+ grids = [_resize_to_height(g, target_h) for g in grids]
+ combined = torch.cat(grids, dim=2) # concatenate along width
+ combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_multiview_{suffix}.jpg')
+ utils.save_image(combined, combined_path, normalize=False)
+ if self.wandb_run is not None:
+ grid_np = combined.permute(1, 2, 0).cpu().numpy()
+ grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8)
+ label = ' | '.join(combo2_present)
+ wandb_images[f'samples/combined_multiview'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}')
+
+ # Log images to wandb
+ if self.wandb_run is not None and wandb_images:
+ self.wandb_run.log(wandb_images, step=self.step)
+
+ if self.is_master:
+ print(' Done.')
+
+ def update_ema(self):
+ """
+ Update exponential moving average.
+ Should only be called by the rank 0 process.
+ """
+ assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
+ for i, ema_rate in enumerate(self.ema_rate):
+ for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
+ ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
+
+ def check_ddp(self):
+ """
+ Check if DDP is working properly.
+ Should be called by all process.
+ """
+ if self.is_master:
+ print('\nPerforming DDP check...')
+
+ if self.is_master:
+ print('Checking if parameters are consistent across processes...')
+ dist.barrier()
+ try:
+ for p in self.master_params:
+ # split to avoid OOM
+ for i in range(0, p.numel(), 10000000):
+ sub_size = min(10000000, p.numel() - i)
+ sub_p = p.detach().view(-1)[i:i+sub_size]
+ # gather from all processes
+ sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
+ dist.all_gather(sub_p_gather, sub_p)
+ # check if equal
+ assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
+ except AssertionError as e:
+ if self.is_master:
+ print(f'\n\033[91mError: {e}\033[0m')
+ print('DDP check failed.')
+ raise e
+
+ dist.barrier()
+ if self.is_master:
+ print('Done.')
+
+ def _verify_gradient_sync(self):
+ """
+ 验证 DDP 梯度同步是否真正生效。
+ DDP 的 backward 会自动对梯度进行 all_reduce,同步后所有卡的梯度应该完全相同。
+
+ 验证方法:
+ 1. 计算所有参数的总梯度 norm
+ 2. 收集各卡的梯度 norm
+ 3. 如果 DDP 同步正常,所有卡的梯度 norm 应该完全相同
+ 4. 如果没有同步,各卡梯度 norm 会不同(因为各卡处理的数据不同)
+ """
+ # 计算本卡所有参数的总梯度 norm
+ total_grad_norm_sq = 0.0
+ grad_count = 0
+ for p in self.model_params:
+ if p.grad is not None:
+ total_grad_norm_sq += p.grad.detach().float().norm().item() ** 2
+ grad_count += 1
+
+ if grad_count == 0:
+ return
+
+ local_grad_norm = total_grad_norm_sq ** 0.5
+
+ # 确保所有进程到达同一点
+ dist.barrier()
+
+ # 收集所有卡的梯度 norm
+ grad_norm_tensor = torch.tensor([local_grad_norm], dtype=torch.float64, device=self.device)
+ all_grad_norms = [torch.zeros(1, dtype=torch.float64, device=self.device) for _ in range(self.world_size)]
+ dist.all_gather(all_grad_norms, grad_norm_tensor)
+ all_grad_norms = [g.item() for g in all_grad_norms]
+
+ # 验证所有卡的梯度 norm 是否相同(使用相对误差,容忍 0.1%)
+ ref_norm = all_grad_norms[0]
+ if ref_norm > 0:
+ is_synced = all(abs(g - ref_norm) / ref_norm < 1e-3 for g in all_grad_norms)
+ else:
+ is_synced = all(abs(g) < 1e-10 for g in all_grad_norms)
+
+ if self.is_master:
+ print(f'\n{"="*60}')
+ print(f'[Step {self.step}] DDP Gradient Sync Verification:')
+ for i, g in enumerate(all_grad_norms):
+ print(f' Rank {i} grad_norm: {g:.8f}')
+ if is_synced:
+ print(f' \033[92m✓ PASS: All gradients are synchronized!\033[0m')
+ else:
+ max_diff = max(abs(g - ref_norm) for g in all_grad_norms)
+ print(f' \033[91m✗ FAIL: Gradients are NOT synchronized! Max diff: {max_diff:.8f}\033[0m')
+ print(f'{"="*60}\n')
+
+ @abstractmethod
+ def training_losses(**mb_data):
+ """
+ Compute training losses.
+ """
+ pass
+
+ def load_data(self):
+ """
+ Load data.
+ """
+ if self.prefetch_data:
+ if self._data_prefetched is None:
+ self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
+ data = self._data_prefetched
+ self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
+ else:
+ data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
+
+ # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
+ if isinstance(data, dict):
+ if self.batch_split == 1:
+ data_list = [data]
+ else:
+ batch_size = list(data.values())[0].shape[0]
+ data_list = [
+ {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
+ for i in range(self.batch_split)
+ ]
+ elif isinstance(data, list):
+ data_list = data
+ else:
+ raise ValueError('Data must be a dict or a list of dicts.')
+
+ return data_list
+
+ def run_step(self, data_list):
+ """
+ Run a training step.
+ """
+ step_log = {'loss': {}, 'status': {}}
+ amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext
+ elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
+
+ # Train
+ losses = []
+ statuses = []
+ elastic_controller_logs = []
+ zero_grad(self.model_params)
+ for i, mb_data in enumerate(data_list):
+ ## sync at the end of each batch split
+ sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
+ with nested_contexts(*sync_contexts), elastic_controller_context():
+ with amp_context():
+ loss, status = self.training_losses(**mb_data)
+ l = loss['loss'] / len(data_list)
+
+ # DEBUG: 打印每个 rank 的 loss
+ if self.debug:
+ print(f'[Rank {self.rank}/{self.world_size}] Step {self.step} batch {i}: loss={loss["loss"].item():.6f}')
+
+ ## backward
+ if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
+ self.scaler.scale(l).backward()
+ elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
+ scaled_l = l * (2 ** self.log_scale)
+ scaled_l.backward()
+ else:
+ l.backward()
+ ## log
+ losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
+ statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
+ if self.elastic_controller_config is not None:
+ elastic_controller_logs.append(self.elastic_controller.log())
+
+ # ============================================================
+ # DEBUG: 验证 DDP 梯度同步
+ # 检查 backward 后各卡梯度是否一致
+ # DDP 在最后一个 batch_split 的 backward 时会自动 all_reduce 梯度
+ # 同步后所有卡的梯度应该完全相同
+ # ============================================================
+ if self.debug and self.world_size > 1:
+ self._verify_gradient_sync()
+
+ ## gradient clip
+ if self.grad_clip is not None:
+ if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
+ self.scaler.unscale_(self.optimizer)
+ elif self.mix_precision_mode == 'inflat_all':
+ model_grads_to_master_grads(self.model_params, self.master_params)
+ if self.mix_precision_dtype == torch.float16:
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
+ if isinstance(self.grad_clip, float):
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
+ else:
+ grad_norm = self.grad_clip(self.master_params)
+ if torch.isfinite(grad_norm):
+ statuses[-1]['grad_norm'] = grad_norm.item()
+ ## step
+ if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
+ prev_scale = self.scaler.get_scale()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ elif self.mix_precision_mode == 'inflat_all':
+ if self.mix_precision_dtype == torch.float16:
+ prev_scale = 2 ** self.log_scale
+ if not any(not p.grad.isfinite().all() for p in self.model_params):
+ if self.grad_clip is None:
+ model_grads_to_master_grads(self.model_params, self.master_params)
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
+ self.optimizer.step()
+ master_params_to_model_params(self.model_params, self.master_params)
+ self.log_scale += self.fp16_scale_growth
+ else:
+ self.log_scale -= 1
+ else:
+ prev_scale = 1.0
+ if self.grad_clip is None:
+ model_grads_to_master_grads(self.model_params, self.master_params)
+ if not any(not p.grad.isfinite().all() for p in self.master_params):
+ self.optimizer.step()
+ master_params_to_model_params(self.model_params, self.master_params)
+ else:
+ print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
+ else:
+ prev_scale = 1.0
+ if not any(not p.grad.isfinite().all() for p in self.model_params):
+ self.optimizer.step()
+ else:
+ print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
+ ## adjust learning rate
+ if self.lr_scheduler_config is not None:
+ statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
+ self.lr_scheduler.step()
+
+ # Logs
+ step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
+ step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
+ if self.elastic_controller_config is not None:
+ step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
+ if self.grad_clip is not None:
+ step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
+
+ # Check grad and norm of each param
+ if self.log_param_stats:
+ param_norms = {}
+ param_grads = {}
+ for model_name, model in self.models.items():
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ param_norms[f'{model_name}.{name}'] = param.norm().item()
+ if param.grad is not None and torch.isfinite(param.grad).all():
+ param_grads[f'{model_name}.{name}'] = param.grad.norm().item() / prev_scale
+ step_log['param_norms'] = param_norms
+ step_log['param_grads'] = param_grads
+
+ # Update exponential moving average
+ if self.is_master:
+ self.update_ema()
+
+ return step_log
+
+ def save_logs(self):
+ log_str = '\n'.join([
+ f'{step}: {json.dumps(dict_foreach(log, lambda x: float(x)))}' for step, log in self.log
+ ])
+
+ # Accumulate logs in memory and overwrite file each time (S3 FUSE does not support append)
+ if not hasattr(self, '_log_buffer'):
+ self._log_buffer = []
+ self._log_buffer.append(log_str)
+ try:
+ with open(self._log_file, 'w') as log_file:
+ log_file.write('\n'.join(self._log_buffer) + '\n')
+ except Exception as e:
+ print(f'\033[93m[WARN] Failed to write log file: {e}\033[0m')
+
+ # show with mlflow
+ log_show = [l for _, l in self.log if not dict_any(l, lambda x: np.isnan(x))]
+ log_show = dict_reduce(log_show, lambda x: np.mean(x))
+ log_show = dict_flatten(log_show, sep='/')
+ if self.writer is not None:
+ for key, value in log_show.items():
+ self.writer.add_scalar(key, value, self.step)
+
+ # Log to wandb
+ if self.wandb_run is not None:
+ wandb_log = {key: value for key, value in log_show.items()}
+ wandb_log['step'] = self.step
+ self.wandb_run.log(wandb_log, step=self.step)
+
+ self.log = []
+
+ def check_abort(self):
+ """
+ Check if training should be aborted due to certain conditions.
+ """
+ # 1. If log_scale in inflat_all mode is less than 0
+ if self.mix_precision_dtype == torch.float16 and \
+ self.mix_precision_mode == 'inflat_all' and \
+ self.log_scale < 0:
+ if self.is_master:
+ print ('\n\n\033[91m')
+ print (f'ABORT: log_scale in inflat_all mode is less than 0 at step {self.step}.')
+ print ('This indicates that the model is diverging. You should look into the model and the data.')
+ print ('\033[0m')
+ self.save(non_blocking=False)
+ self.save_logs()
+ if self.world_size > 1:
+ dist.barrier()
+ raise ValueError('ABORT: log_scale in inflat_all mode is less than 0.')
+
+ def run(self):
+ """
+ Run training.
+ """
+ if self.is_master:
+ print('\nStarting training...')
+ if self.i_sample != -1:
+ try:
+ self.snapshot_dataset(num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size)
+ except (RuntimeError, Exception) as e:
+ print(f'\033[93m[WARN] snapshot_dataset failed, skipping: {e}\033[0m')
+ torch.cuda.empty_cache()
+ else:
+ print('[INFO] i_sample=-1, all snapshots disabled.')
+ if self.i_sample != -1:
+ if self.step == 0:
+ try:
+ self.snapshot(suffix='init', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size)
+ except (RuntimeError, Exception) as e:
+ print(f'\033[93m[WARN] snapshot (init) failed, skipping: {e}\033[0m')
+ torch.cuda.empty_cache()
+ else: # resume
+ try:
+ self.snapshot(suffix=f'resume_step{self.step:07d}', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size)
+ except (RuntimeError, Exception) as e:
+ print(f'\033[93m[WARN] snapshot (resume) failed, skipping: {e}\033[0m')
+ torch.cuda.empty_cache()
+
+ time_last_print = 0.0
+ time_elapsed = 0.0
+ while self.step < self.max_steps:
+ time_start = time.time()
+
+ data_list = self.load_data()
+ step_log = self.run_step(data_list)
+
+ time_end = time.time()
+ time_elapsed += time_end - time_start
+
+ self.step += 1
+
+ # Print progress
+ if self.is_master and self.step % self.i_print == 0:
+ speed = self.i_print / (time_elapsed - time_last_print) * 3600
+ columns = [
+ f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
+ f'Elapsed: {time_elapsed / 3600:.2f} h',
+ f'Speed: {speed:.2f} steps/h',
+ f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
+ ]
+ print(' | '.join([c.ljust(25) for c in columns]), flush=True)
+ time_last_print = time_elapsed
+
+ # Check ddp
+ if self.parallel_mode == 'ddp' and self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
+ self.check_ddp()
+
+ # Sample images
+ if self.i_sample != -1 and self.step % self.i_sample == 0:
+ try:
+ self.snapshot(num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size)
+ except (RuntimeError, Exception) as e:
+ if self.is_master:
+ print(f'\033[93m[WARN] snapshot at step {self.step} failed, skipping: {e}\033[0m')
+ try:
+ torch.cuda.empty_cache()
+ except Exception:
+ pass
+
+ if self.is_master:
+ self.log.append((self.step, {}))
+
+ # Log time
+ self.log[-1][1]['time'] = {
+ 'step': time_end - time_start,
+ 'elapsed': time_elapsed,
+ }
+
+ # Log losses
+ if step_log is not None:
+ self.log[-1][1].update(step_log)
+
+ # Log scale
+ if self.mix_precision_dtype == torch.float16:
+ if self.mix_precision_mode == 'amp':
+ self.log[-1][1]['scale'] = self.scaler.get_scale()
+ elif self.mix_precision_mode == 'inflat_all':
+ self.log[-1][1]['log_scale'] = self.log_scale
+
+ # Save log
+ if self.step % self.i_log == 0:
+ self.save_logs()
+
+ # Save checkpoint
+ if self.step % self.i_save == 0:
+ self.save()
+
+ # Check abort
+ self.check_abort()
+
+ if self.i_sample != -1:
+ try:
+ self.snapshot(suffix='final', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size)
+ except (RuntimeError, Exception) as e:
+ if self.is_master:
+ print(f'\033[93m[WARN] snapshot (final) failed, skipping: {e}\033[0m')
+ torch.cuda.empty_cache()
+ if self.world_size > 1:
+ dist.barrier()
+ if self.is_master:
+ self.writer.close()
+ print('Training finished.')
+
+ def profile(self, wait=2, warmup=3, active=5):
+ """
+ Profile the training loop.
+ """
+ with torch.profiler.profile(
+ schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
+ profile_memory=True,
+ with_stack=True,
+ ) as prof:
+ for _ in range(wait + warmup + active):
+ self.run_step()
+ prof.step()
diff --git a/trellis2/trainers/flow_matching/flow_matching.py b/trellis2/trainers/flow_matching/flow_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ab1bea2e3ca747ec3fdd521309b07e71b5e12d9
--- /dev/null
+++ b/trellis2/trainers/flow_matching/flow_matching.py
@@ -0,0 +1,655 @@
+from typing import *
+import copy
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import numpy as np
+from easydict import EasyDict as edict
+
+from ..basic import BasicTrainer
+from ...pipelines import samplers
+from ...utils.general_utils import dict_reduce
+from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
+from .mixins.text_conditioned import TextConditionedMixin
+from .mixins.image_conditioned import ImageConditionedMixin
+from .mixins.image_conditioned_proj import ImageConditionedProjMixin
+
+
+class FlowMatchingTrainer(BasicTrainer):
+ """
+ Trainer for diffusion model with flow matching objective.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ """
+ def __init__(
+ self,
+ *args,
+ t_schedule: dict = {
+ 'name': 'logitNormal',
+ 'args': {
+ 'mean': 0.0,
+ 'std': 1.0,
+ }
+ },
+ sigma_min: float = 1e-5,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.t_schedule = t_schedule
+ self.sigma_min = sigma_min
+
+ def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+
+ Args:
+ x_0: The [N x C x ...] tensor of noiseless inputs.
+ t: The [N] tensor of diffusion steps [0-1].
+ noise: If specified, use this noise instead of generating new noise.
+
+ Returns:
+ x_t, the noisy version of x_0 under timestep t.
+ """
+ if noise is None:
+ noise = torch.randn_like(x_0)
+ assert noise.shape == x_0.shape, "noise must have same shape as x_0"
+
+ t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
+ x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
+
+ return x_t
+
+ def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
+ """
+ Get original image from noisy version under timestep t.
+ """
+ assert noise.shape == x_t.shape, "noise must have same shape as x_t"
+ t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
+ x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
+ return x_0
+
+ def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the velocity of the diffusion process at time t.
+ """
+ return (1 - self.sigma_min) * noise - x_0
+
+ def get_cond(self, cond, **kwargs):
+ """
+ Get the conditioning data.
+ """
+ return cond
+
+ def get_inference_cond(self, cond, **kwargs):
+ """
+ Get the conditioning data for inference.
+ """
+ return {'cond': cond, **kwargs}
+
+ def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
+ """
+ Get the sampler for the diffusion process.
+ """
+ return samplers.FlowEulerSampler(self.sigma_min)
+
+ def vis_cond(self, **kwargs):
+ """
+ Visualize the conditioning data.
+ """
+ return {}
+
+ def sample_t(self, batch_size: int) -> torch.Tensor:
+ """
+ Sample timesteps.
+ """
+ if self.t_schedule['name'] == 'uniform':
+ t = torch.rand(batch_size)
+ elif self.t_schedule['name'] == 'logitNormal':
+ mean = self.t_schedule['args']['mean']
+ std = self.t_schedule['args']['std']
+ t = torch.sigmoid(torch.randn(batch_size) * std + mean)
+ else:
+ raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
+ return t
+
+ def training_losses(
+ self,
+ x_0: torch.Tensor,
+ cond=None,
+ **kwargs
+ ) -> Tuple[Dict, Dict]:
+ """
+ Compute training losses for a single timestep.
+
+ Args:
+ x_0: The [N x C x ...] tensor of noiseless inputs.
+ cond: The [N x ...] tensor of additional conditions.
+ kwargs: Additional arguments to pass to the backbone.
+
+ Returns:
+ a dict with the key "loss" containing a tensor of shape [N].
+ may also contain other keys for different terms.
+ """
+ noise = torch.randn_like(x_0)
+ t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
+ x_t = self.diffuse(x_0, t, noise=noise)
+ cond = self.get_cond(cond, **kwargs)
+
+ pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
+ assert pred.shape == noise.shape == x_0.shape
+ target = self.get_v(x_0, noise, t)
+ terms = edict()
+ terms["mse"] = F.mse_loss(pred, target)
+ terms["loss"] = terms["mse"]
+
+ # log loss with time bins
+ mse_per_instance = np.array([
+ F.mse_loss(pred[i], target[i]).item()
+ for i in range(x_0.shape[0])
+ ])
+ time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
+ for i in range(10):
+ if (time_bin == i).sum() != 0:
+ terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
+
+ return terms, {}
+
+ @torch.no_grad()
+ def run_snapshot(
+ self,
+ num_samples: int,
+ batch_size: int,
+ verbose: bool = False,
+ ) -> Dict:
+ # Use current step as seed to ensure different samples for each snapshot
+ import random
+ snapshot_seed = self.step
+ random.seed(snapshot_seed)
+ np.random.seed(snapshot_seed)
+
+ g = torch.Generator()
+ g.manual_seed(snapshot_seed)
+
+ dataloader = DataLoader(
+ copy.deepcopy(self.dataset),
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=0,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ generator=g,
+ )
+
+ # inference
+ sampler = self.get_sampler()
+ sample_gt = []
+ sample = []
+ cond_vis = []
+ sample_metadata = []
+ for i in range(0, num_samples, batch_size):
+ batch = min(batch_size, num_samples - i)
+ data = next(iter(dataloader))
+ data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
+
+ # Collect metadata (dataset_name and sha256) for wandb display
+ if '_dataset_name' in data and '_sha256' in data:
+ for j in range(batch):
+ sample_metadata.append(f"{data['_dataset_name'][j]}/{data['_sha256'][j]}")
+
+ # Remove metadata fields before inference
+ data.pop('_dataset_name', None)
+ data.pop('_sha256', None)
+
+ noise = torch.randn_like(data['x_0'])
+ sample_gt.append(data['x_0'])
+ cond_vis.append(self.vis_cond(**data))
+ del data['x_0']
+ args = self.get_inference_cond(**data)
+ res = sampler.sample(
+ self.models['denoiser'],
+ noise=noise,
+ **args,
+ steps=50, guidance_strength=3.0, verbose=verbose,
+ )
+ sample.append(res.samples)
+
+ sample_gt = torch.cat(sample_gt, dim=0)
+ sample = torch.cat(sample, dim=0)
+ sample_dict = {
+ 'sample_gt': {'value': sample_gt, 'type': 'sample'},
+ 'sample': {'value': sample, 'type': 'sample'},
+ }
+ if sample_metadata:
+ sample_dict['_metadata'] = sample_metadata
+ sample_dict.update(dict_reduce(cond_vis, None, {
+ 'value': lambda x: torch.cat(x, dim=0),
+ 'type': lambda x: x[0],
+ }))
+
+ return sample_dict
+
+
+class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
+ """
+ Trainer for diffusion model with flow matching objective and classifier-free guidance.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ """
+ pass
+
+
+class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
+ """
+ Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ text_cond_model(str): Text conditioning model.
+ """
+ pass
+
+
+class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
+ """
+ Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ image_cond_model (str): Image conditioning model.
+ """
+ pass
+
+
+class ImageConditionedProjFlowMatchingCFGTrainer(ImageConditionedProjMixin, FlowMatchingCFGTrainer):
+ """
+ Trainer for image-conditioned diffusion model with view-aligned projection.
+
+ Uses ImageConditionedProjMixin for 3D-to-2D feature projection with camera parameters.
+ CFG dropout is handled by ClassifierFreeGuidanceMixin (via p_uncond parameter).
+
+ Args:
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ image_cond_model (dict): Image conditioning model config (DinoV3ProjFeatureExtractor).
+ run_projection_test (bool): Whether to run projection visualization test before training.
+ """
+
+ def __init__(self, *args, run_projection_test: bool = True, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.run_projection_test = run_projection_test
+
+ def training_losses(
+ self,
+ x_0: torch.Tensor,
+ cond=None,
+ **kwargs
+ ) -> Tuple[Dict, Dict]:
+ """
+ Compute training losses for a single timestep.
+
+ Overridden to avoid passing extra kwargs to the model.
+
+ Args:
+ x_0: The [N x C x ...] tensor of noiseless inputs.
+ cond: The [N x ...] tensor of additional conditions.
+ kwargs: Additional arguments (camera info, view_idx, etc.) for conditioning.
+
+ Returns:
+ a dict with the key "loss" containing a tensor of shape [N].
+ may also contain other keys for different terms.
+ """
+ noise = torch.randn_like(x_0)
+ t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
+ x_t = self.diffuse(x_0, t, noise=noise)
+ cond = self.get_cond(cond, **kwargs)
+
+ # Note: SparseStructureFlowModel.forward() only accepts (x, t, cond)
+ # Do not pass extra kwargs to avoid unexpected keyword argument errors
+ pred = self.training_models['denoiser'](x_t, t * 1000, cond)
+ assert pred.shape == noise.shape == x_0.shape
+ target = self.get_v(x_0, noise, t)
+ terms = edict()
+ terms["mse"] = F.mse_loss(pred, target)
+ terms["loss"] = terms["mse"]
+
+ # log loss with time bins
+ mse_per_instance = np.array([
+ F.mse_loss(pred[i], target[i]).item()
+ for i in range(x_0.shape[0])
+ ])
+ time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
+ for i in range(10):
+ if (time_bin == i).sum() != 0:
+ terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
+
+ return terms, {}
+
+ @torch.no_grad()
+ def run_snapshot(
+ self,
+ num_samples: int,
+ batch_size: int,
+ verbose: bool = False,
+ ) -> Dict:
+ """
+ Run snapshot with camera parameters for GT view rendering.
+
+ Overrides parent to include camera parameters in sample dicts for
+ visualizing the GT camera view alongside the standard 4-view rendering.
+ """
+ # Use current step as seed to ensure different samples for each snapshot
+ import random
+ snapshot_seed = self.step
+ random.seed(snapshot_seed)
+ np.random.seed(snapshot_seed)
+
+ g = torch.Generator()
+ g.manual_seed(snapshot_seed)
+
+ dataloader = DataLoader(
+ copy.deepcopy(self.dataset),
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=0,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ generator=g,
+ )
+
+ # inference
+ sampler = self.get_sampler()
+ sample_gt_list = []
+ sample_list = []
+ cond_vis = []
+ sample_metadata = []
+
+ # Camera params for GT view rendering
+ camera_distances = []
+ camera_angles = []
+ mesh_scales = []
+
+ for i in range(0, num_samples, batch_size):
+ batch = min(batch_size, num_samples - i)
+ data = next(iter(dataloader))
+ data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
+
+ # Collect metadata (dataset_name and sha256) for wandb display
+ if '_dataset_name' in data and '_sha256' in data:
+ for j in range(batch):
+ sample_metadata.append(f"{data['_dataset_name'][j]}/{data['_sha256'][j]}")
+
+ # Remove metadata fields before inference
+ data.pop('_dataset_name', None)
+ data.pop('_sha256', None)
+
+ noise = torch.randn_like(data['x_0'])
+
+ # Save GT sample
+ sample_gt_list.append(data['x_0'])
+ cond_vis.append(self.vis_cond(**data))
+
+ # Save camera parameters for GT view rendering (if available)
+ if 'camera_distance' in data:
+ camera_distances.append(data['camera_distance'])
+ if 'camera_angle_x' in data:
+ camera_angles.append(data['camera_angle_x'])
+ if 'mesh_scale' in data:
+ mesh_scales.append(data['mesh_scale'])
+
+ # Remove x_0 before inference
+ del data['x_0']
+ args = self.get_inference_cond(**data)
+ res = sampler.sample(
+ self.models['denoiser'],
+ noise=noise,
+ **args,
+ steps=50, guidance_strength=3.0, verbose=verbose,
+ )
+ sample_list.append(res.samples)
+
+ # Concatenate samples
+ sample_gt = torch.cat(sample_gt_list, dim=0)
+ sample = torch.cat(sample_list, dim=0)
+
+ # Build sample dicts with camera info for GT view rendering
+ sample_gt_value = {'x_0': sample_gt}
+ sample_value = {'x_0': sample}
+
+ # Add camera params if available
+ if len(camera_distances) > 0:
+ camera_distance = torch.cat(camera_distances, dim=0)
+ sample_gt_value['camera_distance'] = camera_distance
+ sample_value['camera_distance'] = camera_distance
+ if len(camera_angles) > 0:
+ camera_angle_x = torch.cat(camera_angles, dim=0)
+ sample_gt_value['camera_angle_x'] = camera_angle_x
+ sample_value['camera_angle_x'] = camera_angle_x
+ if len(mesh_scales) > 0:
+ mesh_scale = torch.cat(mesh_scales, dim=0)
+ sample_gt_value['mesh_scale'] = mesh_scale
+ sample_value['mesh_scale'] = mesh_scale
+
+ sample_dict = {
+ 'sample_gt': {'value': sample_gt_value, 'type': 'sample'},
+ 'sample': {'value': sample_value, 'type': 'sample'},
+ }
+ if sample_metadata:
+ sample_dict['_metadata'] = sample_metadata
+ sample_dict.update(dict_reduce(cond_vis, None, {
+ 'value': lambda x: torch.cat(x, dim=0),
+ 'type': lambda x: x[0],
+ }))
+
+ return sample_dict
+
+ @torch.no_grad()
+ def visualize_sample(self, sample):
+ """
+ Convert a sample to images, including GT camera view if available.
+
+ Args:
+ sample: Either a tensor or dict containing:
+ - 'x_0': latent tensor [B, C, D, H, W]
+ - 'camera_angle_x': [B] (optional)
+ - 'camera_distance': [B] (optional)
+ - 'mesh_scale': [B] (optional)
+
+ Returns:
+ dict with visualization images or tensor
+ """
+ if hasattr(self.dataset, 'visualize_sample'):
+ if isinstance(sample, dict):
+ # Extract camera params if available
+ camera_angle_x = sample.get('camera_angle_x')
+ camera_distance = sample.get('camera_distance')
+ mesh_scale = sample.get('mesh_scale')
+ x_0 = sample.get('x_0', sample)
+
+ return self.dataset.visualize_sample(
+ x_0,
+ camera_angle_x=camera_angle_x,
+ camera_distance=camera_distance,
+ mesh_scale=mesh_scale,
+ )
+ else:
+ return self.dataset.visualize_sample(sample)
+ else:
+ if isinstance(sample, dict):
+ return sample.get('x_0', sample)
+ return sample
+
+ def run(self):
+ """
+ Run training with projection visualization test before starting.
+ """
+ # Run projection visualization test before training starts (if enabled)
+ if self.run_projection_test and self.is_master:
+ print('\n' + '='*60)
+ print('Running projection visualization test...')
+ print('='*60)
+ self._run_projection_visualization_test()
+
+ super().run()
+
+ @torch.no_grad()
+ def _run_projection_visualization_test(self, num_samples: int = 4):
+ """
+ Run projection visualization test on a few samples before training starts.
+
+ This helps verify that the 3D-to-2D projection is working correctly.
+ """
+ import os
+ from torch.utils.data import DataLoader
+
+ # Create a small dataloader
+ dataloader = DataLoader(
+ self.dataset,
+ batch_size=min(num_samples, self.snapshot_batch_size),
+ shuffle=True,
+ num_workers=0,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ )
+
+ # Get one batch
+ data = next(iter(dataloader))
+ data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
+
+ # Extract condition image
+ cond = data.get('cond')
+ if cond is None:
+ print("Warning: No 'cond' field in data, skipping projection visualization test")
+ return
+
+ # Save directory
+ save_dir = os.path.join(self.output_dir, 'samples', 'projection_test')
+
+ # Call visualization method
+ if hasattr(self, 'visualize_projection_test'):
+ # Need to pass camera info as kwargs
+ kwargs = {k: v for k, v in data.items() if k != 'cond' and k != 'x_0'}
+ self.visualize_projection_test(
+ cond=cond,
+ save_dir=save_dir,
+ prefix="proj_test",
+ **kwargs
+ )
+ print(f"Projection visualization saved to: {save_dir}")
+ else:
+ print("Warning: visualize_projection_test not available")
diff --git a/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py b/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ae3a4ed19791e3f5f8e44eae5ceed02fcd4fc83
--- /dev/null
+++ b/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py
@@ -0,0 +1,60 @@
+import torch
+import numpy as np
+from ....utils.general_utils import dict_foreach
+from ....pipelines import samplers
+
+
+class ClassifierFreeGuidanceMixin:
+ def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.p_uncond = p_uncond
+
+ def get_cond(self, cond, neg_cond=None, **kwargs):
+ """
+ Get the conditioning data.
+ """
+ assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
+
+ if self.p_uncond > 0:
+ # randomly drop the class label
+ def get_batch_size(cond):
+ if isinstance(cond, torch.Tensor):
+ return cond.shape[0]
+ elif isinstance(cond, list):
+ return len(cond)
+ else:
+ raise ValueError(f"Unsupported type of cond: {type(cond)}")
+
+ ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
+ B = get_batch_size(ref_cond)
+
+ def select(cond, neg_cond, mask):
+ if isinstance(cond, torch.Tensor):
+ mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
+ return torch.where(mask, neg_cond, cond)
+ elif isinstance(cond, list):
+ return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
+ else:
+ raise ValueError(f"Unsupported type of cond: {type(cond)}")
+
+ mask = list(np.random.rand(B) < self.p_uncond)
+ if not isinstance(cond, dict):
+ cond = select(cond, neg_cond, mask)
+ else:
+ # Apply select to each key in the dict
+ cond = {k: select(cond[k], neg_cond[k], mask) for k in cond.keys()}
+
+ return cond
+
+ def get_inference_cond(self, cond, neg_cond=None, **kwargs):
+ """
+ Get the conditioning data for inference.
+ """
+ assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
+ return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
+
+ def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
+ """
+ Get the sampler for the diffusion process.
+ """
+ return samplers.FlowEulerCfgSampler(self.sigma_min)
diff --git a/trellis2/trainers/flow_matching/mixins/image_conditioned.py b/trellis2/trainers/flow_matching/mixins/image_conditioned.py
new file mode 100644
index 0000000000000000000000000000000000000000..13b932eb714084ad7fc4ac56d3136d5c90db4350
--- /dev/null
+++ b/trellis2/trainers/flow_matching/mixins/image_conditioned.py
@@ -0,0 +1,249 @@
+from typing import *
+import torch
+import torch.nn.functional as F
+from torchvision import transforms
+from transformers import DINOv3ViTModel
+import numpy as np
+from PIL import Image
+
+from ....utils import dist_utils
+
+
+class DinoV2FeatureExtractor:
+ """
+ Feature extractor for DINOv2 models.
+ """
+ def __init__(self, model_name: str):
+ self.model_name = model_name
+ self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True)
+ self.model.eval()
+ self.transform = transforms.Compose([
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ])
+
+ def to(self, device):
+ self.model.to(device)
+
+ def cuda(self):
+ self.model.cuda()
+
+ def cpu(self):
+ self.model.cpu()
+
+ @torch.no_grad()
+ def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
+ """
+ Extract features from the image.
+
+ Args:
+ image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
+
+ Returns:
+ A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
+ """
+ if isinstance(image, torch.Tensor):
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
+ elif isinstance(image, list):
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
+ image = [i.resize((518, 518), Image.LANCZOS) for i in image]
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
+ image = torch.stack(image).cuda()
+ else:
+ raise ValueError(f"Unsupported type of image: {type(image)}")
+
+ image = self.transform(image).cuda()
+ features = self.model(image, is_training=True)['x_prenorm']
+ patchtokens = F.layer_norm(features, features.shape[-1:])
+ return patchtokens
+
+
+class DinoV3FeatureExtractor:
+ """
+ Feature extractor for DINOv3 models.
+ """
+ def __init__(self, model_name: str, image_size=512):
+ self.model_name = model_name
+ self.model = DINOv3ViTModel.from_pretrained(model_name)
+ self.model.eval()
+ self.image_size = image_size
+ self.transform = transforms.Compose([
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ])
+
+ def to(self, device):
+ self.model.to(device)
+
+ def cuda(self):
+ self.model.cuda()
+
+ def cpu(self):
+ self.model.cpu()
+
+ def extract_features(self, image: torch.Tensor) -> torch.Tensor:
+ image = image.to(self.model.embeddings.patch_embeddings.weight.dtype)
+ hidden_states = self.model.embeddings(image, bool_masked_pos=None)
+ position_embeddings = self.model.rope_embeddings(image)
+
+ for i, layer_module in enumerate(self.model.layer):
+ hidden_states = layer_module(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ )
+
+ return F.layer_norm(hidden_states, hidden_states.shape[-1:])
+
+ @torch.no_grad()
+ def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
+ """
+ Extract features from the image.
+
+ Args:
+ image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
+
+ Returns:
+ A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
+ """
+ if isinstance(image, torch.Tensor):
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
+ elif isinstance(image, list):
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
+ image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image]
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
+ image = torch.stack(image).cuda()
+ else:
+ raise ValueError(f"Unsupported type of image: {type(image)}")
+
+ image = self.transform(image).cuda()
+ features = self.extract_features(image)
+ return features
+
+
+class ImageConditionedMixin:
+ """
+ Mixin for image-conditioned models.
+
+ Args:
+ image_cond_model: The image conditioning model.
+ """
+ def __init__(self, *args, image_cond_model: dict, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.image_cond_model_config = image_cond_model
+ self.image_cond_model = None # the model is init lazily
+
+ def _init_image_cond_model(self):
+ """
+ Initialize the image conditioning model.
+ """
+ with dist_utils.local_master_first():
+ self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {}))
+ self.image_cond_model.cuda()
+
+ @torch.no_grad()
+ def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
+ """
+ Encode the image.
+ """
+ if self.image_cond_model is None:
+ self._init_image_cond_model()
+ features = self.image_cond_model(image)
+ return features
+
+ def get_cond(self, cond, **kwargs):
+ """
+ Get the conditioning data.
+ """
+ cond = self.encode_image(cond)
+ kwargs['neg_cond'] = torch.zeros_like(cond)
+ cond = super().get_cond(cond, **kwargs)
+ return cond
+
+ def get_inference_cond(self, cond, **kwargs):
+ """
+ Get the conditioning data for inference.
+ """
+ cond = self.encode_image(cond)
+ kwargs['neg_cond'] = torch.zeros_like(cond)
+ cond = super().get_inference_cond(cond, **kwargs)
+ return cond
+
+ def vis_cond(self, cond, **kwargs):
+ """
+ Visualize the conditioning data.
+ """
+ return {'image': {'value': cond, 'type': 'image'}}
+
+
+class MultiImageConditionedMixin:
+ """
+ Mixin for multiple-image-conditioned models.
+
+ Args:
+ image_cond_model: The image conditioning model.
+ """
+ def __init__(self, *args, image_cond_model: dict, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.image_cond_model_config = image_cond_model
+ self.image_cond_model = None # the model is init lazily
+
+ def _init_image_cond_model(self):
+ """
+ Initialize the image conditioning model.
+ """
+ with dist_utils.local_master_first():
+ self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {}))
+
+ @torch.no_grad()
+ def encode_images(self, images: Union[List[torch.Tensor], List[List[Image.Image]]]) -> List[torch.Tensor]:
+ """
+ Encode the image.
+ """
+ if self.image_cond_model is None:
+ self._init_image_cond_model()
+ seqlen = [len(i) for i in images]
+ images = torch.cat(images, dim=0) if isinstance(images[0], torch.Tensor) else sum(images, [])
+ features = self.image_cond_model(images)
+ features = torch.split(features, seqlen)
+ features = [feature.reshape(-1, feature.shape[-1]) for feature in features]
+ return features
+
+ def get_cond(self, cond, **kwargs):
+ """
+ Get the conditioning data.
+ """
+ cond = self.encode_images(cond)
+ kwargs['neg_cond'] = [
+ torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond))
+ ]
+ cond = super().get_cond(cond, **kwargs)
+ return cond
+
+ def get_inference_cond(self, cond, **kwargs):
+ """
+ Get the conditioning data for inference.
+ """
+ cond = self.encode_images(cond)
+ kwargs['neg_cond'] = [
+ torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond))
+ ]
+ cond = super().get_inference_cond(cond, **kwargs)
+ return cond
+
+ def vis_cond(self, cond, **kwargs):
+ """
+ Visualize the conditioning data.
+ """
+ H, W = cond[0].shape[-2:]
+ vis = []
+ for images in cond:
+ canvas = torch.zeros(3, H * 2, W * 2, device=images.device, dtype=images.dtype)
+ for i, image in enumerate(images):
+ if i == 4:
+ break
+ kh = i // 2
+ kw = i % 2
+ canvas[:, kh*H:(kh+1)*H, kw*W:(kw+1)*W] = image
+ vis.append(canvas)
+ vis = torch.stack(vis)
+ return {'image': {'value': vis, 'type': 'image'}}
diff --git a/trellis2/trainers/flow_matching/mixins/image_conditioned_proj.py b/trellis2/trainers/flow_matching/mixins/image_conditioned_proj.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e9f326ab066db0c6c94cd31267d9edf10d4c64b
--- /dev/null
+++ b/trellis2/trainers/flow_matching/mixins/image_conditioned_proj.py
@@ -0,0 +1,1530 @@
+"""
+View-Aligned (Projection) Image Conditioned Mixin for TRELLIS2
+
+This module implements DINOv3-based feature extraction with view-aligned projection,
+supporting camera-aware 3D-to-2D feature mapping.
+"""
+
+from typing import *
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import transforms
+from transformers import DINOv3ViTModel
+import numpy as np
+from PIL import Image, ImageDraw
+
+import torch.distributed as dist
+from ....utils import dist_utils
+from ....utils.dist_utils import read_file_dist
+
+
+# =============================================================================
+# Projection Utilities
+# =============================================================================
+
+def project_points_to_image_batch(
+ points_3d: torch.Tensor,
+ transform_matrix: torch.Tensor,
+ camera_angle_x: torch.Tensor,
+ resolution: int = 518
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Project 3D points to 2D image coordinates (batch processing).
+
+ Args:
+ points_3d: torch.Tensor, shape [N, 3] or [B, N, 3], 3D point coordinates (in [-1, 1] range)
+ transform_matrix: torch.Tensor, shape [B, 4, 4], camera transformation matrix
+ camera_angle_x: torch.Tensor, shape [B], horizontal field of view angle (radians)
+ resolution: int, image resolution, default 518
+
+ Returns:
+ points_2d: torch.Tensor, shape [B, N, 2], image coordinates [x, y]
+ depth: torch.Tensor, shape [B, N], depth values
+ valid_mask: torch.Tensor, shape [B, N], mask for points within view
+ """
+ device = points_3d.device
+ B = transform_matrix.shape[0]
+
+ # Ensure inputs are torch.Tensor on correct device
+ if not isinstance(transform_matrix, torch.Tensor):
+ transform_matrix = torch.tensor(transform_matrix, dtype=torch.float32, device=device)
+ if not isinstance(points_3d, torch.Tensor):
+ points_3d = torch.tensor(points_3d, dtype=torch.float32, device=device)
+ if not isinstance(camera_angle_x, torch.Tensor):
+ camera_angle_x = torch.tensor(camera_angle_x, dtype=torch.float32, device=device)
+
+ # Expand points_3d to batch dimension: [N, 3] -> [B, N, 3]
+ if points_3d.dim() == 2:
+ points_3d_batch = points_3d.unsqueeze(0).expand(B, -1, -1)
+ else:
+ points_3d_batch = points_3d
+ N = points_3d_batch.shape[1]
+
+ # Add homogeneous coordinates: [B, N, 3] -> [B, N, 4]
+ ones = torch.ones(B, N, 1, device=device, dtype=points_3d_batch.dtype)
+ points_homogeneous = torch.cat([points_3d_batch, ones], dim=-1) # [B, N, 4]
+
+ # Compute world to camera transformation matrix
+ world_to_camera = torch.linalg.inv(transform_matrix) # [B, 4, 4]
+
+ # Batch transform to camera coordinate system: [B, N, 4] @ [B, 4, 4]^T -> [B, N, 3]
+ points_camera = torch.bmm(points_homogeneous, world_to_camera.transpose(-2, -1))[..., :3] # [B, N, 3]
+
+ # Extract camera coordinates
+ x_cam = points_camera[..., 0] # [B, N]
+ y_cam = points_camera[..., 1] # [B, N]
+ z_cam = points_camera[..., 2] # [B, N]
+
+ # Depth value (Z value in camera coordinate system, note Blender camera faces -Z direction)
+ depth = -z_cam # [B, N]
+
+ # Compute camera intrinsics (batch processing)
+ sensor_width = 32.0 # mm
+ focal_length = 16.0 / torch.tan(camera_angle_x / 2.0) # [B]
+ focal_length_pixels = focal_length * resolution / sensor_width # [B]
+
+ # Expand focal_length_pixels dimension for broadcasting: [B] -> [B, 1]
+ focal_length_pixels = focal_length_pixels.unsqueeze(1) # [B, 1]
+
+ # Perspective projection to NDC coordinates
+ x_ndc = focal_length_pixels * x_cam / (-z_cam + 1e-8) # [B, N]
+ y_ndc = focal_length_pixels * y_cam / (-z_cam + 1e-8) # [B, N]
+
+ # Convert to image coordinates (pixel coordinates)
+ x_pixel = x_ndc + resolution / 2.0 # [B, N]
+ y_pixel = -y_ndc + resolution / 2.0 # [B, N], flip Y axis
+
+ # Create validity mask (points within image range and in front of camera)
+ valid_mask = (
+ (x_pixel >= 0) & (x_pixel < resolution) &
+ (y_pixel >= 0) & (y_pixel < resolution) &
+ (depth > 0) # In front of camera
+ ) # [B, N]
+
+ points_2d = torch.stack([x_pixel, y_pixel], dim=-1) # [B, N, 2]
+
+ return points_2d, depth, valid_mask
+
+
+def sample_features(fmap: torch.Tensor, queries_ndc: torch.Tensor) -> torch.Tensor:
+ """
+ Sample features from feature map at specified NDC coordinates.
+
+ Args:
+ fmap: torch.Tensor, shape [B, C, H, W], feature map
+ queries_ndc: torch.Tensor, shape [B, K, 2], normalized device coordinates
+
+ Returns:
+ torch.Tensor, shape [B, C, K], sampled features
+ """
+ B, C, H, W = fmap.shape
+ Bq, K, _ = queries_ndc.shape
+ assert Bq == B, "Batch size mismatch"
+
+ # grid_sample requires (B, out_h, out_w, 2), here we want K points -> out_h=K, out_w=1
+ grid = queries_ndc.view(B, K, 1, 2) # (B, K, 1, 2)
+
+ # Bilinear interpolation, align_corners=False (consistent with [-1,1] pixel center convention)
+ feat = F.grid_sample(
+ fmap, grid, mode='bilinear',
+ align_corners=False, padding_mode='border' # border avoids out-of-bound becoming 0
+ ) # (B, C, K, 1)
+
+ return feat.squeeze(-1) # (B, C, K)
+
+
+# =============================================================================
+# Projection Grid Module
+# =============================================================================
+
+class ProjGrid(nn.Module):
+ """
+ 3D Grid Projection Module.
+
+ Projects a 3D grid of points to 2D image coordinates and samples features
+ from the image feature map at those locations.
+
+ This is the core module for view-aligned feature extraction.
+ """
+ def __init__(self, grid_resolution: int = 16, image_resolution: int = 518):
+ super().__init__()
+ self.grid_resolution = grid_resolution
+ self.image_resolution = image_resolution
+
+ # Create 3D grid points
+ one_dim = torch.linspace(-1, 1, grid_resolution)
+ x, y, z = torch.meshgrid(one_dim, one_dim, one_dim, indexing='ij')
+ grid_points = torch.stack((x, y, z), dim=-1)
+
+ # Rotation matrix to align with Blender coordinate system
+ rotation_matrix = torch.tensor([
+ [1.0, 0.0, 0.0],
+ [0.0, 0.0, -1.0],
+ [0.0, 1.0, 0.0]
+ ])
+ grid_points = torch.matmul(grid_points, rotation_matrix.T)
+ grid_points = grid_points.reshape(-1, 3)
+ self.register_buffer('grid_points', grid_points) # [R³, 3]
+
+ # Default front view transformation matrix
+ front_view_transform_matrix = torch.tensor([
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, -1.0, -2.0],
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0]
+ ])
+ self.register_buffer("front_view_transform_matrix", front_view_transform_matrix)
+
+ def forward(
+ self,
+ features_map: torch.Tensor,
+ camera_angle_x: torch.Tensor,
+ distance: torch.Tensor,
+ mesh_scale: torch.Tensor,
+ transform_matrix: Optional[torch.Tensor] = None,
+ BHWC: bool = True
+ ) -> torch.Tensor:
+ """
+ Project 3D grid points to image and sample features.
+
+ Args:
+ features_map: Feature map, shape [B, H, W, C] if BHWC else [B, C, H, W]
+ camera_angle_x: Camera FOV angle, shape [B]
+ distance: Camera distance, shape [B]
+ mesh_scale: Mesh scale factor, shape [B]
+ transform_matrix: Optional camera transform matrix, shape [B, 4, 4]
+ BHWC: Whether features_map is in BHWC format
+
+ Returns:
+ Projected features, shape [B, grid_resolution³, C]
+ """
+ if BHWC:
+ B, H, W, C = features_map.shape
+ else:
+ B, C, H, W = features_map.shape
+
+ grid_points = self.grid_points
+ grid_points = grid_points.expand(B, -1, -1)
+ grid_points = grid_points / mesh_scale.unsqueeze(-1).unsqueeze(-1) / 2 # Scale alignment
+ assert transform_matrix is None, "transform_matrix is not None"
+ if transform_matrix is None:
+ transform_matrix = self.front_view_transform_matrix
+ transform_matrix = transform_matrix.expand(B, -1, -1).clone()
+ transform_matrix[:, 1, 3] = -distance # Set camera distance
+
+ # Project to image coordinates (simulate Blender projection)
+ image_points, depth, valid_mask = project_points_to_image_batch(
+ grid_points, transform_matrix, camera_angle_x, self.image_resolution
+ )
+
+ # Normalize to [-1, 1] for grid_sample
+ image_points_norm = (image_points + 0.5) / self.image_resolution * 2 - 1
+
+ if BHWC:
+ features_map = features_map.permute(0, 3, 1, 2) # [B, C, H, W]
+
+ # Sample features from DINOv3 patch feature map
+ x = sample_features(features_map, image_points_norm) # [B, C, K]
+ x = x.permute(0, 2, 1) # [B, K, C]
+
+ return x
+
+ def visualize_projection(
+ self,
+ image: torch.Tensor,
+ camera_angle_x: torch.Tensor,
+ distance: torch.Tensor,
+ mesh_scale: torch.Tensor,
+ transform_matrix: Optional[torch.Tensor] = None,
+ save_dir: Optional[str] = None,
+ prefix: str = "proj_vis",
+ ) -> List[Image.Image]:
+ """
+ Visualize the projected 3D grid points on the input image.
+
+ Args:
+ image: Input image tensor [B, C, H, W], assumed to be in [0, 1] range
+ camera_angle_x: Camera FOV angle, shape [B]
+ distance: Camera distance, shape [B]
+ mesh_scale: Mesh scale factor, shape [B]
+ transform_matrix: Optional camera transform matrix, shape [B, 4, 4]
+ save_dir: Directory to save visualizations (optional)
+ prefix: Prefix for saved files
+
+ Returns:
+ List of PIL Images with projected points overlaid
+ """
+ B = image.shape[0]
+
+ # Get projected points
+ grid_points = self.grid_points.expand(B, -1, -1)
+ grid_points = grid_points / mesh_scale.unsqueeze(-1).unsqueeze(-1) / 2
+ assert transform_matrix is None, "transform_matrix is not None"
+ if transform_matrix is None:
+ transform_matrix = self.front_view_transform_matrix
+ transform_matrix = transform_matrix.expand(B, -1, -1).clone()
+ transform_matrix[:, 1, 3] = -distance
+
+ image_points, depth, valid_mask = project_points_to_image_batch(
+ grid_points, transform_matrix, camera_angle_x, self.image_resolution
+ )
+
+ # Convert image to PIL for visualization
+ vis_images = []
+ for b in range(B):
+ # Convert tensor to PIL image
+ img_np = image[b].cpu().permute(1, 2, 0).numpy()
+ img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
+
+ # Resize to image_resolution if needed
+ pil_img = Image.fromarray(img_np)
+ if pil_img.size != (self.image_resolution, self.image_resolution):
+ pil_img = pil_img.resize((self.image_resolution, self.image_resolution), Image.LANCZOS)
+
+ # Create a copy for drawing
+ vis_img = pil_img.copy()
+ draw = ImageDraw.Draw(vis_img)
+
+ # Get points for this batch
+ pts = image_points[b].cpu().numpy() # [K, 2]
+ depths = depth[b].cpu().numpy() # [K]
+ mask = valid_mask[b].cpu().numpy() # [K]
+
+ # Normalize depth for coloring
+ valid_depths = depths[mask]
+ if len(valid_depths) > 0:
+ d_min, d_max = valid_depths.min(), valid_depths.max()
+ if d_max - d_min > 1e-6:
+ depths_norm = (depths - d_min) / (d_max - d_min)
+ else:
+ depths_norm = np.ones_like(depths) * 0.5
+ else:
+ depths_norm = np.ones_like(depths) * 0.5
+
+ # Draw projected points
+ R = self.grid_resolution
+ for i, (pt, d, m, dn) in enumerate(zip(pts, depths, mask, depths_norm)):
+ if not m:
+ continue
+
+ x, y = pt
+
+ # Color by depth (blue=near, red=far)
+ r = int(255 * dn)
+ g = int(255 * (1 - abs(2 * dn - 1)))
+ b_color = int(255 * (1 - dn))
+ color = (r, g, b_color)
+
+ # Draw small circle
+ radius = 2
+ draw.ellipse(
+ [x - radius, y - radius, x + radius, y + radius],
+ fill=color,
+ outline=color
+ )
+
+ vis_images.append(vis_img)
+
+ # Save if directory is specified
+ if save_dir is not None:
+ os.makedirs(save_dir, exist_ok=True)
+ save_path = os.path.join(save_dir, f"{prefix}_batch{b}.png")
+ vis_img.save(save_path)
+ print(f"Saved projection visualization to: {save_path}")
+
+ return vis_images
+
+
+# =============================================================================
+# DINOv3 Feature Extractor with Projection
+# =============================================================================
+
+class DinoV3ProjFeatureExtractor(nn.Module):
+ """
+ DINOv3 Feature Extractor with View-Aligned Projection.
+
+ This extractor produces both:
+ 1. Global features (CLS token + register tokens) in embed_dim
+ 2. View-aligned projected features (3D grid projected to 2D and sampled)
+ - Without NAF: [B, R³, embed_dim]
+ - With NAF: [B, R³, embed_dim * 2] (concat of lr and hr features)
+
+ NOTE: proj_linear has been moved to per-block ProjectAttention / SparseProjectAttention.
+ This module now outputs raw DINOv3 features for proj (optionally concatenated with NAF-upsampled features).
+
+ Args:
+ model_name: Name of the pretrained DINOv3 model
+ image_size: Input image size (default: 512)
+ grid_resolution: Resolution of the 3D projection grid (default: 16)
+ use_naf_upsample: Whether to use NAF to upsample features (default: False)
+ naf_target_size: Target spatial size for NAF upsampling (default: [128, 128])
+ """
+ def __init__(
+ self,
+ model_name: str,
+ image_size: int = 512,
+ grid_resolution: int = 16,
+ use_naf_upsample: bool = False,
+ naf_target_size: Optional[List[int]] = None,
+ ):
+ super().__init__()
+ self.model_name = model_name
+ self.image_size = image_size
+ self.grid_resolution = grid_resolution
+ self.use_naf_upsample = use_naf_upsample
+ if naf_target_size is None:
+ self.naf_target_size = (128, 128)
+ elif isinstance(naf_target_size, int):
+ self.naf_target_size = (naf_target_size, naf_target_size)
+ else:
+ self.naf_target_size = tuple(naf_target_size)
+
+ # Load DINOv3 model (frozen, no trainable params in this module)
+ self.model = DINOv3ViTModel.from_pretrained(model_name)
+ self.model.eval()
+ self.model.requires_grad_(False)
+
+ # Image transform (only normalize, no resize - assume already resized)
+ self.transform = transforms.Compose([
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ])
+
+ # Get patch info
+ self.patch_size = self.model.config.patch_size
+ self.patch_number = image_size // self.patch_size
+ self.embed_dim = self.model.config.hidden_size
+
+ # Projection grid for view-aligned features
+ self.proj_grid = ProjGrid(
+ grid_resolution=grid_resolution,
+ image_resolution=image_size
+ )
+
+ # NAF upsampler (frozen, no trainable params)
+ self.naf_model = None # Lazy-loaded on first use to avoid import if not needed
+
+ # proj_channels: the output dimension of proj features
+ # Without NAF: embed_dim (e.g. 1024)
+ # With NAF: embed_dim * 2 (e.g. 2048, concat of lr and hr)
+ self.proj_channels = self.embed_dim * 2 if use_naf_upsample else self.embed_dim
+
+ # NOTE: proj_linear removed — now lives in each denoiser block's ProjectAttention
+
+ def _load_naf(self):
+ """Lazy-load pretrained NAF model."""
+ if self.naf_model is None:
+ import torch.hub
+ device = next(self.model.parameters()).device
+ self.naf_model = torch.hub.load(
+ "valeoai/NAF", "naf", pretrained=True, device=device, trust_repo=True
+ )
+ self.naf_model.eval()
+ self.naf_model.requires_grad_(False)
+
+ def to(self, device):
+ super().to(device)
+ self.model.to(device)
+ self.proj_grid.to(device)
+ if self.naf_model is not None:
+ self.naf_model.to(device)
+ return self
+
+ def cuda(self):
+ super().cuda()
+ self.model.cuda()
+ self.proj_grid.cuda()
+ if self.naf_model is not None:
+ self.naf_model.cuda()
+ return self
+
+ def cpu(self):
+ super().cpu()
+ self.model.cpu()
+ self.proj_grid.cpu()
+ if self.naf_model is not None:
+ self.naf_model.cpu()
+ return self
+
+ def extract_features(self, image: torch.Tensor) -> torch.Tensor:
+ """Extract features using DINOv3."""
+ image = image.to(self.model.embeddings.patch_embeddings.weight.dtype)
+ hidden_states = self.model.embeddings(image, bool_masked_pos=None)
+ position_embeddings = self.model.rope_embeddings(image)
+
+ for layer_module in self.model.layer:
+ hidden_states = layer_module(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ )
+
+ return F.layer_norm(hidden_states, hidden_states.shape[-1:])
+
+ def forward(
+ self,
+ image: Union[torch.Tensor, List[Image.Image]],
+ camera_angle_x: Optional[torch.Tensor] = None,
+ distance: Optional[torch.Tensor] = None,
+ mesh_scale: Optional[torch.Tensor] = None,
+ transform_matrix: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Extract view-aligned features from the image.
+
+ Args:
+ image: Input image tensor [B, C, H, W] or list of PIL images
+ camera_angle_x: Camera FOV angle in radians [B]
+ distance: Camera distance [B]
+ mesh_scale: Mesh scale factor [B]
+ transform_matrix: Optional camera transform matrix [B, 4, 4]
+
+ Returns:
+ Tuple of (global_features, proj_features):
+ - global_features: [B, num_global_tokens, embed_dim]
+ - proj_features: [B, grid_resolution³, proj_channels]
+ where proj_channels = embed_dim (no NAF) or embed_dim*2 (with NAF)
+ """
+ # Handle input types
+ if isinstance(image, torch.Tensor):
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
+ elif isinstance(image, list):
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
+ image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image]
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
+ image = torch.stack(image).cuda()
+ else:
+ raise ValueError(f"Unsupported type of image: {type(image)}")
+
+ B = image.shape[0]
+
+ # Keep a copy of the unnormalized image for NAF guide
+ if self.use_naf_upsample:
+ image_for_naf = image.clone() # [B, 3, H, W], in [0, 1] range
+
+ # Apply transform (ImageNet normalization)
+ image = self.transform(image)
+
+ # Extract DINOv3 features (frozen, no gradients)
+ with torch.no_grad():
+ z = self.extract_features(image)
+
+ # Split into CLS token, register tokens, and patch tokens
+ z_clstoken = z[:, 0:1] # [B, 1, D]
+ num_reg = getattr(self.model.config, 'num_register_tokens', 4)
+ z_regtokens = z[:, 1:1+num_reg] # [B, num_reg, D]
+ z_patchtokens = z[:, 1+num_reg:] # [B, num_patches, D]
+
+ # Reshape patch tokens to spatial grid: [B, h, w, D]
+ z_patchtokens_spatial = z_patchtokens.reshape(
+ B, self.patch_number, self.patch_number, -1
+ ) # [B, h, w, D]
+
+ if camera_angle_x is None or distance is None or mesh_scale is None:
+ raise ValueError("camera_angle_x, distance, and mesh_scale must be provided")
+
+ # --- Low-resolution branch: sample from DINOv3 patch feature map ---
+ z_proj_lr = self.proj_grid(
+ z_patchtokens_spatial,
+ camera_angle_x,
+ distance,
+ mesh_scale,
+ transform_matrix
+ ) # [B, grid_res³, D]
+
+ # --- High-resolution branch (NAF): upsample then sample ---
+ if self.use_naf_upsample:
+ self._load_naf()
+ # NAF expects: guide [B, 3, H, W], lr_features [B, C, h, w], target_size (H', W')
+ lr_features_bchw = z_patchtokens_spatial.permute(0, 3, 1, 2) # [B, D, h, w]
+ hr_features = self.naf_model(
+ image_for_naf, lr_features_bchw, self.naf_target_size
+ ) # [B, D, H', W']
+
+ # Sample from high-res feature map using same projection coordinates
+ z_proj_hr = self.proj_grid(
+ hr_features,
+ camera_angle_x,
+ distance,
+ mesh_scale,
+ transform_matrix,
+ BHWC=False # hr_features is [B, C, H', W']
+ ) # [B, grid_res³, D]
+
+ # Concatenate lr and hr: [B, grid_res³, D*2]
+ z_proj = torch.cat([z_proj_lr, z_proj_hr], dim=-1)
+ else:
+ z_proj = z_proj_lr # [B, grid_res³, D]
+
+ # Combine global tokens
+ z_global = torch.cat([z_clstoken, z_regtokens], dim=1) # [B, 1+num_reg, D]
+
+ # proj_linear has been moved to per-block ProjectAttention
+ # z_proj stays in proj_channels, each block will project independently
+
+ return z_global, z_proj
+
+ @torch.no_grad()
+ def visualize_projection(
+ self,
+ image: torch.Tensor,
+ camera_angle_x: torch.Tensor,
+ distance: torch.Tensor,
+ mesh_scale: torch.Tensor,
+ transform_matrix: Optional[torch.Tensor] = None,
+ save_dir: Optional[str] = None,
+ prefix: str = "proj_vis",
+ ) -> List[Image.Image]:
+ """
+ Visualize the projected 3D grid points on the input image.
+
+ This is a convenience method that delegates to ProjGrid.visualize_projection.
+
+ Args:
+ image: Input image tensor [B, C, H, W], in [0, 1] range (before ImageNet normalization)
+ camera_angle_x: Camera FOV angle, shape [B]
+ distance: Camera distance, shape [B]
+ mesh_scale: Mesh scale factor, shape [B]
+ transform_matrix: Optional camera transform matrix, shape [B, 4, 4]
+ save_dir: Directory to save visualizations (optional)
+ prefix: Prefix for saved files
+
+ Returns:
+ List of PIL Images with projected points overlaid
+ """
+ return self.proj_grid.visualize_projection(
+ image=image,
+ camera_angle_x=camera_angle_x,
+ distance=distance,
+ mesh_scale=mesh_scale,
+ transform_matrix=transform_matrix,
+ save_dir=save_dir,
+ prefix=prefix,
+ )
+
+
+# =============================================================================
+# DINOv3 + VAE Gated Feature Extractor with Projection
+# =============================================================================
+
+class DinoV3VaeProjFeatureExtractor(nn.Module):
+ """
+ DINOv3 + Flux VAE Feature Extractor with Gated Fusion and View-Aligned Projection.
+
+ Produces three outputs for GatedProjectAttention:
+ 1. Global features (CLS + register tokens from DINOv3) for cross-attention
+ 2. Semantic proj features (DINOv3 patch tokens projected to 3D grid)
+ 3. Color proj features (Flux VAE latent projected to 3D grid)
+
+ Both DINOv3 and VAE are frozen. The gated fusion happens inside each
+ denoiser block's GatedProjectAttention module (trainable gate + proj_linears).
+
+ Args:
+ dino_model_name: Pretrained DINOv3 model name
+ vae_model_name: Pretrained Flux VAE model name
+ image_size: Input image size (default: 512)
+ grid_resolution: Resolution of the 3D projection grid (default: 16)
+ """
+ def __init__(
+ self,
+ dino_model_name: str,
+ vae_model_name: str = "black-forest-labs/FLUX.1-dev",
+ image_size: int = 512,
+ grid_resolution: int = 16,
+ ):
+ super().__init__()
+ self.image_size = image_size
+ self.grid_resolution = grid_resolution
+
+ # --- DINOv3 backbone (frozen) ---
+ self.dino_model = DINOv3ViTModel.from_pretrained(dino_model_name)
+ self.dino_model.eval()
+ self.dino_model.requires_grad_(False)
+
+ self.dino_transform = transforms.Compose([
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ])
+
+ self.patch_size = self.dino_model.config.patch_size
+ self.patch_number = image_size // self.patch_size
+ self.embed_dim = self.dino_model.config.hidden_size # e.g. 1024
+
+ # --- Flux VAE encoder (frozen, lazy-loaded) ---
+ self.vae_model_name = vae_model_name
+ self._vae = None
+ self.vae_channels = 16 # Flux VAE outputs 16 channels
+ self.vae_downsample = 8 # Flux VAE downsamples by 8x
+
+ # --- Projection grid (shared) ---
+ self.proj_grid = ProjGrid(
+ grid_resolution=grid_resolution,
+ image_resolution=image_size,
+ )
+
+ # Expose dimensions for denoiser block construction
+ self.dino_proj_channels = self.embed_dim # e.g. 1024
+ self.vae_proj_channels = self.vae_channels # 16
+ # proj_channels is kept for backward compat with _proj_channels in mixin
+ self.proj_channels = self.embed_dim
+
+ def _load_vae(self):
+ """Lazy-load Flux VAE encoder."""
+ if self._vae is not None:
+ return
+ from diffusers import AutoencoderKL
+ device = next(self.dino_model.parameters()).device
+ vae = AutoencoderKL.from_pretrained(
+ self.vae_model_name,
+ subfolder="vae",
+ torch_dtype=torch.float32,
+ )
+ vae.eval()
+ vae.requires_grad_(False)
+ vae.to(device)
+ self._vae = vae
+
+ def to(self, device):
+ super().to(device)
+ self.dino_model.to(device)
+ self.proj_grid.to(device)
+ if self._vae is not None:
+ self._vae.to(device)
+ return self
+
+ def cuda(self):
+ super().cuda()
+ self.dino_model.cuda()
+ self.proj_grid.cuda()
+ if self._vae is not None:
+ self._vae.cuda()
+ return self
+
+ def cpu(self):
+ super().cpu()
+ self.dino_model.cpu()
+ self.proj_grid.cpu()
+ if self._vae is not None:
+ self._vae.cpu()
+ return self
+
+ def _extract_dino_features(self, image: torch.Tensor) -> torch.Tensor:
+ """Extract DINOv3 features from normalized image."""
+ image = image.to(self.dino_model.embeddings.patch_embeddings.weight.dtype)
+ hidden_states = self.dino_model.embeddings(image, bool_masked_pos=None)
+ position_embeddings = self.dino_model.rope_embeddings(image)
+ for layer_module in self.dino_model.layer:
+ hidden_states = layer_module(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ )
+ return F.layer_norm(hidden_states, hidden_states.shape[-1:])
+
+ @torch.no_grad()
+ def _extract_vae_latent(self, image: torch.Tensor) -> torch.Tensor:
+ """Extract Flux VAE latent from unnormalized image [0,1]."""
+ self._load_vae()
+ image_normalized = image * 2.0 - 1.0
+ image_normalized = image_normalized.to(self._vae.dtype)
+ posterior = self._vae.encode(image_normalized)
+ latent = posterior.latent_dist.mode()
+ latent = latent * self._vae.config.scaling_factor
+ return latent.float() # [B, 16, H/8, W/8]
+
+ def forward(
+ self,
+ image: Union[torch.Tensor, List[Image.Image]],
+ camera_angle_x: Optional[torch.Tensor] = None,
+ distance: Optional[torch.Tensor] = None,
+ mesh_scale: Optional[torch.Tensor] = None,
+ transform_matrix: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Extract gated features from the image.
+
+ Returns:
+ Tuple of (global_features, proj_semantic, proj_color):
+ - global_features: [B, num_global_tokens, embed_dim] (DINOv3 CLS + registers)
+ - proj_semantic: [B, grid_res³, embed_dim] (DINOv3 projected features)
+ - proj_color: [B, grid_res³, vae_channels] (VAE projected features)
+ """
+ # Handle input types
+ if isinstance(image, torch.Tensor):
+ assert image.ndim == 4
+ elif isinstance(image, list):
+ assert all(isinstance(i, Image.Image) for i in image)
+ image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image]
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
+ image = torch.stack(image).cuda()
+ else:
+ raise ValueError(f"Unsupported type of image: {type(image)}")
+
+ B = image.shape[0]
+ image_raw = image.clone() # Keep unnormalized copy for VAE
+
+ if camera_angle_x is None or distance is None or mesh_scale is None:
+ raise ValueError("camera_angle_x, distance, and mesh_scale must be provided")
+
+ with torch.no_grad():
+ # --- DINOv3 branch ---
+ dino_input = self.dino_transform(image)
+ z = self._extract_dino_features(dino_input)
+
+ z_clstoken = z[:, 0:1]
+ num_reg = getattr(self.dino_model.config, 'num_register_tokens', 4)
+ z_regtokens = z[:, 1:1+num_reg]
+ z_patchtokens = z[:, 1+num_reg:]
+
+ z_patchtokens_spatial = z_patchtokens.reshape(
+ B, self.patch_number, self.patch_number, -1
+ ) # [B, h, w, D]
+
+ proj_semantic = self.proj_grid(
+ z_patchtokens_spatial,
+ camera_angle_x, distance, mesh_scale, transform_matrix,
+ ) # [B, R³, embed_dim]
+
+ z_global = torch.cat([z_clstoken, z_regtokens], dim=1) # [B, 1+num_reg, D]
+
+ # --- VAE branch ---
+ vae_latent = self._extract_vae_latent(image_raw) # [B, 16, H/8, W/8]
+
+ proj_color = self.proj_grid(
+ vae_latent,
+ camera_angle_x, distance, mesh_scale, transform_matrix,
+ BHWC=False, # VAE latent is [B, C, H, W]
+ ) # [B, R³, 16]
+
+ return z_global, proj_semantic, proj_color
+
+
+# =============================================================================
+# Image Conditioned Mixin with Projection Support
+# =============================================================================
+
+class ImageConditionedProjMixin:
+ """
+ Mixin for image-conditioned models with view-aligned projection.
+
+ This mixin adds support for extracting view-aligned features from images
+ using camera parameters.
+
+ Args:
+ image_cond_model: Configuration for the image conditioning model.
+ """
+ def __init__(self, *args, image_cond_model: dict, **kwargs):
+ # Store config before super().__init__ which calls init_models_and_more
+ self.image_cond_model_config = image_cond_model
+ self.image_cond_model = None # Will be initialized in init_models_and_more
+ self.image_attn_mode = image_cond_model.get('image_attn_mode',
+ image_cond_model.get('args', {}).get('image_attn_mode', 'cross'))
+ super().__init__(*args, **kwargs)
+
+ def _init_image_cond_model(self):
+ """Initialize the image conditioning model."""
+ with dist_utils.local_master_first():
+ model_name = self.image_cond_model_config['name']
+ model_args = self.image_cond_model_config.get('args', {})
+
+ if model_name == 'DinoV3ProjFeatureExtractor':
+ self.image_cond_model = DinoV3ProjFeatureExtractor(**model_args)
+ elif model_name == 'DinoV3VaeProjFeatureExtractor':
+ self.image_cond_model = DinoV3VaeProjFeatureExtractor(**model_args)
+ else:
+ # Fallback to standard extractors
+ from . import image_conditioned
+ self.image_cond_model = getattr(image_conditioned, model_name)(**model_args)
+
+ self.image_cond_model.cuda()
+
+ # Expose proj_channels for denoiser to know the correct proj_in_channels
+ if hasattr(self.image_cond_model, 'proj_channels'):
+ self._proj_channels = self.image_cond_model.proj_channels
+ else:
+ self._proj_channels = getattr(self.image_cond_model, 'embed_dim', None)
+ # Expose vae_proj_channels for gated_proj mode
+ self._vae_proj_channels = getattr(self.image_cond_model, 'vae_proj_channels', None)
+
+ def init_models_and_more(self, **kwargs):
+ """
+ Override to handle image_cond_model initialization.
+
+ Since proj_linear has been moved to per-block ProjectAttention in the denoiser,
+ image_cond_model no longer has any trainable parameters (DINOv3 backbone is frozen,
+ ProjGrid only has register_buffers). Therefore we do NOT add it to self.models
+ (which would trigger DDP wrapping and fail). We just initialize it and keep it
+ as a standalone module for inference.
+ """
+ # Initialize image_cond_model first
+ if self.image_cond_model is None:
+ self._init_image_cond_model()
+
+ # Keep a reference to the unwrapped module for attribute access
+ self._image_cond_module = self.image_cond_model # for .grid_resolution etc.
+
+ # Log that image_cond has no trainable params
+ proj_params = [p for p in self.image_cond_model.parameters() if p.requires_grad]
+ if self.is_master:
+ if proj_params:
+ print(f'\nWARNING: image_cond_model has {len(proj_params)} trainable params, '
+ f'but is NOT registered in self.models. These will NOT be trained!')
+ else:
+ print(f'\nimage_cond_model has no trainable parameters, skipping DDP/optimizer registration.')
+
+ # Call base class to set up DDP, optimizer, EMA, etc. (without image_cond)
+ super().init_models_and_more(**kwargs)
+
+ # ------------------------------------------------------------------
+ # Checkpoint save/load overrides: skip DINOv3 backbone weights
+ # ------------------------------------------------------------------
+
+ # Keys in image_cond state_dict that belong to the frozen DINOv3 backbone.
+ # Everything under "model." is DINOv3; we only keep proj_grid.*
+ _IMAGE_COND_BACKBONE_PREFIX = 'model.'
+
+ def _filter_image_cond_state_dict(self, state_dict: dict) -> dict:
+ """Keep only non-backbone keys (proj_grid, etc.) from image_cond state_dict."""
+ return {k: v for k, v in state_dict.items()
+ if not k.startswith(self._IMAGE_COND_BACKBONE_PREFIX)}
+
+ def _fill_denoiser_proj_linear_from_image_cond(
+ self,
+ denoiser_ckpt: dict,
+ denoiser_state_dict: dict,
+ image_cond_ckpt_path: Optional[str] = None,
+ ) -> dict:
+ """
+ Fill missing per-block proj_linear weights in denoiser checkpoint
+ from the old-style image_cond proj_linear (broadcast to all blocks).
+
+ Also handles shape mismatch when NAF is enabled: old proj_linear has shape
+ [model_ch, embed_dim] but new model expects [model_ch, embed_dim*2].
+ In this case, the old weights are placed in the lr half and the hr half is zero-padded.
+
+ Compatibility strategy:
+ 1. If denoiser_ckpt already contains per-block proj_linear keys with correct shape -> do nothing.
+ 2. If shape mismatch (embed_dim vs embed_dim*2) -> zero-pad the weight.
+ 3. If keys missing, try to load proj_linear from image_cond checkpoint -> broadcast (with optional pad).
+
+ Args:
+ denoiser_ckpt: The loaded denoiser state dict
+ denoiser_state_dict: The model's current state dict (to find expected keys)
+ image_cond_ckpt_path: Path to image_cond checkpoint file (optional)
+
+ Returns:
+ Updated denoiser_ckpt with proj_linear keys filled if needed
+ """
+ if self.image_attn_mode != 'proj':
+ return denoiser_ckpt
+
+ # Find all per-block proj_linear keys expected by the model
+ proj_linear_keys = [k for k in denoiser_state_dict.keys()
+ if '.cross_attn.proj_linear.' in k]
+ if not proj_linear_keys:
+ return denoiser_ckpt
+
+ # --- Phase 1: Handle shape mismatch for existing keys (NAF upgrade) ---
+ for k in proj_linear_keys:
+ if k in denoiser_ckpt:
+ expected_shape = denoiser_state_dict[k].shape
+ actual_shape = denoiser_ckpt[k].shape
+ if expected_shape != actual_shape:
+ if k.endswith('.weight') and len(expected_shape) == 2:
+ # Weight shape: [out_features, in_features]
+ # Old: [model_ch, embed_dim], New: [model_ch, embed_dim*2]
+ out_f, new_in_f = expected_shape
+ _, old_in_f = actual_shape
+ if new_in_f > old_in_f and out_f == actual_shape[0]:
+ if self.is_master:
+ print(f'\n [NAF Compat] Padding proj_linear weight {k}: '
+ f'{actual_shape} -> {expected_shape} (zero-pad hr half)')
+ new_w = torch.zeros(expected_shape, dtype=denoiser_ckpt[k].dtype,
+ device=denoiser_ckpt[k].device)
+ new_w[:, :old_in_f] = denoiser_ckpt[k]
+ denoiser_ckpt[k] = new_w
+ else:
+ if self.is_master:
+ print(f'\n Warning: proj_linear {k} shape mismatch '
+ f'{actual_shape} vs {expected_shape}, using model init')
+ denoiser_ckpt[k] = denoiser_state_dict[k]
+ # bias shape should match (out_features only), no padding needed
+
+ # --- Phase 2: Handle completely missing keys ---
+ missing_proj_keys = [k for k in proj_linear_keys if k not in denoiser_ckpt]
+ if not missing_proj_keys:
+ return denoiser_ckpt
+
+ if self.is_master:
+ print(f'\n [Compat] Denoiser ckpt missing {len(missing_proj_keys)} per-block proj_linear keys.')
+ print(f' Attempting to load from image_cond proj_linear: {image_cond_ckpt_path}')
+
+ # Try to find proj_linear weights from image_cond checkpoint
+ old_proj_linear_w = None
+ old_proj_linear_b = None
+
+ if image_cond_ckpt_path is not None:
+ import os as _os
+ if _os.path.exists(image_cond_ckpt_path):
+ try:
+ ic_ckpt = torch.load(image_cond_ckpt_path, map_location=self.device, weights_only=True)
+ old_proj_linear_w = ic_ckpt.get('proj_linear.weight')
+ old_proj_linear_b = ic_ckpt.get('proj_linear.bias')
+ except Exception as e:
+ if self.is_master:
+ print(f' Warning: Failed to load image_cond ckpt: {e}')
+
+ if old_proj_linear_w is None:
+ raise RuntimeError(
+ f'Denoiser checkpoint is missing per-block proj_linear keys '
+ f'(e.g. {missing_proj_keys[0]}), and no image_cond proj_linear '
+ f'was found to broadcast from. Cannot proceed.'
+ )
+
+ if self.is_master:
+ print(f' Found image_cond proj_linear: weight {old_proj_linear_w.shape}, bias {old_proj_linear_b.shape}')
+ print(f' Broadcasting to {len(missing_proj_keys)} per-block keys...')
+
+ for k in missing_proj_keys:
+ if k.endswith('.weight'):
+ expected_shape = denoiser_state_dict[k].shape
+ if expected_shape != old_proj_linear_w.shape:
+ # Pad for NAF: [model_ch, embed_dim] -> [model_ch, embed_dim*2]
+ out_f, new_in_f = expected_shape
+ _, old_in_f = old_proj_linear_w.shape
+ if new_in_f > old_in_f and out_f == old_proj_linear_w.shape[0]:
+ new_w = torch.zeros(expected_shape, dtype=old_proj_linear_w.dtype)
+ new_w[:, :old_in_f] = old_proj_linear_w
+ denoiser_ckpt[k] = new_w
+ else:
+ denoiser_ckpt[k] = denoiser_state_dict[k]
+ else:
+ denoiser_ckpt[k] = old_proj_linear_w.clone()
+ elif k.endswith('.bias'):
+ denoiser_ckpt[k] = old_proj_linear_b.clone()
+
+ return denoiser_ckpt
+
+ def _master_params_to_state_dicts(self, master_params):
+ """Override to skip image_cond checkpoint entirely.
+
+ image_cond model no longer has trainable parameters:
+ - proj_linear has been moved to per-block ProjectAttention in the denoiser
+ - DINOv3 backbone is frozen and loaded from pretrained weights
+ - ProjGrid only contains fixed register_buffers (grid_points, front_view_transform_matrix)
+ So there is nothing worth saving for image_cond.
+ """
+ state_dicts = super()._master_params_to_state_dicts(master_params)
+ state_dicts.pop('image_cond', None)
+ return state_dicts
+
+ def load(self, load_dir, step=0):
+ """
+ Override to handle:
+ 1. Old checkpoints that don't have image_cond_step*.pt
+ 2. Partial image_cond checkpoints (only proj_linear + proj_grid, no DINOv3 backbone)
+ """
+ import os as _os
+
+ if self.is_master:
+ print(f'\nLoading checkpoint from step {step}...', end='')
+
+ model_ckpts = {}
+ for name, model in self.models.items():
+ ckpt_path = _os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')
+
+ if name == 'image_cond':
+ # --- handle missing or partial image_cond checkpoint ---
+ if not _os.path.exists(ckpt_path):
+ if self.is_master:
+ print(f'\n image_cond checkpoint not found at {ckpt_path}, using freshly initialised weights.')
+ model_ckpts[name] = model.state_dict()
+ continue
+
+ try:
+ model_ckpt = torch.load(
+ read_file_dist(ckpt_path),
+ map_location=self.device, weights_only=True)
+ except Exception as e:
+ if self.is_master:
+ print(f'\n Failed to load image_cond checkpoint: {e}. Using freshly initialised weights.')
+ model_ckpts[name] = model.state_dict()
+ continue
+
+ # Partial ckpt (no backbone) → load with strict=False
+ missing, unexpected = model.load_state_dict(model_ckpt, strict=False)
+ # All missing keys should be the frozen DINOv3 backbone; verify
+ non_backbone_missing = [k for k in missing
+ if not k.startswith(self._IMAGE_COND_BACKBONE_PREFIX)]
+ if non_backbone_missing and self.is_master:
+ print(f'\n Warning: unexpected missing keys in image_cond ckpt: {non_backbone_missing}')
+ if unexpected and self.is_master:
+ print(f'\n Warning: unexpected keys in image_cond ckpt: {unexpected}')
+
+ # Build a full state_dict for master_params sync
+ full_sd = model.state_dict()
+ full_sd.update(model_ckpt)
+ model_ckpts[name] = full_sd
+ else:
+ model_ckpt = torch.load(
+ read_file_dist(ckpt_path),
+ map_location=self.device, weights_only=True)
+ # For denoiser: handle old ckpts missing per-block proj_linear
+ if name == 'denoiser':
+ ic_ckpt_path = _os.path.join(load_dir, 'ckpts', f'image_cond_step{step:07d}.pt')
+ model_ckpt = self._fill_denoiser_proj_linear_from_image_cond(
+ model_ckpt, model.state_dict(), ic_ckpt_path)
+ model_ckpts[name] = model_ckpt
+ model.load_state_dict(model_ckpt)
+
+ self._state_dicts_to_master_params(self.master_params, model_ckpts)
+ del model_ckpts
+
+ if self.is_master:
+ for i, ema_rate in enumerate(self.ema_rate):
+ ema_ckpts = {}
+ for name, model in self.models.items():
+ ema_path = _os.path.join(
+ load_dir, 'ckpts',
+ f'{name}_ema{ema_rate}_step{step:07d}.pt')
+ if name == 'image_cond':
+ if not _os.path.exists(ema_path):
+ ema_ckpts[name] = model.state_dict()
+ continue
+ try:
+ ema_ckpt = torch.load(ema_path, map_location=self.device, weights_only=True)
+ except Exception:
+ ema_ckpts[name] = model.state_dict()
+ continue
+ full_sd = model.state_dict()
+ full_sd.update(ema_ckpt)
+ ema_ckpts[name] = full_sd
+ else:
+ ema_ckpt = torch.load(ema_path, map_location=self.device, weights_only=True)
+ if name == 'denoiser':
+ ic_ema_path = _os.path.join(
+ load_dir, 'ckpts',
+ f'image_cond_ema{ema_rate}_step{step:07d}.pt')
+ ema_ckpt = self._fill_denoiser_proj_linear_from_image_cond(
+ ema_ckpt, model.state_dict(), ic_ema_path)
+ ema_ckpts[name] = ema_ckpt
+ self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
+ del ema_ckpts
+
+ misc_ckpt = torch.load(
+ read_file_dist(_os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')),
+ map_location=torch.device('cpu'), weights_only=False)
+ # Optimizer state may mismatch when loading old checkpoints that were
+ # saved before image_cond was added to self.models, or when the number
+ # of trainable parameters changed (e.g. backbone freeze, NAF upgrade).
+ # In that case we skip restoring optimizer state and let it re-initialise.
+ try:
+ self.optimizer.load_state_dict(misc_ckpt['optimizer'])
+ # Verify optimizer state shapes match parameters.
+ # load_state_dict may succeed even when shapes mismatch (keys are
+ # integer indices), causing a crash later in optimizer.step().
+ _shape_ok = True
+ for group in self.optimizer.param_groups:
+ for p in group['params']:
+ state = self.optimizer.state.get(p)
+ if state is not None:
+ for sv in state.values():
+ if isinstance(sv, torch.Tensor) and sv.shape != () and sv.shape != p.shape:
+ _shape_ok = False
+ break
+ if not _shape_ok:
+ break
+ if not _shape_ok:
+ break
+ if not _shape_ok:
+ if self.is_master:
+ print(f'\n Warning: optimizer state shape mismatch (likely NAF upgrade). '
+ f'Optimizer will start fresh.')
+ self.optimizer.state.clear()
+ except (ValueError, RuntimeError) as e:
+ if self.is_master:
+ print(f'\n Warning: could not load optimizer state ({e}). '
+ f'Optimizer will start fresh.')
+ self.step = misc_ckpt['step']
+ self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
+ if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
+ self.scaler.load_state_dict(misc_ckpt['scaler'])
+ elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
+ self.log_scale = misc_ckpt['log_scale']
+ if self.lr_scheduler_config is not None:
+ self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
+ if self.elastic_controller_config is not None:
+ self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
+ if self.grad_clip is not None and not isinstance(self.grad_clip, float):
+ self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
+ del misc_ckpt
+
+ if self.world_size > 1:
+ dist.barrier()
+ if self.is_master:
+ print(' Done.')
+
+ if self.world_size > 1:
+ self.check_ddp()
+
+ def finetune_from(self, finetune_ckpt):
+ """
+ Override to tolerate DINOv3 backbone keys missing from image_cond checkpoint.
+ For image_cond, the checkpoint only stores proj_linear + proj_grid (no backbone),
+ so we treat all backbone keys as allowed-missing.
+ """
+ ALLOWED_MISSING_KEYS = {'rope_phases'}
+
+ if self.is_master:
+ print('\nFinetuning from:')
+ for name, path in finetune_ckpt.items():
+ print(f' - {name}: {path}')
+
+ model_ckpts = {}
+ for name, model in self.models.items():
+ model_state_dict = model.state_dict()
+ if name in finetune_ckpt:
+ model_ckpt = torch.load(
+ read_file_dist(finetune_ckpt[name]),
+ map_location=self.device, weights_only=True)
+
+ model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict)
+
+ for k, v in model_ckpt.items():
+ if k not in model_state_dict:
+ if self.is_master:
+ print(f'Warning: {k} not found in model_state_dict, skipped.')
+ model_ckpt[k] = None
+ elif model_ckpt[k].shape != model_state_dict[k].shape:
+ # For proj_linear weights, try zero-pad instead of skipping
+ # This handles NAF upgrade: [model_ch, embed_dim] -> [model_ch, embed_dim*2]
+ if '.cross_attn.proj_linear.weight' in k and len(model_ckpt[k].shape) == 2:
+ old_shape = model_ckpt[k].shape
+ new_shape = model_state_dict[k].shape
+ if new_shape[0] == old_shape[0] and new_shape[1] > old_shape[1]:
+ if self.is_master:
+ print(f'Info: Zero-padding proj_linear weight {k}: {old_shape} -> {new_shape}')
+ new_w = torch.zeros(new_shape, dtype=model_ckpt[k].dtype)
+ new_w[:, :old_shape[1]] = model_ckpt[k]
+ model_ckpt[k] = new_w
+ else:
+ if self.is_master:
+ print(f'Warning: {k} shape mismatch, {old_shape} vs {new_shape}, skipped.')
+ model_ckpt[k] = model_state_dict[k]
+ else:
+ if self.is_master:
+ print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
+ model_ckpt[k] = model_state_dict[k]
+ model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
+
+ missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys())
+
+ # For denoiser: fill per-block proj_linear from image_cond if missing
+ if name == 'denoiser':
+ ic_path = finetune_ckpt.get('image_cond')
+ proj_linear_missing = {k for k in missing_keys if '.cross_attn.proj_linear.' in k}
+ if proj_linear_missing:
+ model_ckpt = self._fill_denoiser_proj_linear_from_image_cond(
+ model_ckpt, model_state_dict, ic_path)
+ # Recalculate missing_keys after filling
+ missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys())
+
+ # For image_cond, DINOv3 backbone keys are expected to be missing
+ allowed = set(ALLOWED_MISSING_KEYS)
+ if name == 'image_cond':
+ backbone_missing = {k for k in missing_keys
+ if k.startswith(self._IMAGE_COND_BACKBONE_PREFIX)}
+ allowed |= backbone_missing
+ if backbone_missing and self.is_master:
+ print(f'Info: image_cond: {len(backbone_missing)} DINOv3 backbone keys '
+ f'not in ckpt (expected, using pretrained weights)')
+ # Old ckpts may have proj_linear.* which has moved to denoiser blocks
+ proj_linear_missing = {k for k in missing_keys if k.startswith('proj_linear.')}
+ allowed |= proj_linear_missing
+
+ unexpected_missing = missing_keys - allowed
+ if unexpected_missing and self.is_master:
+ print(f'Error: Missing keys in checkpoint: {unexpected_missing}')
+ raise RuntimeError(f'Missing keys in checkpoint: {unexpected_missing}')
+ if missing_keys & ALLOWED_MISSING_KEYS and self.is_master:
+ print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}')
+
+ for k in missing_keys:
+ model_ckpt[k] = model_state_dict[k]
+
+ model_ckpts[name] = model_ckpt
+ model.load_state_dict(model_ckpt)
+ else:
+ if self.is_master:
+ print(f'Warning: {name} not found in finetune_ckpt, skipped.')
+ model_ckpts[name] = model_state_dict
+
+ self._state_dicts_to_master_params(self.master_params, model_ckpts)
+ if self.is_master:
+ for i, ema_rate in enumerate(self.ema_rate):
+ self._state_dicts_to_master_params(self.ema_params[i], model_ckpts)
+ del model_ckpts
+
+ if self.world_size > 1:
+ dist.barrier()
+ if self.is_master:
+ print('Done.')
+
+ if self.world_size > 1:
+ self.check_ddp()
+
+ def encode_image_proj(
+ self,
+ image: torch.Tensor,
+ camera_angle_x: Optional[torch.Tensor] = None,
+ distance: Optional[torch.Tensor] = None,
+ mesh_scale: Optional[torch.Tensor] = None,
+ transform_matrix: Optional[torch.Tensor] = None,
+ coords: Optional[torch.Tensor] = None,
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
+ """
+ Encode the image with view-aligned projection.
+
+ Supports both 'proj' mode (DINOv3 only, 2 outputs) and
+ 'gated_proj' mode (DINOv3 + VAE, 3 outputs).
+ """
+ if self.image_cond_model is None:
+ self._init_image_cond_model()
+
+ outputs = self.image_cond_model(
+ image,
+ camera_angle_x=camera_angle_x,
+ distance=distance,
+ mesh_scale=mesh_scale,
+ transform_matrix=transform_matrix,
+ )
+
+ is_gated = self.image_attn_mode == 'gated_proj'
+
+ if is_gated:
+ cond_global, cond_proj_semantic, cond_proj_color = outputs
+ else:
+ cond_global, cond_proj = outputs
+
+ # If coords provided, extract features at sparse positions
+ if coords is not None:
+ B = cond_global.shape[0]
+ module = getattr(self, '_image_cond_module', self.image_cond_model)
+ grid_res = module.grid_resolution
+ batch_indices = coords[:, 0].long()
+ x_coords = coords[:, 1].long()
+ y_coords = coords[:, 2].long()
+ z_coords = coords[:, 3].long()
+
+ if is_gated:
+ cond_proj_semantic = cond_proj_semantic.reshape(B, grid_res, grid_res, grid_res, -1)
+ cond_proj_semantic = cond_proj_semantic[batch_indices, x_coords, y_coords, z_coords]
+ cond_proj_color = cond_proj_color.reshape(B, grid_res, grid_res, grid_res, -1)
+ cond_proj_color = cond_proj_color[batch_indices, x_coords, y_coords, z_coords]
+ else:
+ cond_proj = cond_proj.reshape(B, grid_res, grid_res, grid_res, -1)
+ cond_proj = cond_proj[batch_indices, x_coords, y_coords, z_coords]
+
+ if is_gated:
+ cond = {
+ 'global': cond_global,
+ 'proj_semantic': cond_proj_semantic,
+ 'proj_color': cond_proj_color,
+ }
+ uncond = {
+ 'global': torch.zeros_like(cond_global),
+ 'proj_semantic': torch.zeros_like(cond_proj_semantic),
+ 'proj_color': torch.zeros_like(cond_proj_color),
+ }
+ else:
+ cond = {'global': cond_global, 'proj': cond_proj}
+ uncond = {'global': torch.zeros_like(cond_global), 'proj': torch.zeros_like(cond_proj)}
+
+ return cond, uncond
+
+ @torch.no_grad()
+ def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
+ """
+ Encode the image (standard mode without projection).
+ """
+ if self.image_cond_model is None:
+ self._init_image_cond_model()
+
+ if self.image_attn_mode == 'proj':
+ # For proj mode, return dict
+ global_feat, proj_feat = self.image_cond_model(image)
+ return {'global': global_feat, 'proj': proj_feat}
+ else:
+ # Standard mode
+ features = self.image_cond_model(image)
+ return features
+
+ def _extract_camera_info(self, kwargs):
+ """
+ Extract camera info from kwargs.
+
+ Supports two formats:
+ 1. 'camera_info' dict: {'camera_angle_x': ..., 'distance': ..., 'mesh_scale': ..., 'transform_matrix': ..., 'coords': ...}
+ 2. Flat fields: 'camera_angle_x', 'camera_distance', 'mesh_scale', 'transform_matrix', 'coords' in kwargs
+
+ Returns:
+ camera_info dict or None if not available
+ """
+ if 'camera_info' in kwargs:
+ return kwargs.pop('camera_info')
+
+ # Try to extract from flat fields (as returned by ViewImageConditionedMixin)
+ camera_angle_x = kwargs.pop('camera_angle_x', None)
+ camera_distance = kwargs.pop('camera_distance', None)
+ mesh_scale = kwargs.pop('mesh_scale', None)
+ transform_matrix = kwargs.pop('transform_matrix', None)
+ coords = kwargs.pop('coords', None)
+
+ if camera_angle_x is not None and camera_distance is not None and mesh_scale is not None:
+ return {
+ 'camera_angle_x': camera_angle_x,
+ 'distance': camera_distance,
+ 'mesh_scale': mesh_scale,
+ 'transform_matrix': transform_matrix,
+ 'coords': coords,
+ }
+
+ return None
+
+ def get_cond(self, cond, **kwargs):
+ """Get the conditioning data."""
+ kwargs.pop('view_idx', None)
+
+ if self.image_attn_mode in ('proj', 'gated_proj'):
+ # Handle projection mode (both standard proj and gated_proj)
+ camera_info = self._extract_camera_info(kwargs)
+ if camera_info is not None:
+ coords = camera_info.get('coords')
+ cond, neg_cond = self.encode_image_proj(
+ cond,
+ camera_angle_x=camera_info.get('camera_angle_x'),
+ distance=camera_info.get('distance'),
+ mesh_scale=camera_info.get('mesh_scale'),
+ transform_matrix=camera_info.get('transform_matrix'),
+ coords=coords,
+ )
+
+ # For sparse mode (coords provided), handle CFG dropout ourselves
+ if coords is not None and hasattr(self, 'p_uncond') and self.p_uncond > 0:
+ import numpy as np
+ B = cond['global'].shape[0]
+ mask = np.random.rand(B) < self.p_uncond
+
+ global_tensor = cond['global']
+ global_mask_shape = [B] + [1] * (global_tensor.ndim - 1)
+ global_mask = torch.tensor(mask, device=global_tensor.device).reshape(global_mask_shape)
+ cond['global'] = torch.where(global_mask, neg_cond['global'], cond['global'])
+
+ batch_indices = coords[:, 0].long()
+ # Handle all sparse proj keys (proj, or proj_semantic + proj_color)
+ for key in list(cond.keys()):
+ if key.startswith('proj'):
+ device = cond[key].device
+ sparse_mask = torch.tensor(mask, device=device)[batch_indices].reshape(-1, 1)
+ cond[key] = torch.where(sparse_mask, neg_cond[key], cond[key])
+
+ return cond
+ else:
+ kwargs['neg_cond'] = neg_cond
+ else:
+ cond = self.encode_image(cond)
+ if isinstance(cond, dict) and 'global' in cond:
+ kwargs['neg_cond'] = {k: torch.zeros_like(v) for k, v in cond.items()}
+ else:
+ kwargs['neg_cond'] = torch.zeros_like(cond)
+ else:
+ cond = self.encode_image(cond)
+ kwargs['neg_cond'] = torch.zeros_like(cond)
+
+ cond = super().get_cond(cond, **kwargs)
+ return cond
+
+ def get_inference_cond(self, cond, **kwargs):
+ """Get the conditioning data for inference."""
+ kwargs.pop('view_idx', None)
+
+ if self.image_attn_mode in ('proj', 'gated_proj'):
+ camera_info = self._extract_camera_info(kwargs)
+ if camera_info is not None:
+ cond, neg_cond = self.encode_image_proj(
+ cond,
+ camera_angle_x=camera_info.get('camera_angle_x'),
+ distance=camera_info.get('distance'),
+ mesh_scale=camera_info.get('mesh_scale'),
+ transform_matrix=camera_info.get('transform_matrix'),
+ coords=camera_info.get('coords'),
+ )
+ kwargs['neg_cond'] = neg_cond
+ else:
+ cond = self.encode_image(cond)
+ if isinstance(cond, dict) and 'global' in cond:
+ kwargs['neg_cond'] = {k: torch.zeros_like(v) for k, v in cond.items()}
+ else:
+ kwargs['neg_cond'] = torch.zeros_like(cond)
+ else:
+ cond = self.encode_image(cond)
+ kwargs['neg_cond'] = torch.zeros_like(cond)
+
+ cond = super().get_inference_cond(cond, **kwargs)
+ return cond
+
+ def vis_cond(self, cond, **kwargs):
+ """Visualize the conditioning data."""
+ return {'image': {'value': cond, 'type': 'image'}}
+
+ @torch.no_grad()
+ def visualize_projection_test(
+ self,
+ cond: torch.Tensor,
+ save_dir: str,
+ prefix: str = "proj_vis",
+ **kwargs
+ ) -> Optional[List[Image.Image]]:
+ """
+ Visualize projection points on the condition images.
+
+ This should be called once before training starts to verify the projection is correct.
+
+ Args:
+ cond: Condition image tensor [B, C, H, W], in [0, 1] range
+ save_dir: Directory to save visualizations
+ prefix: Prefix for saved files
+ **kwargs: Should contain camera_angle_x, camera_distance, mesh_scale, transform_matrix
+
+ Returns:
+ List of PIL Images with projected points overlaid, or None if not in proj mode
+ """
+ if self.image_attn_mode != 'proj':
+ return None
+
+ if self.image_cond_model is None:
+ self._init_image_cond_model()
+
+ # Use _image_cond_module for attribute access (image_cond_model may be DDP-wrapped)
+ module = getattr(self, '_image_cond_module', self.image_cond_model)
+
+ # Check if the model has visualization capability
+ if not hasattr(module, 'visualize_projection'):
+ print("Warning: image_cond_model does not support visualize_projection")
+ return None
+
+ # Extract camera info
+ camera_info = self._extract_camera_info(kwargs)
+ if camera_info is None:
+ print("Warning: No camera info available for projection visualization")
+ return None
+
+ return module.visualize_projection(
+ image=cond,
+ camera_angle_x=camera_info.get('camera_angle_x'),
+ distance=camera_info.get('distance'),
+ mesh_scale=camera_info.get('mesh_scale'),
+ transform_matrix=camera_info.get('transform_matrix'),
+ save_dir=save_dir,
+ prefix=prefix,
+ )
diff --git a/trellis2/trainers/flow_matching/mixins/text_conditioned.py b/trellis2/trainers/flow_matching/mixins/text_conditioned.py
new file mode 100644
index 0000000000000000000000000000000000000000..85f1dcf2582c07edd9b629c07181265f24a90134
--- /dev/null
+++ b/trellis2/trainers/flow_matching/mixins/text_conditioned.py
@@ -0,0 +1,68 @@
+from typing import *
+import os
+os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+import torch
+from transformers import AutoTokenizer, CLIPTextModel
+
+from ....utils import dist_utils
+
+
+class TextConditionedMixin:
+ """
+ Mixin for text-conditioned models.
+
+ Args:
+ text_cond_model: The text conditioning model.
+ """
+ def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
+ super().__init__(*args, **kwargs)
+ self.text_cond_model_name = text_cond_model
+ self.text_cond_model = None # the model is init lazily
+
+ def _init_text_cond_model(self):
+ """
+ Initialize the text conditioning model.
+ """
+ # load model
+ with dist_utils.local_master_first():
+ model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
+ tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
+ model.eval()
+ model = model.cuda()
+ self.text_cond_model = {
+ 'model': model,
+ 'tokenizer': tokenizer,
+ }
+ self.text_cond_model['null_cond'] = self.encode_text([''])
+
+ @torch.no_grad()
+ def encode_text(self, text: List[str]) -> torch.Tensor:
+ """
+ Encode the text.
+ """
+ assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
+ if self.text_cond_model is None:
+ self._init_text_cond_model()
+ encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
+ tokens = encoding['input_ids'].cuda()
+ embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
+
+ return embeddings
+
+ def get_cond(self, cond, **kwargs):
+ """
+ Get the conditioning data.
+ """
+ cond = self.encode_text(cond)
+ kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
+ cond = super().get_cond(cond, **kwargs)
+ return cond
+
+ def get_inference_cond(self, cond, **kwargs):
+ """
+ Get the conditioning data for inference.
+ """
+ cond = self.encode_text(cond)
+ kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
+ cond = super().get_inference_cond(cond, **kwargs)
+ return cond
diff --git a/trellis2/trainers/flow_matching/sparse_flow_matching.py b/trellis2/trainers/flow_matching/sparse_flow_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..65f211fb3983d0b407dd995b2be5a4e944d3e5d0
--- /dev/null
+++ b/trellis2/trainers/flow_matching/sparse_flow_matching.py
@@ -0,0 +1,615 @@
+from typing import *
+import os
+import copy
+import functools
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import numpy as np
+from easydict import EasyDict as edict
+
+from ...modules import sparse as sp
+from ...utils.general_utils import dict_reduce
+from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler
+from .flow_matching import FlowMatchingTrainer
+from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
+from .mixins.text_conditioned import TextConditionedMixin
+from .mixins.image_conditioned import ImageConditionedMixin, MultiImageConditionedMixin
+from .mixins.image_conditioned_proj import ImageConditionedProjMixin
+
+
+class SparseFlowMatchingTrainer(FlowMatchingTrainer):
+ """
+ Trainer for sparse diffusion model with flow matching objective.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ """
+
+ def prepare_dataloader(self, **kwargs):
+ """
+ Prepare dataloader.
+ """
+ self.data_sampler = BalancedResumableSampler(
+ self.dataset,
+ shuffle=True,
+ batch_size=self.batch_size_per_gpu,
+ )
+ if self.num_workers is None or self.num_workers == -1:
+ num_workers = max(1, int(np.ceil((os.cpu_count() - 16) / torch.cuda.device_count())))
+ else:
+ num_workers = self.num_workers
+
+ self.dataloader = DataLoader(
+ self.dataset,
+ batch_size=self.batch_size_per_gpu,
+ num_workers=num_workers,
+ pin_memory=True,
+ drop_last=True,
+ persistent_workers=True,
+ collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
+ sampler=self.data_sampler,
+ )
+ self.data_iterator = cycle(self.dataloader)
+
+ def training_losses(
+ self,
+ x_0: sp.SparseTensor,
+ cond=None,
+ **kwargs
+ ) -> Tuple[Dict, Dict]:
+ """
+ Compute training losses for a single timestep.
+
+ Args:
+ x_0: The [N x ... x C] sparse tensor of the inputs.
+ cond: The [N x ...] tensor of additional conditions.
+ kwargs: Additional arguments to pass to the backbone.
+
+ Returns:
+ a dict with the key "loss" containing a tensor of shape [N].
+ may also contain other keys for different terms.
+ """
+ noise = x_0.replace(torch.randn_like(x_0.feats))
+ t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
+ x_t = self.diffuse(x_0, t, noise=noise)
+ cond = self.get_cond(cond, **kwargs)
+
+ pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
+ assert pred.shape == noise.shape == x_0.shape
+ target = self.get_v(x_0, noise, t)
+ terms = edict()
+ terms["mse"] = F.mse_loss(pred.feats, target.feats)
+ terms["loss"] = terms["mse"]
+
+ # log loss with time bins
+ mse_per_instance = np.array([
+ F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
+ for i in range(x_0.shape[0])
+ ])
+ time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
+ for i in range(10):
+ if (time_bin == i).sum() != 0:
+ terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
+
+ return terms, {}
+
+ @torch.no_grad()
+ def run_snapshot(
+ self,
+ num_samples: int,
+ batch_size: int,
+ verbose: bool = False,
+ ) -> Dict:
+ # Use current step as seed to ensure different samples for each snapshot
+ import random
+ snapshot_seed = self.step
+ random.seed(snapshot_seed)
+ np.random.seed(snapshot_seed)
+
+ g = torch.Generator()
+ g.manual_seed(snapshot_seed)
+
+ dataloader = DataLoader(
+ copy.deepcopy(self.dataset),
+ batch_size=num_samples,
+ shuffle=True,
+ num_workers=0,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ generator=g,
+ )
+ data = next(iter(dataloader))
+
+ # Collect metadata (dataset_name and sha256) for wandb display
+ sample_metadata = []
+ if '_dataset_name' in data and '_sha256' in data:
+ for j in range(min(num_samples, len(data['_dataset_name']))):
+ sample_metadata.append(f"{data['_dataset_name'][j]}/{data['_sha256'][j]}")
+ # Remove metadata fields before inference
+ data.pop('_dataset_name', None)
+ data.pop('_sha256', None)
+
+ # inference
+ sampler = self.get_sampler()
+ sample = []
+ cond_vis = []
+ for i in range(0, num_samples, batch_size):
+ batch_data = {k: v[i:i+batch_size] for k, v in data.items()}
+ batch_data = recursive_to_device(batch_data, 'cuda')
+ noise = batch_data['x_0'].replace(torch.randn_like(batch_data['x_0'].feats))
+ cond_vis.append(self.vis_cond(**batch_data))
+ del batch_data['x_0']
+ args = self.get_inference_cond(**batch_data)
+ res = sampler.sample(
+ self.models['denoiser'],
+ noise=noise,
+ **args,
+ steps=12, guidance_strength=3.0, verbose=verbose,
+ )
+ sample.append(res.samples)
+ sample = sp.sparse_cat(sample)
+
+ sample_gt = {k: v for k, v in data.items()}
+ sample = {k: v if k != 'x_0' else sample for k, v in data.items()}
+ sample_dict = {
+ 'sample_gt': {'value': sample_gt, 'type': 'sample'},
+ 'sample': {'value': sample, 'type': 'sample'},
+ }
+ if sample_metadata:
+ sample_dict['_metadata'] = sample_metadata
+ sample_dict.update(dict_reduce(cond_vis, None, {
+ 'value': lambda x: torch.cat(x, dim=0),
+ 'type': lambda x: x[0],
+ }))
+
+ return sample_dict
+
+
+class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
+ """
+ Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ """
+ pass
+
+
+class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
+ """
+ Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ text_cond_model(str): Text conditioning model.
+ """
+ pass
+
+
+class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
+ """
+ Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ image_cond_model (str): Image conditioning model.
+ """
+ pass
+
+
+class MultiImageConditionedSparseFlowMatchingCFGTrainer(MultiImageConditionedMixin, SparseFlowMatchingCFGTrainer):
+ """
+ Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ image_cond_model (str): Image conditioning model.
+ """
+ pass
+
+
+class ImageConditionedProjSparseFlowMatchingCFGTrainer(ImageConditionedProjMixin, SparseFlowMatchingCFGTrainer):
+ """
+ Trainer for sparse image-conditioned diffusion model with view-aligned projection.
+
+ Uses ImageConditionedProjMixin for 3D-to-2D feature projection with camera parameters.
+ CFG dropout is handled by ClassifierFreeGuidanceMixin (via p_uncond parameter).
+
+ The projection grid outputs a full [B, R, R, R, D] tensor, and this trainer extracts
+ features at sparse coordinates using advanced indexing.
+
+ Args:
+ t_schedule (dict): Time schedule for flow matching.
+ sigma_min (float): Minimum noise level.
+ p_uncond (float): Probability of dropping conditions.
+ image_cond_model (dict): Image conditioning model config (DinoV3ProjFeatureExtractor).
+ run_projection_test (bool): Whether to run projection visualization test before training.
+ """
+
+ def __init__(self, *args, run_projection_test: bool = True, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.run_projection_test = run_projection_test
+
+ def training_losses(
+ self,
+ x_0: sp.SparseTensor,
+ cond=None,
+ **kwargs
+ ) -> Tuple[Dict, Dict]:
+ """
+ Compute training losses for a single timestep.
+
+ Overridden to pass coords from x_0 to get_cond for sparse feature extraction.
+
+ Args:
+ x_0: The [N x ... x C] sparse tensor of the inputs.
+ cond: The [N x ...] tensor of additional conditions.
+ kwargs: Additional arguments to pass to the backbone.
+
+ Returns:
+ a dict with the key "loss" containing a tensor of shape [N].
+ may also contain other keys for different terms.
+ """
+ noise = x_0.replace(torch.randn_like(x_0.feats))
+ t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
+ x_t = self.diffuse(x_0, t, noise=noise)
+
+ # Pass coords to get_cond for sparse feature extraction from full grid
+ kwargs['coords'] = x_0.coords
+ cond = self.get_cond(cond, **kwargs)
+
+ # Pass concat_cond to denoiser if present (needed for PBR/texture training
+ # where shape latent is concatenated with PBR latent as input)
+ denoiser_kwargs = {}
+ if 'concat_cond' in kwargs:
+ denoiser_kwargs['concat_cond'] = kwargs['concat_cond']
+ pred = self.training_models['denoiser'](x_t, t * 1000, cond, **denoiser_kwargs)
+
+ assert pred.shape == noise.shape == x_0.shape
+ target = self.get_v(x_0, noise, t)
+ terms = edict()
+ terms["mse"] = F.mse_loss(pred.feats, target.feats)
+ terms["loss"] = terms["mse"]
+
+ # log loss with time bins
+ mse_per_instance = np.array([
+ F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
+ for i in range(x_0.shape[0])
+ ])
+ time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
+ for i in range(10):
+ if (time_bin == i).sum() != 0:
+ terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
+
+ return terms, {}
+
+ @torch.no_grad()
+ def run_snapshot(
+ self,
+ num_samples: int,
+ batch_size: int,
+ verbose: bool = False,
+ ) -> Dict:
+ """
+ Run snapshot with coords passed to get_inference_cond for sparse feature extraction.
+
+ For projection mode, we need to pass coords to properly extract features at
+ sparse positions from the full projection grid.
+ """
+ # Use current step as seed to ensure different samples for each snapshot
+ import random
+ snapshot_seed = self.step
+ random.seed(snapshot_seed)
+ np.random.seed(snapshot_seed)
+
+ g = torch.Generator()
+ g.manual_seed(snapshot_seed)
+
+ dataloader = DataLoader(
+ copy.deepcopy(self.dataset),
+ batch_size=num_samples,
+ shuffle=True,
+ num_workers=0,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ generator=g,
+ )
+ data = next(iter(dataloader))
+
+ # Collect metadata (dataset_name and sha256) for wandb display
+ sample_metadata = []
+ if '_dataset_name' in data and '_sha256' in data:
+ for j in range(min(num_samples, len(data['_dataset_name']))):
+ sample_metadata.append(f"{data['_dataset_name'][j]}/{data['_sha256'][j]}")
+ # Remove metadata fields before inference
+ data.pop('_dataset_name', None)
+ data.pop('_sha256', None)
+
+ # inference
+ sampler = self.get_sampler()
+ sample = []
+ cond_vis = []
+ for i in range(0, num_samples, batch_size):
+ batch_data = {k: v[i:i+batch_size] for k, v in data.items()}
+ batch_data = recursive_to_device(batch_data, 'cuda')
+ noise = batch_data['x_0'].replace(torch.randn_like(batch_data['x_0'].feats))
+ cond_vis.append(self.vis_cond(**batch_data))
+
+ # Save coords before deleting x_0 (needed for projection feature extraction)
+ coords = batch_data['x_0'].coords
+ del batch_data['x_0']
+
+ # Pass coords to get_inference_cond for sparse feature extraction
+ batch_data['coords'] = coords
+ args = self.get_inference_cond(**batch_data)
+
+ res = sampler.sample(
+ self.models['denoiser'],
+ noise=noise,
+ **args,
+ steps=12, guidance_strength=3.0, verbose=verbose,
+ )
+ sample.append(res.samples)
+ sample = sp.sparse_cat(sample)
+
+ sample_gt = {k: v for k, v in data.items()}
+ sample = {k: v if k != 'x_0' else sample for k, v in data.items()}
+ sample_dict = {
+ 'sample_gt': {'value': sample_gt, 'type': 'sample'},
+ 'sample': {'value': sample, 'type': 'sample'},
+ }
+ if sample_metadata:
+ sample_dict['_metadata'] = sample_metadata
+ sample_dict.update(dict_reduce(cond_vis, None, {
+ 'value': lambda x: torch.cat(x, dim=0),
+ 'type': lambda x: x[0],
+ }))
+
+ return sample_dict
+
+ @torch.no_grad()
+ def visualize_sample(self, sample):
+ """
+ Convert a sample to images, including GT camera view if available.
+
+ Args:
+ sample: Either a SparseTensor or dict containing:
+ - 'x_0': SparseTensor
+ - 'camera_angle_x': [B] (optional)
+ - 'camera_distance': [B] (optional)
+ - 'mesh_scale': [B] (optional)
+
+ Returns:
+ dict with visualization images or tensor
+ """
+ if hasattr(self.dataset, 'visualize_sample'):
+ if isinstance(sample, dict):
+ # Extract camera params and pass them explicitly, since some
+ # dataset.visualize_sample() (e.g. SLatShapeVisMixin) expect
+ # separate keyword arguments rather than a single dict.
+ camera_kwargs = {}
+ for k in ('camera_angle_x', 'camera_distance', 'mesh_scale'):
+ if k in sample:
+ camera_kwargs[k] = sample[k]
+
+ # Try passing camera kwargs explicitly first; fall back to
+ # passing the entire dict if the dataset method doesn't accept them
+ # (e.g. SLatPbrVisMixin expects a dict with 'x_0' + 'concat_cond').
+ import inspect
+ sig = inspect.signature(self.dataset.visualize_sample)
+ params = list(sig.parameters.keys())
+ if 'camera_angle_x' in params:
+ # Shape-style: visualize_sample(x_0, camera_angle_x=, ...)
+ x_0 = sample.get('x_0', sample)
+ return self.dataset.visualize_sample(x_0, **camera_kwargs)
+ else:
+ # Tex/PBR-style: visualize_sample(sample_dict)
+ return self.dataset.visualize_sample(sample)
+ else:
+ return self.dataset.visualize_sample(sample)
+ else:
+ if isinstance(sample, dict):
+ return sample.get('x_0', sample)
+ return sample
+
+ def run(self):
+ """
+ Run training with projection visualization test before starting.
+ """
+ # Run projection visualization test before training starts (if enabled)
+ if self.run_projection_test and self.is_master:
+ print('\n' + '='*60)
+ print('Running projection visualization test...')
+ print('='*60)
+ self._run_projection_visualization_test()
+
+ super().run()
+
+ @torch.no_grad()
+ def _run_projection_visualization_test(self, num_samples: int = 4):
+ """
+ Run projection visualization test on a few samples before training starts.
+
+ This helps verify that the 3D-to-2D projection is working correctly.
+ """
+ import os
+ from torch.utils.data import DataLoader
+
+ # Create a small dataloader
+ dataloader = DataLoader(
+ self.dataset,
+ batch_size=min(num_samples, self.snapshot_batch_size),
+ shuffle=True,
+ num_workers=0,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ )
+
+ # Get one batch
+ data = next(iter(dataloader))
+ data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
+
+ # Extract condition image
+ cond = data.get('cond')
+ if cond is None:
+ print("Warning: No 'cond' field in data, skipping projection visualization test")
+ return
+
+ # Save directory
+ save_dir = os.path.join(self.output_dir, 'samples', 'projection_test')
+
+ # Call visualization method
+ if hasattr(self, 'visualize_projection_test'):
+ # Need to pass camera info as kwargs
+ kwargs = {k: v for k, v in data.items() if k != 'cond' and k != 'x_0'}
+ self.visualize_projection_test(
+ cond=cond,
+ save_dir=save_dir,
+ prefix="proj_test",
+ **kwargs
+ )
+ print(f"Projection visualization saved to: {save_dir}")
+ else:
+ print("Warning: visualize_projection_test not available")
diff --git a/trellis2/trainers/utils.py b/trellis2/trainers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e4286d02b6fc102c364d2ef7571ef97c9fcd41
--- /dev/null
+++ b/trellis2/trainers/utils.py
@@ -0,0 +1,91 @@
+import torch
+import torch.nn as nn
+
+# FP16 utils
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+
+def str_to_dtype(dtype_str: str):
+ return {
+ 'f16': torch.float16,
+ 'fp16': torch.float16,
+ 'float16': torch.float16,
+ 'bf16': torch.bfloat16,
+ 'bfloat16': torch.bfloat16,
+ 'f32': torch.float32,
+ 'fp32': torch.float32,
+ 'float32': torch.float32,
+ }[dtype_str]
+
+
+def make_master_params(model_params):
+ """
+ Copy model parameters into a inflated tensor of full-precision parameters.
+ """
+ master_params = _flatten_dense_tensors(
+ [param.detach().float() for param in model_params]
+ )
+ master_params = nn.Parameter(master_params)
+ master_params.requires_grad = True
+ return [master_params]
+
+
+def unflatten_master_params(model_params, master_params):
+ """
+ Unflatten the master parameters to look like model_params.
+ """
+ return _unflatten_dense_tensors(master_params[0].detach(), model_params)
+
+
+def model_params_to_master_params(model_params, master_params):
+ """
+ Copy the model parameter data into the master parameters.
+ """
+ master_params[0].detach().copy_(
+ _flatten_dense_tensors([param.detach().float() for param in model_params])
+ )
+
+
+def master_params_to_model_params(model_params, master_params):
+ """
+ Copy the master parameter data back into the model parameters.
+ """
+ for param, master_param in zip(
+ model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
+ ):
+ param.detach().copy_(master_param)
+
+
+def model_grads_to_master_grads(model_params, master_params):
+ """
+ Copy the gradients from the model parameters into the master parameters
+ from make_master_params().
+ """
+ master_params[0].grad = _flatten_dense_tensors(
+ [param.grad.data.detach().float() for param in model_params]
+ )
+
+
+def zero_grad(model_params):
+ for param in model_params:
+ if param.grad is not None:
+ if param.grad.grad_fn is not None:
+ param.grad.detach_()
+ else:
+ param.grad.requires_grad_(False)
+ param.grad.zero_()
+
+
+# LR Schedulers
+from torch.optim.lr_scheduler import LambdaLR
+
+class LinearWarmupLRScheduler(LambdaLR):
+ def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+ self.warmup_steps = warmup_steps
+ super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
+
+ def lr_lambda(self, current_step):
+ if current_step < self.warmup_steps:
+ return float(current_step + 1) / self.warmup_steps
+ return 1.0
+
\ No newline at end of file
diff --git a/trellis2/trainers/vae/pbr_vae.py b/trellis2/trainers/vae/pbr_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..28d3646233ce7752796d679c5cd4e1737eece474
--- /dev/null
+++ b/trellis2/trainers/vae/pbr_vae.py
@@ -0,0 +1,291 @@
+from typing import *
+import os
+import copy
+import functools
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import utils3d
+from easydict import EasyDict as edict
+
+from ..basic import BasicTrainer
+from ...modules import sparse as sp
+from ...renderers import MeshRenderer
+from ...representations import Mesh, MeshWithPbrMaterial, MeshWithVoxel
+from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler
+from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
+
+
+class PbrVaeTrainer(BasicTrainer):
+ """
+ Trainer for PBR attributes VAE
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ loss_type (str): Loss type.
+ lambda_kl (float): KL loss weight.
+ lambda_ssim (float): SSIM loss weight.
+ lambda_lpips (float): LPIPS loss weight.
+ """
+
+ def __init__(
+ self,
+ *args,
+ loss_type: str = 'l1',
+ lambda_kl: float = 1e-6,
+ lambda_ssim: float = 0.2,
+ lambda_lpips: float = 0.2,
+ lambda_render: float = 1.0,
+ render_resolution: float = 1024,
+ camera_randomization_config: dict = {
+ 'radius_range': [2, 100],
+ },
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.loss_type = loss_type
+ self.lambda_kl = lambda_kl
+ self.lambda_ssim = lambda_ssim
+ self.lambda_lpips = lambda_lpips
+ self.lambda_render = lambda_render
+ self.camera_randomization_config = camera_randomization_config
+
+ self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device)
+
+ def prepare_dataloader(self, **kwargs):
+ """
+ Prepare dataloader.
+ """
+ self.data_sampler = BalancedResumableSampler(
+ self.dataset,
+ shuffle=True,
+ batch_size=self.batch_size_per_gpu,
+ )
+ self.dataloader = DataLoader(
+ self.dataset,
+ batch_size=self.batch_size_per_gpu,
+ num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
+ pin_memory=True,
+ drop_last=True,
+ persistent_workers=True,
+ collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
+ sampler=self.data_sampler,
+ )
+ self.data_iterator = cycle(self.dataloader)
+
+ def _randomize_camera(self, num_samples: int):
+ # sample radius and fov
+ r_min, r_max = self.camera_randomization_config['radius_range']
+ k_min = 1 / r_max**2
+ k_max = 1 / r_min**2
+ ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min
+ radius = 1 / torch.sqrt(ks)
+ fov = 2 * torch.arcsin(0.5 / radius)
+ origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1)
+
+ # build camera
+ extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device))
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
+ near = [np.random.uniform(r - 1, r) for r in radius.tolist()]
+
+ return {
+ 'extrinsics': extrinsics,
+ 'intrinsics': intrinsics,
+ 'near': near,
+ }
+
+ def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Render a batch of representations.
+
+ Args:
+ reps: The dictionary of lists of representations.
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
+
+ Returns:
+ a dict with
+ base_color : [N x 3 x H x W] tensor of base color.
+ metallic : [N x 1 x H x W] tensor of metallic.
+ roughness : [N x 1 x H x W] tensor of roughness.
+ alpha : [N x 1 x H x W] tensor of alpha.
+ """
+ ret = {k : [] for k in ['base_color', 'metallic', 'roughness', 'alpha']}
+ for i, rep in enumerate(reps):
+ self.renderer.rendering_options['near'] = near[i]
+ self.renderer.rendering_options['far'] = near[i] + 2
+ out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=['attr'])
+ for k in out_dict:
+ ret[k].append(out_dict[k])
+ for k in ret:
+ ret[k] = torch.stack(ret[k])
+ return ret
+
+ def training_losses(
+ self,
+ x: sp.SparseTensor,
+ mesh: List[MeshWithPbrMaterial] = None,
+ **kwargs
+ ) -> Tuple[Dict, Dict]:
+ """
+ Compute training losses.
+
+ Args:
+ x (SparseTensor): Input sparse tensor for pbr materials.
+ mesh (List[MeshWithPbrMaterial]): The list of meshes with PBR materials.
+
+ Returns:
+ a dict with the key "loss" containing a scalar tensor.
+ may also contain other keys for different terms.
+
+ """
+ z, mean, logvar = self.training_models['encoder'](x, sample_posterior=True, return_raw=True)
+ y = self.training_models['decoder'](z)
+
+ terms = edict(loss = 0.0)
+
+ # direct regression
+ if self.loss_type == 'l1':
+ terms["l1"] = l1_loss(x.feats, y.feats)
+ terms["loss"] = terms["loss"] + terms["l1"]
+ elif self.loss_type == 'l2':
+ terms["l2"] = l2_loss(x.feats, y.feats)
+ terms["loss"] = terms["loss"] + terms["l2"]
+ else:
+ raise ValueError(f'Invalid loss type {self.loss_type}')
+
+ # rendering loss
+ if self.lambda_render != 0.0:
+ recon = [MeshWithVoxel(
+ m.vertices,
+ m.faces,
+ [-0.5, -0.5, -0.5],
+ 1 / self.dataset.resolution,
+ v.coords[:, 1:],
+ v.feats * 0.5 + 0.5,
+ torch.Size([*v.shape, *v.spatial_shape]),
+ layout={
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+ ) for m, v in zip(mesh, y)]
+ cameras = self._randomize_camera(len(mesh))
+ gt_renders = self._render_batch(mesh, **cameras)
+ pred_renders = self._render_batch(recon, **cameras)
+ gt_base_color = gt_renders['base_color']
+ pred_base_color = pred_renders['base_color']
+ gt_mra = torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1)
+ pred_mra = torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1)
+ terms['render/base_color/ssim'] = 1 - ssim(pred_base_color, gt_base_color)
+ terms['render/base_color/lpips'] = lpips(pred_base_color, gt_base_color)
+ terms['render/mra/ssim'] = 1 - ssim(pred_mra, gt_mra)
+ terms['render/mra/lpips'] = lpips(pred_mra, gt_mra)
+ terms['loss'] = terms['loss'] + \
+ self.lambda_render * (self.lambda_ssim * terms['render/base_color/ssim'] + self.lambda_lpips * terms['render/base_color/lpips'] + \
+ self.lambda_ssim * terms['render/mra/ssim'] + self.lambda_lpips * terms['render/mra/lpips'])
+
+ # KL regularization
+ terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
+ terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
+
+ return terms, {}
+
+ @torch.no_grad()
+ def run_snapshot(
+ self,
+ num_samples: int,
+ batch_size: int,
+ verbose: bool = False,
+ ) -> Dict:
+ # Use current step as seed to ensure different samples for each snapshot
+ import random
+ snapshot_seed = self.step
+ random.seed(snapshot_seed)
+ np.random.seed(snapshot_seed)
+
+ g = torch.Generator()
+ g.manual_seed(snapshot_seed)
+
+ dataloader = DataLoader(
+ copy.deepcopy(self.dataset),
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=1,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ generator=g,
+ )
+ dataloader.dataset.with_mesh = True
+
+ # inference
+ gts = []
+ recons = []
+ self.models['encoder'].eval()
+ self.models['decoder'].eval()
+ for i in range(0, num_samples, batch_size):
+ batch = min(batch_size, num_samples - i)
+ data = next(iter(dataloader))
+ args = {k: v[:batch] for k, v in data.items()}
+ args = recursive_to_device(args, self.device)
+ z = self.models['encoder'](args['x'])
+ y = self.models['decoder'](z)
+ gts.extend(args['mesh'])
+ recons.extend([MeshWithVoxel(
+ m.vertices,
+ m.faces,
+ [-0.5, -0.5, -0.5],
+ 1 / self.dataset.resolution,
+ v.coords[:, 1:],
+ v.feats * 0.5 + 0.5,
+ torch.Size([*v.shape, *v.spatial_shape]),
+ layout={
+ 'base_color': slice(0, 3),
+ 'metallic': slice(3, 4),
+ 'roughness': slice(4, 5),
+ 'alpha': slice(5, 6),
+ }
+ ) for m, v in zip(args['mesh'], y)])
+ self.models['encoder'].train()
+ self.models['decoder'].train()
+
+ cameras = self._randomize_camera(num_samples)
+ gt_renders = self._render_batch(gts, **cameras)
+ pred_renders = self._render_batch(recons, **cameras)
+
+ sample_dict = {
+ 'gt_base_color': {'value': gt_renders['base_color'] * 2 - 1, 'type': 'image'},
+ 'pred_base_color': {'value': pred_renders['base_color'] * 2 - 1, 'type': 'image'},
+ 'gt_mra': {'value': torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'},
+ 'pred_mra': {'value': torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'},
+ }
+
+ return sample_dict
diff --git a/trellis2/trainers/vae/shape_vae.py b/trellis2/trainers/vae/shape_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..db4630ed7c43792bdb29df98be6be1eb0756e746
--- /dev/null
+++ b/trellis2/trainers/vae/shape_vae.py
@@ -0,0 +1,276 @@
+from typing import *
+import os
+import copy
+import functools
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import utils3d
+from easydict import EasyDict as edict
+
+from ..basic import BasicTrainer
+from ...modules import sparse as sp
+from ...renderers import MeshRenderer
+from ...representations import Mesh
+from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler
+from ...utils.loss_utils import l1_loss, ssim, lpips
+
+
+class ShapeVaeTrainer(BasicTrainer):
+ """
+ Trainer for Shape VAE
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ lambda_subdiv (float): Subdivision loss weight.
+ lambda_intersected (float): Intersected loss weight.
+ lambda_vertice (float): Vertice loss weight.
+ lambda_kl (float): KL loss weight.
+ lambda_ssim (float): SSIM loss weight.
+ lambda_lpips (float): LPIPS loss weight.
+ """
+
+ def __init__(
+ self,
+ *args,
+ lambda_subdiv: float = 0.1,
+ lambda_intersected: float = 0.1,
+ lambda_vertice: float = 1e-2,
+ lambda_mask: float = 1,
+ lambda_depth: float = 10,
+ lambda_normal: float = 1,
+ lambda_kl: float = 1e-6,
+ lambda_ssim: float = 0.2,
+ lambda_lpips: float = 0.2,
+ render_resolution: float = 1024,
+ camera_randomization_config: dict = {
+ 'radius_range': [2, 100],
+ },
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lambda_subdiv = lambda_subdiv
+ self.lambda_intersected = lambda_intersected
+ self.lambda_mask = lambda_mask
+ self.lambda_vertice = lambda_vertice
+ self.lambda_depth = lambda_depth
+ self.lambda_normal = lambda_normal
+ self.lambda_kl = lambda_kl
+ self.lambda_ssim = lambda_ssim
+ self.lambda_lpips = lambda_lpips
+ self.camera_randomization_config = camera_randomization_config
+
+ self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device)
+
+ def prepare_dataloader(self, **kwargs):
+ """
+ Prepare dataloader.
+ """
+ self.data_sampler = BalancedResumableSampler(
+ self.dataset,
+ shuffle=True,
+ batch_size=self.batch_size_per_gpu,
+ )
+ self.dataloader = DataLoader(
+ self.dataset,
+ batch_size=self.batch_size_per_gpu,
+ num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
+ pin_memory=True,
+ drop_last=True,
+ persistent_workers=True,
+ collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
+ sampler=self.data_sampler,
+ )
+ self.data_iterator = cycle(self.dataloader)
+
+ def _randomize_camera(self, num_samples: int):
+ # sample radius and fov
+ r_min, r_max = self.camera_randomization_config['radius_range']
+ k_min = 1 / r_max**2
+ k_max = 1 / r_min**2
+ ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min
+ radius = 1 / torch.sqrt(ks)
+ fov = 2 * torch.arcsin(0.5 / radius)
+ origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1)
+
+ # build camera
+ extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device))
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
+ near = [np.random.uniform(r - 1, r) for r in radius.tolist()]
+
+ return {
+ 'extrinsics': extrinsics,
+ 'intrinsics': intrinsics,
+ 'near': near,
+ }
+
+ def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List,
+ return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]:
+ """
+ Render a batch of representations.
+
+ Args:
+ reps: The dictionary of lists of representations.
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
+ return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
+
+ Returns:
+ a dict with
+ mask : [N x 1 x H x W] tensor of rendered masks
+ normal : [N x 3 x H x W] tensor of rendered normals
+ depth : [N x 1 x H x W] tensor of rendered depths
+ """
+ ret = {k : [] for k in return_types}
+ for i, rep in enumerate(reps):
+ self.renderer.rendering_options['near'] = near[i]
+ self.renderer.rendering_options['far'] = near[i] + 2
+ out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types)
+ for k in out_dict:
+ ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k])
+ for k in ret:
+ ret[k] = torch.stack(ret[k])
+ return ret
+
+ def training_losses(
+ self,
+ vertices: sp.SparseTensor,
+ intersected: sp.SparseTensor,
+ mesh: List[Mesh],
+ ) -> Tuple[Dict, Dict]:
+ """
+ Compute training losses.
+
+ Args:
+ vertices (SparseTensor): vertices of each active voxel
+ intersected (SparseTensor): intersected flag of each active voxel
+ mesh (List[Mesh]): the list of meshes to render
+
+ Returns:
+ a dict with the key "loss" containing a scalar tensor.
+ may also contain other keys for different terms.
+ """
+ z, mean, logvar = self.training_models['encoder'](vertices, intersected, sample_posterior=True, return_raw=True)
+ recon, pred_vertice, pred_intersected, subs_gt, subs = self.training_models['decoder'](z, intersected)
+
+ terms = edict(loss = 0.0)
+
+ # direct regression
+ if self.lambda_intersected > 0:
+ terms["direct/intersected"] = F.binary_cross_entropy_with_logits(pred_intersected.feats.flatten(), intersected.feats.flatten().float())
+ terms["loss"] = terms["loss"] + self.lambda_intersected * terms["direct/intersected"]
+ if self.lambda_vertice > 0:
+ terms["direct/vertice"] = F.mse_loss(pred_vertice.feats, vertices.feats)
+ terms["loss"] = terms["loss"] + self.lambda_vertice * terms["direct/vertice"]
+
+ # subdivision prediction loss
+ for i, (sub_gt, sub) in enumerate(zip(subs_gt, subs)):
+ terms[f"bce_sub{i}"] = F.binary_cross_entropy_with_logits(sub.feats, sub_gt.float())
+ terms["loss"] = terms["loss"] + self.lambda_subdiv * terms[f"bce_sub{i}"]
+
+ # rendering loss
+ cameras = self._randomize_camera(len(mesh))
+ gt_renders = self._render_batch(mesh, **cameras, return_types=['mask', 'normal', 'depth'])
+ pred_renders = self._render_batch(recon, **cameras, return_types=['mask', 'normal', 'depth'])
+ terms['render/mask'] = l1_loss(pred_renders['mask'], gt_renders['mask'])
+ terms['render/depth'] = l1_loss(pred_renders['depth'], gt_renders['depth'])
+ terms['render/normal/l1'] = l1_loss(pred_renders['normal'], gt_renders['normal'])
+ terms['render/normal/ssim'] = 1 - ssim(pred_renders['normal'], gt_renders['normal'])
+ terms['render/normal/lpips'] = lpips(pred_renders['normal'], gt_renders['normal'])
+ terms['loss'] = terms['loss'] + \
+ self.lambda_mask * terms['render/mask'] + \
+ self.lambda_depth * terms['render/depth'] + \
+ self.lambda_normal * (terms['render/normal/l1'] + self.lambda_ssim * terms['render/normal/ssim'] + self.lambda_lpips * terms['render/normal/lpips'])
+
+ # KL regularization
+ terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
+ terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
+
+ return terms, {}
+
+ @torch.no_grad()
+ def run_snapshot(
+ self,
+ num_samples: int,
+ batch_size: int,
+ verbose: bool = False,
+ ) -> Dict:
+ # Use current step as seed to ensure different samples for each snapshot
+ import random
+ snapshot_seed = self.step
+ random.seed(snapshot_seed)
+ np.random.seed(snapshot_seed)
+
+ g = torch.Generator()
+ g.manual_seed(snapshot_seed)
+
+ dataloader = DataLoader(
+ copy.deepcopy(self.dataset),
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=1,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ generator=g,
+ )
+
+ # inference
+ gts = []
+ recons = []
+ recons2 = []
+ self.models['encoder'].eval()
+ for i in range(0, num_samples, batch_size):
+ batch = min(batch_size, num_samples - i)
+ data = next(iter(dataloader))
+ args = {k: v[:batch] for k, v in data.items()}
+ args = recursive_to_device(args, self.device)
+ z = self.models['encoder'](args['vertices'], args['intersected'])
+ self.models['decoder'].train()
+ y = self.models['decoder'](z, args['intersected'])[0]
+ z.clear_spatial_cache()
+ self.models['decoder'].eval()
+ y2 = self.models['decoder'](z)
+ gts.extend(args['mesh'])
+ recons.extend(y)
+ recons2.extend(y2)
+ self.models['encoder'].train()
+ self.models['decoder'].train()
+
+ cameras = self._randomize_camera(num_samples)
+ gt_renders = self._render_batch(gts, **cameras, return_types=['normal'])
+ recons_renders = self._render_batch(recons, **cameras, return_types=['normal'])
+ recons2_renders = self._render_batch(recons2, **cameras, return_types=['normal'])
+
+ sample_dict = {
+ 'gt': {'value': gt_renders['normal'], 'type': 'image'},
+ 'rec': {'value': recons_renders['normal'], 'type': 'image'},
+ 'rec2': {'value': recons2_renders['normal'], 'type': 'image'},
+ }
+
+ return sample_dict
diff --git a/trellis2/trainers/vae/sparse_structure_vae.py b/trellis2/trainers/vae/sparse_structure_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd9d7baae7036e40ac80b90156823941e0ec964
--- /dev/null
+++ b/trellis2/trainers/vae/sparse_structure_vae.py
@@ -0,0 +1,140 @@
+from typing import *
+import copy
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from easydict import EasyDict as edict
+
+from ..basic import BasicTrainer
+
+
+class SparseStructureVaeTrainer(BasicTrainer):
+ """
+ Trainer for Sparse Structure VAE.
+
+ Args:
+ models (dict[str, nn.Module]): Models to train.
+ dataset (torch.utils.data.Dataset): Dataset.
+ output_dir (str): Output directory.
+ load_dir (str): Load directory.
+ step (int): Step to load.
+ batch_size (int): Batch size.
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
+ batch_split (int): Split batch with gradient accumulation.
+ max_steps (int): Max steps.
+ optimizer (dict): Optimizer config.
+ lr_scheduler (dict): Learning rate scheduler config.
+ elastic (dict): Elastic memory management config.
+ grad_clip (float or dict): Gradient clip config.
+ ema_rate (float or list): Exponential moving average rates.
+ fp16_mode (str): FP16 mode.
+ - None: No FP16.
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
+ - 'amp': Automatic mixed precision.
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
+ finetune_ckpt (dict): Finetune checkpoint.
+ log_param_stats (bool): Log parameter stats.
+ i_print (int): Print interval.
+ i_log (int): Log interval.
+ i_sample (int): Sample interval.
+ i_save (int): Save interval.
+ i_ddpcheck (int): DDP check interval.
+
+ loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss.
+ lambda_kl (float): KL divergence loss weight.
+ """
+
+ def __init__(
+ self,
+ *args,
+ loss_type='bce',
+ lambda_kl=1e-6,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.loss_type = loss_type
+ self.lambda_kl = lambda_kl
+
+ def training_losses(
+ self,
+ ss: torch.Tensor,
+ **kwargs
+ ) -> Tuple[Dict, Dict]:
+ """
+ Compute training losses.
+
+ Args:
+ ss: The [N x 1 x H x W x D] tensor of binary sparse structure.
+
+ Returns:
+ a dict with the key "loss" containing a scalar tensor.
+ may also contain other keys for different terms.
+ """
+ z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True)
+ logits = self.training_models['decoder'](z)
+
+ terms = edict(loss = 0.0)
+ if self.loss_type == 'bce':
+ terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean')
+ terms["loss"] = terms["loss"] + terms["bce"]
+ elif self.loss_type == 'l1':
+ terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean')
+ terms["loss"] = terms["loss"] + terms["l1"]
+ elif self.loss_type == 'dice':
+ logits = F.sigmoid(logits)
+ terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1)
+ terms["loss"] = terms["loss"] + terms["dice"]
+ else:
+ raise ValueError(f'Invalid loss type {self.loss_type}')
+ terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
+ terms["loss"] = terms["loss"] + self.lamda_kl * terms["kl"]
+
+ return terms, {}
+
+ @torch.no_grad()
+ def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False):
+ super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose)
+
+ @torch.no_grad()
+ def run_snapshot(
+ self,
+ num_samples: int,
+ batch_size: int,
+ verbose: bool = False,
+ ) -> Dict:
+ # Use current step as seed to ensure different samples for each snapshot
+ import random
+ snapshot_seed = self.step
+ random.seed(snapshot_seed)
+ np.random.seed(snapshot_seed)
+
+ g = torch.Generator()
+ g.manual_seed(snapshot_seed)
+
+ dataloader = DataLoader(
+ copy.deepcopy(self.dataset),
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=0,
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
+ generator=g,
+ )
+
+ # inference
+ gts = []
+ recons = []
+ for i in range(0, num_samples, batch_size):
+ batch = min(batch_size, num_samples - i)
+ data = next(iter(dataloader))
+ args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
+ z = self.models['encoder'](args['ss'].float(), sample_posterior=False)
+ logits = self.models['decoder'](z)
+ recon = (logits > 0).long()
+ gts.append(args['ss'])
+ recons.append(recon)
+
+ sample_dict = {
+ 'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'},
+ 'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'},
+ }
+ return sample_dict
diff --git a/trellis2/utils/__init__.py b/trellis2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/trellis2/utils/data_utils.py b/trellis2/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..805b6cc118106857e7ef767ab4bfd133dbd78e6f
--- /dev/null
+++ b/trellis2/utils/data_utils.py
@@ -0,0 +1,226 @@
+from typing import *
+import math
+import torch
+import numpy as np
+from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
+import torch.distributed as dist
+
+
+def recursive_to_device(
+ data: Any,
+ device: torch.device,
+ non_blocking: bool = False,
+) -> Any:
+ """
+ Recursively move all tensors in a data structure to a device.
+ """
+ if hasattr(data, "to"):
+ return data.to(device, non_blocking=non_blocking)
+ elif isinstance(data, (list, tuple)):
+ return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
+ elif isinstance(data, dict):
+ return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
+ else:
+ return data
+
+
+def load_balanced_group_indices(
+ load: List[int],
+ num_groups: int,
+ equal_size: bool = False,
+) -> List[List[int]]:
+ """
+ Split indices into groups with balanced load.
+ """
+ if equal_size:
+ group_size = len(load) // num_groups
+ indices = np.argsort(load)[::-1]
+ groups = [[] for _ in range(num_groups)]
+ group_load = np.zeros(num_groups)
+ for idx in indices:
+ min_group_idx = np.argmin(group_load)
+ groups[min_group_idx].append(idx)
+ if equal_size and len(groups[min_group_idx]) == group_size:
+ group_load[min_group_idx] = float('inf')
+ else:
+ group_load[min_group_idx] += load[idx]
+ return groups
+
+
+def cycle(data_loader: DataLoader) -> Iterator:
+ while True:
+ for data in data_loader:
+ if isinstance(data_loader.sampler, ResumableSampler):
+ data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined]
+ yield data
+ if isinstance(data_loader.sampler, DistributedSampler):
+ data_loader.sampler.epoch += 1
+ if isinstance(data_loader.sampler, ResumableSampler):
+ data_loader.sampler.epoch += 1
+ data_loader.sampler.idx = 0
+
+
+class ResumableSampler(Sampler):
+ """
+ Distributed sampler that is resumable.
+
+ Args:
+ dataset: Dataset used for sampling.
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
+ By default, :attr:`rank` is retrieved from the current distributed
+ group.
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
+ indices.
+ seed (int, optional): random seed used to shuffle the sampler if
+ :attr:`shuffle=True`. This number should be identical across all
+ processes in the distributed group. Default: ``0``.
+ drop_last (bool, optional): if ``True``, then the sampler will drop the
+ tail of the data to make it evenly divisible across the number of
+ replicas. If ``False``, the sampler will add extra indices to make
+ the data evenly divisible across the replicas. Default: ``False``.
+ """
+
+ def __init__(
+ self,
+ dataset: Dataset,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ ) -> None:
+ self.dataset = dataset
+ self.epoch = 0
+ self.idx = 0
+ self.drop_last = drop_last
+ self.world_size = dist.get_world_size() if dist.is_initialized() else 1
+ self.rank = dist.get_rank() if dist.is_initialized() else 0
+ # If the dataset length is evenly divisible by # of replicas, then there
+ # is no need to drop any data, since the dataset will be split equally.
+ if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
+ # Split to nearest available length that is evenly divisible.
+ # This is to ensure each rank receives the same amount of data when
+ # using this Sampler.
+ self.num_samples = math.ceil(
+ (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type]
+ )
+ else:
+ self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
+ self.total_size = self.num_samples * self.world_size
+ self.shuffle = shuffle
+ self.seed = seed
+
+ def __iter__(self) -> Iterator:
+ if self.shuffle:
+ # deterministically shuffle based on epoch and seed
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
+ else:
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
+
+ if not self.drop_last:
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * math.ceil(padding_size / len(indices)))[
+ :padding_size
+ ]
+ else:
+ # remove tail of data to make it evenly divisible.
+ indices = indices[: self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.world_size]
+
+ # resume from previous state
+ indices = indices[self.idx:]
+
+ return iter(indices)
+
+ def __len__(self) -> int:
+ return self.num_samples
+
+ def state_dict(self) -> dict[str, int]:
+ return {
+ 'epoch': self.epoch,
+ 'idx': self.idx,
+ }
+
+ def load_state_dict(self, state_dict):
+ self.epoch = state_dict['epoch']
+ self.idx = state_dict['idx']
+
+
+class BalancedResumableSampler(ResumableSampler):
+ """
+ Distributed sampler that is resumable and balances the load among the processes.
+
+ Args:
+ dataset: Dataset used for sampling.
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
+ By default, :attr:`rank` is retrieved from the current distributed
+ group.
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
+ indices.
+ seed (int, optional): random seed used to shuffle the sampler if
+ :attr:`shuffle=True`. This number should be identical across all
+ processes in the distributed group. Default: ``0``.
+ drop_last (bool, optional): if ``True``, then the sampler will drop the
+ tail of the data to make it evenly divisible across the number of
+ replicas. If ``False``, the sampler will add extra indices to make
+ the data evenly divisible across the replicas. Default: ``False``.
+ """
+
+ def __init__(
+ self,
+ dataset: Dataset,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ batch_size: int = 1,
+ ) -> None:
+ assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
+ super().__init__(dataset, shuffle, seed, drop_last)
+ self.batch_size = batch_size
+ self.loads = dataset.loads
+
+ def __iter__(self) -> Iterator:
+ if self.shuffle:
+ # deterministically shuffle based on epoch and seed
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
+ else:
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
+
+ if not self.drop_last:
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * math.ceil(padding_size / len(indices)))[
+ :padding_size
+ ]
+ else:
+ # remove tail of data to make it evenly divisible.
+ indices = indices[: self.total_size]
+ assert len(indices) == self.total_size
+
+ # balance load among processes
+ num_batches = len(indices) // (self.batch_size * self.world_size)
+ balanced_indices = []
+ for i in range(num_batches):
+ start_idx = i * self.batch_size * self.world_size
+ end_idx = (i + 1) * self.batch_size * self.world_size
+ batch_indices = indices[start_idx:end_idx]
+ batch_loads = [self.loads[idx] for idx in batch_indices]
+ groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
+ balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
+
+ # resume from previous state
+ indices = balanced_indices[self.idx:]
+
+ return iter(indices)
diff --git a/trellis2/utils/dist_utils.py b/trellis2/utils/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f57e2060ea6d7b208a7322662fcbee13acf1d096
--- /dev/null
+++ b/trellis2/utils/dist_utils.py
@@ -0,0 +1,94 @@
+import os
+import io
+from datetime import timedelta
+from contextlib import contextmanager
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+def setup_dist(rank, local_rank, world_size, master_addr, master_port):
+ os.environ['MASTER_ADDR'] = master_addr
+ os.environ['MASTER_PORT'] = master_port
+ os.environ['WORLD_SIZE'] = str(world_size)
+ os.environ['RANK'] = str(rank)
+ os.environ['LOCAL_RANK'] = str(local_rank)
+ torch.cuda.set_device(local_rank)
+ dist.init_process_group('nccl', rank=rank, world_size=world_size, timeout=timedelta(hours=10))
+
+
+def read_file_dist(path):
+ """
+ Read the binary file distributedly.
+ File is only read once by the rank 0 process and broadcasted to other processes.
+
+ Returns:
+ data (io.BytesIO): The binary data read from the file.
+ """
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ # read file
+ size = torch.LongTensor(1).cuda()
+ if dist.get_rank() == 0:
+ with open(path, 'rb') as f:
+ data = f.read()
+ data = torch.ByteTensor(
+ torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
+ ).cuda()
+ size[0] = data.shape[0]
+ # broadcast size
+ dist.broadcast(size, src=0)
+ if dist.get_rank() != 0:
+ data = torch.ByteTensor(size[0].item()).cuda()
+ # broadcast data
+ dist.broadcast(data, src=0)
+ # convert to io.BytesIO
+ data = data.cpu().numpy().tobytes()
+ data = io.BytesIO(data)
+ return data
+ else:
+ with open(path, 'rb') as f:
+ data = f.read()
+ data = io.BytesIO(data)
+ return data
+
+
+def unwrap_dist(model):
+ """
+ Unwrap the model from distributed training.
+ """
+ if isinstance(model, DDP):
+ return model.module
+ return model
+
+
+@contextmanager
+def master_first():
+ """
+ A context manager that ensures master process executes first.
+ """
+ if not dist.is_initialized():
+ yield
+ else:
+ if dist.get_rank() == 0:
+ yield
+ dist.barrier()
+ else:
+ dist.barrier()
+ yield
+
+
+@contextmanager
+def local_master_first():
+ """
+ A context manager that ensures local master process executes first.
+ """
+ if not dist.is_initialized():
+ yield
+ else:
+ if dist.get_rank() % torch.cuda.device_count() == 0:
+ yield
+ dist.barrier()
+ else:
+ dist.barrier()
+ yield
+
\ No newline at end of file
diff --git a/trellis2/utils/elastic_utils.py b/trellis2/utils/elastic_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cba3cf83836e5b58f5bc3333e809ffc932375a04
--- /dev/null
+++ b/trellis2/utils/elastic_utils.py
@@ -0,0 +1,228 @@
+from abc import abstractmethod
+from contextlib import contextmanager
+from typing import Tuple
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class MemoryController:
+ """
+ Base class for memory management during training.
+ """
+
+ _last_input_size = None
+ _last_mem_ratio = []
+
+ @contextmanager
+ def record(self):
+ pass
+
+ def update_run_states(self, input_size=None, mem_ratio=None):
+ if self._last_input_size is None:
+ self._last_input_size = input_size
+ elif self._last_input_size!= input_size:
+ raise ValueError(f'Input size should not change for different ElasticModules.')
+ self._last_mem_ratio.append(mem_ratio)
+
+ @abstractmethod
+ def get_mem_ratio(self, input_size):
+ pass
+
+ @abstractmethod
+ def state_dict(self):
+ pass
+
+ @abstractmethod
+ def log(self):
+ pass
+
+
+class LinearMemoryController(MemoryController):
+ """
+ A simple controller for memory management during training.
+ The memory usage is modeled as a linear function of:
+ - the number of input parameters
+ - the ratio of memory the model use compared to the maximum usage (with no checkpointing)
+ memory_usage = k * input_size * mem_ratio + b
+ The controller keeps track of the memory usage and gives the
+ expected memory ratio to keep the memory usage under a target
+ """
+ def __init__(
+ self,
+ buffer_size=1000,
+ update_every=500,
+ target_ratio=0.8,
+ available_memory=None,
+ max_mem_ratio_start=0.1,
+ params=None,
+ device=None
+ ):
+ self.buffer_size = buffer_size
+ self.update_every = update_every
+ self.target_ratio = target_ratio
+ self.device = device or torch.cuda.current_device()
+ self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
+
+ self._memory = np.zeros(buffer_size, dtype=np.float32)
+ self._input_size = np.zeros(buffer_size, dtype=np.float32)
+ self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
+ self._buffer_ptr = 0
+ self._buffer_length = 0
+ self._params = tuple(params) if params is not None else (0.0, 0.0)
+ self._max_mem_ratio = max_mem_ratio_start
+ self.step = 0
+
+ def __repr__(self):
+ return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
+
+ def _add_sample(self, memory, input_size, mem_ratio):
+ self._memory[self._buffer_ptr] = memory
+ self._input_size[self._buffer_ptr] = input_size
+ self._mem_ratio[self._buffer_ptr] = mem_ratio
+ self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
+ self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
+
+ @contextmanager
+ def record(self):
+ torch.cuda.reset_peak_memory_stats(self.device)
+ self._last_input_size = None
+ self._last_mem_ratio = []
+ yield
+ self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
+ self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
+ self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
+ self.step += 1
+ if self.step % self.update_every == 0:
+ self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
+ self._fit_params()
+
+ def _fit_params(self):
+ memory_usage = self._memory[:self._buffer_length]
+ input_size = self._input_size[:self._buffer_length]
+ mem_ratio = self._mem_ratio[:self._buffer_length]
+
+ x = input_size * mem_ratio
+ y = memory_usage
+ k, b = np.polyfit(x, y, 1)
+ self._params = (k, b)
+ # self._visualize()
+
+ def _visualize(self):
+ import matplotlib.pyplot as plt
+ memory_usage = self._memory[:self._buffer_length]
+ input_size = self._input_size[:self._buffer_length]
+ mem_ratio = self._mem_ratio[:self._buffer_length]
+ k, b = self._params
+
+ plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
+ x = np.array([0.0, 20000.0])
+ plt.plot(x, k * x + b, c='r')
+ plt.savefig(f'linear_memory_controller_{self.step}.png')
+ plt.cla()
+
+ def get_mem_ratio(self, input_size):
+ k, b = self._params
+ if k == 0: return np.random.rand() * self._max_mem_ratio
+ pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
+ return min(self._max_mem_ratio, max(0.0, pred))
+
+ def state_dict(self):
+ return {
+ 'params': self._params,
+ }
+
+ def load_state_dict(self, state_dict):
+ self._params = tuple(state_dict['params'])
+
+ def log(self):
+ return {
+ 'params/k': self._params[0],
+ 'params/b': self._params[1],
+ 'memory': self._last_memory,
+ 'input_size': self._last_input_size,
+ 'mem_ratio': self._last_mem_ratio,
+ }
+
+
+class ElasticModule(nn.Module):
+ """
+ Module for training with elastic memory management.
+ """
+ def __init__(self):
+ super().__init__()
+ self._memory_controller: MemoryController = None
+
+ @abstractmethod
+ def _get_input_size(self, *args, **kwargs) -> int:
+ """
+ Get the size of the input data.
+
+ Returns:
+ int: The size of the input data.
+ """
+ pass
+
+ @abstractmethod
+ def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
+ """
+ Forward with a given memory ratio.
+ """
+ pass
+
+ def register_memory_controller(self, memory_controller: MemoryController):
+ self._memory_controller = memory_controller
+
+ def forward(self, *args, **kwargs):
+ if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
+ _, ret = self._forward_with_mem_ratio(*args, **kwargs)
+ else:
+ input_size = self._get_input_size(*args, **kwargs)
+ mem_ratio = self._memory_controller.get_mem_ratio(input_size)
+ mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
+ self._memory_controller.update_run_states(input_size, mem_ratio)
+ return ret
+
+
+class ElasticModuleMixin:
+ """
+ Mixin for training with elastic memory management.
+ """
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._memory_controller: MemoryController = None
+
+ @abstractmethod
+ def _get_input_size(self, *args, **kwargs) -> int:
+ """
+ Get the size of the input data.
+
+ Returns:
+ int: The size of the input data.
+ """
+ pass
+
+ @abstractmethod
+ @contextmanager
+ def with_mem_ratio(self, mem_ratio=1.0) -> float:
+ """
+ Context manager for training with a reduced memory ratio compared to the full memory usage.
+
+ Returns:
+ float: The exact memory ratio used during the forward pass.
+ """
+ pass
+
+ def register_memory_controller(self, memory_controller: MemoryController):
+ self._memory_controller = memory_controller
+
+ def forward(self, *args, **kwargs):
+ if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
+ ret = super().forward(*args, **kwargs)
+ else:
+ input_size = self._get_input_size(*args, **kwargs)
+ mem_ratio = self._memory_controller.get_mem_ratio(input_size)
+ with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
+ ret = super().forward(*args, **kwargs)
+ self._memory_controller.update_run_states(input_size, exact_mem_ratio)
+ return ret
diff --git a/trellis2/utils/general_utils.py b/trellis2/utils/general_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..589c103de8a777aea9994a899f97431cbab5447a
--- /dev/null
+++ b/trellis2/utils/general_utils.py
@@ -0,0 +1,373 @@
+import re
+import numpy as np
+import cv2
+import torch
+import contextlib
+
+
+# Dictionary utils
+def _dict_merge(dicta, dictb, prefix=''):
+ """
+ Merge two dictionaries.
+ """
+ assert isinstance(dicta, dict), 'input must be a dictionary'
+ assert isinstance(dictb, dict), 'input must be a dictionary'
+ dict_ = {}
+ all_keys = set(dicta.keys()).union(set(dictb.keys()))
+ for key in all_keys:
+ if key in dicta.keys() and key in dictb.keys():
+ if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
+ dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
+ else:
+ raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
+ elif key in dicta.keys():
+ dict_[key] = dicta[key]
+ else:
+ dict_[key] = dictb[key]
+ return dict_
+
+
+def dict_merge(dicta, dictb):
+ """
+ Merge two dictionaries.
+ """
+ return _dict_merge(dicta, dictb, prefix='')
+
+
+def dict_foreach(dic, func, special_func={}):
+ """
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
+ """
+ assert isinstance(dic, dict), 'input must be a dictionary'
+ for key in dic.keys():
+ if isinstance(dic[key], dict):
+ dic[key] = dict_foreach(dic[key], func)
+ else:
+ if key in special_func.keys():
+ dic[key] = special_func[key](dic[key])
+ else:
+ dic[key] = func(dic[key])
+ return dic
+
+
+def dict_reduce(dicts, func, special_func={}):
+ """
+ Reduce a list of dictionaries. Leaf values must be scalars.
+ """
+ assert isinstance(dicts, list), 'input must be a list of dictionaries'
+ assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
+ assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
+ all_keys = set([key for dict_ in dicts for key in dict_.keys()])
+ reduced_dict = {}
+ for key in all_keys:
+ vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
+ if isinstance(vlist[0], dict):
+ reduced_dict[key] = dict_reduce(vlist, func, special_func)
+ else:
+ if key in special_func.keys():
+ reduced_dict[key] = special_func[key](vlist)
+ else:
+ reduced_dict[key] = func(vlist)
+ return reduced_dict
+
+
+def dict_any(dic, func):
+ """
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
+ """
+ assert isinstance(dic, dict), 'input must be a dictionary'
+ for key in dic.keys():
+ if isinstance(dic[key], dict):
+ if dict_any(dic[key], func):
+ return True
+ else:
+ if func(dic[key]):
+ return True
+ return False
+
+
+def dict_all(dic, func):
+ """
+ Recursively apply a function to all non-dictionary leaf values in a dictionary.
+ """
+ assert isinstance(dic, dict), 'input must be a dictionary'
+ for key in dic.keys():
+ if isinstance(dic[key], dict):
+ if not dict_all(dic[key], func):
+ return False
+ else:
+ if not func(dic[key]):
+ return False
+ return True
+
+
+def dict_flatten(dic, sep='.'):
+ """
+ Flatten a nested dictionary into a dictionary with no nested dictionaries.
+ """
+ assert isinstance(dic, dict), 'input must be a dictionary'
+ flat_dict = {}
+ for key in dic.keys():
+ if isinstance(dic[key], dict):
+ sub_dict = dict_flatten(dic[key], sep=sep)
+ for sub_key in sub_dict.keys():
+ flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
+ else:
+ flat_dict[key] = dic[key]
+ return flat_dict
+
+
+# Context utils
+@contextlib.contextmanager
+def nested_contexts(*contexts):
+ with contextlib.ExitStack() as stack:
+ for ctx in contexts:
+ stack.enter_context(ctx())
+ yield
+
+
+# Image utils
+def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
+ num_images = len(images)
+ if nrow is None and ncol is None:
+ if aspect_ratio is not None:
+ nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
+ else:
+ nrow = int(np.sqrt(num_images))
+ ncol = (num_images + nrow - 1) // nrow
+ elif nrow is None and ncol is not None:
+ nrow = (num_images + ncol - 1) // ncol
+ elif nrow is not None and ncol is None:
+ ncol = (num_images + nrow - 1) // nrow
+ else:
+ assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
+
+ if images[0].ndim == 2:
+ grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
+ else:
+ grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
+ for i, img in enumerate(images):
+ row = i // ncol
+ col = i % ncol
+ grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
+ return grid
+
+
+def notes_on_image(img, notes=None):
+ img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ if notes is not None:
+ img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+
+
+def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"):
+ """
+ Draw text on an image of the given resolution. The text is automatically wrapped
+ and scaled so that it fits completely within the image while preserving any explicit
+ line breaks and original spacing. Horizontal and vertical alignment can be controlled
+ via flags.
+
+ Parameters:
+ text (str): The input text. Newline characters and spacing are preserved.
+ resolution (tuple): The image resolution as (width, height).
+ max_size (float): The maximum font size.
+ h_align (str): Horizontal alignment. Options: "left", "center", "right".
+ v_align (str): Vertical alignment. Options: "top", "center", "bottom".
+
+ Returns:
+ numpy.ndarray: The resulting image (BGR format) with the text drawn.
+ """
+ width, height = resolution
+ # Create a white background image
+ img = np.full((height, width, 3), 255, dtype=np.uint8)
+
+ # Set margins and compute available drawing area
+ margin = 10
+ avail_width = width - 2 * margin
+ avail_height = height - 2 * margin
+
+ # Choose OpenCV font and text thickness
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ thickness = 1
+ # Ratio for additional spacing between lines (relative to the height of "A")
+ line_spacing_ratio = 0.5
+
+ def wrap_line(line, max_width, font, thickness, scale):
+ """
+ Wrap a single line of text into multiple lines such that each line's
+ width (measured at the given scale) does not exceed max_width.
+ This function preserves the original spacing by splitting the line into tokens
+ (words and whitespace) using a regular expression.
+
+ Parameters:
+ line (str): The input text line.
+ max_width (int): Maximum allowed width in pixels.
+ font (int): OpenCV font identifier.
+ thickness (int): Text thickness.
+ scale (float): The current font scale.
+
+ Returns:
+ List[str]: A list of wrapped lines.
+ """
+ # Split the line into tokens (words and whitespace), preserving spacing
+ tokens = re.split(r'(\s+)', line)
+ if not tokens:
+ return ['']
+
+ wrapped_lines = []
+ current_line = ""
+ for token in tokens:
+ candidate = current_line + token
+ candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0]
+ if candidate_width <= max_width:
+ current_line = candidate
+ else:
+ # If current_line is empty, the token itself is too wide;
+ # break the token character by character.
+ if current_line == "":
+ sub_token = ""
+ for char in token:
+ candidate_char = sub_token + char
+ if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width:
+ sub_token = candidate_char
+ else:
+ if sub_token:
+ wrapped_lines.append(sub_token)
+ sub_token = char
+ current_line = sub_token
+ else:
+ wrapped_lines.append(current_line)
+ current_line = token
+ if current_line:
+ wrapped_lines.append(current_line)
+ return wrapped_lines
+
+ def compute_text_block(scale):
+ """
+ Wrap the entire text (splitting at explicit newline characters) using the
+ provided scale, and then compute the overall width and height of the text block.
+
+ Returns:
+ wrapped_lines (List[str]): The list of wrapped lines.
+ block_width (int): Maximum width among the wrapped lines.
+ block_height (int): Total height of the text block including spacing.
+ sizes (List[tuple]): A list of (width, height) for each wrapped line.
+ spacing (int): The spacing between lines (computed from the scaled "A" height).
+ """
+ # Split text by explicit newlines
+ input_lines = text.splitlines() if text else ['']
+ wrapped_lines = []
+ for line in input_lines:
+ wrapped = wrap_line(line, avail_width, font, thickness, scale)
+ wrapped_lines.extend(wrapped)
+
+ sizes = []
+ for line in wrapped_lines:
+ (text_size, _) = cv2.getTextSize(line, font, scale, thickness)
+ sizes.append(text_size) # (width, height)
+
+ block_width = max((w for w, h in sizes), default=0)
+ # Use the height of "A" (at the current scale) to compute line spacing
+ base_height = cv2.getTextSize("A", font, scale, thickness)[0][1]
+ spacing = int(line_spacing_ratio * base_height)
+ block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0
+
+ return wrapped_lines, block_width, block_height, sizes, spacing
+
+ # Use binary search to find the maximum scale that allows the text block to fit
+ lo = 0.001
+ hi = max_size
+ eps = 0.001 # convergence threshold
+ best_scale = lo
+ best_result = None
+
+ while hi - lo > eps:
+ mid = (lo + hi) / 2
+ wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid)
+ # Ensure that both width and height constraints are met
+ if block_width <= avail_width and block_height <= avail_height:
+ best_scale = mid
+ best_result = (wrapped_lines, block_width, block_height, sizes, spacing)
+ lo = mid # try a larger scale
+ else:
+ hi = mid # reduce the scale
+
+ if best_result is None:
+ best_scale = 0.5
+ best_result = compute_text_block(best_scale)
+
+ wrapped_lines, block_width, block_height, sizes, spacing = best_result
+
+ # Compute starting y-coordinate based on vertical alignment flag
+ if v_align == "top":
+ y_top = margin
+ elif v_align == "center":
+ y_top = margin + (avail_height - block_height) // 2
+ elif v_align == "bottom":
+ y_top = margin + (avail_height - block_height)
+ else:
+ y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag
+
+ # For cv2.putText, the y coordinate represents the text baseline;
+ # so for the first line add its height.
+ y = y_top + (sizes[0][1] if sizes else 0)
+
+ # Draw each line with horizontal alignment based on the flag
+ for i, line in enumerate(wrapped_lines):
+ line_width, line_height = sizes[i]
+ if h_align == "left":
+ x = margin
+ elif h_align == "center":
+ x = margin + (avail_width - line_width) // 2
+ elif h_align == "right":
+ x = margin + (avail_width - line_width)
+ else:
+ x = margin # default to left if invalid flag
+
+ cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA)
+ y += line_height + spacing
+
+ return img
+
+
+def save_image_with_notes(img, path, notes=None):
+ """
+ Save an image with notes.
+ """
+ if isinstance(img, torch.Tensor):
+ img = img.cpu().numpy().transpose(1, 2, 0)
+ if img.dtype == np.float32 or img.dtype == np.float64:
+ img = np.clip(img * 255, 0, 255).astype(np.uint8)
+ img = notes_on_image(img, notes)
+ cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+
+
+# debug utils
+
+def atol(x, y):
+ """
+ Absolute tolerance.
+ """
+ return torch.abs(x - y)
+
+
+def rtol(x, y):
+ """
+ Relative tolerance.
+ """
+ return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
+
+
+# print utils
+def indent(s, n=4):
+ """
+ Indent a string.
+ """
+ lines = s.split('\n')
+ for i in range(1, len(lines)):
+ lines[i] = ' ' * n + lines[i]
+ return '\n'.join(lines)
+
diff --git a/trellis2/utils/grad_clip_utils.py b/trellis2/utils/grad_clip_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..990a4352e24fc73bf732d8eb0f8ca9a07365b49e
--- /dev/null
+++ b/trellis2/utils/grad_clip_utils.py
@@ -0,0 +1,81 @@
+from typing import *
+import torch
+import numpy as np
+import torch.utils
+
+
+class AdaptiveGradClipper:
+ """
+ Adaptive gradient clipping for training.
+ """
+ def __init__(
+ self,
+ max_norm=None,
+ clip_percentile=95.0,
+ buffer_size=1000,
+ ):
+ self.max_norm = max_norm
+ self.clip_percentile = clip_percentile
+ self.buffer_size = buffer_size
+
+ self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
+ self._max_norm = max_norm
+ self._buffer_ptr = 0
+ self._buffer_length = 0
+
+ def __repr__(self):
+ return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
+
+ def state_dict(self):
+ return {
+ 'grad_norm': self._grad_norm,
+ 'max_norm': self._max_norm,
+ 'buffer_ptr': self._buffer_ptr,
+ 'buffer_length': self._buffer_length,
+ }
+
+ def load_state_dict(self, state_dict):
+ self._grad_norm = state_dict['grad_norm']
+ self._max_norm = state_dict['max_norm']
+ self._buffer_ptr = state_dict['buffer_ptr']
+ self._buffer_length = state_dict['buffer_length']
+
+ def log(self):
+ return {
+ 'max_norm': self._max_norm,
+ }
+
+ def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
+ """Clip the gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ norm_type (float): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+ error_if_nonfinite (bool): if True, an error is thrown if the total
+ norm of the gradients from :attr:`parameters` is ``nan``,
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
+ foreach (bool): use the faster foreach-based implementation.
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
+ fall back to the slow implementation for other device types.
+ Default: ``None``
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """
+ max_norm = self._max_norm if self._max_norm is not None else float('inf')
+ grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
+
+ if torch.isfinite(grad_norm):
+ self._grad_norm[self._buffer_ptr] = grad_norm
+ self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
+ self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
+ if self._buffer_length == self.buffer_size:
+ self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
+ self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
+
+ return grad_norm
\ No newline at end of file
diff --git a/trellis2/utils/loss_utils.py b/trellis2/utils/loss_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..52049f69543f2700bc5525b09cbf2fb25c08aa9e
--- /dev/null
+++ b/trellis2/utils/loss_utils.py
@@ -0,0 +1,92 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+from math import exp
+from lpips import LPIPS
+
+
+def smooth_l1_loss(pred, target, beta=1.0):
+ diff = torch.abs(pred - target)
+ loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
+ return loss.mean()
+
+
+def l1_loss(network_output, gt):
+ return torch.abs((network_output - gt)).mean()
+
+
+def l2_loss(network_output, gt):
+ return ((network_output - gt) ** 2).mean()
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+ return window
+
+
+def psnr(img1, img2, max_val=1.0):
+ mse = F.mse_loss(img1, img2)
+ return 20 * torch.log10(max_val / torch.sqrt(mse))
+
+
+def ssim(img1, img2, window_size=11, size_average=True):
+ channel = img1.size(-3)
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return _ssim(img1, img2, window, window_size, channel, size_average)
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
+
+ C1 = 0.01 ** 2
+ C2 = 0.03 ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+
+loss_fn_vgg = None
+def lpips(img1, img2, value_range=(0, 1)):
+ global loss_fn_vgg
+ if loss_fn_vgg is None:
+ loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
+ # normalize to [-1, 1]
+ img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
+ img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
+ return loss_fn_vgg(img1, img2).mean()
+
+
+def normal_angle(pred, gt):
+ pred = pred * 2.0 - 1.0
+ gt = gt * 2.0 - 1.0
+ norms = pred.norm(dim=-1) * gt.norm(dim=-1)
+ cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
+ cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
+ ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
+ if ang.isnan():
+ return -1
+ return ang
diff --git a/trellis2/utils/mesh_utils.py b/trellis2/utils/mesh_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9f1451ebd8b89879eee79cc61a6f4161136f245
--- /dev/null
+++ b/trellis2/utils/mesh_utils.py
@@ -0,0 +1,268 @@
+from typing import Tuple, Dict
+import numpy as np
+from trimesh import grouping, util, remesh
+import struct
+import re
+from plyfile import PlyData, PlyElement
+
+
+def read_ply(filename):
+ """
+ Read a PLY file and return vertices, triangle faces, and quad faces.
+
+ Args:
+ filename (str): The file path to read from.
+
+ Returns:
+ vertices (np.ndarray): Array of shape [N, 3] containing vertex positions.
+ tris (np.ndarray): Array of shape [M, 3] containing triangle face indices (empty if none).
+ quads (np.ndarray): Array of shape [K, 4] containing quad face indices (empty if none).
+ """
+ with open(filename, 'rb') as f:
+ # Read the header until 'end_header' is encountered
+ header_bytes = b""
+ while True:
+ line = f.readline()
+ if not line:
+ raise ValueError("PLY header not found")
+ header_bytes += line
+ if b"end_header" in line:
+ break
+ header = header_bytes.decode('utf-8')
+
+ # Determine if the file is in ASCII or binary format
+ is_ascii = "ascii" in header
+
+ # Extract the number of vertices and faces from the header using regex
+ vertex_match = re.search(r'element vertex (\d+)', header)
+ if vertex_match:
+ num_vertices = int(vertex_match.group(1))
+ else:
+ raise ValueError("Vertex count not found in header")
+
+ face_match = re.search(r'element face (\d+)', header)
+ if face_match:
+ num_faces = int(face_match.group(1))
+ else:
+ raise ValueError("Face count not found in header")
+
+ vertices = []
+ tris = []
+ quads = []
+
+ if is_ascii:
+ # For ASCII format, read each line of vertex data (each line contains 3 floats)
+ for _ in range(num_vertices):
+ line = f.readline().decode('utf-8').strip()
+ if not line:
+ continue
+ parts = line.split()
+ vertices.append([float(parts[0]), float(parts[1]), float(parts[2])])
+
+ # Read face data, where the first number indicates the number of vertices for the face
+ for _ in range(num_faces):
+ line = f.readline().decode('utf-8').strip()
+ if not line:
+ continue
+ parts = line.split()
+ count = int(parts[0])
+ indices = list(map(int, parts[1:]))
+ if count == 3:
+ tris.append(indices)
+ elif count == 4:
+ quads.append(indices)
+ else:
+ # Skip faces with other numbers of vertices (can be extended as needed)
+ pass
+ else:
+ # For binary format: read directly from the binary stream
+ # Each vertex consists of 3 floats (12 bytes per vertex)
+ for _ in range(num_vertices):
+ data = f.read(12)
+ if len(data) < 12:
+ raise ValueError("Insufficient vertex data")
+ v = struct.unpack(' 0 else np.empty((0, 3), dtype=np.int32)
+ quads = np.array(quads, dtype=np.int32) if len(quads) > 0 else np.empty((0, 4), dtype=np.int32)
+
+ return vertices, tris, quads
+
+
+def write_ply(
+ filename: str,
+ vertices: np.ndarray,
+ tris: np.ndarray,
+ quads: np.ndarray,
+ vertex_colors: np.ndarray = None,
+ ascii: bool = False
+):
+ """
+ Write a mesh to a PLY file, with the option to save in ASCII or binary format,
+ and optional per-vertex colors.
+
+ Args:
+ filename (str): The filename to write to.
+ vertices (np.ndarray): [N, 3] The vertex positions.
+ tris (np.ndarray): [M, 3] The triangle indices.
+ quads (np.ndarray): [K, 4] The quad indices.
+ vertex_colors (np.ndarray, optional): [N, 3] or [N, 4] UInt8 colors for each vertex (RGB or RGBA).
+ ascii (bool): If True, write in ASCII format; otherwise binary little-endian.
+ """
+ import struct
+
+ num_vertices = len(vertices)
+ num_faces = len(tris) + len(quads)
+
+ # Build header
+ header_lines = [
+ "ply",
+ f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}",
+ f"element vertex {num_vertices}",
+ "property float x",
+ "property float y",
+ "property float z",
+ ]
+
+ # Add vertex color properties if provided
+ has_color = vertex_colors is not None
+ if has_color:
+ # Expect uint8 values 0-255
+ header_lines += [
+ "property uchar red",
+ "property uchar green",
+ "property uchar blue",
+ ]
+ # Include alpha if RGBA
+ if vertex_colors.shape[1] == 4:
+ header_lines.append("property uchar alpha")
+
+ header_lines += [
+ f"element face {num_faces}",
+ "property list uchar int vertex_index",
+ "end_header",
+ ""
+ ]
+ header = "\n".join(header_lines)
+
+ mode = 'w' if ascii else 'wb'
+ with open(filename, mode) as f:
+ # Write header
+ if ascii:
+ f.write(header)
+ else:
+ f.write(header.encode('utf-8'))
+
+ # Write vertex data
+ for i, v in enumerate(vertices):
+ if ascii:
+ line = f"{v[0]} {v[1]} {v[2]}"
+ if has_color:
+ col = vertex_colors[i]
+ line += ' ' + ' '.join(str(int(c)) for c in col)
+ f.write(line + '\n')
+ else:
+ # pack xyz as floats
+ f.write(struct.pack(' 0:
+ digit = n % base
+ val += digit * inv_base_n
+ n //= base
+ inv_base_n *= inv_base
+ return val
+
+def halton_sequence(dim, n):
+ return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
+
+def hammersley_sequence(dim, n, num_samples):
+ return [n / num_samples] + halton_sequence(dim - 1, n)
+
+def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
+ u, v = hammersley_sequence(2, n, num_samples)
+ u += offset[0] / num_samples
+ v += offset[1]
+ if remap:
+ u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
+ theta = np.arccos(1 - 2 * u) - np.pi / 2
+ phi = v * 2 * np.pi
+ return [phi, theta]
\ No newline at end of file
diff --git a/trellis2/utils/render_utils.py b/trellis2/utils/render_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d7d3ede7a1974bcd905fbc16812d9c5eddea9ec
--- /dev/null
+++ b/trellis2/utils/render_utils.py
@@ -0,0 +1,225 @@
+import gc
+import torch
+import numpy as np
+from tqdm import tqdm
+import utils3d
+from PIL import Image
+
+from ..renderers import MeshRenderer, VoxelRenderer, PbrMeshRenderer
+from ..representations import Mesh, Voxel, MeshWithPbrMaterial, MeshWithVoxel
+from .random_utils import sphere_hammersley_sequence
+
+
+def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
+ is_list = isinstance(yaws, list)
+ if not is_list:
+ yaws = [yaws]
+ pitchs = [pitchs]
+ if not isinstance(rs, list):
+ rs = [rs] * len(yaws)
+ if not isinstance(fovs, list):
+ fovs = [fovs] * len(yaws)
+ extrinsics = []
+ intrinsics = []
+ for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
+ fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
+ yaw = torch.tensor(float(yaw)).cuda()
+ pitch = torch.tensor(float(pitch)).cuda()
+ orig = torch.tensor([
+ torch.sin(yaw) * torch.cos(pitch),
+ torch.cos(yaw) * torch.cos(pitch),
+ torch.sin(pitch),
+ ]).cuda() * r
+ extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
+ intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
+ extrinsics.append(extr)
+ intrinsics.append(intr)
+ if not is_list:
+ extrinsics = extrinsics[0]
+ intrinsics = intrinsics[0]
+ return extrinsics, intrinsics
+
+
+def get_renderer(sample, **kwargs):
+ if isinstance(sample, (MeshWithPbrMaterial, MeshWithVoxel)):
+ renderer = PbrMeshRenderer()
+ renderer.rendering_options.resolution = kwargs.get('resolution', 512)
+ renderer.rendering_options.near = kwargs.get('near', 1)
+ renderer.rendering_options.far = kwargs.get('far', 100)
+ renderer.rendering_options.ssaa = kwargs.get('ssaa', 2)
+ renderer.rendering_options.peel_layers = kwargs.get('peel_layers', 8)
+ elif isinstance(sample, Mesh):
+ renderer = MeshRenderer()
+ renderer.rendering_options.resolution = kwargs.get('resolution', 512)
+ renderer.rendering_options.near = kwargs.get('near', 1)
+ renderer.rendering_options.far = kwargs.get('far', 100)
+ renderer.rendering_options.ssaa = kwargs.get('ssaa', 2)
+ renderer.rendering_options.chunk_size = kwargs.get('chunk_size', None)
+ elif isinstance(sample, Voxel):
+ renderer = VoxelRenderer()
+ renderer.rendering_options.resolution = kwargs.get('resolution', 512)
+ renderer.rendering_options.near = kwargs.get('near', 0.1)
+ renderer.rendering_options.far = kwargs.get('far', 10.0)
+ renderer.rendering_options.ssaa = kwargs.get('ssaa', 2)
+ else:
+ raise ValueError(f'Unsupported sample type: {type(sample)}')
+ return renderer
+
+
+@torch.no_grad()
+def render_frames(sample, extrinsics, intrinsics, options={}, verbose=True, renderer=None, **kwargs):
+ if renderer is None:
+ renderer = get_renderer(sample, **options)
+ rets = {}
+ for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), total=len(extrinsics), desc='Rendering', disable=not verbose):
+ res = renderer.render(sample, extr, intr, **kwargs)
+ for k, v in res.items():
+ if k not in rets: rets[k] = []
+ if v.dim() == 2: v = v[None].repeat(3, 1, 1)
+ rets[k].append(np.clip(v.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
+ return rets
+
+
+def render_video(sample, resolution=1024, bg_color=(0, 0, 0), num_frames=40, r=2, fov=40,
+ start_yaw=None, start_pitch=None, **kwargs):
+ """
+ Render a turntable video of the sample.
+
+ Args:
+ start_yaw: Starting yaw angle in radians. If None, defaults to π/2.
+ start_pitch: Starting pitch angle in radians. If None, uses the default oscillating pitch
+ starting at ~0.25.
+ """
+ if start_yaw is None:
+ start_yaw = np.pi / 2
+ yaws = -torch.linspace(0, 2 * 3.1415, num_frames) + start_yaw
+ if start_pitch is not None:
+ pitch = start_pitch + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
+ else:
+ pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
+ yaws = yaws.tolist()
+ pitch = pitch.tolist()
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
+ return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
+
+
+def render_multiview(sample, resolution=512, nviews=30):
+ r = 2
+ fov = 40
+ cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
+ yaws = [cam[0] for cam in cams]
+ pitchs = [cam[1] for cam in cams]
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
+ res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
+ return res['color'], extrinsics, intrinsics
+
+
+def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, nviews=4, **kwargs):
+ yaw = np.linspace(0, 2 * np.pi, nviews, endpoint=False)
+ yaw_offset = offset[0]
+ yaw = [y + yaw_offset for y in yaw]
+ pitch = [offset[1] for _ in range(nviews)]
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
+ return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
+
+
+def proj_camera_to_render_params(camera_angle_x, distance):
+ """
+ Convert proj camera parameters to renderer-compatible extrinsics + intrinsics.
+
+ The proj camera (Blender convention) views from the front. Through empirical
+ testing, this corresponds to extrinsics_look_at from (0, 0, +distance) with up=Y
+ in the mesh coordinate system used by the renderer.
+
+ Args:
+ camera_angle_x: horizontal FOV in radians
+ distance: camera distance
+
+ Returns:
+ extrinsics: [4, 4] OpenCV world-to-camera (on CUDA)
+ intrinsics: [3, 3] OpenCV normalized intrinsics (on CUDA)
+ """
+ orig = torch.tensor([0.0, 0.0, distance]).cuda()
+ target = torch.tensor([0.0, 0.0, 0.0]).cuda()
+ up = torch.tensor([0.0, 1.0, 0.0]).cuda()
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, target, up)
+
+ fov_tensor = torch.tensor(camera_angle_x, dtype=torch.float32).cuda()
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov_tensor, fov_tensor)
+
+ return extrinsics, intrinsics
+
+
+def render_proj_aligned_video(sample, camera_angle_x, distance, resolution=1024,
+ num_frames=40, bg_color=(0, 0, 0), **kwargs):
+ """
+ Render a turntable video starting from the proj camera viewpoint.
+
+ The first frame matches the proj input image exactly. Subsequent frames
+ rotate around the object (around Y axis, which is up in mesh space).
+
+ Args:
+ sample: mesh to render
+ camera_angle_x: proj camera FOV in radians
+ distance: proj camera distance
+ resolution: render resolution
+ num_frames: number of video frames
+ bg_color: background color
+ **kwargs: additional kwargs (e.g. envmap)
+
+ Returns:
+ render result dict (same as render_frames)
+ """
+ import math
+
+ extr_first, intr_first = proj_camera_to_render_params(camera_angle_x, distance)
+
+ extrinsics_list = []
+ intrinsics_list = []
+
+ angles = torch.linspace(0, 2 * math.pi, num_frames + 1)[:num_frames]
+
+ for angle in angles:
+ # Rotation around Y axis (up in mesh space)
+ c = torch.cos(angle)
+ s = torch.sin(angle)
+ R_y = torch.tensor([
+ [ c, 0, s, 0],
+ [ 0, 1, 0, 0],
+ [-s, 0, c, 0],
+ [ 0, 0, 0, 1],
+ ], dtype=torch.float32).cuda()
+
+ # world-to-camera for rotated world: extr @ R_y^{-1}
+ R_y_inv = R_y.clone()
+ R_y_inv[:3, :3] = R_y[:3, :3].T
+ extr_rotated = extr_first @ R_y_inv
+
+ extrinsics_list.append(extr_rotated)
+ intrinsics_list.append(intr_first)
+
+ return render_frames(sample, extrinsics_list, intrinsics_list,
+ {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
+
+
+def make_pbr_vis_frames(result, resolution=1024):
+ num_frames = len(result['shaded'])
+ frames = []
+ for i in range(num_frames):
+ shaded = Image.fromarray(result['shaded'][i])
+ normal = Image.fromarray(result['normal'][i])
+ base_color = Image.fromarray(result['base_color'][i])
+ metallic = Image.fromarray(result['metallic'][i])
+ roughness = Image.fromarray(result['roughness'][i])
+ alpha = Image.fromarray(result['alpha'][i])
+ shaded = shaded.resize((resolution, resolution))
+ normal = normal.resize((resolution, resolution))
+ base_color = base_color.resize((resolution//2, resolution//2))
+ metallic = metallic.resize((resolution//2, resolution//2))
+ roughness = roughness.resize((resolution//2, resolution//2))
+ alpha = alpha.resize((resolution//2, resolution//2))
+ row1 = np.concatenate([shaded, normal], axis=1)
+ row2 = np.concatenate([base_color, metallic, roughness, alpha], axis=1)
+ frame = np.concatenate([row1, row2], axis=0)
+ frames.append(frame)
+ return frames
diff --git a/trellis2/utils/vis_utils.py b/trellis2/utils/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e5f58e564aea50e4d80b0265220ee3fb382cd69
--- /dev/null
+++ b/trellis2/utils/vis_utils.py
@@ -0,0 +1,44 @@
+from typing import *
+import numpy as np
+import torch
+from ..modules import sparse as sp
+from ..representations import Voxel
+from .render_utils import render_video
+
+
+def pca_color(feats: torch.Tensor, channels: Tuple[int, int, int] = (0, 1, 2)) -> torch.Tensor:
+ """
+ Apply PCA to the features and return the first three principal components.
+ """
+ feats = feats.detach()
+ u, s, v = torch.svd(feats)
+ color = u[:, channels]
+ color = (color - color.min(dim=0, keepdim=True)[0]) / (color.max(dim=0, keepdim=True)[0] - color.min(dim=0, keepdim=True)[0])
+ return color
+
+
+def vis_sparse_tensor(
+ x: sp.SparseTensor,
+ num_frames: int = 300,
+):
+ assert x.shape[0] == 1, "Only support batch size 1"
+ assert x.coords.shape[1] == 4, "Only support 3D coordinates"
+
+ coords = x.coords.cuda().detach()[:, 1:]
+ feats = x.feats.cuda().detach()
+ color = pca_color(feats)
+
+ resolution = max(list(x.spatial_shape))
+ resolution = int(2**np.ceil(np.log2(resolution)))
+
+ rep = Voxel(
+ origin=[-0.5, -0.5, -0.5],
+ voxel_size=1/resolution,
+ coords=coords,
+ attrs=color,
+ layout={
+ 'color': slice(0, 3),
+ }
+ )
+
+ return render_video(rep, colors_overwrite=color, num_frames=num_frames)['color']