diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..98f252900c6064513f2378cd1973102f2cf6f681 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,76 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.webp filter=xet diff=xet merge=xet -text +*.exr filter=xet diff=xet merge=xet -text +*.png filter=xet diff=xet merge=xet -text +*.ply filter=xet diff=xet merge=xet -text +assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/T.png filter=lfs diff=lfs merge=lfs -text +assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp filter=lfs diff=lfs merge=lfs -text +assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp filter=lfs diff=lfs merge=lfs -text +assets/example_texturing/the_forgotten_knight.ply filter=lfs diff=lfs merge=lfs -text +assets/hdri/city.exr filter=lfs diff=lfs merge=lfs -text +assets/hdri/courtyard.exr filter=lfs diff=lfs merge=lfs -text +assets/hdri/forest.exr filter=lfs diff=lfs merge=lfs -text +assets/hdri/interior.exr filter=lfs diff=lfs merge=lfs -text +assets/hdri/night.exr filter=lfs diff=lfs merge=lfs -text +assets/hdri/sunrise.exr filter=lfs diff=lfs merge=lfs -text +assets/hdri/sunset.exr filter=lfs diff=lfs merge=lfs -text +assets/teaser.webp filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4711186cc74ab8d2ae812895aedc3daf4fea8429 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +__pycache__/ +*.pyc +.venv/ +venv/ +*.ckpt +*.pt +*.bin +*.safetensors +wandb/ +.wandb/ +node_modules/ +*.egg-info/ +.gradio/ +example.py + +outputs*/ +results*/ +ckpts*/ +tmp/example.py diff --git a/README.md b/README.md index 867da31dc29ac7f954c535fc6046951a1d8b1c3f..4c344761ef566aa6646e0383932994757e61d5fe 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ --- title: Pixal3D T -emoji: 📈 +emoji: 🏆 colorFrom: indigo -colorTo: green +colorTo: gray sdk: gradio sdk_version: 6.13.0 +python_version: "3.10" app_file: app.py pinned: false +license: mit --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..604598ae0aaad102acbb36a4628fb7fd7d646882 --- /dev/null +++ b/app.py @@ -0,0 +1,584 @@ +""" +Pixal3D (TRELLIS.2 Backbone) - Gradio App + +Image-to-3D generation using Proj-mode Cascade inference (512->1024/1536). + +""" + +import gradio as gr + +import os +import subprocess +subprocess.run([ + "pip", "install", "--force-reinstall", "--no-deps", + "https://github.com/LDYang694/Storages/releases/download/20260430/utils3d-0.0.2-py3-none-any.whl" +], check=True) + +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import argparse +import math +import time +from datetime import datetime +import shutil +import cv2 +from typing import * +import torch +import numpy as np +from PIL import Image +import base64 +import io +from trellis2.modules.sparse import SparseTensor +from trellis2.pipelines import Pixal3DImageTo3DPipeline +from trellis2.renderers import EnvMap +from trellis2.utils import render_utils +import o_voxel + + +# ============================================================================ +# Constants & Defaults +# ============================================================================ + +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +MODES = [ + {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, + {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, + {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, + {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, + {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, + {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, +] +STEPS = 8 +DEFAULT_MODE = 3 +DEFAULT_STEP = 3 + +# Cascade parameters +CASCADE_LR_RESOLUTION = 512 +CASCADE_MAX_NUM_TOKENS = 49152 + +# MoGe defaults +MOGE_MODEL_NAME = "Ruicheng/moge-2-vitl" +WILD_MESH_SCALE = 1.0 +WILD_EXTEND_PIXEL = 0 +WILD_IMAGE_RESOLUTION = 512 + +# Image Cond Model configs (extracted from training configs, hardcoded) +IMAGE_COND_CONFIGS = { + "ss": { + "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", + "image_size": 512, + "grid_resolution": 16, + }, + "shape_512": { + "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", + "image_size": 512, + "grid_resolution": 32, + "use_naf_upsample": True, + "naf_target_size": 512, + }, + "shape_1024": { + "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", + "image_size": 1024, + "grid_resolution": 64, + "use_naf_upsample": True, + "naf_target_size": 512, + }, + "tex_1024": { + "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", + "image_size": 1024, + "grid_resolution": 64, + "use_naf_upsample": True, + "naf_target_size": 1024, + }, +} + + +# ============================================================================ +# CSS & JS +# ============================================================================ + +css = """ +.stepper-wrapper { padding: 0; } +.stepper-container { padding: 0; align-items: center; } +.step-button { flex-direction: row; } +.step-connector { transform: none; } +.step-number { width: 16px; height: 16px; } +.step-label { position: relative; bottom: 0; } +.wrap.center.full { inset: 0; height: 100%; } +.wrap.center.full.translucent { background: var(--block-background-fill); } +.meta-text-center { + display: block !important; position: absolute !important; + top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; +} +.previewer-container { + position: relative; + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; + width: 100%; height: 722px; margin: 0 auto; padding: 20px; + display: flex; flex-direction: column; align-items: center; justify-content: center; +} +.previewer-container .tips-icon { + position: absolute; right: 10px; top: 10px; z-index: 10; + border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; +} +.previewer-container .tips-text { + position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); + border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; + transition: all 0.3s; opacity: 0%; user-select: none; +} +.previewer-container .tips-text p { font-size: 14px; line-height: 1.2; } +.tips-icon:hover + .tips-text { display: block; opacity: 100%; } +.previewer-container .mode-row { + width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; +} +.previewer-container .mode-btn { + width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; + transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; +} +.previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); } +.previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); } +.previewer-container .display-row { + margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; + display: flex; justify-content: center; align-items: center; +} +.previewer-container .previewer-main-image { + max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; +} +.previewer-container .previewer-main-image.visible { display: block; } +.previewer-container .slider-row { + width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; +} +.previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; } +.previewer-container input[type=range]::-webkit-slider-runnable-track { + width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; +} +.previewer-container input[type=range]::-webkit-slider-thumb { + height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); + cursor: pointer; -webkit-appearance: none; margin-top: -6px; + box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; +} +.previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); } +.gradio-container .padded:has(.previewer-container) { padding: 0 !important; } +.gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; } +""" + +head = """ + +""" + +empty_html = f""" +
+ +
+""" + + +# ============================================================================ +# Model Loading Utilities +# ============================================================================ + +def build_image_cond_model(config: dict): + """Build DinoV3ProjFeatureExtractor.""" + from trellis2.trainers.flow_matching.mixins.image_conditioned_proj import DinoV3ProjFeatureExtractor + model = DinoV3ProjFeatureExtractor(**config) + model.eval() + return model + + +# ============================================================================ +# Camera Parameter Utilities +# ============================================================================ + +def compute_f_pixels(camera_angle_x: float, resolution: int) -> float: + focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0)) + f_pixels = focal_length * resolution / 32.0 + return float(f_pixels.item()) + + +def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution): + rotation_matrix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) + gp = grid_point.to(torch.float32) @ rotation_matrix.T + gp = gp / mesh_scale / 2 + xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item() + xt, yt = float(target_point[0].item()), float(target_point[1].item()) + f_pixels = compute_f_pixels(camera_angle_x, image_resolution) + x_ndc = xt - image_resolution / 2.0 + y_ndc = -(yt - image_resolution / 2.0) + distance_x = f_pixels * xw / x_ndc - yw + return {"distance_from_x": float(distance_x), "f_pixels": float(f_pixels)} + + +def load_moge_model(device="cuda", model_name=MOGE_MODEL_NAME): + print(f"[MoGe-2] Loading model {model_name}...") + from moge.model.v2 import MoGeModel + moge_model = MoGeModel.from_pretrained(model_name).to(device) + moge_model.eval() + print("[MoGe-2] Model loaded!") + return moge_model + + +def get_camera_params_wild_moge(image, moge_model, device="cuda", + mesh_scale=1.0, extend_pixel=0, image_resolution=512): + """Estimate camera parameters via MoGe-2.""" + if isinstance(image, str): + pil_image = Image.open(image).convert("RGB") + elif isinstance(image, Image.Image): + pil_image = image.convert("RGB") + else: + raise ValueError(f"Unsupported image type: {type(image)}") + width, height = pil_image.size + image_np = np.array(pil_image).astype(np.float32) / 255.0 + image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device) + with torch.no_grad(): + output = moge_model.infer(image_tensor) + intrinsics = output["intrinsics"].squeeze().cpu().numpy() + fx_normalized = intrinsics[0, 0] + fx = fx_normalized * width + camera_angle_x = 2 * math.atan(width / (2 * fx)) + + grid_point = torch.tensor([-1.0, 0.0, 0.0]) + distance = distance_from_fov( + camera_angle_x, grid_point, + torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]), + mesh_scale, image_resolution + )["distance_from_x"] + return {'camera_angle_x': camera_angle_x, 'distance': distance, 'mesh_scale': mesh_scale} + + +# ============================================================================ +# UI Utilities +# ============================================================================ + +def image_to_base64(image): + buffered = io.BytesIO() + image = image.convert("RGB") + image.save(buffered, format="jpeg", quality=85) + img_str = base64.b64encode(buffered.getvalue()).decode() + return f"data:image/jpeg;base64,{img_str}" + + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + if os.path.exists(user_dir): + shutil.rmtree(user_dir) + + +def preprocess_image(image: Image.Image) -> Image.Image: + return pipeline.preprocess_image(image) + + +def pack_state(shape_slat, tex_slat, res): + return { + 'shape_slat_feats': shape_slat.feats.cpu().numpy(), + 'tex_slat_feats': tex_slat.feats.cpu().numpy(), + 'coords': shape_slat.coords.cpu().numpy(), + 'res': res, + } + + +def unpack_state(state): + shape_slat = SparseTensor( + feats=torch.from_numpy(state['shape_slat_feats']).cuda(), + coords=torch.from_numpy(state['coords']).cuda(), + ) + tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) + return shape_slat, tex_slat, state['res'] + + +def get_seed(randomize_seed, seed): + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +# ============================================================================ +# Core Inference +# ============================================================================ + +def image_to_3d( + image, seed, resolution, + ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, + shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, + tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, + req: gr.Request, + progress=gr.Progress(track_tqdm=True), +): + device = pipeline.device + torch.manual_seed(seed) + hr_resolution = int(resolution) + + total_t0 = time.time() + print(f"\n{'='*60}") + print(f" [Generate] Start | seed={seed}, resolution={hr_resolution}") + print(f"{'='*60}") + + # Preprocessing + image_preprocessed = pipeline.preprocess_image(image) + + # Camera estimation via MoGe-2 + camera_params = get_camera_params_wild_moge( + image_preprocessed, moge_model, device=str(device), + mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL, + image_resolution=WILD_IMAGE_RESOLUTION, + ) + + ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, + "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t} + shape_sampler_override = {"steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, + "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t} + tex_sampler_override = {"steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, + "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t} + + # Run pipeline + pipeline_type = f"{hr_resolution}_cascade" + mesh_list, (shape_slat, tex_slat, res) = pipeline.run( + image_preprocessed, + camera_params=camera_params, + seed=seed, + sparse_structure_sampler_params=ss_sampler_override, + shape_slat_sampler_params=shape_sampler_override, + tex_slat_sampler_params=tex_sampler_override, + preprocess_image=False, + return_latent=True, + pipeline_type=pipeline_type, + max_num_tokens=CASCADE_MAX_NUM_TOKENS, + ) + mesh = mesh_list[0] + state = pack_state(shape_slat, tex_slat, res) + del shape_slat, tex_slat, mesh_list + torch.cuda.empty_cache() + + # Render + mesh.simplify(16777216) + images = render_utils.render_proj_aligned_video( + mesh, camera_angle_x=camera_params['camera_angle_x'], + distance=camera_params['distance'], resolution=1024, + num_frames=STEPS, envmap=envmap, + ) + del mesh + torch.cuda.empty_cache() + print(f"\n [Generate] Total time: {time.time()-total_t0:.2f}s") + + # Build HTML + images_html = "" + for m_idx, mode in enumerate(MODES): + for s_idx in range(STEPS): + unique_id = f"view-m{m_idx}-s{s_idx}" + is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP) + vis_class = "visible" if is_visible else "" + img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx])) + images_html += f'' + + btns_html = "" + for idx, mode in enumerate(MODES): + active_class = "active" if idx == DEFAULT_MODE else "" + btns_html += f'' + + full_html = f""" +
+
+
Tips
+
+

Render Mode - Click circular buttons to switch render modes.

+

View Angle - Drag the slider to change the view angle.

+
+
+
{images_html}
+
{btns_html}
+
+ +
+
+ """ + return state, full_html + + +def extract_glb(state, decimation_target, texture_size, req: gr.Request, progress=gr.Progress(track_tqdm=True)): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shape_slat, tex_slat, res = unpack_state(state) + mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] + glb = o_voxel.postprocess.to_glb( + vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, + coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, + grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + decimation_target=decimation_target, texture_size=texture_size, + remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, + ) + now = datetime.now() + timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" + os.makedirs(user_dir, exist_ok=True) + glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') + glb.export(glb_path, extension_webp=True) + torch.cuda.empty_cache() + return glb_path, glb_path + + +# ============================================================================ +# Gradio UI +# ============================================================================ + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Image to 3D Asset with Pixal3D (TRELLIS.2 Backbone) + * Upload an image and click **Generate** to create a 3D asset using Pixal3D with TRELLIS.2 backbone. + * Click **Extract GLB** to export and download the generated GLB file. + * Camera parameters are estimated automatically via MoGe-2. + """) + + with gr.Row(): + with gr.Column(scale=1, min_width=360): + image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400) + resolution = gr.Radio(["1024", "1536"], label="Resolution", value="1536") + seed = gr.Slider(0, MAX_SEED, label="Seed", value=42, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) + decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=1000000, step=10000) + texture_size = gr.Slider(1024, 4096, label="Texture Size", value=4096, step=1024) + generate_btn = gr.Button("Generate") + + with gr.Accordion(label="Advanced Settings", open=False): + gr.Markdown("Stage 1: Sparse Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) + gr.Markdown("Stage 2: Shape Generation") + with gr.Row(): + shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) + shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) + gr.Markdown("Stage 3: Material Generation") + with gr.Row(): + tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) + tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) + tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) + + with gr.Column(scale=10): + with gr.Walkthrough(selected=0) as walkthrough: + with gr.Step("Preview", id=0): + preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True) + extract_btn = gr.Button("Extract GLB") + with gr.Step("Extract", id=1): + glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) + download_btn = gr.DownloadButton(label="Download GLB") + + with gr.Column(scale=1, min_width=172): + examples = gr.Examples( + examples=[f'assets/example_image/{image}' for image in os.listdir("assets/example_image")], + inputs=[image_prompt], fn=preprocess_image, outputs=[image_prompt], + run_on_click=True, examples_per_page=18, + ) + + output_buf = gr.State() + + demo.load(start_session) + demo.unload(end_session) + image_prompt.upload(preprocess_image, inputs=[image_prompt], outputs=[image_prompt]) + + generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then( + lambda: gr.Walkthrough(selected=0), outputs=walkthrough + ).then( + image_to_3d, + inputs=[image_prompt, seed, resolution, + ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, + shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, + tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t], + outputs=[output_buf, preview_output], + ) + + extract_btn.click(lambda: gr.Walkthrough(selected=1), outputs=walkthrough).then( + extract_glb, inputs=[output_buf, decimation_target, texture_size], outputs=[glb_output, download_btn], + ) + + +# ============================================================================ +# Launch +# ============================================================================ + +def parse_args(): + parser = argparse.ArgumentParser(description="Pixal3D Gradio App") + parser.add_argument("--model_path", type=str, default="TencentARC/Pixal3D-T", + help="HuggingFace repo ID or local path (default: TencentARC/Pixal3D-T)") + parser.add_argument("--port", type=int, default=7860) + parser.add_argument("--share", action="store_true", default=True) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + os.makedirs(TMP_DIR, exist_ok=True) + + # Construct UI icon base64 + for i in range(len(MODES)): + icon = Image.open(MODES[i]['icon']) + MODES[i]['icon_base64'] = image_to_base64(icon) + + # Load pipeline from HuggingFace or local path + print(f"[Pipeline] Loading from {args.model_path}...") + pipeline = Pixal3DImageTo3DPipeline.from_pretrained(args.model_path) + + # Load environment maps + envmap = { + 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), + 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), + 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), + } + + # Build image cond models and set on pipeline + print("[ImageCond] Building DinoV3ProjFeatureExtractor models...") + pipeline.image_cond_model_ss = build_image_cond_model(IMAGE_COND_CONFIGS["ss"]) + pipeline.image_cond_model_shape_512 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_512"]) + pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"]) + pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"]) + + pipeline.cuda() + + # Pre-download NAF model (avoid lazy-loading during inference) + print("[NAF] Pre-loading NAF upsampler model...") + for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']: + model = getattr(pipeline, attr, None) + if model is not None and getattr(model, 'use_naf_upsample', False): + model._load_naf() + print("[NAF] NAF model loaded.") + + # Load MoGe-2 + print("\n[MoGe-2] Loading model for camera estimation...") + moge_model = load_moge_model(device="cuda") + + print(f"\n{'=' * 60}") + print(f" Pixal3D ready! Model loaded from: {args.model_path}") + print(f" Cascade: {CASCADE_LR_RESOLUTION} -> 1024/1536") + print(f"{'=' * 60}\n") + + demo.launch(css=css, head=head, server_port=args.port, share=args.share) diff --git a/assets/app/basecolor.png b/assets/app/basecolor.png new file mode 100644 index 0000000000000000000000000000000000000000..e7dbeaf6344b4b292757964eca2c6960e7e10a68 Binary files /dev/null and b/assets/app/basecolor.png differ diff --git a/assets/app/clay.png b/assets/app/clay.png new file mode 100644 index 0000000000000000000000000000000000000000..e02866a15d1101d1d1cd3f6098c7025c0763cfdb Binary files /dev/null and b/assets/app/clay.png differ diff --git a/assets/app/hdri_city.png b/assets/app/hdri_city.png new file mode 100644 index 0000000000000000000000000000000000000000..43e1c2a54c3bb8aa376613c698268fca0aea9b29 Binary files /dev/null and b/assets/app/hdri_city.png differ diff --git a/assets/app/hdri_courtyard.png b/assets/app/hdri_courtyard.png new file mode 100644 index 0000000000000000000000000000000000000000..4261ad62862163d32c22ac38495defdbdf3bebd0 Binary files /dev/null and b/assets/app/hdri_courtyard.png differ diff --git a/assets/app/hdri_forest.png b/assets/app/hdri_forest.png new file mode 100644 index 0000000000000000000000000000000000000000..7617fe19adf536c12db3f4bf7bc1f6a42f54f8b7 Binary files /dev/null and b/assets/app/hdri_forest.png differ diff --git a/assets/app/hdri_interior.png b/assets/app/hdri_interior.png new file mode 100644 index 0000000000000000000000000000000000000000..e00c1d656c94c71402e93067cac64725425754d0 Binary files /dev/null and b/assets/app/hdri_interior.png differ diff --git a/assets/app/hdri_night.png b/assets/app/hdri_night.png new file mode 100644 index 0000000000000000000000000000000000000000..f0423d221069904b0de32c82956675a210e8c375 Binary files /dev/null and b/assets/app/hdri_night.png differ diff --git a/assets/app/hdri_studio.png b/assets/app/hdri_studio.png new file mode 100644 index 0000000000000000000000000000000000000000..0f5a4e8d5c5717b6ca4217e6e081e695e2bec264 Binary files /dev/null and b/assets/app/hdri_studio.png differ diff --git a/assets/app/hdri_sunrise.png b/assets/app/hdri_sunrise.png new file mode 100644 index 0000000000000000000000000000000000000000..9cee3bb066a0f01ddf84b088da31fba1af6648d0 Binary files /dev/null and b/assets/app/hdri_sunrise.png differ diff --git a/assets/app/hdri_sunset.png b/assets/app/hdri_sunset.png new file mode 100644 index 0000000000000000000000000000000000000000..bd67070912b846b7cb353d13696f4aff4be40831 Binary files /dev/null and b/assets/app/hdri_sunset.png differ diff --git a/assets/app/normal.png b/assets/app/normal.png new file mode 100644 index 0000000000000000000000000000000000000000..352e92b5750414ee1c2562d1c719468ce5da2883 Binary files /dev/null and b/assets/app/normal.png differ diff --git a/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp b/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp new file mode 100644 index 0000000000000000000000000000000000000000..8cb4652f561eca8ab3aa32f6e00457bc9b9f194a --- /dev/null +++ b/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83d3765ff57511f11054d62e6beaf52648d277006b6dcb3c1d5f9e03ef502c49 +size 108096 diff --git a/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp b/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp new file mode 100644 index 0000000000000000000000000000000000000000..b9dbb51554fe1f2daeb448c0e47cfaa512a07eff --- /dev/null +++ b/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a31ba3b084757faa66e4bb91a85bd956fcc1eaaaf298a091fd08044351cd0293 +size 184424 diff --git a/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp b/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp new file mode 100644 index 0000000000000000000000000000000000000000..62488149f1f5d00ac044ff6f95b91efa6117b7ed --- /dev/null +++ b/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8d419ac16852b73fbc7b92c6857ba5d4be2f2dd8056bf44af707e6af4d62207 +size 156794 diff --git a/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp b/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp new file mode 100644 index 0000000000000000000000000000000000000000..f6a749a33c8ffd111cf198f499bda2b499537d51 --- /dev/null +++ b/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1a36d67f1cd7bba5c78b429c85346990aa3f269e19eee60b6d8b5d632cd743a +size 100332 diff --git a/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp b/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp new file mode 100644 index 0000000000000000000000000000000000000000..3c3ac5dfc7a6d30528928e30eca04f36870f33c3 Binary files /dev/null and b/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp differ diff --git a/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp b/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp new file mode 100644 index 0000000000000000000000000000000000000000..e596271b736281ad07491dd7b7d3544214e2c7aa Binary files /dev/null and b/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp differ diff --git a/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp b/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp new file mode 100644 index 0000000000000000000000000000000000000000..287db1bed9d597e5ff7e4a49da4a63c5529f51c7 --- /dev/null +++ b/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb7cdcea3b238d40c7b838659288f5c12febd77292b22d44f54f785e2712d49a +size 159636 diff --git a/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp b/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp new file mode 100644 index 0000000000000000000000000000000000000000..bdfc4f52348dc4ea59cae734f9ee0f4aa29484da --- /dev/null +++ b/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4202eb19d6802319aa140e65385b3474d09aaa7f008218acf1e1f5186b923c90 +size 167372 diff --git a/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp b/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp new file mode 100644 index 0000000000000000000000000000000000000000..ef5029ebcfa044e437caa35458bb6ab860dbda8a --- /dev/null +++ b/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d15269d0c0b427eabcca39d6d093cc2cfdaab19cb8a40b6158988a70a91ffe45 +size 207084 diff --git a/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp b/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp new file mode 100644 index 0000000000000000000000000000000000000000..89889d4b276ba78b3fc451ec6a504c1b5089c859 --- /dev/null +++ b/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cd45dd941dc559191c71f49ac5841fc6d06d32c7427fcf535f38c007611cd3f +size 109042 diff --git a/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp b/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp new file mode 100644 index 0000000000000000000000000000000000000000..962281565b2c60a013e64776313145fd8186e4ef --- /dev/null +++ b/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4e26e3e2f3de96cf396c64f1cef39ac82c13f231e2567fe0fe902dc103bf949 +size 169506 diff --git a/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp b/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp new file mode 100644 index 0000000000000000000000000000000000000000..0af9ee139499ab072b6068e5fb1d974f99103a30 --- /dev/null +++ b/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88c65db47f93beed7d6b184c32f9fa27c5f64ced2e3d4f535ad628179763cff8 +size 115508 diff --git a/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp b/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp new file mode 100644 index 0000000000000000000000000000000000000000..7411bc618bfd5af692073d70281a857df6932b22 Binary files /dev/null and b/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp differ diff --git a/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp b/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp new file mode 100644 index 0000000000000000000000000000000000000000..5f00cc2b4562fb026033a555cf95c946a3922246 Binary files /dev/null and b/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp differ diff --git a/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp b/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp new file mode 100644 index 0000000000000000000000000000000000000000..0a3a3a893b1fa0914e63d79e82d7894afe552483 --- /dev/null +++ b/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47be8271d2d1d27fc334cc470d63534807ad44c6f50cee18a204125a1c6b66a4 +size 140210 diff --git a/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp b/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp new file mode 100644 index 0000000000000000000000000000000000000000..6ae86128b367dc0618fd87e605cd08832b6b931b --- /dev/null +++ b/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae7ad70cbb2a616fca63bc73d319172d5fda0930c4aaf9d897e21c09ccf17ad7 +size 167670 diff --git a/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp b/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp new file mode 100644 index 0000000000000000000000000000000000000000..4cb9bdb5ac7fe4bd2ea84771de21eff56dda4cf4 --- /dev/null +++ b/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b5ca8a0b47aa2dbd5480cd356e7ab574dadd565ed738ac5ed655f36a4349eea +size 175080 diff --git a/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp b/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp new file mode 100644 index 0000000000000000000000000000000000000000..01a571a7a2ad661e29c1a4c3b5ba68353451cd70 --- /dev/null +++ b/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:933c10aeebb2920b08cb34a08ab1878817b64eb9e30efdcc3d76731069fc0849 +size 131488 diff --git a/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp b/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp new file mode 100644 index 0000000000000000000000000000000000000000..60375c981e776e3667287f8df99b04f12098f1df --- /dev/null +++ b/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d5224235eeadadfdb93ab37664055ef55ffd930b085d268cd62c5faf9d101de +size 136872 diff --git a/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp b/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp new file mode 100644 index 0000000000000000000000000000000000000000..ddd363de151adc06a1815d148a69fce18a238942 --- /dev/null +++ b/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86b909e6847118f2cec6e9c0945e6a78154ce5918f4ba14c0b0afa0be2e8647f +size 150020 diff --git a/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp b/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp new file mode 100644 index 0000000000000000000000000000000000000000..4151759cf55814774c5e1ff29e219e965a21a423 --- /dev/null +++ b/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62b7097ce57d01e730a0ce2cc120b6c9d27585026c3de84d6b1a0dcaf5fea9d3 +size 128572 diff --git a/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp b/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp new file mode 100644 index 0000000000000000000000000000000000000000..cd9120b3d0c12fcec01cf68a27060eba6c658e37 --- /dev/null +++ b/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47cbe7b6f29adb0f7d02c2cd73f6f8bd90a54b12d416368499458a8c571e35c5 +size 178460 diff --git a/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp b/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp new file mode 100644 index 0000000000000000000000000000000000000000..7a10d813fb288d647e49f2b8bb2f5cdf578f873d --- /dev/null +++ b/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbd98cf5da79c56f8efc6cf86804391e71bdad2e935d21a7262472653a0674dc +size 126914 diff --git a/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp b/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp new file mode 100644 index 0000000000000000000000000000000000000000..021980a74ef652ae55bef0272527aab211275753 --- /dev/null +++ b/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfad86b88eb81da36a5acf77891822c042e991436bb004c7d75d8b19e89c45bd +size 110562 diff --git a/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp b/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp new file mode 100644 index 0000000000000000000000000000000000000000..3fe995897ad1712af150b391a5d91063c3e417b5 --- /dev/null +++ b/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd73dcc5ae5b42a442c87f1522f370869f3d2e63922e6ecf769fd90fdeed1e66 +size 124876 diff --git a/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp b/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp new file mode 100644 index 0000000000000000000000000000000000000000..8f787a298a8ba70c64beff1ad88cc9b96d4d58b9 Binary files /dev/null and b/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp differ diff --git a/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp b/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp new file mode 100644 index 0000000000000000000000000000000000000000..54a78c84746298a677529c8abe5858d6d734bbe7 Binary files /dev/null and b/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp differ diff --git a/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp b/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp new file mode 100644 index 0000000000000000000000000000000000000000..31b50241ad8cc38c4d2b9a00b36dc258dc172b73 --- /dev/null +++ b/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0dd605ac2927b56f6ba5ced2d0b56aaba1be9f7dc2ab7f8d558d0f37856f6b5 +size 169830 diff --git a/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp b/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp new file mode 100644 index 0000000000000000000000000000000000000000..b7a521fcdb71be9f5b095bd0c2a3a8b3ed9a32fc Binary files /dev/null and b/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp differ diff --git a/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp b/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp new file mode 100644 index 0000000000000000000000000000000000000000..51557cf27478f8c301978278157bf71c1d4069f1 --- /dev/null +++ b/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41fd0f2a1cd66f4615a3947b6fdbddf527d4d628935a6fd208ca32a3b46849b1 +size 111992 diff --git a/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp b/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp new file mode 100644 index 0000000000000000000000000000000000000000..b07990950bc8f3aee982664d6fd1264570906752 --- /dev/null +++ b/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5a7835a908470c1142678cddfbfbe093ba0ec1085852ed072ae597f1eae5c37 +size 111836 diff --git a/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp b/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp new file mode 100644 index 0000000000000000000000000000000000000000..5f10f35ab7a6f2d1d95d40e82c4bb91e30f165b5 --- /dev/null +++ b/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7807391ef76523dab937bfcae5f9fc4803a65b86ab49310c17df0714288b7c48 +size 195232 diff --git a/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp b/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp new file mode 100644 index 0000000000000000000000000000000000000000..ab02dea2e3abcb802b48c82d6e7d670783f1be5e --- /dev/null +++ b/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3977e4b323080e122ec297f2fafe1203e717fef8417e82b6ed523e8438f829c6 +size 191642 diff --git a/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp b/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp new file mode 100644 index 0000000000000000000000000000000000000000..60dd646a277d93fb76441ec47f88e22c4cf6967a --- /dev/null +++ b/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6707893380d3c791a9a2d33bc0caa5f10719f7168d25f8d447fa26f4420919db +size 108942 diff --git a/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp b/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp new file mode 100644 index 0000000000000000000000000000000000000000..3d6c84ab1ccc878d9f292a04b579642b92c14fc9 --- /dev/null +++ b/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9558307682a1e723f86291af377f0b53970c0d301304c1f43dec741fedc209d7 +size 117870 diff --git a/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp b/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp new file mode 100644 index 0000000000000000000000000000000000000000..f12cbf93875f086f3084efc29072f8fb579f3143 --- /dev/null +++ b/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a319ace2549835da92a6ffa5db73eebd7fce29079e5865cb32dfbdac21d9b900 +size 100414 diff --git a/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp b/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp new file mode 100644 index 0000000000000000000000000000000000000000..afc4b658a9e2dfefd67a66e91cf60f009b74e8bd --- /dev/null +++ b/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b8373ce90506089d23907895d7efac4c37defb603ea7434bee9c4de870bf7da +size 179774 diff --git a/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp b/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp new file mode 100644 index 0000000000000000000000000000000000000000..d4298ffbb33598e6bb6392741f65cc9bc648a468 --- /dev/null +++ b/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:643b3d426830d89dc62634027f802e3af5450eb45591eeb413ae67d5ba56f557 +size 114408 diff --git a/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp b/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp new file mode 100644 index 0000000000000000000000000000000000000000..bbd93cd629c4850d0ddfcd65763478d1897cbb1c --- /dev/null +++ b/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4e65a295b78ca3505337eb1cc58d0e5ff7effc0d500d01e503c4d5243af0a68 +size 103346 diff --git a/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp b/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp new file mode 100644 index 0000000000000000000000000000000000000000..a0795fe45f2477d4811a07bd2168f08ef9460960 --- /dev/null +++ b/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23b3f7294731c66a3a9db7ccbd593708e1ba56819cff970b29fd9fbf286588e0 +size 187208 diff --git a/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp b/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp new file mode 100644 index 0000000000000000000000000000000000000000..b668a21a0f5e8a3c0730377da02357cd6e30b869 --- /dev/null +++ b/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d42562b49df17554583f2acbe18f8a26c9075cdab278caefab27aa7587df4ac +size 115424 diff --git a/assets/example_image/T.png b/assets/example_image/T.png new file mode 100644 index 0000000000000000000000000000000000000000..171f72e68f117c7d2507f94970a6a8d2d5e7e563 --- /dev/null +++ b/assets/example_image/T.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db468bad8a04f1474a8d68140c07501b013b3ec6124b911fb7852675d64c05ee +size 1718470 diff --git a/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp b/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp new file mode 100644 index 0000000000000000000000000000000000000000..7f764c337d834d1077aef043aa277f360b20d9ba --- /dev/null +++ b/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94b508edb4a3c49ff4e5c4c583b6980a7d0ef84a2f16a301ac9b671d421b8a52 +size 103072 diff --git a/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp b/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp new file mode 100644 index 0000000000000000000000000000000000000000..e4cf1566ad8daa3be8f938d3d3e5b00b297c7414 --- /dev/null +++ b/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71c2035d7968538b51a5edb0d853a310f3f03b9626b705ac7f14398bb72f6093 +size 101638 diff --git a/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp b/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp new file mode 100644 index 0000000000000000000000000000000000000000..29b7868f213043532468e2fa1c169c3a9c98ce26 --- /dev/null +++ b/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a8dbf4af05a60591cb157bcbb19c58c05c1d5622aaecddaf5aaabf292e89ea0 +size 208094 diff --git a/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp b/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp new file mode 100644 index 0000000000000000000000000000000000000000..242c31ce26348b99d98a69f789cf646ca472db05 --- /dev/null +++ b/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3786b5f2c6715d02a20abde45f4465883eb8b791a807f9dfc86dc78a17b7c58 +size 180294 diff --git a/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp b/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp new file mode 100644 index 0000000000000000000000000000000000000000..00790a403686c1fb1a78e30b9e0bad269dcb80a6 --- /dev/null +++ b/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69f3ab098a28bf0c06194ab523ba10e30ce3db331871e2e8fa97cfd5967853b0 +size 180186 diff --git a/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp b/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp new file mode 100644 index 0000000000000000000000000000000000000000..dc0630884d160addf2997b2435fd15ea3bece805 --- /dev/null +++ b/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc3d5f2d5228c9314d8037105cc24798499aea1aaf311cde332df74d17f96da6 +size 130468 diff --git a/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp b/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp new file mode 100644 index 0000000000000000000000000000000000000000..cfa8a87a3b0c0a5ed90f249c030815517317403c --- /dev/null +++ b/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a77a4ab99beb302c5934efaeaeef2effc48b2c07b9d68cadb06d68396a48abea +size 106998 diff --git a/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp b/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp new file mode 100644 index 0000000000000000000000000000000000000000..9b23b87c2a30e9c704de86bc7c2214a9629661a8 --- /dev/null +++ b/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01962bf62febbf916100201f915d27ab881c747a9e283228d8b9c890c1eb2fb8 +size 155612 diff --git a/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp b/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp new file mode 100644 index 0000000000000000000000000000000000000000..4ae373686902684181a5f7c84f976a02f3fea9d5 --- /dev/null +++ b/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed898f0689bd19f163fbfb601b9c9d31c422ab2ad409ec07f61b70471636c164 +size 155628 diff --git a/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp b/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp new file mode 100644 index 0000000000000000000000000000000000000000..4282564790b1a57ea1d4a7c2f77dd84775cade17 --- /dev/null +++ b/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:961adcb3c9c10db565957bfcf5d0be49e18c23012363a78bebcf8b49dbd7153c +size 206140 diff --git a/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp b/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp new file mode 100644 index 0000000000000000000000000000000000000000..4203f61711b705c5e5e0e6f49e52f9b555c878a9 --- /dev/null +++ b/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e211aff13d1e7dcdf592d68ea34e03794c4464c5b159224d7c508d9ecae5eb59 +size 151940 diff --git a/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp b/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp new file mode 100644 index 0000000000000000000000000000000000000000..fc11f391f770d92966ef73037c8769f807337875 --- /dev/null +++ b/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4533ce41604e7aff386c71f37f0b2727242a4615ef0e37c3cd62273678ad1809 +size 142682 diff --git a/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp b/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp new file mode 100644 index 0000000000000000000000000000000000000000..898e801ef377765e4082aa5968db56dd6e696a06 --- /dev/null +++ b/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec6060536a5db96092d419be6e0f9d14985f6954ced1605e9d473d415ff29368 +size 764638 diff --git a/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp b/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp new file mode 100644 index 0000000000000000000000000000000000000000..6d79bbc33c14ba034cbc6f5dc7ab443e0f92d445 --- /dev/null +++ b/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eaf705e6205f5101af261fac1a3477121cf377f92ff93feaf506eada70de98d7 +size 217676 diff --git a/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp b/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp new file mode 100644 index 0000000000000000000000000000000000000000..66bc8713ed9f5a86463850a3dd130f6d9a4953c0 --- /dev/null +++ b/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c0c09536a1a6c1c44e7c8649b1fda58ed4404ffa26debdc6f1719aaa7fdc59b +size 235370 diff --git a/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp b/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp new file mode 100644 index 0000000000000000000000000000000000000000..2573d75a2a428aef7d4d42daaa29d087b293a461 --- /dev/null +++ b/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:373afe2a56a038d1356f3fc5f1c8fb04778e90cb2216ad1c60ba7e2d2877705f +size 153922 diff --git a/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp b/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp new file mode 100644 index 0000000000000000000000000000000000000000..65b852da0703baeaee573de95f33b9ce366731d8 Binary files /dev/null and b/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp differ diff --git a/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp b/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp new file mode 100644 index 0000000000000000000000000000000000000000..a773e5c7e9a3e0078acdf30dbace00de6e22cd18 Binary files /dev/null and b/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp differ diff --git a/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp b/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp new file mode 100644 index 0000000000000000000000000000000000000000..9ffe4be1a5349b8a65b220b6063386831816db4b --- /dev/null +++ b/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:309713c97b32a579a232faea918aea7c49d3638b07c47f29072523bb4e95c5a0 +size 167178 diff --git a/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp b/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp new file mode 100644 index 0000000000000000000000000000000000000000..335531a6f0d695d3463e0bbfc478c08ad19c93ed --- /dev/null +++ b/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15e90c0ce28673b896e754ca56060dd840799cb8b95fb4ebeb06661675f8ed07 +size 182650 diff --git a/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp b/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp new file mode 100644 index 0000000000000000000000000000000000000000..114d6ffaa3545023654eea953aee8c22031e40a1 --- /dev/null +++ b/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05a0cabaea29f4d03c6468d039fe09af8eb8c3647952f3bcb342421445ea00a2 +size 229896 diff --git a/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp b/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp new file mode 100644 index 0000000000000000000000000000000000000000..e64acdec6e24981b82c80f35bafb44766ac6cc31 Binary files /dev/null and b/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp differ diff --git a/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp b/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp new file mode 100644 index 0000000000000000000000000000000000000000..08850da10b51022be1509076a797e4b4ea2887a1 --- /dev/null +++ b/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:477c0c9ef774e796a27645e1c143e7bd8d574cb6f31f2a0a9f57aaf66e80c0be +size 106434 diff --git a/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp b/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp new file mode 100644 index 0000000000000000000000000000000000000000..ba44a2d028fc04d8b74cef6935f8db9e9df17ce5 --- /dev/null +++ b/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b169654d84f98e808a875bc80b47e0cfe988af495250dba4368edf3eea4f633 +size 304284 diff --git a/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp b/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp new file mode 100644 index 0000000000000000000000000000000000000000..b9e06fd37eba38547fe24b88ce68c442086a83c7 --- /dev/null +++ b/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:242db48e9c2e6e45a888fdd24f133b966f460b583f1e77324252b869d1a7d65d +size 254674 diff --git a/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp b/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp new file mode 100644 index 0000000000000000000000000000000000000000..16921f8bf9843166dd1e7695fcb5fa13008c75e1 --- /dev/null +++ b/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f669f15496ecc1f3ec9819b40c8ca46ce2c080f49eab342724f2872f5c9dcff +size 186202 diff --git a/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp b/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp new file mode 100644 index 0000000000000000000000000000000000000000..c2ab696162c9dae5f90114d4825af5ec30cb28cb --- /dev/null +++ b/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4d093a8f99f3fc25a588785ceee95180427ff8dd063e2fda9f1e7f41cf40940 +size 247040 diff --git a/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp b/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp new file mode 100644 index 0000000000000000000000000000000000000000..1ae36239b3659bf579bd2588f930cdcd02d68b13 --- /dev/null +++ b/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d65e9fdbe6e4b432e34fbae16ea8a3d98e3cef81a9b1613f294039e96805e77 +size 208356 diff --git a/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp b/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp new file mode 100644 index 0000000000000000000000000000000000000000..9c24d0eb0a3d0c2c7ddc69e3ae39f4e9d85b9243 Binary files /dev/null and b/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp differ diff --git a/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp b/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp new file mode 100644 index 0000000000000000000000000000000000000000..f90db415387320470e55cc5bdcf62635e7b45726 Binary files /dev/null and b/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp differ diff --git a/assets/example_texturing/image.webp b/assets/example_texturing/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..a69f81dc924ae63e0c0a4f3212f476280d95439b Binary files /dev/null and b/assets/example_texturing/image.webp differ diff --git a/assets/example_texturing/readme b/assets/example_texturing/readme new file mode 100644 index 0000000000000000000000000000000000000000..7f6c799d95a410000c9116e459665339a701c88a --- /dev/null +++ b/assets/example_texturing/readme @@ -0,0 +1,11 @@ +## Asset Information + +* Title: The Forgotten Knight +* Author: dark_igorek +* Source: https://sketchfab.com/3d-models/the-forgotten-knight-d14eb14d83bd4e7ba7cbe443d76a10fd +* License: Creative Commons Attribution (CC BY) + +## Usage + +The asset is used for research purposes only. +Please credit the original author and include the Sketchfab link when using or redistributing this model. \ No newline at end of file diff --git a/assets/example_texturing/the_forgotten_knight.ply b/assets/example_texturing/the_forgotten_knight.ply new file mode 100644 index 0000000000000000000000000000000000000000..919ffdc534b30c9e19af438705d50a6f1814531f --- /dev/null +++ b/assets/example_texturing/the_forgotten_knight.ply @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56b16fba10017445de9529b2c2c1f68a97d2ca883293b27fa5efce116c6489b1 +size 4753123 diff --git a/assets/hdri/city.exr b/assets/hdri/city.exr new file mode 100644 index 0000000000000000000000000000000000000000..fac69a5403a2a3b51d725bca679174d2f0208b25 --- /dev/null +++ b/assets/hdri/city.exr @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e42abcd2fa3231e5c2485ca6dd64800534d157b7194ab7fe9cc3bf5a56d0256 +size 204863 diff --git a/assets/hdri/courtyard.exr b/assets/hdri/courtyard.exr new file mode 100644 index 0000000000000000000000000000000000000000..142504138fb47eba3024495d69923211ebeda9b0 --- /dev/null +++ b/assets/hdri/courtyard.exr @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6690b47725965531559380121e6878373eb599655e35fefd109a4bd0911366f3 +size 255126 diff --git a/assets/hdri/forest.exr b/assets/hdri/forest.exr new file mode 100644 index 0000000000000000000000000000000000000000..7809de1141fe935457de0cc73d8fe87ecf924e1a --- /dev/null +++ b/assets/hdri/forest.exr @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdf2298244affa0f85509380fd130ac6d4dfaa3c856df065998f7f4c1a93dc0d +size 552641 diff --git a/assets/hdri/interior.exr b/assets/hdri/interior.exr new file mode 100644 index 0000000000000000000000000000000000000000..ddb91daa8b222bedea152d60221499768d0e0fd7 --- /dev/null +++ b/assets/hdri/interior.exr @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e945ff5c1ddd7a3aaf05e9fb5c3bc9cb93c5518414febd44ed2c394e013f0cbd +size 189416 diff --git a/assets/hdri/license.txt b/assets/hdri/license.txt new file mode 100644 index 0000000000000000000000000000000000000000..eba5cb8037c9eb45e8c8638c99b39f7160a69214 --- /dev/null +++ b/assets/hdri/license.txt @@ -0,0 +1,15 @@ +All HDRIs are licensed as CC0. + +These were created by Greg Zaal (Poly Haven https://polyhaven.com). +Originals used for each HDRI: +- City: https://polyhaven.com/a/portland_landing_pad +- Courtyard: https://polyhaven.com/a/courtyard +- Forest: https://polyhaven.com/a/ninomaru_teien +- Interior: https://polyhaven.com/a/hotel_room +- Night: Probably https://polyhaven.com/a/moonless_golf +- Studio: Probably https://polyhaven.com/a/studio_small_01 +- Sunrise: https://polyhaven.com/a/spruit_sunrise +- Sunset: https://polyhaven.com/a/venice_sunset + +1K resolution of each was taken, and compressed with oiiotool: +oiiotool input.exr --ch R,G,B -d float --compression dwab:300 --clamp:min=0.0:max=32000.0 -o output.exr diff --git a/assets/hdri/night.exr b/assets/hdri/night.exr new file mode 100644 index 0000000000000000000000000000000000000000..380591f26bebbd1419b703145dedb20cab765f59 --- /dev/null +++ b/assets/hdri/night.exr @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17480af5547465d160f7307c92585cf30820ca563f87f066395decbad8ac32a4 +size 139959 diff --git a/assets/hdri/studio.exr b/assets/hdri/studio.exr new file mode 100644 index 0000000000000000000000000000000000000000..baf478da00acddbd70593e36d90146eb15c8939d Binary files /dev/null and b/assets/hdri/studio.exr differ diff --git a/assets/hdri/sunrise.exr b/assets/hdri/sunrise.exr new file mode 100644 index 0000000000000000000000000000000000000000..b0689a58e51c6527ed8810186644dd482f96799a --- /dev/null +++ b/assets/hdri/sunrise.exr @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a1180126e4db7d01f134c5430ea43c1b263ab1a12faf58a444d9ce9c03f3a84 +size 251877 diff --git a/assets/hdri/sunset.exr b/assets/hdri/sunset.exr new file mode 100644 index 0000000000000000000000000000000000000000..bf0414b39faf538b48baa62257aa6e5ff168928f --- /dev/null +++ b/assets/hdri/sunset.exr @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bcafdda4f2d7b9759cc1d73004d34d721a274c29b1a0947be88a602dbac426b +size 171164 diff --git a/assets/teaser.webp b/assets/teaser.webp new file mode 100644 index 0000000000000000000000000000000000000000..d8b89bd17c1fdf5831bb79a3989bf428086c6a7b --- /dev/null +++ b/assets/teaser.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b110c904037017d996933cc318ba948bb2dec7691d5a709568e2e0d9ae86068b +size 290204 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a6295dd81c7a9a4d865420bfcf05e3f867b9aa90 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +--extra-index-url https://download.pytorch.org/whl/cu124 + +torch==2.6.0 +torchvision==0.21.0 +triton==3.2.0 +pillow==12.0.0 +imageio==2.37.2 +imageio-ffmpeg==0.6.0 +tqdm==4.67.1 +easydict==1.13 +opencv-python-headless==4.12.0.88 +trimesh==4.10.1 +transformers==4.57.3 +zstandard==0.25.0 +kornia==0.8.2 +timm==1.0.22 +diffusers==0.37.1 +accelerate==1.13.0 +gradio + +git+https://github.com/microsoft/MoGe.git +https://github.com/LDYang694/Storages/releases/download/20260430/natten-0.21.0-cp310-cp310-linux_x86_64.whl +https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl +https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl +https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl +https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl +https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/nvdiffrast-0.4.0-cp310-cp310-linux_x86_64.whl +https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/nvdiffrec_render-0.0.0-cp310-cp310-linux_x86_64.whl \ No newline at end of file diff --git a/trellis2/__init__.py b/trellis2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20d240afc9c26a21aee76954628b3d4ef9a1ccbd --- /dev/null +++ b/trellis2/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/trellis2/datasets/__init__.py b/trellis2/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eadadadb935480ed30d44ae11ebe41f785ed772b --- /dev/null +++ b/trellis2/datasets/__init__.py @@ -0,0 +1,52 @@ +import importlib + +__attributes = { + 'FlexiDualGridDataset': 'flexi_dual_grid', + 'SparseVoxelPbrDataset':'sparse_voxel_pbr', + + 'SparseStructureLatent': 'sparse_structure_latent', + 'TextConditionedSparseStructureLatent': 'sparse_structure_latent', + 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent', + 'SparseStructureLatentView': 'sparse_structure_latent', + 'ViewImageConditionedSparseStructureLatentView': 'sparse_structure_latent', + + 'SLat': 'structured_latent', + 'ImageConditionedSLat': 'structured_latent', + 'SLatShape': 'structured_latent_shape', + 'ImageConditionedSLatShape': 'structured_latent_shape', + 'SLatShapeView': 'structured_latent_shape', + 'ViewImageConditionedSLatShapeView': 'structured_latent_shape', + 'SLatPbr': 'structured_latent_svpbr', + 'ImageConditionedSLatPbr': 'structured_latent_svpbr', + 'SLatPbrView': 'structured_latent_svpbr', + 'ViewImageConditionedSLatPbrView': 'structured_latent_svpbr', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .flexi_dual_grid import FlexiDualGridDataset + from .sparse_voxel_pbr import SparseVoxelPbrDataset + + from .sparse_structure_latent import SparseStructureLatent, ImageConditionedSparseStructureLatent + from .structured_latent import SLat, ImageConditionedSLat + from .structured_latent_shape import SLatShape, ImageConditionedSLatShape + from .structured_latent_svpbr import SLatPbr, ImageConditionedSLatPbr + \ No newline at end of file diff --git a/trellis2/datasets/components.py b/trellis2/datasets/components.py new file mode 100644 index 0000000000000000000000000000000000000000..8ebc93c1e87b16a28bb24fef8da61bc43502db60 --- /dev/null +++ b/trellis2/datasets/components.py @@ -0,0 +1,349 @@ +from typing import * +import json +from abc import abstractmethod +import os +import json +import torch +import numpy as np +import pandas as pd +from PIL import Image +from torch.utils.data import Dataset + + +class StandardDatasetBase(Dataset): + """ + Base class for standard datasets. + + Args: + roots (str): paths to the dataset + skip_list (str, optional): path to a file containing sha256 hashes to skip (one per line) + Format: "dataset/sha256" (e.g., "ABO/6a79dbb5...") + skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check + (e.g., ["texverse"] for datasets without aesthetic_score) + """ + + def __init__(self, + roots: str, + skip_list: Optional[str] = None, + skip_aesthetic_score_datasets: Optional[List[str]] = None, + ): + super().__init__() + + # Datasets to skip aesthetic score check + self.skip_aesthetic_score_datasets = set(skip_aesthetic_score_datasets or []) + + # Load skip list if provided + self.skip_set = set() + if skip_list is not None and os.path.exists(skip_list): + with open(skip_list, 'r') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#'): + self.skip_set.add(line) + print(f'Loaded {len(self.skip_set)} items from skip_list: {skip_list}') + + try: + self.roots = json.loads(roots) + root_type = 'obj' + except: + self.roots = roots.split(',') + root_type = 'list' + self.instances = [] + self.metadata = pd.DataFrame() + + self._stats = {} + if root_type == 'obj': + for key, root in self.roots.items(): + self._stats[key] = {} + metadata = pd.DataFrame(columns=['sha256']).set_index('sha256') + + # 只从 ss_latent 和 render_cond 合并关键字段 + # 不包含 base,因为 base/metadata.csv 中的 cond_rendered=False 会错误覆盖真实值 + for sub_key, r in root.items(): + if sub_key == 'base': + continue # 跳过 base 目录 + metadata_file = os.path.join(r, 'metadata.csv') + if os.path.exists(metadata_file): + metadata = metadata.combine_first(pd.read_csv(metadata_file).set_index('sha256')) + + # 从 base 单独读取 aesthetic_score(不读取其他可能冲突的列) + if 'base' in root: + base_metadata_file = os.path.join(root['base'], 'metadata.csv') + if os.path.exists(base_metadata_file): + base_df = pd.read_csv(base_metadata_file).set_index('sha256') + if 'aesthetic_score' in base_df.columns and 'aesthetic_score' not in metadata.columns: + metadata['aesthetic_score'] = base_df['aesthetic_score'] + + self._stats[key]['Total'] = len(metadata) + metadata, stats = self.filter_metadata(metadata, dataset_name=key) + self._stats[key].update(stats) + + # Filter out items in skip_list + skipped_count = 0 + for sha256 in metadata.index.values: + skip_key = f'{key}/{sha256}' + if skip_key in self.skip_set: + skipped_count += 1 + else: + self.instances.append((root, sha256, key)) + if skipped_count > 0: + self._stats[key]['Skipped (skip_list)'] = skipped_count + self._stats[key]['After skip_list'] = len(metadata) - skipped_count + + self.metadata = pd.concat([self.metadata, metadata]) + else: + for root in self.roots: + key = os.path.basename(root) + self._stats[key] = {} + metadata = pd.read_csv(os.path.join(root, 'metadata.csv')) + self._stats[key]['Total'] = len(metadata) + metadata, stats = self.filter_metadata(metadata, dataset_name=key) + self._stats[key].update(stats) + + # Filter out items in skip_list + skipped_count = 0 + for sha256 in metadata['sha256'].values: + skip_key = f'{key}/{sha256}' + if skip_key in self.skip_set: + skipped_count += 1 + else: + self.instances.append((root, sha256, key)) + if skipped_count > 0: + self._stats[key]['Skipped (skip_list)'] = skipped_count + self._stats[key]['After skip_list'] = len(metadata) - skipped_count + metadata.set_index('sha256', inplace=True) + self.metadata = pd.concat([self.metadata, metadata]) + + @abstractmethod + def filter_metadata(self, metadata: pd.DataFrame, dataset_name: str = None) -> Tuple[pd.DataFrame, Dict[str, int]]: + pass + + @abstractmethod + def get_instance(self, root, instance: str) -> Dict[str, Any]: + pass + + def __len__(self): + return len(self.instances) + + def __getitem__(self, index) -> Dict[str, Any]: + try: + root, instance, dataset_name = self.instances[index] + pack = self.get_instance(root, instance) + pack['_dataset_name'] = dataset_name + pack['_sha256'] = instance + return pack + except Exception as e: + print(f'Error loading {self.instances[index][1]}: {e}') + return self.__getitem__(np.random.randint(0, len(self))) + + def __str__(self): + lines = [] + lines.append(self.__class__.__name__) + lines.append(f' - Total instances: {len(self)}') + lines.append(f' - Sources:') + for key, stats in self._stats.items(): + lines.append(f' - {key}:') + for k, v in stats.items(): + lines.append(f' - {k}: {v}') + return '\n'.join(lines) + + +class ImageConditionedMixin: + def __init__(self, roots, *, image_size=518, **kwargs): + self.image_size = image_size + super().__init__(roots, **kwargs) + + def filter_metadata(self, metadata, dataset_name=None): + metadata, stats = super().filter_metadata(metadata, dataset_name=dataset_name) + metadata = metadata[metadata['cond_rendered'].notna()] + stats['Cond rendered'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + pack = super().get_instance(root, instance) + + image_root = os.path.join(root['render_cond'], instance) + with open(os.path.join(image_root, 'transforms.json')) as f: + metadata = json.load(f) + n_views = len(metadata['frames']) + view = np.random.randint(n_views) + metadata = metadata['frames'][view] + + image_path = os.path.join(image_root, metadata['file_path']) + image = Image.open(image_path) + + alpha = np.array(image.getchannel(3)) + bbox = np.array(alpha).nonzero() + bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + aug_hsize = hsize + aug_center_offset = [0, 0] + aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] + aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] + image = image.crop(aug_bbox) + + image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = image.getchannel(3) + image = image.convert('RGB') + image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + image = image * alpha.unsqueeze(0) + pack['cond'] = image + + return pack + + +class ViewImageConditionedMixin: + """ + Mixin for view-based image-conditioned datasets. + + This mixin is designed for datasets where ss_latent is stored per-view (view{XX}.npz), + and needs to load the corresponding view image and scale from view{XX}_scale.json. + + Args: + image_size: Target image size + load_camera_info: Whether to load camera information for view-aligned conditioning + """ + def __init__(self, roots, *, image_size=518, load_camera_info=False, **kwargs): + self.image_size = image_size + # self.load_camera_info = load_camera_info + super().__init__(roots, **kwargs) + + def filter_metadata(self, metadata, dataset_name=None): + metadata, stats = super().filter_metadata(metadata, dataset_name=dataset_name) + metadata = metadata[metadata['cond_rendered'].notna()] + stats['Cond rendered'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + """ + Get instance with view-aligned image and camera info. + + Expects parent class to set: + - pack['x_0']: the latent tensor + - self._current_view_idx: the selected view index + - self._current_latent_dir: the latent directory path + """ + pack = super().get_instance(root, instance) + + # Get view_idx from parent class (set by SparseStructureLatentView) + if not hasattr(self, '_current_view_idx'): + raise RuntimeError("Parent class must set '_current_view_idx' before calling ViewImageConditionedMixin.get_instance") + if not hasattr(self, '_current_latent_dir'): + raise RuntimeError("Parent class must set '_current_latent_dir' before calling ViewImageConditionedMixin.get_instance") + view_idx = self._current_view_idx + latent_dir = self._current_latent_dir + + # Load image metadata + image_root = os.path.join(root['render_cond'], instance) + with open(os.path.join(image_root, 'transforms.json')) as f: + metadata = json.load(f) + + # Load corresponding image for this view + frame_metadata = metadata['frames'][view_idx] + image_path = os.path.join(image_root, frame_metadata['file_path']) + image = Image.open(image_path) + + image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = image.getchannel(3) + image = image.convert('RGB') + image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + image = image * alpha.unsqueeze(0) + pack['cond'] = image + + # Load camera info if requested + + # camera_angle_x: check frame first, then root metadata + if 'camera_angle_x' in frame_metadata: + camera_angle_x = float(frame_metadata['camera_angle_x']) + elif 'camera_angle_x' in metadata: + camera_angle_x = float(metadata['camera_angle_x']) + else: + raise KeyError(f"'camera_angle_x' not found in transforms.json for {instance}") + pack['camera_angle_x'] = torch.tensor(camera_angle_x, dtype=torch.float32) + + # transform_matrix + if 'transform_matrix' not in frame_metadata: + raise KeyError(f"'transform_matrix' not found in frame {view_idx} for {instance}") + transform_matrix = torch.tensor(frame_metadata['transform_matrix'], dtype=torch.float32) + distance = torch.norm(transform_matrix[:3, 3]).item() + + pack['camera_distance'] = torch.tensor(distance, dtype=torch.float32) + # NOTE: Do NOT pass transform_matrix to ProjGrid. + # shape_latent space objects are already rotated to front-view by transform_mesh, + # so ProjGrid should use the default front_view_transform_matrix + distance. + # pack['transform_matrix'] = transform_matrix + + # Load mesh_scale from ss_latent directory's view{XX}_scale.json + scale_json_path = os.path.join(latent_dir, f'view{view_idx:02d}_scale.json') + if not os.path.exists(scale_json_path): + raise FileNotFoundError(f"Scale file not found: {scale_json_path}") + with open(scale_json_path) as f: + scale_data = json.load(f) + if 'total_scale' not in scale_data: + raise KeyError(f"'total_scale' not found in {scale_json_path}") + pack['mesh_scale'] = torch.tensor(float(scale_data['total_scale']), dtype=torch.float32) + + return pack + + +class MultiImageConditionedMixin: + def __init__(self, roots, *, image_size=518, max_image_cond_view = 4, **kwargs): + self.image_size = image_size + self.max_image_cond_view = max_image_cond_view + super().__init__(roots, **kwargs) + + def filter_metadata(self, metadata, dataset_name=None): + metadata, stats = super().filter_metadata(metadata, dataset_name=dataset_name) + metadata = metadata[metadata['cond_rendered'].notna()] + stats['Cond rendered'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + pack = super().get_instance(root, instance) + + image_root = os.path.join(root['render_cond'], instance) + with open(os.path.join(image_root, 'transforms.json')) as f: + metadata = json.load(f) + + n_views = len(metadata['frames']) + n_sample_views = np.random.randint(1, self.max_image_cond_view+1) + + assert n_views >= n_sample_views, f'Not enough views to sample {n_sample_views} unique images.' + + sampled_views = np.random.choice(n_views, size=n_sample_views, replace=False) + + cond_images = [] + for v in sampled_views: + frame_info = metadata['frames'][v] + image_path = os.path.join(image_root, frame_info['file_path']) + image = Image.open(image_path) + + alpha = np.array(image.getchannel(3)) + bbox = np.array(alpha).nonzero() + bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + aug_hsize = hsize + aug_center = center + aug_bbox = [ + int(aug_center[0] - aug_hsize), + int(aug_center[1] - aug_hsize), + int(aug_center[0] + aug_hsize), + int(aug_center[1] + aug_hsize), + ] + + img = image.crop(aug_bbox) + img = img.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = img.getchannel(3) + img = img.convert('RGB') + img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + img = img * alpha.unsqueeze(0) + + cond_images.append(img) + + pack['cond'] = [torch.stack(cond_images, dim=0)] # (V,3,H,W) + return pack diff --git a/trellis2/datasets/flexi_dual_grid.py b/trellis2/datasets/flexi_dual_grid.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a935349c105116653acfe9a7ab4ef91820d38d --- /dev/null +++ b/trellis2/datasets/flexi_dual_grid.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import pickle +import torch +import utils3d +from .components import StandardDatasetBase +from ..modules import sparse as sp +from ..renderers import MeshRenderer +from ..representations import Mesh +from ..utils.data_utils import load_balanced_group_indices +import o_voxel + + +class FlexiDualGridVisMixin: + @torch.no_grad() + def visualize_sample(self, x: dict): + mesh = x['mesh'] + + renderer = MeshRenderer({'near': 1, 'far': 3}) + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + # Build each representation + images = [] + for m in mesh: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = \ + renderer.render(m.cuda(), ext, intr)['normal'] + images.append(image) + images = torch.stack(images) + + return images + + +class FlexiDualGridDataset(FlexiDualGridVisMixin, StandardDatasetBase): + """ + Flexible Dual Grid Dataset + + Args: + roots (str): path to the dataset + resolution (int): resolution of the voxel grid + min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset + """ + + def __init__( + self, + roots, + resolution: int = 1024, + max_active_voxels: int = 1000000, + max_num_faces: int = None, + min_aesthetic_score: float = 5.0, + ): + self.resolution = resolution + self.min_aesthetic_score = min_aesthetic_score + self.max_active_voxels = max_active_voxels + self.max_num_faces = max_num_faces + self.value_range = (0, 1) + + super().__init__(roots) + + self.loads = [self.metadata.loc[sha256, f'dual_grid_size'] for _, sha256, _ in self.instances] + + def __str__(self): + lines = [ + super().__str__(), + f' - Resolution: {self.resolution}', + ] + return '\n'.join(lines) + + def filter_metadata(self, metadata, dataset_name=None): + stats = {} + metadata = metadata[metadata[f'dual_grid_converted'] == True] + stats['Dual Grid Converted'] = len(metadata) + if self.min_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata[f'dual_grid_size'] <= self.max_active_voxels] + stats[f'Active Voxels <= {self.max_active_voxels}'] = len(metadata) + if self.max_num_faces is not None: + metadata = metadata[metadata['num_faces'] <= self.max_num_faces] + stats[f'Faces <= {self.max_num_faces}'] = len(metadata) + return metadata, stats + + def read_mesh(self, root, instance): + with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f: + dump = pickle.load(f) + start = 0 + vertices = [] + faces = [] + for obj in dump['objects']: + if obj['vertices'].size == 0 or obj['faces'].size == 0: + continue + vertices.append(obj['vertices']) + faces.append(obj['faces'] + start) + start += len(obj['vertices']) + vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float() + faces = torch.from_numpy(np.concatenate(faces, axis=0)).long() + vertices_min = vertices.min(dim=0)[0] + vertices_max = vertices.max(dim=0)[0] + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range' + return {'mesh': [Mesh(vertices=vertices, faces=faces)]} + + def read_dual_grid(self, root, instance): + coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4) + vertices = sp.SparseTensor( + (attr['vertices'] / 255.0).float(), + torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1), + ) + intersected = vertices.replace(torch.cat([ + attr['intersected'] % 2, + attr['intersected'] // 2 % 2, + attr['intersected'] // 4 % 2, + ], dim=-1).bool()) + return {'vertices': vertices, 'intersected': intersected} + + def get_instance(self, root, instance): + mesh = self.read_mesh(root['mesh_dump'], instance) + dual_grid = self.read_dual_grid(root['dual_grid'], instance) + return {**mesh, **dual_grid} + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['vertices'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], sp.SparseTensor): + pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + \ No newline at end of file diff --git a/trellis2/datasets/sparse_structure_latent.py b/trellis2/datasets/sparse_structure_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..dd24d241e92c88591226f5719f09065db78bc2e8 --- /dev/null +++ b/trellis2/datasets/sparse_structure_latent.py @@ -0,0 +1,408 @@ +import os +import json +from typing import * +import numpy as np +import torch +import utils3d +from PIL import Image +from ..representations import Voxel +from ..renderers import VoxelRenderer +from .components import StandardDatasetBase, ImageConditionedMixin, ViewImageConditionedMixin +from .. import models +from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics + + +class SparseStructureLatentVisMixin: + def __init__( + self, + *args, + pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16.json', + ss_dec_path: Optional[str] = None, + ss_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.ss_dec = None + self.pretrained_ss_dec = pretrained_ss_dec + self.ss_dec_path = ss_dec_path + self.ss_dec_ckpt = ss_dec_ckpt + + def _loading_ss_dec(self): + if self.ss_dec is not None: + return + if self.ss_dec_path is not None: + cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_ss_dec) + self.ss_dec = decoder.cuda().eval() + + def _delete_ss_dec(self): + del self.ss_dec + self.ss_dec = None + + @torch.no_grad() + def decode_latent(self, z, batch_size=4): + self._loading_ss_dec() + ss = [] + if self.normalization: + z = z * self.std.to(z.device) + self.mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + ss.append(self.ss_dec(z[i:i+batch_size])) + ss = torch.cat(ss, dim=0) + self._delete_ss_dec() + return ss + + @torch.no_grad() + def visualize_sample( + self, + x_0: Union[torch.Tensor, dict], + camera_angle_x: Optional[torch.Tensor] = None, + camera_distance: Optional[torch.Tensor] = None, + mesh_scale: Optional[torch.Tensor] = None, + ): + """ + Visualize sparse structure samples. + + Args: + x_0: Latent tensor [B, C, D, H, W] or dict containing 'x_0' + camera_angle_x: Optional [B] camera FOV angle in radians + camera_distance: Optional [B] camera distance for GT view rendering + mesh_scale: Optional [B] mesh scale factor for coordinate alignment + + Returns: + dict with: + 'multiview': [B, 3, 1024, 1024] - 4 fixed views rendered in 2x2 grid + 'gt_view': [B, 3, 512, 512] - GT camera view (if camera params provided) + """ + x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0'] + x_0 = self.decode_latent(x_0.cuda()) + + renderer = VoxelRenderer() + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # Build fixed camera views (4 views: 0°, 90°, 180°, 270°) + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + fixed_exts, fixed_ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + # Check if we have GT camera parameters for front view rendering + # GT view uses the fixed front_view_transform_matrix from image_conditioned_proj.py + has_gt_camera = ( + camera_angle_x is not None and + camera_distance is not None and + mesh_scale is not None + ) + + multiview_images = [] + gt_view_images = [] + + # Build each representation + x_0 = x_0.cuda() + for i in range(x_0.shape[0]): + coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False) + resolution = x_0.shape[-1] + color = coords / resolution + + # Standard voxel for fixed multiview rendering (origin at [-0.5, -0.5, -0.5]) + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/resolution, + coords=coords, + attrs=color, + layout={ + 'color': slice(0, 3), + } + ) + + # Render 4 fixed views (2x2 grid) + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(fixed_exts, fixed_ints)): + res = renderer.render(rep, ext, intr, colors_overwrite=color) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + multiview_images.append(image) + + # Render GT camera view using the fixed front view from image_conditioned_proj.py + if has_gt_camera: + # The GT view should match exactly how ProjGrid projects 3D points to 2D. + # + # In image_conditioned_proj.py (ProjGrid.forward): + # 1. grid_points are in [-1, 1]^3 (from torch.linspace(-1, 1, res)) + # 2. grid_points are rotated by rotation_matrix (Y-Z swap): x'=x, y'=-z, z'=y + # 3. grid_points are scaled: grid_points / mesh_scale / 2 + # 4. Points are projected using front_view_transform_matrix with distance + # + # front_view_transform_matrix (camera-to-world): + # [[1, 0, 0, 0], + # [0, 0, -1, -distance], + # [0, 1, 0, 0], + # [0, 0, 0, 1]] + # + # Camera is at (0, -distance, 0) in Blender coords (Z-up), looking at origin. + # + # To match this in VoxelRenderer: + # 1. Voxel coords [0, res-1] map to positions via: pos = (coords + 0.5) * voxel_size + origin + # 2. We need these positions to match ProjGrid's transformed grid_points + # 3. Apply rotation by swapping/flipping coords, then scale voxel_size and origin + + scale = mesh_scale[i].item() + distance = camera_distance[i].item() + fov = camera_angle_x[i].item() + + # Coordinate transformation to match ProjGrid's rotation (x'=x, y'=-z, z'=y) + # new_coords maps to rotated positions in the same grid structure + new_coords = torch.zeros_like(coords) + new_coords[:, 0] = coords[:, 0] # x stays + new_coords[:, 1] = (resolution - 1) - coords[:, 2] # y' = -z (flip for negation) + new_coords[:, 2] = coords[:, 1] # z' = y + + # Voxel position calculation: + # Original: pos = (coords + 0.5) / res - 0.5 -> range [-0.5, 0.5] + # We need: pos = (coords + 0.5) * 2 / res - 1 -> range [-1, 1] (like ProjGrid) + # Then: pos_final = pos / scale / 2 -> range [-0.5/scale, 0.5/scale] + # + # Combined: pos_final = ((coords + 0.5) * 2 / res - 1) / scale / 2 + # = (coords + 0.5) / res / scale - 0.5 / scale + # = (coords + 0.5) * voxel_size + origin + # where: voxel_size = 1 / res / scale + # origin = -0.5 / scale + + scaled_voxel_size = 1.0 / resolution / scale + scaled_origin = [-0.5 / scale, -0.5 / scale, -0.5 / scale] + + rep_scaled = Voxel( + origin=scaled_origin, + voxel_size=scaled_voxel_size, + coords=new_coords, + attrs=color, + layout={ + 'color': slice(0, 3), + } + ) + + # Build the fixed front view camera (same as front_view_transform_matrix) + # Camera at (0, -distance, 0), looking at origin, up is Z + cam_pos = torch.tensor([0.0, -distance, 0.0], device=coords.device) + look_at = torch.tensor([0.0, 0.0, 0.0], device=coords.device) + cam_up = torch.tensor([0.0, 0.0, 1.0], device=coords.device) + + gt_ext = utils3d.torch.extrinsics_look_at(cam_pos, look_at, cam_up) + gt_int = utils3d.torch.intrinsics_from_fov_xy( + torch.tensor(fov, device=coords.device), + torch.tensor(fov, device=coords.device) + ) + + # Ensure tensors are on the correct device (utils3d may not preserve device) + gt_ext = gt_ext.to(coords.device) + gt_int = gt_int.to(coords.device) + + gt_res = renderer.render(rep_scaled, gt_ext, gt_int, colors_overwrite=color) + gt_view_images.append(gt_res['color']) + + result = { + 'multiview': torch.stack(multiview_images), + } + + if has_gt_camera and len(gt_view_images) > 0: + result['gt_view'] = torch.stack(gt_view_images) + + return result + + +class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase): + """ + Sparse structure latent dataset + + Args: + roots (str): path to the dataset + min_aesthetic_score (float): minimum aesthetic score + normalization (dict): normalization stats + pretrained_ss_dec (str): name of the pretrained sparse structure decoder + ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec + ss_dec_ckpt (str): name of the sparse structure decoder checkpoint + skip_list (str, optional): path to a file containing sha256 hashes to skip + skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check + """ + def __init__(self, + roots: str, + *, + min_aesthetic_score: float = 5.0, + normalization: Optional[dict] = None, + pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16', + ss_dec_path: Optional[str] = None, + ss_dec_ckpt: Optional[str] = None, + skip_list: Optional[str] = None, + skip_aesthetic_score_datasets: Optional[list] = None, + ): + self.min_aesthetic_score = min_aesthetic_score + self.normalization = normalization + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_ss_dec=pretrained_ss_dec, + ss_dec_path=ss_dec_path, + ss_dec_ckpt=ss_dec_ckpt, + skip_list=skip_list, + skip_aesthetic_score_datasets=skip_aesthetic_score_datasets, + ) + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1) + self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1) + + def filter_metadata(self, metadata, dataset_name=None): + stats = {} + metadata = metadata[metadata['ss_latent_encoded'] == True] + stats['With latent'] = len(metadata) + # Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist + skip_aesthetic = ( + (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or + ('aesthetic_score' not in metadata.columns) + ) + if skip_aesthetic: + stats[f'Aesthetic score check skipped'] = len(metadata) + else: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + latent = np.load(os.path.join(root['ss_latent'], f'{instance}.npz')) + z = torch.tensor(latent['z']).float() + if self.normalization is not None: + z = (z - self.mean) / self.std + + pack = { + 'x_0': z, + } + return pack + + +class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent): + """ + Image-conditioned sparse structure dataset + """ + pass + + +class SparseStructureLatentView(SparseStructureLatentVisMixin, StandardDatasetBase): + """ + View-based sparse structure latent dataset. + + Data format: {sha256}/view{XX}.npz where each npz contains 'z' key. + + Args: + num_views (int): Number of views to use (0 to num_views-1). Default is 2. + skip_list (str, optional): path to a file containing sha256 hashes to skip + skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check + """ + def __init__(self, + roots: str, + *, + min_aesthetic_score: float = 5.0, + normalization: Optional[dict] = None, + num_views: int = 2, + pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16', + ss_dec_path: Optional[str] = None, + ss_dec_ckpt: Optional[str] = None, + skip_list: Optional[str] = None, + skip_aesthetic_score_datasets: Optional[list] = None, + ): + self.min_aesthetic_score = min_aesthetic_score + self.normalization = normalization + self.num_views = num_views + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_ss_dec=pretrained_ss_dec, + ss_dec_path=ss_dec_path, + ss_dec_ckpt=ss_dec_ckpt, + skip_list=skip_list, + skip_aesthetic_score_datasets=skip_aesthetic_score_datasets, + ) + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1) + self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1) + + def filter_metadata(self, metadata, dataset_name=None): + stats = {} + # View-based ss_latent uses columns like: + # ss_latent_view00_encoded, ss_latent_view01_encoded, ... (view format) + # ss_latent_view_scale00_encoded, ss_latent_view_scale01_encoded, ... (view_scale format) + # Check both formats and use whichever exists (prefer view_scale over view) + required_view_cols = [f'ss_latent_view_scale{i:02d}_encoded' for i in range(self.num_views)] + existing_view_cols = [col for col in required_view_cols if col in metadata.columns] + + if not existing_view_cols: + # Fallback to view format + required_view_cols = [f'ss_latent_view{i:02d}_encoded' for i in range(self.num_views)] + existing_view_cols = [col for col in required_view_cols if col in metadata.columns] + + if existing_view_cols: + # Filter rows where all required views are encoded + # 注意:NaN 需要被视为 False,所以用 == True 显式比较 + has_all_views = (metadata[existing_view_cols] == True).all(axis=1) + metadata = metadata[has_all_views] + stats[f'With {self.num_views} view latents'] = len(metadata) + else: + # Fallback: check ss_latent_encoded column + if 'ss_latent_encoded' in metadata.columns: + metadata = metadata[metadata['ss_latent_encoded'] == True] + stats['With latent'] = len(metadata) + else: + raise ValueError(f'No view columns found in metadata: {metadata.columns.tolist()}') + # Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist + skip_aesthetic = ( + (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or + ('aesthetic_score' not in metadata.columns) + ) + if skip_aesthetic: + stats[f'Aesthetic score check skipped'] = len(metadata) + else: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + # View-based format: directory with view{XX}.npz files + latent_dir = os.path.join(root['ss_latent'], instance) + + # Randomly select a view from the configured range + view_idx = np.random.randint(0, self.num_views) + view_file = f'view{view_idx:02d}.npz' + + # Store view info for ViewImageConditionedMixin + self._current_view_idx = view_idx + self._current_latent_dir = latent_dir + + latent = np.load(os.path.join(latent_dir, view_file)) + z = torch.tensor(latent['z']).float() + if self.normalization is not None: + z = (z - self.mean) / self.std + + pack = { + 'x_0': z, + 'view_idx': view_idx, + } + return pack + + +class ViewImageConditionedSparseStructureLatentView(ViewImageConditionedMixin, SparseStructureLatentView): + """ + Image-conditioned view-based sparse structure dataset. + + Loads ss_latent from {sha256}/view{XX}.npz format and pairs with + corresponding view from render_cond. + + Uses ViewImageConditionedMixin which reads mesh_scale from view{XX}_scale.json. + """ + pass diff --git a/trellis2/datasets/sparse_voxel_pbr.py b/trellis2/datasets/sparse_voxel_pbr.py new file mode 100644 index 0000000000000000000000000000000000000000..d9146b164111df390665af6dfeab19a9a2ca64e1 --- /dev/null +++ b/trellis2/datasets/sparse_voxel_pbr.py @@ -0,0 +1,298 @@ +import os +import io +from typing import Union +import numpy as np +import pickle +import torch +from PIL import Image +import o_voxel +import utils3d +from .components import StandardDatasetBase +from ..modules import sparse as sp +from ..renderers import VoxelRenderer +from ..representations import Voxel +from ..representations.mesh import MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture + +from ..utils.data_utils import load_balanced_group_indices + + +def is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +def nearest_power_of_two(n: int) -> int: + if n < 1: + raise ValueError("n must be >= 1") + if is_power_of_two(n): + return n + lower = 2 ** (n.bit_length() - 1) + upper = 2 ** n.bit_length() + if n - lower < upper - n: + return lower + else: + return upper + + +class SparseVoxelPbrVisMixin: + @torch.no_grad() + def visualize_sample(self, x: Union[sp.SparseTensor, dict]): + x = x if isinstance(x, sp.SparseTensor) else x['x'] + + renderer = VoxelRenderer() + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + images = {k: [] for k in self.layout} + + # Build each representation + x = x.cuda() + for i in range(x.shape[0]): + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/self.resolution, + coords=x[i].coords[:, 1:].contiguous(), + attrs=None, + layout={ + 'color': slice(0, 3), + } + ) + for k in self.layout: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + attr = x[i].feats[:, self.layout[k]].expand(-1, 3) + res = renderer.render(rep, ext, intr, colors_overwrite=attr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images[k].append(image) + + for k in self.layout: + images[k] = torch.stack(images[k]) + + return images + + +class SparseVoxelPbrDataset(SparseVoxelPbrVisMixin, StandardDatasetBase): + """ + Sparse Voxel PBR dataset. + + Args: + roots (str): path to the dataset + resolution (int): resolution of the voxel grid + min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset + """ + + def __init__( + self, + roots, + resolution: int = 1024, + max_active_voxels: int = 1000000, + max_num_faces: int = None, + min_aesthetic_score: float = 5.0, + attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'], + with_mesh: bool = True, + ): + self.resolution = resolution + self.min_aesthetic_score = min_aesthetic_score + self.max_active_voxels = max_active_voxels + self.max_num_faces = max_num_faces + self.with_mesh = with_mesh + self.value_range = (-1, 1) + self.channels = { + 'base_color': 3, + 'metallic': 1, + 'roughness': 1, + 'emissive': 3, + 'alpha': 1, + } + self.layout = {} + start = 0 + for attr in attrs: + self.layout[attr] = slice(start, start + self.channels[attr]) + start += self.channels[attr] + + super().__init__(roots) + + self.loads = [self.metadata.loc[sha256, f'num_pbr_voxels'] for _, sha256, _ in self.instances] + + def __str__(self): + lines = [ + super().__str__(), + f' - Resolution: {self.resolution}', + f' - Attributes: {list(self.layout.keys())}', + ] + return '\n'.join(lines) + + def filter_metadata(self, metadata, dataset_name=None): + stats = {} + metadata = metadata[metadata['pbr_voxelized'] == True] + stats['PBR Voxelized'] = len(metadata) + if self.min_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata['num_pbr_voxels'] <= self.max_active_voxels] + stats[f'Active voxels <= {self.max_active_voxels}'] = len(metadata) + if self.max_num_faces is not None: + metadata = metadata[metadata['num_faces'] <= self.max_num_faces] + stats[f'Faces <= {self.max_num_faces}'] = len(metadata) + return metadata, stats + + @staticmethod + def _texture_from_dump(pack) -> Texture: + png_bytes = pack['image'] + image = Image.open(io.BytesIO(png_bytes)) + if image.width != image.height or not is_power_of_two(image.width): + size = nearest_power_of_two(max(image.width, image.height)) + image = image.resize((size, size), Image.LANCZOS) + texture = torch.tensor(np.array(image) / 255.0, dtype=torch.float32).reshape(image.height, image.width, -1) + filter_mode = { + 'Linear': TextureFilterMode.LINEAR, + 'Closest': TextureFilterMode.CLOSEST, + 'Cubic': TextureFilterMode.LINEAR, + 'Smart': TextureFilterMode.LINEAR, + }[pack['interpolation']] + wrap_mode = { + 'REPEAT': TextureWrapMode.REPEAT, + 'EXTEND': TextureWrapMode.CLAMP_TO_EDGE, + 'CLIP': TextureWrapMode.CLAMP_TO_EDGE, + 'MIRROR': TextureWrapMode.MIRRORED_REPEAT, + }[pack['extension']] + return Texture(texture, filter_mode=filter_mode, wrap_mode=wrap_mode) + + def read_mesh_with_texture(self, root, instance): + with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f: + dump = pickle.load(f) + + # Fix dump alpha map + for mat in dump['materials']: + if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE': + mat['alphaMode'] = 'BLEND' + + # process material + materials = [] + for mat in dump['materials']: + materials.append(PbrMaterial( + base_color_texture=self._texture_from_dump(mat['baseColorTexture']) if mat['baseColorTexture'] is not None else None, + base_color_factor=mat['baseColorFactor'], + metallic_texture=self._texture_from_dump(mat['metallicTexture']) if mat['metallicTexture'] is not None else None, + metallic_factor=mat['metallicFactor'], + roughness_texture=self._texture_from_dump(mat['roughnessTexture']) if mat['roughnessTexture'] is not None else None, + roughness_factor=mat['roughnessFactor'], + alpha_texture=self._texture_from_dump(mat['alphaTexture']) if mat['alphaTexture'] is not None else None, + alpha_factor=mat['alphaFactor'], + alpha_mode={ + 'OPAQUE': AlphaMode.OPAQUE, + 'MASK': AlphaMode.MASK, + 'BLEND': AlphaMode.BLEND, + }[mat['alphaMode']], + alpha_cutoff=mat['alphaCutoff'], + )) + materials.append(PbrMaterial( + base_color_factor=[0.8, 0.8, 0.8], + alpha_factor=1.0, + metallic_factor=0.0, + roughness_factor=0.5, + alpha_mode=AlphaMode.OPAQUE, + alpha_cutoff=0.5, + )) # append default material + + # process mesh + start = 0 + vertices = [] + faces = [] + material_ids = [] + uv_coords = [] + for obj in dump['objects']: + if obj['vertices'].size == 0 or obj['faces'].size == 0: + continue + vertices.append(obj['vertices']) + faces.append(obj['faces'] + start) + obj['mat_ids'][obj['mat_ids'] == -1] = len(materials) - 1 + material_ids.append(obj['mat_ids']) + uv_coords.append(obj['uvs'] if obj['uvs'] is not None else np.zeros((obj['faces'].shape[0], 3, 2), dtype=np.float32)) + start += len(obj['vertices']) + + vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float() + faces = torch.from_numpy(np.concatenate(faces, axis=0)).long() + material_ids = torch.from_numpy(np.concatenate(material_ids, axis=0)).long() + uv_coords = torch.from_numpy(np.concatenate(uv_coords, axis=0)).float() + + # Normalize vertices + vertices_min = vertices.min(dim=0)[0] + vertices_max = vertices.max(dim=0)[0] + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range' + + return {'mesh': [MeshWithPbrMaterial( + vertices=vertices, + faces=faces, + material_ids=material_ids, + uv_coords=uv_coords, + materials=materials, + )]} + + def read_pbr_voxel(self, root, instance): + coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4) + feats = torch.concat([attr[k] for k in self.layout], dim=-1) / 255.0 * 2 - 1 + x = sp.SparseTensor( + feats.float(), + torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1), + ) + return {'x': x} + + def get_instance(self, root, instance): + if self.with_mesh: + mesh = self.read_mesh_with_texture(root['pbr_dump'], instance) + pbr_voxel = self.read_pbr_voxel(root['pbr_voxel'], instance) + return {**mesh, **pbr_voxel} + else: + return self.read_pbr_voxel(root['pbr_voxel'], instance) + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['x'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], sp.SparseTensor): + pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs diff --git a/trellis2/datasets/structured_latent.py b/trellis2/datasets/structured_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc0bccd4128187cb948c259a80b892a6d715a76 --- /dev/null +++ b/trellis2/datasets/structured_latent.py @@ -0,0 +1,224 @@ +import json +import os +from typing import * +import numpy as np +import torch +import utils3d.torch +from .components import StandardDatasetBase, ImageConditionedMixin +from ..modules.sparse.basic import SparseTensor +from .. import models +from ..utils.render_utils import get_renderer +from ..utils.data_utils import load_balanced_group_indices + + +class SLatVisMixin: + def __init__( + self, + *args, + pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.slat_dec = None + self.pretrained_slat_dec = pretrained_slat_dec + self.slat_dec_path = slat_dec_path + self.slat_dec_ckpt = slat_dec_ckpt + + def _loading_slat_dec(self): + if self.slat_dec is not None: + return + if self.slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_slat_dec) + self.slat_dec = decoder.cuda().eval() + + def _delete_slat_dec(self): + del self.slat_dec + self.slat_dec = None + + @torch.no_grad() + def decode_latent(self, z, batch_size=4): + self._loading_slat_dec() + reps = [] + if self.normalization is not None: + z = z * self.std.to(z.device) + self.mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + reps.append(self.slat_dec(z[i:i+batch_size])) + reps = sum(reps, []) + self._delete_slat_dec() + return reps + + @torch.no_grad() + def visualize_sample(self, x_0: Union[SparseTensor, dict]): + x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] + reps = self.decode_latent(x_0.cuda()) + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(40)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + renderer = get_renderer(reps[0]) + images = [] + for representation in reps: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images.append(image) + images = torch.stack(images) + + return images + + +class SLat(SLatVisMixin, StandardDatasetBase): + """ + structured latent V2 dataset + + Args: + roots (str): path to the dataset + min_aesthetic_score (float): minimum aesthetic score + max_tokens (int): maximum number of tokens + latent_key (str): key of the latent to be used + normalization (dict): normalization stats + pretrained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + skip_list (str, optional): path to a file containing sha256 hashes to skip + skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check + """ + def __init__(self, + roots: str, + *, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + latent_key: str = 'shape_latent', + normalization: Optional[dict] = None, + pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + skip_list: Optional[str] = None, + skip_aesthetic_score_datasets: Optional[list] = None, + ): + self.normalization = normalization + self.min_aesthetic_score = min_aesthetic_score + self.max_tokens = max_tokens + self.latent_key = latent_key + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_slat_dec=pretrained_slat_dec, + slat_dec_path=slat_dec_path, + slat_dec_ckpt=slat_dec_ckpt, + skip_list=skip_list, + skip_aesthetic_score_datasets=skip_aesthetic_score_datasets, + ) + + self.loads = [self.metadata.loc[sha256, f'{latent_key}_tokens'] for _, sha256, _ in self.instances] + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1) + self.std = torch.tensor(self.normalization['std']).reshape(1, -1) + + def filter_metadata(self, metadata, dataset_name=None): + stats = {} + metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True] + stats['With latent'] = len(metadata) + # Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist + skip_aesthetic = ( + (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or + ('aesthetic_score' not in metadata.columns) + ) + if skip_aesthetic: + stats[f'Aesthetic score check skipped'] = len(metadata) + else: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata[f'{self.latent_key}_tokens'] <= self.max_tokens] + stats[f'Num tokens <= {self.max_tokens}'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + data = np.load(os.path.join(root[self.latent_key], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + feats = torch.tensor(data['feats']).float() + if self.normalization is not None: + feats = (feats - self.mean) / self.std + return { + 'coords': coords, + 'feats': feats, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + coords = [] + feats = [] + layout = [] + start = 0 + for i, b in enumerate(sub_batch): + coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1)) + feats.append(b['feats']) + layout.append(slice(start, start + b['coords'].shape[0])) + start += b['coords'].shape[0] + coords = torch.cat(coords) + feats = torch.cat(feats) + pack['x_0'] = SparseTensor( + coords=coords, + feats=feats, + ) + pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]]) + pack['x_0'].register_spatial_cache('layout', layout) + + # collate other data + keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class ImageConditionedSLat(ImageConditionedMixin, SLat): + """ + Image conditioned structured latent dataset + """ + pass diff --git a/trellis2/datasets/structured_latent_shape.py b/trellis2/datasets/structured_latent_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b40e0a61a1caadb34eecf91bf38d6c8128e2c5 --- /dev/null +++ b/trellis2/datasets/structured_latent_shape.py @@ -0,0 +1,402 @@ +import os +import json +from typing import * +import numpy as np +import torch +import utils3d +from .. import models +from .components import ImageConditionedMixin, ViewImageConditionedMixin +from ..modules.sparse import SparseTensor +from .structured_latent import SLatVisMixin, SLat +from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics +from ..utils.data_utils import load_balanced_group_indices + + +class SLatShapeVisMixin(SLatVisMixin): + def _loading_slat_dec(self): + if self.slat_dec is not None: + return + if self.slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_slat_dec) + decoder.set_resolution(self.resolution) + self.slat_dec = decoder.cuda().eval() + + @torch.no_grad() + def visualize_sample( + self, + x_0: Union[SparseTensor, dict], + camera_angle_x: Optional[torch.Tensor] = None, + camera_distance: Optional[torch.Tensor] = None, + mesh_scale: Optional[torch.Tensor] = None, + ): + """ + Visualize shape samples. + + Args: + x_0: SparseTensor or dict containing 'x_0' + camera_angle_x: Optional [B] camera FOV angle in radians + camera_distance: Optional [B] camera distance for GT view rendering + mesh_scale: Optional [B] mesh scale factor for coordinate alignment + + Returns: + dict with: + 'multiview': [B, 3, 1024, 1024] - 4 fixed views rendered in 2x2 grid (normal) + 'gt_view': [B, 3, 512, 512] - GT camera view (if camera params provided) + """ + x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] + reps = self.decode_latent(x_0.cuda()) + + # build fixed camera views (4 views: 0°, 90°, 180°, 270°) + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + fixed_exts, fixed_ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + # Check if we have GT camera parameters for GT view rendering + has_gt_camera = ( + camera_angle_x is not None and + camera_distance is not None and + mesh_scale is not None + ) + + # render + renderer = get_renderer(reps[0]) + multiview_images = [] + gt_view_images = [] + + for i, representation in enumerate(reps): + # Render 4 fixed views (2x2 grid) + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + + # Validate mesh data before rasterization + verts = representation.vertices + faces = representation.faces + if verts.shape[0] == 0 or faces.shape[0] == 0: + print(f"[visualize_sample] Warning: sample {i} has empty mesh, skipping") + multiview_images.append(image) + continue + if faces.max() >= verts.shape[0]: + print(f"[visualize_sample] Warning: sample {i} has out-of-bound face indices " + f"(max face idx={faces.max().item()}, num verts={verts.shape[0]}), skipping") + multiview_images.append(image) + continue + if torch.isnan(verts).any() or torch.isinf(verts).any(): + print(f"[visualize_sample] Warning: sample {i} has NaN/Inf vertices, skipping") + multiview_images.append(image) + continue + + try: + for j, (ext, intr) in enumerate(zip(fixed_exts, fixed_ints)): + res = renderer.render(representation, ext, intr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['normal'] + except RuntimeError as e: + print(f"[visualize_sample] Warning: render failed for sample {i}: {e}") + image = torch.zeros(3, 1024, 1024).cuda() + multiview_images.append(image) + + # Render GT camera view using the fixed front view (same as sparse_structure_latent.py) + if has_gt_camera: + # The GT view should match exactly how ProjGrid projects 3D points to 2D. + # + # In image_conditioned_proj.py (ProjGrid.forward): + # 1. grid_points are in [-1, 1]^3 (from torch.linspace(-1, 1, res)) + # 2. grid_points are rotated by rotation_matrix (Y-Z swap): x'=x, y'=-z, z'=y + # 3. grid_points are scaled: grid_points / mesh_scale / 2 + # 4. Points are projected using front_view_transform_matrix with distance + # + # Mesh vertices are in [-0.5, 0.5]^3. To match ProjGrid's coordinate space, + # we need to scale them: vertices / mesh_scale -> [-0.5/s, 0.5/s]^3 + # This is equivalent to ProjGrid's: [-1,1]^3 / scale / 2 -> [-0.5/s, 0.5/s]^3 + # + # Camera position: ProjGrid camera is at (0, -distance, 0) in Blender coords (Z-up). + # After inverse rotation to mesh space, camera is at (0, 0, distance). + + scale = mesh_scale[i].item() + distance = camera_distance[i].item() + fov = camera_angle_x[i].item() + device = representation.vertices.device + + # Scale mesh vertices to match ProjGrid's projection space + from ..representations import Mesh + scaled_rep = Mesh( + vertices=representation.vertices / scale, + faces=representation.faces, + ) + + cam_pos = torch.tensor([0.0, 0.0, distance], device=device) + look_at = torch.tensor([0.0, 0.0, 0.0], device=device) + cam_up = torch.tensor([0.0, 1.0, 0.0], device=device) + + gt_ext = utils3d.torch.extrinsics_look_at(cam_pos, look_at, cam_up) + gt_int = utils3d.torch.intrinsics_from_fov_xy( + torch.tensor(fov, device=device), + torch.tensor(fov, device=device) + ) + + gt_ext = gt_ext.to(device) + gt_int = gt_int.to(device) + + # Use scaled mesh renderer with appropriate near/far for smaller mesh + mesh_half_size = 0.5 / scale + renderer.rendering_options.near = max(0.01, distance - mesh_half_size - 0.5) + renderer.rendering_options.far = distance + mesh_half_size + 0.5 + + try: + gt_res = renderer.render(scaled_rep, gt_ext, gt_int) + gt_view_images.append(gt_res['normal']) + except RuntimeError as e: + print(f"[visualize_sample] Warning: GT view render failed for sample {i}: {e}") + gt_view_images.append(torch.full((3, 512, 512), 0.5, device=device)) + + result = { + 'multiview': torch.stack(multiview_images), + } + + if has_gt_camera and len(gt_view_images) > 0: + result['gt_view'] = torch.stack(gt_view_images) + + return result + + +class SLatShape(SLatShapeVisMixin, SLat): + """ + structured latent for shape generation + + Args: + roots (str): path to the dataset + resolution (int): resolution of the shape + min_aesthetic_score (float): minimum aesthetic score + max_tokens (int): maximum number of tokens + latent_key (str): key of the latent to be used + normalization (dict): normalization stats + pretrained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + skip_list (str, optional): path to a file containing sha256 hashes to skip + skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check + """ + def __init__(self, + roots: str, + *, + resolution: int, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + normalization: Optional[dict] = None, + pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + skip_list: Optional[str] = None, + skip_aesthetic_score_datasets: Optional[list] = None, + ): + super().__init__( + roots, + min_aesthetic_score=min_aesthetic_score, + max_tokens=max_tokens, + latent_key='shape_latent', + normalization=normalization, + pretrained_slat_dec=pretrained_slat_dec, + slat_dec_path=slat_dec_path, + slat_dec_ckpt=slat_dec_ckpt, + skip_list=skip_list, + skip_aesthetic_score_datasets=skip_aesthetic_score_datasets, + ) + self.resolution = resolution + + +class ImageConditionedSLatShape(ImageConditionedMixin, SLatShape): + """ + Image conditioned structured latent for shape generation + """ + pass + + +class SLatShapeView(SLatShapeVisMixin, SLat): + """ + View-based structured latent for shape generation. + + Data format: {sha256}/view{XX}.npz where each npz contains 'coords' and 'feats' keys. + + Args: + roots (str): path to the dataset + resolution (int): resolution of the shape + min_aesthetic_score (float): minimum aesthetic score + max_tokens (int): maximum number of tokens + num_views (int): Number of views to use (0 to num_views-1). Default is 2. + normalization (dict): normalization stats + pretrained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + skip_list (str, optional): path to a file containing sha256 hashes to skip + skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check + """ + def __init__(self, + roots: str, + *, + resolution: int, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + num_views: int = 2, + normalization: Optional[dict] = None, + pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + skip_list: Optional[str] = None, + skip_aesthetic_score_datasets: Optional[list] = None, + ): + self.normalization = normalization + self.min_aesthetic_score = min_aesthetic_score + self.max_tokens = max_tokens + self.num_views = num_views + self.latent_key = 'shape_latent' + self.value_range = (0, 1) + + # Initialize parent with SLatVisMixin parameters + from .components import StandardDatasetBase + SLatVisMixin.__init__( + self, + roots, + pretrained_slat_dec=pretrained_slat_dec, + slat_dec_path=slat_dec_path, + slat_dec_ckpt=slat_dec_ckpt, + ) + StandardDatasetBase.__init__(self, roots, skip_list=skip_list, skip_aesthetic_score_datasets=skip_aesthetic_score_datasets) + + self.resolution = resolution + + # Calculate loads for load balancing + self.loads = [] + for _, sha256, _ in self.instances: + if f'{self.latent_key}_tokens' in self.metadata.columns: + try: + self.loads.append(self.metadata.loc[sha256, f'{self.latent_key}_tokens']) + except: + self.loads.append(self.max_tokens) + else: + self.loads.append(self.max_tokens) + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1) + self.std = torch.tensor(self.normalization['std']).reshape(1, -1) + + def filter_metadata(self, metadata, dataset_name=None): + stats = {} + # View-based shape_latent uses columns like shape_latent_view00_encoded, shape_latent_view01_encoded, etc. + required_view_cols = [f'shape_latent_view{i:02d}_encoded' for i in range(self.num_views)] + existing_view_cols = [col for col in required_view_cols if col in metadata.columns] + + if existing_view_cols: + # Filter rows where all required views are encoded + # 注意:NaN 需要被视为 False,所以用 == True 显式比较 + has_all_views = (metadata[existing_view_cols] == True).all(axis=1) + metadata = metadata[has_all_views] + stats[f'With {self.num_views} view latents'] = len(metadata) + else: + # Fallback: check shape_latent_encoded column + if f'{self.latent_key}_encoded' in metadata.columns: + metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True] + stats['With latent'] = len(metadata) + + # Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist + skip_aesthetic = ( + (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or + ('aesthetic_score' not in metadata.columns) + ) + if skip_aesthetic: + stats[f'Aesthetic score check skipped'] = len(metadata) + else: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + + # Filter by max_tokens if column exists + tokens_col = f'{self.latent_key}_tokens' + if tokens_col in metadata.columns: + metadata = metadata[metadata[tokens_col] <= self.max_tokens] + stats[f'Num tokens <= {self.max_tokens}'] = len(metadata) + + return metadata, stats + + def get_instance(self, root, instance): + # View-based format: directory with view{XX}.npz files + latent_dir = os.path.join(root[self.latent_key], instance) + + # Randomly select a view from the configured range + view_idx = np.random.randint(0, self.num_views) + view_file = f'view{view_idx:02d}.npz' + + # Store view info for ViewImageConditionedMixin + self._current_view_idx = view_idx + self._current_latent_dir = latent_dir + + data = np.load(os.path.join(latent_dir, view_file)) + coords = torch.tensor(data['coords']).int() + feats = torch.tensor(data['feats']).float() + if self.normalization is not None: + feats = (feats - self.mean) / self.std + return { + 'coords': coords, + 'feats': feats, + 'view_idx': view_idx, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + coords = [] + feats = [] + layout = [] + start = 0 + for i, b in enumerate(sub_batch): + coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1)) + feats.append(b['feats']) + layout.append(slice(start, start + b['coords'].shape[0])) + start += b['coords'].shape[0] + coords = torch.cat(coords) + feats = torch.cat(feats) + pack['x_0'] = SparseTensor( + coords=coords, + feats=feats, + ) + pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]]) + pack['x_0'].register_spatial_cache('layout', layout) + + # collate other data + keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class ViewImageConditionedSLatShapeView(ViewImageConditionedMixin, SLatShapeView): + """ + Image-conditioned view-based structured latent for shape generation. + + Loads shape_latent from {sha256}/view{XX}.npz format and pairs with + corresponding view from render_cond. + + Uses ViewImageConditionedMixin which reads mesh_scale from view{XX}_scale.json. + """ + pass diff --git a/trellis2/datasets/structured_latent_svpbr.py b/trellis2/datasets/structured_latent_svpbr.py new file mode 100644 index 0000000000000000000000000000000000000000..5aeb1579ed23a2ac5016b63683697ec7cc7ccf82 --- /dev/null +++ b/trellis2/datasets/structured_latent_svpbr.py @@ -0,0 +1,666 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +import json +from typing import * +import numpy as np +import torch +import cv2 +import utils3d +from .. import models +from .components import StandardDatasetBase, ImageConditionedMixin, ViewImageConditionedMixin +from ..modules.sparse import SparseTensor, sparse_cat +from ..representations import MeshWithVoxel +from ..renderers import PbrMeshRenderer, EnvMap +from ..utils.data_utils import load_balanced_group_indices +from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics + + +class SLatPbrVisMixin: + def __init__( + self, + *args, + pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16', + pbr_slat_dec_path: Optional[str] = None, + pbr_slat_dec_ckpt: Optional[str] = None, + pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + shape_slat_dec_path: Optional[str] = None, + shape_slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.pbr_slat_dec = None + self.pretrained_pbr_slat_dec = pretrained_pbr_slat_dec + self.pbr_slat_dec_path = pbr_slat_dec_path + self.pbr_slat_dec_ckpt = pbr_slat_dec_ckpt + self.shape_slat_dec = None + self.pretrained_shape_slat_dec = pretrained_shape_slat_dec + self.shape_slat_dec_path = shape_slat_dec_path + self.shape_slat_dec_ckpt = shape_slat_dec_ckpt + + def _loading_slat_dec(self): + if self.pbr_slat_dec is not None and self.shape_slat_dec is not None: + return + if self.pbr_slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.pbr_slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.pbr_slat_dec_path, 'ckpts', f'decoder_{self.pbr_slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_pbr_slat_dec) + self.pbr_slat_dec = decoder.cuda().eval() + + if self.shape_slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.shape_slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.shape_slat_dec_path, 'ckpts', f'decoder_{self.shape_slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_shape_slat_dec) + decoder.set_resolution(self.resolution) + self.shape_slat_dec = decoder.cuda().eval() + + def _delete_slat_dec(self): + del self.pbr_slat_dec + self.pbr_slat_dec = None + del self.shape_slat_dec + self.shape_slat_dec = None + + @torch.no_grad() + def decode_latent(self, z, shape_z, batch_size=4): + self._loading_slat_dec() + reps = [] + if self.shape_slat_normalization is not None: + shape_z = shape_z * self.shape_slat_std.to(z.device) + self.shape_slat_mean.to(z.device) + if self.pbr_slat_normalization is not None: + z = z * self.pbr_slat_std.to(z.device) + self.pbr_slat_mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + mesh, subs = self.shape_slat_dec(shape_z[i:i+batch_size], return_subs=True) + vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs) * 0.5 + 0.5 + reps.extend([ + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / self.resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout = self.layout, + ) + for m, v in zip(mesh, vox) + ]) + self._delete_slat_dec() + return reps + + @torch.no_grad() + def visualize_sample(self, sample: dict): + shape_z = sample['concat_cond'].cuda() + z = sample['x_0'].cuda() + reps = self.decode_latent(z, shape_z) + + # Extract camera parameters for GT view rendering (if available) + camera_angle_x = sample.get('camera_angle_x') + camera_distance = sample.get('camera_distance') + mesh_scale = sample.get('mesh_scale') + has_gt_camera = ( + camera_angle_x is not None and + camera_distance is not None and + mesh_scale is not None + ) + + # build camera + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + # render + renderer = PbrMeshRenderer() + renderer.rendering_options.resolution = 512 + renderer.rendering_options.near = 1 + renderer.rendering_options.far = 100 + renderer.rendering_options.ssaa = 2 + renderer.rendering_options.peel_layers = 8 + envmap = EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' + )) + + images = {} + gt_view_images = {} + for i, representation in enumerate(reps): + # Validate mesh data before rasterization (same as shape training) + verts = representation.vertices + faces = representation.faces + if verts.shape[0] == 0 or faces.shape[0] == 0: + print(f"[visualize_sample] Warning: sample {i} has empty mesh, skipping") + continue + if faces.max() >= verts.shape[0]: + print(f"[visualize_sample] Warning: sample {i} has out-of-bound face indices " + f"(max face idx={faces.max().item()}, num verts={verts.shape[0]}), skipping") + continue + if torch.isnan(verts).any() or torch.isinf(verts).any(): + print(f"[visualize_sample] Warning: sample {i} has NaN/Inf vertices, skipping") + continue + + image = {} + tile = [2, 2] + try: + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr, envmap=envmap) + for k, v in res.items(): + if k not in images: + images[k] = [] + if k not in image: + image[k] = torch.zeros(3, 1024, 1024).cuda() + image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = v + for k in images.keys(): + images[k].append(image[k]) + except RuntimeError as e: + print(f"[visualize_sample] Warning: render failed for sample {i}: {e}") + try: + torch.cuda.synchronize() + except Exception: + pass + try: + torch.cuda.empty_cache() + except Exception: + pass + continue + + # Render GT camera view + # Must scale mesh vertices by / mesh_scale to match ProjGrid's projection space. + # ProjGrid maps [-1,1]^3 -> / scale / 2 -> [-0.5/s, 0.5/s]^3 + # Mesh vertices in [-0.5, 0.5]^3 -> / scale -> [-0.5/s, 0.5/s]^3 (equivalent) + if has_gt_camera: + try: + scale = mesh_scale[i].item() + distance = camera_distance[i].item() + fov = camera_angle_x[i].item() + device = representation.vertices.device + + # Scale mesh and voxel to match ProjGrid's projection space + scaled_rep = MeshWithVoxel( + vertices=representation.vertices / scale, + faces=representation.faces, + origin=(representation.origin / scale).tolist(), + voxel_size=representation.voxel_size / scale, + coords=representation.coords, + attrs=representation.attrs, + voxel_shape=representation.voxel_shape, + layout=representation.layout, + ) + + cam_pos = torch.tensor([0.0, 0.0, distance], device=device) + look_at = torch.tensor([0.0, 0.0, 0.0], device=device) + cam_up = torch.tensor([0.0, 1.0, 0.0], device=device) + + gt_ext = utils3d.torch.extrinsics_look_at(cam_pos, look_at, cam_up) + gt_int = utils3d.torch.intrinsics_from_fov_xy( + torch.tensor(fov, device=device), + torch.tensor(fov, device=device) + ) + gt_ext = gt_ext.to(device) + gt_int = gt_int.to(device) + + # Update near/far for the smaller scaled mesh + mesh_half_size = 0.5 / scale + renderer.rendering_options.near = max(0.01, distance - mesh_half_size - 0.5) + renderer.rendering_options.far = distance + mesh_half_size + 0.5 + + gt_res = renderer.render(scaled_rep, gt_ext, gt_int, envmap=envmap) + for k, v in gt_res.items(): + gt_key = f'gt_view_{k}' + if gt_key not in gt_view_images: + gt_view_images[gt_key] = [] + gt_view_images[gt_key].append(v) + except RuntimeError as e: + print(f"[visualize_sample] Warning: GT view render failed for sample {i}: {e}") + try: + torch.cuda.synchronize() + except Exception: + pass + try: + torch.cuda.empty_cache() + except Exception: + pass + + for k in images.keys(): + images[k] = torch.stack(images[k], dim=0) + + for k, v in gt_view_images.items(): + images[k] = torch.stack(v) + + return images + + +class SLatPbr(SLatPbrVisMixin, StandardDatasetBase): + """ + structured latent for sparse voxel pbr dataset + + Args: + roots (str): path to the dataset + latent_key (str): key of the latent to be used + min_aesthetic_score (float): minimum aesthetic score + normalization (dict): normalization stats + resolution (int): resolution of decoded sparse voxel + attrs (list): attributes to be decoded + pretained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + """ + def __init__(self, + roots: str, + *, + resolution: int, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + full_pbr: bool = False, + pbr_slat_normalization: Optional[dict] = None, + shape_slat_normalization: Optional[dict] = None, + attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'], + pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16', + pbr_slat_dec_path: Optional[str] = None, + pbr_slat_dec_ckpt: Optional[str] = None, + pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + shape_slat_dec_path: Optional[str] = None, + shape_slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + self.resolution = resolution + self.pbr_slat_normalization = pbr_slat_normalization + self.shape_slat_normalization = shape_slat_normalization + self.min_aesthetic_score = min_aesthetic_score + self.max_tokens = max_tokens + self.full_pbr = full_pbr + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_pbr_slat_dec=pretrained_pbr_slat_dec, + pbr_slat_dec_path=pbr_slat_dec_path, + pbr_slat_dec_ckpt=pbr_slat_dec_ckpt, + pretrained_shape_slat_dec=pretrained_shape_slat_dec, + shape_slat_dec_path=shape_slat_dec_path, + shape_slat_dec_ckpt=shape_slat_dec_ckpt, + **kwargs + ) + + self.loads = [self.metadata.loc[sha256, 'pbr_latent_tokens'] for _, sha256, _ in self.instances] + + if self.pbr_slat_normalization is not None: + self.pbr_slat_mean = torch.tensor(self.pbr_slat_normalization['mean']).reshape(1, -1) + self.pbr_slat_std = torch.tensor(self.pbr_slat_normalization['std']).reshape(1, -1) + + if self.shape_slat_normalization is not None: + self.shape_slat_mean = torch.tensor(self.shape_slat_normalization['mean']).reshape(1, -1) + self.shape_slat_std = torch.tensor(self.shape_slat_normalization['std']).reshape(1, -1) + + self.attrs = attrs + self.channels = { + 'base_color': 3, + 'metallic': 1, + 'roughness': 1, + 'emissive': 3, + 'alpha': 1, + } + self.layout = {} + start = 0 + for attr in attrs: + self.layout[attr] = slice(start, start + self.channels[attr]) + start += self.channels[attr] + + def filter_metadata(self, metadata, dataset_name=None): + stats = {} + metadata = metadata[metadata['pbr_latent_encoded'] == True] + stats['With PBR latent'] = len(metadata) + metadata = metadata[metadata['shape_latent_encoded'] == True] + stats['With shape latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata['pbr_latent_tokens'] <= self.max_tokens] + stats[f'Num tokens <= {self.max_tokens}'] = len(metadata) + if self.full_pbr: + metadata = metadata[metadata['num_basecolor_tex'] > 0] + metadata = metadata[metadata['num_metallic_tex'] > 0] + metadata = metadata[metadata['num_roughness_tex'] > 0] + stats['Full PBR'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + # PBR latent + data = np.load(os.path.join(root['pbr_latent'], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1) + feats = torch.tensor(data['feats']).float() + if self.pbr_slat_normalization is not None: + feats = (feats - self.pbr_slat_mean) / self.pbr_slat_std + pbr_z = SparseTensor(feats, coords) + + # Shape latent + data = np.load(os.path.join(root['shape_latent'], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1) + feats = torch.tensor(data['feats']).float() + if self.shape_slat_normalization is not None: + feats = (feats - self.shape_slat_mean) / self.shape_slat_std + shape_z = SparseTensor(feats, coords) + + assert torch.equal(shape_z.coords, pbr_z.coords), \ + f"Shape latent and PBR latent have different coordinates: {shape_z.coords.shape} vs {pbr_z.coords.shape}" + + return { + 'x_0': pbr_z, + 'concat_cond': shape_z, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['x_0'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], SparseTensor): + pack[k] = sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class ImageConditionedSLatPbr(ImageConditionedMixin, SLatPbr): + """ + Image conditioned structured latent dataset + """ + pass + + +class SLatPbrView(SLatPbrVisMixin, StandardDatasetBase): + """ + View-based structured latent for PBR/texture generation with view-aligned projection. + + Data format: + PBR latent: {sha256}/view{XX}.npz (coords + feats) + Shape latent: {sha256}/view{XX}.npz (coords + feats, from shape_latent_view dir) + + Each view's PBR latent and Shape latent share the same sparse coordinates. + + Args: + roots (str): path to the dataset + resolution (int): resolution of decoded sparse voxel + min_aesthetic_score (float): minimum aesthetic score + max_tokens (int): maximum number of tokens + num_views (int): Number of views to use (0 to num_views-1). Default is 2. + full_pbr (bool): Whether to require full PBR textures + pbr_slat_normalization (dict): normalization stats for PBR latent + shape_slat_normalization (dict): normalization stats for shape latent + attrs (list): PBR attributes to decode + pretrained_pbr_slat_dec (str): pretrained PBR decoder name + pretrained_shape_slat_dec (str): pretrained shape decoder name + skip_list (str, optional): path to a file containing sha256 hashes to skip + skip_aesthetic_score_datasets (list, optional): datasets to skip aesthetic score check + """ + def __init__(self, + roots: str, + *, + resolution: int, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + num_views: int = 2, + full_pbr: bool = False, + pbr_slat_normalization: Optional[dict] = None, + shape_slat_normalization: Optional[dict] = None, + attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'], + pretrained_pbr_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16', + pbr_slat_dec_path: Optional[str] = None, + pbr_slat_dec_ckpt: Optional[str] = None, + pretrained_shape_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + shape_slat_dec_path: Optional[str] = None, + shape_slat_dec_ckpt: Optional[str] = None, + skip_list: Optional[str] = None, + skip_aesthetic_score_datasets: Optional[list] = None, + ): + self.resolution = resolution + self.pbr_slat_normalization = pbr_slat_normalization + self.shape_slat_normalization = shape_slat_normalization + self.min_aesthetic_score = min_aesthetic_score + self.max_tokens = max_tokens + self.num_views = num_views + self.full_pbr = full_pbr + self.value_range = (0, 1) + self.skip_aesthetic_score_datasets = set(skip_aesthetic_score_datasets or []) + + # Initialize visualization mixin + SLatPbrVisMixin.__init__( + self, + roots, + pretrained_pbr_slat_dec=pretrained_pbr_slat_dec, + pbr_slat_dec_path=pbr_slat_dec_path, + pbr_slat_dec_ckpt=pbr_slat_dec_ckpt, + pretrained_shape_slat_dec=pretrained_shape_slat_dec, + shape_slat_dec_path=shape_slat_dec_path, + shape_slat_dec_ckpt=shape_slat_dec_ckpt, + ) + StandardDatasetBase.__init__( + self, roots, + skip_list=skip_list, + skip_aesthetic_score_datasets=skip_aesthetic_score_datasets, + ) + + # Calculate loads for load balancing + self.loads = [] + for _, sha256, _ in self.instances: + if 'pbr_latent_tokens' in self.metadata.columns: + try: + self.loads.append(self.metadata.loc[sha256, 'pbr_latent_tokens']) + except: + self.loads.append(self.max_tokens) + else: + self.loads.append(self.max_tokens) + + if self.pbr_slat_normalization is not None: + self.pbr_slat_mean = torch.tensor(self.pbr_slat_normalization['mean']).reshape(1, -1) + self.pbr_slat_std = torch.tensor(self.pbr_slat_normalization['std']).reshape(1, -1) + + if self.shape_slat_normalization is not None: + self.shape_slat_mean = torch.tensor(self.shape_slat_normalization['mean']).reshape(1, -1) + self.shape_slat_std = torch.tensor(self.shape_slat_normalization['std']).reshape(1, -1) + + self.attrs = attrs + self.channels = { + 'base_color': 3, + 'metallic': 1, + 'roughness': 1, + 'emissive': 3, + 'alpha': 1, + } + self.layout = {} + start = 0 + for attr in attrs: + self.layout[attr] = slice(start, start + self.channels[attr]) + start += self.channels[attr] + + def filter_metadata(self, metadata, dataset_name=None): + stats = {} + # View-based PBR latent uses columns like pbr_latent_view00_encoded, etc. + required_pbr_view_cols = [f'pbr_latent_view{i:02d}_encoded' for i in range(self.num_views)] + existing_pbr_view_cols = [col for col in required_pbr_view_cols if col in metadata.columns] + + if existing_pbr_view_cols: + has_all_pbr_views = (metadata[existing_pbr_view_cols] == True).all(axis=1) + metadata = metadata[has_all_pbr_views] + stats[f'With {self.num_views} PBR view latents'] = len(metadata) + else: + # Fallback: check pbr_latent_encoded + if 'pbr_latent_encoded' in metadata.columns: + metadata = metadata[metadata['pbr_latent_encoded'] == True] + stats['With PBR latent'] = len(metadata) + + # Also require shape latent views + required_shape_view_cols = [f'shape_latent_view{i:02d}_encoded' for i in range(self.num_views)] + existing_shape_view_cols = [col for col in required_shape_view_cols if col in metadata.columns] + + if existing_shape_view_cols: + has_all_shape_views = (metadata[existing_shape_view_cols] == True).all(axis=1) + metadata = metadata[has_all_shape_views] + stats[f'With {self.num_views} shape view latents'] = len(metadata) + else: + if 'shape_latent_encoded' in metadata.columns: + metadata = metadata[metadata['shape_latent_encoded'] == True] + stats['With shape latent'] = len(metadata) + + # Skip aesthetic score check for specified datasets + skip_aesthetic = ( + (dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or + ('aesthetic_score' not in metadata.columns) + ) + if skip_aesthetic: + stats[f'Aesthetic score check skipped'] = len(metadata) + else: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + + # Filter by max_tokens if column exists + if 'pbr_latent_tokens' in metadata.columns: + metadata = metadata[metadata['pbr_latent_tokens'] <= self.max_tokens] + stats[f'Num tokens <= {self.max_tokens}'] = len(metadata) + + if self.full_pbr: + if 'num_basecolor_tex' in metadata.columns: + metadata = metadata[metadata['num_basecolor_tex'] > 0] + if 'num_metallic_tex' in metadata.columns: + metadata = metadata[metadata['num_metallic_tex'] > 0] + if 'num_roughness_tex' in metadata.columns: + metadata = metadata[metadata['num_roughness_tex'] > 0] + stats['Full PBR'] = len(metadata) + + return metadata, stats + + def get_instance(self, root, instance): + # Randomly select a view from the configured range + view_idx = np.random.randint(0, self.num_views) + view_file = f'view{view_idx:02d}.npz' + + # Store view info for ViewImageConditionedMixin + self._current_view_idx = view_idx + + # Load PBR latent for this view + pbr_latent_dir = os.path.join(root['pbr_latent'], instance) + self._current_latent_dir = pbr_latent_dir + + data = np.load(os.path.join(pbr_latent_dir, view_file)) + pbr_coords = torch.tensor(data['coords']).int() + pbr_feats = torch.tensor(data['feats']).float() + if self.pbr_slat_normalization is not None: + pbr_feats = (pbr_feats - self.pbr_slat_mean) / self.pbr_slat_std + + # Load Shape latent for this view (as concat_cond) + shape_latent_dir = os.path.join(root['shape_latent'], instance) + data = np.load(os.path.join(shape_latent_dir, view_file)) + shape_coords = torch.tensor(data['coords']).int() + shape_feats = torch.tensor(data['feats']).float() + if self.shape_slat_normalization is not None: + shape_feats = (shape_feats - self.shape_slat_mean) / self.shape_slat_std + + # Verify coordinates match + assert torch.equal(pbr_coords, shape_coords), \ + f"PBR and shape latent coordinates mismatch for {instance}/view{view_idx:02d}" + + return { + 'coords': pbr_coords, + 'pbr_feats': pbr_feats, + 'shape_feats': shape_feats, + 'view_idx': view_idx, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + # Build x_0 (PBR latent) and concat_cond (shape latent) as SparseTensors + coords_list = [] + pbr_feats_list = [] + shape_feats_list = [] + layout = [] + start = 0 + for i, b in enumerate(sub_batch): + batch_coords = torch.cat([ + torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), + b['coords'] + ], dim=-1) + coords_list.append(batch_coords) + pbr_feats_list.append(b['pbr_feats']) + shape_feats_list.append(b['shape_feats']) + layout.append(slice(start, start + b['coords'].shape[0])) + start += b['coords'].shape[0] + + all_coords = torch.cat(coords_list) + + # x_0: PBR latent + pack['x_0'] = SparseTensor( + coords=all_coords, + feats=torch.cat(pbr_feats_list), + ) + pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['pbr_feats'].shape[1:]]) + pack['x_0'].register_spatial_cache('layout', layout) + + # concat_cond: Shape latent (same coordinates) + pack['concat_cond'] = SparseTensor( + coords=all_coords.clone(), + feats=torch.cat(shape_feats_list), + ) + pack['concat_cond']._shape = torch.Size([len(group), *sub_batch[0]['shape_feats'].shape[1:]]) + pack['concat_cond'].register_spatial_cache('layout', layout) + + # collate other data (excluding already handled fields) + skip_keys = {'coords', 'pbr_feats', 'shape_feats'} + keys = [k for k in sub_batch[0].keys() if k not in skip_keys] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class ViewImageConditionedSLatPbrView(ViewImageConditionedMixin, SLatPbrView): + """ + Image-conditioned view-based structured latent for PBR/texture generation + with view-aligned projection. + + Loads PBR latent and shape latent from {sha256}/view{XX}.npz format and pairs + with corresponding view from render_cond. + + Uses ViewImageConditionedMixin which reads mesh_scale from view{XX}_scale.json + and provides camera parameters for 3D-to-2D projection. + """ + pass diff --git a/trellis2/models/__init__.py b/trellis2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4fed035ed7c1f6352d93e9787f1aaed072876d5 --- /dev/null +++ b/trellis2/models/__init__.py @@ -0,0 +1,78 @@ +import importlib + +__attributes = { + # Sparse Structure + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + 'SparseStructureFlowModel': 'sparse_structure_flow', + + # SLat Generation + 'SLatFlowModel': 'structured_latent_flow', + 'ElasticSLatFlowModel': 'structured_latent_flow', + + # SC-VAEs + 'SparseUnetVaeEncoder': 'sc_vaes.sparse_unet_vae', + 'SparseUnetVaeDecoder': 'sc_vaes.sparse_unet_vae', + 'FlexiDualGridVaeEncoder': 'sc_vaes.fdg_vae', + 'FlexiDualGridVaeDecoder': 'sc_vaes.fdg_vae' +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file), strict=False) + + return model + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder + from .sparse_structure_flow import SparseStructureFlowModel + from .structured_latent_flow import SLatFlowModel, ElasticSLatFlowModel + + from .sc_vaes.sparse_unet_vae import SparseUnetVaeEncoder, SparseUnetVaeDecoder + from .sc_vaes.fdg_vae import FlexiDualGridVaeEncoder, FlexiDualGridVaeDecoder diff --git a/trellis2/models/sc_vaes/fdg_vae.py b/trellis2/models/sc_vaes/fdg_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..0209e690ea6f4096dd42a4d137dbc0bee1f51367 --- /dev/null +++ b/trellis2/models/sc_vaes/fdg_vae.py @@ -0,0 +1,110 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .sparse_unet_vae import ( + SparseResBlock3d, + SparseConvNeXtBlock3d, + + SparseResBlockDownsample3d, + SparseResBlockUpsample3d, + SparseResBlockS2C3d, + SparseResBlockC2S3d, +) +from .sparse_unet_vae import ( + SparseUnetVaeEncoder, + SparseUnetVaeDecoder, +) +from ...representations import Mesh +from o_voxel.convert import flexible_dual_grid_to_mesh + + +class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): + def __init__( + self, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__( + 6, + model_channels, + latent_channels, + num_blocks, + block_type, + down_block_type, + block_args, + use_fp16, + ) + + def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False): + x = vertices.replace(torch.cat([ + vertices.feats - 0.5, + intersected.feats.float() - 0.5, + ], dim=1)) + return super().forward(x, sample_posterior, return_raw) + + +class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): + def __init__( + self, + resolution: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + voxel_margin: float = 0.5, + use_fp16: bool = False, + ): + self.resolution = resolution + self.voxel_margin = voxel_margin + + super().__init__( + 7, + model_channels, + latent_channels, + num_blocks, + block_type, + up_block_type, + block_args, + use_fp16, + ) + + def set_resolution(self, resolution: int) -> None: + self.resolution = resolution + + def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs): + decoded = super().forward(x, **kwargs) + if self.training: + h, subs_gt, subs = decoded + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected_logits = h.replace(h.feats[..., 3:6]) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + v.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=True + )) for v, i, q in zip(vertices, gt_intersected, quad_lerp)] + return mesh, vertices, intersected_logits, subs_gt, subs + else: + out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] + h = out_list[0] + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected = h.replace(h.feats[..., 3:6] > 0) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + v.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=False + )) for v, i, q in zip(vertices, intersected, quad_lerp)] + out_list[0] = mesh + return out_list[0] if len(out_list) == 1 else tuple(out_list) diff --git a/trellis2/models/sc_vaes/sparse_unet_vae.py b/trellis2/models/sc_vaes/sparse_unet_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..b9902a155a8a85c5c616a4503be92f43f6fdde27 --- /dev/null +++ b/trellis2/models/sc_vaes/sparse_unet_vae.py @@ -0,0 +1,522 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from ...modules.utils import convert_module_to_f16, convert_module_to_f32, zero_module +from ...modules import sparse as sp +from ...modules.norm import LayerNorm32 + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest', + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + self.resample_mode = resample_mode + self.use_checkpoint = use_checkpoint + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + if resample_mode == 'nearest': + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + elif resample_mode =='spatial2channel' and not self.downsample: + self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) + elif resample_mode =='spatial2channel' and self.downsample: + self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + if resample_mode == 'nearest': + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + elif resample_mode =='spatial2channel' and self.downsample: + self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) + elif resample_mode =='spatial2channel' and not self.downsample: + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + self.updown = None + if self.downsample: + if resample_mode == 'nearest': + self.updown = sp.SparseDownsample(2) + elif resample_mode =='spatial2channel': + self.updown = sp.SparseSpatial2Channel(2) + elif self.upsample: + self.to_subdiv = sp.SparseLinear(channels, 8) + if resample_mode == 'nearest': + self.updown = sp.SparseUpsample(2) + elif resample_mode =='spatial2channel': + self.updown = sp.SparseChannel2Spatial(2) + + def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.downsample: + x = self.updown(x) + elif self.upsample: + x = self.updown(x, subdiv.replace(subdiv.feats > 0)) + return x + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + subdiv = None + if self.upsample: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + if self.resample_mode == 'spatial2channel': + h = self.conv1(h) + h = self._updown(h, subdiv) + x = self._updown(x, subdiv) + if self.resample_mode == 'nearest': + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.upsample: + return h, subdiv + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockDownsample3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = sp.SparseDownsample(2) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.updown(h) + x = self.updown(x) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockUpsample3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + if self.pred_subdiv: + self.to_subdiv = sp.SparseLinear(channels, 8) + self.updown = sp.SparseUpsample(2) + + def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockS2C3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) + self.updown = sp.SparseSpatial2Channel(2) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = self.updown(h) + x = self.updown(x) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockC2S3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + if pred_subdiv: + self.to_subdiv = sp.SparseLinear(channels, 8) + self.updown = sp.SparseChannel2Spatial(2) + + def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + + def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False) + else: + return self._forward(x, subdiv) + + +class SparseConvNeXtBlock3d(nn.Module): + def __init__( + self, + channels: int, + mlp_ratio: float = 4.0, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.use_checkpoint = use_checkpoint + + self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.conv = sp.SparseConv3d(channels, channels, 3) + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.SiLU(), + zero_module(nn.Linear(int(channels * mlp_ratio), channels)), + ) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.conv(x) + h = h.replace(self.norm(h.feats)) + h = h.replace(self.mlp(h.feats)) + return h + x + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseUnetVaeEncoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + in_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = sp.SparseLinear(in_channels, model_channels[0]) + self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[down_block_type[i]]( + model_channels[i], + model_channels[i+1], + **block_args[i], + ) + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False): + h = self.input_layer(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.to_latent(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + +class SparseUnetVaeDecoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + out_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.use_fp16 = use_fp16 + self.pred_subdiv = pred_subdiv + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.low_vram = False + + self.output_layer = sp.SparseLinear(model_channels[-1], out_channels) + self.from_latent = sp.SparseLinear(latent_channels, model_channels[0]) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[up_block_type[i]]( + model_channels[i], + model_channels[i+1], + pred_subdiv=pred_subdiv, + **block_args[i], + ) + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor: + assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs" + assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs" + + h = self.from_latent(x) + h = h.type(self.dtype) + subs_gt = [] + subs = [] + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + if self.pred_subdiv: + if self.training: + subs_gt.append(h.get_spatial_cache('subdivision')) + h, sub = block(h) + subs.append(sub) + else: + h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) + else: + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.output_layer(h) + if self.training and self.pred_subdiv: + return h, subs_gt, subs + else: + if return_subs: + return h, subs + else: + return h + + def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor: + assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling" + + h = self.from_latent(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + if i == upsample_times: + return h.coords + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + h, sub = block(h) + else: + h = block(h) + \ No newline at end of file diff --git a/trellis2/models/sparse_elastic_mixin.py b/trellis2/models/sparse_elastic_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..66d204c89bedabc2afd1795cdfc6f5d58a6b1ac0 --- /dev/null +++ b/trellis2/models/sparse_elastic_mixin.py @@ -0,0 +1,24 @@ +from contextlib import contextmanager +from typing import * +import math +from ..modules import sparse as sp +from ..utils.elastic_utils import ElasticModuleMixin + + +class SparseTransformerElasticMixin(ElasticModuleMixin): + def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): + return x.feats.shape[0] + + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0): + if mem_ratio == 1.0: + yield 1.0 + return + num_blocks = len(self.blocks) + num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks) + exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks + for i in range(num_blocks): + self.blocks[i].use_checkpoint = i < num_checkpoint_blocks + yield exact_mem_ratio + for i in range(num_blocks): + self.blocks[i].use_checkpoint = False diff --git a/trellis2/models/sparse_structure_flow.py b/trellis2/models/sparse_structure_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..be6b50eac0095d3f57b24c134b893342edc6f101 --- /dev/null +++ b/trellis2/models/sparse_structure_flow.py @@ -0,0 +1,298 @@ +from typing import * +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to, manual_cast, str_to_dtype +from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ..modules.attention import RotaryPositionEmbedder + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + """ + Sparse Structure Flow Model for 3D generation. + + Supports two conditioning modes: + - "cross": Standard cross-attention with image features + - "proj": View-aligned projection attention with camera-aware features + """ + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross", + proj_in_channels: Optional[int] = None, + vae_in_channels: Optional[int] = None, + **kwargs + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.image_attn_mode = image_attn_mode + self.proj_in_channels = proj_in_channels + self.vae_in_channels = vae_in_channels + self.dtype = str_to_dtype(dtype) + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + elif pe_mode == "rope": + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + rope_phases = pos_embedder(coords) + self.register_buffer("rope_phases", rope_phases) + + if pe_mode != "rope": + self.rope_phases = None + + self.input_layer = nn.Linear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + image_attn_mode=image_attn_mode, + proj_in_channels=proj_in_channels, + vae_in_channels=vae_in_channels, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to(self, dtype: torch.dtype) -> None: + """ + Convert the torso of the model to the specified dtype. + """ + self.dtype = dtype + self.blocks.apply(partial(convert_module_to, dtype=dtype)) + + def initialize_weights(self) -> None: + if self.initialization == 'vanilla': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + elif self.initialization == 'scaled': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Scaled init for to_out and ffn2 + def _scaled_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + for block in self.blocks: + block.self_attn.to_out.apply(_scaled_init) + # Handle cross, proj, and gated_proj modes + if self.image_attn_mode in ("proj", "gated_proj"): + block.cross_attn.cross_attn_block.to_out.apply(_scaled_init) + else: + block.cross_attn.to_out.apply(_scaled_init) + block.mlp.mlp[2].apply(_scaled_init) + + # Initialize input layer to make the initial representation have variance 1 + nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) + nn.init.zeros_(self.input_layer.bias) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + + Args: + x: Input tensor [B, C, D, H, W] + t: Timestep tensor [B] + cond: Conditioning tensor. For "cross" mode: [B, N, D]. + For "proj" mode: dict {'global': global_cond, 'proj': proj_cond} + or tuple of (global_cond, proj_cond) + + Returns: + Output tensor [B, C, D, H, W] + """ + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + if self.pe_mode == "ape": + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + h = manual_cast(h, self.dtype) + + # Handle different conditioning modes + if self.image_attn_mode == 'proj': + if isinstance(cond, dict): + global_cond = cond['global'] + proj_cond = cond['proj'] + else: + global_cond, proj_cond = cond + global_cond = manual_cast(global_cond, self.dtype) + proj_cond = manual_cast(proj_cond, self.dtype) + cond = (global_cond, proj_cond) + elif self.image_attn_mode == 'gated_proj': + global_cond = manual_cast(cond['global'], self.dtype) + proj_semantic = manual_cast(cond['proj_semantic'], self.dtype) + proj_color = manual_cast(cond['proj_color'], self.dtype) + cond = {'global': global_cond, 'proj_semantic': proj_semantic, 'proj_color': proj_color} + else: + cond = manual_cast(cond, self.dtype) + + for block in self.blocks: + h = block(h, t_emb, cond, self.rope_phases) + h = manual_cast(h, x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() + + return h diff --git a/trellis2/models/sparse_structure_vae.py b/trellis2/models/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e09136cf294c4c1b47b0f09fa6ee57bad2166d --- /dev/null +++ b/trellis2/models/sparse_structure_vae.py @@ -0,0 +1,306 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h diff --git a/trellis2/models/structured_latent_flow.py b/trellis2/models/structured_latent_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..23d78ad65c62aadf3442de2d81efd12dc2490ce2 --- /dev/null +++ b/trellis2/models/structured_latent_flow.py @@ -0,0 +1,265 @@ +from typing import * +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to, manual_cast, str_to_dtype +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder +from .sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SLatFlowModel(nn.Module): + """ + Structured Latent Flow Model for 3D generation. + + Supports two conditioning modes: + - "cross": Standard cross-attention with image features + - "proj": View-aligned projection attention with camera-aware features + """ + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross", + proj_in_channels: Optional[int] = None, + vae_in_channels: Optional[int] = None, + **kwargs + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.image_attn_mode = image_attn_mode + self.proj_in_channels = proj_in_channels + self.vae_in_channels = vae_in_channels + self.dtype = str_to_dtype(dtype) + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + image_attn_mode=image_attn_mode, + proj_in_channels=proj_in_channels, + vae_in_channels=vae_in_channels, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = sp.SparseLinear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to(self, dtype: torch.dtype) -> None: + """ + Convert the torso of the model to the specified dtype. + """ + self.dtype = dtype + self.blocks.apply(partial(convert_module_to, dtype=dtype)) + + def initialize_weights(self) -> None: + if self.initialization == 'vanilla': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + elif self.initialization == 'scaled': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Scaled init for to_out and ffn2 + def _scaled_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + for block in self.blocks: + block.self_attn.to_out.apply(_scaled_init) + # Handle cross, proj, and gated_proj modes + if self.image_attn_mode in ("proj", "gated_proj"): + block.cross_attn.cross_attn_block.to_out.apply(_scaled_init) + else: + block.cross_attn.to_out.apply(_scaled_init) + block.mlp.mlp[2].apply(_scaled_init) + + # Initialize input layer to make the initial representation have variance 1 + nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) + nn.init.zeros_(self.input_layer.bias) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward( + self, + x: sp.SparseTensor, + t: torch.Tensor, + cond: Union[torch.Tensor, List[torch.Tensor], Dict[str, Union[torch.Tensor, sp.SparseTensor]], Tuple], + concat_cond: Optional[sp.SparseTensor] = None, + **kwargs + ) -> sp.SparseTensor: + """ + Forward pass. + + Args: + x: SparseTensor input + t: Timestep tensor [B] + cond: Conditioning tensor. For "cross" mode: list of tensors or tensor. + For "proj" mode: dict {'global': global_cond, 'proj': proj_cond} + or tuple of (global_cond, proj_cond) + concat_cond: Optional concatenation condition + + Returns: + SparseTensor output + """ + if concat_cond is not None: + x = sp.sparse_cat([x, concat_cond], dim=-1) + + h = self.input_layer(x) + h = manual_cast(h, self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + + if self.pe_mode == "ape": + pe = self.pos_embedder(h.coords[:, 1:]) + h = h + manual_cast(pe, self.dtype) + + # Handle different conditioning modes + if self.image_attn_mode == 'proj': + if isinstance(cond, dict): + global_cond = cond['global'] + proj_cond = cond['proj'] + else: + global_cond, proj_cond = cond + if isinstance(global_cond, list): + global_cond = sp.VarLenTensor.from_tensor_list(global_cond) + global_cond = manual_cast(global_cond, self.dtype) + proj_cond = manual_cast(proj_cond, self.dtype) + cond = (global_cond, proj_cond) + elif self.image_attn_mode == 'gated_proj': + global_cond = cond['global'] + if isinstance(global_cond, list): + global_cond = sp.VarLenTensor.from_tensor_list(global_cond) + global_cond = manual_cast(global_cond, self.dtype) + proj_semantic = manual_cast(cond['proj_semantic'], self.dtype) + proj_color = manual_cast(cond['proj_color'], self.dtype) + cond = {'global': global_cond, 'proj_semantic': proj_semantic, 'proj_color': proj_color} + else: + if isinstance(cond, list): + cond = sp.VarLenTensor.from_tensor_list(cond) + cond = manual_cast(cond, self.dtype) + + for block in self.blocks: + h = block(h, t_emb, cond) + + h = manual_cast(h, x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return h + + +class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel): + """ + SLat Flow Model with elastic memory management. + Used for training with low VRAM. + """ + pass diff --git a/trellis2/modules/attention/__init__.py b/trellis2/modules/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..318c6d3b13cf0f2d6b814876a8efc7c9941806a8 --- /dev/null +++ b/trellis2/modules/attention/__init__.py @@ -0,0 +1,4 @@ +from .full_attn import * +from .modules import * +from .rope import * +from .proj_attention import ProjectAttention, GatedProjectAttention diff --git a/trellis2/modules/attention/config.py b/trellis2/modules/attention/config.py new file mode 100644 index 0000000000000000000000000000000000000000..579db837b858c8b2f424732c8633489a577079e2 --- /dev/null +++ b/trellis2/modules/attention/config.py @@ -0,0 +1,32 @@ +from typing import * + +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_attn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'flash_attn_4', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_attn_debug is not None: + DEBUG = env_attn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn', 'flash_attn_3', 'flash_attn_4', 'sdpa', 'naive']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug diff --git a/trellis2/modules/attention/full_attn.py b/trellis2/modules/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0216288bb3e4b52673488b0b011e004eda68fd --- /dev/null +++ b/trellis2/modules/attention/full_attn.py @@ -0,0 +1,153 @@ +from typing import * +import torch +import math +from . import config + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if config.BACKEND == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif config.BACKEND == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif config.BACKEND == 'flash_attn_3': + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + if num_all_args == 1: + out = flash_attn_3.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = flash_attn_3.flash_attn_func(q, k, v) + elif num_all_args == 3: + out = flash_attn_3.flash_attn_func(q, k, v) + elif config.BACKEND == 'flash_attn_4': + if 'flash_attn_4' not in globals(): + from flash_attn.cute import flash_attn_func as flash_attn_4_func + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out, _ = flash_attn_4_func(q, k, v) + elif config.BACKEND == 'sdpa': + if 'sdpa' not in globals(): + from torch.nn.functional import scaled_dot_product_attention as sdpa + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif config.BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {config.BACKEND}") + + return out diff --git a/trellis2/modules/attention/modules.py b/trellis2/modules/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..492784c7ba8f572c4820b604f51c924ed564ab00 --- /dev/null +++ b/trellis2/modules/attention/modules.py @@ -0,0 +1,102 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention +from .rope import RotaryPositionEmbedder + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + + if self.attn_mode == "full": + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) + k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/trellis2/modules/attention/proj_attention.py b/trellis2/modules/attention/proj_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..77011c1189e596e3c926ed31a94cfa2d14d6fe46 --- /dev/null +++ b/trellis2/modules/attention/proj_attention.py @@ -0,0 +1,101 @@ +""" +View-Aligned Projection Attention Module for TRELLIS2 + +This module implements the projection-based attention mechanism that combines +global cross-attention with view-aligned projected features. + +Supports two modes: +- "proj": Standard projection (DINOv3 only), per-block proj_linear +- "gated_proj": Gated fusion of DINOv3 (semantic) + VAE (color) features +""" + +from typing import * +import torch +import torch.nn as nn + + +class ProjectAttention(nn.Module): + """ + Projection-based Attention Module with per-block proj_linear. + + Combines global cross-attention with view-aligned projected features. + Each block owns a proj_linear that projects DINOv3 features from + proj_in_channels (e.g. 1024) to model_channels (e.g. 1536). + + The module receives: + - x: Input features from the transformer + - context: A dict with keys: + - 'global': Global image features, shape [B, M, ctx_channels] + - 'proj': View-aligned projected features, shape [B, N, proj_in_channels] + + The output combines the cross-attention result with the projected context. + """ + def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int): + super().__init__() + self.cross_attn_block = cross_attn_block + self.proj_linear = nn.Linear(proj_in_channels, channels, bias=True) + + def forward(self, x: torch.Tensor, context: Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor: + if isinstance(context, dict): + global_context = context['global'] + proj_context = context['proj'] + else: + global_context, proj_context = context + + global_out = self.cross_attn_block(x, global_context) + proj_out = self.proj_linear(proj_context) + context_combined = proj_out + global_out + return context_combined + + +class GatedProjectAttention(nn.Module): + """ + Concat-Projection Attention Module for DINOv3 (semantic) + VAE (color) features. + + Concatenates DINOv3 and VAE projected features and applies a single linear + projection to model_channels. This is mathematically equivalent to two + separate proj_linears + addition, but allows cross-dimensional interactions + between semantic and color features through the shared weight matrix. + + Zero-initialized for stable training: at init, fused=0 so only global + cross-attention contributes; color+semantic signals are gradually learned. + + The module receives: + - x: Input features from the transformer + - context: A dict with keys: + - 'global': Global image features, shape [B, M, ctx_channels] + - 'proj_semantic': DINOv3 projected features, shape [B, N, dino_channels] + - 'proj_color': VAE projected features, shape [B, N, vae_channels] + """ + def __init__( + self, + cross_attn_block: nn.Module, + channels: int, + dino_in_channels: int, + vae_in_channels: int, + ): + """ + Args: + cross_attn_block: The underlying cross-attention module + channels: Model channels (output dimension) + dino_in_channels: DINOv3 proj feature dimension (e.g. 1024) + vae_in_channels: VAE latent feature dimension (e.g. 16) + """ + super().__init__() + self.cross_attn_block = cross_attn_block + self.proj_linear = nn.Linear(dino_in_channels + vae_in_channels, channels, bias=True) + # Zero-init: at start, fused=0, only global cross-attn contributes + nn.init.zeros_(self.proj_linear.weight) + nn.init.zeros_(self.proj_linear.bias) + + def forward(self, x: torch.Tensor, context: Union[Dict[str, torch.Tensor], Tuple]) -> torch.Tensor: + if isinstance(context, dict): + global_context = context['global'] + proj_semantic = context['proj_semantic'] + proj_color = context['proj_color'] + else: + global_context, proj_semantic, proj_color = context + + global_out = self.cross_attn_block(x, global_context) + fused = self.proj_linear(torch.cat([proj_semantic, proj_color], dim=-1)) + return fused + global_out diff --git a/trellis2/modules/attention/rope.py b/trellis2/modules/attention/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf6c5b321c2417443ffad3804df97ff5fbe6658 --- /dev/null +++ b/trellis2/modules/attention/rope.py @@ -0,0 +1,48 @@ +from typing import * +import torch +import torch.nn as nn + + +class RotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + @staticmethod + def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, indices: torch.Tensor) -> torch.Tensor: + """ + Args: + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + return phases \ No newline at end of file diff --git a/trellis2/modules/image_feature_extractor.py b/trellis2/modules/image_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..c3cb515ad1152752b09fef9999cc5ff44bd885f7 --- /dev/null +++ b/trellis2/modules/image_feature_extractor.py @@ -0,0 +1,118 @@ +from typing import * +import torch +import torch.nn.functional as F +from torchvision import transforms +from transformers import DINOv3ViTModel +import numpy as np +from PIL import Image + + +class DinoV2FeatureExtractor: + """ + Feature extractor for DINOv2 models. + """ + def __init__(self, model_name: str): + self.model_name = model_name + self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) + self.model.eval() + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.model(image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + +class DinoV3FeatureExtractor: + """ + Feature extractor for DINOv3 models. + """ + def __init__(self, model_name: str, image_size=512): + self.model_name = model_name + self.model = DINOv3ViTModel.from_pretrained(model_name) + self.model.eval() + self.image_size = image_size + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def extract_features(self, image: torch.Tensor) -> torch.Tensor: + image = image.to(self.model.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.model.embeddings(image, bool_masked_pos=None) + position_embeddings = self.model.rope_embeddings(image) + + for i, layer_module in enumerate(self.model.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + return F.layer_norm(hidden_states, hidden_states.shape[-1:]) + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.extract_features(image) + return features diff --git a/trellis2/modules/norm.py b/trellis2/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..78675d0f850d2e34d3c90b5d6bc14db708e5b400 --- /dev/null +++ b/trellis2/modules/norm.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +from .utils import manual_cast + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/trellis2/modules/sparse/__init__.py b/trellis2/modules/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e73f232abc6f31cafeac4172a96150906ba20b7b --- /dev/null +++ b/trellis2/modules/sparse/__init__.py @@ -0,0 +1,69 @@ +from . import config +import importlib + +__attributes = { + 'VarLenTensor': 'basic', + 'varlen_cat': 'basic', + 'varlen_unbind': 'basic', + 'SparseTensor': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_cross_attention': 'attention', + 'SparseRotaryPositionEmbedder': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide': 'spatial', + 'SparseSpatial2Channel': 'spatial', + 'SparseChannel2Spatial': 'spatial', + 'sparse_nearest_interpolate': 'spatial', + 'sparse_trilinear_interpolate': 'spatial', + 'encode_seq': 'serialize', + 'decode_seq': 'serialize', +} + +__submodules = ['transformer', 'conv'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + from .serialize import * + import transformer + import conv diff --git a/trellis2/modules/sparse/attention/__init__.py b/trellis2/modules/sparse/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..695105b1997bb0ad13c94efa4ca04a95861adb1b --- /dev/null +++ b/trellis2/modules/sparse/attention/__init__.py @@ -0,0 +1,4 @@ +from .full_attn import * +from .windowed_attn import * +from .modules import * +from .proj_attention import * diff --git a/trellis2/modules/sparse/attention/full_attn.py b/trellis2/modules/sparse/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7d559768fb3d0c5d0b67b835a12084f09ccbf0 --- /dev/null +++ b/trellis2/modules/sparse/attention/full_attn.py @@ -0,0 +1,238 @@ +from typing import * +import torch +from .. import VarLenTensor +from .. import config + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: VarLenTensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (VarLenTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, kv: Union[VarLenTensor, torch.Tensor]) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (VarLenTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: VarLenTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, C] dense tensor containing Qs. + kv (VarLenTensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, k: VarLenTensor, v: VarLenTensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, k: torch.Tensor, v: torch.Tensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: VarLenTensor, v: VarLenTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, VarLenTensor), f"qkv must be a VarLenTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, VarLenTensor) and isinstance(kv, (VarLenTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, VarLenTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, VarLenTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, VarLenTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, VarLenTensor) and isinstance(k, (VarLenTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, VarLenTensor) and isinstance(v, VarLenTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, VarLenTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, VarLenTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif config.ATTN == 'flash_attn_3': + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + cu_seqlens_kv = cu_seqlens_q.clone() + max_q_seqlen = max_kv_seqlen = max(q_seqlen) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + elif num_all_args == 3: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + elif config.ATTN == 'flash_attn_4': + if 'flash_attn_4' not in globals(): + from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + cu_seqlens_kv = cu_seqlens_q.clone() + max_q_seqlen = max_kv_seqlen = max(q_seqlen) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + elif num_all_args == 3: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + out, _ = flash_attn_4_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + else: + raise ValueError(f"Unknown attention module: {config.ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/trellis2/modules/sparse/attention/modules.py b/trellis2/modules/sparse/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..d762b4b2e01335d26d856b3fe82eca31a2735123 --- /dev/null +++ b/trellis2/modules/sparse/attention/modules.py @@ -0,0 +1,141 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import VarLenTensor, SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from .rope import SparseRotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, VarLenTensor): + x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) + else: + x = F.normalize(x, dim=-1) * self.gamma * self.scale + return x.to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed", "double_windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed", "double_windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + if attn_mode == 'double_windowed': + assert window_size % 2 == 0, "Window size must be even for double windowed attention" + assert num_heads % 2 == 0, "Number of heads must be even for double windowed attention" + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + + @staticmethod + def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats + + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_scaled_dot_product_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_scaled_dot_product_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=-3) + k = self.k_rms_norm(k) + h = sparse_scaled_dot_product_attention(q, k, v) + else: + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/trellis2/modules/sparse/attention/proj_attention.py b/trellis2/modules/sparse/attention/proj_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..185c1dbf8bb480b1cc2267002c655f878e667afd --- /dev/null +++ b/trellis2/modules/sparse/attention/proj_attention.py @@ -0,0 +1,99 @@ +""" +Sparse View-Aligned Projection Attention Module for TRELLIS2 + +Sparse versions of ProjectAttention and GatedProjectAttention. + +Supports two modes: +- "proj": Standard projection (DINOv3 only) +- "gated_proj": Gated fusion of DINOv3 (semantic) + VAE (color) features +""" + +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor, VarLenTensor + + +class SparseProjectAttention(nn.Module): + """ + Sparse Projection-based Attention Module with per-block proj_linear. + """ + def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int): + super().__init__() + self.cross_attn_block = cross_attn_block + self.proj_linear = nn.Linear(proj_in_channels, channels, bias=True) + + def forward( + self, + x: SparseTensor, + context: Union[Dict[str, Union[torch.Tensor, VarLenTensor, SparseTensor]], + Tuple[Union[torch.Tensor, VarLenTensor], SparseTensor]] + ) -> SparseTensor: + if isinstance(context, dict): + global_context = context['global'] + proj_context = context['proj'] + else: + global_context, proj_context = context + + global_out = self.cross_attn_block(x, global_context) + + if isinstance(proj_context, SparseTensor): + proj_feats = self.proj_linear(proj_context.feats) + combined_feats = proj_feats + global_out.feats + else: + proj_feats = self.proj_linear(proj_context) + combined_feats = proj_feats + global_out.feats + + return global_out.replace(combined_feats) + + +class SparseGatedProjectAttention(nn.Module): + """ + Sparse Concat-Projection Attention Module for DINOv3 + VAE features. + + Concatenates DINOv3 and VAE projected features and applies a single linear + projection to model_channels. Zero-initialized for stable training. + + Context dict must contain: + - 'global': Global image features for cross-attention + - 'proj_semantic': DINOv3 projected features (SparseTensor or Tensor) + - 'proj_color': VAE projected features (SparseTensor or Tensor) + """ + def __init__( + self, + cross_attn_block: nn.Module, + channels: int, + dino_in_channels: int, + vae_in_channels: int, + ): + super().__init__() + self.cross_attn_block = cross_attn_block + self.proj_linear = nn.Linear(dino_in_channels + vae_in_channels, channels, bias=True) + # Zero-init: at start, fused=0, only global cross-attn contributes + nn.init.zeros_(self.proj_linear.weight) + nn.init.zeros_(self.proj_linear.bias) + + def _get_feats(self, t): + return t.feats if isinstance(t, SparseTensor) else t + + def forward( + self, + x: SparseTensor, + context: Union[Dict[str, Union[torch.Tensor, VarLenTensor, SparseTensor]], Tuple], + ) -> SparseTensor: + if isinstance(context, dict): + global_context = context['global'] + proj_semantic = context['proj_semantic'] + proj_color = context['proj_color'] + else: + global_context, proj_semantic, proj_color = context + + global_out = self.cross_attn_block(x, global_context) + + fused = self.proj_linear(torch.cat([ + self._get_feats(proj_semantic), + self._get_feats(proj_color), + ], dim=-1)) + combined_feats = fused + global_out.feats + + return global_out.replace(combined_feats) diff --git a/trellis2/modules/sparse/attention/rope.py b/trellis2/modules/sparse/attention/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..fb877291f3430f2cff4329a67b3592e0e3c3f137 --- /dev/null +++ b/trellis2/modules/sparse/attention/rope.py @@ -0,0 +1,58 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor + + +class SparseRotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (SparseTensor): [..., N, H, D] tensor of queries + k (SparseTensor): [..., N, H, D] tensor of keys + """ + assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" + phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' + phases = q.get_spatial_cache(phases_cache_name) + if phases is None: + coords = q.coords[..., 1:] + phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + q.register_spatial_cache(phases_cache_name, phases) + q_embed = q.replace(self._rotary_embedding(q.feats, phases)) + if k is None: + return q_embed + k_embed = k.replace(self._rotary_embedding(k.feats, phases)) + return q_embed, k_embed \ No newline at end of file diff --git a/trellis2/modules/sparse/attention/windowed_attn.py b/trellis2/modules/sparse/attention/windowed_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..675aa0870d98ea014206961ed947704bc6262970 --- /dev/null +++ b/trellis2/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,205 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import config + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', + 'sparse_windowed_scaled_dot_product_cross_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (torch.Tensor): Sequence lengths. + (dict): Attn func args. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + mask = seq_lens != 0 + seq_lens = seq_lens[mask] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + attn_func_args = { + 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + } + elif config.ATTN == 'flash_attn': + attn_func_args = { + 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen': torch.max(seq_lens) + } + elif config.ATTN == 'flash_attn_4': + attn_func_args = { + 'cu_seqlens_q': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'cu_seqlens_k': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen_q': torch.max(seq_lens), + 'max_seqlen_k': torch.max(seq_lens), + } + + return fwd_indices, bwd_indices, seq_lens, attn_func_args + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (SparseTensor): [N, *, H, C] sparse tensor containing the output features. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) + else: + fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if config.DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if config.DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) + + +def sparse_windowed_scaled_dot_product_cross_attention( + q: SparseTensor, + kv: SparseTensor, + q_window_size: int, + kv_window_size: int, + q_shift_window: Tuple[int, int, int] = (0, 0, 0), + kv_shift_window: Tuple[int, int, int] = (0, 0, 0), +) -> SparseTensor: + """ + Apply windowed scaled dot product cross attention to two sparse tensors. + + Args: + q (SparseTensor): [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor): [N, *, 2, H, C] sparse tensor containing Ks and Vs. + q_window_size (int): The window size to use for Qs. + kv_window_size (int): The window size to use for Ks and Vs. + q_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Qs. + kv_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Ks and Vs. + + Returns: + (SparseTensor): [N, *, H, C] sparse tensor containing the output features. + """ + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + + q_serialization_spatial_cache_name = f'windowed_attention_{q_window_size}_{q_shift_window}' + q_serialization_spatial_cache = q.get_spatial_cache(q_serialization_spatial_cache_name) + if q_serialization_spatial_cache is None: + q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = calc_window_partition(q, q_window_size, q_shift_window) + q.register_spatial_cache(q_serialization_spatial_cache_name, (q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args)) + else: + q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = q_serialization_spatial_cache + kv_serialization_spatial_cache_name = f'windowed_attention_{kv_window_size}_{kv_shift_window}' + kv_serialization_spatial_cache = kv.get_spatial_cache(kv_serialization_spatial_cache_name) + if kv_serialization_spatial_cache is None: + kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = calc_window_partition(kv, kv_window_size, kv_shift_window) + kv.register_spatial_cache(kv_serialization_spatial_cache_name, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args)) + else: + kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = kv_serialization_spatial_cache + + assert len(q_seq_lens) == len(kv_seq_lens), "Number of sequences in q and kv must match" + + q_feats = q.feats[q_fwd_indices] # [M, H, C] + kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + k, v = kv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens) + out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)[0] # [M, H, C] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats, kv_feats, + cu_seqlens_q=q_attn_func_args['cu_seqlens'], cu_seqlens_k=kv_attn_func_args['cu_seqlens'], + max_seqlen_q=q_attn_func_args['max_seqlen'], max_seqlen_k=kv_attn_func_args['max_seqlen'], + ) # [M, H, C] + elif config.ATTN == 'flash_attn_4': + if 'flash_attn_4' not in globals(): + from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func + k, v = kv_feats.unbind(dim=1) + out, _ = flash_attn_4_varlen_func(q_feats, k, v, + cu_seqlens_q=q_attn_func_args['cu_seqlens_q'], cu_seqlens_k=kv_attn_func_args['cu_seqlens_k'], + max_seqlen_q=q_attn_func_args['max_seqlen_q'], max_seqlen_k=kv_attn_func_args['max_seqlen_k'], + ) # [M, H, C] + + out = out[q_bwd_indices] # [T, H, C] + + return q.replace(out) diff --git a/trellis2/modules/sparse/basic.py b/trellis2/modules/sparse/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..880973b8dd6bdafcfca4ca7c529308d2ef2ad266 --- /dev/null +++ b/trellis2/modules/sparse/basic.py @@ -0,0 +1,836 @@ +from typing import * +from fractions import Fraction +import torch +from . import config + + +__all__ = [ + 'VarLenTensor', + 'varlen_cat', + 'varlen_unbind', + 'SparseTensor', + 'sparse_cat', + 'sparse_unbind', +] + + +class VarLenTensor: + """ + Sequential tensor with variable length. + + Args: + feats (torch.Tensor): Features of the varlen tensor. + layout (List[slice]): Layout of the varlen tensor for each batch + """ + def __init__(self, feats: torch.Tensor, layout: List[slice]=None): + self.feats = feats + self.layout = layout if layout is not None else [slice(0, feats.shape[0])] + self._cache = {} + + @staticmethod + def layout_from_seqlen(seqlen: list) -> List[slice]: + """ + Create a layout from a tensor of sequence lengths. + """ + layout = [] + start = 0 + for l in seqlen: + layout.append(slice(start, start + l)) + start += l + return layout + + @staticmethod + def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': + """ + Create a VarLenTensor from a list of tensors. + """ + feats = torch.cat(tensor_list, dim=0) + layout = [] + start = 0 + for tensor in tensor_list: + layout.append(slice(start, start + tensor.shape[0])) + start += tensor.shape[0] + return VarLenTensor(feats, layout) + + def to_tensor_list(self) -> List[torch.Tensor]: + """ + Convert a VarLenTensor to a list of tensors. + """ + tensor_list = [] + for s in self.layout: + tensor_list.append(self.feats[s]) + return tensor_list + + def __len__(self) -> int: + return len(self.layout) + + @property + def shape(self) -> torch.Size: + return torch.Size([len(self.layout), *self.feats.shape[1:]]) + + def dim(self) -> int: + return len(self.shape) + + @property + def ndim(self) -> int: + return self.dim() + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + if 'seqlen' not in self._cache: + self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + return self._cache['seqlen'] + + @property + def cum_seqlen(self) -> torch.LongTensor: + if 'cum_seqlen' not in self._cache: + self._cache['cum_seqlen'] = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + return self._cache['cum_seqlen'] + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + if 'batch_boardcast_map' not in self._cache: + self._cache['batch_boardcast_map'] = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + return self._cache['batch_boardcast_map'] + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + def to(self, *args, **kwargs) -> 'VarLenTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'VarLenTensor': + new_feats = self.feats.cpu() + return self.replace(new_feats) + + def cuda(self) -> 'VarLenTensor': + new_feats = self.feats.cuda() + return self.replace(new_feats) + + def half(self) -> 'VarLenTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'VarLenTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'VarLenTensor': + new_feats = self.feats.detach() + return self.replace(new_feats) + + def reshape(self, *shape) -> 'VarLenTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['VarLenTensor']: + return varlen_unbind(self, dim) + + def replace(self, feats: torch.Tensor) -> 'VarLenTensor': + new_tensor = VarLenTensor( + feats=feats, + layout=self.layout, + ) + new_tensor._cache = self._cache + return new_tensor + + def to_dense(self, max_length=None) -> torch.Tensor: + """ + Convert a VarLenTensor to a dense representation without for-loop. + + Returns: + dense (torch.Tensor): (N, L, C) dense tensor + mask (torch.BoolTensor): (N, L) mask indicating valid positions + """ + N = len(self) + L = max_length or self.seqlen.max().item() + spatial = self.feats.shape[1:] + idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) + mask = (idx < self.seqlen.unsqueeze(1)) + mapping = mask.reshape(-1).cumsum(dim=0) - 1 + dense = self.feats[mapping] + dense = dense.reshape(N, L, *spatial) + return dense, mask + + def __neg__(self) -> 'VarLenTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_feats = [] + new_layout = [] + start = 0 + for new_idx, old_idx in enumerate(idx): + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_feats[-1]))) + start += len(new_feats[-1]) + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) + return new_tensor + + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + if isinstance(dim, int): + dim = (dim,) + + if op =='mean': + red = self.feats.mean(dim=dim, keepdim=keepdim) + elif op =='sum': + red = self.feats.sum(dim=dim, keepdim=keepdim) + elif op == 'prod': + red = self.feats.prod(dim=dim, keepdim=keepdim) + else: + raise ValueError(f"Unsupported reduce operation: {op}") + + if dim is None or 0 in dim: + return red + + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) + return red + + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='mean', dim=dim, keepdim=keepdim) + + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='sum', dim=dim, keepdim=keepdim) + + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='prod', dim=dim, keepdim=keepdim) + + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + mean = self.mean(dim=dim, keepdim=True) + mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) + std = (mean2 - mean ** 2).sqrt() + return std + + def __repr__(self) -> str: + return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + + +def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor: + """ + Concatenate a list of varlen tensors. + + Args: + inputs (List[VarLenTensor]): List of varlen tensors to concatenate. + """ + if dim == 0: + new_feats = torch.cat([input.feats for input in inputs], dim=0) + start = 0 + new_layout = [] + for input in inputs: + for l in input.layout: + new_layout.append(slice(start, start + l.stop - l.start)) + start += l.stop - l.start + output = VarLenTensor(feats=new_feats, layout=new_layout) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: + """ + Unbind a varlen tensor along a dimension. + + Args: + input (VarLenTensor): Varlen tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(len(input))] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + + +class SparseTensor(VarLenTensor): + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + SparseTensorData = None + + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + if self.SparseTensorData is None: + import importlib + if config.CONV == 'torchsparse': + self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif config.CONV == 'spconv': + self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape = args + (None,) * (3 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + if config.CONV == 'torchsparse': + self.data = self.SparseTensorData(feats, coords, **kwargs) + elif config.CONV == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1) + self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) + self.data._features = feats + else: + self.data = { + 'feats': feats, + 'coords': coords, + } + elif method_id == 1: + data, shape = args + (None,) * (2 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + self.data = data + + self._shape = shape + self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1))) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if config.DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + @staticmethod + def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': + """ + Create a SparseTensor from a list of tensors. + """ + feats = torch.cat(feats_list, dim=0) + coords = [] + for i, coord in enumerate(coords_list): + coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1) + coords.append(coord) + coords = torch.cat(coords, dim=0) + return SparseTensor(feats, coords) + + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Convert a SparseTensor to list of tensors. + """ + feats_list = [] + coords_list = [] + for s in self.layout: + feats_list.append(self.feats[s]) + coords_list.append(self.coords[s]) + return feats_list, coords_list + + def __len__(self) -> int: + return len(self.layout) + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + def __cal_spatial_shape(self, coords): + return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) + + @property + def shape(self) -> torch.Size: + if self._shape is None: + self._shape = self.__cal_shape(self.feats, self.coords) + return self._shape + + @property + def layout(self) -> List[slice]: + layout = self.get_spatial_cache('layout') + if layout is None: + layout = self.__cal_layout(self.coords, self.shape[0]) + self.register_spatial_cache('layout', layout) + return layout + + @property + def spatial_shape(self) -> torch.Size: + spatial_shape = self.get_spatial_cache('shape') + if spatial_shape is None: + spatial_shape = self.__cal_spatial_shape(self.coords) + self.register_spatial_cache('shape', spatial_shape) + return spatial_shape + + @property + def feats(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.F + elif config.CONV == 'spconv': + return self.data.features + else: + return self.data['feats'] + + @feats.setter + def feats(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.F = value + elif config.CONV == 'spconv': + self.data.features = value + else: + self.data['feats'] = value + + @property + def coords(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.C + elif config.CONV == 'spconv': + return self.data.indices + else: + return self.data['coords'] + + @coords.setter + def coords(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.C = value + elif config.CONV == 'spconv': + self.data.indices = value + else: + self.data['coords'] = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + seqlen = self.get_spatial_cache('seqlen') + if seqlen is None: + seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + self.register_spatial_cache('seqlen', seqlen) + return seqlen + + @property + def cum_seqlen(self) -> torch.LongTensor: + cum_seqlen = self.get_spatial_cache('cum_seqlen') + if cum_seqlen is None: + cum_seqlen = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + self.register_spatial_cache('cum_seqlen', cum_seqlen) + return cum_seqlen + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') + if batch_boardcast_map is None: + batch_boardcast_map = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) + return batch_boardcast_map + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + if config.CONV == 'torchsparse': + new_data = self.SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif config.CONV == 'spconv': + new_data = self.SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + else: + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + new_tensor = SparseTensor( + new_data, + shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, + scale=self._scale, + spatial_cache=self._spatial_cache + ) + return new_tensor + + def to_dense(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.dense() + elif config.CONV == 'spconv': + return self.data.dense() + else: + spatial_shape = self.spatial_shape + ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) + idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) + ret[tuple(idx)] = self.feats + return ret + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_coords = [] + new_feats = [] + new_layout = [] + new_shape = torch.Size([len(idx)] + list(self.shape[1:])) + start = 0 + for new_idx, old_idx in enumerate(idx): + new_coords.append(self.coords[self.layout[old_idx]].clone()) + new_coords[-1][:, 0] = new_idx + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_coords[-1]))) + start += len(new_coords[-1]) + new_coords = torch.cat(new_coords, dim=0).contiguous() + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) + new_tensor.register_spatial_cache('layout', new_layout) + return new_tensor + + def clear_spatial_cache(self) -> None: + """ + Clear all spatial caches. + """ + self._spatial_cache = {} + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + def __repr__(self) -> str: + return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/trellis2/modules/sparse/config.py b/trellis2/modules/sparse/config.py new file mode 100644 index 0000000000000000000000000000000000000000..74e995366efaa6413a2ab64fe521a8eda119508a --- /dev/null +++ b/trellis2/modules/sparse/config.py @@ -0,0 +1,43 @@ +from typing import * + +CONV = 'flex_gemm' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global CONV + global DEBUG + global ATTN + + env_sparse_conv_backend = os.environ.get('SPARSE_CONV_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn_backend = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn_backend is None: + env_sparse_attn_backend = os.environ.get('ATTN_BACKEND') + + if env_sparse_conv_backend is not None and env_sparse_conv_backend in ['none', 'spconv', 'torchsparse', 'flex_gemm']: + CONV = env_sparse_conv_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn_backend is not None and env_sparse_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'flash_attn_4']: + ATTN = env_sparse_attn_backend + + print(f"[SPARSE] Conv backend: {CONV}; Attention backend: {ATTN}") + + +__from_env() + + +def set_conv_backend(backend: Literal['none', 'spconv', 'torchsparse', 'flex_gemm']): + global CONV + CONV = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn_backend(backend: Literal['xformers', 'flash_attn', 'flash_attn_3', 'flash_attn_4']): + global ATTN + ATTN = backend diff --git a/trellis2/modules/sparse/conv/__init__.py b/trellis2/modules/sparse/conv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f5911f2cb266f93e52cfcdea9f63f39be172c6 --- /dev/null +++ b/trellis2/modules/sparse/conv/__init__.py @@ -0,0 +1,2 @@ +from .conv import SparseConv3d, SparseInverseConv3d +from . import config diff --git a/trellis2/modules/sparse/conv/config.py b/trellis2/modules/sparse/conv/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0848906703e7811300235e32d14e50ad5aac51 --- /dev/null +++ b/trellis2/modules/sparse/conv/config.py @@ -0,0 +1,3 @@ +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' +FLEX_GEMM_ALGO = 'masked_implicit_gemm_splitk' # 'explicit_gemm', 'implicit_gemm', 'implicit_gemm_splitk', 'masked_implicit_gemm', 'masked_implicit_gemm_splitk' +FLEX_GEMM_HASHMAP_RATIO = 2.0 # Ratio of hashmap size to input size diff --git a/trellis2/modules/sparse/conv/conv.py b/trellis2/modules/sparse/conv/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7d40707a24e26ba2a11a90c41dbc9eb11e7ab2 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv.py @@ -0,0 +1,30 @@ +from .. import config +import importlib +import torch +import torch.nn as nn +from .. import SparseTensor + + +_backends = {} + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if config.CONV not in _backends: + _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) + _backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) + + def forward(self, x: SparseTensor) -> SparseTensor: + return _backends[config.CONV].sparse_conv3d_forward(self, x) + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if config.CONV not in _backends: + _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) + _backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key) + + def forward(self, x: SparseTensor) -> SparseTensor: + return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x) diff --git a/trellis2/modules/sparse/conv/conv_flex_gemm.py b/trellis2/modules/sparse/conv/conv_flex_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..d25619475e4bf39307f97d47b3828b19c48cd7da --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_flex_gemm.py @@ -0,0 +1,68 @@ +import math +import torch +import torch.nn as nn +from .. import SparseTensor +from . import config +import flex_gemm +from flex_gemm.ops.spconv import sparse_submanifold_conv3d + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + assert stride == 1 and (padding is None), 'Currently flex_gemm implementation only support submanifold sparse convolution (stride=1, padding=None)' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3 + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3 + self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3 + + self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size))) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + # initialize parameters + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci) + self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous()) + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + flex_gemm.ops.spconv.set_algorithm(config.FLEX_GEMM_ALGO) + flex_gemm.ops.spconv.set_hashmap_ratio(config.FLEX_GEMM_HASHMAP_RATIO) + + # check if neighbor map is already computed + Co, Kd, Kh, Kw, Ci = self.weight.shape + neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' + neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + + out, neighbor_cache_ = sparse_submanifold_conv3d( + x.feats, + x.coords, + torch.Size([*x.shape, *x.spatial_shape]), + self.weight, + self.bias, + neighbor_cache, + self.dilation + ) + + if neighbor_cache is None: + x.register_spatial_cache(neighbor_cache_key, neighbor_cache_) + + out = x.replace(out) + return out + + +def sparse_inverse_conv3d_init(self, *args, **kwargs): + raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet') + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet') diff --git a/trellis2/modules/sparse/conv/conv_spconv.py b/trellis2/modules/sparse/conv/conv_spconv.py new file mode 100644 index 0000000000000000000000000000000000000000..f709708d4f8ee3bb98930c68b28e7f3b897fc591 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from . import config +import spconv.pytorch as spconv + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + algo = None + if config.SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif config.SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/trellis2/modules/sparse/conv/conv_torchsparse.py b/trellis2/modules/sparse/conv/conv_torchsparse.py new file mode 100644 index 0000000000000000000000000000000000000000..5234bd15553aa8d71df280672475a898ffe56af7 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +import torchsparse + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s / stride for s, stride in zip(x._scale, self.conv.stride)]) + return out diff --git a/trellis2/modules/sparse/linear.py b/trellis2/modules/sparse/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..44317709ab16b17fa0132fd48e48519ea0ef9ea9 --- /dev/null +++ b/trellis2/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import VarLenTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) diff --git a/trellis2/modules/sparse/nonlinearity.py b/trellis2/modules/sparse/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..950e5c03c997905162e39fec701db49a2640700c --- /dev/null +++ b/trellis2/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import VarLenTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(self.activation(input.feats)) + diff --git a/trellis2/modules/sparse/norm.py b/trellis2/modules/sparse/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..95711203f0adbae1c7ea845e2500a3823997d652 --- /dev/null +++ b/trellis2/modules/sparse/norm.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +from ..utils import manual_cast +from . import VarLenTensor +from . import config + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: VarLenTensor) -> VarLenTensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: VarLenTensor) -> VarLenTensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) diff --git a/trellis2/modules/sparse/spatial/__init__.py b/trellis2/modules/sparse/spatial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e27425f165d271fd16a2c6f7b7684d4a81202ebd --- /dev/null +++ b/trellis2/modules/sparse/spatial/__init__.py @@ -0,0 +1,2 @@ +from .basic import * +from .spatial2channel import * diff --git a/trellis2/modules/sparse/spatial/basic.py b/trellis2/modules/sparse/spatial/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..eaeb8afefd889d8a94812c579383e528b8b56699 --- /dev/null +++ b/trellis2/modules/sparse/spatial/basic.py @@ -0,0 +1,109 @@ +from typing import * +import torch +import torch.nn as nn +from .. import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: int, mode: Literal['mean', 'max'] = 'mean'): + super(SparseDownsample, self).__init__() + self.factor = factor + self.mode = mode + assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}' + + def forward(self, x: SparseTensor) -> SparseTensor: + cache = x.get_spatial_cache(f'downsample_{self.factor}') + if cache is None: + DIM = x.coords.shape[-1] - 1 + + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx = cache + + new_feats = torch.scatter_reduce( + torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype), + dim=0, + index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]), + src=x.feats, + reduce=self.mode, + include_self=False, + ) + out = SparseTensor(new_feats, new_coords, x._shape) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx)) + out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx)) + out.register_spatial_cache(f'shape', torch.Size(MAX)) + if self.training: + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) + subdivision[idx, subidx] = True + out.register_spatial_cache(f'subdivision', subdivision) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__( + self, factor: int + ): + super(SparseUpsample, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'upsample_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseUpsample with SparseDownsample.') + else: + sub = subdivision.feats + N_leaf = sub.sum(dim=-1) + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx = cache + + new_feats = x.feats[idx] + out = SparseTensor(new_feats, new_coords, x._shape) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + + return out + \ No newline at end of file diff --git a/trellis2/modules/sparse/spatial/spatial2channel.py b/trellis2/modules/sparse/spatial/spatial2channel.py new file mode 100644 index 0000000000000000000000000000000000000000..577f36d208726f64422f8774c3556a1d643f1e2d --- /dev/null +++ b/trellis2/modules/sparse/spatial/spatial2channel.py @@ -0,0 +1,93 @@ +from typing import * +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseSpatial2Channel(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as rearranging its features from spatial to channel. + """ + def __init__(self, factor: int = 2): + super(SparseSpatial2Channel, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + cache = x.get_spatial_cache(f'spatial2channel_{self.factor}') + if cache is None: + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx, subidx = cache + + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) + new_feats[idx * self.factor ** DIM + subidx] = x.feats + + out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) + out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) + out.register_spatial_cache(f'shape', torch.Size(MAX)) + if self.training: + subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) + subdivision[idx, subidx] = True + out.register_spatial_cache(f'subdivision', subdivision) + + return out + + +class SparseChannel2Spatial(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as rearranging its features from channel to spatial. + """ + def __init__(self, factor: int = 2): + super(SparseChannel2Spatial, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'channel2spatial_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.') + else: + sub = subdivision.feats # [N, self.factor ** DIM] + N_leaf = sub.sum(dim=-1) # [N] + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx, subidx = cache + + x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1) + new_feats = x_feats[idx * self.factor ** DIM + subidx] + out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + return out diff --git a/trellis2/modules/sparse/transformer/__init__.py b/trellis2/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/trellis2/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis2/modules/sparse/transformer/blocks.py b/trellis2/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1ec600404fba490894872109d44be6b6477186 --- /dev/null +++ b/trellis2/modules/sparse/transformer/blocks.py @@ -0,0 +1,145 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import VarLenTensor, SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: VarLenTensor) -> VarLenTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/trellis2/modules/sparse/transformer/modulated.py b/trellis2/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8fa56513da50906b3ba503c6cdce52441cc99e --- /dev/null +++ b/trellis2/modules/sparse/transformer/modulated.py @@ -0,0 +1,205 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import VarLenTensor, SparseTensor +from ..attention import SparseMultiHeadAttention, SparseProjectAttention, SparseGatedProjectAttention +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + + Supports two image attention modes: + - "cross": Standard cross-attention with image features + - "proj": Projection-based attention with view-aligned features + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross", + proj_in_channels: Optional[int] = None, + vae_in_channels: Optional[int] = None, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.image_attn_mode = image_attn_mode + + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + + # Build cross attention based on mode + if image_attn_mode == "cross": + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + elif image_attn_mode == "proj": + _proj_in = proj_in_channels if proj_in_channels is not None else ctx_channels + cross_attn_block = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.cross_attn = SparseProjectAttention(cross_attn_block, channels, _proj_in) + elif image_attn_mode == "gated_proj": + _dino_in = proj_in_channels if proj_in_channels is not None else ctx_channels + _vae_in = vae_in_channels if vae_in_channels is not None else 16 + cross_attn_block = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.cross_attn = SparseGatedProjectAttention(cross_attn_block, channels, _dino_in, _vae_in) + else: + raise ValueError(f"Unknown image attention mode: {image_attn_mode}") + + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/trellis2/modules/spatial.py b/trellis2/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3 --- /dev/null +++ b/trellis2/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/trellis2/modules/transformer/__init__.py b/trellis2/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/trellis2/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis2/modules/transformer/blocks.py b/trellis2/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6f5eb5462fec62aa5edc062104f643fca03bfa --- /dev/null +++ b/trellis2/modules/transformer/blocks.py @@ -0,0 +1,186 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = True, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h, phases=phases) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, phases, use_reentrant=False) + else: + return self._forward(x, phases) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.self_attn(h, phases=phases) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, phases, use_reentrant=False) + else: + return self._forward(x, context, phases) + \ No newline at end of file diff --git a/trellis2/modules/transformer/modulated.py b/trellis2/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..4444b208c5d7559029b0908eb8519282c512973c --- /dev/null +++ b/trellis2/modules/transformer/modulated.py @@ -0,0 +1,205 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention, ProjectAttention, GatedProjectAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, phases, use_reentrant=False) + else: + return self._forward(x, mod, phases) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + + Supports two image attention modes: + - "cross": Standard cross-attention with image features + - "proj": Projection-based attention with view-aligned features + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross", + proj_in_channels: Optional[int] = None, + vae_in_channels: Optional[int] = None, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.image_attn_mode = image_attn_mode + + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + + # Build cross attention based on mode + if image_attn_mode == "cross": + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + elif image_attn_mode == "proj": + _proj_in = proj_in_channels if proj_in_channels is not None else ctx_channels + cross_attn_block = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.cross_attn = ProjectAttention(cross_attn_block, channels, _proj_in) + elif image_attn_mode == "gated_proj": + _dino_in = proj_in_channels if proj_in_channels is not None else ctx_channels + _vae_in = vae_in_channels if vae_in_channels is not None else 16 + cross_attn_block = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.cross_attn = GatedProjectAttention(cross_attn_block, channels, _dino_in, _vae_in) + else: + raise ValueError(f"Unknown image attention mode: {image_attn_mode}") + + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) + else: + return self._forward(x, mod, context, phases) + \ No newline at end of file diff --git a/trellis2/modules/utils.py b/trellis2/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5d92d7d5b0ab07972fb6e0397c9c8e525e196211 --- /dev/null +++ b/trellis2/modules/utils.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +from ..modules import sparse as sp + +MIX_PRECISION_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def convert_module_to(l, dtype): + """ + Convert primitive modules to the given dtype. + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.to(dtype) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def manual_cast(tensor, dtype): + """ + Cast if autocast is not enabled. + """ + if not torch.is_autocast_enabled(): + return tensor.type(dtype) + return tensor + + +def str_to_dtype(dtype_str: str): + return { + 'f16': torch.float16, + 'fp16': torch.float16, + 'float16': torch.float16, + 'bf16': torch.bfloat16, + 'bfloat16': torch.bfloat16, + 'f32': torch.float32, + 'fp32': torch.float32, + 'float32': torch.float32, + }[dtype_str] diff --git a/trellis2/pipelines/__init__.py b/trellis2/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..807eae0d16c6bce6d4b9dfdf8ae60ada8e43e180 --- /dev/null +++ b/trellis2/pipelines/__init__.py @@ -0,0 +1,54 @@ +import importlib + +__attributes = { + "Trellis2ImageTo3DPipeline": "trellis2_image_to_3d", + "Trellis2TexturingPipeline": "trellis2_texturing", + "Pixal3DImageTo3DPipeline": "pixal3d_image_to_3d", +} + +__submodules = ['samplers', 'rembg'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) + + +# For PyLance +if __name__ == '__main__': + from . import samplers, rembg + from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline + from .trellis2_texturing import Trellis2TexturingPipeline + from .pixal3d_image_to_3d import Pixal3DImageTo3DPipeline diff --git a/trellis2/pipelines/base.py b/trellis2/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..331e1ed979b47b23ef1d45c8851658df5d49fb69 --- /dev/null +++ b/trellis2/pipelines/base.py @@ -0,0 +1,72 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/{config_file}") + + if is_local: + config_file = f"{path}/{config_file}" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, config_file) + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = {} + for k, v in args['models'].items(): + if hasattr(cls, 'model_names_to_load') and k not in cls.model_names_to_load: + continue + try: + _models[k] = models.from_pretrained(f"{path}/{v}") + except Exception as e: + _models[k] = models.from_pretrained(v) + + new_pipeline = cls(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + if hasattr(self, '_device'): + return self._device + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) \ No newline at end of file diff --git a/trellis2/pipelines/pixal3d_image_to_3d.py b/trellis2/pipelines/pixal3d_image_to_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1826c8d791411f13d22fc304e7b5dd379b3d01 --- /dev/null +++ b/trellis2/pipelines/pixal3d_image_to_3d.py @@ -0,0 +1,783 @@ +from typing import * +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +from .base import Pipeline +from . import samplers, rembg +from ..modules.sparse import SparseTensor +from ..modules import image_feature_extractor +from ..representations import Mesh, MeshWithVoxel + + +class Pixal3DImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Pixal3D (proj mode) image-to-3D models. + + 基于 Trellis2 pipeline,使用 proj 模式进行推理。 + 每个 stage (SS, Shape 512, Shape 1024, Tex 1024) 有独立的 image_cond_model (DinoV3ProjFeatureExtractor)。 + 条件构建使用 camera-aware projection(需要 camera_angle_x, distance, mesh_scale 参数)。 + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + shape_slat_sampler (samplers.Sampler): The sampler for the structured latent. + tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. + sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): The parameters for the structured latent sampler. + tex_slat_sampler_params (dict): The parameters for the texture latent sampler. + shape_slat_normalization (dict): The normalization parameters for the structured latent. + tex_slat_normalization (dict): The normalization parameters for the texture latent. + image_cond_model_ss (nn.Module): Proj image cond model for sparse structure stage. + image_cond_model_shape_512 (nn.Module): Proj image cond model for shape LR (512) stage. + image_cond_model_shape_1024 (nn.Module): Proj image cond model for shape HR (1024) stage. + image_cond_model_tex_1024 (nn.Module): Proj image cond model for texture (1024) stage. + rembg_model (Callable): The model for removing background. + low_vram (bool): Whether to use low-VRAM mode. + """ + model_names_to_load = [ + 'sparse_structure_flow_model', + 'sparse_structure_decoder', + 'shape_slat_flow_model_512', + 'shape_slat_flow_model_1024', + 'shape_slat_decoder', + 'tex_slat_flow_model_512', + 'tex_slat_flow_model_1024', + 'tex_slat_decoder', + ] + + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + shape_slat_sampler: samplers.Sampler = None, + tex_slat_sampler: samplers.Sampler = None, + sparse_structure_sampler_params: dict = None, + shape_slat_sampler_params: dict = None, + tex_slat_sampler_params: dict = None, + shape_slat_normalization: dict = None, + tex_slat_normalization: dict = None, + image_cond_model_ss: nn.Module = None, + image_cond_model_shape_512: nn.Module = None, + image_cond_model_shape_1024: nn.Module = None, + image_cond_model_tex_1024: nn.Module = None, + rembg_model: Callable = None, + low_vram: bool = True, + default_pipeline_type: str = '1024_cascade', + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.shape_slat_sampler = shape_slat_sampler + self.tex_slat_sampler = tex_slat_sampler + self.sparse_structure_sampler_params = sparse_structure_sampler_params + self.shape_slat_sampler_params = shape_slat_sampler_params + self.tex_slat_sampler_params = tex_slat_sampler_params + self.shape_slat_normalization = shape_slat_normalization + self.tex_slat_normalization = tex_slat_normalization + self.image_cond_model_ss = image_cond_model_ss + self.image_cond_model_shape_512 = image_cond_model_shape_512 + self.image_cond_model_shape_1024 = image_cond_model_shape_1024 + self.image_cond_model_tex_1024 = image_cond_model_tex_1024 + self.rembg_model = rembg_model + self.low_vram = low_vram + self.default_pipeline_type = default_pipeline_type + self.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + self._device = 'cpu' + + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pixal3DImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super().from_pretrained(path, config_file) + args = pipeline._pretrained_args + + pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args']) + pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params'] + + pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + + pipeline.shape_slat_normalization = args['shape_slat_normalization'] + pipeline.tex_slat_normalization = args['tex_slat_normalization'] + + # Proj mode: image_cond_models 需要外部加载后设置,这里先置为 None + pipeline.image_cond_model_ss = None + pipeline.image_cond_model_shape_512 = None + pipeline.image_cond_model_shape_1024 = None + pipeline.image_cond_model_tex_1024 = None + + pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + + pipeline.low_vram = args.get('low_vram', True) + pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade') + pipeline.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + pipeline._device = 'cpu' + + return pipeline + + def to(self, device: torch.device) -> None: + self._device = device + if not self.low_vram: + super().to(device) + if self.rembg_model is not None: + self.rembg_model.to(device) + + def preprocess_image(self, input: Image.Image, bg_color: tuple = (0, 0, 0)) -> Image.Image: + """ + Preprocess the input image. + + Args: + input: Input image (RGB or RGBA). + bg_color: Background color (R, G, B) in 0~255. Default black (0,0,0). + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if has_alpha: + output = input + else: + input = input.convert('RGB') + if self.low_vram: + self.rembg_model.to(self.device) + output = self.rembg_model(input) + if self.low_vram: + self.rembg_model.cpu() + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1.1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + rgb = output[:, :, :3] + a = output[:, :, 3:4] + bg = np.array(bg_color, dtype=np.float32) / 255.0 + output = rgb * a + bg * (1.0 - a) + output = Image.fromarray((np.clip(output, 0, 1) * 255).astype(np.uint8)) + return output + + # ========================================================================= + # Proj 模式条件构建 + # ========================================================================= + + @torch.no_grad() + def get_proj_cond_ss( + self, + image: list, + camera_angle_x: float = 0.8575560450553894, + distance: float = 2.0, + mesh_scale: float = 1.0, + ) -> dict: + """ + Get proj conditioning for sparse structure stage. + + Args: + image: List of PIL images. + camera_angle_x: Camera horizontal FOV in radians. + distance: Camera distance. + mesh_scale: Mesh scale. + + Returns: + dict with 'cond' and 'neg_cond', each containing {'global': ..., 'proj': ...} + """ + device = self.device + image_cond_model = self.image_cond_model_ss + if self.low_vram: + image_cond_model.to(device) + cam_angle = torch.tensor([camera_angle_x], device=device) + dist_tensor = torch.tensor([distance], device=device) + scale_tensor = torch.tensor([mesh_scale], device=device) + z_global, z_proj = image_cond_model( + image, camera_angle_x=cam_angle, distance=dist_tensor, mesh_scale=scale_tensor, + ) + if self.low_vram: + image_cond_model.cpu() + return { + 'cond': {'global': z_global, 'proj': z_proj}, + 'neg_cond': {'global': torch.zeros_like(z_global), 'proj': torch.zeros_like(z_proj)}, + } + + @torch.no_grad() + def get_proj_cond_shape( + self, + image_cond_model: nn.Module, + image: list, + coords: torch.Tensor, + camera_angle_x: float = 0.8575560450553894, + distance: float = 2.0, + mesh_scale: float = 1.0, + grid_resolution_override: int = None, + ) -> dict: + """ + Get proj conditioning for shape/texture stages (sparse-token aligned). + + Args: + image_cond_model: The proj image cond model for this stage. + image: List of PIL images. + coords: Sparse structure coordinates [N, 4] (batch_idx, x, y, z). + camera_angle_x: Camera horizontal FOV in radians. + distance: Camera distance. + mesh_scale: Mesh scale. + grid_resolution_override: Override the grid resolution if not None. + + Returns: + dict with 'cond' and 'neg_cond', each containing {'global': ..., 'proj': SparseTensor} + """ + device = self.device + if self.low_vram: + image_cond_model.to(device) + + orig_grid_res = image_cond_model.grid_resolution + if grid_resolution_override is not None and grid_resolution_override != orig_grid_res: + image_cond_model.grid_resolution = grid_resolution_override + image_cond_model.proj_grid = image_cond_model.proj_grid.__class__( + grid_resolution=grid_resolution_override, + image_resolution=image_cond_model.proj_grid.image_resolution, + ).to(device) + + B = 1 + cam_angle = torch.tensor([camera_angle_x], device=device) + dist_tensor = torch.tensor([distance], device=device) + scale_tensor = torch.tensor([mesh_scale], device=device) + z_global, z_proj = image_cond_model( + image, camera_angle_x=cam_angle, distance=dist_tensor, mesh_scale=scale_tensor, + ) + grid_res = image_cond_model.grid_resolution + z_proj_grid = z_proj.reshape(B, grid_res, grid_res, grid_res, -1) + batch_indices = coords[:, 0].long() + x_coords = coords[:, 1].long() + y_coords = coords[:, 2].long() + z_coords = coords[:, 3].long() + z_proj_sparse = z_proj_grid[batch_indices, x_coords, y_coords, z_coords] + z_proj_st = SparseTensor(feats=z_proj_sparse, coords=coords) + + if grid_resolution_override is not None and grid_resolution_override != orig_grid_res: + image_cond_model.grid_resolution = orig_grid_res + image_cond_model.proj_grid = image_cond_model.proj_grid.__class__( + grid_resolution=orig_grid_res, + image_resolution=image_cond_model.proj_grid.image_resolution, + ).to(device) + + if self.low_vram: + image_cond_model.cpu() + return { + 'cond': {'global': z_global, 'proj': z_proj_st}, + 'neg_cond': {'global': torch.zeros_like(z_global), 'proj': SparseTensor(feats=torch.zeros_like(z_proj_sparse), coords=coords)}, + } + + # ========================================================================= + # Sampling methods (保持与 Trellis2 一致) + # ========================================================================= + + def sample_sparse_structure( + self, + cond: dict, + resolution: int, + num_samples: int = 1, + sampler_params: dict = {}, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + resolution (int): The resolution of the sparse structure. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample sparse structure latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + in_channels = flow_model.in_channels + noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling sparse structure (proj)", + ).samples + if self.low_vram: + flow_model.cpu() + + # Decode sparse structure latent + decoder = self.models['sparse_structure_decoder'] + if self.low_vram: + decoder.to(self.device) + decoded = decoder(z_s)>0 + if self.low_vram: + decoder.cpu() + if resolution != decoded.shape[2]: + ratio = decoded.shape[2] // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 + coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() + + return coords + + def sample_shape_slat( + self, + cond: dict, + flow_model, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat (proj)", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def sample_shape_slat_cascade( + self, + lr_cond: dict, + cond: dict, + flow_model_lr, + flow_model, + lr_resolution: int, + resolution: int, + coords: torch.Tensor, + sampler_params: dict = {}, + max_num_tokens: int = 49152, + ) -> SparseTensor: + """ + Sample structured latent with cascade (LR → HR). + + Args: + lr_cond (dict): The conditioning information for LR stage. + cond (dict): The conditioning information for HR stage. + flow_model_lr: LR flow model. + flow_model: HR flow model. + lr_resolution (int): LR resolution. + resolution (int): Target HR resolution. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + max_num_tokens (int): Maximum number of tokens. + """ + # LR + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model_lr.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model_lr, + noise, + **lr_cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling LR shape SLat (proj, 512)", + ).samples + if self.low_vram: + flow_model_lr.cpu() + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + # Upsample + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + hr_resolution = resolution + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + coords = quant_coords.unique(dim=0) + num_tokens = coords.shape[0] + if num_tokens < max_num_tokens or hr_resolution == 1024: + if hr_resolution != resolution: + print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.") + break + hr_resolution -= 128 + + # Sample structured latent (HR) + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc=f"Sampling HR shape SLat (proj, {hr_resolution})", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat, hr_resolution + + def decode_shape_slat( + self, + slat: SparseTensor, + resolution: int, + ) -> Tuple[List[Mesh], List[SparseTensor]]: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + List[Mesh]: The decoded meshes. + List[SparseTensor]: The decoded substructures. + """ + self.models['shape_slat_decoder'].set_resolution(resolution) + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + ret = self.models['shape_slat_decoder'](slat, return_subs=True) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + return ret + + def sample_tex_slat( + self, + cond: dict, + flow_model, + shape_slat: SparseTensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample texture structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + shape_slat (SparseTensor): The structured latent for shape. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) + shape_slat = (shape_slat - mean) / std + + in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels + noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) + sampler_params = {**self.tex_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.tex_slat_sampler.sample( + flow_model, + noise, + concat_cond=shape_slat, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling texture SLat (proj)", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def decode_tex_slat( + self, + slat: SparseTensor, + subs: List[SparseTensor], + ) -> SparseTensor: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + SparseTensor: The decoded texture voxels + """ + if self.low_vram: + self.models['tex_slat_decoder'].to(self.device) + ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5 + if self.low_vram: + self.models['tex_slat_decoder'].cpu() + return ret + + @torch.no_grad() + def decode_latent( + self, + shape_slat: SparseTensor, + tex_slat: SparseTensor, + resolution: int, + ) -> List[MeshWithVoxel]: + """ + Decode the latent codes. + + Args: + shape_slat (SparseTensor): The structured latent for shape. + tex_slat (SparseTensor): The structured latent for texture. + resolution (int): The resolution of the output. + """ + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + out_mesh = [] + torch.cuda.synchronize() + for m, v in zip(meshes, tex_voxels): + m.fill_holes() + out_mesh.append( + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout=self.pbr_attr_layout + ) + ) + return out_mesh + + @torch.no_grad() + def run( + self, + image: Image.Image, + camera_params: dict, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + shape_slat_sampler_params: dict = {}, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + return_latent: bool = False, + pipeline_type: Optional[str] = None, + max_num_tokens: int = 49152, + ) -> List[MeshWithVoxel]: + """ + Run the Pixal3D pipeline (proj mode, cascade). + + Args: + image (Image.Image): The image prompt. + camera_params (dict): Camera parameters with keys: + - camera_angle_x (float): Horizontal FOV in radians. + - distance (float): Camera distance. + - mesh_scale (float): Mesh scale factor. + num_samples (int): The number of samples to generate. + seed (int): The random seed. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler. + tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler. + preprocess_image (bool): Whether to preprocess the image. + return_latent (bool): Whether to return the latent codes. + pipeline_type (str): The type of the pipeline. Options: '1024_cascade', '1536_cascade'. + max_num_tokens (int): The maximum number of tokens to use. + """ + # Check pipeline type + pipeline_type = pipeline_type or self.default_pipeline_type + if pipeline_type == '1024_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + hr_resolution = 1024 + elif pipeline_type == '1536_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + hr_resolution = 1536 + else: + raise ValueError(f"Invalid pipeline type for Pixal3D proj mode: {pipeline_type}. " + f"Supported: '1024_cascade', '1536_cascade'.") + + # Validate image_cond_models are set + assert self.image_cond_model_ss is not None, "image_cond_model_ss not set." + assert self.image_cond_model_shape_512 is not None, "image_cond_model_shape_512 not set." + assert self.image_cond_model_shape_1024 is not None, "image_cond_model_shape_1024 not set." + assert self.image_cond_model_tex_1024 is not None, "image_cond_model_tex_1024 not set." + + # Extract camera params + camera_angle_x = camera_params['camera_angle_x'] + distance = camera_params['distance'] + mesh_scale = camera_params.get('mesh_scale', 1.0) + + if preprocess_image: + image = self.preprocess_image(image) + torch.manual_seed(seed) + + # ---- Stage 1: Sparse Structure (proj) ---- + cond_ss = self.get_proj_cond_ss( + [image], + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + ) + ss_res = 32 + coords = self.sample_sparse_structure( + cond_ss, ss_res, + num_samples, sparse_structure_sampler_params + ) + del cond_ss + torch.cuda.empty_cache() + + # ---- Stage 2: Shape LR 512 (proj) ---- + cond_shape_lr = self.get_proj_cond_shape( + self.image_cond_model_shape_512, [image], coords, + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + ) + lr_slat = self.sample_shape_slat( + cond_shape_lr, self.models['shape_slat_flow_model_512'], + coords, shape_slat_sampler_params + ) + del cond_shape_lr + torch.cuda.empty_cache() + + # ---- Stage 3a: Upsample LR → HR ---- + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + hr_coords = self.models['shape_slat_decoder'].upsample(lr_slat, upsample_times=4) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + + lr_resolution = 512 + actual_hr_resolution = hr_resolution + while True: + grid_res = actual_hr_resolution // 16 + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (grid_res - 1)).round().int(), + ], dim=1) + hr_coords_unique = quant_coords.unique(dim=0) + num_tokens = hr_coords_unique.shape[0] + if num_tokens < max_num_tokens or actual_hr_resolution == 1024: + break + actual_hr_resolution -= 128 + + actual_grid_res = actual_hr_resolution // 16 + del lr_slat, hr_coords, quant_coords + torch.cuda.empty_cache() + + # ---- Stage 3b: Shape HR (proj) ---- + cond_shape_hr = self.get_proj_cond_shape( + self.image_cond_model_shape_1024, [image], hr_coords_unique, + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + grid_resolution_override=actual_grid_res, + ) + noise_hr = SparseTensor( + feats=torch.randn(hr_coords_unique.shape[0], self.models['shape_slat_flow_model_1024'].in_channels).to(self.device), + coords=hr_coords_unique, + ) + sampler_params_hr = {**self.shape_slat_sampler_params, **shape_slat_sampler_params} + flow_model_hr = self.models['shape_slat_flow_model_1024'] + if self.low_vram: + flow_model_hr.to(self.device) + hr_slat = self.shape_slat_sampler.sample( + flow_model_hr, + noise_hr, + **cond_shape_hr, + **sampler_params_hr, + verbose=True, + tqdm_desc=f"Sampling HR shape SLat (proj, {actual_hr_resolution})", + ).samples + if self.low_vram: + flow_model_hr.cpu() + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(hr_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(hr_slat.device) + shape_slat = hr_slat * std + mean + del cond_shape_hr, noise_hr, hr_slat, hr_coords_unique + torch.cuda.empty_cache() + + # ---- Stage 4: Texture (proj) ---- + tex_grid_res = actual_hr_resolution // 16 + cond_tex = self.get_proj_cond_shape( + self.image_cond_model_tex_1024, [image], shape_slat.coords, + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + grid_resolution_override=tex_grid_res, + ) + tex_slat = self.sample_tex_slat( + cond_tex, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + del cond_tex + torch.cuda.empty_cache() + + # ---- Stage 5: Decode ---- + res = actual_hr_resolution + out_mesh = self.decode_latent(shape_slat, tex_slat, res) + if return_latent: + return out_mesh, (shape_slat, tex_slat, res) + else: + return out_mesh diff --git a/trellis2/pipelines/rembg/BiRefNet.py b/trellis2/pipelines/rembg/BiRefNet.py new file mode 100644 index 0000000000000000000000000000000000000000..c71a99274823aefe6f18ab921a5beb074177de18 --- /dev/null +++ b/trellis2/pipelines/rembg/BiRefNet.py @@ -0,0 +1,42 @@ +from typing import * +from transformers import AutoModelForImageSegmentation +import torch +from torchvision import transforms +from PIL import Image + + +class BiRefNet: + def __init__(self, model_name: str = "ZhengPeng7/BiRefNet"): + self.model = AutoModelForImageSegmentation.from_pretrained( + model_name, trust_remote_code=True + ) + self.model.eval() + self.transform_image = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + def to(self, device: str): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def __call__(self, image: Image.Image) -> Image.Image: + image_size = image.size + input_images = self.transform_image(image).unsqueeze(0).to("cuda") + # Prediction + with torch.no_grad(): + preds = self.model(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + image.putalpha(mask) + return image + \ No newline at end of file diff --git a/trellis2/pipelines/rembg/__init__.py b/trellis2/pipelines/rembg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1eed1ba0c962478fc48132432f2e649fe03411 --- /dev/null +++ b/trellis2/pipelines/rembg/__init__.py @@ -0,0 +1 @@ +from .BiRefNet import * diff --git a/trellis2/pipelines/samplers/__init__.py b/trellis2/pipelines/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a69b95e5963c6d3a837518ed9c3cf6972235d60 --- /dev/null +++ b/trellis2/pipelines/samplers/__init__.py @@ -0,0 +1,6 @@ +from .base import Sampler +from .flow_euler import ( + FlowEulerSampler, + FlowEulerCfgSampler, + FlowEulerGuidanceIntervalSampler, +) \ No newline at end of file diff --git a/trellis2/pipelines/samplers/base.py b/trellis2/pipelines/samplers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1966ce787009a5ee0c1ed06dce491525ff1dbcbf --- /dev/null +++ b/trellis2/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7a4da4c1324bed5319f89916b532a159471d84 --- /dev/null +++ b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,29 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, guidance_strength, guidance_rescale=0.0, **kwargs): + if guidance_strength == 1: + return super()._inference_model(model, x_t, t, cond, **kwargs) + elif guidance_strength == 0: + return super()._inference_model(model, x_t, t, neg_cond, **kwargs) + else: + pred_pos = super()._inference_model(model, x_t, t, cond, **kwargs) + pred_neg = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg + + # CFG rescale + if guidance_rescale > 0: + x_0_pos = self._pred_to_xstart(x_t, t, pred_pos) + x_0_cfg = self._pred_to_xstart(x_t, t, pred) + std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True) + std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True) + x_0_rescaled = x_0_cfg * (std_pos / std_cfg) + x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg + pred = self._xstart_to_pred(x_t, t, x_0) + + return pred diff --git a/trellis2/pipelines/samplers/flow_euler.py b/trellis2/pipelines/samplers/flow_euler.py new file mode 100644 index 0000000000000000000000000000000000000000..5ff72b84221f210f6cb06684bdf13aedc6cf20c3 --- /dev/null +++ b/trellis2/pipelines/samplers/flow_euler.py @@ -0,0 +1,208 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _pred_to_xstart(self, x_t, t, pred): + return (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * pred + + def _xstart_to_pred(self, x_t, t, x_0): + return ((1 - self.sigma_min) * x_t - x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + return pred_x_0, pred_eps, pred_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + tqdm_desc: str = "Sampling", + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + tqdm_desc: A customized tqdm desc. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_seq = t_seq.tolist() + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc=tqdm_desc, disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + guidance_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + guidance_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + guidance_strength: float = 3.0, + guidance_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + guidance_strength: The strength of classifier-free guidance. + guidance_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, guidance_interval=guidance_interval, **kwargs) diff --git a/trellis2/pipelines/samplers/guidance_interval_mixin.py b/trellis2/pipelines/samplers/guidance_interval_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..3f57869a17d1626f5b2c58eb3c477127bf464abf --- /dev/null +++ b/trellis2/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,13 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, guidance_strength, guidance_interval, **kwargs): + if guidance_interval[0] <= t <= guidance_interval[1]: + return super()._inference_model(model, x_t, t, cond, guidance_strength=guidance_strength, **kwargs) + else: + return super()._inference_model(model, x_t, t, cond, guidance_strength=1, **kwargs) diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4d16c5a2106a7dfc545603846452b0600018f6 --- /dev/null +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -0,0 +1,610 @@ +from typing import * +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +from .base import Pipeline +from . import samplers, rembg +from ..modules.sparse import SparseTensor +from ..modules import image_feature_extractor +from ..representations import Mesh, MeshWithVoxel + + +class Trellis2ImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Trellis2 image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + shape_slat_sampler (samplers.Sampler): The sampler for the structured latent. + tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. + sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): The parameters for the structured latent sampler. + tex_slat_sampler_params (dict): The parameters for the texture latent sampler. + shape_slat_normalization (dict): The normalization parameters for the structured latent. + tex_slat_normalization (dict): The normalization parameters for the texture latent. + image_cond_model (Callable): The image conditioning model. + rembg_model (Callable): The model for removing background. + low_vram (bool): Whether to use low-VRAM mode. + """ + model_names_to_load = [ + 'sparse_structure_flow_model', + 'sparse_structure_decoder', + 'shape_slat_flow_model_512', + 'shape_slat_flow_model_1024', + 'shape_slat_decoder', + 'tex_slat_flow_model_512', + 'tex_slat_flow_model_1024', + 'tex_slat_decoder', + ] + + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + shape_slat_sampler: samplers.Sampler = None, + tex_slat_sampler: samplers.Sampler = None, + sparse_structure_sampler_params: dict = None, + shape_slat_sampler_params: dict = None, + tex_slat_sampler_params: dict = None, + shape_slat_normalization: dict = None, + tex_slat_normalization: dict = None, + image_cond_model: Callable = None, + rembg_model: Callable = None, + low_vram: bool = True, + default_pipeline_type: str = '1024_cascade', + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.shape_slat_sampler = shape_slat_sampler + self.tex_slat_sampler = tex_slat_sampler + self.sparse_structure_sampler_params = sparse_structure_sampler_params + self.shape_slat_sampler_params = shape_slat_sampler_params + self.tex_slat_sampler_params = tex_slat_sampler_params + self.shape_slat_normalization = shape_slat_normalization + self.tex_slat_normalization = tex_slat_normalization + self.image_cond_model = image_cond_model + self.rembg_model = rembg_model + self.low_vram = low_vram + self.default_pipeline_type = default_pipeline_type + self.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + self._device = 'cpu' + + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2ImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super().from_pretrained(path, config_file) + args = pipeline._pretrained_args + + pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args']) + pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params'] + + pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + + pipeline.shape_slat_normalization = args['shape_slat_normalization'] + pipeline.tex_slat_normalization = args['tex_slat_normalization'] + + # HACK: 替换 dinov3 模型源为 camenduru 镜像 + image_cond_args = args['image_cond_model']['args'].copy() + if image_cond_args.get('model_name') == 'facebook/dinov3-vitl16-pretrain-lvd1689m': + image_cond_args['model_name'] = 'camenduru/dinov3-vitl16-pretrain-lvd1689m' + pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**image_cond_args) + pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + + pipeline.low_vram = args.get('low_vram', True) + pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade') + pipeline.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + pipeline._device = 'cpu' + + return pipeline + + def to(self, device: torch.device) -> None: + self._device = device + if not self.low_vram: + super().to(device) + self.image_cond_model.to(device) + if self.rembg_model is not None: + self.rembg_model.to(device) + + def preprocess_image(self, input: Image.Image, bg_color: tuple = (0, 0, 0)) -> Image.Image: + """ + Preprocess the input image. + + Args: + input: Input image (RGB or RGBA). + bg_color: Background color (R, G, B) in 0~255. Default black (0,0,0). + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if has_alpha: + output = input + else: + input = input.convert('RGB') + if self.low_vram: + self.rembg_model.to(self.device) + output = self.rembg_model(input) + if self.low_vram: + self.rembg_model.cpu() + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1.1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + rgb = output[:, :, :3] + a = output[:, :, 3:4] + bg = np.array(bg_color, dtype=np.float32) / 255.0 + output = rgb * a + bg * (1.0 - a) + output = Image.fromarray((np.clip(output, 0, 1) * 255).astype(np.uint8)) + return output + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + self.image_cond_model.image_size = resolution + if self.low_vram: + self.image_cond_model.to(self.device) + cond = self.image_cond_model(image) + if self.low_vram: + self.image_cond_model.cpu() + if not include_neg_cond: + return {'cond': cond} + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def sample_sparse_structure( + self, + cond: dict, + resolution: int, + num_samples: int = 1, + sampler_params: dict = {}, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + resolution (int): The resolution of the sparse structure. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample sparse structure latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + in_channels = flow_model.in_channels + noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling sparse structure", + ).samples + if self.low_vram: + flow_model.cpu() + + # Decode sparse structure latent + decoder = self.models['sparse_structure_decoder'] + if self.low_vram: + decoder.to(self.device) + decoded = decoder(z_s)>0 + if self.low_vram: + decoder.cpu() + if resolution != decoded.shape[2]: + ratio = decoded.shape[2] // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 + coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() + + return coords + + def sample_shape_slat( + self, + cond: dict, + flow_model, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def sample_shape_slat_cascade( + self, + lr_cond: dict, + cond: dict, + flow_model_lr, + flow_model, + lr_resolution: int, + resolution: int, + coords: torch.Tensor, + sampler_params: dict = {}, + max_num_tokens: int = 49152, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # LR + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model_lr.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model_lr, + noise, + **lr_cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model_lr.cpu() + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + # Upsample + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + hr_resolution = resolution + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + coords = quant_coords.unique(dim=0) + num_tokens = coords.shape[0] + if num_tokens < max_num_tokens or hr_resolution == 1024: + if hr_resolution != resolution: + print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.") + break + hr_resolution -= 128 + + # Sample structured latent + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat, hr_resolution + + def decode_shape_slat( + self, + slat: SparseTensor, + resolution: int, + ) -> Tuple[List[Mesh], List[SparseTensor]]: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + List[Mesh]: The decoded meshes. + List[SparseTensor]: The decoded substructures. + """ + self.models['shape_slat_decoder'].set_resolution(resolution) + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + ret = self.models['shape_slat_decoder'](slat, return_subs=True) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + return ret + + def sample_tex_slat( + self, + cond: dict, + flow_model, + shape_slat: SparseTensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + shape_slat (SparseTensor): The structured latent for shape + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) + shape_slat = (shape_slat - mean) / std + + in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels + noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) + sampler_params = {**self.tex_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.tex_slat_sampler.sample( + flow_model, + noise, + concat_cond=shape_slat, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling texture SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def decode_tex_slat( + self, + slat: SparseTensor, + subs: List[SparseTensor], + ) -> SparseTensor: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + SparseTensor: The decoded texture voxels + """ + if self.low_vram: + self.models['tex_slat_decoder'].to(self.device) + ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5 + if self.low_vram: + self.models['tex_slat_decoder'].cpu() + return ret + + @torch.no_grad() + def decode_latent( + self, + shape_slat: SparseTensor, + tex_slat: SparseTensor, + resolution: int, + ) -> List[MeshWithVoxel]: + """ + Decode the latent codes. + + Args: + shape_slat (SparseTensor): The structured latent for shape. + tex_slat (SparseTensor): The structured latent for texture. + resolution (int): The resolution of the output. + """ + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + out_mesh = [] + torch.cuda.synchronize() + for m, v in zip(meshes, tex_voxels): + # try: + m.fill_holes() + # except RuntimeError as e: + # print(f"[WARNING] fill_holes failed (likely PyTorch/cumesh compatibility issue), skipping: {e}") + out_mesh.append( + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout=self.pbr_attr_layout + ) + ) + return out_mesh + + @torch.no_grad() + def run( + self, + image: Image.Image, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + shape_slat_sampler_params: dict = {}, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + return_latent: bool = False, + pipeline_type: Optional[str] = None, + max_num_tokens: int = 49152, + ) -> List[MeshWithVoxel]: + """ + Run the pipeline. + + Args: + image (Image.Image): The image prompt. + num_samples (int): The number of samples to generate. + seed (int): The random seed. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler. + tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler. + preprocess_image (bool): Whether to preprocess the image. + return_latent (bool): Whether to return the latent codes. + pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'. + max_num_tokens (int): The maximum number of tokens to use. + """ + # Check pipeline type + pipeline_type = pipeline_type or self.default_pipeline_type + if pipeline_type == '512': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found." + elif pipeline_type == '1024': + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '1024_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '1536_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + else: + raise ValueError(f"Invalid pipeline type: {pipeline_type}") + + if preprocess_image: + image = self.preprocess_image(image) + torch.manual_seed(seed) + cond_512 = self.get_cond([image], 512) + cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None + ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type] + coords = self.sample_sparse_structure( + cond_512, ss_res, + num_samples, sparse_structure_sampler_params + ) + if pipeline_type == '512': + shape_slat = self.sample_shape_slat( + cond_512, self.models['shape_slat_flow_model_512'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_512, self.models['tex_slat_flow_model_512'], + shape_slat, tex_slat_sampler_params + ) + res = 512 + elif pipeline_type == '1024': + shape_slat = self.sample_shape_slat( + cond_1024, self.models['shape_slat_flow_model_1024'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + res = 1024 + elif pipeline_type == '1024_cascade': + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1024, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + elif pipeline_type == '1536_cascade': + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1536, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + torch.cuda.empty_cache() + out_mesh = self.decode_latent(shape_slat, tex_slat, res) + if return_latent: + return out_mesh, (shape_slat, tex_slat, res) + else: + return out_mesh diff --git a/trellis2/pipelines/trellis2_texturing.py b/trellis2/pipelines/trellis2_texturing.py new file mode 100644 index 0000000000000000000000000000000000000000..c184b5e73ab49d5f29256e16ff900abc92be73be --- /dev/null +++ b/trellis2/pipelines/trellis2_texturing.py @@ -0,0 +1,408 @@ +from typing import * +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import trimesh +from .base import Pipeline +from . import samplers, rembg +from ..modules.sparse import SparseTensor +from ..modules import image_feature_extractor +import o_voxel +import cumesh +import nvdiffrast.torch as dr +import cv2 +import flex_gemm + + +class Trellis2TexturingPipeline(Pipeline): + """ + Pipeline for inferring Trellis2 image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. + tex_slat_sampler_params (dict): The parameters for the texture latent sampler. + shape_slat_normalization (dict): The normalization parameters for the structured latent. + tex_slat_normalization (dict): The normalization parameters for the texture latent. + image_cond_model (Callable): The image conditioning model. + rembg_model (Callable): The model for removing background. + low_vram (bool): Whether to use low-VRAM mode. + """ + model_names_to_load = [ + 'shape_slat_encoder', + 'tex_slat_decoder', + 'tex_slat_flow_model_512', + 'tex_slat_flow_model_1024' + ] + + def __init__( + self, + models: dict[str, nn.Module] = None, + tex_slat_sampler: samplers.Sampler = None, + tex_slat_sampler_params: dict = None, + shape_slat_normalization: dict = None, + tex_slat_normalization: dict = None, + image_cond_model: Callable = None, + rembg_model: Callable = None, + low_vram: bool = True, + ): + if models is None: + return + super().__init__(models) + self.tex_slat_sampler = tex_slat_sampler + self.tex_slat_sampler_params = tex_slat_sampler_params + self.shape_slat_normalization = shape_slat_normalization + self.tex_slat_normalization = tex_slat_normalization + self.image_cond_model = image_cond_model + self.rembg_model = rembg_model + self.low_vram = low_vram + self.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + self._device = 'cpu' + + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super().from_pretrained(path, config_file) + args = pipeline._pretrained_args + + pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + + pipeline.shape_slat_normalization = args['shape_slat_normalization'] + pipeline.tex_slat_normalization = args['tex_slat_normalization'] + + pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) + pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + + pipeline.low_vram = args.get('low_vram', True) + pipeline.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + pipeline._device = 'cpu' + return pipeline + + def to(self, device: torch.device) -> None: + self._device = device + if not self.low_vram: + super().to(device) + self.image_cond_model.to(device) + if self.rembg_model is not None: + self.rembg_model.to(device) + + def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: + """ + Preprocess the input mesh. + """ + vertices = mesh.vertices + vertices_min = vertices.min(axis=0) + vertices_max = vertices.max(axis=0) + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + tmp = vertices[:, 1].copy() + vertices[:, 1] = -vertices[:, 2] + vertices[:, 2] = tmp + assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range' + return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False) + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if has_alpha: + output = input + else: + input = input.convert('RGB') + if self.low_vram: + self.rembg_model.to(self.device) + output = self.rembg_model(input) + if self.low_vram: + self.rembg_model.cpu() + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + self.image_cond_model.image_size = resolution + if self.low_vram: + self.image_cond_model.to(self.device) + cond = self.image_cond_model(image) + if self.low_vram: + self.image_cond_model.cpu() + if not include_neg_cond: + return {'cond': cond} + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def encode_shape_slat( + self, + mesh: trimesh.Trimesh, + resolution: int = 1024, + ) -> SparseTensor: + """ + Encode the meshes to structured latent. + + Args: + mesh (trimesh.Trimesh): The mesh to encode. + resolution (int): The resolution of mesh + + Returns: + SparseTensor: The encoded structured latent. + """ + vertices = torch.from_numpy(mesh.vertices).float() + faces = torch.from_numpy(mesh.faces).long() + + voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( + vertices.cpu(), faces.cpu(), + grid_size=resolution, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], + face_weight=1.0, + boundary_weight=0.2, + regularization_weight=1e-2, + timing=True, + ) + + vertices = SparseTensor( + feats=dual_vertices * resolution - voxel_indices, + coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) + ).to(self.device) + intersected = vertices.replace(intersected).to(self.device) + + if self.low_vram: + self.models['shape_slat_encoder'].to(self.device) + shape_slat = self.models['shape_slat_encoder'](vertices, intersected) + if self.low_vram: + self.models['shape_slat_encoder'].cpu() + return shape_slat + + def sample_tex_slat( + self, + cond: dict, + flow_model, + shape_slat: SparseTensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + shape_slat (SparseTensor): The structured latent for shape + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) + shape_slat = (shape_slat - mean) / std + + in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels + noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) + sampler_params = {**self.tex_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.tex_slat_sampler.sample( + flow_model, + noise, + concat_cond=shape_slat, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling texture SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def decode_tex_slat( + self, + slat: SparseTensor, + ) -> SparseTensor: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + SparseTensor: The decoded texture voxels + """ + if self.low_vram: + self.models['tex_slat_decoder'].to(self.device) + ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5 + if self.low_vram: + self.models['tex_slat_decoder'].cpu() + return ret + + def postprocess_mesh( + self, + mesh: trimesh.Trimesh, + pbr_voxel: SparseTensor, + resolution: int = 1024, + texture_size: int = 1024, + ) -> trimesh.Trimesh: + vertices = mesh.vertices + faces = mesh.faces + normals = mesh.vertex_normals + vertices_torch = torch.from_numpy(vertices).float().cuda() + faces_torch = torch.from_numpy(faces).int().cuda() + if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None: + uvs = mesh.visual.uv.copy() + uvs[:, 1] = 1 - uvs[:, 1] + uvs_torch = torch.from_numpy(uvs).float().cuda() + else: + _cumesh = cumesh.CuMesh() + _cumesh.init(vertices_torch, faces_torch) + vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True) + vertices_torch = vertices_torch.cuda() + faces_torch = faces_torch.cuda() + uvs_torch = uvs_torch.cuda() + vertices = vertices_torch.cpu().numpy() + faces = faces_torch.cpu().numpy() + uvs = uvs_torch.cpu().numpy() + normals = normals[vmap.cpu().numpy()] + + # rasterize + ctx = dr.RasterizeCudaContext() + uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0) + rast, _ = dr.rasterize( + ctx, uvs_torch, faces_torch, + resolution=[texture_size, texture_size], + ) + mask = rast[0, ..., 3] > 0 + pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0] + + attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device) + attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d( + pbr_voxel.feats, + pbr_voxel.coords, + shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]), + grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3), + mode='trilinear', + ) + + # construct mesh + mask = mask.cpu().numpy() + base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + + # extend + mask = (~mask).astype(np.uint8) + base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA) + metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None] + roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None] + alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None] + + material = trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)), + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), + metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)), + metallicFactor=1.0, + roughnessFactor=1.0, + alphaMode='OPAQUE', + doubleSided=True, + ) + + # Swap Y and Z axes, invert Y (common conversion for GLB compatibility) + vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1] + normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1] + uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate + + textured_mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_normals=normals, + process=False, + visual=trimesh.visual.TextureVisuals(uv=uvs, material=material) + ) + + return textured_mesh + + + @torch.no_grad() + def run( + self, + mesh: trimesh.Trimesh, + image: Image.Image, + seed: int = 42, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + resolution: int = 1024, + texture_size: int = 2048, + ) -> trimesh.Trimesh: + """ + Run the pipeline. + + Args: + mesh (trimesh.Trimesh): The mesh to texture. + image (Image.Image): The image prompt. + seed (int): The random seed. + tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + image = self.preprocess_image(image) + mesh = self.preprocess_mesh(mesh) + torch.manual_seed(seed) + cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024) + shape_slat = self.encode_shape_slat(mesh, resolution) + tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024'] + tex_slat = self.sample_tex_slat( + cond, tex_model, + shape_slat, tex_slat_sampler_params + ) + pbr_voxel = self.decode_tex_slat(tex_slat) + out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size) + return out_mesh diff --git a/trellis2/renderers/__init__.py b/trellis2/renderers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de3203d1bb16065f912bff039e431f609911782d --- /dev/null +++ b/trellis2/renderers/__init__.py @@ -0,0 +1,33 @@ +import importlib + +__attributes = { + 'MeshRenderer': 'mesh_renderer', + 'VoxelRenderer': 'voxel_renderer', + 'PbrMeshRenderer': 'pbr_mesh_renderer', + 'EnvMap': 'pbr_mesh_renderer', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh_renderer import MeshRenderer + from .voxel_renderer import VoxelRenderer + from .pbr_mesh_renderer import PbrMeshRenderer, EnvMap + \ No newline at end of file diff --git a/trellis2/renderers/mesh_renderer.py b/trellis2/renderers/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..e20efc59c26fadcbd9c11cbafa2e31b78eec395c --- /dev/null +++ b/trellis2/renderers/mesh_renderer.py @@ -0,0 +1,414 @@ +from typing import * +import torch +from easydict import EasyDict as edict +from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode +import torch.nn.functional as F + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = (far + near) / (far - near) + ret[2, 3] = 2 * near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class MeshRenderer: + """ + Renderer for the Mesh representation. + + Args: + rendering_options (dict): Rendering options. + """ + def __init__(self, rendering_options={}, device='cuda'): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "chunk_size": None, + "antialias": True, + "clamp_barycentric_coords": False, + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : Mesh, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types = ["mask", "normal", "depth"], + transformation : Optional[torch.Tensor] = None + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + return_types (list): list of return types, can be "attr", "mask", "depth", "coord", "normal" + + Returns: + edict based on return_types containing: + attr (torch.Tensor): [C, H, W] rendered attr image + depth (torch.Tensor): [H, W] rendered depth image + normal (torch.Tensor): [3, H, W] rendered normal image + mask (torch.Tensor): [H, W] rendered mask image + """ + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + chunk_size = self.rendering_options["chunk_size"] + antialias = self.rendering_options["antialias"] + clamp_barycentric_coords = self.rendering_options["clamp_barycentric_coords"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + ret_dict = edict() + for type in return_types: + if type == "mask" : + ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "depth": + ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "normal": + ret_dict[type] = torch.full((3, resolution, resolution), 0.5, dtype=torch.float32, device=self.device) + elif type == "coord": + ret_dict[type] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + ret_dict[type] = torch.zeros((mesh.attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device) + else: + ret_dict[type] = torch.zeros((mesh.vertex_attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device) + return ret_dict + + perspective = intrinsics_to_projection(intrinsics, near, far) + + full_proj = (perspective @ extrinsics).unsqueeze(0) + extrinsics = extrinsics.unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + if transformation is not None: + vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2)) + vertices = vertices_homo[..., :3].contiguous() + vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces = mesh.faces + + if 'normal' in return_types: + v0 = vertices_camera[0, mesh.faces[:, 0], :3] + v1 = vertices_camera[0, mesh.faces[:, 1], :3] + v2 = vertices_camera[0, mesh.faces[:, 2], :3] + e0 = v1 - v0 + e1 = v2 - v0 + face_normal = torch.cross(e0, e1, dim=1) + face_normal = F.normalize(face_normal, dim=1) + face_normal = torch.where(torch.sum(face_normal * v0, dim=1, keepdim=True) > 0, face_normal, -face_normal) + + out_dict = edict() + if chunk_size is None: + rast, rast_db = dr.rasterize( + self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa) + ) + if clamp_barycentric_coords: + rast[..., :2] = torch.clamp(rast[..., :2], 0, 1) + rast[..., :2] /= torch.where(rast[..., :2].sum(dim=-1, keepdim=True) > 1, rast[..., :2].sum(dim=-1, keepdim=True), torch.ones_like(rast[..., :2])) + for type in return_types: + img = None + if type == "mask" : + img = (rast[..., -1:] > 0).float() + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "normal" : + img = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + img = (img + 1) / 2 + elif type == "coord": + img = dr.interpolate(vertices, rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices, rast, faces)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + imgs = { + 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device), + 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + } + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + base_color = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['base_color'] += base_color * mat.base_color_factor * mat_mask + else: + imgs['base_color'] += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + metallic = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['metallic'] += metallic * mat.metallic_factor * mat_mask + else: + imgs['metallic'] += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + roughness = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['roughness'] += roughness * mat.roughness_factor * mat_mask + else: + imgs['roughness'] += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + imgs['alpha'] += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + alpha = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += alpha * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += mat.alpha_factor * mat_mask + + img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0) + else: + img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + + out_dict[type] = img + else: + z_buffer = torch.full((1, resolution * ssaa, resolution * ssaa), torch.inf, device=self.device, dtype=torch.float32) + for i in range(0, faces.shape[0], chunk_size): + faces_chunk = faces[i:i+chunk_size] + rast, rast_db = dr.rasterize( + self.glctx, vertices_clip, faces_chunk, (resolution * ssaa, resolution * ssaa) + ) + z_filter = torch.logical_and( + rast[..., 3] != 0, + rast[..., 2] < z_buffer + ) + z_buffer[z_filter] = rast[z_filter][..., 2] + + for type in return_types: + img = None + if type == "mask" : + img = (rast[..., -1:] > 0).float() + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_chunk)[0] + elif type == "normal" : + face_normal_chunk = face_normal[i:i+chunk_size] + img = dr.interpolate(face_normal_chunk.unsqueeze(0), rast, torch.arange(face_normal_chunk.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0] + img = (img + 1) / 2 + elif type == "coord": + img = dr.interpolate(vertices, rast, faces_chunk)[0] + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices, rast, faces_chunk)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + imgs = { + 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device), + 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + } + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + base_color = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['base_color'] += base_color * mat.base_color_factor * mat_mask + else: + imgs['base_color'] += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + metallic = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['metallic'] += metallic * mat.metallic_factor * mat_mask + else: + imgs['metallic'] += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + roughness = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['roughness'] += roughness * mat.roughness_factor * mat_mask + else: + imgs['roughness'] += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + imgs['alpha'] += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + alpha = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += alpha * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += mat.alpha_factor * mat_mask + + img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0) + else: + img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces_chunk)[0] + + if type not in out_dict: + out_dict[type] = img + else: + out_dict[type][z_filter] = img[z_filter] + + for type in return_types: + img = out_dict[type] + if ssaa > 1: + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = img.squeeze() + else: + img = img.permute(0, 3, 1, 2).squeeze() + out_dict[type] = img + + if isinstance(mesh, (MeshWithVoxel, MeshWithPbrMaterial)) and 'attr' in return_types: + for k, s in mesh.layout.items(): + out_dict[k] = out_dict['attr'][s] + del out_dict['attr'] + + return out_dict diff --git a/trellis2/renderers/pbr_mesh_renderer.py b/trellis2/renderers/pbr_mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..14ded300c861e163e15c1bff4c3dd7ee27a08e7c --- /dev/null +++ b/trellis2/renderers/pbr_mesh_renderer.py @@ -0,0 +1,521 @@ +from typing import * +import torch +from easydict import EasyDict as edict +import numpy as np +import utils3d +from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode +import torch.nn.functional as F + + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -x, -y + elif s == 1: rx, ry, rz = -torch.ones_like(x), x, -y + elif s == 2: rx, ry, rz = x, y, torch.ones_like(x) + elif s == 3: rx, ry, rz = x, -y, -torch.ones_like(x) + elif s == 4: rx, ry, rz = x, torch.ones_like(x), -y + elif s == 5: rx, ry, rz = -x, -torch.ones_like(x), -y + return torch.stack((rx, ry, rz), dim=-1) + + +def latlong_to_cubemap(latlong_map, res): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = F.normalize(cube_to_dir(s, gx, gy), dim=-1) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + + +class EnvMap: + def __init__(self, image: torch.Tensor): + self.image = image + + @property + def _backend(self): + if not hasattr(self, '_nvdiffrec_envlight'): + if 'EnvironmentLight' not in globals(): + from nvdiffrec_render.light import EnvironmentLight + cubemap = latlong_to_cubemap(self.image, [512, 512]) + self._nvdiffrec_envlight = EnvironmentLight(cubemap) + self._nvdiffrec_envlight.build_mips() + return self._nvdiffrec_envlight + + def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): + return self._backend.shade(gb_pos, gb_normal, kd, ks, view_pos, specular) + + def sample(self, directions: torch.Tensor): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + return dr.texture( + self._backend.base.unsqueeze(0), + directions.unsqueeze(0), + boundary_mode='cube', + )[0] + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = (far + near) / (far - near) + ret[2, 3] = 2 * near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def screen_space_ambient_occlusion( + depth: torch.Tensor, + normal: torch.Tensor, + perspective: torch.Tensor, + radius: float = 0.1, + bias: float = 1e-6, + samples: int = 64, + intensity: float = 1.0, +) -> torch.Tensor: + """ + Screen space ambient occlusion (SSAO) + + Args: + depth (torch.Tensor): [H, W, 1] depth image + normal (torch.Tensor): [H, W, 3] normal image + perspective (torch.Tensor): [4, 4] camera projection matrix + radius (float): radius of the SSAO kernel + bias (float): bias to avoid self-occlusion + samples (int): number of samples to use for the SSAO kernel + intensity (float): intensity of the SSAO effect + Returns: + (torch.Tensor): [H, W, 1] SSAO image + """ + device = depth.device + H, W, _ = depth.shape + + fx = perspective[0, 0] + fy = perspective[1, 1] + cx = perspective[0, 2] + cy = perspective[1, 2] + + y_grid, x_grid = torch.meshgrid( + (torch.arange(H, device=device) + 0.5) / H * 2 - 1, + (torch.arange(W, device=device) + 0.5) / W * 2 - 1, + indexing='ij' + ) + x_view = (x_grid.float() - cx) * depth[..., 0] / fx + y_view = (y_grid.float() - cy) * depth[..., 0] / fy + view_pos = torch.stack([x_view, y_view, depth[..., 0]], dim=-1) # [H, W, 3] + + depth_feat = depth.permute(2, 0, 1).unsqueeze(0) + occlusion = torch.zeros((H, W), device=device) + + # start sampling + for _ in range(samples): + # sample normal distribution, if inside, flip the sign + rnd_vec = torch.randn(H, W, 3, device=device) + rnd_vec = F.normalize(rnd_vec, p=2, dim=-1) + dot_val = torch.sum(rnd_vec * normal, dim=-1, keepdim=True) + sample_dir = torch.sign(dot_val) * rnd_vec + scale = torch.rand(H, W, 1, device=device) + scale = scale * scale + sample_pos = view_pos + sample_dir * radius * scale + sample_z = sample_pos[..., 2] + + # project to screen space + z_safe = torch.clamp(sample_pos[..., 2], min=1e-5) + proj_u = (sample_pos[..., 0] * fx / z_safe) + cx + proj_v = (sample_pos[..., 1] * fy / z_safe) + cy + grid = torch.stack([proj_u, proj_v], dim=-1).unsqueeze(0) + geo_z = F.grid_sample(depth_feat, grid, mode='nearest', padding_mode='border').squeeze() + range_check = torch.abs(geo_z - sample_z) < radius + is_occluded = (geo_z <= sample_z - bias) & range_check + occlusion += is_occluded.float() + + f_occ = occlusion / samples * intensity + f_occ = torch.clamp(f_occ, 0.0, 1.0) + + return f_occ.unsqueeze(-1) + + +def aces_tonemapping(x: torch.Tensor) -> torch.Tensor: + """ + Applies ACES tone mapping curve to an HDR image tensor. + Input: x - HDR tensor, shape (..., 3), range [0, +inf) + Output: LDR tensor, same shape, range [0, 1] + + NOTE: This causes bright base_color objects to over-expose (appear white) + compared to Blender's standard sRGB display transform. Use linear_to_srgb() + instead for renders_cond alignment. See pbr_color_diff_debug_guide.md #5. + """ + a = 2.51 + b = 0.03 + c = 2.43 + d = 0.59 + e = 0.14 + + # Apply the ACES fitted curve + mapped = (x * (a * x + b)) / (x * (c * x + d) + e) + + # Clamp to [0, 1] for display or saving + return torch.clamp(mapped, 0.0, 1.0) + + +def gamma_correction(x: torch.Tensor, gamma: float = 2.2) -> torch.Tensor: + """ + Applies gamma correction to an HDR image tensor. + """ + return torch.clamp(x ** (1.0 / gamma), 0.0, 1.0) + + +def linear_to_srgb(x: torch.Tensor) -> torch.Tensor: + """ + Standard linear-to-sRGB conversion matching Blender's sRGB display transform. + + Applies the official sRGB EOTF inverse: + - For values <= 0.0031308: sRGB = 12.92 * linear + - For values > 0.0031308: sRGB = 1.055 * linear^(1/2.4) - 0.055 + + Input: x - linear HDR tensor, clamped to [0, +inf) + Output: sRGB tensor, range [0, 1] + """ + x = torch.clamp(x, 0.0) + low = 12.92 * x + high = 1.055 * torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) - 0.055 + return torch.clamp(torch.where(x <= 0.0031308, low, high), 0.0, 1.0) + + +class PbrMeshRenderer: + """ + Renderer for the PBR mesh. + + Args: + rendering_options (dict): Rendering options. + """ + def __init__(self, rendering_options={}, device='cuda'): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "peel_layers": 8, + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : Mesh, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + envmap : Union[EnvMap, Dict[str, EnvMap]], + use_envmap_bg : bool = False, + transformation : Optional[torch.Tensor] = None + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + envmap (Union[EnvMap, Dict[str, EnvMap]]): environment map or a dictionary of environment maps + use_envmap_bg (bool): whether to use envmap as background + transformation (torch.Tensor): (4, 4) transformation matrix + + Returns: + edict based on return_types containing: + shaded (torch.Tensor): [3, H, W] shaded color image + normal (torch.Tensor): [3, H, W] normal image + base_color (torch.Tensor): [3, H, W] base color image + metallic (torch.Tensor): [H, W] metallic image + roughness (torch.Tensor): [H, W] roughness image + """ + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + if not isinstance(envmap, dict): + envmap = {'' : envmap} + num_envmaps = len(envmap) + + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0 or \ + not torch.isfinite(mesh.vertices).all() or \ + mesh.faces.max() >= mesh.vertices.shape[0] or mesh.faces.min() < 0: + if mesh.vertices.shape[0] > 0 and not torch.isfinite(mesh.vertices).all(): + print(f"[PbrMeshRenderer] WARNING: mesh vertices contain NaN/Inf, returning blank image.") + if mesh.faces.shape[0] > 0 and mesh.vertices.shape[0] > 0 and \ + (mesh.faces.max() >= mesh.vertices.shape[0] or mesh.faces.min() < 0): + print(f"[PbrMeshRenderer] WARNING: mesh faces contain out-of-bound indices " + f"(max={mesh.faces.max().item()}, min={mesh.faces.min().item()}, " + f"num_vertices={mesh.vertices.shape[0]}), returning blank image.") + out_dict = edict( + normal=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device), + mask=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + base_color=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device), + metallic=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + roughness=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + alpha=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + clay=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + ) + for i, k in enumerate(envmap.keys()): + shaded_key = f"shaded_{k}" if k != '' else "shaded" + out_dict[shaded_key] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device) + return out_dict + + rays_o, rays_d = utils3d.torch.get_image_rays( + extrinsics, intrinsics, resolution * ssaa, resolution * ssaa + ) + + perspective = intrinsics_to_projection(intrinsics, near, far) + + full_proj = (perspective @ extrinsics).unsqueeze(0) + extrinsics = extrinsics.unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + vertices_orig = vertices.clone() + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + if transformation is not None: + vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2)) + vertices = vertices_homo[..., :3].contiguous() + vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces = mesh.faces + + v0 = vertices[0, mesh.faces[:, 0], :3] + v1 = vertices[0, mesh.faces[:, 1], :3] + v2 = vertices[0, mesh.faces[:, 2], :3] + e0 = v1 - v0 + e1 = v2 - v0 + face_normal = torch.cross(e0, e1, dim=1) + face_normal = F.normalize(face_normal, dim=1) + + out_dict = edict() + shaded = torch.zeros((num_envmaps, resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + depth = torch.full((resolution * ssaa, resolution * ssaa, 1), 1e10, dtype=torch.float32, device=self.device) + normal = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + max_w = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + with dr.DepthPeeler(self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa)) as peeler: + for _ in range(self.rendering_options["peel_layers"]): + rast, rast_db = peeler.rasterize_next_layer() + + # Pos + pos = dr.interpolate(vertices, rast, faces)[0][0] + + # Depth + gb_depth = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0][0] + + # Normal + gb_normal = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0][0] + gb_normal = torch.where( + torch.sum(gb_normal * (pos - rays_o), dim=-1, keepdim=True) > 0, + -gb_normal, + gb_normal + ) + gb_cam_normal = (extrinsics[..., :3, :3].reshape(1, 1, 3, 3) @ gb_normal.unsqueeze(-1)).squeeze(-1) + if _ == 0: + out_dict.normal = -gb_cam_normal * 0.5 + 0.5 + mask = (rast[0, ..., -1:] > 0).float() + out_dict.mask = mask + + # PBR attributes + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices_orig, rast, faces)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + gb_basecolor = img[0, ..., mesh.layout['base_color']] + gb_metallic = img[0, ..., mesh.layout['metallic']] + gb_roughness = img[0, ..., mesh.layout['roughness']] + gb_alpha = img[0, ..., mesh.layout['alpha']] + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + gb_basecolor = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + gb_metallic = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + gb_roughness = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + gb_alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + bc = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_basecolor += bc * mat.base_color_factor * mat_mask + else: + gb_basecolor += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + m = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_metallic += m * mat.metallic_factor * mat_mask + else: + gb_metallic += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + r = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_roughness += r * mat.roughness_factor * mat_mask + else: + gb_roughness += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + gb_alpha += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + a = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + gb_alpha += (a * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + gb_alpha += a * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + gb_alpha += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + gb_alpha += mat.alpha_factor * mat_mask + if _ == 0: + out_dict.base_color = gb_basecolor + out_dict.metallic = gb_metallic + out_dict.roughness = gb_roughness + out_dict.alpha = gb_alpha + + # Shading + #TODO + gb_basecolor = torch.clamp(gb_basecolor, 0.0, 1.0) ** 2.2 + # gb_basecolor = torch.clamp(gb_basecolor, 0.0, 1.0) + gb_metallic = torch.clamp(gb_metallic, 0.0, 1.0) + gb_roughness = torch.clamp(gb_roughness, 0.0, 1.0) + gb_alpha = torch.clamp(gb_alpha, 0.0, 1.0) + gb_orm = torch.cat([ + torch.zeros_like(gb_metallic), + gb_roughness, + gb_metallic, + ], dim=-1) + gb_shaded = torch.stack([ + e.shade( + pos.unsqueeze(0), + gb_normal.unsqueeze(0), + gb_basecolor.unsqueeze(0), + gb_orm.unsqueeze(0), + rays_o, + specular=True, + )[0] + for e in envmap.values() + ], dim=0) + + # Compositing + w = (1 - alpha) * gb_alpha + depth = torch.where(w > max_w, gb_depth, depth) + normal = torch.where(w > max_w, gb_cam_normal, normal) + max_w = torch.maximum(max_w, w) + shaded += w * gb_shaded + alpha += w + + # Ambient occulusion + f_occ = screen_space_ambient_occlusion( + depth, normal, perspective, intensity=1.5 + ) + shaded *= (1 - f_occ) + out_dict.clay = (1 - f_occ) + + # Background + if use_envmap_bg: + bg = torch.stack([e.sample(rays_d) for e in envmap.values()], dim=0) + shaded += (1 - alpha) * bg + + for i, k in enumerate(envmap.keys()): + shaded_key = f"shaded_{k}" if k != '' else "shaded" + out_dict[shaded_key] = shaded[i] + + # SSAA + for k in out_dict.keys(): + if ssaa > 1: + out_dict[k] = F.interpolate(out_dict[k].unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + else: + out_dict[k] = out_dict[k].permute(2, 0, 1) + out_dict[k] = out_dict[k].squeeze() + + # Post processing: linear → sRGB (matches Blender's display transform) + for k in envmap.keys(): + shaded_key = f"shaded_{k}" if k != '' else "shaded" + out_dict[shaded_key] = linear_to_srgb(out_dict[shaded_key]) + + return out_dict diff --git a/trellis2/renderers/voxel_renderer.py b/trellis2/renderers/voxel_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe28ad8d341ed62c1d7a5ab739fed6cace30a5f --- /dev/null +++ b/trellis2/renderers/voxel_renderer.py @@ -0,0 +1,68 @@ +import torch +from easydict import EasyDict as edict +from ..representations import Voxel +from easydict import EasyDict as edict + + +class VoxelRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.rendering_options = edict({ + "resolution": None, + "near": 0.1, + "far": 10.0, + "ssaa": 1, + }) + self.rendering_options.update(rendering_options) + + def render( + self, + voxel: Voxel, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + """ + Render the gausssian. + + Args: + voxel (Voxel): Voxel representation. + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + ... + """ + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + renderer = o_voxel.rasterize.VoxelRenderer(self.rendering_options) + positions = voxel.position + attrs = voxel.attrs if colors_overwrite is None else colors_overwrite + voxel_size = voxel.voxel_size + + # Render + render_ret = renderer.render(positions, attrs, voxel_size, extrinsics, intrinsics) + + ret = { + 'depth': render_ret['depth'], + 'alpha': render_ret['alpha'], + } + if colors_overwrite is not None: + ret['color'] = render_ret['attr'] + else: + for k, s in voxel.layout.items(): + ret[k] = render_ret['attr'][s] + + return ret diff --git a/trellis2/representations/__init__.py b/trellis2/representations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7d9299f866c344e81e27f748f837e5ce81ed8b --- /dev/null +++ b/trellis2/representations/__init__.py @@ -0,0 +1,31 @@ +import importlib + +__attributes = { + 'Mesh': 'mesh', + 'Voxel': 'voxel', + 'MeshWithVoxel': 'mesh', + 'MeshWithPbrMaterial': 'mesh', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial + from .voxel import Voxel diff --git a/trellis2/representations/mesh/__init__.py b/trellis2/representations/mesh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aff4c99a97f764a7c695d87d6a60bd03d61e2106 --- /dev/null +++ b/trellis2/representations/mesh/__init__.py @@ -0,0 +1 @@ +from .base import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture diff --git a/trellis2/representations/mesh/base.py b/trellis2/representations/mesh/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6bdb47fe89047233628b44900124723770e66502 --- /dev/null +++ b/trellis2/representations/mesh/base.py @@ -0,0 +1,234 @@ +from typing import * +import torch +from ..voxel import Voxel +import cumesh +from flex_gemm.ops.grid_sample import grid_sample_3d + + +class Mesh: + def __init__(self, + vertices, + faces, + vertex_attrs=None + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.vertex_attrs = vertex_attrs + + @property + def device(self): + return self.vertices.device + + def to(self, device, non_blocking=False): + return Mesh( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, + ) + + def cuda(self, non_blocking=False): + return self.to('cuda', non_blocking=non_blocking) + + def cpu(self): + return self.to('cpu') + + def fill_holes(self, max_hole_perimeter=3e-2): + vertices = self.vertices.clone().cuda().contiguous() + faces = self.faces.clone().cuda().contiguous() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.get_edges() + mesh.get_boundary_info() + if mesh.num_boundaries == 0: + return + mesh.get_vertex_edge_adjacency() + mesh.get_vertex_boundary_adjacency() + mesh.get_manifold_boundary_adjacency() + mesh.read_manifold_boundary_adjacency() + mesh.get_boundary_connected_components() + mesh.get_boundary_loops() + if mesh.num_boundary_loops == 0: + return + mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + def remove_faces(self, face_mask: torch.Tensor): + vertices = self.vertices.clone().cuda().contiguous() + faces = self.faces.clone().cuda().contiguous() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.remove_faces(face_mask) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + def simplify(self, target=1000000, verbose: bool=False, options: dict={}): + vertices = self.vertices.clone().cuda().contiguous() + faces = self.faces.clone().cuda().contiguous() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.simplify(target, verbose=verbose, options=options) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + +class TextureFilterMode: + CLOSEST = 0 + LINEAR = 1 + + +class TextureWrapMode: + CLAMP_TO_EDGE = 0 + REPEAT = 1 + MIRRORED_REPEAT = 2 + + +class AlphaMode: + OPAQUE = 0 + MASK = 1 + BLEND = 2 + + +class Texture: + def __init__( + self, + image: torch.Tensor, + filter_mode: TextureFilterMode = TextureFilterMode.LINEAR, + wrap_mode: TextureWrapMode = TextureWrapMode.REPEAT + ): + self.image = image + self.filter_mode = filter_mode + self.wrap_mode = wrap_mode + + def to(self, device, non_blocking=False): + return Texture( + self.image.to(device, non_blocking=non_blocking), + self.filter_mode, + self.wrap_mode, + ) + + +class PbrMaterial: + def __init__( + self, + base_color_texture: Optional[Texture] = None, + base_color_factor: Union[torch.Tensor, List[float]] = [1.0, 1.0, 1.0], + metallic_texture: Optional[Texture] = None, + metallic_factor: float = 1.0, + roughness_texture: Optional[Texture] = None, + roughness_factor: float = 1.0, + alpha_texture: Optional[Texture] = None, + alpha_factor: float = 1.0, + alpha_mode: AlphaMode = AlphaMode.OPAQUE, + alpha_cutoff: float = 0.5, + ): + self.base_color_texture = base_color_texture + self.base_color_factor = torch.tensor(base_color_factor, dtype=torch.float32)[:3] + self.metallic_texture = metallic_texture + self.metallic_factor = metallic_factor + self.roughness_texture = roughness_texture + self.roughness_factor = roughness_factor + self.alpha_texture = alpha_texture + self.alpha_factor = alpha_factor + self.alpha_mode = alpha_mode + self.alpha_cutoff = alpha_cutoff + + def to(self, device, non_blocking=False): + return PbrMaterial( + base_color_texture=self.base_color_texture.to(device, non_blocking=non_blocking) if self.base_color_texture is not None else None, + base_color_factor=self.base_color_factor.to(device, non_blocking=non_blocking), + metallic_texture=self.metallic_texture.to(device, non_blocking=non_blocking) if self.metallic_texture is not None else None, + metallic_factor=self.metallic_factor, + roughness_texture=self.roughness_texture.to(device, non_blocking=non_blocking) if self.roughness_texture is not None else None, + roughness_factor=self.roughness_factor, + alpha_texture=self.alpha_texture.to(device, non_blocking=non_blocking) if self.alpha_texture is not None else None, + alpha_factor=self.alpha_factor, + alpha_mode=self.alpha_mode, + alpha_cutoff=self.alpha_cutoff, + ) + + +class MeshWithPbrMaterial(Mesh): + def __init__(self, + vertices, + faces, + material_ids, + uv_coords, + materials: List[PbrMaterial], + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.material_ids = material_ids # [M] + self.uv_coords = uv_coords # [M, 3, 2] + self.materials = materials + self.layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + + def to(self, device, non_blocking=False): + return MeshWithPbrMaterial( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.material_ids.to(device, non_blocking=non_blocking), + self.uv_coords.to(device, non_blocking=non_blocking), + [material.to(device, non_blocking=non_blocking) for material in self.materials], + ) + + +class MeshWithVoxel(Mesh, Voxel): + def __init__(self, + vertices: torch.Tensor, + faces: torch.Tensor, + origin: list, + voxel_size: float, + coords: torch.Tensor, + attrs: torch.Tensor, + voxel_shape: torch.Size, + layout: Dict = {}, + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.voxel_shape = voxel_shape + self.layout = layout + + def to(self, device, non_blocking=False): + return MeshWithVoxel( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.origin.tolist(), + self.voxel_size, + self.coords.to(device, non_blocking=non_blocking), + self.attrs.to(device, non_blocking=non_blocking), + self.voxel_shape, + self.layout, + ) + + def query_attrs(self, xyz): + grid = ((xyz - self.origin) / self.voxel_size).reshape(1, -1, 3) + vertex_attrs = grid_sample_3d( + self.attrs, + torch.cat([torch.zeros_like(self.coords[..., :1]), self.coords], dim=-1), + self.voxel_shape, + grid, + mode='trilinear' + )[0] + return vertex_attrs + + def query_vertex_attrs(self): + return self.query_attrs(self.vertices) diff --git a/trellis2/representations/voxel/__init__.py b/trellis2/representations/voxel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5792ea14b2371a96c4130371eb976f0aff4b5dd --- /dev/null +++ b/trellis2/representations/voxel/__init__.py @@ -0,0 +1 @@ +from .voxel_model import Voxel \ No newline at end of file diff --git a/trellis2/representations/voxel/voxel_model.py b/trellis2/representations/voxel/voxel_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9317ab22db61c5da96514ca46a780378392f3cc8 --- /dev/null +++ b/trellis2/representations/voxel/voxel_model.py @@ -0,0 +1,54 @@ +from typing import Dict +import torch + + +class Voxel: + def __init__( + self, + origin: list, + voxel_size: float, + coords: torch.Tensor = None, + attrs: torch.Tensor = None, + layout: Dict = {}, + device: torch.device = 'cuda' + ): + self.origin = torch.tensor(origin, dtype=torch.float32, device=device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.layout = layout + self.device = device + + @property + def position(self): + return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] + + def split_attrs(self): + return { + k: self.attrs[:, self.layout[k]] + for k in self.layout + } + + def save(self, path): + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + o_voxel.io.write( + path, + self.coords, + self.split_attrs(), + ) + + def load(self, path): + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + coord, attrs = o_voxel.io.read(path) + self.coords = coord.int().to(self.device) + self.attrs = torch.cat([attrs[k] for k in attrs], dim=1).to(self.device) + # build layout + start = 0 + self.layout = {} + for k in attrs: + self.layout[k] = slice(start, start + attrs[k].shape[1]) + start += attrs[k].shape[1] diff --git a/trellis2/trainers/__init__.py b/trellis2/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a107517e58c4582b6de4150b507012077700010 --- /dev/null +++ b/trellis2/trainers/__init__.py @@ -0,0 +1,74 @@ +import importlib + +__attributes = { + 'BasicTrainer': 'basic', + + 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae', + 'ShapeVaeTrainer': 'vae.shape_vae', + 'PbrVaeTrainer': 'vae.pbr_vae', + + 'FlowMatchingTrainer': 'flow_matching.flow_matching', + 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'ImageConditionedProjFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + + 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching', + 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'MultiImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'ImageConditionedProjSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + + 'DinoV2FeatureExtractor': 'flow_matching.mixins.image_conditioned', + 'DinoV3FeatureExtractor': 'flow_matching.mixins.image_conditioned', + 'DinoV3ProjFeatureExtractor': 'flow_matching.mixins.image_conditioned_proj', + 'DinoV3VaeProjFeatureExtractor': 'flow_matching.mixins.image_conditioned_proj', + 'ImageConditionedProjMixin': 'flow_matching.mixins.image_conditioned_proj', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import BasicTrainer + + from .vae.sparse_structure_vae import SparseStructureVaeTrainer + from .vae.shape_vae import ShapeVaeTrainer + from .vae.pbr_vae import PbrVaeTrainer + + from .flow_matching.flow_matching import ( + FlowMatchingTrainer, + FlowMatchingCFGTrainer, + TextConditionedFlowMatchingCFGTrainer, + ImageConditionedFlowMatchingCFGTrainer, + ImageConditionedProjFlowMatchingCFGTrainer, + ) + + from .flow_matching.sparse_flow_matching import ( + SparseFlowMatchingTrainer, + SparseFlowMatchingCFGTrainer, + TextConditionedSparseFlowMatchingCFGTrainer, + ImageConditionedSparseFlowMatchingCFGTrainer, + ) + + from .flow_matching.mixins.image_conditioned import ( + DinoV2FeatureExtractor, + DinoV3FeatureExtractor, + ) diff --git a/trellis2/trainers/basic.py b/trellis2/trainers/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac701b32d7376a8bdcb66415d30d0c7890f1819 --- /dev/null +++ b/trellis2/trainers/basic.py @@ -0,0 +1,1293 @@ +from abc import abstractmethod +import os +import time +import json +import copy +import threading +from functools import partial +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel as DDP +import numpy as np + +from torchvision import utils + +try: + import wandb + WANDB_AVAILABLE = True +except ImportError: + WANDB_AVAILABLE = False + +from .utils import * +from ..utils.general_utils import * +from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler +from ..utils.dist_utils import * +from ..utils import grad_clip_utils, elastic_utils + + +class BasicTrainer: + """ + Trainer for basic training loop. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + mix_precision_mode (str): + - None: No mixed precision. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + mix_precision_dtype (str): Mixed precision dtype. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + parallel_mode (str): Parallel mode. Options are 'ddp'. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + """ + def __init__(self, + models, + dataset, + *, + output_dir, + load_dir, + step, + max_steps, + batch_size=None, + batch_size_per_gpu=None, + batch_split=None, + optimizer={}, + lr_scheduler=None, + elastic=None, + grad_clip=None, + ema_rate=0.9999, + fp16_mode=None, + mix_precision_mode='inflat_all', + mix_precision_dtype='float16', + fp16_scale_growth=1e-3, + parallel_mode='ddp', + finetune_ckpt=None, + log_param_stats=False, + prefetch_data=True, + snapshot_batch_size=4, + snapshot_num_samples=64, + num_workers=None, + debug=False, + i_print=1000, + i_log=500, + i_sample=10000, + i_save=10000, + i_ddpcheck=10000, + wandb_run=None, # wandb run object + **kwargs + ): + assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.' + + self.models = models + self.dataset = dataset + self.batch_split = batch_split if batch_split is not None else 1 + self.max_steps = max_steps + self.debug = debug + self.optimizer_config = optimizer + self.lr_scheduler_config = lr_scheduler + self.elastic_controller_config = elastic + self.grad_clip = grad_clip + self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate + if fp16_mode is not None: + mix_precision_dtype = 'float16' + mix_precision_mode = fp16_mode + self.mix_precision_mode = mix_precision_mode + self.mix_precision_dtype = str_to_dtype(mix_precision_dtype) + self.fp16_scale_growth = fp16_scale_growth + self.parallel_mode = parallel_mode + self.log_param_stats = log_param_stats + self.prefetch_data = prefetch_data + self.snapshot_batch_size = snapshot_batch_size + self.snapshot_num_samples = snapshot_num_samples + self.num_workers = num_workers + self.log = [] + if self.prefetch_data: + self._data_prefetched = None + + self.output_dir = output_dir + from datetime import datetime + self._log_file = os.path.join(self.output_dir, f'log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt') + self.i_print = i_print + self.i_log = i_log + self.i_sample = i_sample + self.i_save = i_save + self.i_ddpcheck = i_ddpcheck + + if dist.is_initialized(): + # Multi-GPU params + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.local_rank = dist.get_rank() % torch.cuda.device_count() + self.is_master = self.rank == 0 + else: + # Single-GPU params + self.world_size = 1 + self.rank = 0 + self.local_rank = 0 + self.is_master = True + + self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size + self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size + assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.' + assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.' + + self.init_models_and_more(**kwargs) + self.prepare_dataloader(**kwargs) + + # Load checkpoint + self.step = 0 + if load_dir is not None and step is not None: + self.load(load_dir, step) + elif finetune_ckpt is not None: + self.finetune_from(finetune_ckpt) + + if self.is_master: + os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True) + os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True) + self.writer = None # TensorBoard disabled (S3 FUSE does not support append) + # Initialize wandb + self.wandb_run = wandb_run + if self.wandb_run is not None: + print(f'Wandb logging enabled: {self.wandb_run.url}') + + if self.parallel_mode == 'ddp' and self.world_size > 1: + self.check_ddp() + + if self.is_master: + print('\n\nTrainer initialized.') + print(self) + + def __str__(self): + lines = [] + lines.append(self.__class__.__name__) + lines.append(f' - Models:') + for name, model in self.models.items(): + lines.append(f' - {name}: {model.__class__.__name__}') + lines.append(f' - Dataset: {indent(str(self.dataset), 2)}') + lines.append(f' - Dataloader:') + lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}') + lines.append(f' - Num workers: {self.dataloader.num_workers}') + lines.append(f' - Number of steps: {self.max_steps}') + lines.append(f' - Number of GPUs: {self.world_size}') + lines.append(f' - Batch size: {self.batch_size}') + lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}') + lines.append(f' - Batch split: {self.batch_split}') + lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}') + lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}') + if self.lr_scheduler_config is not None: + lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}') + if self.elastic_controller_config is not None: + lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}') + if self.grad_clip is not None: + lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}') + lines.append(f' - EMA rate: {self.ema_rate}') + lines.append(f' - Mixed precision dtype: {self.mix_precision_dtype}') + lines.append(f' - Mixed precision mode: {self.mix_precision_mode}') + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + lines.append(f' - FP16 scale growth: {self.fp16_scale_growth}') + lines.append(f' - Parallel mode: {self.parallel_mode}') + return '\n'.join(lines) + + @property + def device(self): + for _, model in self.models.items(): + if hasattr(model, 'device'): + return model.device + return next(list(self.models.values())[0].parameters()).device + + def init_models_and_more(self, **kwargs): + """ + Initialize models and more. + """ + if self.world_size > 1: + # Prepare distributed data parallel + self.training_models = { + name: DDP( + model, + device_ids=[self.local_rank], + output_device=self.local_rank, + bucket_cap_mb=128, + find_unused_parameters=False + ) + for name, model in self.models.items() + } + else: + self.training_models = self.models + + # Build master params + self.model_params = sum( + [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()] + , []) + if self.mix_precision_mode == 'amp': + self.master_params = self.model_params + if self.mix_precision_dtype == torch.float16: + self.scaler = torch.GradScaler() + elif self.mix_precision_mode == 'inflat_all': + self.master_params = make_master_params(self.model_params) + if self.mix_precision_dtype == torch.float16: + self.log_scale = 20.0 + elif self.mix_precision_mode is None: + self.master_params = self.model_params + else: + raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.') + + # Build EMA params + if self.is_master: + self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate] + + # Initialize optimizer + if hasattr(torch.optim, self.optimizer_config['name']): + self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args']) + else: + self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args']) + + # Initalize learning rate scheduler + if self.lr_scheduler_config is not None: + if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']): + self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args']) + else: + self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args']) + + # Initialize elastic memory controller + if self.elastic_controller_config is not None: + assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \ + 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin' + self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args']) + for model in self.models.values(): + if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)): + model.register_memory_controller(self.elastic_controller) + + # Initialize gradient clipper + if self.grad_clip is not None: + if isinstance(self.grad_clip, (float, int)): + self.grad_clip = float(self.grad_clip) + else: + self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args']) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = ResumableSampler( + self.dataset, + shuffle=True, + ) + if self.num_workers is None or self.num_workers == -1: + num_workers = max(1, int(np.ceil((os.cpu_count() - 16) / torch.cuda.device_count()))) + else: + num_workers = self.num_workers + + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _master_params_to_state_dicts(self, master_params): + """ + Convert master params to dict of state_dicts. + """ + if self.mix_precision_mode == 'inflat_all': + master_params = unflatten_master_params(self.model_params, master_params) + state_dicts = {name: model.state_dict() for name, model in self.models.items()} + master_params_names = sum( + [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] + , []) + for i, (model_name, param_name) in enumerate(master_params_names): + state_dicts[model_name][param_name] = master_params[i] + return state_dicts + + def _state_dicts_to_master_params(self, master_params, state_dicts): + """ + Convert a state_dict to master params. + """ + master_params_names = sum( + [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] + , []) + params = [state_dicts[name][param_name] for name, param_name in master_params_names] + if self.mix_precision_mode == 'inflat_all': + model_params_to_master_params(params, master_params) + else: + for i, param in enumerate(params): + master_params[i].data.copy_(param.data) + + def load(self, load_dir, step=0): + """ + Load a checkpoint. + Should be called by all processes. + """ + if self.is_master: + print(f'\nLoading checkpoint from step {step}...', end='') + + model_ckpts = {} + for name, model in self.models.items(): + model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True) + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + self._state_dicts_to_master_params(self.master_params, model_ckpts) + del model_ckpts + + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = {} + for name, model in self.models.items(): + ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True) + ema_ckpts[name] = ema_ckpt + self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) + del ema_ckpts + + misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False) + self.optimizer.load_state_dict(misc_ckpt['optimizer']) + self.step = misc_ckpt['step'] + self.data_sampler.load_state_dict(misc_ckpt['data_sampler']) + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.load_state_dict(misc_ckpt['scaler']) + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + self.log_scale = misc_ckpt['log_scale'] + if self.lr_scheduler_config is not None: + self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler']) + if self.elastic_controller_config is not None: + self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller']) + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + self.grad_clip.load_state_dict(misc_ckpt['grad_clip']) + del misc_ckpt + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print(' Done.') + + if self.world_size > 1: + self.check_ddp() + + def save(self, non_blocking=True): + """ + Save a checkpoint. + Should be called only by the rank 0 process. + """ + assert self.is_master, 'save() should be called only by the rank 0 process.' + print(f'\nSaving checkpoint at step {self.step}...', end='') + + model_ckpts = self._master_params_to_state_dicts(self.master_params) + for name, model_ckpt in model_ckpts.items(): + model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving + if non_blocking: + threading.Thread( + target=torch.save, + args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')), + ).start() + else: + torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')) + + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i]) + for name, ema_ckpt in ema_ckpts.items(): + ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()} # Move to CPU for saving + if non_blocking: + threading.Thread( + target=torch.save, + args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')), + ).start() + else: + torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')) + + misc_ckpt = { + 'optimizer': self.optimizer.state_dict(), + 'step': self.step, + 'data_sampler': self.data_sampler.state_dict(), + } + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + misc_ckpt['scaler'] = self.scaler.state_dict() + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + misc_ckpt['log_scale'] = self.log_scale + if self.lr_scheduler_config is not None: + misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict() + if self.elastic_controller_config is not None: + misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict() + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + misc_ckpt['grad_clip'] = self.grad_clip.state_dict() + if non_blocking: + threading.Thread( + target=torch.save, + args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')), + ).start() + else: + torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')) + print(' Done.') + + def _remap_checkpoint_keys(self, model_ckpt, model_state_dict): + """ + Remap checkpoint keys to match model state dict. + + Handles structural changes like: + - cross_attn.xxx -> cross_attn.cross_attn_block.xxx (for ProjectAttention wrapper) + + Args: + model_ckpt: Checkpoint state dict + model_state_dict: Model state dict + + Returns: + Remapped checkpoint dict + """ + remapped_ckpt = {} + remapped_count = 0 + + for ckpt_key, ckpt_value in model_ckpt.items(): + # Check if key exists directly + if ckpt_key in model_state_dict: + remapped_ckpt[ckpt_key] = ckpt_value + continue + + # Try remapping: cross_attn.xxx -> cross_attn.cross_attn_block.xxx + # This handles the case when cross_attn is wrapped by ProjectAttention + if '.cross_attn.' in ckpt_key: + # Split at .cross_attn. + parts = ckpt_key.split('.cross_attn.') + if len(parts) == 2: + new_key = f'{parts[0]}.cross_attn.cross_attn_block.{parts[1]}' + if new_key in model_state_dict: + remapped_ckpt[new_key] = ckpt_value + remapped_count += 1 + continue + + # Key not remapped, keep original (will be handled by missing key logic) + remapped_ckpt[ckpt_key] = ckpt_value + + if remapped_count > 0 and self.is_master: + print(f'Info: Remapped {remapped_count} cross_attn keys to cross_attn.cross_attn_block') + + return remapped_ckpt + + def finetune_from(self, finetune_ckpt): + """ + Finetune from a checkpoint. + Should be called by all processes. + """ + # 允许缺失的 keys(如 register_buffer 的参数) + ALLOWED_MISSING_KEYS = {'rope_phases'} + + if self.is_master: + print('\nFinetuning from:') + for name, path in finetune_ckpt.items(): + print(f' - {name}: {path}') + + model_ckpts = {} + for name, model in self.models.items(): + model_state_dict = model.state_dict() + if name in finetune_ckpt: + model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True) + + # Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper) + model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict) + + # 检查多余的 keys(在 ckpt 中但不在 model 中) + for k, v in model_ckpt.items(): + if k not in model_state_dict: + if self.is_master: + print(f'Warning: {k} not found in model_state_dict, skipped.') + model_ckpt[k] = None + elif model_ckpt[k].shape != model_state_dict[k].shape: + if self.is_master: + print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.') + model_ckpt[k] = model_state_dict[k] + model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None} + + # 检查缺失的 keys(在 model 中但不在 ckpt 中) + missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys()) + unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS + if unexpected_missing and self.is_master: + print(f'Error: Missing keys in checkpoint: {unexpected_missing}') + raise RuntimeError(f'Missing keys in checkpoint: {unexpected_missing}') + if missing_keys & ALLOWED_MISSING_KEYS and self.is_master: + print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}') + + # 补充缺失的 keys(使用模型初始化值) + for k in missing_keys: + model_ckpt[k] = model_state_dict[k] + + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + else: + if self.is_master: + print(f'Warning: {name} not found in finetune_ckpt, skipped.') + model_ckpts[name] = model_state_dict + self._state_dicts_to_master_params(self.master_params, model_ckpts) + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + self._state_dicts_to_master_params(self.ema_params[i], model_ckpts) + del model_ckpts + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print('Done.') + + if self.world_size > 1: + self.check_ddp() + + @abstractmethod + def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs): + """ + Run a snapshot of the model. + """ + pass + + @torch.no_grad() + def visualize_sample(self, sample): + """ + Convert a sample to an image. + """ + if hasattr(self.dataset, 'visualize_sample'): + return self.dataset.visualize_sample(sample) + else: + return sample + + @torch.no_grad() + def snapshot_dataset(self, num_samples=100, batch_size=4): + """ + Sample images from the dataset. + """ + dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=batch_size, + num_workers=0, + shuffle=True, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + save_cfg = {} + for i in range(0, num_samples, batch_size): + data = next(iter(dataloader)) + data = {k: v[:min(num_samples - i, batch_size)] for k, v in data.items()} + data = recursive_to_device(data, self.device) + try: + vis = self.visualize_sample(data) + except (RuntimeError, Exception) as e: + print(f'\033[93m[WARN] snapshot_dataset visualize_sample failed (batch {i}), skipping: {e}\033[0m') + torch.cuda.empty_cache() + continue + if isinstance(vis, dict): + for k, v in vis.items(): + if f'dataset_{k}' not in save_cfg: + save_cfg[f'dataset_{k}'] = [] + save_cfg[f'dataset_{k}'].append(v) + else: + if 'dataset' not in save_cfg: + save_cfg['dataset'] = [] + save_cfg['dataset'].append(vis) + for name, image in save_cfg.items(): + utils.save_image( + torch.cat(image, dim=0), + os.path.join(self.output_dir, 'samples', f'{name}.jpg'), + nrow=int(np.sqrt(num_samples)), + normalize=True, + value_range=self.dataset.value_range, + ) + + @torch.no_grad() + def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False): + """ + Sample images from the model. + NOTE: When num_samples >= 4, this function should be called by all processes. + When num_samples < 4, only master runs snapshot (other ranks skip via barrier). + """ + # Free cached GPU memory before snapshot to avoid OOM / illegal address errors + import gc + gc.collect() + torch.cuda.empty_cache() + + if self.is_master: + print(f'\nSampling {num_samples} images...', end='') + + if suffix is None: + suffix = f'step{self.step:07d}' + + # When num_samples < 4, only master runs snapshot to avoid multi-rank gather issues + master_only = num_samples < 4 + + sample_metadata = None # Will hold list of "dataset_name/sha256" strings + + if master_only and self.world_size > 1: + if not self.is_master: + # Non-master ranks just wait at barrier + dist.barrier() + return + + # Master runs snapshot alone + amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext + with amp_context(): + samples = self.run_snapshot(num_samples, batch_size=batch_size, verbose=verbose) + + # Extract metadata before preprocessing + sample_metadata = samples.pop('_metadata', None) + + # Free GPU memory after sampling, before decode + render + torch.cuda.empty_cache() + + # Preprocess images + for key in list(samples.keys()): + if samples[key]['type'] == 'sample': + try: + vis = self.visualize_sample(samples[key]['value']) + except RuntimeError as e: + print(f"[Snapshot] WARNING: visualize_sample failed for '{key}': {e}") + # Reset CUDA error state and skip this sample + try: + torch.cuda.synchronize() + except RuntimeError: + pass + torch.cuda.empty_cache() + del samples[key] + continue + if isinstance(vis, dict): + for k, v in vis.items(): + samples[f'{key}_{k}'] = {'value': v, 'type': 'image'} + del samples[key] + else: + samples[key] = {'value': vis, 'type': 'image'} + + # No gather needed, master already has all samples + dist.barrier() + else: + # Distribute sampling across all ranks + num_samples_per_process = int(np.ceil(num_samples / self.world_size)) + amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext + + with amp_context(): + samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose) + + # Extract metadata before preprocessing + local_metadata = samples.pop('_metadata', None) + + # Free GPU memory after sampling, before decode + render + torch.cuda.empty_cache() + + # Preprocess images + for key in list(samples.keys()): + if samples[key]['type'] == 'sample': + try: + vis = self.visualize_sample(samples[key]['value']) + except RuntimeError as e: + print(f"[Snapshot] WARNING: visualize_sample failed for '{key}': {e}") + torch.cuda.synchronize() + del samples[key] + continue + if isinstance(vis, dict): + for k, v in vis.items(): + samples[f'{key}_{k}'] = {'value': v, 'type': 'image'} + del samples[key] + else: + samples[key] = {'value': vis, 'type': 'image'} + + # Gather results + if self.world_size > 1: + for key in samples.keys(): + samples[key]['value'] = samples[key]['value'].contiguous() + if self.is_master: + all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)] + else: + all_images = [] + dist.gather(samples[key]['value'], all_images, dst=0) + if self.is_master: + samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples] + + # Gather metadata across ranks + if local_metadata is not None: + all_metadata = [None] * self.world_size + dist.all_gather_object(all_metadata, local_metadata) + if self.is_master: + sample_metadata = sum(all_metadata, [])[:num_samples] + else: + sample_metadata = None + else: + sample_metadata = local_metadata + + # Save images + if self.is_master: + os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True) + wandb_images = {} # Collect images for wandb logging + nrow = int(np.sqrt(num_samples)) + vr = self.dataset.value_range + + # Build metadata caption string for wandb + metadata_caption = '' + if sample_metadata: + metadata_caption = '\n' + ' | '.join(sample_metadata) + # Also save metadata to file + with open(os.path.join(self.output_dir, 'samples', suffix, 'metadata.txt'), 'w') as f: + for i, m in enumerate(sample_metadata): + f.write(f'{i}: {m}\n') + + # Helper: make a normalized grid tensor from a batch of images + def _make_grid(tensor): + return utils.make_grid(tensor, nrow=nrow, normalize=True, value_range=vr) + + # Helper: resize grid to target height (keep aspect ratio) + def _resize_to_height(grid, target_h): + import torch.nn.functional as F + _, h, w = grid.shape + if h == target_h: + return grid + target_w = int(round(w * target_h / h)) + return F.interpolate(grid.unsqueeze(0), size=(target_h, target_w), mode='bilinear', align_corners=False).squeeze(0) + + # --- Save individual images (original behavior) --- + for key in samples.keys(): + if samples[key]['type'] == 'image': + image_path = os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg') + utils.save_image( + samples[key]['value'], + image_path, + nrow=nrow, + normalize=True, + value_range=vr, + ) + # Collect for wandb + if self.wandb_run is not None: + grid = _make_grid(samples[key]['value']) + grid_np = grid.permute(1, 2, 0).cpu().numpy() + grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8) + wandb_images[f'samples/{key}'] = wandb.Image(grid_np, caption=f'{key} at step {self.step}{metadata_caption}') + elif samples[key]['type'] == 'number': + val_min = samples[key]['value'].min() + val_max = samples[key]['value'].max() + images = (samples[key]['value'] - val_min) / (val_max - val_min) + images = utils.make_grid( + images, + nrow=nrow, + normalize=False, + ) + save_image_with_notes( + images, + os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), + notes=f'{key} min: {val_min}, max: {val_max}', + ) + + # --- Save combined images --- + sample_keys = set(samples.keys()) + + # Combined 1: image + sample_gt_view + sample_gt_gt_view (shape) + # image + sample_gt_view_{attr} + sample_gt_gt_view_{attr} (tex, per attribute) + # Detect gt_view attribute suffixes from sample keys + gt_view_attrs = set() + for k in sample_keys: + if k.startswith('sample_gt_view_'): + attr = k[len('sample_gt_view_'):] + gt_view_attrs.add(attr) + + if gt_view_attrs: + # Tex mode: generate combined view for each PBR attribute + for attr in sorted(gt_view_attrs): + combo1_keys = ['image', f'sample_gt_view_{attr}', f'sample_gt_gt_view_{attr}'] + combo1_present = [k for k in combo1_keys if k in sample_keys and samples[k]['type'] == 'image'] + if len(combo1_present) >= 2: + grids = [_make_grid(samples[k]['value']) for k in combo1_present] + target_h = max(g.shape[1] for g in grids) + grids = [_resize_to_height(g, target_h) for g in grids] + combined = torch.cat(grids, dim=2) + combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_views_{attr}_{suffix}.jpg') + utils.save_image(combined, combined_path, normalize=False) + if self.wandb_run is not None: + grid_np = combined.permute(1, 2, 0).cpu().numpy() + grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8) + label = ' | '.join(combo1_present) + wandb_images[f'samples/combined_views_{attr}'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}') + else: + # Shape mode: single gt_view + combo1_keys = ['image', 'sample_gt_view', 'sample_gt_gt_view'] + combo1_present = [k for k in combo1_keys if k in sample_keys and samples[k]['type'] == 'image'] + if len(combo1_present) >= 2: + grids = [_make_grid(samples[k]['value']) for k in combo1_present] + target_h = max(g.shape[1] for g in grids) + grids = [_resize_to_height(g, target_h) for g in grids] + combined = torch.cat(grids, dim=2) + combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_views_{suffix}.jpg') + utils.save_image(combined, combined_path, normalize=False) + if self.wandb_run is not None: + grid_np = combined.permute(1, 2, 0).cpu().numpy() + grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8) + label = ' | '.join(combo1_present) + wandb_images[f'samples/combined_views'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}') + + # Combined 2: sample_multiview + sample_gt_multiview + combo2_keys = ['sample_multiview', 'sample_gt_multiview'] + combo2_present = [k for k in combo2_keys if k in sample_keys and samples[k]['type'] == 'image'] + if len(combo2_present) >= 2: + grids = [_make_grid(samples[k]['value']) for k in combo2_present] + target_h = max(g.shape[1] for g in grids) + grids = [_resize_to_height(g, target_h) for g in grids] + combined = torch.cat(grids, dim=2) # concatenate along width + combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_multiview_{suffix}.jpg') + utils.save_image(combined, combined_path, normalize=False) + if self.wandb_run is not None: + grid_np = combined.permute(1, 2, 0).cpu().numpy() + grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8) + label = ' | '.join(combo2_present) + wandb_images[f'samples/combined_multiview'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}') + + # Log images to wandb + if self.wandb_run is not None and wandb_images: + self.wandb_run.log(wandb_images, step=self.step) + + if self.is_master: + print(' Done.') + + def update_ema(self): + """ + Update exponential moving average. + Should only be called by the rank 0 process. + """ + assert self.is_master, 'update_ema() should be called only by the rank 0 process.' + for i, ema_rate in enumerate(self.ema_rate): + for master_param, ema_param in zip(self.master_params, self.ema_params[i]): + ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate) + + def check_ddp(self): + """ + Check if DDP is working properly. + Should be called by all process. + """ + if self.is_master: + print('\nPerforming DDP check...') + + if self.is_master: + print('Checking if parameters are consistent across processes...') + dist.barrier() + try: + for p in self.master_params: + # split to avoid OOM + for i in range(0, p.numel(), 10000000): + sub_size = min(10000000, p.numel() - i) + sub_p = p.detach().view(-1)[i:i+sub_size] + # gather from all processes + sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)] + dist.all_gather(sub_p_gather, sub_p) + # check if equal + assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes' + except AssertionError as e: + if self.is_master: + print(f'\n\033[91mError: {e}\033[0m') + print('DDP check failed.') + raise e + + dist.barrier() + if self.is_master: + print('Done.') + + def _verify_gradient_sync(self): + """ + 验证 DDP 梯度同步是否真正生效。 + DDP 的 backward 会自动对梯度进行 all_reduce,同步后所有卡的梯度应该完全相同。 + + 验证方法: + 1. 计算所有参数的总梯度 norm + 2. 收集各卡的梯度 norm + 3. 如果 DDP 同步正常,所有卡的梯度 norm 应该完全相同 + 4. 如果没有同步,各卡梯度 norm 会不同(因为各卡处理的数据不同) + """ + # 计算本卡所有参数的总梯度 norm + total_grad_norm_sq = 0.0 + grad_count = 0 + for p in self.model_params: + if p.grad is not None: + total_grad_norm_sq += p.grad.detach().float().norm().item() ** 2 + grad_count += 1 + + if grad_count == 0: + return + + local_grad_norm = total_grad_norm_sq ** 0.5 + + # 确保所有进程到达同一点 + dist.barrier() + + # 收集所有卡的梯度 norm + grad_norm_tensor = torch.tensor([local_grad_norm], dtype=torch.float64, device=self.device) + all_grad_norms = [torch.zeros(1, dtype=torch.float64, device=self.device) for _ in range(self.world_size)] + dist.all_gather(all_grad_norms, grad_norm_tensor) + all_grad_norms = [g.item() for g in all_grad_norms] + + # 验证所有卡的梯度 norm 是否相同(使用相对误差,容忍 0.1%) + ref_norm = all_grad_norms[0] + if ref_norm > 0: + is_synced = all(abs(g - ref_norm) / ref_norm < 1e-3 for g in all_grad_norms) + else: + is_synced = all(abs(g) < 1e-10 for g in all_grad_norms) + + if self.is_master: + print(f'\n{"="*60}') + print(f'[Step {self.step}] DDP Gradient Sync Verification:') + for i, g in enumerate(all_grad_norms): + print(f' Rank {i} grad_norm: {g:.8f}') + if is_synced: + print(f' \033[92m✓ PASS: All gradients are synchronized!\033[0m') + else: + max_diff = max(abs(g - ref_norm) for g in all_grad_norms) + print(f' \033[91m✗ FAIL: Gradients are NOT synchronized! Max diff: {max_diff:.8f}\033[0m') + print(f'{"="*60}\n') + + @abstractmethod + def training_losses(**mb_data): + """ + Compute training losses. + """ + pass + + def load_data(self): + """ + Load data. + """ + if self.prefetch_data: + if self._data_prefetched is None: + self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + data = self._data_prefetched + self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + else: + data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + + # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu + if isinstance(data, dict): + if self.batch_split == 1: + data_list = [data] + else: + batch_size = list(data.values())[0].shape[0] + data_list = [ + {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()} + for i in range(self.batch_split) + ] + elif isinstance(data, list): + data_list = data + else: + raise ValueError('Data must be a dict or a list of dicts.') + + return data_list + + def run_step(self, data_list): + """ + Run a training step. + """ + step_log = {'loss': {}, 'status': {}} + amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext + elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext + + # Train + losses = [] + statuses = [] + elastic_controller_logs = [] + zero_grad(self.model_params) + for i, mb_data in enumerate(data_list): + ## sync at the end of each batch split + sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext] + with nested_contexts(*sync_contexts), elastic_controller_context(): + with amp_context(): + loss, status = self.training_losses(**mb_data) + l = loss['loss'] / len(data_list) + + # DEBUG: 打印每个 rank 的 loss + if self.debug: + print(f'[Rank {self.rank}/{self.world_size}] Step {self.step} batch {i}: loss={loss["loss"].item():.6f}') + + ## backward + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.scale(l).backward() + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + scaled_l = l * (2 ** self.log_scale) + scaled_l.backward() + else: + l.backward() + ## log + losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) + statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) + if self.elastic_controller_config is not None: + elastic_controller_logs.append(self.elastic_controller.log()) + + # ============================================================ + # DEBUG: 验证 DDP 梯度同步 + # 检查 backward 后各卡梯度是否一致 + # DDP 在最后一个 batch_split 的 backward 时会自动 all_reduce 梯度 + # 同步后所有卡的梯度应该完全相同 + # ============================================================ + if self.debug and self.world_size > 1: + self._verify_gradient_sync() + + ## gradient clip + if self.grad_clip is not None: + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.unscale_(self.optimizer) + elif self.mix_precision_mode == 'inflat_all': + model_grads_to_master_grads(self.model_params, self.master_params) + if self.mix_precision_dtype == torch.float16: + self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) + if isinstance(self.grad_clip, float): + grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip) + else: + grad_norm = self.grad_clip(self.master_params) + if torch.isfinite(grad_norm): + statuses[-1]['grad_norm'] = grad_norm.item() + ## step + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + prev_scale = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + elif self.mix_precision_mode == 'inflat_all': + if self.mix_precision_dtype == torch.float16: + prev_scale = 2 ** self.log_scale + if not any(not p.grad.isfinite().all() for p in self.model_params): + if self.grad_clip is None: + model_grads_to_master_grads(self.model_params, self.master_params) + self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) + self.optimizer.step() + master_params_to_model_params(self.model_params, self.master_params) + self.log_scale += self.fp16_scale_growth + else: + self.log_scale -= 1 + else: + prev_scale = 1.0 + if self.grad_clip is None: + model_grads_to_master_grads(self.model_params, self.master_params) + if not any(not p.grad.isfinite().all() for p in self.master_params): + self.optimizer.step() + master_params_to_model_params(self.model_params, self.master_params) + else: + print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') + else: + prev_scale = 1.0 + if not any(not p.grad.isfinite().all() for p in self.model_params): + self.optimizer.step() + else: + print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') + ## adjust learning rate + if self.lr_scheduler_config is not None: + statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0] + self.lr_scheduler.step() + + # Logs + step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x)) + step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)}) + if self.elastic_controller_config is not None: + step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x)) + if self.grad_clip is not None: + step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log() + + # Check grad and norm of each param + if self.log_param_stats: + param_norms = {} + param_grads = {} + for model_name, model in self.models.items(): + for name, param in model.named_parameters(): + if param.requires_grad: + param_norms[f'{model_name}.{name}'] = param.norm().item() + if param.grad is not None and torch.isfinite(param.grad).all(): + param_grads[f'{model_name}.{name}'] = param.grad.norm().item() / prev_scale + step_log['param_norms'] = param_norms + step_log['param_grads'] = param_grads + + # Update exponential moving average + if self.is_master: + self.update_ema() + + return step_log + + def save_logs(self): + log_str = '\n'.join([ + f'{step}: {json.dumps(dict_foreach(log, lambda x: float(x)))}' for step, log in self.log + ]) + + # Accumulate logs in memory and overwrite file each time (S3 FUSE does not support append) + if not hasattr(self, '_log_buffer'): + self._log_buffer = [] + self._log_buffer.append(log_str) + try: + with open(self._log_file, 'w') as log_file: + log_file.write('\n'.join(self._log_buffer) + '\n') + except Exception as e: + print(f'\033[93m[WARN] Failed to write log file: {e}\033[0m') + + # show with mlflow + log_show = [l for _, l in self.log if not dict_any(l, lambda x: np.isnan(x))] + log_show = dict_reduce(log_show, lambda x: np.mean(x)) + log_show = dict_flatten(log_show, sep='/') + if self.writer is not None: + for key, value in log_show.items(): + self.writer.add_scalar(key, value, self.step) + + # Log to wandb + if self.wandb_run is not None: + wandb_log = {key: value for key, value in log_show.items()} + wandb_log['step'] = self.step + self.wandb_run.log(wandb_log, step=self.step) + + self.log = [] + + def check_abort(self): + """ + Check if training should be aborted due to certain conditions. + """ + # 1. If log_scale in inflat_all mode is less than 0 + if self.mix_precision_dtype == torch.float16 and \ + self.mix_precision_mode == 'inflat_all' and \ + self.log_scale < 0: + if self.is_master: + print ('\n\n\033[91m') + print (f'ABORT: log_scale in inflat_all mode is less than 0 at step {self.step}.') + print ('This indicates that the model is diverging. You should look into the model and the data.') + print ('\033[0m') + self.save(non_blocking=False) + self.save_logs() + if self.world_size > 1: + dist.barrier() + raise ValueError('ABORT: log_scale in inflat_all mode is less than 0.') + + def run(self): + """ + Run training. + """ + if self.is_master: + print('\nStarting training...') + if self.i_sample != -1: + try: + self.snapshot_dataset(num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) + except (RuntimeError, Exception) as e: + print(f'\033[93m[WARN] snapshot_dataset failed, skipping: {e}\033[0m') + torch.cuda.empty_cache() + else: + print('[INFO] i_sample=-1, all snapshots disabled.') + if self.i_sample != -1: + if self.step == 0: + try: + self.snapshot(suffix='init', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) + except (RuntimeError, Exception) as e: + print(f'\033[93m[WARN] snapshot (init) failed, skipping: {e}\033[0m') + torch.cuda.empty_cache() + else: # resume + try: + self.snapshot(suffix=f'resume_step{self.step:07d}', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) + except (RuntimeError, Exception) as e: + print(f'\033[93m[WARN] snapshot (resume) failed, skipping: {e}\033[0m') + torch.cuda.empty_cache() + + time_last_print = 0.0 + time_elapsed = 0.0 + while self.step < self.max_steps: + time_start = time.time() + + data_list = self.load_data() + step_log = self.run_step(data_list) + + time_end = time.time() + time_elapsed += time_end - time_start + + self.step += 1 + + # Print progress + if self.is_master and self.step % self.i_print == 0: + speed = self.i_print / (time_elapsed - time_last_print) * 3600 + columns = [ + f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)', + f'Elapsed: {time_elapsed / 3600:.2f} h', + f'Speed: {speed:.2f} steps/h', + f'ETA: {(self.max_steps - self.step) / speed:.2f} h', + ] + print(' | '.join([c.ljust(25) for c in columns]), flush=True) + time_last_print = time_elapsed + + # Check ddp + if self.parallel_mode == 'ddp' and self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0: + self.check_ddp() + + # Sample images + if self.i_sample != -1 and self.step % self.i_sample == 0: + try: + self.snapshot(num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) + except (RuntimeError, Exception) as e: + if self.is_master: + print(f'\033[93m[WARN] snapshot at step {self.step} failed, skipping: {e}\033[0m') + try: + torch.cuda.empty_cache() + except Exception: + pass + + if self.is_master: + self.log.append((self.step, {})) + + # Log time + self.log[-1][1]['time'] = { + 'step': time_end - time_start, + 'elapsed': time_elapsed, + } + + # Log losses + if step_log is not None: + self.log[-1][1].update(step_log) + + # Log scale + if self.mix_precision_dtype == torch.float16: + if self.mix_precision_mode == 'amp': + self.log[-1][1]['scale'] = self.scaler.get_scale() + elif self.mix_precision_mode == 'inflat_all': + self.log[-1][1]['log_scale'] = self.log_scale + + # Save log + if self.step % self.i_log == 0: + self.save_logs() + + # Save checkpoint + if self.step % self.i_save == 0: + self.save() + + # Check abort + self.check_abort() + + if self.i_sample != -1: + try: + self.snapshot(suffix='final', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) + except (RuntimeError, Exception) as e: + if self.is_master: + print(f'\033[93m[WARN] snapshot (final) failed, skipping: {e}\033[0m') + torch.cuda.empty_cache() + if self.world_size > 1: + dist.barrier() + if self.is_master: + self.writer.close() + print('Training finished.') + + def profile(self, wait=2, warmup=3, active=5): + """ + Profile the training loop. + """ + with torch.profiler.profile( + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')), + profile_memory=True, + with_stack=True, + ) as prof: + for _ in range(wait + warmup + active): + self.run_step() + prof.step() diff --git a/trellis2/trainers/flow_matching/flow_matching.py b/trellis2/trainers/flow_matching/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab1bea2e3ca747ec3fdd521309b07e71b5e12d9 --- /dev/null +++ b/trellis2/trainers/flow_matching/flow_matching.py @@ -0,0 +1,655 @@ +from typing import * +import copy +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...pipelines import samplers +from ...utils.general_utils import dict_reduce +from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin +from .mixins.text_conditioned import TextConditionedMixin +from .mixins.image_conditioned import ImageConditionedMixin +from .mixins.image_conditioned_proj import ImageConditionedProjMixin + + +class FlowMatchingTrainer(BasicTrainer): + """ + Trainer for diffusion model with flow matching objective. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + """ + def __init__( + self, + *args, + t_schedule: dict = { + 'name': 'logitNormal', + 'args': { + 'mean': 0.0, + 'std': 1.0, + } + }, + sigma_min: float = 1e-5, + **kwargs + ): + super().__init__(*args, **kwargs) + self.t_schedule = t_schedule + self.sigma_min = sigma_min + + def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + t: The [N] tensor of diffusion steps [0-1]. + noise: If specified, use this noise instead of generating new noise. + + Returns: + x_t, the noisy version of x_0 under timestep t. + """ + if noise is None: + noise = torch.randn_like(x_0) + assert noise.shape == x_0.shape, "noise must have same shape as x_0" + + t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)]) + x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise + + return x_t + + def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: + """ + Get original image from noisy version under timestep t. + """ + assert noise.shape == x_t.shape, "noise must have same shape as x_t" + t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)]) + x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t) + return x_0 + + def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Compute the velocity of the diffusion process at time t. + """ + return (1 - self.sigma_min) * noise - x_0 + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + return {'cond': cond, **kwargs} + + def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler: + """ + Get the sampler for the diffusion process. + """ + return samplers.FlowEulerSampler(self.sigma_min) + + def vis_cond(self, **kwargs): + """ + Visualize the conditioning data. + """ + return {} + + def sample_t(self, batch_size: int) -> torch.Tensor: + """ + Sample timesteps. + """ + if self.t_schedule['name'] == 'uniform': + t = torch.rand(batch_size) + elif self.t_schedule['name'] == 'logitNormal': + mean = self.t_schedule['args']['mean'] + std = self.t_schedule['args']['std'] + t = torch.sigmoid(torch.randn(batch_size) * std + mean) + else: + raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}") + return t + + def training_losses( + self, + x_0: torch.Tensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments to pass to the backbone. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = torch.randn_like(x_0) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + cond = self.get_cond(cond, **kwargs) + + pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred, target) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred[i], target[i]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + # Use current step as seed to ensure different samples for each snapshot + import random + snapshot_seed = self.step + random.seed(snapshot_seed) + np.random.seed(snapshot_seed) + + g = torch.Generator() + g.manual_seed(snapshot_seed) + + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + generator=g, + ) + + # inference + sampler = self.get_sampler() + sample_gt = [] + sample = [] + cond_vis = [] + sample_metadata = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} + + # Collect metadata (dataset_name and sha256) for wandb display + if '_dataset_name' in data and '_sha256' in data: + for j in range(batch): + sample_metadata.append(f"{data['_dataset_name'][j]}/{data['_sha256'][j]}") + + # Remove metadata fields before inference + data.pop('_dataset_name', None) + data.pop('_sha256', None) + + noise = torch.randn_like(data['x_0']) + sample_gt.append(data['x_0']) + cond_vis.append(self.vis_cond(**data)) + del data['x_0'] + args = self.get_inference_cond(**data) + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=50, guidance_strength=3.0, verbose=verbose, + ) + sample.append(res.samples) + + sample_gt = torch.cat(sample_gt, dim=0) + sample = torch.cat(sample, dim=0) + sample_dict = { + 'sample_gt': {'value': sample_gt, 'type': 'sample'}, + 'sample': {'value': sample, 'type': 'sample'}, + } + if sample_metadata: + sample_dict['_metadata'] = sample_metadata + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + +class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer): + """ + Trainer for diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + """ + pass + + +class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer): + """ + Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + text_cond_model(str): Text conditioning model. + """ + pass + + +class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer): + """ + Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass + + +class ImageConditionedProjFlowMatchingCFGTrainer(ImageConditionedProjMixin, FlowMatchingCFGTrainer): + """ + Trainer for image-conditioned diffusion model with view-aligned projection. + + Uses ImageConditionedProjMixin for 3D-to-2D feature projection with camera parameters. + CFG dropout is handled by ClassifierFreeGuidanceMixin (via p_uncond parameter). + + Args: + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (dict): Image conditioning model config (DinoV3ProjFeatureExtractor). + run_projection_test (bool): Whether to run projection visualization test before training. + """ + + def __init__(self, *args, run_projection_test: bool = True, **kwargs): + super().__init__(*args, **kwargs) + self.run_projection_test = run_projection_test + + def training_losses( + self, + x_0: torch.Tensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Overridden to avoid passing extra kwargs to the model. + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments (camera info, view_idx, etc.) for conditioning. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = torch.randn_like(x_0) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + cond = self.get_cond(cond, **kwargs) + + # Note: SparseStructureFlowModel.forward() only accepts (x, t, cond) + # Do not pass extra kwargs to avoid unexpected keyword argument errors + pred = self.training_models['denoiser'](x_t, t * 1000, cond) + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred, target) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred[i], target[i]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + """ + Run snapshot with camera parameters for GT view rendering. + + Overrides parent to include camera parameters in sample dicts for + visualizing the GT camera view alongside the standard 4-view rendering. + """ + # Use current step as seed to ensure different samples for each snapshot + import random + snapshot_seed = self.step + random.seed(snapshot_seed) + np.random.seed(snapshot_seed) + + g = torch.Generator() + g.manual_seed(snapshot_seed) + + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + generator=g, + ) + + # inference + sampler = self.get_sampler() + sample_gt_list = [] + sample_list = [] + cond_vis = [] + sample_metadata = [] + + # Camera params for GT view rendering + camera_distances = [] + camera_angles = [] + mesh_scales = [] + + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} + + # Collect metadata (dataset_name and sha256) for wandb display + if '_dataset_name' in data and '_sha256' in data: + for j in range(batch): + sample_metadata.append(f"{data['_dataset_name'][j]}/{data['_sha256'][j]}") + + # Remove metadata fields before inference + data.pop('_dataset_name', None) + data.pop('_sha256', None) + + noise = torch.randn_like(data['x_0']) + + # Save GT sample + sample_gt_list.append(data['x_0']) + cond_vis.append(self.vis_cond(**data)) + + # Save camera parameters for GT view rendering (if available) + if 'camera_distance' in data: + camera_distances.append(data['camera_distance']) + if 'camera_angle_x' in data: + camera_angles.append(data['camera_angle_x']) + if 'mesh_scale' in data: + mesh_scales.append(data['mesh_scale']) + + # Remove x_0 before inference + del data['x_0'] + args = self.get_inference_cond(**data) + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=50, guidance_strength=3.0, verbose=verbose, + ) + sample_list.append(res.samples) + + # Concatenate samples + sample_gt = torch.cat(sample_gt_list, dim=0) + sample = torch.cat(sample_list, dim=0) + + # Build sample dicts with camera info for GT view rendering + sample_gt_value = {'x_0': sample_gt} + sample_value = {'x_0': sample} + + # Add camera params if available + if len(camera_distances) > 0: + camera_distance = torch.cat(camera_distances, dim=0) + sample_gt_value['camera_distance'] = camera_distance + sample_value['camera_distance'] = camera_distance + if len(camera_angles) > 0: + camera_angle_x = torch.cat(camera_angles, dim=0) + sample_gt_value['camera_angle_x'] = camera_angle_x + sample_value['camera_angle_x'] = camera_angle_x + if len(mesh_scales) > 0: + mesh_scale = torch.cat(mesh_scales, dim=0) + sample_gt_value['mesh_scale'] = mesh_scale + sample_value['mesh_scale'] = mesh_scale + + sample_dict = { + 'sample_gt': {'value': sample_gt_value, 'type': 'sample'}, + 'sample': {'value': sample_value, 'type': 'sample'}, + } + if sample_metadata: + sample_dict['_metadata'] = sample_metadata + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + @torch.no_grad() + def visualize_sample(self, sample): + """ + Convert a sample to images, including GT camera view if available. + + Args: + sample: Either a tensor or dict containing: + - 'x_0': latent tensor [B, C, D, H, W] + - 'camera_angle_x': [B] (optional) + - 'camera_distance': [B] (optional) + - 'mesh_scale': [B] (optional) + + Returns: + dict with visualization images or tensor + """ + if hasattr(self.dataset, 'visualize_sample'): + if isinstance(sample, dict): + # Extract camera params if available + camera_angle_x = sample.get('camera_angle_x') + camera_distance = sample.get('camera_distance') + mesh_scale = sample.get('mesh_scale') + x_0 = sample.get('x_0', sample) + + return self.dataset.visualize_sample( + x_0, + camera_angle_x=camera_angle_x, + camera_distance=camera_distance, + mesh_scale=mesh_scale, + ) + else: + return self.dataset.visualize_sample(sample) + else: + if isinstance(sample, dict): + return sample.get('x_0', sample) + return sample + + def run(self): + """ + Run training with projection visualization test before starting. + """ + # Run projection visualization test before training starts (if enabled) + if self.run_projection_test and self.is_master: + print('\n' + '='*60) + print('Running projection visualization test...') + print('='*60) + self._run_projection_visualization_test() + + super().run() + + @torch.no_grad() + def _run_projection_visualization_test(self, num_samples: int = 4): + """ + Run projection visualization test on a few samples before training starts. + + This helps verify that the 3D-to-2D projection is working correctly. + """ + import os + from torch.utils.data import DataLoader + + # Create a small dataloader + dataloader = DataLoader( + self.dataset, + batch_size=min(num_samples, self.snapshot_batch_size), + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # Get one batch + data = next(iter(dataloader)) + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + + # Extract condition image + cond = data.get('cond') + if cond is None: + print("Warning: No 'cond' field in data, skipping projection visualization test") + return + + # Save directory + save_dir = os.path.join(self.output_dir, 'samples', 'projection_test') + + # Call visualization method + if hasattr(self, 'visualize_projection_test'): + # Need to pass camera info as kwargs + kwargs = {k: v for k, v in data.items() if k != 'cond' and k != 'x_0'} + self.visualize_projection_test( + cond=cond, + save_dir=save_dir, + prefix="proj_test", + **kwargs + ) + print(f"Projection visualization saved to: {save_dir}") + else: + print("Warning: visualize_projection_test not available") diff --git a/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py b/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae3a4ed19791e3f5f8e44eae5ceed02fcd4fc83 --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py @@ -0,0 +1,60 @@ +import torch +import numpy as np +from ....utils.general_utils import dict_foreach +from ....pipelines import samplers + + +class ClassifierFreeGuidanceMixin: + def __init__(self, *args, p_uncond: float = 0.1, **kwargs): + super().__init__(*args, **kwargs) + self.p_uncond = p_uncond + + def get_cond(self, cond, neg_cond=None, **kwargs): + """ + Get the conditioning data. + """ + assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" + + if self.p_uncond > 0: + # randomly drop the class label + def get_batch_size(cond): + if isinstance(cond, torch.Tensor): + return cond.shape[0] + elif isinstance(cond, list): + return len(cond) + else: + raise ValueError(f"Unsupported type of cond: {type(cond)}") + + ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]] + B = get_batch_size(ref_cond) + + def select(cond, neg_cond, mask): + if isinstance(cond, torch.Tensor): + mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1)) + return torch.where(mask, neg_cond, cond) + elif isinstance(cond, list): + return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)] + else: + raise ValueError(f"Unsupported type of cond: {type(cond)}") + + mask = list(np.random.rand(B) < self.p_uncond) + if not isinstance(cond, dict): + cond = select(cond, neg_cond, mask) + else: + # Apply select to each key in the dict + cond = {k: select(cond[k], neg_cond[k], mask) for k in cond.keys()} + + return cond + + def get_inference_cond(self, cond, neg_cond=None, **kwargs): + """ + Get the conditioning data for inference. + """ + assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" + return {'cond': cond, 'neg_cond': neg_cond, **kwargs} + + def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler: + """ + Get the sampler for the diffusion process. + """ + return samplers.FlowEulerCfgSampler(self.sigma_min) diff --git a/trellis2/trainers/flow_matching/mixins/image_conditioned.py b/trellis2/trainers/flow_matching/mixins/image_conditioned.py new file mode 100644 index 0000000000000000000000000000000000000000..13b932eb714084ad7fc4ac56d3136d5c90db4350 --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/image_conditioned.py @@ -0,0 +1,249 @@ +from typing import * +import torch +import torch.nn.functional as F +from torchvision import transforms +from transformers import DINOv3ViTModel +import numpy as np +from PIL import Image + +from ....utils import dist_utils + + +class DinoV2FeatureExtractor: + """ + Feature extractor for DINOv2 models. + """ + def __init__(self, model_name: str): + self.model_name = model_name + self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) + self.model.eval() + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.model(image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + +class DinoV3FeatureExtractor: + """ + Feature extractor for DINOv3 models. + """ + def __init__(self, model_name: str, image_size=512): + self.model_name = model_name + self.model = DINOv3ViTModel.from_pretrained(model_name) + self.model.eval() + self.image_size = image_size + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def extract_features(self, image: torch.Tensor) -> torch.Tensor: + image = image.to(self.model.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.model.embeddings(image, bool_masked_pos=None) + position_embeddings = self.model.rope_embeddings(image) + + for i, layer_module in enumerate(self.model.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + return F.layer_norm(hidden_states, hidden_states.shape[-1:]) + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.extract_features(image) + return features + + +class ImageConditionedMixin: + """ + Mixin for image-conditioned models. + + Args: + image_cond_model: The image conditioning model. + """ + def __init__(self, *args, image_cond_model: dict, **kwargs): + super().__init__(*args, **kwargs) + self.image_cond_model_config = image_cond_model + self.image_cond_model = None # the model is init lazily + + def _init_image_cond_model(self): + """ + Initialize the image conditioning model. + """ + with dist_utils.local_master_first(): + self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {})) + self.image_cond_model.cuda() + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Encode the image. + """ + if self.image_cond_model is None: + self._init_image_cond_model() + features = self.image_cond_model(image) + return features + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + cond = super().get_inference_cond(cond, **kwargs) + return cond + + def vis_cond(self, cond, **kwargs): + """ + Visualize the conditioning data. + """ + return {'image': {'value': cond, 'type': 'image'}} + + +class MultiImageConditionedMixin: + """ + Mixin for multiple-image-conditioned models. + + Args: + image_cond_model: The image conditioning model. + """ + def __init__(self, *args, image_cond_model: dict, **kwargs): + super().__init__(*args, **kwargs) + self.image_cond_model_config = image_cond_model + self.image_cond_model = None # the model is init lazily + + def _init_image_cond_model(self): + """ + Initialize the image conditioning model. + """ + with dist_utils.local_master_first(): + self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {})) + + @torch.no_grad() + def encode_images(self, images: Union[List[torch.Tensor], List[List[Image.Image]]]) -> List[torch.Tensor]: + """ + Encode the image. + """ + if self.image_cond_model is None: + self._init_image_cond_model() + seqlen = [len(i) for i in images] + images = torch.cat(images, dim=0) if isinstance(images[0], torch.Tensor) else sum(images, []) + features = self.image_cond_model(images) + features = torch.split(features, seqlen) + features = [feature.reshape(-1, feature.shape[-1]) for feature in features] + return features + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_images(cond) + kwargs['neg_cond'] = [ + torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond)) + ] + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_images(cond) + kwargs['neg_cond'] = [ + torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond)) + ] + cond = super().get_inference_cond(cond, **kwargs) + return cond + + def vis_cond(self, cond, **kwargs): + """ + Visualize the conditioning data. + """ + H, W = cond[0].shape[-2:] + vis = [] + for images in cond: + canvas = torch.zeros(3, H * 2, W * 2, device=images.device, dtype=images.dtype) + for i, image in enumerate(images): + if i == 4: + break + kh = i // 2 + kw = i % 2 + canvas[:, kh*H:(kh+1)*H, kw*W:(kw+1)*W] = image + vis.append(canvas) + vis = torch.stack(vis) + return {'image': {'value': vis, 'type': 'image'}} diff --git a/trellis2/trainers/flow_matching/mixins/image_conditioned_proj.py b/trellis2/trainers/flow_matching/mixins/image_conditioned_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9f326ab066db0c6c94cd31267d9edf10d4c64b --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/image_conditioned_proj.py @@ -0,0 +1,1530 @@ +""" +View-Aligned (Projection) Image Conditioned Mixin for TRELLIS2 + +This module implements DINOv3-based feature extraction with view-aligned projection, +supporting camera-aware 3D-to-2D feature mapping. +""" + +from typing import * +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from transformers import DINOv3ViTModel +import numpy as np +from PIL import Image, ImageDraw + +import torch.distributed as dist +from ....utils import dist_utils +from ....utils.dist_utils import read_file_dist + + +# ============================================================================= +# Projection Utilities +# ============================================================================= + +def project_points_to_image_batch( + points_3d: torch.Tensor, + transform_matrix: torch.Tensor, + camera_angle_x: torch.Tensor, + resolution: int = 518 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D image coordinates (batch processing). + + Args: + points_3d: torch.Tensor, shape [N, 3] or [B, N, 3], 3D point coordinates (in [-1, 1] range) + transform_matrix: torch.Tensor, shape [B, 4, 4], camera transformation matrix + camera_angle_x: torch.Tensor, shape [B], horizontal field of view angle (radians) + resolution: int, image resolution, default 518 + + Returns: + points_2d: torch.Tensor, shape [B, N, 2], image coordinates [x, y] + depth: torch.Tensor, shape [B, N], depth values + valid_mask: torch.Tensor, shape [B, N], mask for points within view + """ + device = points_3d.device + B = transform_matrix.shape[0] + + # Ensure inputs are torch.Tensor on correct device + if not isinstance(transform_matrix, torch.Tensor): + transform_matrix = torch.tensor(transform_matrix, dtype=torch.float32, device=device) + if not isinstance(points_3d, torch.Tensor): + points_3d = torch.tensor(points_3d, dtype=torch.float32, device=device) + if not isinstance(camera_angle_x, torch.Tensor): + camera_angle_x = torch.tensor(camera_angle_x, dtype=torch.float32, device=device) + + # Expand points_3d to batch dimension: [N, 3] -> [B, N, 3] + if points_3d.dim() == 2: + points_3d_batch = points_3d.unsqueeze(0).expand(B, -1, -1) + else: + points_3d_batch = points_3d + N = points_3d_batch.shape[1] + + # Add homogeneous coordinates: [B, N, 3] -> [B, N, 4] + ones = torch.ones(B, N, 1, device=device, dtype=points_3d_batch.dtype) + points_homogeneous = torch.cat([points_3d_batch, ones], dim=-1) # [B, N, 4] + + # Compute world to camera transformation matrix + world_to_camera = torch.linalg.inv(transform_matrix) # [B, 4, 4] + + # Batch transform to camera coordinate system: [B, N, 4] @ [B, 4, 4]^T -> [B, N, 3] + points_camera = torch.bmm(points_homogeneous, world_to_camera.transpose(-2, -1))[..., :3] # [B, N, 3] + + # Extract camera coordinates + x_cam = points_camera[..., 0] # [B, N] + y_cam = points_camera[..., 1] # [B, N] + z_cam = points_camera[..., 2] # [B, N] + + # Depth value (Z value in camera coordinate system, note Blender camera faces -Z direction) + depth = -z_cam # [B, N] + + # Compute camera intrinsics (batch processing) + sensor_width = 32.0 # mm + focal_length = 16.0 / torch.tan(camera_angle_x / 2.0) # [B] + focal_length_pixels = focal_length * resolution / sensor_width # [B] + + # Expand focal_length_pixels dimension for broadcasting: [B] -> [B, 1] + focal_length_pixels = focal_length_pixels.unsqueeze(1) # [B, 1] + + # Perspective projection to NDC coordinates + x_ndc = focal_length_pixels * x_cam / (-z_cam + 1e-8) # [B, N] + y_ndc = focal_length_pixels * y_cam / (-z_cam + 1e-8) # [B, N] + + # Convert to image coordinates (pixel coordinates) + x_pixel = x_ndc + resolution / 2.0 # [B, N] + y_pixel = -y_ndc + resolution / 2.0 # [B, N], flip Y axis + + # Create validity mask (points within image range and in front of camera) + valid_mask = ( + (x_pixel >= 0) & (x_pixel < resolution) & + (y_pixel >= 0) & (y_pixel < resolution) & + (depth > 0) # In front of camera + ) # [B, N] + + points_2d = torch.stack([x_pixel, y_pixel], dim=-1) # [B, N, 2] + + return points_2d, depth, valid_mask + + +def sample_features(fmap: torch.Tensor, queries_ndc: torch.Tensor) -> torch.Tensor: + """ + Sample features from feature map at specified NDC coordinates. + + Args: + fmap: torch.Tensor, shape [B, C, H, W], feature map + queries_ndc: torch.Tensor, shape [B, K, 2], normalized device coordinates + + Returns: + torch.Tensor, shape [B, C, K], sampled features + """ + B, C, H, W = fmap.shape + Bq, K, _ = queries_ndc.shape + assert Bq == B, "Batch size mismatch" + + # grid_sample requires (B, out_h, out_w, 2), here we want K points -> out_h=K, out_w=1 + grid = queries_ndc.view(B, K, 1, 2) # (B, K, 1, 2) + + # Bilinear interpolation, align_corners=False (consistent with [-1,1] pixel center convention) + feat = F.grid_sample( + fmap, grid, mode='bilinear', + align_corners=False, padding_mode='border' # border avoids out-of-bound becoming 0 + ) # (B, C, K, 1) + + return feat.squeeze(-1) # (B, C, K) + + +# ============================================================================= +# Projection Grid Module +# ============================================================================= + +class ProjGrid(nn.Module): + """ + 3D Grid Projection Module. + + Projects a 3D grid of points to 2D image coordinates and samples features + from the image feature map at those locations. + + This is the core module for view-aligned feature extraction. + """ + def __init__(self, grid_resolution: int = 16, image_resolution: int = 518): + super().__init__() + self.grid_resolution = grid_resolution + self.image_resolution = image_resolution + + # Create 3D grid points + one_dim = torch.linspace(-1, 1, grid_resolution) + x, y, z = torch.meshgrid(one_dim, one_dim, one_dim, indexing='ij') + grid_points = torch.stack((x, y, z), dim=-1) + + # Rotation matrix to align with Blender coordinate system + rotation_matrix = torch.tensor([ + [1.0, 0.0, 0.0], + [0.0, 0.0, -1.0], + [0.0, 1.0, 0.0] + ]) + grid_points = torch.matmul(grid_points, rotation_matrix.T) + grid_points = grid_points.reshape(-1, 3) + self.register_buffer('grid_points', grid_points) # [R³, 3] + + # Default front view transformation matrix + front_view_transform_matrix = torch.tensor([ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, -1.0, -2.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ]) + self.register_buffer("front_view_transform_matrix", front_view_transform_matrix) + + def forward( + self, + features_map: torch.Tensor, + camera_angle_x: torch.Tensor, + distance: torch.Tensor, + mesh_scale: torch.Tensor, + transform_matrix: Optional[torch.Tensor] = None, + BHWC: bool = True + ) -> torch.Tensor: + """ + Project 3D grid points to image and sample features. + + Args: + features_map: Feature map, shape [B, H, W, C] if BHWC else [B, C, H, W] + camera_angle_x: Camera FOV angle, shape [B] + distance: Camera distance, shape [B] + mesh_scale: Mesh scale factor, shape [B] + transform_matrix: Optional camera transform matrix, shape [B, 4, 4] + BHWC: Whether features_map is in BHWC format + + Returns: + Projected features, shape [B, grid_resolution³, C] + """ + if BHWC: + B, H, W, C = features_map.shape + else: + B, C, H, W = features_map.shape + + grid_points = self.grid_points + grid_points = grid_points.expand(B, -1, -1) + grid_points = grid_points / mesh_scale.unsqueeze(-1).unsqueeze(-1) / 2 # Scale alignment + assert transform_matrix is None, "transform_matrix is not None" + if transform_matrix is None: + transform_matrix = self.front_view_transform_matrix + transform_matrix = transform_matrix.expand(B, -1, -1).clone() + transform_matrix[:, 1, 3] = -distance # Set camera distance + + # Project to image coordinates (simulate Blender projection) + image_points, depth, valid_mask = project_points_to_image_batch( + grid_points, transform_matrix, camera_angle_x, self.image_resolution + ) + + # Normalize to [-1, 1] for grid_sample + image_points_norm = (image_points + 0.5) / self.image_resolution * 2 - 1 + + if BHWC: + features_map = features_map.permute(0, 3, 1, 2) # [B, C, H, W] + + # Sample features from DINOv3 patch feature map + x = sample_features(features_map, image_points_norm) # [B, C, K] + x = x.permute(0, 2, 1) # [B, K, C] + + return x + + def visualize_projection( + self, + image: torch.Tensor, + camera_angle_x: torch.Tensor, + distance: torch.Tensor, + mesh_scale: torch.Tensor, + transform_matrix: Optional[torch.Tensor] = None, + save_dir: Optional[str] = None, + prefix: str = "proj_vis", + ) -> List[Image.Image]: + """ + Visualize the projected 3D grid points on the input image. + + Args: + image: Input image tensor [B, C, H, W], assumed to be in [0, 1] range + camera_angle_x: Camera FOV angle, shape [B] + distance: Camera distance, shape [B] + mesh_scale: Mesh scale factor, shape [B] + transform_matrix: Optional camera transform matrix, shape [B, 4, 4] + save_dir: Directory to save visualizations (optional) + prefix: Prefix for saved files + + Returns: + List of PIL Images with projected points overlaid + """ + B = image.shape[0] + + # Get projected points + grid_points = self.grid_points.expand(B, -1, -1) + grid_points = grid_points / mesh_scale.unsqueeze(-1).unsqueeze(-1) / 2 + assert transform_matrix is None, "transform_matrix is not None" + if transform_matrix is None: + transform_matrix = self.front_view_transform_matrix + transform_matrix = transform_matrix.expand(B, -1, -1).clone() + transform_matrix[:, 1, 3] = -distance + + image_points, depth, valid_mask = project_points_to_image_batch( + grid_points, transform_matrix, camera_angle_x, self.image_resolution + ) + + # Convert image to PIL for visualization + vis_images = [] + for b in range(B): + # Convert tensor to PIL image + img_np = image[b].cpu().permute(1, 2, 0).numpy() + img_np = (img_np * 255).clip(0, 255).astype(np.uint8) + + # Resize to image_resolution if needed + pil_img = Image.fromarray(img_np) + if pil_img.size != (self.image_resolution, self.image_resolution): + pil_img = pil_img.resize((self.image_resolution, self.image_resolution), Image.LANCZOS) + + # Create a copy for drawing + vis_img = pil_img.copy() + draw = ImageDraw.Draw(vis_img) + + # Get points for this batch + pts = image_points[b].cpu().numpy() # [K, 2] + depths = depth[b].cpu().numpy() # [K] + mask = valid_mask[b].cpu().numpy() # [K] + + # Normalize depth for coloring + valid_depths = depths[mask] + if len(valid_depths) > 0: + d_min, d_max = valid_depths.min(), valid_depths.max() + if d_max - d_min > 1e-6: + depths_norm = (depths - d_min) / (d_max - d_min) + else: + depths_norm = np.ones_like(depths) * 0.5 + else: + depths_norm = np.ones_like(depths) * 0.5 + + # Draw projected points + R = self.grid_resolution + for i, (pt, d, m, dn) in enumerate(zip(pts, depths, mask, depths_norm)): + if not m: + continue + + x, y = pt + + # Color by depth (blue=near, red=far) + r = int(255 * dn) + g = int(255 * (1 - abs(2 * dn - 1))) + b_color = int(255 * (1 - dn)) + color = (r, g, b_color) + + # Draw small circle + radius = 2 + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=color, + outline=color + ) + + vis_images.append(vis_img) + + # Save if directory is specified + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, f"{prefix}_batch{b}.png") + vis_img.save(save_path) + print(f"Saved projection visualization to: {save_path}") + + return vis_images + + +# ============================================================================= +# DINOv3 Feature Extractor with Projection +# ============================================================================= + +class DinoV3ProjFeatureExtractor(nn.Module): + """ + DINOv3 Feature Extractor with View-Aligned Projection. + + This extractor produces both: + 1. Global features (CLS token + register tokens) in embed_dim + 2. View-aligned projected features (3D grid projected to 2D and sampled) + - Without NAF: [B, R³, embed_dim] + - With NAF: [B, R³, embed_dim * 2] (concat of lr and hr features) + + NOTE: proj_linear has been moved to per-block ProjectAttention / SparseProjectAttention. + This module now outputs raw DINOv3 features for proj (optionally concatenated with NAF-upsampled features). + + Args: + model_name: Name of the pretrained DINOv3 model + image_size: Input image size (default: 512) + grid_resolution: Resolution of the 3D projection grid (default: 16) + use_naf_upsample: Whether to use NAF to upsample features (default: False) + naf_target_size: Target spatial size for NAF upsampling (default: [128, 128]) + """ + def __init__( + self, + model_name: str, + image_size: int = 512, + grid_resolution: int = 16, + use_naf_upsample: bool = False, + naf_target_size: Optional[List[int]] = None, + ): + super().__init__() + self.model_name = model_name + self.image_size = image_size + self.grid_resolution = grid_resolution + self.use_naf_upsample = use_naf_upsample + if naf_target_size is None: + self.naf_target_size = (128, 128) + elif isinstance(naf_target_size, int): + self.naf_target_size = (naf_target_size, naf_target_size) + else: + self.naf_target_size = tuple(naf_target_size) + + # Load DINOv3 model (frozen, no trainable params in this module) + self.model = DINOv3ViTModel.from_pretrained(model_name) + self.model.eval() + self.model.requires_grad_(False) + + # Image transform (only normalize, no resize - assume already resized) + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + # Get patch info + self.patch_size = self.model.config.patch_size + self.patch_number = image_size // self.patch_size + self.embed_dim = self.model.config.hidden_size + + # Projection grid for view-aligned features + self.proj_grid = ProjGrid( + grid_resolution=grid_resolution, + image_resolution=image_size + ) + + # NAF upsampler (frozen, no trainable params) + self.naf_model = None # Lazy-loaded on first use to avoid import if not needed + + # proj_channels: the output dimension of proj features + # Without NAF: embed_dim (e.g. 1024) + # With NAF: embed_dim * 2 (e.g. 2048, concat of lr and hr) + self.proj_channels = self.embed_dim * 2 if use_naf_upsample else self.embed_dim + + # NOTE: proj_linear removed — now lives in each denoiser block's ProjectAttention + + def _load_naf(self): + """Lazy-load pretrained NAF model.""" + if self.naf_model is None: + import torch.hub + device = next(self.model.parameters()).device + self.naf_model = torch.hub.load( + "valeoai/NAF", "naf", pretrained=True, device=device, trust_repo=True + ) + self.naf_model.eval() + self.naf_model.requires_grad_(False) + + def to(self, device): + super().to(device) + self.model.to(device) + self.proj_grid.to(device) + if self.naf_model is not None: + self.naf_model.to(device) + return self + + def cuda(self): + super().cuda() + self.model.cuda() + self.proj_grid.cuda() + if self.naf_model is not None: + self.naf_model.cuda() + return self + + def cpu(self): + super().cpu() + self.model.cpu() + self.proj_grid.cpu() + if self.naf_model is not None: + self.naf_model.cpu() + return self + + def extract_features(self, image: torch.Tensor) -> torch.Tensor: + """Extract features using DINOv3.""" + image = image.to(self.model.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.model.embeddings(image, bool_masked_pos=None) + position_embeddings = self.model.rope_embeddings(image) + + for layer_module in self.model.layer: + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + return F.layer_norm(hidden_states, hidden_states.shape[-1:]) + + def forward( + self, + image: Union[torch.Tensor, List[Image.Image]], + camera_angle_x: Optional[torch.Tensor] = None, + distance: Optional[torch.Tensor] = None, + mesh_scale: Optional[torch.Tensor] = None, + transform_matrix: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract view-aligned features from the image. + + Args: + image: Input image tensor [B, C, H, W] or list of PIL images + camera_angle_x: Camera FOV angle in radians [B] + distance: Camera distance [B] + mesh_scale: Mesh scale factor [B] + transform_matrix: Optional camera transform matrix [B, 4, 4] + + Returns: + Tuple of (global_features, proj_features): + - global_features: [B, num_global_tokens, embed_dim] + - proj_features: [B, grid_resolution³, proj_channels] + where proj_channels = embed_dim (no NAF) or embed_dim*2 (with NAF) + """ + # Handle input types + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + B = image.shape[0] + + # Keep a copy of the unnormalized image for NAF guide + if self.use_naf_upsample: + image_for_naf = image.clone() # [B, 3, H, W], in [0, 1] range + + # Apply transform (ImageNet normalization) + image = self.transform(image) + + # Extract DINOv3 features (frozen, no gradients) + with torch.no_grad(): + z = self.extract_features(image) + + # Split into CLS token, register tokens, and patch tokens + z_clstoken = z[:, 0:1] # [B, 1, D] + num_reg = getattr(self.model.config, 'num_register_tokens', 4) + z_regtokens = z[:, 1:1+num_reg] # [B, num_reg, D] + z_patchtokens = z[:, 1+num_reg:] # [B, num_patches, D] + + # Reshape patch tokens to spatial grid: [B, h, w, D] + z_patchtokens_spatial = z_patchtokens.reshape( + B, self.patch_number, self.patch_number, -1 + ) # [B, h, w, D] + + if camera_angle_x is None or distance is None or mesh_scale is None: + raise ValueError("camera_angle_x, distance, and mesh_scale must be provided") + + # --- Low-resolution branch: sample from DINOv3 patch feature map --- + z_proj_lr = self.proj_grid( + z_patchtokens_spatial, + camera_angle_x, + distance, + mesh_scale, + transform_matrix + ) # [B, grid_res³, D] + + # --- High-resolution branch (NAF): upsample then sample --- + if self.use_naf_upsample: + self._load_naf() + # NAF expects: guide [B, 3, H, W], lr_features [B, C, h, w], target_size (H', W') + lr_features_bchw = z_patchtokens_spatial.permute(0, 3, 1, 2) # [B, D, h, w] + hr_features = self.naf_model( + image_for_naf, lr_features_bchw, self.naf_target_size + ) # [B, D, H', W'] + + # Sample from high-res feature map using same projection coordinates + z_proj_hr = self.proj_grid( + hr_features, + camera_angle_x, + distance, + mesh_scale, + transform_matrix, + BHWC=False # hr_features is [B, C, H', W'] + ) # [B, grid_res³, D] + + # Concatenate lr and hr: [B, grid_res³, D*2] + z_proj = torch.cat([z_proj_lr, z_proj_hr], dim=-1) + else: + z_proj = z_proj_lr # [B, grid_res³, D] + + # Combine global tokens + z_global = torch.cat([z_clstoken, z_regtokens], dim=1) # [B, 1+num_reg, D] + + # proj_linear has been moved to per-block ProjectAttention + # z_proj stays in proj_channels, each block will project independently + + return z_global, z_proj + + @torch.no_grad() + def visualize_projection( + self, + image: torch.Tensor, + camera_angle_x: torch.Tensor, + distance: torch.Tensor, + mesh_scale: torch.Tensor, + transform_matrix: Optional[torch.Tensor] = None, + save_dir: Optional[str] = None, + prefix: str = "proj_vis", + ) -> List[Image.Image]: + """ + Visualize the projected 3D grid points on the input image. + + This is a convenience method that delegates to ProjGrid.visualize_projection. + + Args: + image: Input image tensor [B, C, H, W], in [0, 1] range (before ImageNet normalization) + camera_angle_x: Camera FOV angle, shape [B] + distance: Camera distance, shape [B] + mesh_scale: Mesh scale factor, shape [B] + transform_matrix: Optional camera transform matrix, shape [B, 4, 4] + save_dir: Directory to save visualizations (optional) + prefix: Prefix for saved files + + Returns: + List of PIL Images with projected points overlaid + """ + return self.proj_grid.visualize_projection( + image=image, + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + transform_matrix=transform_matrix, + save_dir=save_dir, + prefix=prefix, + ) + + +# ============================================================================= +# DINOv3 + VAE Gated Feature Extractor with Projection +# ============================================================================= + +class DinoV3VaeProjFeatureExtractor(nn.Module): + """ + DINOv3 + Flux VAE Feature Extractor with Gated Fusion and View-Aligned Projection. + + Produces three outputs for GatedProjectAttention: + 1. Global features (CLS + register tokens from DINOv3) for cross-attention + 2. Semantic proj features (DINOv3 patch tokens projected to 3D grid) + 3. Color proj features (Flux VAE latent projected to 3D grid) + + Both DINOv3 and VAE are frozen. The gated fusion happens inside each + denoiser block's GatedProjectAttention module (trainable gate + proj_linears). + + Args: + dino_model_name: Pretrained DINOv3 model name + vae_model_name: Pretrained Flux VAE model name + image_size: Input image size (default: 512) + grid_resolution: Resolution of the 3D projection grid (default: 16) + """ + def __init__( + self, + dino_model_name: str, + vae_model_name: str = "black-forest-labs/FLUX.1-dev", + image_size: int = 512, + grid_resolution: int = 16, + ): + super().__init__() + self.image_size = image_size + self.grid_resolution = grid_resolution + + # --- DINOv3 backbone (frozen) --- + self.dino_model = DINOv3ViTModel.from_pretrained(dino_model_name) + self.dino_model.eval() + self.dino_model.requires_grad_(False) + + self.dino_transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + self.patch_size = self.dino_model.config.patch_size + self.patch_number = image_size // self.patch_size + self.embed_dim = self.dino_model.config.hidden_size # e.g. 1024 + + # --- Flux VAE encoder (frozen, lazy-loaded) --- + self.vae_model_name = vae_model_name + self._vae = None + self.vae_channels = 16 # Flux VAE outputs 16 channels + self.vae_downsample = 8 # Flux VAE downsamples by 8x + + # --- Projection grid (shared) --- + self.proj_grid = ProjGrid( + grid_resolution=grid_resolution, + image_resolution=image_size, + ) + + # Expose dimensions for denoiser block construction + self.dino_proj_channels = self.embed_dim # e.g. 1024 + self.vae_proj_channels = self.vae_channels # 16 + # proj_channels is kept for backward compat with _proj_channels in mixin + self.proj_channels = self.embed_dim + + def _load_vae(self): + """Lazy-load Flux VAE encoder.""" + if self._vae is not None: + return + from diffusers import AutoencoderKL + device = next(self.dino_model.parameters()).device + vae = AutoencoderKL.from_pretrained( + self.vae_model_name, + subfolder="vae", + torch_dtype=torch.float32, + ) + vae.eval() + vae.requires_grad_(False) + vae.to(device) + self._vae = vae + + def to(self, device): + super().to(device) + self.dino_model.to(device) + self.proj_grid.to(device) + if self._vae is not None: + self._vae.to(device) + return self + + def cuda(self): + super().cuda() + self.dino_model.cuda() + self.proj_grid.cuda() + if self._vae is not None: + self._vae.cuda() + return self + + def cpu(self): + super().cpu() + self.dino_model.cpu() + self.proj_grid.cpu() + if self._vae is not None: + self._vae.cpu() + return self + + def _extract_dino_features(self, image: torch.Tensor) -> torch.Tensor: + """Extract DINOv3 features from normalized image.""" + image = image.to(self.dino_model.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.dino_model.embeddings(image, bool_masked_pos=None) + position_embeddings = self.dino_model.rope_embeddings(image) + for layer_module in self.dino_model.layer: + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + return F.layer_norm(hidden_states, hidden_states.shape[-1:]) + + @torch.no_grad() + def _extract_vae_latent(self, image: torch.Tensor) -> torch.Tensor: + """Extract Flux VAE latent from unnormalized image [0,1].""" + self._load_vae() + image_normalized = image * 2.0 - 1.0 + image_normalized = image_normalized.to(self._vae.dtype) + posterior = self._vae.encode(image_normalized) + latent = posterior.latent_dist.mode() + latent = latent * self._vae.config.scaling_factor + return latent.float() # [B, 16, H/8, W/8] + + def forward( + self, + image: Union[torch.Tensor, List[Image.Image]], + camera_angle_x: Optional[torch.Tensor] = None, + distance: Optional[torch.Tensor] = None, + mesh_scale: Optional[torch.Tensor] = None, + transform_matrix: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Extract gated features from the image. + + Returns: + Tuple of (global_features, proj_semantic, proj_color): + - global_features: [B, num_global_tokens, embed_dim] (DINOv3 CLS + registers) + - proj_semantic: [B, grid_res³, embed_dim] (DINOv3 projected features) + - proj_color: [B, grid_res³, vae_channels] (VAE projected features) + """ + # Handle input types + if isinstance(image, torch.Tensor): + assert image.ndim == 4 + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image) + image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + B = image.shape[0] + image_raw = image.clone() # Keep unnormalized copy for VAE + + if camera_angle_x is None or distance is None or mesh_scale is None: + raise ValueError("camera_angle_x, distance, and mesh_scale must be provided") + + with torch.no_grad(): + # --- DINOv3 branch --- + dino_input = self.dino_transform(image) + z = self._extract_dino_features(dino_input) + + z_clstoken = z[:, 0:1] + num_reg = getattr(self.dino_model.config, 'num_register_tokens', 4) + z_regtokens = z[:, 1:1+num_reg] + z_patchtokens = z[:, 1+num_reg:] + + z_patchtokens_spatial = z_patchtokens.reshape( + B, self.patch_number, self.patch_number, -1 + ) # [B, h, w, D] + + proj_semantic = self.proj_grid( + z_patchtokens_spatial, + camera_angle_x, distance, mesh_scale, transform_matrix, + ) # [B, R³, embed_dim] + + z_global = torch.cat([z_clstoken, z_regtokens], dim=1) # [B, 1+num_reg, D] + + # --- VAE branch --- + vae_latent = self._extract_vae_latent(image_raw) # [B, 16, H/8, W/8] + + proj_color = self.proj_grid( + vae_latent, + camera_angle_x, distance, mesh_scale, transform_matrix, + BHWC=False, # VAE latent is [B, C, H, W] + ) # [B, R³, 16] + + return z_global, proj_semantic, proj_color + + +# ============================================================================= +# Image Conditioned Mixin with Projection Support +# ============================================================================= + +class ImageConditionedProjMixin: + """ + Mixin for image-conditioned models with view-aligned projection. + + This mixin adds support for extracting view-aligned features from images + using camera parameters. + + Args: + image_cond_model: Configuration for the image conditioning model. + """ + def __init__(self, *args, image_cond_model: dict, **kwargs): + # Store config before super().__init__ which calls init_models_and_more + self.image_cond_model_config = image_cond_model + self.image_cond_model = None # Will be initialized in init_models_and_more + self.image_attn_mode = image_cond_model.get('image_attn_mode', + image_cond_model.get('args', {}).get('image_attn_mode', 'cross')) + super().__init__(*args, **kwargs) + + def _init_image_cond_model(self): + """Initialize the image conditioning model.""" + with dist_utils.local_master_first(): + model_name = self.image_cond_model_config['name'] + model_args = self.image_cond_model_config.get('args', {}) + + if model_name == 'DinoV3ProjFeatureExtractor': + self.image_cond_model = DinoV3ProjFeatureExtractor(**model_args) + elif model_name == 'DinoV3VaeProjFeatureExtractor': + self.image_cond_model = DinoV3VaeProjFeatureExtractor(**model_args) + else: + # Fallback to standard extractors + from . import image_conditioned + self.image_cond_model = getattr(image_conditioned, model_name)(**model_args) + + self.image_cond_model.cuda() + + # Expose proj_channels for denoiser to know the correct proj_in_channels + if hasattr(self.image_cond_model, 'proj_channels'): + self._proj_channels = self.image_cond_model.proj_channels + else: + self._proj_channels = getattr(self.image_cond_model, 'embed_dim', None) + # Expose vae_proj_channels for gated_proj mode + self._vae_proj_channels = getattr(self.image_cond_model, 'vae_proj_channels', None) + + def init_models_and_more(self, **kwargs): + """ + Override to handle image_cond_model initialization. + + Since proj_linear has been moved to per-block ProjectAttention in the denoiser, + image_cond_model no longer has any trainable parameters (DINOv3 backbone is frozen, + ProjGrid only has register_buffers). Therefore we do NOT add it to self.models + (which would trigger DDP wrapping and fail). We just initialize it and keep it + as a standalone module for inference. + """ + # Initialize image_cond_model first + if self.image_cond_model is None: + self._init_image_cond_model() + + # Keep a reference to the unwrapped module for attribute access + self._image_cond_module = self.image_cond_model # for .grid_resolution etc. + + # Log that image_cond has no trainable params + proj_params = [p for p in self.image_cond_model.parameters() if p.requires_grad] + if self.is_master: + if proj_params: + print(f'\nWARNING: image_cond_model has {len(proj_params)} trainable params, ' + f'but is NOT registered in self.models. These will NOT be trained!') + else: + print(f'\nimage_cond_model has no trainable parameters, skipping DDP/optimizer registration.') + + # Call base class to set up DDP, optimizer, EMA, etc. (without image_cond) + super().init_models_and_more(**kwargs) + + # ------------------------------------------------------------------ + # Checkpoint save/load overrides: skip DINOv3 backbone weights + # ------------------------------------------------------------------ + + # Keys in image_cond state_dict that belong to the frozen DINOv3 backbone. + # Everything under "model." is DINOv3; we only keep proj_grid.* + _IMAGE_COND_BACKBONE_PREFIX = 'model.' + + def _filter_image_cond_state_dict(self, state_dict: dict) -> dict: + """Keep only non-backbone keys (proj_grid, etc.) from image_cond state_dict.""" + return {k: v for k, v in state_dict.items() + if not k.startswith(self._IMAGE_COND_BACKBONE_PREFIX)} + + def _fill_denoiser_proj_linear_from_image_cond( + self, + denoiser_ckpt: dict, + denoiser_state_dict: dict, + image_cond_ckpt_path: Optional[str] = None, + ) -> dict: + """ + Fill missing per-block proj_linear weights in denoiser checkpoint + from the old-style image_cond proj_linear (broadcast to all blocks). + + Also handles shape mismatch when NAF is enabled: old proj_linear has shape + [model_ch, embed_dim] but new model expects [model_ch, embed_dim*2]. + In this case, the old weights are placed in the lr half and the hr half is zero-padded. + + Compatibility strategy: + 1. If denoiser_ckpt already contains per-block proj_linear keys with correct shape -> do nothing. + 2. If shape mismatch (embed_dim vs embed_dim*2) -> zero-pad the weight. + 3. If keys missing, try to load proj_linear from image_cond checkpoint -> broadcast (with optional pad). + + Args: + denoiser_ckpt: The loaded denoiser state dict + denoiser_state_dict: The model's current state dict (to find expected keys) + image_cond_ckpt_path: Path to image_cond checkpoint file (optional) + + Returns: + Updated denoiser_ckpt with proj_linear keys filled if needed + """ + if self.image_attn_mode != 'proj': + return denoiser_ckpt + + # Find all per-block proj_linear keys expected by the model + proj_linear_keys = [k for k in denoiser_state_dict.keys() + if '.cross_attn.proj_linear.' in k] + if not proj_linear_keys: + return denoiser_ckpt + + # --- Phase 1: Handle shape mismatch for existing keys (NAF upgrade) --- + for k in proj_linear_keys: + if k in denoiser_ckpt: + expected_shape = denoiser_state_dict[k].shape + actual_shape = denoiser_ckpt[k].shape + if expected_shape != actual_shape: + if k.endswith('.weight') and len(expected_shape) == 2: + # Weight shape: [out_features, in_features] + # Old: [model_ch, embed_dim], New: [model_ch, embed_dim*2] + out_f, new_in_f = expected_shape + _, old_in_f = actual_shape + if new_in_f > old_in_f and out_f == actual_shape[0]: + if self.is_master: + print(f'\n [NAF Compat] Padding proj_linear weight {k}: ' + f'{actual_shape} -> {expected_shape} (zero-pad hr half)') + new_w = torch.zeros(expected_shape, dtype=denoiser_ckpt[k].dtype, + device=denoiser_ckpt[k].device) + new_w[:, :old_in_f] = denoiser_ckpt[k] + denoiser_ckpt[k] = new_w + else: + if self.is_master: + print(f'\n Warning: proj_linear {k} shape mismatch ' + f'{actual_shape} vs {expected_shape}, using model init') + denoiser_ckpt[k] = denoiser_state_dict[k] + # bias shape should match (out_features only), no padding needed + + # --- Phase 2: Handle completely missing keys --- + missing_proj_keys = [k for k in proj_linear_keys if k not in denoiser_ckpt] + if not missing_proj_keys: + return denoiser_ckpt + + if self.is_master: + print(f'\n [Compat] Denoiser ckpt missing {len(missing_proj_keys)} per-block proj_linear keys.') + print(f' Attempting to load from image_cond proj_linear: {image_cond_ckpt_path}') + + # Try to find proj_linear weights from image_cond checkpoint + old_proj_linear_w = None + old_proj_linear_b = None + + if image_cond_ckpt_path is not None: + import os as _os + if _os.path.exists(image_cond_ckpt_path): + try: + ic_ckpt = torch.load(image_cond_ckpt_path, map_location=self.device, weights_only=True) + old_proj_linear_w = ic_ckpt.get('proj_linear.weight') + old_proj_linear_b = ic_ckpt.get('proj_linear.bias') + except Exception as e: + if self.is_master: + print(f' Warning: Failed to load image_cond ckpt: {e}') + + if old_proj_linear_w is None: + raise RuntimeError( + f'Denoiser checkpoint is missing per-block proj_linear keys ' + f'(e.g. {missing_proj_keys[0]}), and no image_cond proj_linear ' + f'was found to broadcast from. Cannot proceed.' + ) + + if self.is_master: + print(f' Found image_cond proj_linear: weight {old_proj_linear_w.shape}, bias {old_proj_linear_b.shape}') + print(f' Broadcasting to {len(missing_proj_keys)} per-block keys...') + + for k in missing_proj_keys: + if k.endswith('.weight'): + expected_shape = denoiser_state_dict[k].shape + if expected_shape != old_proj_linear_w.shape: + # Pad for NAF: [model_ch, embed_dim] -> [model_ch, embed_dim*2] + out_f, new_in_f = expected_shape + _, old_in_f = old_proj_linear_w.shape + if new_in_f > old_in_f and out_f == old_proj_linear_w.shape[0]: + new_w = torch.zeros(expected_shape, dtype=old_proj_linear_w.dtype) + new_w[:, :old_in_f] = old_proj_linear_w + denoiser_ckpt[k] = new_w + else: + denoiser_ckpt[k] = denoiser_state_dict[k] + else: + denoiser_ckpt[k] = old_proj_linear_w.clone() + elif k.endswith('.bias'): + denoiser_ckpt[k] = old_proj_linear_b.clone() + + return denoiser_ckpt + + def _master_params_to_state_dicts(self, master_params): + """Override to skip image_cond checkpoint entirely. + + image_cond model no longer has trainable parameters: + - proj_linear has been moved to per-block ProjectAttention in the denoiser + - DINOv3 backbone is frozen and loaded from pretrained weights + - ProjGrid only contains fixed register_buffers (grid_points, front_view_transform_matrix) + So there is nothing worth saving for image_cond. + """ + state_dicts = super()._master_params_to_state_dicts(master_params) + state_dicts.pop('image_cond', None) + return state_dicts + + def load(self, load_dir, step=0): + """ + Override to handle: + 1. Old checkpoints that don't have image_cond_step*.pt + 2. Partial image_cond checkpoints (only proj_linear + proj_grid, no DINOv3 backbone) + """ + import os as _os + + if self.is_master: + print(f'\nLoading checkpoint from step {step}...', end='') + + model_ckpts = {} + for name, model in self.models.items(): + ckpt_path = _os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt') + + if name == 'image_cond': + # --- handle missing or partial image_cond checkpoint --- + if not _os.path.exists(ckpt_path): + if self.is_master: + print(f'\n image_cond checkpoint not found at {ckpt_path}, using freshly initialised weights.') + model_ckpts[name] = model.state_dict() + continue + + try: + model_ckpt = torch.load( + read_file_dist(ckpt_path), + map_location=self.device, weights_only=True) + except Exception as e: + if self.is_master: + print(f'\n Failed to load image_cond checkpoint: {e}. Using freshly initialised weights.') + model_ckpts[name] = model.state_dict() + continue + + # Partial ckpt (no backbone) → load with strict=False + missing, unexpected = model.load_state_dict(model_ckpt, strict=False) + # All missing keys should be the frozen DINOv3 backbone; verify + non_backbone_missing = [k for k in missing + if not k.startswith(self._IMAGE_COND_BACKBONE_PREFIX)] + if non_backbone_missing and self.is_master: + print(f'\n Warning: unexpected missing keys in image_cond ckpt: {non_backbone_missing}') + if unexpected and self.is_master: + print(f'\n Warning: unexpected keys in image_cond ckpt: {unexpected}') + + # Build a full state_dict for master_params sync + full_sd = model.state_dict() + full_sd.update(model_ckpt) + model_ckpts[name] = full_sd + else: + model_ckpt = torch.load( + read_file_dist(ckpt_path), + map_location=self.device, weights_only=True) + # For denoiser: handle old ckpts missing per-block proj_linear + if name == 'denoiser': + ic_ckpt_path = _os.path.join(load_dir, 'ckpts', f'image_cond_step{step:07d}.pt') + model_ckpt = self._fill_denoiser_proj_linear_from_image_cond( + model_ckpt, model.state_dict(), ic_ckpt_path) + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + + self._state_dicts_to_master_params(self.master_params, model_ckpts) + del model_ckpts + + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = {} + for name, model in self.models.items(): + ema_path = _os.path.join( + load_dir, 'ckpts', + f'{name}_ema{ema_rate}_step{step:07d}.pt') + if name == 'image_cond': + if not _os.path.exists(ema_path): + ema_ckpts[name] = model.state_dict() + continue + try: + ema_ckpt = torch.load(ema_path, map_location=self.device, weights_only=True) + except Exception: + ema_ckpts[name] = model.state_dict() + continue + full_sd = model.state_dict() + full_sd.update(ema_ckpt) + ema_ckpts[name] = full_sd + else: + ema_ckpt = torch.load(ema_path, map_location=self.device, weights_only=True) + if name == 'denoiser': + ic_ema_path = _os.path.join( + load_dir, 'ckpts', + f'image_cond_ema{ema_rate}_step{step:07d}.pt') + ema_ckpt = self._fill_denoiser_proj_linear_from_image_cond( + ema_ckpt, model.state_dict(), ic_ema_path) + ema_ckpts[name] = ema_ckpt + self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) + del ema_ckpts + + misc_ckpt = torch.load( + read_file_dist(_os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), + map_location=torch.device('cpu'), weights_only=False) + # Optimizer state may mismatch when loading old checkpoints that were + # saved before image_cond was added to self.models, or when the number + # of trainable parameters changed (e.g. backbone freeze, NAF upgrade). + # In that case we skip restoring optimizer state and let it re-initialise. + try: + self.optimizer.load_state_dict(misc_ckpt['optimizer']) + # Verify optimizer state shapes match parameters. + # load_state_dict may succeed even when shapes mismatch (keys are + # integer indices), causing a crash later in optimizer.step(). + _shape_ok = True + for group in self.optimizer.param_groups: + for p in group['params']: + state = self.optimizer.state.get(p) + if state is not None: + for sv in state.values(): + if isinstance(sv, torch.Tensor) and sv.shape != () and sv.shape != p.shape: + _shape_ok = False + break + if not _shape_ok: + break + if not _shape_ok: + break + if not _shape_ok: + if self.is_master: + print(f'\n Warning: optimizer state shape mismatch (likely NAF upgrade). ' + f'Optimizer will start fresh.') + self.optimizer.state.clear() + except (ValueError, RuntimeError) as e: + if self.is_master: + print(f'\n Warning: could not load optimizer state ({e}). ' + f'Optimizer will start fresh.') + self.step = misc_ckpt['step'] + self.data_sampler.load_state_dict(misc_ckpt['data_sampler']) + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.load_state_dict(misc_ckpt['scaler']) + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + self.log_scale = misc_ckpt['log_scale'] + if self.lr_scheduler_config is not None: + self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler']) + if self.elastic_controller_config is not None: + self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller']) + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + self.grad_clip.load_state_dict(misc_ckpt['grad_clip']) + del misc_ckpt + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print(' Done.') + + if self.world_size > 1: + self.check_ddp() + + def finetune_from(self, finetune_ckpt): + """ + Override to tolerate DINOv3 backbone keys missing from image_cond checkpoint. + For image_cond, the checkpoint only stores proj_linear + proj_grid (no backbone), + so we treat all backbone keys as allowed-missing. + """ + ALLOWED_MISSING_KEYS = {'rope_phases'} + + if self.is_master: + print('\nFinetuning from:') + for name, path in finetune_ckpt.items(): + print(f' - {name}: {path}') + + model_ckpts = {} + for name, model in self.models.items(): + model_state_dict = model.state_dict() + if name in finetune_ckpt: + model_ckpt = torch.load( + read_file_dist(finetune_ckpt[name]), + map_location=self.device, weights_only=True) + + model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict) + + for k, v in model_ckpt.items(): + if k not in model_state_dict: + if self.is_master: + print(f'Warning: {k} not found in model_state_dict, skipped.') + model_ckpt[k] = None + elif model_ckpt[k].shape != model_state_dict[k].shape: + # For proj_linear weights, try zero-pad instead of skipping + # This handles NAF upgrade: [model_ch, embed_dim] -> [model_ch, embed_dim*2] + if '.cross_attn.proj_linear.weight' in k and len(model_ckpt[k].shape) == 2: + old_shape = model_ckpt[k].shape + new_shape = model_state_dict[k].shape + if new_shape[0] == old_shape[0] and new_shape[1] > old_shape[1]: + if self.is_master: + print(f'Info: Zero-padding proj_linear weight {k}: {old_shape} -> {new_shape}') + new_w = torch.zeros(new_shape, dtype=model_ckpt[k].dtype) + new_w[:, :old_shape[1]] = model_ckpt[k] + model_ckpt[k] = new_w + else: + if self.is_master: + print(f'Warning: {k} shape mismatch, {old_shape} vs {new_shape}, skipped.') + model_ckpt[k] = model_state_dict[k] + else: + if self.is_master: + print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.') + model_ckpt[k] = model_state_dict[k] + model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None} + + missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys()) + + # For denoiser: fill per-block proj_linear from image_cond if missing + if name == 'denoiser': + ic_path = finetune_ckpt.get('image_cond') + proj_linear_missing = {k for k in missing_keys if '.cross_attn.proj_linear.' in k} + if proj_linear_missing: + model_ckpt = self._fill_denoiser_proj_linear_from_image_cond( + model_ckpt, model_state_dict, ic_path) + # Recalculate missing_keys after filling + missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys()) + + # For image_cond, DINOv3 backbone keys are expected to be missing + allowed = set(ALLOWED_MISSING_KEYS) + if name == 'image_cond': + backbone_missing = {k for k in missing_keys + if k.startswith(self._IMAGE_COND_BACKBONE_PREFIX)} + allowed |= backbone_missing + if backbone_missing and self.is_master: + print(f'Info: image_cond: {len(backbone_missing)} DINOv3 backbone keys ' + f'not in ckpt (expected, using pretrained weights)') + # Old ckpts may have proj_linear.* which has moved to denoiser blocks + proj_linear_missing = {k for k in missing_keys if k.startswith('proj_linear.')} + allowed |= proj_linear_missing + + unexpected_missing = missing_keys - allowed + if unexpected_missing and self.is_master: + print(f'Error: Missing keys in checkpoint: {unexpected_missing}') + raise RuntimeError(f'Missing keys in checkpoint: {unexpected_missing}') + if missing_keys & ALLOWED_MISSING_KEYS and self.is_master: + print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}') + + for k in missing_keys: + model_ckpt[k] = model_state_dict[k] + + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + else: + if self.is_master: + print(f'Warning: {name} not found in finetune_ckpt, skipped.') + model_ckpts[name] = model_state_dict + + self._state_dicts_to_master_params(self.master_params, model_ckpts) + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + self._state_dicts_to_master_params(self.ema_params[i], model_ckpts) + del model_ckpts + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print('Done.') + + if self.world_size > 1: + self.check_ddp() + + def encode_image_proj( + self, + image: torch.Tensor, + camera_angle_x: Optional[torch.Tensor] = None, + distance: Optional[torch.Tensor] = None, + mesh_scale: Optional[torch.Tensor] = None, + transform_matrix: Optional[torch.Tensor] = None, + coords: Optional[torch.Tensor] = None, + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """ + Encode the image with view-aligned projection. + + Supports both 'proj' mode (DINOv3 only, 2 outputs) and + 'gated_proj' mode (DINOv3 + VAE, 3 outputs). + """ + if self.image_cond_model is None: + self._init_image_cond_model() + + outputs = self.image_cond_model( + image, + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + transform_matrix=transform_matrix, + ) + + is_gated = self.image_attn_mode == 'gated_proj' + + if is_gated: + cond_global, cond_proj_semantic, cond_proj_color = outputs + else: + cond_global, cond_proj = outputs + + # If coords provided, extract features at sparse positions + if coords is not None: + B = cond_global.shape[0] + module = getattr(self, '_image_cond_module', self.image_cond_model) + grid_res = module.grid_resolution + batch_indices = coords[:, 0].long() + x_coords = coords[:, 1].long() + y_coords = coords[:, 2].long() + z_coords = coords[:, 3].long() + + if is_gated: + cond_proj_semantic = cond_proj_semantic.reshape(B, grid_res, grid_res, grid_res, -1) + cond_proj_semantic = cond_proj_semantic[batch_indices, x_coords, y_coords, z_coords] + cond_proj_color = cond_proj_color.reshape(B, grid_res, grid_res, grid_res, -1) + cond_proj_color = cond_proj_color[batch_indices, x_coords, y_coords, z_coords] + else: + cond_proj = cond_proj.reshape(B, grid_res, grid_res, grid_res, -1) + cond_proj = cond_proj[batch_indices, x_coords, y_coords, z_coords] + + if is_gated: + cond = { + 'global': cond_global, + 'proj_semantic': cond_proj_semantic, + 'proj_color': cond_proj_color, + } + uncond = { + 'global': torch.zeros_like(cond_global), + 'proj_semantic': torch.zeros_like(cond_proj_semantic), + 'proj_color': torch.zeros_like(cond_proj_color), + } + else: + cond = {'global': cond_global, 'proj': cond_proj} + uncond = {'global': torch.zeros_like(cond_global), 'proj': torch.zeros_like(cond_proj)} + + return cond, uncond + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Encode the image (standard mode without projection). + """ + if self.image_cond_model is None: + self._init_image_cond_model() + + if self.image_attn_mode == 'proj': + # For proj mode, return dict + global_feat, proj_feat = self.image_cond_model(image) + return {'global': global_feat, 'proj': proj_feat} + else: + # Standard mode + features = self.image_cond_model(image) + return features + + def _extract_camera_info(self, kwargs): + """ + Extract camera info from kwargs. + + Supports two formats: + 1. 'camera_info' dict: {'camera_angle_x': ..., 'distance': ..., 'mesh_scale': ..., 'transform_matrix': ..., 'coords': ...} + 2. Flat fields: 'camera_angle_x', 'camera_distance', 'mesh_scale', 'transform_matrix', 'coords' in kwargs + + Returns: + camera_info dict or None if not available + """ + if 'camera_info' in kwargs: + return kwargs.pop('camera_info') + + # Try to extract from flat fields (as returned by ViewImageConditionedMixin) + camera_angle_x = kwargs.pop('camera_angle_x', None) + camera_distance = kwargs.pop('camera_distance', None) + mesh_scale = kwargs.pop('mesh_scale', None) + transform_matrix = kwargs.pop('transform_matrix', None) + coords = kwargs.pop('coords', None) + + if camera_angle_x is not None and camera_distance is not None and mesh_scale is not None: + return { + 'camera_angle_x': camera_angle_x, + 'distance': camera_distance, + 'mesh_scale': mesh_scale, + 'transform_matrix': transform_matrix, + 'coords': coords, + } + + return None + + def get_cond(self, cond, **kwargs): + """Get the conditioning data.""" + kwargs.pop('view_idx', None) + + if self.image_attn_mode in ('proj', 'gated_proj'): + # Handle projection mode (both standard proj and gated_proj) + camera_info = self._extract_camera_info(kwargs) + if camera_info is not None: + coords = camera_info.get('coords') + cond, neg_cond = self.encode_image_proj( + cond, + camera_angle_x=camera_info.get('camera_angle_x'), + distance=camera_info.get('distance'), + mesh_scale=camera_info.get('mesh_scale'), + transform_matrix=camera_info.get('transform_matrix'), + coords=coords, + ) + + # For sparse mode (coords provided), handle CFG dropout ourselves + if coords is not None and hasattr(self, 'p_uncond') and self.p_uncond > 0: + import numpy as np + B = cond['global'].shape[0] + mask = np.random.rand(B) < self.p_uncond + + global_tensor = cond['global'] + global_mask_shape = [B] + [1] * (global_tensor.ndim - 1) + global_mask = torch.tensor(mask, device=global_tensor.device).reshape(global_mask_shape) + cond['global'] = torch.where(global_mask, neg_cond['global'], cond['global']) + + batch_indices = coords[:, 0].long() + # Handle all sparse proj keys (proj, or proj_semantic + proj_color) + for key in list(cond.keys()): + if key.startswith('proj'): + device = cond[key].device + sparse_mask = torch.tensor(mask, device=device)[batch_indices].reshape(-1, 1) + cond[key] = torch.where(sparse_mask, neg_cond[key], cond[key]) + + return cond + else: + kwargs['neg_cond'] = neg_cond + else: + cond = self.encode_image(cond) + if isinstance(cond, dict) and 'global' in cond: + kwargs['neg_cond'] = {k: torch.zeros_like(v) for k, v in cond.items()} + else: + kwargs['neg_cond'] = torch.zeros_like(cond) + else: + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """Get the conditioning data for inference.""" + kwargs.pop('view_idx', None) + + if self.image_attn_mode in ('proj', 'gated_proj'): + camera_info = self._extract_camera_info(kwargs) + if camera_info is not None: + cond, neg_cond = self.encode_image_proj( + cond, + camera_angle_x=camera_info.get('camera_angle_x'), + distance=camera_info.get('distance'), + mesh_scale=camera_info.get('mesh_scale'), + transform_matrix=camera_info.get('transform_matrix'), + coords=camera_info.get('coords'), + ) + kwargs['neg_cond'] = neg_cond + else: + cond = self.encode_image(cond) + if isinstance(cond, dict) and 'global' in cond: + kwargs['neg_cond'] = {k: torch.zeros_like(v) for k, v in cond.items()} + else: + kwargs['neg_cond'] = torch.zeros_like(cond) + else: + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + + cond = super().get_inference_cond(cond, **kwargs) + return cond + + def vis_cond(self, cond, **kwargs): + """Visualize the conditioning data.""" + return {'image': {'value': cond, 'type': 'image'}} + + @torch.no_grad() + def visualize_projection_test( + self, + cond: torch.Tensor, + save_dir: str, + prefix: str = "proj_vis", + **kwargs + ) -> Optional[List[Image.Image]]: + """ + Visualize projection points on the condition images. + + This should be called once before training starts to verify the projection is correct. + + Args: + cond: Condition image tensor [B, C, H, W], in [0, 1] range + save_dir: Directory to save visualizations + prefix: Prefix for saved files + **kwargs: Should contain camera_angle_x, camera_distance, mesh_scale, transform_matrix + + Returns: + List of PIL Images with projected points overlaid, or None if not in proj mode + """ + if self.image_attn_mode != 'proj': + return None + + if self.image_cond_model is None: + self._init_image_cond_model() + + # Use _image_cond_module for attribute access (image_cond_model may be DDP-wrapped) + module = getattr(self, '_image_cond_module', self.image_cond_model) + + # Check if the model has visualization capability + if not hasattr(module, 'visualize_projection'): + print("Warning: image_cond_model does not support visualize_projection") + return None + + # Extract camera info + camera_info = self._extract_camera_info(kwargs) + if camera_info is None: + print("Warning: No camera info available for projection visualization") + return None + + return module.visualize_projection( + image=cond, + camera_angle_x=camera_info.get('camera_angle_x'), + distance=camera_info.get('distance'), + mesh_scale=camera_info.get('mesh_scale'), + transform_matrix=camera_info.get('transform_matrix'), + save_dir=save_dir, + prefix=prefix, + ) diff --git a/trellis2/trainers/flow_matching/mixins/text_conditioned.py b/trellis2/trainers/flow_matching/mixins/text_conditioned.py new file mode 100644 index 0000000000000000000000000000000000000000..85f1dcf2582c07edd9b629c07181265f24a90134 --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/text_conditioned.py @@ -0,0 +1,68 @@ +from typing import * +import os +os.environ['TOKENIZERS_PARALLELISM'] = 'true' +import torch +from transformers import AutoTokenizer, CLIPTextModel + +from ....utils import dist_utils + + +class TextConditionedMixin: + """ + Mixin for text-conditioned models. + + Args: + text_cond_model: The text conditioning model. + """ + def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs): + super().__init__(*args, **kwargs) + self.text_cond_model_name = text_cond_model + self.text_cond_model = None # the model is init lazily + + def _init_text_cond_model(self): + """ + Initialize the text conditioning model. + """ + # load model + with dist_utils.local_master_first(): + model = CLIPTextModel.from_pretrained(self.text_cond_model_name) + tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name) + model.eval() + model = model.cuda() + self.text_cond_model = { + 'model': model, + 'tokenizer': tokenizer, + } + self.text_cond_model['null_cond'] = self.encode_text(['']) + + @torch.no_grad() + def encode_text(self, text: List[str]) -> torch.Tensor: + """ + Encode the text. + """ + assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond" + if self.text_cond_model is None: + self._init_text_cond_model() + encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') + tokens = encoding['input_ids'].cuda() + embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state + + return embeddings + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_text(cond) + kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_text(cond) + kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) + cond = super().get_inference_cond(cond, **kwargs) + return cond diff --git a/trellis2/trainers/flow_matching/sparse_flow_matching.py b/trellis2/trainers/flow_matching/sparse_flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..65f211fb3983d0b407dd995b2be5a4e944d3e5d0 --- /dev/null +++ b/trellis2/trainers/flow_matching/sparse_flow_matching.py @@ -0,0 +1,615 @@ +from typing import * +import os +import copy +import functools +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict + +from ...modules import sparse as sp +from ...utils.general_utils import dict_reduce +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from .flow_matching import FlowMatchingTrainer +from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin +from .mixins.text_conditioned import TextConditionedMixin +from .mixins.image_conditioned import ImageConditionedMixin, MultiImageConditionedMixin +from .mixins.image_conditioned_proj import ImageConditionedProjMixin + + +class SparseFlowMatchingTrainer(FlowMatchingTrainer): + """ + Trainer for sparse diffusion model with flow matching objective. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + """ + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + if self.num_workers is None or self.num_workers == -1: + num_workers = max(1, int(np.ceil((os.cpu_count() - 16) / torch.cuda.device_count()))) + else: + num_workers = self.num_workers + + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def training_losses( + self, + x_0: sp.SparseTensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Args: + x_0: The [N x ... x C] sparse tensor of the inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments to pass to the backbone. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = x_0.replace(torch.randn_like(x_0.feats)) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + cond = self.get_cond(cond, **kwargs) + + pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred.feats, target.feats) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + # Use current step as seed to ensure different samples for each snapshot + import random + snapshot_seed = self.step + random.seed(snapshot_seed) + np.random.seed(snapshot_seed) + + g = torch.Generator() + g.manual_seed(snapshot_seed) + + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=num_samples, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + generator=g, + ) + data = next(iter(dataloader)) + + # Collect metadata (dataset_name and sha256) for wandb display + sample_metadata = [] + if '_dataset_name' in data and '_sha256' in data: + for j in range(min(num_samples, len(data['_dataset_name']))): + sample_metadata.append(f"{data['_dataset_name'][j]}/{data['_sha256'][j]}") + # Remove metadata fields before inference + data.pop('_dataset_name', None) + data.pop('_sha256', None) + + # inference + sampler = self.get_sampler() + sample = [] + cond_vis = [] + for i in range(0, num_samples, batch_size): + batch_data = {k: v[i:i+batch_size] for k, v in data.items()} + batch_data = recursive_to_device(batch_data, 'cuda') + noise = batch_data['x_0'].replace(torch.randn_like(batch_data['x_0'].feats)) + cond_vis.append(self.vis_cond(**batch_data)) + del batch_data['x_0'] + args = self.get_inference_cond(**batch_data) + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=12, guidance_strength=3.0, verbose=verbose, + ) + sample.append(res.samples) + sample = sp.sparse_cat(sample) + + sample_gt = {k: v for k, v in data.items()} + sample = {k: v if k != 'x_0' else sample for k, v in data.items()} + sample_dict = { + 'sample_gt': {'value': sample_gt, 'type': 'sample'}, + 'sample': {'value': sample, 'type': 'sample'}, + } + if sample_metadata: + sample_dict['_metadata'] = sample_metadata + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + +class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer): + """ + Trainer for sparse diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + """ + pass + + +class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + text_cond_model(str): Text conditioning model. + """ + pass + + +class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass + + +class MultiImageConditionedSparseFlowMatchingCFGTrainer(MultiImageConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass + + +class ImageConditionedProjSparseFlowMatchingCFGTrainer(ImageConditionedProjMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse image-conditioned diffusion model with view-aligned projection. + + Uses ImageConditionedProjMixin for 3D-to-2D feature projection with camera parameters. + CFG dropout is handled by ClassifierFreeGuidanceMixin (via p_uncond parameter). + + The projection grid outputs a full [B, R, R, R, D] tensor, and this trainer extracts + features at sparse coordinates using advanced indexing. + + Args: + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (dict): Image conditioning model config (DinoV3ProjFeatureExtractor). + run_projection_test (bool): Whether to run projection visualization test before training. + """ + + def __init__(self, *args, run_projection_test: bool = True, **kwargs): + super().__init__(*args, **kwargs) + self.run_projection_test = run_projection_test + + def training_losses( + self, + x_0: sp.SparseTensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Overridden to pass coords from x_0 to get_cond for sparse feature extraction. + + Args: + x_0: The [N x ... x C] sparse tensor of the inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments to pass to the backbone. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = x_0.replace(torch.randn_like(x_0.feats)) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + + # Pass coords to get_cond for sparse feature extraction from full grid + kwargs['coords'] = x_0.coords + cond = self.get_cond(cond, **kwargs) + + # Pass concat_cond to denoiser if present (needed for PBR/texture training + # where shape latent is concatenated with PBR latent as input) + denoiser_kwargs = {} + if 'concat_cond' in kwargs: + denoiser_kwargs['concat_cond'] = kwargs['concat_cond'] + pred = self.training_models['denoiser'](x_t, t * 1000, cond, **denoiser_kwargs) + + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred.feats, target.feats) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + """ + Run snapshot with coords passed to get_inference_cond for sparse feature extraction. + + For projection mode, we need to pass coords to properly extract features at + sparse positions from the full projection grid. + """ + # Use current step as seed to ensure different samples for each snapshot + import random + snapshot_seed = self.step + random.seed(snapshot_seed) + np.random.seed(snapshot_seed) + + g = torch.Generator() + g.manual_seed(snapshot_seed) + + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=num_samples, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + generator=g, + ) + data = next(iter(dataloader)) + + # Collect metadata (dataset_name and sha256) for wandb display + sample_metadata = [] + if '_dataset_name' in data and '_sha256' in data: + for j in range(min(num_samples, len(data['_dataset_name']))): + sample_metadata.append(f"{data['_dataset_name'][j]}/{data['_sha256'][j]}") + # Remove metadata fields before inference + data.pop('_dataset_name', None) + data.pop('_sha256', None) + + # inference + sampler = self.get_sampler() + sample = [] + cond_vis = [] + for i in range(0, num_samples, batch_size): + batch_data = {k: v[i:i+batch_size] for k, v in data.items()} + batch_data = recursive_to_device(batch_data, 'cuda') + noise = batch_data['x_0'].replace(torch.randn_like(batch_data['x_0'].feats)) + cond_vis.append(self.vis_cond(**batch_data)) + + # Save coords before deleting x_0 (needed for projection feature extraction) + coords = batch_data['x_0'].coords + del batch_data['x_0'] + + # Pass coords to get_inference_cond for sparse feature extraction + batch_data['coords'] = coords + args = self.get_inference_cond(**batch_data) + + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=12, guidance_strength=3.0, verbose=verbose, + ) + sample.append(res.samples) + sample = sp.sparse_cat(sample) + + sample_gt = {k: v for k, v in data.items()} + sample = {k: v if k != 'x_0' else sample for k, v in data.items()} + sample_dict = { + 'sample_gt': {'value': sample_gt, 'type': 'sample'}, + 'sample': {'value': sample, 'type': 'sample'}, + } + if sample_metadata: + sample_dict['_metadata'] = sample_metadata + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + @torch.no_grad() + def visualize_sample(self, sample): + """ + Convert a sample to images, including GT camera view if available. + + Args: + sample: Either a SparseTensor or dict containing: + - 'x_0': SparseTensor + - 'camera_angle_x': [B] (optional) + - 'camera_distance': [B] (optional) + - 'mesh_scale': [B] (optional) + + Returns: + dict with visualization images or tensor + """ + if hasattr(self.dataset, 'visualize_sample'): + if isinstance(sample, dict): + # Extract camera params and pass them explicitly, since some + # dataset.visualize_sample() (e.g. SLatShapeVisMixin) expect + # separate keyword arguments rather than a single dict. + camera_kwargs = {} + for k in ('camera_angle_x', 'camera_distance', 'mesh_scale'): + if k in sample: + camera_kwargs[k] = sample[k] + + # Try passing camera kwargs explicitly first; fall back to + # passing the entire dict if the dataset method doesn't accept them + # (e.g. SLatPbrVisMixin expects a dict with 'x_0' + 'concat_cond'). + import inspect + sig = inspect.signature(self.dataset.visualize_sample) + params = list(sig.parameters.keys()) + if 'camera_angle_x' in params: + # Shape-style: visualize_sample(x_0, camera_angle_x=, ...) + x_0 = sample.get('x_0', sample) + return self.dataset.visualize_sample(x_0, **camera_kwargs) + else: + # Tex/PBR-style: visualize_sample(sample_dict) + return self.dataset.visualize_sample(sample) + else: + return self.dataset.visualize_sample(sample) + else: + if isinstance(sample, dict): + return sample.get('x_0', sample) + return sample + + def run(self): + """ + Run training with projection visualization test before starting. + """ + # Run projection visualization test before training starts (if enabled) + if self.run_projection_test and self.is_master: + print('\n' + '='*60) + print('Running projection visualization test...') + print('='*60) + self._run_projection_visualization_test() + + super().run() + + @torch.no_grad() + def _run_projection_visualization_test(self, num_samples: int = 4): + """ + Run projection visualization test on a few samples before training starts. + + This helps verify that the 3D-to-2D projection is working correctly. + """ + import os + from torch.utils.data import DataLoader + + # Create a small dataloader + dataloader = DataLoader( + self.dataset, + batch_size=min(num_samples, self.snapshot_batch_size), + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # Get one batch + data = next(iter(dataloader)) + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + + # Extract condition image + cond = data.get('cond') + if cond is None: + print("Warning: No 'cond' field in data, skipping projection visualization test") + return + + # Save directory + save_dir = os.path.join(self.output_dir, 'samples', 'projection_test') + + # Call visualization method + if hasattr(self, 'visualize_projection_test'): + # Need to pass camera info as kwargs + kwargs = {k: v for k, v in data.items() if k != 'cond' and k != 'x_0'} + self.visualize_projection_test( + cond=cond, + save_dir=save_dir, + prefix="proj_test", + **kwargs + ) + print(f"Projection visualization saved to: {save_dir}") + else: + print("Warning: visualize_projection_test not available") diff --git a/trellis2/trainers/utils.py b/trellis2/trainers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..23e4286d02b6fc102c364d2ef7571ef97c9fcd41 --- /dev/null +++ b/trellis2/trainers/utils.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +# FP16 utils +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def str_to_dtype(dtype_str: str): + return { + 'f16': torch.float16, + 'fp16': torch.float16, + 'float16': torch.float16, + 'bf16': torch.bfloat16, + 'bfloat16': torch.bfloat16, + 'f32': torch.float32, + 'fp32': torch.float32, + 'float32': torch.float32, + }[dtype_str] + + +def make_master_params(model_params): + """ + Copy model parameters into a inflated tensor of full-precision parameters. + """ + master_params = _flatten_dense_tensors( + [param.detach().float() for param in model_params] + ) + master_params = nn.Parameter(master_params) + master_params.requires_grad = True + return [master_params] + + +def unflatten_master_params(model_params, master_params): + """ + Unflatten the master parameters to look like model_params. + """ + return _unflatten_dense_tensors(master_params[0].detach(), model_params) + + +def model_params_to_master_params(model_params, master_params): + """ + Copy the model parameter data into the master parameters. + """ + master_params[0].detach().copy_( + _flatten_dense_tensors([param.detach().float() for param in model_params]) + ) + + +def master_params_to_model_params(model_params, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + for param, master_param in zip( + model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params) + ): + param.detach().copy_(master_param) + + +def model_grads_to_master_grads(model_params, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + master_params[0].grad = _flatten_dense_tensors( + [param.grad.data.detach().float() for param in model_params] + ) + + +def zero_grad(model_params): + for param in model_params: + if param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + +# LR Schedulers +from torch.optim.lr_scheduler import LambdaLR + +class LinearWarmupLRScheduler(LambdaLR): + def __init__(self, optimizer, warmup_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, current_step): + if current_step < self.warmup_steps: + return float(current_step + 1) / self.warmup_steps + return 1.0 + \ No newline at end of file diff --git a/trellis2/trainers/vae/pbr_vae.py b/trellis2/trainers/vae/pbr_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..28d3646233ce7752796d679c5cd4e1737eece474 --- /dev/null +++ b/trellis2/trainers/vae/pbr_vae.py @@ -0,0 +1,291 @@ +from typing import * +import os +import copy +import functools +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import utils3d +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...modules import sparse as sp +from ...renderers import MeshRenderer +from ...representations import Mesh, MeshWithPbrMaterial, MeshWithVoxel +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips + + +class PbrVaeTrainer(BasicTrainer): + """ + Trainer for PBR attributes VAE + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. + lambda_kl (float): KL loss weight. + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + """ + + def __init__( + self, + *args, + loss_type: str = 'l1', + lambda_kl: float = 1e-6, + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + lambda_render: float = 1.0, + render_resolution: float = 1024, + camera_randomization_config: dict = { + 'radius_range': [2, 100], + }, + **kwargs + ): + super().__init__(*args, **kwargs) + self.loss_type = loss_type + self.lambda_kl = lambda_kl + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + self.lambda_render = lambda_render + self.camera_randomization_config = camera_randomization_config + + self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _randomize_camera(self, num_samples: int): + # sample radius and fov + r_min, r_max = self.camera_randomization_config['radius_range'] + k_min = 1 / r_max**2 + k_max = 1 / r_min**2 + ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min + radius = 1 / torch.sqrt(ks) + fov = 2 * torch.arcsin(0.5 / radius) + origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1) + + # build camera + extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device)) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + near = [np.random.uniform(r - 1, r) for r in radius.tolist()] + + return { + 'extrinsics': extrinsics, + 'intrinsics': intrinsics, + 'near': near, + } + + def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List, + ) -> Dict[str, torch.Tensor]: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + + Returns: + a dict with + base_color : [N x 3 x H x W] tensor of base color. + metallic : [N x 1 x H x W] tensor of metallic. + roughness : [N x 1 x H x W] tensor of roughness. + alpha : [N x 1 x H x W] tensor of alpha. + """ + ret = {k : [] for k in ['base_color', 'metallic', 'roughness', 'alpha']} + for i, rep in enumerate(reps): + self.renderer.rendering_options['near'] = near[i] + self.renderer.rendering_options['far'] = near[i] + 2 + out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=['attr']) + for k in out_dict: + ret[k].append(out_dict[k]) + for k in ret: + ret[k] = torch.stack(ret[k]) + return ret + + def training_losses( + self, + x: sp.SparseTensor, + mesh: List[MeshWithPbrMaterial] = None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + x (SparseTensor): Input sparse tensor for pbr materials. + mesh (List[MeshWithPbrMaterial]): The list of meshes with PBR materials. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + + """ + z, mean, logvar = self.training_models['encoder'](x, sample_posterior=True, return_raw=True) + y = self.training_models['decoder'](z) + + terms = edict(loss = 0.0) + + # direct regression + if self.loss_type == 'l1': + terms["l1"] = l1_loss(x.feats, y.feats) + terms["loss"] = terms["loss"] + terms["l1"] + elif self.loss_type == 'l2': + terms["l2"] = l2_loss(x.feats, y.feats) + terms["loss"] = terms["loss"] + terms["l2"] + else: + raise ValueError(f'Invalid loss type {self.loss_type}') + + # rendering loss + if self.lambda_render != 0.0: + recon = [MeshWithVoxel( + m.vertices, + m.faces, + [-0.5, -0.5, -0.5], + 1 / self.dataset.resolution, + v.coords[:, 1:], + v.feats * 0.5 + 0.5, + torch.Size([*v.shape, *v.spatial_shape]), + layout={ + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + ) for m, v in zip(mesh, y)] + cameras = self._randomize_camera(len(mesh)) + gt_renders = self._render_batch(mesh, **cameras) + pred_renders = self._render_batch(recon, **cameras) + gt_base_color = gt_renders['base_color'] + pred_base_color = pred_renders['base_color'] + gt_mra = torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1) + pred_mra = torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1) + terms['render/base_color/ssim'] = 1 - ssim(pred_base_color, gt_base_color) + terms['render/base_color/lpips'] = lpips(pred_base_color, gt_base_color) + terms['render/mra/ssim'] = 1 - ssim(pred_mra, gt_mra) + terms['render/mra/lpips'] = lpips(pred_mra, gt_mra) + terms['loss'] = terms['loss'] + \ + self.lambda_render * (self.lambda_ssim * terms['render/base_color/ssim'] + self.lambda_lpips * terms['render/base_color/lpips'] + \ + self.lambda_ssim * terms['render/mra/ssim'] + self.lambda_lpips * terms['render/mra/lpips']) + + # KL regularization + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + # Use current step as seed to ensure different samples for each snapshot + import random + snapshot_seed = self.step + random.seed(snapshot_seed) + np.random.seed(snapshot_seed) + + g = torch.Generator() + g.manual_seed(snapshot_seed) + + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=1, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + generator=g, + ) + dataloader.dataset.with_mesh = True + + # inference + gts = [] + recons = [] + self.models['encoder'].eval() + self.models['decoder'].eval() + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch] for k, v in data.items()} + args = recursive_to_device(args, self.device) + z = self.models['encoder'](args['x']) + y = self.models['decoder'](z) + gts.extend(args['mesh']) + recons.extend([MeshWithVoxel( + m.vertices, + m.faces, + [-0.5, -0.5, -0.5], + 1 / self.dataset.resolution, + v.coords[:, 1:], + v.feats * 0.5 + 0.5, + torch.Size([*v.shape, *v.spatial_shape]), + layout={ + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + ) for m, v in zip(args['mesh'], y)]) + self.models['encoder'].train() + self.models['decoder'].train() + + cameras = self._randomize_camera(num_samples) + gt_renders = self._render_batch(gts, **cameras) + pred_renders = self._render_batch(recons, **cameras) + + sample_dict = { + 'gt_base_color': {'value': gt_renders['base_color'] * 2 - 1, 'type': 'image'}, + 'pred_base_color': {'value': pred_renders['base_color'] * 2 - 1, 'type': 'image'}, + 'gt_mra': {'value': torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'}, + 'pred_mra': {'value': torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'}, + } + + return sample_dict diff --git a/trellis2/trainers/vae/shape_vae.py b/trellis2/trainers/vae/shape_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..db4630ed7c43792bdb29df98be6be1eb0756e746 --- /dev/null +++ b/trellis2/trainers/vae/shape_vae.py @@ -0,0 +1,276 @@ +from typing import * +import os +import copy +import functools +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import utils3d +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...modules import sparse as sp +from ...renderers import MeshRenderer +from ...representations import Mesh +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from ...utils.loss_utils import l1_loss, ssim, lpips + + +class ShapeVaeTrainer(BasicTrainer): + """ + Trainer for Shape VAE + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + lambda_subdiv (float): Subdivision loss weight. + lambda_intersected (float): Intersected loss weight. + lambda_vertice (float): Vertice loss weight. + lambda_kl (float): KL loss weight. + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + """ + + def __init__( + self, + *args, + lambda_subdiv: float = 0.1, + lambda_intersected: float = 0.1, + lambda_vertice: float = 1e-2, + lambda_mask: float = 1, + lambda_depth: float = 10, + lambda_normal: float = 1, + lambda_kl: float = 1e-6, + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + render_resolution: float = 1024, + camera_randomization_config: dict = { + 'radius_range': [2, 100], + }, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lambda_subdiv = lambda_subdiv + self.lambda_intersected = lambda_intersected + self.lambda_mask = lambda_mask + self.lambda_vertice = lambda_vertice + self.lambda_depth = lambda_depth + self.lambda_normal = lambda_normal + self.lambda_kl = lambda_kl + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + self.camera_randomization_config = camera_randomization_config + + self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _randomize_camera(self, num_samples: int): + # sample radius and fov + r_min, r_max = self.camera_randomization_config['radius_range'] + k_min = 1 / r_max**2 + k_max = 1 / r_min**2 + ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min + radius = 1 / torch.sqrt(ks) + fov = 2 * torch.arcsin(0.5 / radius) + origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1) + + # build camera + extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device)) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + near = [np.random.uniform(r - 1, r) for r in radius.tolist()] + + return { + 'extrinsics': extrinsics, + 'intrinsics': intrinsics, + 'near': near, + } + + def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List, + return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color'] + + Returns: + a dict with + mask : [N x 1 x H x W] tensor of rendered masks + normal : [N x 3 x H x W] tensor of rendered normals + depth : [N x 1 x H x W] tensor of rendered depths + """ + ret = {k : [] for k in return_types} + for i, rep in enumerate(reps): + self.renderer.rendering_options['near'] = near[i] + self.renderer.rendering_options['far'] = near[i] + 2 + out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types) + for k in out_dict: + ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k]) + for k in ret: + ret[k] = torch.stack(ret[k]) + return ret + + def training_losses( + self, + vertices: sp.SparseTensor, + intersected: sp.SparseTensor, + mesh: List[Mesh], + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + vertices (SparseTensor): vertices of each active voxel + intersected (SparseTensor): intersected flag of each active voxel + mesh (List[Mesh]): the list of meshes to render + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + z, mean, logvar = self.training_models['encoder'](vertices, intersected, sample_posterior=True, return_raw=True) + recon, pred_vertice, pred_intersected, subs_gt, subs = self.training_models['decoder'](z, intersected) + + terms = edict(loss = 0.0) + + # direct regression + if self.lambda_intersected > 0: + terms["direct/intersected"] = F.binary_cross_entropy_with_logits(pred_intersected.feats.flatten(), intersected.feats.flatten().float()) + terms["loss"] = terms["loss"] + self.lambda_intersected * terms["direct/intersected"] + if self.lambda_vertice > 0: + terms["direct/vertice"] = F.mse_loss(pred_vertice.feats, vertices.feats) + terms["loss"] = terms["loss"] + self.lambda_vertice * terms["direct/vertice"] + + # subdivision prediction loss + for i, (sub_gt, sub) in enumerate(zip(subs_gt, subs)): + terms[f"bce_sub{i}"] = F.binary_cross_entropy_with_logits(sub.feats, sub_gt.float()) + terms["loss"] = terms["loss"] + self.lambda_subdiv * terms[f"bce_sub{i}"] + + # rendering loss + cameras = self._randomize_camera(len(mesh)) + gt_renders = self._render_batch(mesh, **cameras, return_types=['mask', 'normal', 'depth']) + pred_renders = self._render_batch(recon, **cameras, return_types=['mask', 'normal', 'depth']) + terms['render/mask'] = l1_loss(pred_renders['mask'], gt_renders['mask']) + terms['render/depth'] = l1_loss(pred_renders['depth'], gt_renders['depth']) + terms['render/normal/l1'] = l1_loss(pred_renders['normal'], gt_renders['normal']) + terms['render/normal/ssim'] = 1 - ssim(pred_renders['normal'], gt_renders['normal']) + terms['render/normal/lpips'] = lpips(pred_renders['normal'], gt_renders['normal']) + terms['loss'] = terms['loss'] + \ + self.lambda_mask * terms['render/mask'] + \ + self.lambda_depth * terms['render/depth'] + \ + self.lambda_normal * (terms['render/normal/l1'] + self.lambda_ssim * terms['render/normal/ssim'] + self.lambda_lpips * terms['render/normal/lpips']) + + # KL regularization + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + # Use current step as seed to ensure different samples for each snapshot + import random + snapshot_seed = self.step + random.seed(snapshot_seed) + np.random.seed(snapshot_seed) + + g = torch.Generator() + g.manual_seed(snapshot_seed) + + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=1, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + generator=g, + ) + + # inference + gts = [] + recons = [] + recons2 = [] + self.models['encoder'].eval() + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch] for k, v in data.items()} + args = recursive_to_device(args, self.device) + z = self.models['encoder'](args['vertices'], args['intersected']) + self.models['decoder'].train() + y = self.models['decoder'](z, args['intersected'])[0] + z.clear_spatial_cache() + self.models['decoder'].eval() + y2 = self.models['decoder'](z) + gts.extend(args['mesh']) + recons.extend(y) + recons2.extend(y2) + self.models['encoder'].train() + self.models['decoder'].train() + + cameras = self._randomize_camera(num_samples) + gt_renders = self._render_batch(gts, **cameras, return_types=['normal']) + recons_renders = self._render_batch(recons, **cameras, return_types=['normal']) + recons2_renders = self._render_batch(recons2, **cameras, return_types=['normal']) + + sample_dict = { + 'gt': {'value': gt_renders['normal'], 'type': 'image'}, + 'rec': {'value': recons_renders['normal'], 'type': 'image'}, + 'rec2': {'value': recons2_renders['normal'], 'type': 'image'}, + } + + return sample_dict diff --git a/trellis2/trainers/vae/sparse_structure_vae.py b/trellis2/trainers/vae/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd9d7baae7036e40ac80b90156823941e0ec964 --- /dev/null +++ b/trellis2/trainers/vae/sparse_structure_vae.py @@ -0,0 +1,140 @@ +from typing import * +import copy +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from easydict import EasyDict as edict + +from ..basic import BasicTrainer + + +class SparseStructureVaeTrainer(BasicTrainer): + """ + Trainer for Sparse Structure VAE. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss. + lambda_kl (float): KL divergence loss weight. + """ + + def __init__( + self, + *args, + loss_type='bce', + lambda_kl=1e-6, + **kwargs + ): + super().__init__(*args, **kwargs) + self.loss_type = loss_type + self.lambda_kl = lambda_kl + + def training_losses( + self, + ss: torch.Tensor, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + ss: The [N x 1 x H x W x D] tensor of binary sparse structure. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True) + logits = self.training_models['decoder'](z) + + terms = edict(loss = 0.0) + if self.loss_type == 'bce': + terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean') + terms["loss"] = terms["loss"] + terms["bce"] + elif self.loss_type == 'l1': + terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean') + terms["loss"] = terms["loss"] + terms["l1"] + elif self.loss_type == 'dice': + logits = F.sigmoid(logits) + terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1) + terms["loss"] = terms["loss"] + terms["dice"] + else: + raise ValueError(f'Invalid loss type {self.loss_type}') + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lamda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False): + super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose) + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + # Use current step as seed to ensure different samples for each snapshot + import random + snapshot_seed = self.step + random.seed(snapshot_seed) + np.random.seed(snapshot_seed) + + g = torch.Generator() + g.manual_seed(snapshot_seed) + + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + generator=g, + ) + + # inference + gts = [] + recons = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} + z = self.models['encoder'](args['ss'].float(), sample_posterior=False) + logits = self.models['decoder'](z) + recon = (logits > 0).long() + gts.append(args['ss']) + recons.append(recon) + + sample_dict = { + 'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'}, + 'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'}, + } + return sample_dict diff --git a/trellis2/utils/__init__.py b/trellis2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/trellis2/utils/data_utils.py b/trellis2/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..805b6cc118106857e7ef767ab4bfd133dbd78e6f --- /dev/null +++ b/trellis2/utils/data_utils.py @@ -0,0 +1,226 @@ +from typing import * +import math +import torch +import numpy as np +from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler +import torch.distributed as dist + + +def recursive_to_device( + data: Any, + device: torch.device, + non_blocking: bool = False, +) -> Any: + """ + Recursively move all tensors in a data structure to a device. + """ + if hasattr(data, "to"): + return data.to(device, non_blocking=non_blocking) + elif isinstance(data, (list, tuple)): + return type(data)(recursive_to_device(d, device, non_blocking) for d in data) + elif isinstance(data, dict): + return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()} + else: + return data + + +def load_balanced_group_indices( + load: List[int], + num_groups: int, + equal_size: bool = False, +) -> List[List[int]]: + """ + Split indices into groups with balanced load. + """ + if equal_size: + group_size = len(load) // num_groups + indices = np.argsort(load)[::-1] + groups = [[] for _ in range(num_groups)] + group_load = np.zeros(num_groups) + for idx in indices: + min_group_idx = np.argmin(group_load) + groups[min_group_idx].append(idx) + if equal_size and len(groups[min_group_idx]) == group_size: + group_load[min_group_idx] = float('inf') + else: + group_load[min_group_idx] += load[idx] + return groups + + +def cycle(data_loader: DataLoader) -> Iterator: + while True: + for data in data_loader: + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined] + yield data + if isinstance(data_loader.sampler, DistributedSampler): + data_loader.sampler.epoch += 1 + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.epoch += 1 + data_loader.sampler.idx = 0 + + +class ResumableSampler(Sampler): + """ + Distributed sampler that is resumable. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + self.dataset = dataset + self.epoch = 0 + self.idx = 0 + self.drop_last = drop_last + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.rank = dist.get_rank() if dist.is_initialized() else 0 + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type] + self.total_size = self.num_samples * self.world_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.world_size] + + # resume from previous state + indices = indices[self.idx:] + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def state_dict(self) -> dict[str, int]: + return { + 'epoch': self.epoch, + 'idx': self.idx, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict['epoch'] + self.idx = state_dict['idx'] + + +class BalancedResumableSampler(ResumableSampler): + """ + Distributed sampler that is resumable and balances the load among the processes. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + batch_size: int = 1, + ) -> None: + assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler' + super().__init__(dataset, shuffle, seed, drop_last) + self.batch_size = batch_size + self.loads = dataset.loads + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # balance load among processes + num_batches = len(indices) // (self.batch_size * self.world_size) + balanced_indices = [] + for i in range(num_batches): + start_idx = i * self.batch_size * self.world_size + end_idx = (i + 1) * self.batch_size * self.world_size + batch_indices = indices[start_idx:end_idx] + batch_loads = [self.loads[idx] for idx in batch_indices] + groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True) + balanced_indices.extend([batch_indices[j] for j in groups[self.rank]]) + + # resume from previous state + indices = balanced_indices[self.idx:] + + return iter(indices) diff --git a/trellis2/utils/dist_utils.py b/trellis2/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f57e2060ea6d7b208a7322662fcbee13acf1d096 --- /dev/null +++ b/trellis2/utils/dist_utils.py @@ -0,0 +1,94 @@ +import os +import io +from datetime import timedelta +from contextlib import contextmanager +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + + +def setup_dist(rank, local_rank, world_size, master_addr, master_port): + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(local_rank) + torch.cuda.set_device(local_rank) + dist.init_process_group('nccl', rank=rank, world_size=world_size, timeout=timedelta(hours=10)) + + +def read_file_dist(path): + """ + Read the binary file distributedly. + File is only read once by the rank 0 process and broadcasted to other processes. + + Returns: + data (io.BytesIO): The binary data read from the file. + """ + if dist.is_initialized() and dist.get_world_size() > 1: + # read file + size = torch.LongTensor(1).cuda() + if dist.get_rank() == 0: + with open(path, 'rb') as f: + data = f.read() + data = torch.ByteTensor( + torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) + ).cuda() + size[0] = data.shape[0] + # broadcast size + dist.broadcast(size, src=0) + if dist.get_rank() != 0: + data = torch.ByteTensor(size[0].item()).cuda() + # broadcast data + dist.broadcast(data, src=0) + # convert to io.BytesIO + data = data.cpu().numpy().tobytes() + data = io.BytesIO(data) + return data + else: + with open(path, 'rb') as f: + data = f.read() + data = io.BytesIO(data) + return data + + +def unwrap_dist(model): + """ + Unwrap the model from distributed training. + """ + if isinstance(model, DDP): + return model.module + return model + + +@contextmanager +def master_first(): + """ + A context manager that ensures master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + + +@contextmanager +def local_master_first(): + """ + A context manager that ensures local master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() % torch.cuda.device_count() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + \ No newline at end of file diff --git a/trellis2/utils/elastic_utils.py b/trellis2/utils/elastic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cba3cf83836e5b58f5bc3333e809ffc932375a04 --- /dev/null +++ b/trellis2/utils/elastic_utils.py @@ -0,0 +1,228 @@ +from abc import abstractmethod +from contextlib import contextmanager +from typing import Tuple +import torch +import torch.nn as nn +import numpy as np + + +class MemoryController: + """ + Base class for memory management during training. + """ + + _last_input_size = None + _last_mem_ratio = [] + + @contextmanager + def record(self): + pass + + def update_run_states(self, input_size=None, mem_ratio=None): + if self._last_input_size is None: + self._last_input_size = input_size + elif self._last_input_size!= input_size: + raise ValueError(f'Input size should not change for different ElasticModules.') + self._last_mem_ratio.append(mem_ratio) + + @abstractmethod + def get_mem_ratio(self, input_size): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def log(self): + pass + + +class LinearMemoryController(MemoryController): + """ + A simple controller for memory management during training. + The memory usage is modeled as a linear function of: + - the number of input parameters + - the ratio of memory the model use compared to the maximum usage (with no checkpointing) + memory_usage = k * input_size * mem_ratio + b + The controller keeps track of the memory usage and gives the + expected memory ratio to keep the memory usage under a target + """ + def __init__( + self, + buffer_size=1000, + update_every=500, + target_ratio=0.8, + available_memory=None, + max_mem_ratio_start=0.1, + params=None, + device=None + ): + self.buffer_size = buffer_size + self.update_every = update_every + self.target_ratio = target_ratio + self.device = device or torch.cuda.current_device() + self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3 + + self._memory = np.zeros(buffer_size, dtype=np.float32) + self._input_size = np.zeros(buffer_size, dtype=np.float32) + self._mem_ratio = np.zeros(buffer_size, dtype=np.float32) + self._buffer_ptr = 0 + self._buffer_length = 0 + self._params = tuple(params) if params is not None else (0.0, 0.0) + self._max_mem_ratio = max_mem_ratio_start + self.step = 0 + + def __repr__(self): + return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})' + + def _add_sample(self, memory, input_size, mem_ratio): + self._memory[self._buffer_ptr] = memory + self._input_size[self._buffer_ptr] = input_size + self._mem_ratio[self._buffer_ptr] = mem_ratio + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + + @contextmanager + def record(self): + torch.cuda.reset_peak_memory_stats(self.device) + self._last_input_size = None + self._last_mem_ratio = [] + yield + self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3 + self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio) + self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio) + self.step += 1 + if self.step % self.update_every == 0: + self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1) + self._fit_params() + + def _fit_params(self): + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + + x = input_size * mem_ratio + y = memory_usage + k, b = np.polyfit(x, y, 1) + self._params = (k, b) + # self._visualize() + + def _visualize(self): + import matplotlib.pyplot as plt + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + k, b = self._params + + plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis') + x = np.array([0.0, 20000.0]) + plt.plot(x, k * x + b, c='r') + plt.savefig(f'linear_memory_controller_{self.step}.png') + plt.cla() + + def get_mem_ratio(self, input_size): + k, b = self._params + if k == 0: return np.random.rand() * self._max_mem_ratio + pred = (self.available_memory * self.target_ratio - b) / (k * input_size) + return min(self._max_mem_ratio, max(0.0, pred)) + + def state_dict(self): + return { + 'params': self._params, + } + + def load_state_dict(self, state_dict): + self._params = tuple(state_dict['params']) + + def log(self): + return { + 'params/k': self._params[0], + 'params/b': self._params[1], + 'memory': self._last_memory, + 'input_size': self._last_input_size, + 'mem_ratio': self._last_mem_ratio, + } + + +class ElasticModule(nn.Module): + """ + Module for training with elastic memory management. + """ + def __init__(self): + super().__init__() + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]: + """ + Forward with a given memory ratio. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + _, ret = self._forward_with_mem_ratio(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs) + self._memory_controller.update_run_states(input_size, mem_ratio) + return ret + + +class ElasticModuleMixin: + """ + Mixin for training with elastic memory management. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0) -> float: + """ + Context manager for training with a reduced memory ratio compared to the full memory usage. + + Returns: + float: The exact memory ratio used during the forward pass. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + ret = super().forward(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + with self.with_mem_ratio(mem_ratio) as exact_mem_ratio: + ret = super().forward(*args, **kwargs) + self._memory_controller.update_run_states(input_size, exact_mem_ratio) + return ret diff --git a/trellis2/utils/general_utils.py b/trellis2/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..589c103de8a777aea9994a899f97431cbab5447a --- /dev/null +++ b/trellis2/utils/general_utils.py @@ -0,0 +1,373 @@ +import re +import numpy as np +import cv2 +import torch +import contextlib + + +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +# Context utils +@contextlib.contextmanager +def nested_contexts(*contexts): + with contextlib.ExitStack() as stack: + for ctx in contexts: + stack.enter_context(ctx()) + yield + + +# Image utils +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + if images[0].ndim == 2: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype) + else: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + + +def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"): + """ + Draw text on an image of the given resolution. The text is automatically wrapped + and scaled so that it fits completely within the image while preserving any explicit + line breaks and original spacing. Horizontal and vertical alignment can be controlled + via flags. + + Parameters: + text (str): The input text. Newline characters and spacing are preserved. + resolution (tuple): The image resolution as (width, height). + max_size (float): The maximum font size. + h_align (str): Horizontal alignment. Options: "left", "center", "right". + v_align (str): Vertical alignment. Options: "top", "center", "bottom". + + Returns: + numpy.ndarray: The resulting image (BGR format) with the text drawn. + """ + width, height = resolution + # Create a white background image + img = np.full((height, width, 3), 255, dtype=np.uint8) + + # Set margins and compute available drawing area + margin = 10 + avail_width = width - 2 * margin + avail_height = height - 2 * margin + + # Choose OpenCV font and text thickness + font = cv2.FONT_HERSHEY_SIMPLEX + thickness = 1 + # Ratio for additional spacing between lines (relative to the height of "A") + line_spacing_ratio = 0.5 + + def wrap_line(line, max_width, font, thickness, scale): + """ + Wrap a single line of text into multiple lines such that each line's + width (measured at the given scale) does not exceed max_width. + This function preserves the original spacing by splitting the line into tokens + (words and whitespace) using a regular expression. + + Parameters: + line (str): The input text line. + max_width (int): Maximum allowed width in pixels. + font (int): OpenCV font identifier. + thickness (int): Text thickness. + scale (float): The current font scale. + + Returns: + List[str]: A list of wrapped lines. + """ + # Split the line into tokens (words and whitespace), preserving spacing + tokens = re.split(r'(\s+)', line) + if not tokens: + return [''] + + wrapped_lines = [] + current_line = "" + for token in tokens: + candidate = current_line + token + candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0] + if candidate_width <= max_width: + current_line = candidate + else: + # If current_line is empty, the token itself is too wide; + # break the token character by character. + if current_line == "": + sub_token = "" + for char in token: + candidate_char = sub_token + char + if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width: + sub_token = candidate_char + else: + if sub_token: + wrapped_lines.append(sub_token) + sub_token = char + current_line = sub_token + else: + wrapped_lines.append(current_line) + current_line = token + if current_line: + wrapped_lines.append(current_line) + return wrapped_lines + + def compute_text_block(scale): + """ + Wrap the entire text (splitting at explicit newline characters) using the + provided scale, and then compute the overall width and height of the text block. + + Returns: + wrapped_lines (List[str]): The list of wrapped lines. + block_width (int): Maximum width among the wrapped lines. + block_height (int): Total height of the text block including spacing. + sizes (List[tuple]): A list of (width, height) for each wrapped line. + spacing (int): The spacing between lines (computed from the scaled "A" height). + """ + # Split text by explicit newlines + input_lines = text.splitlines() if text else [''] + wrapped_lines = [] + for line in input_lines: + wrapped = wrap_line(line, avail_width, font, thickness, scale) + wrapped_lines.extend(wrapped) + + sizes = [] + for line in wrapped_lines: + (text_size, _) = cv2.getTextSize(line, font, scale, thickness) + sizes.append(text_size) # (width, height) + + block_width = max((w for w, h in sizes), default=0) + # Use the height of "A" (at the current scale) to compute line spacing + base_height = cv2.getTextSize("A", font, scale, thickness)[0][1] + spacing = int(line_spacing_ratio * base_height) + block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0 + + return wrapped_lines, block_width, block_height, sizes, spacing + + # Use binary search to find the maximum scale that allows the text block to fit + lo = 0.001 + hi = max_size + eps = 0.001 # convergence threshold + best_scale = lo + best_result = None + + while hi - lo > eps: + mid = (lo + hi) / 2 + wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid) + # Ensure that both width and height constraints are met + if block_width <= avail_width and block_height <= avail_height: + best_scale = mid + best_result = (wrapped_lines, block_width, block_height, sizes, spacing) + lo = mid # try a larger scale + else: + hi = mid # reduce the scale + + if best_result is None: + best_scale = 0.5 + best_result = compute_text_block(best_scale) + + wrapped_lines, block_width, block_height, sizes, spacing = best_result + + # Compute starting y-coordinate based on vertical alignment flag + if v_align == "top": + y_top = margin + elif v_align == "center": + y_top = margin + (avail_height - block_height) // 2 + elif v_align == "bottom": + y_top = margin + (avail_height - block_height) + else: + y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag + + # For cv2.putText, the y coordinate represents the text baseline; + # so for the first line add its height. + y = y_top + (sizes[0][1] if sizes else 0) + + # Draw each line with horizontal alignment based on the flag + for i, line in enumerate(wrapped_lines): + line_width, line_height = sizes[i] + if h_align == "left": + x = margin + elif h_align == "center": + x = margin + (avail_width - line_width) // 2 + elif h_align == "right": + x = margin + (avail_width - line_width) + else: + x = margin # default to left if invalid flag + + cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA) + y += line_height + spacing + + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + diff --git a/trellis2/utils/grad_clip_utils.py b/trellis2/utils/grad_clip_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..990a4352e24fc73bf732d8eb0f8ca9a07365b49e --- /dev/null +++ b/trellis2/utils/grad_clip_utils.py @@ -0,0 +1,81 @@ +from typing import * +import torch +import numpy as np +import torch.utils + + +class AdaptiveGradClipper: + """ + Adaptive gradient clipping for training. + """ + def __init__( + self, + max_norm=None, + clip_percentile=95.0, + buffer_size=1000, + ): + self.max_norm = max_norm + self.clip_percentile = clip_percentile + self.buffer_size = buffer_size + + self._grad_norm = np.zeros(buffer_size, dtype=np.float32) + self._max_norm = max_norm + self._buffer_ptr = 0 + self._buffer_length = 0 + + def __repr__(self): + return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' + + def state_dict(self): + return { + 'grad_norm': self._grad_norm, + 'max_norm': self._max_norm, + 'buffer_ptr': self._buffer_ptr, + 'buffer_length': self._buffer_length, + } + + def load_state_dict(self, state_dict): + self._grad_norm = state_dict['grad_norm'] + self._max_norm = state_dict['max_norm'] + self._buffer_ptr = state_dict['buffer_ptr'] + self._buffer_length = state_dict['buffer_length'] + + def log(self): + return { + 'max_norm': self._max_norm, + } + + def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None): + """Clip the gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + max_norm = self._max_norm if self._max_norm is not None else float('inf') + grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) + + if torch.isfinite(grad_norm): + self._grad_norm[self._buffer_ptr] = grad_norm + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + if self._buffer_length == self.buffer_size: + self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) + self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm + + return grad_norm \ No newline at end of file diff --git a/trellis2/utils/loss_utils.py b/trellis2/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52049f69543f2700bc5525b09cbf2fb25c08aa9e --- /dev/null +++ b/trellis2/utils/loss_utils.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +from lpips import LPIPS + + +def smooth_l1_loss(pred, target, beta=1.0): + diff = torch.abs(pred - target) + loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + return loss.mean() + + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def psnr(img1, img2, max_val=1.0): + mse = F.mse_loss(img1, img2) + return 20 * torch.log10(max_val / torch.sqrt(mse)) + + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +loss_fn_vgg = None +def lpips(img1, img2, value_range=(0, 1)): + global loss_fn_vgg + if loss_fn_vgg is None: + loss_fn_vgg = LPIPS(net='vgg').cuda().eval() + # normalize to [-1, 1] + img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + return loss_fn_vgg(img1, img2).mean() + + +def normal_angle(pred, gt): + pred = pred * 2.0 - 1.0 + gt = gt * 2.0 - 1.0 + norms = pred.norm(dim=-1) * gt.norm(dim=-1) + cos_sim = (pred * gt).sum(-1) / (norms + 1e-9) + cos_sim = torch.clamp(cos_sim, -1.0, 1.0) + ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean() + if ang.isnan(): + return -1 + return ang diff --git a/trellis2/utils/mesh_utils.py b/trellis2/utils/mesh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f1451ebd8b89879eee79cc61a6f4161136f245 --- /dev/null +++ b/trellis2/utils/mesh_utils.py @@ -0,0 +1,268 @@ +from typing import Tuple, Dict +import numpy as np +from trimesh import grouping, util, remesh +import struct +import re +from plyfile import PlyData, PlyElement + + +def read_ply(filename): + """ + Read a PLY file and return vertices, triangle faces, and quad faces. + + Args: + filename (str): The file path to read from. + + Returns: + vertices (np.ndarray): Array of shape [N, 3] containing vertex positions. + tris (np.ndarray): Array of shape [M, 3] containing triangle face indices (empty if none). + quads (np.ndarray): Array of shape [K, 4] containing quad face indices (empty if none). + """ + with open(filename, 'rb') as f: + # Read the header until 'end_header' is encountered + header_bytes = b"" + while True: + line = f.readline() + if not line: + raise ValueError("PLY header not found") + header_bytes += line + if b"end_header" in line: + break + header = header_bytes.decode('utf-8') + + # Determine if the file is in ASCII or binary format + is_ascii = "ascii" in header + + # Extract the number of vertices and faces from the header using regex + vertex_match = re.search(r'element vertex (\d+)', header) + if vertex_match: + num_vertices = int(vertex_match.group(1)) + else: + raise ValueError("Vertex count not found in header") + + face_match = re.search(r'element face (\d+)', header) + if face_match: + num_faces = int(face_match.group(1)) + else: + raise ValueError("Face count not found in header") + + vertices = [] + tris = [] + quads = [] + + if is_ascii: + # For ASCII format, read each line of vertex data (each line contains 3 floats) + for _ in range(num_vertices): + line = f.readline().decode('utf-8').strip() + if not line: + continue + parts = line.split() + vertices.append([float(parts[0]), float(parts[1]), float(parts[2])]) + + # Read face data, where the first number indicates the number of vertices for the face + for _ in range(num_faces): + line = f.readline().decode('utf-8').strip() + if not line: + continue + parts = line.split() + count = int(parts[0]) + indices = list(map(int, parts[1:])) + if count == 3: + tris.append(indices) + elif count == 4: + quads.append(indices) + else: + # Skip faces with other numbers of vertices (can be extended as needed) + pass + else: + # For binary format: read directly from the binary stream + # Each vertex consists of 3 floats (12 bytes per vertex) + for _ in range(num_vertices): + data = f.read(12) + if len(data) < 12: + raise ValueError("Insufficient vertex data") + v = struct.unpack(' 0 else np.empty((0, 3), dtype=np.int32) + quads = np.array(quads, dtype=np.int32) if len(quads) > 0 else np.empty((0, 4), dtype=np.int32) + + return vertices, tris, quads + + +def write_ply( + filename: str, + vertices: np.ndarray, + tris: np.ndarray, + quads: np.ndarray, + vertex_colors: np.ndarray = None, + ascii: bool = False +): + """ + Write a mesh to a PLY file, with the option to save in ASCII or binary format, + and optional per-vertex colors. + + Args: + filename (str): The filename to write to. + vertices (np.ndarray): [N, 3] The vertex positions. + tris (np.ndarray): [M, 3] The triangle indices. + quads (np.ndarray): [K, 4] The quad indices. + vertex_colors (np.ndarray, optional): [N, 3] or [N, 4] UInt8 colors for each vertex (RGB or RGBA). + ascii (bool): If True, write in ASCII format; otherwise binary little-endian. + """ + import struct + + num_vertices = len(vertices) + num_faces = len(tris) + len(quads) + + # Build header + header_lines = [ + "ply", + f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}", + f"element vertex {num_vertices}", + "property float x", + "property float y", + "property float z", + ] + + # Add vertex color properties if provided + has_color = vertex_colors is not None + if has_color: + # Expect uint8 values 0-255 + header_lines += [ + "property uchar red", + "property uchar green", + "property uchar blue", + ] + # Include alpha if RGBA + if vertex_colors.shape[1] == 4: + header_lines.append("property uchar alpha") + + header_lines += [ + f"element face {num_faces}", + "property list uchar int vertex_index", + "end_header", + "" + ] + header = "\n".join(header_lines) + + mode = 'w' if ascii else 'wb' + with open(filename, mode) as f: + # Write header + if ascii: + f.write(header) + else: + f.write(header.encode('utf-8')) + + # Write vertex data + for i, v in enumerate(vertices): + if ascii: + line = f"{v[0]} {v[1]} {v[2]}" + if has_color: + col = vertex_colors[i] + line += ' ' + ' '.join(str(int(c)) for c in col) + f.write(line + '\n') + else: + # pack xyz as floats + f.write(struct.pack(' 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/trellis2/utils/render_utils.py b/trellis2/utils/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7d3ede7a1974bcd905fbc16812d9c5eddea9ec --- /dev/null +++ b/trellis2/utils/render_utils.py @@ -0,0 +1,225 @@ +import gc +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import MeshRenderer, VoxelRenderer, PbrMeshRenderer +from ..representations import Mesh, Voxel, MeshWithPbrMaterial, MeshWithVoxel +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda() * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def get_renderer(sample, **kwargs): + if isinstance(sample, (MeshWithPbrMaterial, MeshWithVoxel)): + renderer = PbrMeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + renderer.rendering_options.peel_layers = kwargs.get('peel_layers', 8) + elif isinstance(sample, Mesh): + renderer = MeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + renderer.rendering_options.chunk_size = kwargs.get('chunk_size', None) + elif isinstance(sample, Voxel): + renderer = VoxelRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 0.1) + renderer.rendering_options.far = kwargs.get('far', 10.0) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + return renderer + + +@torch.no_grad() +def render_frames(sample, extrinsics, intrinsics, options={}, verbose=True, renderer=None, **kwargs): + if renderer is None: + renderer = get_renderer(sample, **options) + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), total=len(extrinsics), desc='Rendering', disable=not verbose): + res = renderer.render(sample, extr, intr, **kwargs) + for k, v in res.items(): + if k not in rets: rets[k] = [] + if v.dim() == 2: v = v[None].repeat(3, 1, 1) + rets[k].append(np.clip(v.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + return rets + + +def render_video(sample, resolution=1024, bg_color=(0, 0, 0), num_frames=40, r=2, fov=40, + start_yaw=None, start_pitch=None, **kwargs): + """ + Render a turntable video of the sample. + + Args: + start_yaw: Starting yaw angle in radians. If None, defaults to π/2. + start_pitch: Starting pitch angle in radians. If None, uses the default oscillating pitch + starting at ~0.25. + """ + if start_yaw is None: + start_yaw = np.pi / 2 + yaws = -torch.linspace(0, 2 * 3.1415, num_frames) + start_yaw + if start_pitch is not None: + pitch = start_pitch + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + else: + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + return res['color'], extrinsics, intrinsics + + +def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, nviews=4, **kwargs): + yaw = np.linspace(0, 2 * np.pi, nviews, endpoint=False) + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(nviews)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def proj_camera_to_render_params(camera_angle_x, distance): + """ + Convert proj camera parameters to renderer-compatible extrinsics + intrinsics. + + The proj camera (Blender convention) views from the front. Through empirical + testing, this corresponds to extrinsics_look_at from (0, 0, +distance) with up=Y + in the mesh coordinate system used by the renderer. + + Args: + camera_angle_x: horizontal FOV in radians + distance: camera distance + + Returns: + extrinsics: [4, 4] OpenCV world-to-camera (on CUDA) + intrinsics: [3, 3] OpenCV normalized intrinsics (on CUDA) + """ + orig = torch.tensor([0.0, 0.0, distance]).cuda() + target = torch.tensor([0.0, 0.0, 0.0]).cuda() + up = torch.tensor([0.0, 1.0, 0.0]).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, target, up) + + fov_tensor = torch.tensor(camera_angle_x, dtype=torch.float32).cuda() + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov_tensor, fov_tensor) + + return extrinsics, intrinsics + + +def render_proj_aligned_video(sample, camera_angle_x, distance, resolution=1024, + num_frames=40, bg_color=(0, 0, 0), **kwargs): + """ + Render a turntable video starting from the proj camera viewpoint. + + The first frame matches the proj input image exactly. Subsequent frames + rotate around the object (around Y axis, which is up in mesh space). + + Args: + sample: mesh to render + camera_angle_x: proj camera FOV in radians + distance: proj camera distance + resolution: render resolution + num_frames: number of video frames + bg_color: background color + **kwargs: additional kwargs (e.g. envmap) + + Returns: + render result dict (same as render_frames) + """ + import math + + extr_first, intr_first = proj_camera_to_render_params(camera_angle_x, distance) + + extrinsics_list = [] + intrinsics_list = [] + + angles = torch.linspace(0, 2 * math.pi, num_frames + 1)[:num_frames] + + for angle in angles: + # Rotation around Y axis (up in mesh space) + c = torch.cos(angle) + s = torch.sin(angle) + R_y = torch.tensor([ + [ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1], + ], dtype=torch.float32).cuda() + + # world-to-camera for rotated world: extr @ R_y^{-1} + R_y_inv = R_y.clone() + R_y_inv[:3, :3] = R_y[:3, :3].T + extr_rotated = extr_first @ R_y_inv + + extrinsics_list.append(extr_rotated) + intrinsics_list.append(intr_first) + + return render_frames(sample, extrinsics_list, intrinsics_list, + {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def make_pbr_vis_frames(result, resolution=1024): + num_frames = len(result['shaded']) + frames = [] + for i in range(num_frames): + shaded = Image.fromarray(result['shaded'][i]) + normal = Image.fromarray(result['normal'][i]) + base_color = Image.fromarray(result['base_color'][i]) + metallic = Image.fromarray(result['metallic'][i]) + roughness = Image.fromarray(result['roughness'][i]) + alpha = Image.fromarray(result['alpha'][i]) + shaded = shaded.resize((resolution, resolution)) + normal = normal.resize((resolution, resolution)) + base_color = base_color.resize((resolution//2, resolution//2)) + metallic = metallic.resize((resolution//2, resolution//2)) + roughness = roughness.resize((resolution//2, resolution//2)) + alpha = alpha.resize((resolution//2, resolution//2)) + row1 = np.concatenate([shaded, normal], axis=1) + row2 = np.concatenate([base_color, metallic, roughness, alpha], axis=1) + frame = np.concatenate([row1, row2], axis=0) + frames.append(frame) + return frames diff --git a/trellis2/utils/vis_utils.py b/trellis2/utils/vis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e5f58e564aea50e4d80b0265220ee3fb382cd69 --- /dev/null +++ b/trellis2/utils/vis_utils.py @@ -0,0 +1,44 @@ +from typing import * +import numpy as np +import torch +from ..modules import sparse as sp +from ..representations import Voxel +from .render_utils import render_video + + +def pca_color(feats: torch.Tensor, channels: Tuple[int, int, int] = (0, 1, 2)) -> torch.Tensor: + """ + Apply PCA to the features and return the first three principal components. + """ + feats = feats.detach() + u, s, v = torch.svd(feats) + color = u[:, channels] + color = (color - color.min(dim=0, keepdim=True)[0]) / (color.max(dim=0, keepdim=True)[0] - color.min(dim=0, keepdim=True)[0]) + return color + + +def vis_sparse_tensor( + x: sp.SparseTensor, + num_frames: int = 300, +): + assert x.shape[0] == 1, "Only support batch size 1" + assert x.coords.shape[1] == 4, "Only support 3D coordinates" + + coords = x.coords.cuda().detach()[:, 1:] + feats = x.feats.cuda().detach() + color = pca_color(feats) + + resolution = max(list(x.spatial_shape)) + resolution = int(2**np.ceil(np.log2(resolution))) + + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/resolution, + coords=coords, + attrs=color, + layout={ + 'color': slice(0, 3), + } + ) + + return render_video(rep, colors_overwrite=color, num_frames=num_frames)['color']