Spaces:
Runtime error
Runtime error
Upload Pixal3D-D Space
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +6 -7
- app.py +311 -0
- packages.txt +1 -0
- pixal3d/__init__.py +44 -0
- pixal3d/__pycache__/__init__.cpython-310.pyc +0 -0
- pixal3d/models/__init__.py +1 -0
- pixal3d/models/__pycache__/__init__.cpython-310.pyc +0 -0
- pixal3d/models/autoencoders/__pycache__/base.cpython-310.pyc +0 -0
- pixal3d/models/autoencoders/__pycache__/decoder.cpython-310.pyc +0 -0
- pixal3d/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc +0 -0
- pixal3d/models/autoencoders/__pycache__/distributions.cpython-310.pyc +0 -0
- pixal3d/models/autoencoders/__pycache__/encoder.cpython-310.pyc +0 -0
- pixal3d/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc +0 -0
- pixal3d/models/autoencoders/base.py +118 -0
- pixal3d/models/autoencoders/decoder.py +353 -0
- pixal3d/models/autoencoders/dense_vae.py +401 -0
- pixal3d/models/autoencoders/distributions.py +51 -0
- pixal3d/models/autoencoders/encoder.py +133 -0
- pixal3d/models/autoencoders/ss_vae.py +129 -0
- pixal3d/models/conditional_encoders/__init__.py +2 -0
- pixal3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc +0 -0
- pixal3d/models/conditional_encoders/__pycache__/dinov2_project_grid.cpython-310.pyc +0 -0
- pixal3d/models/conditional_encoders/dinov2_project_grid.py +750 -0
- pixal3d/models/transformers/__init__.py +2 -0
- pixal3d/models/transformers/__pycache__/__init__.cpython-310.pyc +0 -0
- pixal3d/models/transformers/__pycache__/dense_dit.cpython-310.pyc +0 -0
- pixal3d/models/transformers/__pycache__/sparse_dit.cpython-310.pyc +0 -0
- pixal3d/models/transformers/dense_dit.py +298 -0
- pixal3d/models/transformers/sparse_dit.py +469 -0
- pixal3d/modules/__pycache__/norm.cpython-310.pyc +0 -0
- pixal3d/modules/__pycache__/spatial.cpython-310.pyc +0 -0
- pixal3d/modules/__pycache__/utils.cpython-310.pyc +0 -0
- pixal3d/modules/attention/__init__.py +35 -0
- pixal3d/modules/attention/__pycache__/__init__.cpython-310.pyc +0 -0
- pixal3d/modules/attention/__pycache__/full_attn.cpython-310.pyc +0 -0
- pixal3d/modules/attention/__pycache__/modules.cpython-310.pyc +0 -0
- pixal3d/modules/attention/full_attn.py +140 -0
- pixal3d/modules/attention/modules.py +164 -0
- pixal3d/modules/norm.py +25 -0
- pixal3d/modules/sparse/__init__.py +105 -0
- pixal3d/modules/sparse/__pycache__/__init__.cpython-310.pyc +0 -0
- pixal3d/modules/sparse/__pycache__/basic.cpython-310.pyc +0 -0
- pixal3d/modules/sparse/__pycache__/linear.cpython-310.pyc +0 -0
- pixal3d/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc +0 -0
- pixal3d/modules/sparse/__pycache__/norm.cpython-310.pyc +0 -0
- pixal3d/modules/sparse/__pycache__/spatial.cpython-310.pyc +0 -0
- pixal3d/modules/sparse/attention/__init__.py +5 -0
- pixal3d/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc +0 -0
- pixal3d/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc +0 -0
- pixal3d/modules/sparse/attention/__pycache__/modules.cpython-310.pyc +0 -0
README.md
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title: Pixal3D
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 11 |
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Pixal3D-D
|
| 3 |
+
emoji: 🎨
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.29.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
+
extra_gated_eu_disallowed: true
|
| 12 |
---
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pixal3D Gradio App
|
| 3 |
+
Upload an image and generate a 3D mesh. Supports both automatic (MoGe) and fixed camera parameters.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
os.environ["no_proxy"] = os.environ.get("no_proxy", "") + ",localhost,127.0.0.1"
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import tempfile
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
|
| 15 |
+
import gradio as gr
|
| 16 |
+
|
| 17 |
+
from pixal3dpipeline2stage import Pixal3DPipeline2Stage
|
| 18 |
+
from pixal3dpipeline import Pixal3DPipeline
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
import trimesh
|
| 22 |
+
from trimesh.visual.material import PBRMaterial
|
| 23 |
+
from trimesh.transformations import rotation_matrix
|
| 24 |
+
# Static files directory for model viewer
|
| 25 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 26 |
+
SAVE_DIR = os.path.join(CURRENT_DIR, "gradio_outputs")
|
| 27 |
+
|
| 28 |
+
# Global pipeline reference
|
| 29 |
+
pipeline = None
|
| 30 |
+
rmbg = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_pipeline(ckpt_dir="./ckpt", repo_id="Pixal3D/Pixal3D"):
|
| 34 |
+
"""Load all weights at startup."""
|
| 35 |
+
global pipeline, rmbg
|
| 36 |
+
print("Loading Pixal3D 2-Stage pipeline (with MoGe + dense_check)...")
|
| 37 |
+
pipeline = Pixal3DPipeline2Stage.from_pretrained(
|
| 38 |
+
ckpt_dir=ckpt_dir,
|
| 39 |
+
repo_id=repo_id,
|
| 40 |
+
use_moge=True,
|
| 41 |
+
use_dense_check=True,
|
| 42 |
+
)
|
| 43 |
+
print("Pipeline loaded!")
|
| 44 |
+
print("Loading BiRefNet for background removal...")
|
| 45 |
+
from transformers import AutoModelForImageSegmentation
|
| 46 |
+
birefnet_model = AutoModelForImageSegmentation.from_pretrained(
|
| 47 |
+
'ZhengPeng7/BiRefNet',
|
| 48 |
+
trust_remote_code=True,
|
| 49 |
+
).to("cuda:0")
|
| 50 |
+
birefnet_model.eval()
|
| 51 |
+
rmbg = birefnet_model
|
| 52 |
+
print("BiRefNet loaded!")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def remove_background(image_np):
|
| 56 |
+
"""Use BiRefNet to remove background and add alpha channel.
|
| 57 |
+
Input: numpy array (H, W, 3) RGB
|
| 58 |
+
Output: numpy array (H, W, 4) RGBA
|
| 59 |
+
"""
|
| 60 |
+
pil_img = Image.fromarray(image_np[:, :, :3]).convert('RGB')
|
| 61 |
+
image_size = (1024, 1024)
|
| 62 |
+
transform_image = transforms.Compose([
|
| 63 |
+
transforms.Resize(image_size),
|
| 64 |
+
transforms.ToTensor(),
|
| 65 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 66 |
+
])
|
| 67 |
+
input_tensor = transform_image(pil_img).unsqueeze(0).to("cuda:0")
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
preds = rmbg(input_tensor)[-1].sigmoid().cpu()
|
| 70 |
+
pred = preds[0].squeeze()
|
| 71 |
+
pred_pil = transforms.ToPILImage()(pred)
|
| 72 |
+
mask = pred_pil.resize(pil_img.size)
|
| 73 |
+
mask = np.array(mask)
|
| 74 |
+
rgba = np.concatenate([np.array(pil_img), mask[..., None]], axis=-1)
|
| 75 |
+
return rgba
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def preprocess_image(image, use_rmbg):
|
| 79 |
+
"""Step 1: process image (background removal or use original), return immediately.
|
| 80 |
+
|
| 81 |
+
use_rmbg=True: run BiRefNet to remove background and generate RGBA
|
| 82 |
+
use_rmbg=False: directly use the original image (RGB or RGBA), skip background removal
|
| 83 |
+
"""
|
| 84 |
+
if image is None:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
if use_rmbg:
|
| 88 |
+
# Run background removal
|
| 89 |
+
if rmbg is None:
|
| 90 |
+
gr.Warning("Background removal model not loaded.")
|
| 91 |
+
return None
|
| 92 |
+
processed = remove_background(image)
|
| 93 |
+
else:
|
| 94 |
+
# Directly use original image, no background removal
|
| 95 |
+
processed = image
|
| 96 |
+
|
| 97 |
+
os.makedirs("./gradio_outputs", exist_ok=True)
|
| 98 |
+
Image.fromarray(processed).save("./gradio_outputs/processed.png")
|
| 99 |
+
return processed
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def infer_mesh(
|
| 103 |
+
processed,
|
| 104 |
+
use_fixed_camera,
|
| 105 |
+
camera_angle_x,
|
| 106 |
+
mesh_scale,
|
| 107 |
+
dense_steps,
|
| 108 |
+
dense_guidance_scale,
|
| 109 |
+
dense_seed,
|
| 110 |
+
sparse_512_steps,
|
| 111 |
+
sparse_512_guidance_scale,
|
| 112 |
+
sparse_1024_steps,
|
| 113 |
+
sparse_1024_guidance_scale,
|
| 114 |
+
sparse_seed,
|
| 115 |
+
dense_threshold,
|
| 116 |
+
mc_threshold,
|
| 117 |
+
):
|
| 118 |
+
"""Step 2: run 3D inference on the already-processed image."""
|
| 119 |
+
if processed is None or pipeline is None:
|
| 120 |
+
return None, None
|
| 121 |
+
|
| 122 |
+
tmp_input = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 123 |
+
Image.fromarray(processed).save(tmp_input.name)
|
| 124 |
+
input_path = tmp_input.name
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
if use_fixed_camera:
|
| 128 |
+
mesh = Pixal3DPipeline.infer(
|
| 129 |
+
pipeline,
|
| 130 |
+
image=input_path,
|
| 131 |
+
camera_angle_x=camera_angle_x,
|
| 132 |
+
mesh_scale=mesh_scale,
|
| 133 |
+
dense_steps=int(dense_steps),
|
| 134 |
+
dense_guidance_scale=dense_guidance_scale,
|
| 135 |
+
dense_seed=int(dense_seed),
|
| 136 |
+
sparse_512_steps=int(sparse_512_steps),
|
| 137 |
+
sparse_512_guidance_scale=sparse_512_guidance_scale,
|
| 138 |
+
sparse_1024_steps=int(sparse_1024_steps),
|
| 139 |
+
sparse_1024_guidance_scale=sparse_1024_guidance_scale,
|
| 140 |
+
sparse_seed=int(sparse_seed),
|
| 141 |
+
dense_threshold=dense_threshold,
|
| 142 |
+
mc_threshold=mc_threshold,
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
mesh = pipeline.infer(
|
| 146 |
+
image=input_path,
|
| 147 |
+
mesh_scale=mesh_scale,
|
| 148 |
+
optimize_mesh_scale=True,
|
| 149 |
+
target_padding=3,
|
| 150 |
+
max_optim_iterations=2,
|
| 151 |
+
dense_steps=int(dense_steps),
|
| 152 |
+
dense_guidance_scale=dense_guidance_scale,
|
| 153 |
+
dense_seed=int(dense_seed),
|
| 154 |
+
sparse_512_steps=int(sparse_512_steps),
|
| 155 |
+
sparse_512_guidance_scale=sparse_512_guidance_scale,
|
| 156 |
+
sparse_1024_steps=int(sparse_1024_steps),
|
| 157 |
+
sparse_1024_guidance_scale=sparse_1024_guidance_scale,
|
| 158 |
+
sparse_seed=int(sparse_seed),
|
| 159 |
+
dense_threshold=dense_threshold,
|
| 160 |
+
mc_threshold=mc_threshold,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
ply_file = tempfile.NamedTemporaryFile(suffix=".ply", delete=False)
|
| 164 |
+
glb_file = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
|
| 165 |
+
ply_path = ply_file.name
|
| 166 |
+
glb_path = glb_file.name
|
| 167 |
+
ply_file.close()
|
| 168 |
+
glb_file.close()
|
| 169 |
+
mesh.export(ply_path)
|
| 170 |
+
# Export GLB with PBR material (same as hunyuan_app)
|
| 171 |
+
|
| 172 |
+
material = PBRMaterial(baseColorFactor=[102, 102, 102, 255])
|
| 173 |
+
clean_mesh = trimesh.Trimesh(mesh.vertices, mesh.faces)
|
| 174 |
+
clean_mesh.visual = trimesh.visual.TextureVisuals(material=material)
|
| 175 |
+
# Rotate mesh to desired view angle (only X rotation needed)
|
| 176 |
+
rot_x = rotation_matrix(np.radians(-90), [1, 0, 0])
|
| 177 |
+
clean_mesh.apply_transform(rot_x)
|
| 178 |
+
clean_mesh.export(glb_path)
|
| 179 |
+
|
| 180 |
+
return glb_path, ply_path
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
import traceback
|
| 184 |
+
traceback.print_exc()
|
| 185 |
+
return None, None
|
| 186 |
+
finally:
|
| 187 |
+
os.unlink(input_path)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def build_ui():
|
| 191 |
+
# Custom CSS to hide the download button in Model3D
|
| 192 |
+
custom_css = """
|
| 193 |
+
#model3d-viewer button[aria-label="下载"],
|
| 194 |
+
#model3d-viewer button[aria-label="Download"],
|
| 195 |
+
#model3d-viewer button[title="下载"],
|
| 196 |
+
#model3d-viewer button[title="Download"] {
|
| 197 |
+
display: none !important;
|
| 198 |
+
}
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
with gr.Blocks(title="Pixal3D", theme=gr.themes.Soft(), css=custom_css) as demo:
|
| 202 |
+
gr.Markdown("# Pixal3D: Pixel-Aligned 3D Generation from Images")
|
| 203 |
+
|
| 204 |
+
with gr.Row():
|
| 205 |
+
# Left column: input (scale=1)
|
| 206 |
+
with gr.Column(scale=1):
|
| 207 |
+
image_input = gr.Image(label="Input Image", type="numpy", image_mode=None)
|
| 208 |
+
|
| 209 |
+
processed_image = gr.Image(
|
| 210 |
+
label="Processed Image",
|
| 211 |
+
image_mode="RGBA",
|
| 212 |
+
type="numpy",
|
| 213 |
+
interactive=False,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
use_rmbg = gr.Checkbox(
|
| 217 |
+
label="Remove Background",
|
| 218 |
+
value=True,
|
| 219 |
+
info="Checked: auto remove background via BiRefNet. Unchecked: use original image directly.",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
use_fixed_camera = gr.Checkbox(
|
| 223 |
+
label="Use Fixed Camera Parameters",
|
| 224 |
+
value=False,
|
| 225 |
+
info="If checked, use manually set FOV/distance/mesh_scale instead of MoGe auto-estimation.",
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
with gr.Group(visible=False) as fixed_camera_group:
|
| 229 |
+
gr.Markdown("### Camera Parameters (fixed mode)")
|
| 230 |
+
camera_angle_x = gr.Number(value=0.2, label="camera_angle_x (rad)", step=0.01)
|
| 231 |
+
|
| 232 |
+
with gr.Group():
|
| 233 |
+
gr.Markdown("### Mesh Scale")
|
| 234 |
+
mesh_scale = gr.Number(value=0.5, label="mesh_scale", step=0.01,
|
| 235 |
+
info="Initial mesh scale. Fixed mode default: 0.9, Auto mode default: 0.5")
|
| 236 |
+
|
| 237 |
+
with gr.Accordion("Advanced Inference Parameters", open=False):
|
| 238 |
+
dense_steps = gr.Number(value=50, label="Dense Steps", step=1, precision=0)
|
| 239 |
+
dense_guidance_scale = gr.Number(value=7.0, label="Dense Guidance Scale", step=0.1)
|
| 240 |
+
dense_seed = gr.Number(value=0, label="Dense Seed", step=1, precision=0)
|
| 241 |
+
sparse_512_steps = gr.Number(value=30, label="Sparse 512 Steps", step=1, precision=0)
|
| 242 |
+
sparse_512_guidance_scale = gr.Number(value=7.0, label="Sparse 512 Guidance Scale", step=0.1)
|
| 243 |
+
sparse_1024_steps = gr.Number(value=15, label="Sparse 1024 Steps", step=1, precision=0)
|
| 244 |
+
sparse_1024_guidance_scale = gr.Number(value=7.0, label="Sparse 1024 Guidance Scale", step=0.1)
|
| 245 |
+
sparse_seed = gr.Number(value=0, label="Sparse Seed", step=1, precision=0)
|
| 246 |
+
dense_threshold = gr.Number(value=0.1, label="Dense Threshold", step=0.01)
|
| 247 |
+
mc_threshold = gr.Number(value=0.2, label="MC Threshold", step=0.01)
|
| 248 |
+
|
| 249 |
+
run_btn = gr.Button("Generate 3D Mesh", variant="primary", size="lg")
|
| 250 |
+
|
| 251 |
+
# Right column: output (scale=2)
|
| 252 |
+
with gr.Column(scale=2):
|
| 253 |
+
model_viewer = gr.Model3D(label="3D Mesh Preview", interactive=False, clear_color=[1.0, 1.0, 1.0, 1.0], elem_id="model3d-viewer")
|
| 254 |
+
output_file = gr.File(label="Download .ply")
|
| 255 |
+
|
| 256 |
+
# Toggle fixed camera group visibility and mesh_scale default
|
| 257 |
+
def on_toggle_fixed(use_fixed):
|
| 258 |
+
new_scale = 0.9 if use_fixed else 0.5
|
| 259 |
+
return gr.update(visible=use_fixed), gr.update(value=new_scale)
|
| 260 |
+
|
| 261 |
+
use_fixed_camera.change(
|
| 262 |
+
fn=on_toggle_fixed,
|
| 263 |
+
inputs=[use_fixed_camera],
|
| 264 |
+
outputs=[fixed_camera_group, mesh_scale],
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Step 1: preprocess image → show processed image immediately
|
| 268 |
+
# Step 2: run 3D inference → show mesh and download
|
| 269 |
+
run_btn.click(
|
| 270 |
+
fn=preprocess_image,
|
| 271 |
+
inputs=[image_input, use_rmbg],
|
| 272 |
+
outputs=[processed_image],
|
| 273 |
+
).then(
|
| 274 |
+
fn=infer_mesh,
|
| 275 |
+
inputs=[
|
| 276 |
+
processed_image,
|
| 277 |
+
use_fixed_camera,
|
| 278 |
+
camera_angle_x,
|
| 279 |
+
mesh_scale,
|
| 280 |
+
dense_steps,
|
| 281 |
+
dense_guidance_scale,
|
| 282 |
+
dense_seed,
|
| 283 |
+
sparse_512_steps,
|
| 284 |
+
sparse_512_guidance_scale,
|
| 285 |
+
sparse_1024_steps,
|
| 286 |
+
sparse_1024_guidance_scale,
|
| 287 |
+
sparse_seed,
|
| 288 |
+
dense_threshold,
|
| 289 |
+
mc_threshold,
|
| 290 |
+
],
|
| 291 |
+
outputs=[model_viewer, output_file],
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
demo.queue(api_open=False)
|
| 295 |
+
return demo
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
if __name__ == "__main__":
|
| 299 |
+
import argparse
|
| 300 |
+
|
| 301 |
+
parser = argparse.ArgumentParser()
|
| 302 |
+
parser.add_argument("--repo_id", type=str, default="TencentARC/Pixal3D-D")
|
| 303 |
+
args = parser.parse_args()
|
| 304 |
+
|
| 305 |
+
load_pipeline(repo_id=args.repo_id)
|
| 306 |
+
|
| 307 |
+
demo = build_ui()
|
| 308 |
+
demo.launch(
|
| 309 |
+
server_name="127.0.0.1",
|
| 310 |
+
share=True,
|
| 311 |
+
)
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
libsparsehash-dev
|
pixal3d/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
__modules__ = {}
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def register(name):
|
| 7 |
+
def decorator(cls):
|
| 8 |
+
# Allow re-registration for checkpoint loading compatibility
|
| 9 |
+
# When torch.load triggers module re-import, the same class may be registered again
|
| 10 |
+
__modules__[name] = cls
|
| 11 |
+
return cls
|
| 12 |
+
|
| 13 |
+
return decorator
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def find(name):
|
| 17 |
+
if name in __modules__:
|
| 18 |
+
return __modules__[name]
|
| 19 |
+
else:
|
| 20 |
+
try:
|
| 21 |
+
module_string = ".".join(name.split(".")[:-1])
|
| 22 |
+
cls_name = name.split(".")[-1]
|
| 23 |
+
module = importlib.import_module(module_string, package=None)
|
| 24 |
+
return getattr(module, cls_name)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
raise ValueError(f"Module {name} not found!")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
### grammar sugar for logging utilities ###
|
| 30 |
+
import logging
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger("pixal3d")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def debug(*args, **kwargs):
|
| 36 |
+
logger.debug(*args, **kwargs)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def info(*args, **kwargs):
|
| 40 |
+
logger.info(*args, **kwargs)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def warn(*args, **kwargs):
|
| 44 |
+
logger.warning(*args, **kwargs)
|
pixal3d/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.25 kB). View file
|
|
|
pixal3d/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import conditional_encoders, transformers
|
pixal3d/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (241 Bytes). View file
|
|
|
pixal3d/models/autoencoders/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (4.39 kB). View file
|
|
|
pixal3d/models/autoencoders/__pycache__/decoder.cpython-310.pyc
ADDED
|
Binary file (8.77 kB). View file
|
|
|
pixal3d/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
pixal3d/models/autoencoders/__pycache__/distributions.cpython-310.pyc
ADDED
|
Binary file (2.09 kB). View file
|
|
|
pixal3d/models/autoencoders/__pycache__/encoder.cpython-310.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
pixal3d/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc
ADDED
|
Binary file (3.86 kB). View file
|
|
|
pixal3d/models/autoencoders/base.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
|
| 6 |
+
from ...modules import sparse as sp
|
| 7 |
+
from ...modules.transformer import AbsolutePositionEmbedder
|
| 8 |
+
from ...modules.sparse.transformer import SparseTransformerBlock
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def block_attn_config(self):
|
| 12 |
+
"""
|
| 13 |
+
Return the attention configuration of the model.
|
| 14 |
+
"""
|
| 15 |
+
for i in range(self.num_blocks):
|
| 16 |
+
if self.attn_mode == "shift_window":
|
| 17 |
+
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
|
| 18 |
+
elif self.attn_mode == "shift_sequence":
|
| 19 |
+
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
|
| 20 |
+
elif self.attn_mode == "shift_order":
|
| 21 |
+
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
|
| 22 |
+
elif self.attn_mode == "full":
|
| 23 |
+
yield "full", None, None, None, None
|
| 24 |
+
elif self.attn_mode == "swin":
|
| 25 |
+
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SparseTransformerBase(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
Sparse Transformer without output layers.
|
| 31 |
+
Serve as the base class for encoder and decoder.
|
| 32 |
+
"""
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
in_channels: int,
|
| 36 |
+
model_channels: int,
|
| 37 |
+
num_blocks: int,
|
| 38 |
+
num_heads: Optional[int] = None,
|
| 39 |
+
num_head_channels: Optional[int] = 64,
|
| 40 |
+
mlp_ratio: float = 4.0,
|
| 41 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
| 42 |
+
window_size: Optional[int] = None,
|
| 43 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 44 |
+
use_fp16: bool = False,
|
| 45 |
+
use_checkpoint: bool = False,
|
| 46 |
+
qk_rms_norm: bool = False,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.in_channels = in_channels
|
| 50 |
+
self.model_channels = model_channels
|
| 51 |
+
self.num_blocks = num_blocks
|
| 52 |
+
self.window_size = window_size
|
| 53 |
+
self.num_heads = num_heads or model_channels // num_head_channels
|
| 54 |
+
self.mlp_ratio = mlp_ratio
|
| 55 |
+
self.attn_mode = attn_mode
|
| 56 |
+
self.pe_mode = pe_mode
|
| 57 |
+
self.use_fp16 = use_fp16
|
| 58 |
+
self.use_checkpoint = use_checkpoint
|
| 59 |
+
self.qk_rms_norm = qk_rms_norm
|
| 60 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 61 |
+
|
| 62 |
+
if pe_mode == "ape":
|
| 63 |
+
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
| 64 |
+
self.input_layer = sp.SparseLinear(in_channels, model_channels)
|
| 65 |
+
self.blocks = nn.ModuleList([
|
| 66 |
+
SparseTransformerBlock(
|
| 67 |
+
model_channels,
|
| 68 |
+
num_heads=self.num_heads,
|
| 69 |
+
mlp_ratio=self.mlp_ratio,
|
| 70 |
+
attn_mode=attn_mode,
|
| 71 |
+
window_size=window_size,
|
| 72 |
+
shift_sequence=shift_sequence,
|
| 73 |
+
shift_window=shift_window,
|
| 74 |
+
serialize_mode=serialize_mode,
|
| 75 |
+
use_checkpoint=self.use_checkpoint,
|
| 76 |
+
use_rope=(pe_mode == "rope"),
|
| 77 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 78 |
+
)
|
| 79 |
+
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
|
| 80 |
+
])
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def device(self) -> torch.device:
|
| 84 |
+
"""
|
| 85 |
+
Return the device of the model.
|
| 86 |
+
"""
|
| 87 |
+
return next(self.parameters()).device
|
| 88 |
+
|
| 89 |
+
def convert_to_fp16(self) -> None:
|
| 90 |
+
"""
|
| 91 |
+
Convert the torso of the model to float16.
|
| 92 |
+
"""
|
| 93 |
+
# self.blocks.apply(convert_module_to_f16)
|
| 94 |
+
self.apply(convert_module_to_f16)
|
| 95 |
+
|
| 96 |
+
def convert_to_fp32(self) -> None:
|
| 97 |
+
"""
|
| 98 |
+
Convert the torso of the model to float32.
|
| 99 |
+
"""
|
| 100 |
+
self.blocks.apply(convert_module_to_f32)
|
| 101 |
+
|
| 102 |
+
def initialize_weights(self) -> None:
|
| 103 |
+
# Initialize transformer layers:
|
| 104 |
+
def _basic_init(module):
|
| 105 |
+
if isinstance(module, nn.Linear):
|
| 106 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 107 |
+
if module.bias is not None:
|
| 108 |
+
nn.init.constant_(module.bias, 0)
|
| 109 |
+
self.apply(_basic_init)
|
| 110 |
+
|
| 111 |
+
def forward(self, x: sp.SparseTensor, factor: float = None) -> sp.SparseTensor:
|
| 112 |
+
h = self.input_layer(x)
|
| 113 |
+
if self.pe_mode == "ape":
|
| 114 |
+
h = h + self.pos_embedder(x.coords[:, 1:], factor)
|
| 115 |
+
h = h.type(self.dtype)
|
| 116 |
+
for block in self.blocks:
|
| 117 |
+
h = block(h)
|
| 118 |
+
return h
|
pixal3d/models/autoencoders/decoder.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
| 7 |
+
from ...modules import sparse as sp
|
| 8 |
+
from .base import SparseTransformerBase
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SparseSubdivideBlock3d(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
channels: int,
|
| 16 |
+
out_channels: Optional[int] = None,
|
| 17 |
+
use_checkpoint: bool = False,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.channels = channels
|
| 21 |
+
self.out_channels = out_channels or channels
|
| 22 |
+
self.use_checkpoint = use_checkpoint
|
| 23 |
+
|
| 24 |
+
self.act_layers = nn.Sequential(
|
| 25 |
+
sp.SparseConv3d(channels, self.out_channels, 3, padding=1),
|
| 26 |
+
sp.SparseSiLU()
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.sub = sp.SparseSubdivide()
|
| 30 |
+
|
| 31 |
+
self.out_layers = nn.Sequential(
|
| 32 |
+
sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1),
|
| 33 |
+
sp.SparseSiLU(),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
| 37 |
+
h = self.act_layers(x)
|
| 38 |
+
h = self.sub(h)
|
| 39 |
+
h = self.out_layers(h)
|
| 40 |
+
return h
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
if self.use_checkpoint:
|
| 44 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
| 45 |
+
else:
|
| 46 |
+
return self._forward(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SparseSDFDecoder(SparseTransformerBase):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
resolution: int,
|
| 53 |
+
model_channels: int,
|
| 54 |
+
latent_channels: int,
|
| 55 |
+
num_blocks: int,
|
| 56 |
+
num_heads: Optional[int] = None,
|
| 57 |
+
num_head_channels: Optional[int] = 64,
|
| 58 |
+
mlp_ratio: float = 4,
|
| 59 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
| 60 |
+
window_size: int = 8,
|
| 61 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 62 |
+
use_fp16: bool = False,
|
| 63 |
+
use_checkpoint: bool = False,
|
| 64 |
+
qk_rms_norm: bool = False,
|
| 65 |
+
representation_config: dict = None,
|
| 66 |
+
out_channels: int = 1,
|
| 67 |
+
chunk_size: int = 1,
|
| 68 |
+
):
|
| 69 |
+
super().__init__(
|
| 70 |
+
in_channels=latent_channels,
|
| 71 |
+
model_channels=model_channels,
|
| 72 |
+
num_blocks=num_blocks,
|
| 73 |
+
num_heads=num_heads,
|
| 74 |
+
num_head_channels=num_head_channels,
|
| 75 |
+
mlp_ratio=mlp_ratio,
|
| 76 |
+
attn_mode=attn_mode,
|
| 77 |
+
window_size=window_size,
|
| 78 |
+
pe_mode=pe_mode,
|
| 79 |
+
use_fp16=use_fp16,
|
| 80 |
+
use_checkpoint=use_checkpoint,
|
| 81 |
+
qk_rms_norm=qk_rms_norm,
|
| 82 |
+
)
|
| 83 |
+
self.resolution = resolution
|
| 84 |
+
self.rep_config = representation_config
|
| 85 |
+
self.out_channels = out_channels
|
| 86 |
+
self.chunk_size = chunk_size
|
| 87 |
+
self.upsample = nn.ModuleList([
|
| 88 |
+
SparseSubdivideBlock3d(
|
| 89 |
+
channels=model_channels,
|
| 90 |
+
out_channels=model_channels // 4,
|
| 91 |
+
use_checkpoint=use_checkpoint,
|
| 92 |
+
),
|
| 93 |
+
SparseSubdivideBlock3d(
|
| 94 |
+
channels=model_channels // 4,
|
| 95 |
+
out_channels=model_channels // 8,
|
| 96 |
+
use_checkpoint=use_checkpoint,
|
| 97 |
+
),
|
| 98 |
+
SparseSubdivideBlock3d(
|
| 99 |
+
channels=model_channels // 8,
|
| 100 |
+
out_channels=model_channels // 16,
|
| 101 |
+
use_checkpoint=use_checkpoint,
|
| 102 |
+
)
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
self.out_layer = sp.SparseLinear(model_channels // 16, self.out_channels)
|
| 106 |
+
self.out_active = sp.SparseTanh()
|
| 107 |
+
|
| 108 |
+
self.initialize_weights()
|
| 109 |
+
if use_fp16:
|
| 110 |
+
self.convert_to_fp16()
|
| 111 |
+
|
| 112 |
+
def initialize_weights(self) -> None:
|
| 113 |
+
super().initialize_weights()
|
| 114 |
+
# Zero-out output layers:
|
| 115 |
+
nn.init.constant_(self.out_layer.weight, 0)
|
| 116 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
| 117 |
+
|
| 118 |
+
def convert_to_fp16(self) -> None:
|
| 119 |
+
"""
|
| 120 |
+
Convert the torso of the model to float16.
|
| 121 |
+
"""
|
| 122 |
+
super().convert_to_fp16()
|
| 123 |
+
self.upsample.apply(convert_module_to_f16)
|
| 124 |
+
|
| 125 |
+
def convert_to_fp32(self) -> None:
|
| 126 |
+
"""
|
| 127 |
+
Convert the torso of the model to float32.
|
| 128 |
+
"""
|
| 129 |
+
super().convert_to_fp32()
|
| 130 |
+
self.upsample.apply(convert_module_to_f32)
|
| 131 |
+
|
| 132 |
+
@torch.no_grad()
|
| 133 |
+
def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4):
|
| 134 |
+
|
| 135 |
+
sub_resolution = self.resolution // chunk_size
|
| 136 |
+
upsample_ratio = 8 # hard-coded here
|
| 137 |
+
assert sub_resolution % padding == 0
|
| 138 |
+
out = []
|
| 139 |
+
|
| 140 |
+
for i in range(chunk_size):
|
| 141 |
+
for j in range(chunk_size):
|
| 142 |
+
for k in range(chunk_size):
|
| 143 |
+
# Calculate padded boundaries
|
| 144 |
+
start_x = max(0, i * sub_resolution - padding)
|
| 145 |
+
end_x = min((i + 1) * sub_resolution + padding, self.resolution)
|
| 146 |
+
start_y = max(0, j * sub_resolution - padding)
|
| 147 |
+
end_y = min((j + 1) * sub_resolution + padding, self.resolution)
|
| 148 |
+
start_z = max(0, k * sub_resolution - padding)
|
| 149 |
+
end_z = min((k + 1) * sub_resolution + padding, self.resolution)
|
| 150 |
+
|
| 151 |
+
# Store original (unpadded) boundaries for later cropping
|
| 152 |
+
orig_start_x = i * sub_resolution
|
| 153 |
+
orig_end_x = (i + 1) * sub_resolution
|
| 154 |
+
orig_start_y = j * sub_resolution
|
| 155 |
+
orig_end_y = (j + 1) * sub_resolution
|
| 156 |
+
orig_start_z = k * sub_resolution
|
| 157 |
+
orig_end_z = (k + 1) * sub_resolution
|
| 158 |
+
|
| 159 |
+
mask = torch.logical_and(
|
| 160 |
+
torch.logical_and(
|
| 161 |
+
torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
|
| 162 |
+
torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
|
| 163 |
+
),
|
| 164 |
+
torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if mask.sum() > 0:
|
| 168 |
+
# Get the coordinates and shift them to local space
|
| 169 |
+
coords = x.coords[mask].clone()
|
| 170 |
+
# Shift to local coordinates
|
| 171 |
+
coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
|
| 172 |
+
device=coords.device).view(1, 3)
|
| 173 |
+
|
| 174 |
+
chunk_tensor = sp.SparseTensor(x.feats[mask], coords)
|
| 175 |
+
# Store the boundaries and offsets as metadata for later reconstruction
|
| 176 |
+
chunk_tensor.bounds = {
|
| 177 |
+
'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
|
| 178 |
+
'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
|
| 179 |
+
}
|
| 180 |
+
out.append(chunk_tensor)
|
| 181 |
+
|
| 182 |
+
del mask
|
| 183 |
+
torch.cuda.empty_cache()
|
| 184 |
+
return out
|
| 185 |
+
|
| 186 |
+
@torch.no_grad()
|
| 187 |
+
def split_single_chunk(self, x: sp.SparseTensor, chunk_size=4, padding=4):
|
| 188 |
+
sub_resolution = self.resolution // chunk_size
|
| 189 |
+
upsample_ratio = 8 # hard-coded here
|
| 190 |
+
assert sub_resolution % padding == 0
|
| 191 |
+
|
| 192 |
+
mask_sum = -1
|
| 193 |
+
while mask_sum < 1:
|
| 194 |
+
orig_start_x = random.randint(0, self.resolution - sub_resolution)
|
| 195 |
+
orig_end_x = orig_start_x + sub_resolution
|
| 196 |
+
orig_start_y = random.randint(0, self.resolution - sub_resolution)
|
| 197 |
+
orig_end_y = orig_start_y + sub_resolution
|
| 198 |
+
orig_start_z = random.randint(0, self.resolution - sub_resolution)
|
| 199 |
+
orig_end_z = orig_start_z + sub_resolution
|
| 200 |
+
start_x = max(0, orig_start_x - padding)
|
| 201 |
+
end_x = min(orig_end_x + padding, self.resolution)
|
| 202 |
+
start_y = max(0, orig_start_y - padding)
|
| 203 |
+
end_y = min(orig_end_y + padding, self.resolution)
|
| 204 |
+
start_z = max(0, orig_start_z - padding)
|
| 205 |
+
end_z = min(orig_end_z + padding, self.resolution)
|
| 206 |
+
|
| 207 |
+
mask_ori = torch.logical_and(
|
| 208 |
+
torch.logical_and(
|
| 209 |
+
torch.logical_and(x.coords[:, 1] >= orig_start_x, x.coords[:, 1] < orig_end_x),
|
| 210 |
+
torch.logical_and(x.coords[:, 2] >= orig_start_y, x.coords[:, 2] < orig_end_y)
|
| 211 |
+
),
|
| 212 |
+
torch.logical_and(x.coords[:, 3] >= orig_start_z, x.coords[:, 3] < orig_end_z)
|
| 213 |
+
)
|
| 214 |
+
mask_sum = mask_ori.sum()
|
| 215 |
+
|
| 216 |
+
# Store the boundaries and offsets as metadata for later reconstruction
|
| 217 |
+
bounds = {
|
| 218 |
+
'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
|
| 219 |
+
'start': (start_x, end_x, start_y, end_y, start_z, end_z),
|
| 220 |
+
'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
|
| 221 |
+
}
|
| 222 |
+
return bounds
|
| 223 |
+
|
| 224 |
+
def forward_single_chunk(self, x: sp.SparseTensor, padding=4):
|
| 225 |
+
|
| 226 |
+
bounds = self.split_single_chunk(x, self.chunk_size, padding=padding)
|
| 227 |
+
|
| 228 |
+
start_x, end_x, start_y, end_y, start_z, end_z = bounds['start']
|
| 229 |
+
mask = torch.logical_and(
|
| 230 |
+
torch.logical_and(
|
| 231 |
+
torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
|
| 232 |
+
torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
|
| 233 |
+
),
|
| 234 |
+
torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Shift to local coordinates
|
| 238 |
+
coords = x.coords.clone()
|
| 239 |
+
coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
|
| 240 |
+
device=coords.device).view(1, 3)
|
| 241 |
+
|
| 242 |
+
chunk = sp.SparseTensor(x.feats[mask], coords[mask])
|
| 243 |
+
|
| 244 |
+
chunk_result = self.upsamples(chunk)
|
| 245 |
+
|
| 246 |
+
coords = chunk_result.coords.clone()
|
| 247 |
+
|
| 248 |
+
# Restore global coordinates
|
| 249 |
+
offsets = torch.tensor(bounds['offsets'],
|
| 250 |
+
device=coords.device).view(1, 3)
|
| 251 |
+
coords[:, 1:] = coords[:, 1:] + offsets
|
| 252 |
+
|
| 253 |
+
# Filter points within original bounds
|
| 254 |
+
original = bounds['original']
|
| 255 |
+
within_bounds = torch.logical_and(
|
| 256 |
+
torch.logical_and(
|
| 257 |
+
torch.logical_and(
|
| 258 |
+
coords[:, 1] >= original[0],
|
| 259 |
+
coords[:, 1] < original[1]
|
| 260 |
+
),
|
| 261 |
+
torch.logical_and(
|
| 262 |
+
coords[:, 2] >= original[2],
|
| 263 |
+
coords[:, 2] < original[3]
|
| 264 |
+
)
|
| 265 |
+
),
|
| 266 |
+
torch.logical_and(
|
| 267 |
+
coords[:, 3] >= original[4],
|
| 268 |
+
coords[:, 3] < original[5]
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
final_coords = coords[within_bounds]
|
| 273 |
+
final_feats = chunk_result.feats[within_bounds]
|
| 274 |
+
|
| 275 |
+
return sp.SparseTensor(final_feats, final_coords)
|
| 276 |
+
|
| 277 |
+
def upsamples(self, x, return_feat: bool = False):
|
| 278 |
+
dtype = x.dtype
|
| 279 |
+
for block in self.upsample:
|
| 280 |
+
x = block(x)
|
| 281 |
+
x = x.type(dtype)
|
| 282 |
+
|
| 283 |
+
output = self.out_active(self.out_layer(x))
|
| 284 |
+
|
| 285 |
+
if return_feat:
|
| 286 |
+
return output, x
|
| 287 |
+
else:
|
| 288 |
+
return output
|
| 289 |
+
|
| 290 |
+
def forward(self, x: sp.SparseTensor, factor: float = None, return_feat: bool = False):
|
| 291 |
+
h = super().forward(x, factor)
|
| 292 |
+
if self.chunk_size <= 1:
|
| 293 |
+
for block in self.upsample:
|
| 294 |
+
h = block(h)
|
| 295 |
+
h = h.type(x.dtype)
|
| 296 |
+
|
| 297 |
+
if return_feat:
|
| 298 |
+
return self.out_active(self.out_layer(h)), h
|
| 299 |
+
|
| 300 |
+
h = self.out_layer(h)
|
| 301 |
+
h = self.out_active(h)
|
| 302 |
+
return h
|
| 303 |
+
else:
|
| 304 |
+
if self.training:
|
| 305 |
+
return self.forward_single_chunk(h)
|
| 306 |
+
else:
|
| 307 |
+
batch_size = x.shape[0]
|
| 308 |
+
chunks = self.split_for_meshing(h, chunk_size=self.chunk_size)
|
| 309 |
+
all_coords, all_feats = [], []
|
| 310 |
+
for chunk_idx, chunk in enumerate(chunks):
|
| 311 |
+
chunk_result = self.upsamples(chunk)
|
| 312 |
+
|
| 313 |
+
for b in range(batch_size):
|
| 314 |
+
mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1)
|
| 315 |
+
if mask.numel() > 0:
|
| 316 |
+
coords = chunk_result.coords[mask].clone()
|
| 317 |
+
|
| 318 |
+
# Restore global coordinates
|
| 319 |
+
offsets = torch.tensor(chunk.bounds['offsets'],
|
| 320 |
+
device=coords.device).view(1, 3)
|
| 321 |
+
coords[:, 1:] = coords[:, 1:] + offsets
|
| 322 |
+
|
| 323 |
+
# Filter points within original bounds
|
| 324 |
+
bounds = chunk.bounds['original']
|
| 325 |
+
within_bounds = torch.logical_and(
|
| 326 |
+
torch.logical_and(
|
| 327 |
+
torch.logical_and(
|
| 328 |
+
coords[:, 1] >= bounds[0],
|
| 329 |
+
coords[:, 1] < bounds[1]
|
| 330 |
+
),
|
| 331 |
+
torch.logical_and(
|
| 332 |
+
coords[:, 2] >= bounds[2],
|
| 333 |
+
coords[:, 2] < bounds[3]
|
| 334 |
+
)
|
| 335 |
+
),
|
| 336 |
+
torch.logical_and(
|
| 337 |
+
coords[:, 3] >= bounds[4],
|
| 338 |
+
coords[:, 3] < bounds[5]
|
| 339 |
+
)
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if within_bounds.any():
|
| 343 |
+
all_coords.append(coords[within_bounds])
|
| 344 |
+
all_feats.append(chunk_result.feats[mask][within_bounds])
|
| 345 |
+
|
| 346 |
+
if not self.training:
|
| 347 |
+
torch.cuda.empty_cache()
|
| 348 |
+
|
| 349 |
+
final_coords = torch.cat(all_coords)
|
| 350 |
+
final_feats = torch.cat(all_feats)
|
| 351 |
+
|
| 352 |
+
return sp.SparseTensor(final_feats, final_coords)
|
| 353 |
+
|
pixal3d/models/autoencoders/dense_vae.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import trimesh
|
| 6 |
+
from skimage import measure
|
| 7 |
+
from ...modules.norm import GroupNorm32, ChannelLayerNorm32
|
| 8 |
+
from ...modules.spatial import pixel_shuffle_3d
|
| 9 |
+
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
| 10 |
+
from .distributions import DiagonalGaussianDistribution
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
|
| 14 |
+
"""
|
| 15 |
+
Return a normalization layer.
|
| 16 |
+
"""
|
| 17 |
+
if norm_type == "group":
|
| 18 |
+
return GroupNorm32(32, *args, **kwargs)
|
| 19 |
+
elif norm_type == "layer":
|
| 20 |
+
return ChannelLayerNorm32(*args, **kwargs)
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError(f"Invalid norm type {norm_type}")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ResBlock3d(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
channels: int,
|
| 29 |
+
out_channels: Optional[int] = None,
|
| 30 |
+
norm_type: Literal["group", "layer"] = "layer",
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.channels = channels
|
| 34 |
+
self.out_channels = out_channels or channels
|
| 35 |
+
|
| 36 |
+
self.norm1 = norm_layer(norm_type, channels)
|
| 37 |
+
self.norm2 = norm_layer(norm_type, self.out_channels)
|
| 38 |
+
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
|
| 39 |
+
self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
|
| 40 |
+
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
h = self.norm1(x)
|
| 44 |
+
h = F.silu(h)
|
| 45 |
+
h = self.conv1(h)
|
| 46 |
+
h = self.norm2(h)
|
| 47 |
+
h = F.silu(h)
|
| 48 |
+
h = self.conv2(h)
|
| 49 |
+
h = h + self.skip_connection(x)
|
| 50 |
+
return h
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class DownsampleBlock3d(nn.Module):
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
in_channels: int,
|
| 57 |
+
out_channels: int,
|
| 58 |
+
mode: Literal["conv", "avgpool"] = "conv",
|
| 59 |
+
):
|
| 60 |
+
assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
|
| 61 |
+
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.in_channels = in_channels
|
| 64 |
+
self.out_channels = out_channels
|
| 65 |
+
|
| 66 |
+
if mode == "conv":
|
| 67 |
+
self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
|
| 68 |
+
elif mode == "avgpool":
|
| 69 |
+
assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
|
| 70 |
+
|
| 71 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
if hasattr(self, "conv"):
|
| 73 |
+
return self.conv(x)
|
| 74 |
+
else:
|
| 75 |
+
return F.avg_pool3d(x, 2)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class UpsampleBlock3d(nn.Module):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
in_channels: int,
|
| 82 |
+
out_channels: int,
|
| 83 |
+
mode: Literal["conv", "nearest"] = "conv",
|
| 84 |
+
):
|
| 85 |
+
assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
|
| 86 |
+
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.in_channels = in_channels
|
| 89 |
+
self.out_channels = out_channels
|
| 90 |
+
|
| 91 |
+
if mode == "conv":
|
| 92 |
+
self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
|
| 93 |
+
elif mode == "nearest":
|
| 94 |
+
assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
if hasattr(self, "conv"):
|
| 98 |
+
x = self.conv(x)
|
| 99 |
+
return pixel_shuffle_3d(x, 2)
|
| 100 |
+
else:
|
| 101 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class SparseStructureEncoder(nn.Module):
|
| 105 |
+
"""
|
| 106 |
+
Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
in_channels (int): Channels of the input.
|
| 110 |
+
latent_channels (int): Channels of the latent representation.
|
| 111 |
+
num_res_blocks (int): Number of residual blocks at each resolution.
|
| 112 |
+
channels (List[int]): Channels of the encoder blocks.
|
| 113 |
+
num_res_blocks_middle (int): Number of residual blocks in the middle.
|
| 114 |
+
norm_type (Literal["group", "layer"]): Type of normalization layer.
|
| 115 |
+
use_fp16 (bool): Whether to use FP16.
|
| 116 |
+
"""
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
in_channels: int,
|
| 120 |
+
latent_channels: int,
|
| 121 |
+
num_res_blocks: int,
|
| 122 |
+
channels: List[int],
|
| 123 |
+
num_res_blocks_middle: int = 2,
|
| 124 |
+
norm_type: Literal["group", "layer"] = "layer",
|
| 125 |
+
use_fp16: bool = False,
|
| 126 |
+
use_checkpoint: bool = False,
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.in_channels = in_channels
|
| 130 |
+
self.latent_channels = latent_channels
|
| 131 |
+
self.num_res_blocks = num_res_blocks
|
| 132 |
+
self.channels = channels
|
| 133 |
+
self.num_res_blocks_middle = num_res_blocks_middle
|
| 134 |
+
self.norm_type = norm_type
|
| 135 |
+
self.use_fp16 = use_fp16
|
| 136 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 137 |
+
self.use_checkpoint = use_checkpoint
|
| 138 |
+
|
| 139 |
+
self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
|
| 140 |
+
|
| 141 |
+
self.blocks = nn.ModuleList([])
|
| 142 |
+
for i, ch in enumerate(channels):
|
| 143 |
+
self.blocks.extend([
|
| 144 |
+
ResBlock3d(ch, ch)
|
| 145 |
+
for _ in range(num_res_blocks)
|
| 146 |
+
])
|
| 147 |
+
if i < len(channels) - 1:
|
| 148 |
+
self.blocks.append(
|
| 149 |
+
DownsampleBlock3d(ch, channels[i+1])
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.middle_block = nn.Sequential(*[
|
| 153 |
+
ResBlock3d(channels[-1], channels[-1])
|
| 154 |
+
for _ in range(num_res_blocks_middle)
|
| 155 |
+
])
|
| 156 |
+
|
| 157 |
+
self.out_layer = nn.Sequential(
|
| 158 |
+
norm_layer(norm_type, channels[-1]),
|
| 159 |
+
nn.SiLU(),
|
| 160 |
+
nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if use_fp16:
|
| 164 |
+
self.convert_to_fp16()
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def device(self) -> torch.device:
|
| 168 |
+
"""
|
| 169 |
+
Return the device of the model.
|
| 170 |
+
"""
|
| 171 |
+
return next(self.parameters()).device
|
| 172 |
+
|
| 173 |
+
def convert_to_fp16(self) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Convert the torso of the model to float16.
|
| 176 |
+
"""
|
| 177 |
+
self.use_fp16 = True
|
| 178 |
+
self.dtype = torch.float16
|
| 179 |
+
self.blocks.apply(convert_module_to_f16)
|
| 180 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 181 |
+
|
| 182 |
+
def convert_to_fp32(self) -> None:
|
| 183 |
+
"""
|
| 184 |
+
Convert the torso of the model to float32.
|
| 185 |
+
"""
|
| 186 |
+
self.use_fp16 = False
|
| 187 |
+
self.dtype = torch.float32
|
| 188 |
+
self.blocks.apply(convert_module_to_f32)
|
| 189 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 192 |
+
h = self.input_layer(x)
|
| 193 |
+
|
| 194 |
+
for block in self.blocks:
|
| 195 |
+
h = block(h)
|
| 196 |
+
h = self.middle_block(h)
|
| 197 |
+
|
| 198 |
+
h = self.out_layer(h)
|
| 199 |
+
|
| 200 |
+
return h
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class SparseStructureDecoder(nn.Module):
|
| 204 |
+
"""
|
| 205 |
+
Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
out_channels (int): Channels of the output.
|
| 209 |
+
latent_channels (int): Channels of the latent representation.
|
| 210 |
+
num_res_blocks (int): Number of residual blocks at each resolution.
|
| 211 |
+
channels (List[int]): Channels of the decoder blocks.
|
| 212 |
+
num_res_blocks_middle (int): Number of residual blocks in the middle.
|
| 213 |
+
norm_type (Literal["group", "layer"]): Type of normalization layer.
|
| 214 |
+
use_fp16 (bool): Whether to use FP16.
|
| 215 |
+
"""
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
out_channels: int,
|
| 219 |
+
latent_channels: int,
|
| 220 |
+
num_res_blocks: int,
|
| 221 |
+
channels: List[int],
|
| 222 |
+
num_res_blocks_middle: int = 2,
|
| 223 |
+
norm_type: Literal["group", "layer"] = "layer",
|
| 224 |
+
use_fp16: bool = False,
|
| 225 |
+
use_checkpoint: bool = False,
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.out_channels = out_channels
|
| 229 |
+
self.latent_channels = latent_channels
|
| 230 |
+
self.num_res_blocks = num_res_blocks
|
| 231 |
+
self.channels = channels
|
| 232 |
+
self.num_res_blocks_middle = num_res_blocks_middle
|
| 233 |
+
self.norm_type = norm_type
|
| 234 |
+
self.use_fp16 = use_fp16
|
| 235 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 236 |
+
self.use_checkpoint = use_checkpoint
|
| 237 |
+
|
| 238 |
+
self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
|
| 239 |
+
|
| 240 |
+
self.middle_block = nn.Sequential(*[
|
| 241 |
+
ResBlock3d(channels[0], channels[0])
|
| 242 |
+
for _ in range(num_res_blocks_middle)
|
| 243 |
+
])
|
| 244 |
+
|
| 245 |
+
self.blocks = nn.ModuleList([])
|
| 246 |
+
for i, ch in enumerate(channels):
|
| 247 |
+
self.blocks.extend([
|
| 248 |
+
ResBlock3d(ch, ch)
|
| 249 |
+
for _ in range(num_res_blocks)
|
| 250 |
+
])
|
| 251 |
+
if i < len(channels) - 1:
|
| 252 |
+
self.blocks.append(
|
| 253 |
+
UpsampleBlock3d(ch, channels[i+1])
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
self.out_layer = nn.Sequential(
|
| 257 |
+
norm_layer(norm_type, channels[-1]),
|
| 258 |
+
nn.SiLU(),
|
| 259 |
+
nn.Conv3d(channels[-1], out_channels, 3, padding=1)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if use_fp16:
|
| 263 |
+
self.convert_to_fp16()
|
| 264 |
+
|
| 265 |
+
@property
|
| 266 |
+
def device(self) -> torch.device:
|
| 267 |
+
"""
|
| 268 |
+
Return the device of the model.
|
| 269 |
+
"""
|
| 270 |
+
return next(self.parameters()).device
|
| 271 |
+
|
| 272 |
+
def convert_to_fp16(self) -> None:
|
| 273 |
+
"""
|
| 274 |
+
Convert the torso of the model to float16.
|
| 275 |
+
"""
|
| 276 |
+
self.use_fp16 = True
|
| 277 |
+
self.dtype = torch.float16
|
| 278 |
+
# self.blocks.apply(convert_module_to_f16)
|
| 279 |
+
# self.middle_block.apply(convert_module_to_f16)
|
| 280 |
+
self.apply(convert_module_to_f16)
|
| 281 |
+
|
| 282 |
+
def convert_to_fp32(self) -> None:
|
| 283 |
+
"""
|
| 284 |
+
Convert the torso of the model to float32.
|
| 285 |
+
"""
|
| 286 |
+
self.use_fp16 = False
|
| 287 |
+
self.dtype = torch.float32
|
| 288 |
+
self.blocks.apply(convert_module_to_f32)
|
| 289 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 290 |
+
|
| 291 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 292 |
+
h = self.input_layer(x)
|
| 293 |
+
|
| 294 |
+
h = self.middle_block(h)
|
| 295 |
+
for block in self.blocks:
|
| 296 |
+
h = block(h)
|
| 297 |
+
|
| 298 |
+
h = self.out_layer(h)
|
| 299 |
+
return h
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class DenseShapeVAE(nn.Module):
|
| 303 |
+
def __init__(self,
|
| 304 |
+
embed_dim: int = 0,
|
| 305 |
+
model_channels_encoder: list = [32, 128, 512],
|
| 306 |
+
model_channels_decoder: list = [512, 128, 32],
|
| 307 |
+
num_res_blocks_encoder: int = 2,
|
| 308 |
+
num_res_blocks_middle_encoder: int = 2,
|
| 309 |
+
num_res_blocks_decoder: int = 2,
|
| 310 |
+
num_res_blocks_middle_decoder: int=2,
|
| 311 |
+
in_channels: int = 1,
|
| 312 |
+
out_channels: int = 1,
|
| 313 |
+
use_fp16: bool = False,
|
| 314 |
+
use_checkpoint: bool = False,
|
| 315 |
+
latents_scale: float = 1.0,
|
| 316 |
+
latents_shift: float = 0.0):
|
| 317 |
+
|
| 318 |
+
super().__init__()
|
| 319 |
+
|
| 320 |
+
self.use_checkpoint = use_checkpoint
|
| 321 |
+
self.latents_scale = latents_scale
|
| 322 |
+
self.latents_shift = latents_shift
|
| 323 |
+
|
| 324 |
+
self.encoder = SparseStructureEncoder(
|
| 325 |
+
in_channels=in_channels,
|
| 326 |
+
latent_channels=embed_dim,
|
| 327 |
+
num_res_blocks=num_res_blocks_encoder,
|
| 328 |
+
channels=model_channels_encoder,
|
| 329 |
+
num_res_blocks_middle=num_res_blocks_middle_encoder,
|
| 330 |
+
use_fp16=use_fp16,
|
| 331 |
+
use_checkpoint=use_checkpoint,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
self.decoder = SparseStructureDecoder(
|
| 335 |
+
num_res_blocks=num_res_blocks_decoder,
|
| 336 |
+
num_res_blocks_middle=num_res_blocks_middle_decoder,
|
| 337 |
+
channels=model_channels_decoder,
|
| 338 |
+
latent_channels=embed_dim,
|
| 339 |
+
out_channels=out_channels,
|
| 340 |
+
use_fp16=use_fp16,
|
| 341 |
+
use_checkpoint=use_checkpoint,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
self.embed_dim = embed_dim
|
| 345 |
+
|
| 346 |
+
def encode(self, batch, sample_posterior: bool = True):
|
| 347 |
+
|
| 348 |
+
x = batch['dense_index'] * 2.0 - 1.0
|
| 349 |
+
h = self.encoder(x)
|
| 350 |
+
posterior = DiagonalGaussianDistribution(h, feat_dim=1)
|
| 351 |
+
if sample_posterior:
|
| 352 |
+
z = posterior.sample()
|
| 353 |
+
else:
|
| 354 |
+
z = posterior.mode()
|
| 355 |
+
|
| 356 |
+
return z, posterior
|
| 357 |
+
|
| 358 |
+
def forward(self, batch):
|
| 359 |
+
|
| 360 |
+
z, posterior = self.encode(batch)
|
| 361 |
+
reconst_x = self.decoder(z)
|
| 362 |
+
outputs = {'reconst_x': reconst_x, 'posterior': posterior}
|
| 363 |
+
|
| 364 |
+
return outputs
|
| 365 |
+
|
| 366 |
+
def decode_mesh(self,
|
| 367 |
+
latents,
|
| 368 |
+
voxel_resolution: int = 64,
|
| 369 |
+
mc_threshold: float = 0.5,
|
| 370 |
+
return_index: bool = False):
|
| 371 |
+
x = self.decoder(latents)
|
| 372 |
+
if return_index:
|
| 373 |
+
outputs = []
|
| 374 |
+
for i in range(len(x)):
|
| 375 |
+
occ = x[i].sigmoid()
|
| 376 |
+
occ = (occ >= mc_threshold).float().squeeze(0)
|
| 377 |
+
index = occ.unsqueeze(0).nonzero()
|
| 378 |
+
outputs.append(index)
|
| 379 |
+
else:
|
| 380 |
+
outputs = self.dense2mesh(x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold)
|
| 381 |
+
|
| 382 |
+
return outputs
|
| 383 |
+
|
| 384 |
+
def dense2mesh(self,
|
| 385 |
+
x: torch.FloatTensor,
|
| 386 |
+
voxel_resolution: int = 64,
|
| 387 |
+
mc_threshold: float = 0.5):
|
| 388 |
+
|
| 389 |
+
meshes = []
|
| 390 |
+
for i in range(len(x)):
|
| 391 |
+
occ = x[i].sigmoid()
|
| 392 |
+
occ = (occ >= 0.1).float().squeeze(0).cpu().detach().numpy()
|
| 393 |
+
vertices, faces, _, _ = measure.marching_cubes(
|
| 394 |
+
occ,
|
| 395 |
+
mc_threshold,
|
| 396 |
+
method="lewiner",
|
| 397 |
+
)
|
| 398 |
+
vertices = vertices / voxel_resolution * 2 - 1
|
| 399 |
+
meshes.append(trimesh.Trimesh(vertices, faces))
|
| 400 |
+
|
| 401 |
+
return meshes
|
pixal3d/models/autoencoders/distributions.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Union, List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DiagonalGaussianDistribution(object):
|
| 7 |
+
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
|
| 8 |
+
self.feat_dim = feat_dim
|
| 9 |
+
self.parameters = parameters
|
| 10 |
+
|
| 11 |
+
if isinstance(parameters, list):
|
| 12 |
+
self.mean = parameters[0]
|
| 13 |
+
self.logvar = parameters[1]
|
| 14 |
+
else:
|
| 15 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
|
| 16 |
+
|
| 17 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 18 |
+
self.deterministic = deterministic
|
| 19 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 20 |
+
self.var = torch.exp(self.logvar)
|
| 21 |
+
if self.deterministic:
|
| 22 |
+
self.var = self.std = torch.zeros_like(self.mean)
|
| 23 |
+
|
| 24 |
+
def sample(self):
|
| 25 |
+
x = self.mean + self.std * torch.randn_like(self.mean)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
def kl(self, other=None, dims=(1, 2, 3)):
|
| 29 |
+
if self.deterministic:
|
| 30 |
+
return torch.Tensor([0.])
|
| 31 |
+
else:
|
| 32 |
+
if other is None:
|
| 33 |
+
return 0.5 * torch.mean(torch.pow(self.mean, 2)
|
| 34 |
+
+ self.var - 1.0 - self.logvar,
|
| 35 |
+
dim=dims)
|
| 36 |
+
else:
|
| 37 |
+
return 0.5 * torch.mean(
|
| 38 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 39 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
| 40 |
+
dim=dims)
|
| 41 |
+
|
| 42 |
+
def nll(self, sample, dims=(1, 2, 3)):
|
| 43 |
+
if self.deterministic:
|
| 44 |
+
return torch.Tensor([0.])
|
| 45 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 46 |
+
return 0.5 * torch.sum(
|
| 47 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 48 |
+
dim=dims)
|
| 49 |
+
|
| 50 |
+
def mode(self):
|
| 51 |
+
return self.mean
|
pixal3d/models/autoencoders/encoder.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from ...modules import sparse as sp
|
| 6 |
+
from .base import SparseTransformerBase
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SparseDownBlock3d(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
channels: int,
|
| 14 |
+
out_channels: Optional[int] = None,
|
| 15 |
+
num_groups: int = 32,
|
| 16 |
+
use_checkpoint: bool = False,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.channels = channels
|
| 20 |
+
self.out_channels = out_channels or channels
|
| 21 |
+
|
| 22 |
+
self.act_layers = nn.Sequential(
|
| 23 |
+
sp.SparseGroupNorm32(num_groups, channels),
|
| 24 |
+
sp.SparseSiLU()
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.down = sp.SparseDownsample(2)
|
| 28 |
+
self.out_layers = nn.Sequential(
|
| 29 |
+
sp.SparseConv3d(channels, self.out_channels, 3, padding=1),
|
| 30 |
+
sp.SparseGroupNorm32(num_groups, self.out_channels),
|
| 31 |
+
sp.SparseSiLU(),
|
| 32 |
+
sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if self.out_channels == channels:
|
| 36 |
+
self.skip_connection = nn.Identity()
|
| 37 |
+
else:
|
| 38 |
+
self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1)
|
| 39 |
+
|
| 40 |
+
self.use_checkpoint = use_checkpoint
|
| 41 |
+
|
| 42 |
+
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
| 43 |
+
h = self.act_layers(x)
|
| 44 |
+
h = self.down(h)
|
| 45 |
+
x = self.down(x)
|
| 46 |
+
h = self.out_layers(h)
|
| 47 |
+
h = h + self.skip_connection(x)
|
| 48 |
+
return h
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor):
|
| 51 |
+
if self.use_checkpoint:
|
| 52 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
| 53 |
+
else:
|
| 54 |
+
return self._forward(x)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SparseSDFEncoder(SparseTransformerBase):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
resolution: int,
|
| 61 |
+
in_channels: int,
|
| 62 |
+
model_channels: int,
|
| 63 |
+
latent_channels: int,
|
| 64 |
+
num_blocks: int,
|
| 65 |
+
num_heads: Optional[int] = None,
|
| 66 |
+
num_head_channels: Optional[int] = 64,
|
| 67 |
+
mlp_ratio: float = 4,
|
| 68 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
| 69 |
+
window_size: int = 8,
|
| 70 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 71 |
+
use_fp16: bool = False,
|
| 72 |
+
use_checkpoint: bool = False,
|
| 73 |
+
qk_rms_norm: bool = False,
|
| 74 |
+
):
|
| 75 |
+
super().__init__(
|
| 76 |
+
in_channels=in_channels,
|
| 77 |
+
model_channels=model_channels,
|
| 78 |
+
num_blocks=num_blocks,
|
| 79 |
+
num_heads=num_heads,
|
| 80 |
+
num_head_channels=num_head_channels,
|
| 81 |
+
mlp_ratio=mlp_ratio,
|
| 82 |
+
attn_mode=attn_mode,
|
| 83 |
+
window_size=window_size,
|
| 84 |
+
pe_mode=pe_mode,
|
| 85 |
+
use_fp16=use_fp16,
|
| 86 |
+
use_checkpoint=use_checkpoint,
|
| 87 |
+
qk_rms_norm=qk_rms_norm,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
self.input_layer1 = sp.SparseLinear(1, model_channels // 16)
|
| 91 |
+
|
| 92 |
+
self.downsample = nn.ModuleList([
|
| 93 |
+
SparseDownBlock3d(
|
| 94 |
+
channels=model_channels//16,
|
| 95 |
+
out_channels=model_channels // 8,
|
| 96 |
+
use_checkpoint=use_checkpoint,
|
| 97 |
+
),
|
| 98 |
+
SparseDownBlock3d(
|
| 99 |
+
channels=model_channels // 8,
|
| 100 |
+
out_channels=model_channels // 4,
|
| 101 |
+
use_checkpoint=use_checkpoint,
|
| 102 |
+
),
|
| 103 |
+
SparseDownBlock3d(
|
| 104 |
+
channels=model_channels // 4,
|
| 105 |
+
out_channels=model_channels,
|
| 106 |
+
use_checkpoint=use_checkpoint,
|
| 107 |
+
)
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
self.resolution = resolution
|
| 111 |
+
self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
|
| 112 |
+
|
| 113 |
+
self.initialize_weights()
|
| 114 |
+
if use_fp16:
|
| 115 |
+
self.convert_to_fp16()
|
| 116 |
+
|
| 117 |
+
def initialize_weights(self) -> None:
|
| 118 |
+
super().initialize_weights()
|
| 119 |
+
# Zero-out output layers:
|
| 120 |
+
nn.init.constant_(self.out_layer.weight, 0)
|
| 121 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
| 122 |
+
|
| 123 |
+
def forward(self, x: sp.SparseTensor, factor: float = None):
|
| 124 |
+
|
| 125 |
+
x = self.input_layer1(x)
|
| 126 |
+
for block in self.downsample:
|
| 127 |
+
x = block(x)
|
| 128 |
+
h = super().forward(x, factor)
|
| 129 |
+
h = h.type(x.dtype)
|
| 130 |
+
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
| 131 |
+
h = self.out_layer(h)
|
| 132 |
+
|
| 133 |
+
return h
|
pixal3d/models/autoencoders/ss_vae.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import trimesh
|
| 7 |
+
from skimage import measure
|
| 8 |
+
|
| 9 |
+
from ...modules import sparse as sp
|
| 10 |
+
from .encoder import SparseSDFEncoder
|
| 11 |
+
from .decoder import SparseSDFDecoder
|
| 12 |
+
from .distributions import DiagonalGaussianDistribution
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SparseSDFVAE(nn.Module):
|
| 16 |
+
def __init__(self, *,
|
| 17 |
+
embed_dim: int = 0,
|
| 18 |
+
resolution: int = 64,
|
| 19 |
+
model_channels_encoder: int = 512,
|
| 20 |
+
num_blocks_encoder: int = 4,
|
| 21 |
+
num_heads_encoder: int = 8,
|
| 22 |
+
num_head_channels_encoder: int = 64,
|
| 23 |
+
model_channels_decoder: int = 512,
|
| 24 |
+
num_blocks_decoder: int = 4,
|
| 25 |
+
num_heads_decoder: int = 8,
|
| 26 |
+
num_head_channels_decoder: int = 64,
|
| 27 |
+
out_channels: int = 1,
|
| 28 |
+
use_fp16: bool = False,
|
| 29 |
+
use_checkpoint: bool = False,
|
| 30 |
+
chunk_size: int = 1,
|
| 31 |
+
latents_scale: float = 1.0,
|
| 32 |
+
latents_shift: float = 0.0):
|
| 33 |
+
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.use_checkpoint = use_checkpoint
|
| 37 |
+
self.resolution = resolution
|
| 38 |
+
self.latents_scale = latents_scale
|
| 39 |
+
self.latents_shift = latents_shift
|
| 40 |
+
|
| 41 |
+
self.encoder = SparseSDFEncoder(
|
| 42 |
+
resolution=resolution,
|
| 43 |
+
in_channels=model_channels_encoder,
|
| 44 |
+
model_channels=model_channels_encoder,
|
| 45 |
+
latent_channels=embed_dim,
|
| 46 |
+
num_blocks=num_blocks_encoder,
|
| 47 |
+
num_heads=num_heads_encoder,
|
| 48 |
+
num_head_channels=num_head_channels_encoder,
|
| 49 |
+
use_fp16=use_fp16,
|
| 50 |
+
use_checkpoint=use_checkpoint,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.decoder = SparseSDFDecoder(
|
| 54 |
+
resolution=resolution,
|
| 55 |
+
model_channels=model_channels_decoder,
|
| 56 |
+
latent_channels=embed_dim,
|
| 57 |
+
num_blocks=num_blocks_decoder,
|
| 58 |
+
num_heads=num_heads_decoder,
|
| 59 |
+
num_head_channels=num_head_channels_decoder,
|
| 60 |
+
out_channels=out_channels,
|
| 61 |
+
use_fp16=use_fp16,
|
| 62 |
+
use_checkpoint=use_checkpoint,
|
| 63 |
+
chunk_size=chunk_size,
|
| 64 |
+
)
|
| 65 |
+
self.embed_dim = embed_dim
|
| 66 |
+
|
| 67 |
+
def forward(self, batch):
|
| 68 |
+
|
| 69 |
+
z, posterior = self.encode(batch)
|
| 70 |
+
|
| 71 |
+
reconst_x = self.decoder(z)
|
| 72 |
+
outputs = {'reconst_x': reconst_x, 'posterior': posterior}
|
| 73 |
+
return outputs
|
| 74 |
+
|
| 75 |
+
def encode(self, batch, sample_posterior: bool = True):
|
| 76 |
+
|
| 77 |
+
feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx']
|
| 78 |
+
if feat.ndim == 1:
|
| 79 |
+
feat = feat.unsqueeze(-1)
|
| 80 |
+
coords = torch.cat([batch_idx.unsqueeze(-1), xyz], dim=-1).int()
|
| 81 |
+
|
| 82 |
+
x = sp.SparseTensor(feat, coords)
|
| 83 |
+
h = self.encoder(x, batch.get('factor', None))
|
| 84 |
+
posterior = DiagonalGaussianDistribution(h.feats, feat_dim=1)
|
| 85 |
+
if sample_posterior:
|
| 86 |
+
z = posterior.sample()
|
| 87 |
+
else:
|
| 88 |
+
z = posterior.mode()
|
| 89 |
+
z = h.replace(z)
|
| 90 |
+
|
| 91 |
+
return z, posterior
|
| 92 |
+
|
| 93 |
+
def decode_mesh(self,
|
| 94 |
+
latents,
|
| 95 |
+
voxel_resolution: int = 512,
|
| 96 |
+
mc_threshold: float = 0.2,
|
| 97 |
+
return_feat: bool = False,
|
| 98 |
+
factor: float = 1.0):
|
| 99 |
+
voxel_resolution = int(voxel_resolution / factor)
|
| 100 |
+
reconst_x = self.decoder(latents, factor=factor, return_feat=return_feat)
|
| 101 |
+
if return_feat:
|
| 102 |
+
return reconst_x
|
| 103 |
+
outputs = self.sparse2mesh(reconst_x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold)
|
| 104 |
+
|
| 105 |
+
return outputs
|
| 106 |
+
|
| 107 |
+
def sparse2mesh(self,
|
| 108 |
+
reconst_x: torch.FloatTensor,
|
| 109 |
+
voxel_resolution: int = 512,
|
| 110 |
+
mc_threshold: float = 0.0):
|
| 111 |
+
|
| 112 |
+
sparse_sdf, sparse_index = reconst_x.feats.float(), reconst_x.coords
|
| 113 |
+
batch_size = int(sparse_index[..., 0].max().cpu().numpy() + 1)
|
| 114 |
+
|
| 115 |
+
meshes = []
|
| 116 |
+
for i in range(batch_size):
|
| 117 |
+
idx = sparse_index[..., 0] == i
|
| 118 |
+
sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1).cpu(), sparse_index[idx][..., 1:].detach().cpu()
|
| 119 |
+
sdf = torch.ones((voxel_resolution, voxel_resolution, voxel_resolution))
|
| 120 |
+
sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i
|
| 121 |
+
vertices, faces, _, _ = measure.marching_cubes(
|
| 122 |
+
sdf.numpy(),
|
| 123 |
+
mc_threshold,
|
| 124 |
+
method="lewiner",
|
| 125 |
+
)
|
| 126 |
+
vertices = vertices / voxel_resolution * 2 - 1
|
| 127 |
+
meshes.append(trimesh.Trimesh(vertices, faces))
|
| 128 |
+
|
| 129 |
+
return meshes
|
pixal3d/models/conditional_encoders/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import dinov2_project_grid
|
| 2 |
+
|
pixal3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (238 Bytes). View file
|
|
|
pixal3d/models/conditional_encoders/__pycache__/dinov2_project_grid.cpython-310.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
pixal3d/models/conditional_encoders/dinov2_project_grid.py
ADDED
|
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DINOv2 Project Grid Encoders
|
| 3 |
+
Includes single-view and multi-view DINOv2 encoders with 3D grid projection support
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import List, Dict, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
|
| 16 |
+
import pixal3d
|
| 17 |
+
from pixal3d.utils.base import BaseModule
|
| 18 |
+
|
| 19 |
+
# Set linear algebra backend to avoid cusolver errors
|
| 20 |
+
try:
|
| 21 |
+
torch.backends.cuda.preferred_linalg_library("cusolver")
|
| 22 |
+
except Exception:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# =============================================================================
|
| 27 |
+
# Base DINOv2 Encoder
|
| 28 |
+
# =============================================================================
|
| 29 |
+
|
| 30 |
+
@pixal3d.register("dinov2-encoder")
|
| 31 |
+
class DinoEncoder(BaseModule, ModelMixin):
|
| 32 |
+
"""Base DINOv2 Encoder"""
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class Config(BaseModule.Config):
|
| 36 |
+
model: str = "facebookresearch/dinov2"
|
| 37 |
+
version: str = "dinov2_vitl14_reg"
|
| 38 |
+
size: int = 518
|
| 39 |
+
empty_embeds_ratio: float = 0.1
|
| 40 |
+
|
| 41 |
+
cfg: Config
|
| 42 |
+
|
| 43 |
+
def configure(self) -> None:
|
| 44 |
+
super().configure()
|
| 45 |
+
self.empty_embeds_ratio = self.cfg.empty_embeds_ratio
|
| 46 |
+
|
| 47 |
+
# Load DINOv2 model
|
| 48 |
+
dino_model = torch.hub.load(
|
| 49 |
+
self.cfg.model, self.cfg.version, pretrained=True
|
| 50 |
+
)
|
| 51 |
+
self.encoder = dino_model.eval()
|
| 52 |
+
|
| 53 |
+
# Image preprocessing
|
| 54 |
+
self.transform = transforms.Compose([
|
| 55 |
+
transforms.Resize(
|
| 56 |
+
self.cfg.size,
|
| 57 |
+
transforms.InterpolationMode.BILINEAR,
|
| 58 |
+
antialias=True
|
| 59 |
+
),
|
| 60 |
+
transforms.CenterCrop(self.cfg.size),
|
| 61 |
+
transforms.Normalize(
|
| 62 |
+
mean=[0.485, 0.456, 0.406],
|
| 63 |
+
std=[0.229, 0.224, 0.225],
|
| 64 |
+
),
|
| 65 |
+
])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def forward(self, image, image_mask=None, is_training=False):
|
| 71 |
+
z = self.encoder(self.transform(image), is_training=True)['x_prenorm']
|
| 72 |
+
z = F.layer_norm(z, z.shape[-1:])
|
| 73 |
+
|
| 74 |
+
if is_training and random.random() < self.empty_embeds_ratio:
|
| 75 |
+
# zero out embeddings
|
| 76 |
+
z = z * 0
|
| 77 |
+
|
| 78 |
+
if image_mask is not None:
|
| 79 |
+
image_mask_patch = F.max_pool2d(
|
| 80 |
+
image_mask, kernel_size=14, stride=14
|
| 81 |
+
).squeeze(1) > 0
|
| 82 |
+
return z, image_mask_patch
|
| 83 |
+
|
| 84 |
+
return z
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# =============================================================================
|
| 88 |
+
# 3D Projection Utility Functions
|
| 89 |
+
# =============================================================================
|
| 90 |
+
|
| 91 |
+
def project_points_to_image_batch(
|
| 92 |
+
points_3d: torch.Tensor,
|
| 93 |
+
transform_matrix: torch.Tensor,
|
| 94 |
+
camera_angle_x: torch.Tensor,
|
| 95 |
+
resolution: int = 518
|
| 96 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 97 |
+
"""
|
| 98 |
+
Project 3D points to 2D image coordinates with batch support
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
points_3d: [N, 3] or [B, N, 3], 3D point coordinates (in [-1, 1] range)
|
| 102 |
+
transform_matrix: [B, 4, 4], batch of camera transformation matrices
|
| 103 |
+
camera_angle_x: [B], batch of camera horizontal FOV angles (radians)
|
| 104 |
+
resolution: Rendering image resolution
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
points_2d: [B, N, 2], image coordinates [x, y]
|
| 108 |
+
depth: [B, N], depth values
|
| 109 |
+
valid_mask: [B, N], mask indicating if points are within view
|
| 110 |
+
"""
|
| 111 |
+
device = points_3d.device
|
| 112 |
+
B = transform_matrix.shape[0]
|
| 113 |
+
|
| 114 |
+
# Ensure inputs are torch.Tensor
|
| 115 |
+
if not isinstance(transform_matrix, torch.Tensor):
|
| 116 |
+
transform_matrix = torch.tensor(
|
| 117 |
+
transform_matrix, dtype=torch.float32, device=device
|
| 118 |
+
)
|
| 119 |
+
if not isinstance(points_3d, torch.Tensor):
|
| 120 |
+
points_3d = torch.tensor(
|
| 121 |
+
points_3d, dtype=torch.float32, device=device
|
| 122 |
+
)
|
| 123 |
+
if not isinstance(camera_angle_x, torch.Tensor):
|
| 124 |
+
camera_angle_x = torch.tensor(
|
| 125 |
+
camera_angle_x, dtype=torch.float32, device=device
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Expand points_3d to batch dimension
|
| 129 |
+
if points_3d.dim() == 2:
|
| 130 |
+
points_3d_batch = points_3d.unsqueeze(0).expand(B, -1, -1)
|
| 131 |
+
else:
|
| 132 |
+
points_3d_batch = points_3d
|
| 133 |
+
|
| 134 |
+
N = points_3d_batch.shape[1]
|
| 135 |
+
|
| 136 |
+
# Add homogeneous coordinates
|
| 137 |
+
ones = torch.ones(B, N, 1, device=device)
|
| 138 |
+
points_homogeneous = torch.cat([points_3d_batch, ones], dim=-1)
|
| 139 |
+
|
| 140 |
+
# World to camera transformation
|
| 141 |
+
world_to_camera = torch.linalg.inv(transform_matrix)
|
| 142 |
+
points_camera = torch.bmm(
|
| 143 |
+
points_homogeneous,
|
| 144 |
+
world_to_camera.transpose(-2, -1)
|
| 145 |
+
)[..., :3]
|
| 146 |
+
|
| 147 |
+
# Extract camera coordinates
|
| 148 |
+
x_cam = points_camera[..., 0]
|
| 149 |
+
y_cam = points_camera[..., 1]
|
| 150 |
+
z_cam = points_camera[..., 2]
|
| 151 |
+
|
| 152 |
+
# Depth values
|
| 153 |
+
depth = -z_cam
|
| 154 |
+
|
| 155 |
+
# Compute camera intrinsics
|
| 156 |
+
sensor_width = 32.0
|
| 157 |
+
focal_length = 16.0 / torch.tan(camera_angle_x / 2.0)
|
| 158 |
+
focal_length_pixels = focal_length * resolution / sensor_width
|
| 159 |
+
focal_length_pixels = focal_length_pixels.unsqueeze(1)
|
| 160 |
+
|
| 161 |
+
# Perspective projection
|
| 162 |
+
x_ndc = focal_length_pixels * x_cam / (-z_cam)
|
| 163 |
+
y_ndc = focal_length_pixels * y_cam / (-z_cam)
|
| 164 |
+
|
| 165 |
+
# Convert to image coordinates
|
| 166 |
+
x_pixel = x_ndc + resolution / 2.0
|
| 167 |
+
y_pixel = -y_ndc + resolution / 2.0
|
| 168 |
+
|
| 169 |
+
# Validity mask
|
| 170 |
+
valid_mask = (
|
| 171 |
+
(x_pixel >= 0) & (x_pixel < resolution) &
|
| 172 |
+
(y_pixel >= 0) & (y_pixel < resolution) &
|
| 173 |
+
(depth > 0)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
points_2d = torch.stack([x_pixel, y_pixel], dim=-1)
|
| 177 |
+
return points_2d, depth, valid_mask
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def project_points_to_image(
|
| 181 |
+
points_3d: torch.Tensor,
|
| 182 |
+
transform_matrix: torch.Tensor,
|
| 183 |
+
camera_angle_x: float,
|
| 184 |
+
resolution: int = 512
|
| 185 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 186 |
+
"""
|
| 187 |
+
Project 3D points to 2D image coordinates (single-view version)
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
points_3d: [N, 3], 3D point coordinates
|
| 191 |
+
transform_matrix: [4, 4], camera transformation matrix
|
| 192 |
+
camera_angle_x: Camera horizontal FOV angle (radians)
|
| 193 |
+
resolution: Rendering image resolution
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
points_2d: [N, 2], image coordinates [x, y]
|
| 197 |
+
depth: [N], depth values
|
| 198 |
+
valid_mask: [N], mask indicating if points are within view
|
| 199 |
+
"""
|
| 200 |
+
device = points_3d.device
|
| 201 |
+
|
| 202 |
+
if not isinstance(transform_matrix, torch.Tensor):
|
| 203 |
+
transform_matrix = torch.tensor(
|
| 204 |
+
transform_matrix, dtype=torch.float32, device=device
|
| 205 |
+
)
|
| 206 |
+
if not isinstance(points_3d, torch.Tensor):
|
| 207 |
+
points_3d = torch.tensor(
|
| 208 |
+
points_3d, dtype=torch.float32, device=device
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
N = points_3d.shape[0]
|
| 212 |
+
points_homogeneous = torch.cat([
|
| 213 |
+
points_3d,
|
| 214 |
+
torch.ones(N, 1, device=device)
|
| 215 |
+
], dim=1)
|
| 216 |
+
|
| 217 |
+
# World to camera transformation
|
| 218 |
+
camera_to_world = transform_matrix
|
| 219 |
+
world_to_camera = torch.linalg.inv(camera_to_world)
|
| 220 |
+
points_camera = torch.matmul(
|
| 221 |
+
points_homogeneous,
|
| 222 |
+
world_to_camera.T
|
| 223 |
+
)[:, :3]
|
| 224 |
+
|
| 225 |
+
x_cam = points_camera[:, 0]
|
| 226 |
+
y_cam = points_camera[:, 1]
|
| 227 |
+
z_cam = points_camera[:, 2]
|
| 228 |
+
depth = -z_cam
|
| 229 |
+
|
| 230 |
+
# Camera intrinsics
|
| 231 |
+
sensor_width = 32.0
|
| 232 |
+
focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0))
|
| 233 |
+
focal_length_pixels = focal_length * resolution / sensor_width
|
| 234 |
+
|
| 235 |
+
# Perspective projection
|
| 236 |
+
x_ndc = focal_length_pixels * x_cam / (-z_cam)
|
| 237 |
+
y_ndc = focal_length_pixels * y_cam / (-z_cam)
|
| 238 |
+
|
| 239 |
+
# Image coordinates
|
| 240 |
+
x_pixel = x_ndc + resolution / 2.0
|
| 241 |
+
y_pixel = -y_ndc + resolution / 2.0
|
| 242 |
+
|
| 243 |
+
valid_mask = (
|
| 244 |
+
(x_pixel >= 0) & (x_pixel < resolution) &
|
| 245 |
+
(y_pixel >= 0) & (y_pixel < resolution) &
|
| 246 |
+
(depth > 0)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
points_2d = torch.stack([x_pixel, y_pixel], dim=1)
|
| 250 |
+
return points_2d, depth, valid_mask
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def sample_features(
|
| 254 |
+
fmap: torch.Tensor,
|
| 255 |
+
queries_ndc: torch.Tensor
|
| 256 |
+
) -> torch.Tensor:
|
| 257 |
+
"""
|
| 258 |
+
Sample features using grid_sample
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
fmap: [B, C, H, W], feature map
|
| 262 |
+
queries_ndc: [B, K, 2], NDC coordinates
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
feat: [B, C, K], sampled features
|
| 266 |
+
"""
|
| 267 |
+
B, C, H, W = fmap.shape
|
| 268 |
+
Bq, K, _ = queries_ndc.shape
|
| 269 |
+
assert Bq == B, "batch 不一致"
|
| 270 |
+
|
| 271 |
+
grid = queries_ndc.view(B, K, 1, 2)
|
| 272 |
+
feat = F.grid_sample(
|
| 273 |
+
fmap, grid, mode='bilinear',
|
| 274 |
+
align_corners=False, padding_mode='border'
|
| 275 |
+
)
|
| 276 |
+
return feat.squeeze(-1)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# =============================================================================
|
| 280 |
+
# Projection Grid Module
|
| 281 |
+
# =============================================================================
|
| 282 |
+
|
| 283 |
+
class ProjGrid(nn.Module):
|
| 284 |
+
"""3D Grid Projection Module"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, grid_resolution: int = 16):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.grid_resolution = grid_resolution
|
| 289 |
+
self.image_resolution = 518
|
| 290 |
+
|
| 291 |
+
# Create 3D grid points
|
| 292 |
+
one_dim = torch.linspace(-1, 1, grid_resolution)
|
| 293 |
+
x, y, z = torch.meshgrid(one_dim, one_dim, one_dim, indexing='ij')
|
| 294 |
+
grid_points = torch.stack((x, y, z), dim=-1)
|
| 295 |
+
|
| 296 |
+
# Rotation matrix (align with Blender)
|
| 297 |
+
rotation_matrix = torch.tensor([
|
| 298 |
+
[1.0, 0.0, 0.0],
|
| 299 |
+
[0.0, 0.0, -1.0],
|
| 300 |
+
[0.0, 1.0, 0.0]
|
| 301 |
+
])
|
| 302 |
+
grid_points = torch.matmul(grid_points, rotation_matrix.T)
|
| 303 |
+
grid_points = grid_points.reshape(-1, 3)
|
| 304 |
+
self.register_buffer('grid_points', grid_points)
|
| 305 |
+
|
| 306 |
+
# Front view transformation matrix
|
| 307 |
+
front_view_transform_matrix = torch.tensor([
|
| 308 |
+
[1.0, 0.0, 0.0, 0.0],
|
| 309 |
+
[0.0, 0.0, -1.0, -2.0],
|
| 310 |
+
[0.0, 1.0, 0.0, 0.0],
|
| 311 |
+
[0.0, 0.0, 0.0, 1.0]
|
| 312 |
+
])
|
| 313 |
+
self.register_buffer(
|
| 314 |
+
"front_view_transform_matrix",
|
| 315 |
+
front_view_transform_matrix
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def forward(
|
| 319 |
+
self,
|
| 320 |
+
features_map: torch.Tensor,
|
| 321 |
+
camera_angle_x: torch.Tensor,
|
| 322 |
+
distance: torch.Tensor,
|
| 323 |
+
mesh_scale: torch.Tensor,
|
| 324 |
+
transform_matrix: torch.Tensor = None,
|
| 325 |
+
BHWC: bool = True
|
| 326 |
+
) -> torch.Tensor:
|
| 327 |
+
"""
|
| 328 |
+
Project feature map to 3D grid
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
features_map: [B, H, W, C] or [B, C, H, W]
|
| 332 |
+
camera_angle_x: [B]
|
| 333 |
+
distance: [B]
|
| 334 |
+
mesh_scale: [B]
|
| 335 |
+
transform_matrix: [B, 4, 4] or None
|
| 336 |
+
BHWC: Whether input is in BHWC format
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
x: [B, K, C], projected features
|
| 340 |
+
"""
|
| 341 |
+
if BHWC:
|
| 342 |
+
B, H, W, C = features_map.shape
|
| 343 |
+
else:
|
| 344 |
+
B, C, H, W = features_map.shape
|
| 345 |
+
|
| 346 |
+
# Prepare grid points
|
| 347 |
+
grid_points = self.grid_points.expand(B, -1, -1)
|
| 348 |
+
grid_points = grid_points / mesh_scale.unsqueeze(-1).unsqueeze(-1) / 2
|
| 349 |
+
|
| 350 |
+
# Use default transformation matrix
|
| 351 |
+
if transform_matrix is None:
|
| 352 |
+
transform_matrix = self.front_view_transform_matrix
|
| 353 |
+
transform_matrix = transform_matrix.expand(B, -1, -1).clone()
|
| 354 |
+
transform_matrix[:, 1, 3] = -distance
|
| 355 |
+
|
| 356 |
+
# Project to image
|
| 357 |
+
image_points, depth, valid_mask = project_points_to_image_batch(
|
| 358 |
+
grid_points, transform_matrix, camera_angle_x, self.image_resolution
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Normalize to [-1, 1]
|
| 362 |
+
|
| 363 |
+
image_points_norm = (image_points + 0.5) / self.image_resolution * 2 - 1
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# Adjust dimensions and sample
|
| 367 |
+
if BHWC:
|
| 368 |
+
features_map = features_map.permute(0, 3, 1, 2)
|
| 369 |
+
|
| 370 |
+
x = sample_features(features_map, image_points_norm)
|
| 371 |
+
x = x.permute(0, 2, 1)
|
| 372 |
+
|
| 373 |
+
return x
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# =============================================================================
|
| 380 |
+
# DINOv2 Encoder with Projection
|
| 381 |
+
# =============================================================================
|
| 382 |
+
|
| 383 |
+
@pixal3d.register("dinov2-encoder-proj")
|
| 384 |
+
class DinoEncoderProj(BaseModule, ModelMixin):
|
| 385 |
+
"""DINOv2 Encoder with 3D Grid Projection"""
|
| 386 |
+
|
| 387 |
+
@dataclass
|
| 388 |
+
class Config(BaseModule.Config):
|
| 389 |
+
model: str = "facebookresearch/dinov2"
|
| 390 |
+
version: str = "dinov2_vitl14_reg"
|
| 391 |
+
size: int = 518
|
| 392 |
+
empty_embeds_ratio: float = 0.1
|
| 393 |
+
grid_resolution: int = 16
|
| 394 |
+
use_upsample: bool = False
|
| 395 |
+
use_geo_feats: bool = False
|
| 396 |
+
|
| 397 |
+
cfg: Config
|
| 398 |
+
|
| 399 |
+
def configure(self) -> None:
|
| 400 |
+
super().configure()
|
| 401 |
+
self.grid_resolution = self.cfg.grid_resolution
|
| 402 |
+
self.empty_embeds_ratio = self.cfg.empty_embeds_ratio
|
| 403 |
+
self.use_upsample = self.cfg.use_upsample
|
| 404 |
+
|
| 405 |
+
# Load DINOv2
|
| 406 |
+
dino_model = torch.hub.load(
|
| 407 |
+
self.cfg.model, self.cfg.version, pretrained=True
|
| 408 |
+
)
|
| 409 |
+
self.encoder = dino_model.eval()
|
| 410 |
+
|
| 411 |
+
# Optional: load upsampler
|
| 412 |
+
if self.use_upsample:
|
| 413 |
+
upsampler = torch.hub.load("valeoai/NAF", "naf", pretrained=True)
|
| 414 |
+
self.upsampler = upsampler.eval()
|
| 415 |
+
|
| 416 |
+
# Image preprocessing (normalization only)
|
| 417 |
+
self.transform = transforms.Compose([
|
| 418 |
+
transforms.Normalize(
|
| 419 |
+
mean=[0.485, 0.456, 0.406],
|
| 420 |
+
std=[0.229, 0.224, 0.225],
|
| 421 |
+
),
|
| 422 |
+
])
|
| 423 |
+
|
| 424 |
+
self.patch_size = self.encoder.patch_size
|
| 425 |
+
self.patch_number = self.cfg.size // self.patch_size
|
| 426 |
+
self.proj_grid = ProjGrid(grid_resolution=self.cfg.grid_resolution)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def forward(
|
| 434 |
+
self,
|
| 435 |
+
image: torch.Tensor,
|
| 436 |
+
image_mask: torch.Tensor = None,
|
| 437 |
+
camera_angle_x: torch.Tensor = None,
|
| 438 |
+
distance: torch.Tensor = None,
|
| 439 |
+
mesh_scale: torch.Tensor = None,
|
| 440 |
+
transform_matrix: torch.Tensor = None,
|
| 441 |
+
is_training: bool = False
|
| 442 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 443 |
+
"""
|
| 444 |
+
Forward pass
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
image: [B, C, H, W]
|
| 448 |
+
camera_angle_x: [B]
|
| 449 |
+
distance: [B]
|
| 450 |
+
mesh_scale: [B]
|
| 451 |
+
is_training: Training mode flag
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
z_global: [B, num_global, C]
|
| 455 |
+
z: [B, grid_resolution^3, C]
|
| 456 |
+
"""
|
| 457 |
+
image = self.transform(image)
|
| 458 |
+
|
| 459 |
+
with torch.no_grad():
|
| 460 |
+
z = self.encoder(image, is_training=True)['x_prenorm']
|
| 461 |
+
z = F.layer_norm(z, z.shape[-1:])
|
| 462 |
+
|
| 463 |
+
# Split tokens
|
| 464 |
+
z_clstoken = z[:, 0:1]
|
| 465 |
+
z_regtokens = z[:, 1:self.encoder.num_register_tokens + 1]
|
| 466 |
+
z_patchtokens = z[:, 1 + self.encoder.num_register_tokens:]
|
| 467 |
+
z_patchtokens = z_patchtokens.reshape(
|
| 468 |
+
z_patchtokens.shape[0],
|
| 469 |
+
self.patch_number,
|
| 470 |
+
self.patch_number,
|
| 471 |
+
-1
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Project to grid
|
| 475 |
+
z = self.proj_grid(
|
| 476 |
+
z_patchtokens, camera_angle_x, distance, mesh_scale
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Optional: upsample and fuse
|
| 480 |
+
if self.use_upsample:
|
| 481 |
+
z_patchtokens_permuted = z_patchtokens.permute(0, 3, 1, 2)
|
| 482 |
+
z_upsampled = self.upsampler(
|
| 483 |
+
image, z_patchtokens_permuted, output_size=(518, 518)
|
| 484 |
+
)
|
| 485 |
+
z_upsampled = self.proj_grid(
|
| 486 |
+
z_upsampled, camera_angle_x, distance, mesh_scale, BHWC=False
|
| 487 |
+
)
|
| 488 |
+
z = z + z_upsampled
|
| 489 |
+
|
| 490 |
+
# Global tokens
|
| 491 |
+
z_global = torch.cat([z_clstoken, z_regtokens], dim=1)
|
| 492 |
+
z_global = z_global.expand(z.shape[0], -1, -1)
|
| 493 |
+
|
| 494 |
+
# Classifier-free guidance: random drop
|
| 495 |
+
if is_training and random.random() < self.empty_embeds_ratio:
|
| 496 |
+
z_global = z_global * 0
|
| 497 |
+
z = z * 0
|
| 498 |
+
|
| 499 |
+
return z_global, z
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# =============================================================================
|
| 503 |
+
# Multi-View Projection Encoder Helper Functions
|
| 504 |
+
# =============================================================================
|
| 505 |
+
|
| 506 |
+
def compute_calc_mat(
|
| 507 |
+
true_view_mat: torch.Tensor,
|
| 508 |
+
ext_true_view_mat: torch.Tensor,
|
| 509 |
+
fix_mat: torch.Tensor
|
| 510 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 511 |
+
"""
|
| 512 |
+
Compute calc_mat using matrix relative transformation
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
true_view_mat: [B, 1, 4, 4], ground truth camera matrix
|
| 516 |
+
ext_true_view_mat: [B, N, 4, 4], extended ground truth camera matrices
|
| 517 |
+
fix_mat: [B, 1, 4, 4], fixed matrix
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
calc_mat: [B, N, 4, 4]
|
| 521 |
+
relative_transform: [B, N, 4, 4]
|
| 522 |
+
"""
|
| 523 |
+
B, N = ext_true_view_mat.shape[:2]
|
| 524 |
+
|
| 525 |
+
# Expand to [B, N, 4, 4]
|
| 526 |
+
true_view_mat_exp = true_view_mat.expand(B, N, 4, 4)
|
| 527 |
+
fix_mat_exp = fix_mat.expand(B, N, 4, 4)
|
| 528 |
+
|
| 529 |
+
# Flatten to [B*N, 4, 4]
|
| 530 |
+
true_view_mat_flat = true_view_mat_exp.reshape(B * N, 4, 4)
|
| 531 |
+
ext_true_view_mat_flat = ext_true_view_mat.reshape(B * N, 4, 4)
|
| 532 |
+
fix_mat_flat = fix_mat_exp.reshape(B * N, 4, 4)
|
| 533 |
+
|
| 534 |
+
# Compute relative transformation (disable autocast for fp32 precision)
|
| 535 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 536 |
+
true_view_mat_flat = true_view_mat_flat.float()
|
| 537 |
+
ext_true_view_mat_flat = ext_true_view_mat_flat.float()
|
| 538 |
+
fix_mat_flat = fix_mat_flat.float()
|
| 539 |
+
|
| 540 |
+
relative_transform_flat = torch.bmm(
|
| 541 |
+
torch.linalg.inv(true_view_mat_flat),
|
| 542 |
+
ext_true_view_mat_flat
|
| 543 |
+
)
|
| 544 |
+
calc_mat_flat = torch.bmm(fix_mat_flat, relative_transform_flat)
|
| 545 |
+
|
| 546 |
+
calc_mat = calc_mat_flat.view(B, N, 4, 4)
|
| 547 |
+
relative_transform = relative_transform_flat.view(B, N, 4, 4)
|
| 548 |
+
|
| 549 |
+
return calc_mat, relative_transform
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# =============================================================================
|
| 553 |
+
# Multi-View DINOv2 Projection Encoder
|
| 554 |
+
# =============================================================================
|
| 555 |
+
|
| 556 |
+
@pixal3d.register("dinov2-encoder-proj-multi-view")
|
| 557 |
+
class DinoEncoderProjMultiView(BaseModule, ModelMixin):
|
| 558 |
+
"""Multi-View DINOv2 Projection Encoder"""
|
| 559 |
+
|
| 560 |
+
@dataclass
|
| 561 |
+
class Config(BaseModule.Config):
|
| 562 |
+
model: str = "facebookresearch/dinov2"
|
| 563 |
+
version: str = "dinov2_vitl14_reg"
|
| 564 |
+
size: int = 518
|
| 565 |
+
empty_embeds_ratio: float = 0.1
|
| 566 |
+
grid_resolution: int = 16
|
| 567 |
+
use_upsample: bool = False
|
| 568 |
+
|
| 569 |
+
cfg: Config
|
| 570 |
+
|
| 571 |
+
def configure(self) -> None:
|
| 572 |
+
super().configure()
|
| 573 |
+
self.grid_resolution = self.cfg.grid_resolution
|
| 574 |
+
self.empty_embeds_ratio = self.cfg.empty_embeds_ratio
|
| 575 |
+
self.use_upsample = self.cfg.use_upsample
|
| 576 |
+
|
| 577 |
+
# Load DINOv2
|
| 578 |
+
dino_model = torch.hub.load(
|
| 579 |
+
self.cfg.model, self.cfg.version, pretrained=True
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
self.encoder = dino_model.eval()
|
| 583 |
+
|
| 584 |
+
# Optional: upsampler
|
| 585 |
+
if self.use_upsample:
|
| 586 |
+
upsampler = torch.hub.load("valeoai/NAF", "naf", pretrained=True)
|
| 587 |
+
self.upsampler = upsampler.eval()
|
| 588 |
+
|
| 589 |
+
# Image preprocessing
|
| 590 |
+
self.transform = transforms.Compose([
|
| 591 |
+
transforms.Normalize(
|
| 592 |
+
mean=[0.485, 0.456, 0.406],
|
| 593 |
+
std=[0.229, 0.224, 0.225],
|
| 594 |
+
),
|
| 595 |
+
])
|
| 596 |
+
|
| 597 |
+
self.patch_size = self.encoder.patch_size
|
| 598 |
+
self.patch_number = self.cfg.size // self.patch_size
|
| 599 |
+
self.proj_grid = ProjGrid(grid_resolution=self.cfg.grid_resolution)
|
| 600 |
+
|
| 601 |
+
# Fixed transformation matrix
|
| 602 |
+
self.register_buffer("fix_transform_matrix", torch.tensor([
|
| 603 |
+
[1.0, 0.0, 0.0, 0.0],
|
| 604 |
+
[0.0, 0.0, -1.0, -2.0],
|
| 605 |
+
[0.0, 1.0, 0.0, 0.0],
|
| 606 |
+
[0.0, 0.0, 0.0, 1.0]
|
| 607 |
+
]))
|
| 608 |
+
|
| 609 |
+
def forward(
|
| 610 |
+
self,
|
| 611 |
+
image: torch.Tensor,
|
| 612 |
+
image_mask: torch.Tensor = None,
|
| 613 |
+
camera_angle_x: torch.Tensor = None,
|
| 614 |
+
distance: torch.Tensor = None,
|
| 615 |
+
mesh_scale: torch.Tensor = None,
|
| 616 |
+
transform_matrix: torch.Tensor = None,
|
| 617 |
+
is_training: bool = False
|
| 618 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 619 |
+
"""
|
| 620 |
+
Forward pass
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
image: [B, num_views, C, H, W]
|
| 624 |
+
camera_angle_x: [B, num_views]
|
| 625 |
+
distance: [B, num_views]
|
| 626 |
+
mesh_scale: [B]
|
| 627 |
+
transform_matrix: [B, num_views, 4, 4]
|
| 628 |
+
|
| 629 |
+
Returns:
|
| 630 |
+
z_global: [B, num_global, C]
|
| 631 |
+
z: [B, grid_resolution^3, C]
|
| 632 |
+
"""
|
| 633 |
+
B, num_views, C, H, W = image.shape
|
| 634 |
+
image = image.reshape(B * num_views, C, H, W)
|
| 635 |
+
image = self.transform(image)
|
| 636 |
+
|
| 637 |
+
with torch.no_grad():
|
| 638 |
+
z = self.encoder(image, is_training=True)['x_prenorm']
|
| 639 |
+
z = F.layer_norm(z, z.shape[-1:])
|
| 640 |
+
z_clstoken = z[:, 0:1]
|
| 641 |
+
z_regtokens = z[:, 1:self.encoder.num_register_tokens + 1]
|
| 642 |
+
z_patchtokens = z[:, 1 + self.encoder.num_register_tokens:]
|
| 643 |
+
z_patchtokens = z_patchtokens.reshape(
|
| 644 |
+
z_patchtokens.shape[0],
|
| 645 |
+
self.patch_number,
|
| 646 |
+
self.patch_number,
|
| 647 |
+
-1
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
# Compute relative transformation
|
| 651 |
+
calc_mat, relative_transform = self.get_relative_transform(
|
| 652 |
+
transform_matrix, distance
|
| 653 |
+
)
|
| 654 |
+
calc_mat = calc_mat.reshape(B * num_views, 4, 4)
|
| 655 |
+
|
| 656 |
+
# Prepare parameters
|
| 657 |
+
init_mesh_scale = mesh_scale[:, None].expand(B, num_views).reshape(B * num_views)
|
| 658 |
+
camera_angle_x_flat = camera_angle_x.reshape(B * num_views)
|
| 659 |
+
distance_flat = distance.reshape(B * num_views)
|
| 660 |
+
|
| 661 |
+
# Accumulate per-view (avoid OOM)
|
| 662 |
+
z_accumulated = None
|
| 663 |
+
z_patchtokens_permuted = z_patchtokens.permute(0, 3, 1, 2) if self.use_upsample else None
|
| 664 |
+
|
| 665 |
+
with torch.no_grad():
|
| 666 |
+
for view_idx in range(num_views):
|
| 667 |
+
indices = torch.arange(
|
| 668 |
+
view_idx, B * num_views, num_views, device=z_patchtokens.device
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Project current view
|
| 672 |
+
z_view = self.proj_grid(
|
| 673 |
+
z_patchtokens[indices],
|
| 674 |
+
camera_angle_x_flat[indices],
|
| 675 |
+
distance_flat[indices],
|
| 676 |
+
init_mesh_scale[indices],
|
| 677 |
+
calc_mat[indices]
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Optional: upsample
|
| 681 |
+
if self.use_upsample:
|
| 682 |
+
chunk_upsampled = self.upsampler(
|
| 683 |
+
image[indices],
|
| 684 |
+
z_patchtokens_permuted[indices],
|
| 685 |
+
output_size=(518, 518)
|
| 686 |
+
)
|
| 687 |
+
chunk_proj = self.proj_grid(
|
| 688 |
+
chunk_upsampled,
|
| 689 |
+
camera_angle_x_flat[indices],
|
| 690 |
+
distance_flat[indices],
|
| 691 |
+
init_mesh_scale[indices],
|
| 692 |
+
calc_mat[indices],
|
| 693 |
+
BHWC=False
|
| 694 |
+
)
|
| 695 |
+
z_view = z_view + chunk_proj
|
| 696 |
+
del chunk_upsampled, chunk_proj
|
| 697 |
+
|
| 698 |
+
# Accumulate
|
| 699 |
+
if z_accumulated is None:
|
| 700 |
+
z_accumulated = z_view.clone()
|
| 701 |
+
else:
|
| 702 |
+
z_accumulated = z_accumulated + z_view
|
| 703 |
+
del z_view
|
| 704 |
+
|
| 705 |
+
if z_patchtokens_permuted is not None:
|
| 706 |
+
del z_patchtokens_permuted
|
| 707 |
+
|
| 708 |
+
# Average
|
| 709 |
+
z = z_accumulated / num_views
|
| 710 |
+
|
| 711 |
+
# Average global tokens
|
| 712 |
+
z_global = torch.cat([z_clstoken, z_regtokens], dim=1)
|
| 713 |
+
z_global = z_global.reshape(B, num_views, z_global.shape[-2], z_global.shape[-1])
|
| 714 |
+
z_global = z_global.mean(dim=1)
|
| 715 |
+
|
| 716 |
+
# Classifier-free guidance
|
| 717 |
+
if is_training and random.random() < self.empty_embeds_ratio:
|
| 718 |
+
z_global = z_global * 0
|
| 719 |
+
z = z * 0
|
| 720 |
+
|
| 721 |
+
return z_global, z
|
| 722 |
+
|
| 723 |
+
def get_relative_transform(
|
| 724 |
+
self,
|
| 725 |
+
transform_matrix: torch.Tensor,
|
| 726 |
+
distance: torch.Tensor
|
| 727 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 728 |
+
"""
|
| 729 |
+
Compute relative transformation matrix
|
| 730 |
+
|
| 731 |
+
Args:
|
| 732 |
+
transform_matrix: [B, num_views, 4, 4]
|
| 733 |
+
distance: [B, num_views]
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
calc_mat: [B, num_views, 4, 4]
|
| 737 |
+
relative_transform: [B, num_views, 4, 4]
|
| 738 |
+
"""
|
| 739 |
+
B, num_views, _, _ = transform_matrix.shape
|
| 740 |
+
init_transform_matrix = transform_matrix[:, 0:1]
|
| 741 |
+
|
| 742 |
+
fix_transform_matrix = self.fix_transform_matrix.unsqueeze(0).expand(B, -1, -1).clone()
|
| 743 |
+
init_distance = distance[:, 0]
|
| 744 |
+
fix_transform_matrix[:, 1, 3] = -init_distance
|
| 745 |
+
fix_transform_matrix = fix_transform_matrix.unsqueeze(1)
|
| 746 |
+
|
| 747 |
+
calc_mat, relative_transform = compute_calc_mat(
|
| 748 |
+
init_transform_matrix, transform_matrix, fix_transform_matrix
|
| 749 |
+
)
|
| 750 |
+
return calc_mat, relative_transform
|
pixal3d/models/transformers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import sparse_dit
|
| 2 |
+
from . import dense_dit
|
pixal3d/models/transformers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (253 Bytes). View file
|
|
|
pixal3d/models/transformers/__pycache__/dense_dit.cpython-310.pyc
ADDED
|
Binary file (9.49 kB). View file
|
|
|
pixal3d/models/transformers/__pycache__/sparse_dit.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
pixal3d/models/transformers/dense_dit.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
|
| 8 |
+
from ...modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
|
| 9 |
+
from ...modules.spatial import patchify, unpatchify
|
| 10 |
+
from ...utils.base import BaseModule
|
| 11 |
+
import pixal3d
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
class TimestepEmbedder(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Embeds scalar timesteps into vector representations.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.mlp = nn.Sequential(
|
| 22 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 23 |
+
nn.SiLU(),
|
| 24 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 25 |
+
)
|
| 26 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 30 |
+
"""
|
| 31 |
+
Create sinusoidal timestep embeddings.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
t: a 1-D Tensor of N indices, one per batch element.
|
| 35 |
+
These may be fractional.
|
| 36 |
+
dim: the dimension of the output.
|
| 37 |
+
max_period: controls the minimum frequency of the embeddings.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
an (N, D) Tensor of positional embeddings.
|
| 41 |
+
"""
|
| 42 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 43 |
+
half = dim // 2
|
| 44 |
+
freqs = torch.exp(
|
| 45 |
+
-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 46 |
+
).to(device=t.device)
|
| 47 |
+
args = t[:, None].float() * freqs[None]
|
| 48 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 49 |
+
if dim % 2:
|
| 50 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 51 |
+
return embedding
|
| 52 |
+
|
| 53 |
+
def forward(self, t):
|
| 54 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 55 |
+
t_freq = t_freq.to(self.mlp[0].weight.dtype)
|
| 56 |
+
t_emb = self.mlp(t_freq)
|
| 57 |
+
return t_emb
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DenseDiT(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
resolution: int,
|
| 64 |
+
in_channels: int,
|
| 65 |
+
model_channels: int,
|
| 66 |
+
cond_channels: int,
|
| 67 |
+
out_channels: int,
|
| 68 |
+
num_blocks: int,
|
| 69 |
+
num_heads: Optional[int] = None,
|
| 70 |
+
num_head_channels: Optional[int] = 64,
|
| 71 |
+
mlp_ratio: float = 4,
|
| 72 |
+
patch_size: int = 2,
|
| 73 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 74 |
+
use_fp16: bool = False,
|
| 75 |
+
use_checkpoint: bool = False,
|
| 76 |
+
share_mod: bool = False,
|
| 77 |
+
qk_rms_norm: bool = False,
|
| 78 |
+
qk_rms_norm_cross: bool = False,
|
| 79 |
+
latent_shape: list = [8, 16, 16, 16],
|
| 80 |
+
image_attn_mode:str = "cross",
|
| 81 |
+
load_ckpt:bool = True,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.resolution = resolution
|
| 85 |
+
self.in_channels = in_channels
|
| 86 |
+
self.model_channels = model_channels
|
| 87 |
+
self.cond_channels = cond_channels
|
| 88 |
+
self.out_channels = out_channels
|
| 89 |
+
self.num_blocks = num_blocks
|
| 90 |
+
self.num_heads = num_heads or model_channels // num_head_channels
|
| 91 |
+
self.mlp_ratio = mlp_ratio
|
| 92 |
+
self.patch_size = patch_size
|
| 93 |
+
self.pe_mode = pe_mode
|
| 94 |
+
self.use_fp16 = use_fp16
|
| 95 |
+
self.use_checkpoint = use_checkpoint
|
| 96 |
+
self.share_mod = share_mod
|
| 97 |
+
self.qk_rms_norm = qk_rms_norm
|
| 98 |
+
self.qk_rms_norm_cross = qk_rms_norm_cross
|
| 99 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 100 |
+
self.latent_shape = latent_shape
|
| 101 |
+
self.image_attn_mode = image_attn_mode
|
| 102 |
+
|
| 103 |
+
self.t_embedder = TimestepEmbedder(model_channels)
|
| 104 |
+
if share_mod:
|
| 105 |
+
self.adaLN_modulation = nn.Sequential(
|
| 106 |
+
nn.SiLU(),
|
| 107 |
+
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if pe_mode == "ape":
|
| 111 |
+
pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
|
| 112 |
+
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
|
| 113 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
| 114 |
+
pos_emb = pos_embedder(coords)
|
| 115 |
+
self.register_buffer("pos_emb", pos_emb)
|
| 116 |
+
|
| 117 |
+
self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
|
| 118 |
+
|
| 119 |
+
self.blocks = nn.ModuleList([
|
| 120 |
+
ModulatedTransformerCrossBlock(
|
| 121 |
+
model_channels,
|
| 122 |
+
cond_channels,
|
| 123 |
+
num_heads=self.num_heads,
|
| 124 |
+
mlp_ratio=self.mlp_ratio,
|
| 125 |
+
attn_mode='full',
|
| 126 |
+
use_checkpoint=self.use_checkpoint,
|
| 127 |
+
use_rope=(pe_mode == "rope"),
|
| 128 |
+
share_mod=share_mod,
|
| 129 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 130 |
+
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
| 131 |
+
image_attn_mode = self.image_attn_mode,
|
| 132 |
+
|
| 133 |
+
)
|
| 134 |
+
for _ in range(num_blocks)
|
| 135 |
+
])
|
| 136 |
+
|
| 137 |
+
self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
|
| 138 |
+
|
| 139 |
+
self.initialize_weights()
|
| 140 |
+
if use_fp16:
|
| 141 |
+
self.convert_to_fp16()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def device(self) -> torch.device:
|
| 146 |
+
"""
|
| 147 |
+
Return the device of the model.
|
| 148 |
+
"""
|
| 149 |
+
return next(self.parameters()).device
|
| 150 |
+
|
| 151 |
+
def convert_to_fp16(self) -> None:
|
| 152 |
+
"""
|
| 153 |
+
Convert the torso of the model to float16.
|
| 154 |
+
"""
|
| 155 |
+
# self.blocks.apply(convert_module_to_f16)
|
| 156 |
+
self.apply(convert_module_to_f16)
|
| 157 |
+
|
| 158 |
+
def convert_to_fp32(self) -> None:
|
| 159 |
+
"""
|
| 160 |
+
Convert the torso of the model to float32.
|
| 161 |
+
"""
|
| 162 |
+
self.blocks.apply(convert_module_to_f32)
|
| 163 |
+
|
| 164 |
+
def initialize_weights(self) -> None:
|
| 165 |
+
# Initialize transformer layers:
|
| 166 |
+
def _basic_init(module):
|
| 167 |
+
if isinstance(module, nn.Linear):
|
| 168 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 169 |
+
if module.bias is not None:
|
| 170 |
+
nn.init.constant_(module.bias, 0)
|
| 171 |
+
self.apply(_basic_init)
|
| 172 |
+
|
| 173 |
+
# Initialize timestep embedding MLP:
|
| 174 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 175 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 176 |
+
|
| 177 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 178 |
+
if self.share_mod:
|
| 179 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
| 180 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
| 181 |
+
else:
|
| 182 |
+
for block in self.blocks:
|
| 183 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 184 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 185 |
+
|
| 186 |
+
# Zero-out output layers:
|
| 187 |
+
nn.init.constant_(self.out_layer.weight, 0)
|
| 188 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
|
| 192 |
+
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
|
| 193 |
+
|
| 194 |
+
h = patchify(x, self.patch_size)
|
| 195 |
+
h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
|
| 196 |
+
h = self.input_layer(h)
|
| 197 |
+
h = h + self.pos_emb[None]
|
| 198 |
+
t_emb = self.t_embedder(t)
|
| 199 |
+
if self.share_mod:
|
| 200 |
+
t_emb = self.adaLN_modulation(t_emb)
|
| 201 |
+
t_emb = t_emb.type(self.dtype)
|
| 202 |
+
h = h.type(self.dtype)
|
| 203 |
+
if self.image_attn_mode=='proj':
|
| 204 |
+
global_cond,proj_cond = cond
|
| 205 |
+
global_cond = global_cond.type(self.dtype)
|
| 206 |
+
proj_cond = proj_cond.type(self.dtype)
|
| 207 |
+
cond = (global_cond, proj_cond)
|
| 208 |
+
else:
|
| 209 |
+
cond = cond.type(self.dtype)
|
| 210 |
+
for block in self.blocks:
|
| 211 |
+
h = block(h, t_emb, cond)
|
| 212 |
+
h = h.type(x.dtype)
|
| 213 |
+
h = F.layer_norm(h, h.shape[-1:])
|
| 214 |
+
h = self.out_layer(h)
|
| 215 |
+
|
| 216 |
+
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
|
| 217 |
+
h = unpatchify(h, self.patch_size).contiguous()
|
| 218 |
+
|
| 219 |
+
return h
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# ===== Align to sparse_dit style: ModelOutput + Denoiser wrapper (Lightning-friendly) =====
|
| 223 |
+
|
| 224 |
+
@dataclass
|
| 225 |
+
class DenseDiTModelOutput:
|
| 226 |
+
sample: torch.Tensor
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@pixal3d.register("dense-dit-denoiser")
|
| 230 |
+
class DenseDiTDenoiser(BaseModule):
|
| 231 |
+
@dataclass
|
| 232 |
+
class Config(BaseModule.Config):
|
| 233 |
+
# Mirror DenseDiT init signature with reasonable defaults
|
| 234 |
+
resolution: int = 64
|
| 235 |
+
in_channels: int = 16
|
| 236 |
+
model_channels: int = 1024
|
| 237 |
+
cond_channels: int = 1024
|
| 238 |
+
out_channels: int = 16
|
| 239 |
+
num_blocks: int = 24
|
| 240 |
+
num_heads: Optional[int] = None
|
| 241 |
+
num_head_channels: Optional[int] = 64
|
| 242 |
+
mlp_ratio: float = 4.0
|
| 243 |
+
patch_size: int = 2
|
| 244 |
+
pe_mode: str = "ape" # "ape" | "rope"
|
| 245 |
+
use_fp16: bool = False
|
| 246 |
+
use_checkpoint: bool = False
|
| 247 |
+
share_mod: bool = False
|
| 248 |
+
qk_rms_norm: bool = False
|
| 249 |
+
qk_rms_norm_cross: bool = False
|
| 250 |
+
latent_shape: list = (8, 16, 16, 16)
|
| 251 |
+
image_attn_mode: str = "cross"
|
| 252 |
+
load_ckpt:bool = True
|
| 253 |
+
|
| 254 |
+
cfg: Config
|
| 255 |
+
|
| 256 |
+
def configure(self) -> None:
|
| 257 |
+
# Instantiate the underlying DenseDiT model
|
| 258 |
+
self.dit_model = DenseDiT(
|
| 259 |
+
resolution=self.cfg.resolution,
|
| 260 |
+
in_channels=self.cfg.in_channels,
|
| 261 |
+
model_channels=self.cfg.model_channels,
|
| 262 |
+
cond_channels=self.cfg.cond_channels,
|
| 263 |
+
out_channels=self.cfg.out_channels,
|
| 264 |
+
num_blocks=self.cfg.num_blocks,
|
| 265 |
+
num_heads=self.cfg.num_heads,
|
| 266 |
+
num_head_channels=self.cfg.num_head_channels,
|
| 267 |
+
mlp_ratio=self.cfg.mlp_ratio,
|
| 268 |
+
patch_size=self.cfg.patch_size,
|
| 269 |
+
pe_mode=self.cfg.pe_mode,
|
| 270 |
+
use_fp16=self.cfg.use_fp16,
|
| 271 |
+
use_checkpoint=self.cfg.use_checkpoint,
|
| 272 |
+
share_mod=self.cfg.share_mod,
|
| 273 |
+
qk_rms_norm=self.cfg.qk_rms_norm,
|
| 274 |
+
qk_rms_norm_cross=self.cfg.qk_rms_norm_cross,
|
| 275 |
+
latent_shape=list(self.cfg.latent_shape) if isinstance(self.cfg.latent_shape, (list, tuple)) else self.cfg.latent_shape,
|
| 276 |
+
image_attn_mode=self.cfg.image_attn_mode,
|
| 277 |
+
load_ckpt=self.cfg.load_ckpt,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# For a consistent external API (some systems may read out_channels)
|
| 281 |
+
self.out_channels = self.cfg.out_channels
|
| 282 |
+
|
| 283 |
+
def forward(
|
| 284 |
+
self,
|
| 285 |
+
x: torch.Tensor,
|
| 286 |
+
t: torch.Tensor,
|
| 287 |
+
cond: torch.Tensor,
|
| 288 |
+
**kwargs,
|
| 289 |
+
) -> DenseDiTModelOutput:
|
| 290 |
+
"""Forward wrapper returning a structured output like diffusers models.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
x: [B, C, D, H, W] dense latent tensor.
|
| 294 |
+
t: [B] or [1] timestep tensor.
|
| 295 |
+
cond: conditioning tensor matching the transformer blocks' expected dims.
|
| 296 |
+
"""
|
| 297 |
+
out = self.dit_model(x, t, cond)
|
| 298 |
+
return DenseDiTModelOutput(sample=out)
|
pixal3d/models/transformers/sparse_dit.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Some parts of this file are adapted from the SparseDiT implementation
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Dict, Optional, Union, Tuple, Literal
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 14 |
+
from diffusers.utils import logging
|
| 15 |
+
|
| 16 |
+
import pixal3d
|
| 17 |
+
from pixal3d.utils.base import BaseModule
|
| 18 |
+
from huggingface_hub import hf_hub_download
|
| 19 |
+
|
| 20 |
+
# Import sparse operations
|
| 21 |
+
|
| 22 |
+
from ...modules import sparse as sp
|
| 23 |
+
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
|
| 24 |
+
from ...modules.transformer import AbsolutePositionEmbedder
|
| 25 |
+
from ...modules.sparse.transformer.modulated import ModulatedSparseTransformerCrossBlock
|
| 26 |
+
SPARSE_AVAILABLE = True
|
| 27 |
+
# except ImportError:
|
| 28 |
+
# print("Warning: sparse modules not found. Please ensure it's in your Python path.")
|
| 29 |
+
# sp = None
|
| 30 |
+
# convert_module_to_f16 = None
|
| 31 |
+
# convert_module_to_f32 = None
|
| 32 |
+
# AbsolutePositionEmbedder = None
|
| 33 |
+
# ModulatedSparseTransformerCrossBlock = None
|
| 34 |
+
# SPARSE_AVAILABLE = False
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class SparseDiTModelOutput:
|
| 41 |
+
sample: Any # Can be torch.FloatTensor or sp.SparseTensor
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TimestepEmbedder(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
Embeds scalar timesteps into vector representations.
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.mlp = nn.Sequential(
|
| 51 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 52 |
+
nn.SiLU(),
|
| 53 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 54 |
+
)
|
| 55 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 59 |
+
"""
|
| 60 |
+
Create sinusoidal timestep embeddings.
|
| 61 |
+
"""
|
| 62 |
+
half = dim // 2
|
| 63 |
+
freqs = torch.exp(
|
| 64 |
+
-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 65 |
+
).to(device=t.device)
|
| 66 |
+
args = t[:, None].float() * freqs[None]
|
| 67 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 68 |
+
if dim % 2:
|
| 69 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 70 |
+
return embedding
|
| 71 |
+
|
| 72 |
+
def forward(self, t):
|
| 73 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 74 |
+
t_freq = t_freq.to(self.mlp[0].weight.dtype)
|
| 75 |
+
t_emb = self.mlp(t_freq)
|
| 76 |
+
return t_emb
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SparseDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
| 80 |
+
"""
|
| 81 |
+
Sparse Diffusion Transformer model for 3D shape generation.
|
| 82 |
+
|
| 83 |
+
This model processes sparse 3D data using sparse attention mechanisms.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
_supports_gradient_checkpointing = True
|
| 87 |
+
|
| 88 |
+
@register_to_config
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
resolution: int = 64,
|
| 92 |
+
in_channels: int = 16,
|
| 93 |
+
model_channels: int = 1024,
|
| 94 |
+
cond_channels: int = 1024,
|
| 95 |
+
out_channels: int = 16,
|
| 96 |
+
num_blocks: int = 24,
|
| 97 |
+
num_heads: int = 32,
|
| 98 |
+
num_head_channels: int = 64,
|
| 99 |
+
num_kv_heads: int = 2,
|
| 100 |
+
compression_block_size: int = 4,
|
| 101 |
+
selection_block_size: int = 8,
|
| 102 |
+
topk: int = 32,
|
| 103 |
+
compression_version: str = 'v2',
|
| 104 |
+
mlp_ratio: float = 4.0,
|
| 105 |
+
pe_mode: str = "ape",
|
| 106 |
+
use_fp16: bool = True,
|
| 107 |
+
use_checkpoint: bool = True,
|
| 108 |
+
share_mod: bool = False,
|
| 109 |
+
qk_rms_norm: bool = True,
|
| 110 |
+
qk_rms_norm_cross: bool = False,
|
| 111 |
+
sparse_conditions: bool = True,
|
| 112 |
+
factor: float = 1.0,
|
| 113 |
+
window_size: int = 8,
|
| 114 |
+
use_shift: bool = True,
|
| 115 |
+
image_attn_mode:str='cross',
|
| 116 |
+
load_ckpt:bool=True,
|
| 117 |
+
version:Optional[str]='V10',
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
|
| 121 |
+
if not SPARSE_AVAILABLE:
|
| 122 |
+
raise ImportError("sparse modules not found.")
|
| 123 |
+
|
| 124 |
+
self.resolution = resolution
|
| 125 |
+
self.in_channels = in_channels
|
| 126 |
+
self.model_channels = model_channels
|
| 127 |
+
self.cond_channels = cond_channels
|
| 128 |
+
self.out_channels = out_channels
|
| 129 |
+
self.num_blocks = num_blocks
|
| 130 |
+
self.num_heads = num_heads or model_channels // num_head_channels
|
| 131 |
+
self.mlp_ratio = mlp_ratio
|
| 132 |
+
self.pe_mode = pe_mode
|
| 133 |
+
self.use_fp16 = use_fp16
|
| 134 |
+
self.use_checkpoint = use_checkpoint
|
| 135 |
+
self.share_mod = share_mod
|
| 136 |
+
self.qk_rms_norm = qk_rms_norm
|
| 137 |
+
self.qk_rms_norm_cross = qk_rms_norm_cross
|
| 138 |
+
self._dtype = torch.float16 if use_fp16 else torch.float32
|
| 139 |
+
self.sparse_conditions = sparse_conditions
|
| 140 |
+
self.factor = factor
|
| 141 |
+
self.compression_block_size = compression_block_size
|
| 142 |
+
self.selection_block_size = selection_block_size
|
| 143 |
+
self.image_attn_mode = image_attn_mode
|
| 144 |
+
|
| 145 |
+
# Timestep embedding
|
| 146 |
+
self.t_embedder = TimestepEmbedder(model_channels)
|
| 147 |
+
|
| 148 |
+
# Shared modulation if enabled
|
| 149 |
+
if share_mod:
|
| 150 |
+
self.adaLN_modulation = nn.Sequential(
|
| 151 |
+
nn.SiLU(),
|
| 152 |
+
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Condition processing for sparse conditions
|
| 156 |
+
if sparse_conditions:
|
| 157 |
+
self.cond_proj = sp.SparseLinear(cond_channels, cond_channels)
|
| 158 |
+
self.pos_embedder_cond = AbsolutePositionEmbedder(model_channels, in_channels=3)
|
| 159 |
+
|
| 160 |
+
# Position embedding
|
| 161 |
+
if pe_mode == "ape":
|
| 162 |
+
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
| 163 |
+
|
| 164 |
+
# Input projection
|
| 165 |
+
self.input_layer = sp.SparseLinear(in_channels, model_channels)
|
| 166 |
+
|
| 167 |
+
# Transformer blocks
|
| 168 |
+
self.blocks = nn.ModuleList([
|
| 169 |
+
ModulatedSparseTransformerCrossBlock(
|
| 170 |
+
model_channels,
|
| 171 |
+
cond_channels,
|
| 172 |
+
num_heads=self.num_heads,
|
| 173 |
+
num_kv_heads=num_kv_heads,
|
| 174 |
+
compression_block_size=compression_block_size,
|
| 175 |
+
selection_block_size=selection_block_size,
|
| 176 |
+
topk=topk,
|
| 177 |
+
mlp_ratio=self.mlp_ratio,
|
| 178 |
+
attn_mode='full',
|
| 179 |
+
compression_version=compression_version,
|
| 180 |
+
use_checkpoint=self.use_checkpoint,
|
| 181 |
+
use_rope=(pe_mode == "rope"),
|
| 182 |
+
share_mod=self.share_mod,
|
| 183 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 184 |
+
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
| 185 |
+
resolution=resolution,
|
| 186 |
+
window_size=window_size,
|
| 187 |
+
shift_window=window_size // 2 * (i % 2) if use_shift else window_size // 2,
|
| 188 |
+
image_attn_mode = image_attn_mode,
|
| 189 |
+
)
|
| 190 |
+
for i in range(num_blocks)
|
| 191 |
+
])
|
| 192 |
+
|
| 193 |
+
# Output projection
|
| 194 |
+
self.out_layer = sp.SparseLinear(model_channels, out_channels)
|
| 195 |
+
|
| 196 |
+
# Initialize weights
|
| 197 |
+
self.initialize_weights()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
self.gradient_checkpointing = False
|
| 201 |
+
|
| 202 |
+
if use_fp16:
|
| 203 |
+
print("Converting model to float16 ============================")
|
| 204 |
+
self.convert_to_fp16()
|
| 205 |
+
# else:
|
| 206 |
+
# self.convert_to_fp32()
|
| 207 |
+
@property
|
| 208 |
+
def device(self) -> torch.device:
|
| 209 |
+
"""Return the device of the model."""
|
| 210 |
+
return next(self.parameters()).device
|
| 211 |
+
|
| 212 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 213 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 214 |
+
module.gradient_checkpointing = value
|
| 215 |
+
|
| 216 |
+
def convert_to_fp16(self) -> None:
|
| 217 |
+
"""Convert the model to float16."""
|
| 218 |
+
self.apply(convert_module_to_f16)
|
| 219 |
+
|
| 220 |
+
def convert_to_fp32(self) -> None:
|
| 221 |
+
"""Convert the model to float32."""
|
| 222 |
+
self.apply(convert_module_to_f32)
|
| 223 |
+
|
| 224 |
+
def initialize_weights(self) -> None:
|
| 225 |
+
"""Initialize model weights."""
|
| 226 |
+
# Initialize transformer layers
|
| 227 |
+
def _basic_init(module):
|
| 228 |
+
if isinstance(module, nn.Linear):
|
| 229 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 230 |
+
if module.bias is not None:
|
| 231 |
+
nn.init.constant_(module.bias, 0)
|
| 232 |
+
self.apply(_basic_init)
|
| 233 |
+
|
| 234 |
+
# Initialize timestep embedding MLP
|
| 235 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 236 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 237 |
+
|
| 238 |
+
# Zero-out adaLN modulation layers
|
| 239 |
+
if self.share_mod:
|
| 240 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
| 241 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
| 242 |
+
else:
|
| 243 |
+
for block in self.blocks:
|
| 244 |
+
# if hasattr(block, 'adaLN_modulation'):
|
| 245 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 246 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 247 |
+
|
| 248 |
+
# Zero-out output layers
|
| 249 |
+
nn.init.constant_(self.out_layer.weight, 0)
|
| 250 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
| 251 |
+
|
| 252 |
+
def forward(
|
| 253 |
+
self,
|
| 254 |
+
hidden_states: Any, # sp.SparseTensor
|
| 255 |
+
timestep: torch.Tensor,
|
| 256 |
+
encoder_hidden_states: Optional[Any] = None, # torch.Tensor or sp.SparseTensor
|
| 257 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 258 |
+
return_dict: bool = True,
|
| 259 |
+
) -> Union[SparseDiTModelOutput, Tuple]:
|
| 260 |
+
"""
|
| 261 |
+
Forward pass of the SparseDiT model.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
hidden_states: Input sparse tensor
|
| 265 |
+
timestep: Timestep tensor
|
| 266 |
+
encoder_hidden_states: Condition tensor (visual/text conditions)
|
| 267 |
+
attention_kwargs: Additional attention arguments
|
| 268 |
+
return_dict: Whether to return a dictionary
|
| 269 |
+
"""
|
| 270 |
+
# breakpoint()
|
| 271 |
+
# Process input
|
| 272 |
+
assert attention_kwargs is None, "attention_kwargs not supported in SparseDiT"
|
| 273 |
+
# breakpoint()
|
| 274 |
+
h = self.input_layer(hidden_states).type(self._dtype)
|
| 275 |
+
|
| 276 |
+
# Process timestep
|
| 277 |
+
t_emb = self.t_embedder(timestep)
|
| 278 |
+
if self.share_mod:
|
| 279 |
+
t_emb = self.adaLN_modulation(t_emb)
|
| 280 |
+
t_emb = t_emb.type(self._dtype)
|
| 281 |
+
|
| 282 |
+
# Process conditions
|
| 283 |
+
|
| 284 |
+
cond = encoder_hidden_states
|
| 285 |
+
if self.image_attn_mode=='proj':
|
| 286 |
+
global_cond,sparse_cond = cond
|
| 287 |
+
|
| 288 |
+
if sparse_cond is not None:
|
| 289 |
+
sparse_cond = sparse_cond.type(self._dtype)
|
| 290 |
+
global_cond = global_cond.type(self._dtype)
|
| 291 |
+
# breakpoint()
|
| 292 |
+
if self.sparse_conditions and isinstance(sparse_cond, sp.SparseTensor):
|
| 293 |
+
# breakpoint()
|
| 294 |
+
sparse_cond = self.cond_proj(sparse_cond)
|
| 295 |
+
sparse_cond = sparse_cond + self.pos_embedder_cond(sparse_cond.coords[:, 1:]).type(self._dtype)
|
| 296 |
+
cond = (global_cond,sparse_cond)
|
| 297 |
+
else:
|
| 298 |
+
if self.sparse_conditions:
|
| 299 |
+
cond = self.cond_proj(cond)
|
| 300 |
+
cond = cond + self.pos_embedder_cond(cond.coords[:, 1:]).type(self.dtype)
|
| 301 |
+
|
| 302 |
+
# Add positional embeddings
|
| 303 |
+
if self.pe_mode == "ape":
|
| 304 |
+
h = h + self.pos_embedder(h.coords[:, 1:], factor=self.factor).type(self._dtype)
|
| 305 |
+
|
| 306 |
+
# Process through transformer blocks
|
| 307 |
+
for block in self.blocks:
|
| 308 |
+
if self.training and self.gradient_checkpointing:
|
| 309 |
+
def create_custom_forward(module):
|
| 310 |
+
def custom_forward(*inputs):
|
| 311 |
+
return module(*inputs)
|
| 312 |
+
return custom_forward
|
| 313 |
+
|
| 314 |
+
h = torch.utils.checkpoint.checkpoint(
|
| 315 |
+
create_custom_forward(block),
|
| 316 |
+
h, t_emb, cond
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
h = block(h, t_emb, cond)
|
| 320 |
+
|
| 321 |
+
# Final layer norm and output projection
|
| 322 |
+
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
| 323 |
+
h = self.out_layer(h.type(hidden_states.dtype))
|
| 324 |
+
|
| 325 |
+
if not return_dict:
|
| 326 |
+
return (h,)
|
| 327 |
+
|
| 328 |
+
return SparseDiTModelOutput(sample=h)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@pixal3d.register("sparse-dit-denoiser")
|
| 332 |
+
class SparseDiTDenoiser(BaseModule):
|
| 333 |
+
"""
|
| 334 |
+
Sparse DiT Denoiser wrapper for pixal3d framework.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
@dataclass
|
| 338 |
+
class Config(BaseModule.Config):
|
| 339 |
+
# Model architecture
|
| 340 |
+
resolution: int = 64
|
| 341 |
+
in_channels: int = 16
|
| 342 |
+
model_channels: int = 1024
|
| 343 |
+
cond_channels: int = 1024
|
| 344 |
+
out_channels: int = 16
|
| 345 |
+
num_blocks: int = 24
|
| 346 |
+
num_heads: int = 32
|
| 347 |
+
num_kv_heads: int = 2
|
| 348 |
+
compression_block_size: int = 4
|
| 349 |
+
selection_block_size: int = 8
|
| 350 |
+
topk: int = 32
|
| 351 |
+
compression_version: str = 'v2'
|
| 352 |
+
mlp_ratio: float = 4.0
|
| 353 |
+
pe_mode: str = "ape"
|
| 354 |
+
use_fp16: bool = True
|
| 355 |
+
use_checkpoint: bool = True
|
| 356 |
+
qk_rms_norm: bool = True
|
| 357 |
+
qk_rms_norm_cross: bool = False
|
| 358 |
+
sparse_conditions: bool = True
|
| 359 |
+
factor: float = 1.0
|
| 360 |
+
window_size: int = 8
|
| 361 |
+
use_shift: bool = True
|
| 362 |
+
|
| 363 |
+
# Condition settings
|
| 364 |
+
use_visual_condition: bool = True
|
| 365 |
+
visual_condition_dim: int = 1024
|
| 366 |
+
use_caption_condition: bool = False
|
| 367 |
+
caption_condition_dim: int = 1024
|
| 368 |
+
use_label_condition: bool = False
|
| 369 |
+
label_condition_dim: int = 1024
|
| 370 |
+
|
| 371 |
+
# Training settings
|
| 372 |
+
pretrained_model_name_or_path: Optional[str] = None
|
| 373 |
+
|
| 374 |
+
image_attn_mode:Optional[str]='cross'
|
| 375 |
+
load_ckpt:bool =True
|
| 376 |
+
version:Optional[str]='V10'
|
| 377 |
+
|
| 378 |
+
cfg: Config
|
| 379 |
+
|
| 380 |
+
def configure(self) -> None:
|
| 381 |
+
"""Configure the SparseDiT model."""
|
| 382 |
+
|
| 383 |
+
# Create the core SparseDiT model
|
| 384 |
+
self.dit_model = SparseDiTModel(
|
| 385 |
+
resolution=self.cfg.resolution,
|
| 386 |
+
in_channels=self.cfg.in_channels,
|
| 387 |
+
model_channels=self.cfg.model_channels,
|
| 388 |
+
cond_channels=self.cfg.cond_channels,
|
| 389 |
+
out_channels=self.cfg.out_channels,
|
| 390 |
+
num_blocks=self.cfg.num_blocks,
|
| 391 |
+
num_heads=self.cfg.num_heads,
|
| 392 |
+
num_kv_heads=self.cfg.num_kv_heads,
|
| 393 |
+
compression_block_size=self.cfg.compression_block_size,
|
| 394 |
+
selection_block_size=self.cfg.selection_block_size,
|
| 395 |
+
topk=self.cfg.topk,
|
| 396 |
+
compression_version=self.cfg.compression_version,
|
| 397 |
+
mlp_ratio=self.cfg.mlp_ratio,
|
| 398 |
+
pe_mode=self.cfg.pe_mode,
|
| 399 |
+
use_fp16=self.cfg.use_fp16,
|
| 400 |
+
use_checkpoint=self.cfg.use_checkpoint,
|
| 401 |
+
sparse_conditions=self.cfg.sparse_conditions,
|
| 402 |
+
factor=self.cfg.factor,
|
| 403 |
+
window_size=self.cfg.window_size,
|
| 404 |
+
use_shift=self.cfg.use_shift,
|
| 405 |
+
image_attn_mode=self.cfg.image_attn_mode,
|
| 406 |
+
load_ckpt = self.cfg.load_ckpt,
|
| 407 |
+
version=self.cfg.version,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Condition projectors
|
| 411 |
+
if self.cfg.use_visual_condition and self.cfg.visual_condition_dim != self.cfg.cond_channels:
|
| 412 |
+
self.proj_visual_condition = nn.Sequential(
|
| 413 |
+
nn.RMSNorm(self.cfg.visual_condition_dim),
|
| 414 |
+
nn.Linear(self.cfg.visual_condition_dim, self.cfg.cond_channels),
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
if self.cfg.use_caption_condition and self.cfg.caption_condition_dim != self.cfg.cond_channels:
|
| 418 |
+
self.proj_caption_condition = nn.Sequential(
|
| 419 |
+
nn.RMSNorm(self.cfg.caption_condition_dim),
|
| 420 |
+
nn.Linear(self.cfg.caption_condition_dim, self.cfg.cond_channels),
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
if self.cfg.use_label_condition and self.cfg.label_condition_dim != self.cfg.cond_channels:
|
| 424 |
+
self.proj_label_condition = nn.Sequential(
|
| 425 |
+
nn.RMSNorm(self.cfg.label_condition_dim),
|
| 426 |
+
nn.Linear(self.cfg.label_condition_dim, self.cfg.cond_channels),
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Load pretrained weights if specified
|
| 430 |
+
if self.cfg.pretrained_model_name_or_path:
|
| 431 |
+
print(f"Loading pretrained SparseDiT model from {self.cfg.pretrained_model_name_or_path}")
|
| 432 |
+
ckpt = torch.load(
|
| 433 |
+
self.cfg.pretrained_model_name_or_path,
|
| 434 |
+
map_location="cpu",
|
| 435 |
+
weights_only=True,
|
| 436 |
+
)
|
| 437 |
+
if "state_dict" in ckpt.keys():
|
| 438 |
+
ckpt = ckpt["state_dict"]
|
| 439 |
+
self.load_state_dict(ckpt, strict=True)
|
| 440 |
+
|
| 441 |
+
def forward(
|
| 442 |
+
self,
|
| 443 |
+
x: Any, # sp.SparseTensor
|
| 444 |
+
t: torch.Tensor,
|
| 445 |
+
cond: Optional[Any] = None,
|
| 446 |
+
):
|
| 447 |
+
"""
|
| 448 |
+
Forward pass of the denoiser.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
model_input: Input sparse tensor [SparseTensor with features]
|
| 452 |
+
timestep: Timestep tensor [batch_size,]
|
| 453 |
+
visual_condition: Visual condition tensor
|
| 454 |
+
caption_condition: Caption condition tensor
|
| 455 |
+
label_condition: Label condition tensor
|
| 456 |
+
attention_kwargs: Additional attention arguments
|
| 457 |
+
return_dict: Whether to return a dictionary
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
output = self.dit_model(
|
| 462 |
+
hidden_states=x,
|
| 463 |
+
timestep=t,
|
| 464 |
+
encoder_hidden_states=cond,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
return output
|
| 468 |
+
|
| 469 |
+
|
pixal3d/modules/__pycache__/norm.cpython-310.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
pixal3d/modules/__pycache__/spatial.cpython-310.pyc
ADDED
|
Binary file (2.49 kB). View file
|
|
|
pixal3d/modules/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.55 kB). View file
|
|
|
pixal3d/modules/attention/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
BACKEND = 'flash_attn'
|
| 3 |
+
DEBUG = False
|
| 4 |
+
|
| 5 |
+
def __from_env():
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
global BACKEND
|
| 9 |
+
global DEBUG
|
| 10 |
+
|
| 11 |
+
env_attn_backend = os.environ.get('ATTN_BACKEND')
|
| 12 |
+
env_sttn_debug = os.environ.get('ATTN_DEBUG')
|
| 13 |
+
|
| 14 |
+
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
|
| 15 |
+
BACKEND = env_attn_backend
|
| 16 |
+
if env_sttn_debug is not None:
|
| 17 |
+
DEBUG = env_sttn_debug == '1'
|
| 18 |
+
|
| 19 |
+
print(f"[ATTENTION] Using backend: {BACKEND}")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__from_env()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def set_backend(backend: Literal['xformers', 'flash_attn']):
|
| 26 |
+
global BACKEND
|
| 27 |
+
BACKEND = backend
|
| 28 |
+
|
| 29 |
+
def set_debug(debug: bool):
|
| 30 |
+
global DEBUG
|
| 31 |
+
DEBUG = debug
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
from .full_attn import *
|
| 35 |
+
from .modules import *
|
pixal3d/modules/attention/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (940 Bytes). View file
|
|
|
pixal3d/modules/attention/__pycache__/full_attn.cpython-310.pyc
ADDED
|
Binary file (4.15 kB). View file
|
|
|
pixal3d/modules/attention/__pycache__/modules.cpython-310.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
pixal3d/modules/attention/full_attn.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
from . import DEBUG, BACKEND
|
| 5 |
+
|
| 6 |
+
if BACKEND == 'xformers':
|
| 7 |
+
import xformers.ops as xops
|
| 8 |
+
elif BACKEND == 'flash_attn':
|
| 9 |
+
import flash_attn
|
| 10 |
+
elif BACKEND == 'sdpa':
|
| 11 |
+
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 12 |
+
elif BACKEND == 'naive':
|
| 13 |
+
pass
|
| 14 |
+
else:
|
| 15 |
+
raise ValueError(f"Unknown attention backend: {BACKEND}")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
'scaled_dot_product_attention',
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _naive_sdpa(q, k, v):
|
| 24 |
+
"""
|
| 25 |
+
Naive implementation of scaled dot product attention.
|
| 26 |
+
"""
|
| 27 |
+
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 28 |
+
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 29 |
+
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 30 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
| 31 |
+
attn_weight = q @ k.transpose(-2, -1) * scale_factor
|
| 32 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
| 33 |
+
out = attn_weight @ v
|
| 34 |
+
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@overload
|
| 39 |
+
def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
"""
|
| 41 |
+
Apply scaled dot product attention.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
|
| 45 |
+
"""
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
@overload
|
| 49 |
+
def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Apply scaled dot product attention.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
|
| 55 |
+
kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
|
| 56 |
+
"""
|
| 57 |
+
...
|
| 58 |
+
|
| 59 |
+
@overload
|
| 60 |
+
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Apply scaled dot product attention.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
|
| 66 |
+
k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
|
| 67 |
+
v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
|
| 68 |
+
|
| 69 |
+
Note:
|
| 70 |
+
k and v are assumed to have the same coordinate map.
|
| 71 |
+
"""
|
| 72 |
+
...
|
| 73 |
+
|
| 74 |
+
def scaled_dot_product_attention(*args, **kwargs):
|
| 75 |
+
arg_names_dict = {
|
| 76 |
+
1: ['qkv'],
|
| 77 |
+
2: ['q', 'kv'],
|
| 78 |
+
3: ['q', 'k', 'v']
|
| 79 |
+
}
|
| 80 |
+
num_all_args = len(args) + len(kwargs)
|
| 81 |
+
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
|
| 82 |
+
for key in arg_names_dict[num_all_args][len(args):]:
|
| 83 |
+
assert key in kwargs, f"Missing argument {key}"
|
| 84 |
+
|
| 85 |
+
if num_all_args == 1:
|
| 86 |
+
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
| 87 |
+
assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
|
| 88 |
+
device = qkv.device
|
| 89 |
+
|
| 90 |
+
elif num_all_args == 2:
|
| 91 |
+
q = args[0] if len(args) > 0 else kwargs['q']
|
| 92 |
+
kv = args[1] if len(args) > 1 else kwargs['kv']
|
| 93 |
+
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
|
| 94 |
+
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
|
| 95 |
+
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
|
| 96 |
+
device = q.device
|
| 97 |
+
|
| 98 |
+
elif num_all_args == 3:
|
| 99 |
+
q = args[0] if len(args) > 0 else kwargs['q']
|
| 100 |
+
k = args[1] if len(args) > 1 else kwargs['k']
|
| 101 |
+
v = args[2] if len(args) > 2 else kwargs['v']
|
| 102 |
+
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
|
| 103 |
+
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
|
| 104 |
+
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
|
| 105 |
+
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
|
| 106 |
+
device = q.device
|
| 107 |
+
|
| 108 |
+
if BACKEND == 'xformers':
|
| 109 |
+
if num_all_args == 1:
|
| 110 |
+
q, k, v = qkv.unbind(dim=2)
|
| 111 |
+
elif num_all_args == 2:
|
| 112 |
+
k, v = kv.unbind(dim=2)
|
| 113 |
+
out = xops.memory_efficient_attention(q, k, v)
|
| 114 |
+
elif BACKEND == 'flash_attn':
|
| 115 |
+
if num_all_args == 1:
|
| 116 |
+
out = flash_attn.flash_attn_qkvpacked_func(qkv)
|
| 117 |
+
elif num_all_args == 2:
|
| 118 |
+
out = flash_attn.flash_attn_kvpacked_func(q, kv)
|
| 119 |
+
elif num_all_args == 3:
|
| 120 |
+
out = flash_attn.flash_attn_func(q, k, v)
|
| 121 |
+
elif BACKEND == 'sdpa':
|
| 122 |
+
if num_all_args == 1:
|
| 123 |
+
q, k, v = qkv.unbind(dim=2)
|
| 124 |
+
elif num_all_args == 2:
|
| 125 |
+
k, v = kv.unbind(dim=2)
|
| 126 |
+
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 127 |
+
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 128 |
+
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 129 |
+
out = sdpa(q, k, v) # [N, H, L, C]
|
| 130 |
+
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
| 131 |
+
elif BACKEND == 'naive':
|
| 132 |
+
if num_all_args == 1:
|
| 133 |
+
q, k, v = qkv.unbind(dim=2)
|
| 134 |
+
elif num_all_args == 2:
|
| 135 |
+
k, v = kv.unbind(dim=2)
|
| 136 |
+
out = _naive_sdpa(q, k, v)
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unknown attention module: {BACKEND}")
|
| 139 |
+
|
| 140 |
+
return out
|
pixal3d/modules/attention/modules.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from .full_attn import scaled_dot_product_attention
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MultiHeadRMSNorm(nn.Module):
|
| 9 |
+
def __init__(self, dim: int, heads: int):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.scale = dim ** 0.5
|
| 12 |
+
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RotaryPositionEmbedder(nn.Module):
|
| 19 |
+
def __init__(self, hidden_size: int, in_channels: int = 3):
|
| 20 |
+
super().__init__()
|
| 21 |
+
assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
|
| 22 |
+
self.hidden_size = hidden_size
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
self.freq_dim = hidden_size // in_channels // 2
|
| 25 |
+
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
| 26 |
+
self.freqs = 1.0 / (10000 ** self.freqs)
|
| 27 |
+
|
| 28 |
+
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
self.freqs = self.freqs.to(indices.device)
|
| 30 |
+
phases = torch.outer(indices, self.freqs)
|
| 31 |
+
phases = torch.polar(torch.ones_like(phases), phases)
|
| 32 |
+
return phases
|
| 33 |
+
|
| 34 |
+
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 36 |
+
x_rotated = x_complex * phases
|
| 37 |
+
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
| 38 |
+
return x_embed
|
| 39 |
+
|
| 40 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
q (sp.SparseTensor): [..., N, D] tensor of queries
|
| 44 |
+
k (sp.SparseTensor): [..., N, D] tensor of keys
|
| 45 |
+
indices (torch.Tensor): [..., N, C] tensor of spatial positions
|
| 46 |
+
"""
|
| 47 |
+
if indices is None:
|
| 48 |
+
indices = torch.arange(q.shape[-2], device=q.device)
|
| 49 |
+
if len(q.shape) > 2:
|
| 50 |
+
indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
|
| 51 |
+
|
| 52 |
+
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
| 53 |
+
if phases.shape[1] < self.hidden_size // 2:
|
| 54 |
+
phases = torch.cat([phases, torch.polar(
|
| 55 |
+
torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
|
| 56 |
+
torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
|
| 57 |
+
)], dim=-1)
|
| 58 |
+
q_embed = self._rotary_embedding(q, phases)
|
| 59 |
+
k_embed = self._rotary_embedding(k, phases)
|
| 60 |
+
return q_embed, k_embed
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class MultiHeadAttention(nn.Module):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
channels: int,
|
| 67 |
+
num_heads: int,
|
| 68 |
+
ctx_channels: Optional[int]=None,
|
| 69 |
+
type: Literal["self", "cross"] = "self",
|
| 70 |
+
attn_mode: Literal["full", "windowed"] = "full",
|
| 71 |
+
window_size: Optional[int] = None,
|
| 72 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 73 |
+
qkv_bias: bool = True,
|
| 74 |
+
use_rope: bool = False,
|
| 75 |
+
qk_rms_norm: bool = False,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
assert channels % num_heads == 0
|
| 79 |
+
assert type in ["self", "cross"], f"Invalid attention type: {type}"
|
| 80 |
+
assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
|
| 81 |
+
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
|
| 82 |
+
|
| 83 |
+
if attn_mode == "windowed":
|
| 84 |
+
raise NotImplementedError("Windowed attention is not yet implemented")
|
| 85 |
+
|
| 86 |
+
self.channels = channels
|
| 87 |
+
self.head_dim = channels // num_heads
|
| 88 |
+
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
| 89 |
+
self.num_heads = num_heads
|
| 90 |
+
self._type = type
|
| 91 |
+
self.attn_mode = attn_mode
|
| 92 |
+
self.window_size = window_size
|
| 93 |
+
self.shift_window = shift_window
|
| 94 |
+
self.use_rope = use_rope
|
| 95 |
+
self.qk_rms_norm = qk_rms_norm
|
| 96 |
+
|
| 97 |
+
if self._type == "self":
|
| 98 |
+
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
| 99 |
+
else:
|
| 100 |
+
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
| 101 |
+
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
| 102 |
+
|
| 103 |
+
if self.qk_rms_norm:
|
| 104 |
+
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
| 105 |
+
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
| 106 |
+
|
| 107 |
+
self.to_out = nn.Linear(channels, channels)
|
| 108 |
+
|
| 109 |
+
if use_rope:
|
| 110 |
+
self.rope = RotaryPositionEmbedder(channels)
|
| 111 |
+
|
| 112 |
+
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 113 |
+
B, L, C = x.shape
|
| 114 |
+
if self._type == "self":
|
| 115 |
+
qkv = self.to_qkv(x)
|
| 116 |
+
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
| 117 |
+
if self.use_rope:
|
| 118 |
+
q, k, v = qkv.unbind(dim=2)
|
| 119 |
+
q, k = self.rope(q, k, indices)
|
| 120 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 121 |
+
if self.attn_mode == "full":
|
| 122 |
+
if self.qk_rms_norm:
|
| 123 |
+
q, k, v = qkv.unbind(dim=2)
|
| 124 |
+
q = self.q_rms_norm(q)
|
| 125 |
+
k = self.k_rms_norm(k)
|
| 126 |
+
h = scaled_dot_product_attention(q, k, v)
|
| 127 |
+
else:
|
| 128 |
+
h = scaled_dot_product_attention(qkv)
|
| 129 |
+
elif self.attn_mode == "windowed":
|
| 130 |
+
raise NotImplementedError("Windowed attention is not yet implemented")
|
| 131 |
+
else:
|
| 132 |
+
Lkv = context.shape[1]
|
| 133 |
+
q = self.to_q(x)
|
| 134 |
+
kv = self.to_kv(context)
|
| 135 |
+
q = q.reshape(B, L, self.num_heads, -1)
|
| 136 |
+
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
| 137 |
+
if self.qk_rms_norm:
|
| 138 |
+
q = self.q_rms_norm(q)
|
| 139 |
+
k, v = kv.unbind(dim=2)
|
| 140 |
+
k = self.k_rms_norm(k)
|
| 141 |
+
h = scaled_dot_product_attention(q, k, v)
|
| 142 |
+
else:
|
| 143 |
+
h = scaled_dot_product_attention(q, kv)
|
| 144 |
+
h = h.reshape(B, L, -1)
|
| 145 |
+
h = self.to_out(h)
|
| 146 |
+
return h
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class ProjectAttention(nn.Module):
|
| 150 |
+
def __init__(self,cross_attn_block: nn.Module):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.cross_attn_block = cross_attn_block
|
| 153 |
+
self.global_token_length = 5
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 157 |
+
|
| 158 |
+
global_context = context[0]
|
| 159 |
+
proj_context = context[1]
|
| 160 |
+
global_context = self.cross_attn_block(x, global_context)
|
| 161 |
+
context = proj_context + global_context
|
| 162 |
+
return context + x
|
| 163 |
+
|
| 164 |
+
|
pixal3d/modules/norm.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LayerNorm32(nn.LayerNorm):
|
| 6 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 7 |
+
return super().forward(x.float()).type(x.dtype)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GroupNorm32(nn.GroupNorm):
|
| 11 |
+
"""
|
| 12 |
+
A GroupNorm layer that converts to float32 before the forward pass.
|
| 13 |
+
"""
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return super().forward(x.float()).type(x.dtype)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ChannelLayerNorm32(LayerNorm32):
|
| 19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
DIM = x.dim()
|
| 21 |
+
x = x.permute(0, *range(2, DIM), 1).contiguous()
|
| 22 |
+
x = super().forward(x)
|
| 23 |
+
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
|
| 24 |
+
return x
|
| 25 |
+
|
pixal3d/modules/sparse/__init__.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
BACKEND = 'torchsparse'
|
| 4 |
+
DEBUG = False
|
| 5 |
+
ATTN = 'flash_attn'
|
| 6 |
+
|
| 7 |
+
def __from_env():
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
global BACKEND
|
| 11 |
+
global DEBUG
|
| 12 |
+
global ATTN
|
| 13 |
+
|
| 14 |
+
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
|
| 15 |
+
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
|
| 16 |
+
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
|
| 17 |
+
if env_sparse_attn is None:
|
| 18 |
+
env_sparse_attn = os.environ.get('ATTN_BACKEND')
|
| 19 |
+
|
| 20 |
+
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
|
| 21 |
+
BACKEND = env_sparse_backend
|
| 22 |
+
if env_sparse_debug is not None:
|
| 23 |
+
DEBUG = env_sparse_debug == '1'
|
| 24 |
+
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
|
| 25 |
+
ATTN = env_sparse_attn
|
| 26 |
+
|
| 27 |
+
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__from_env()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def set_backend(backend: Literal['spconv', 'torchsparse']):
|
| 34 |
+
global BACKEND
|
| 35 |
+
BACKEND = backend
|
| 36 |
+
|
| 37 |
+
def set_debug(debug: bool):
|
| 38 |
+
global DEBUG
|
| 39 |
+
DEBUG = debug
|
| 40 |
+
|
| 41 |
+
def set_attn(attn: Literal['xformers', 'flash_attn']):
|
| 42 |
+
global ATTN
|
| 43 |
+
ATTN = attn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
import importlib
|
| 47 |
+
|
| 48 |
+
__attributes = {
|
| 49 |
+
'SparseTensor': 'basic',
|
| 50 |
+
'sparse_batch_broadcast': 'basic',
|
| 51 |
+
'sparse_batch_op': 'basic',
|
| 52 |
+
'sparse_cat': 'basic',
|
| 53 |
+
'sparse_unbind': 'basic',
|
| 54 |
+
'SparseGroupNorm': 'norm',
|
| 55 |
+
'SparseLayerNorm': 'norm',
|
| 56 |
+
'SparseGroupNorm32': 'norm',
|
| 57 |
+
'SparseLayerNorm32': 'norm',
|
| 58 |
+
'SparseSigmoid': 'nonlinearity',
|
| 59 |
+
'SparseReLU': 'nonlinearity',
|
| 60 |
+
'SparseSiLU': 'nonlinearity',
|
| 61 |
+
'SparseGELU': 'nonlinearity',
|
| 62 |
+
'SparseTanh': 'nonlinearity',
|
| 63 |
+
'SparseActivation': 'nonlinearity',
|
| 64 |
+
'SparseLinear': 'linear',
|
| 65 |
+
'sparse_scaled_dot_product_attention': 'attention',
|
| 66 |
+
'SerializeMode': 'attention',
|
| 67 |
+
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
|
| 68 |
+
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
|
| 69 |
+
'SparseMultiHeadAttention': 'attention',
|
| 70 |
+
'SparseConv3d': 'conv',
|
| 71 |
+
'SparseInverseConv3d': 'conv',
|
| 72 |
+
'sparseconv3d_func': 'conv',
|
| 73 |
+
'SparseDownsample': 'spatial',
|
| 74 |
+
'SparseUpsample': 'spatial',
|
| 75 |
+
'SparseSubdivide' : 'spatial'
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
__submodules = ['transformer']
|
| 79 |
+
|
| 80 |
+
__all__ = list(__attributes.keys()) + __submodules
|
| 81 |
+
|
| 82 |
+
def __getattr__(name):
|
| 83 |
+
if name not in globals():
|
| 84 |
+
if name in __attributes:
|
| 85 |
+
module_name = __attributes[name]
|
| 86 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
| 87 |
+
globals()[name] = getattr(module, name)
|
| 88 |
+
elif name in __submodules:
|
| 89 |
+
module = importlib.import_module(f".{name}", __name__)
|
| 90 |
+
globals()[name] = module
|
| 91 |
+
else:
|
| 92 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
| 93 |
+
return globals()[name]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# For Pylance
|
| 97 |
+
if __name__ == '__main__':
|
| 98 |
+
from .basic import *
|
| 99 |
+
from .norm import *
|
| 100 |
+
from .nonlinearity import *
|
| 101 |
+
from .linear import *
|
| 102 |
+
from .attention import *
|
| 103 |
+
from .conv import *
|
| 104 |
+
from .spatial import *
|
| 105 |
+
import transformer
|
pixal3d/modules/sparse/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (2.64 kB). View file
|
|
|
pixal3d/modules/sparse/__pycache__/basic.cpython-310.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
pixal3d/modules/sparse/__pycache__/linear.cpython-310.pyc
ADDED
|
Binary file (884 Bytes). View file
|
|
|
pixal3d/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc
ADDED
|
Binary file (2.17 kB). View file
|
|
|
pixal3d/modules/sparse/__pycache__/norm.cpython-310.pyc
ADDED
|
Binary file (2.7 kB). View file
|
|
|
pixal3d/modules/sparse/__pycache__/spatial.cpython-310.pyc
ADDED
|
Binary file (5.11 kB). View file
|
|
|
pixal3d/modules/sparse/attention/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .full_attn import *
|
| 2 |
+
from .serialized_attn import *
|
| 3 |
+
from .windowed_attn import *
|
| 4 |
+
from .modules import *
|
| 5 |
+
from .spatial_sparse_attention.module.spatial_sparse_attention import SpatialSparseAttention
|
pixal3d/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (391 Bytes). View file
|
|
|
pixal3d/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc
ADDED
|
Binary file (7.3 kB). View file
|
|
|
pixal3d/modules/sparse/attention/__pycache__/modules.cpython-310.pyc
ADDED
|
Binary file (5.86 kB). View file
|
|
|