Yang2001 commited on
Commit
c272f3c
·
verified ·
1 Parent(s): ab8266a

Upload Pixal3D-D Space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +6 -7
  2. app.py +311 -0
  3. packages.txt +1 -0
  4. pixal3d/__init__.py +44 -0
  5. pixal3d/__pycache__/__init__.cpython-310.pyc +0 -0
  6. pixal3d/models/__init__.py +1 -0
  7. pixal3d/models/__pycache__/__init__.cpython-310.pyc +0 -0
  8. pixal3d/models/autoencoders/__pycache__/base.cpython-310.pyc +0 -0
  9. pixal3d/models/autoencoders/__pycache__/decoder.cpython-310.pyc +0 -0
  10. pixal3d/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc +0 -0
  11. pixal3d/models/autoencoders/__pycache__/distributions.cpython-310.pyc +0 -0
  12. pixal3d/models/autoencoders/__pycache__/encoder.cpython-310.pyc +0 -0
  13. pixal3d/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc +0 -0
  14. pixal3d/models/autoencoders/base.py +118 -0
  15. pixal3d/models/autoencoders/decoder.py +353 -0
  16. pixal3d/models/autoencoders/dense_vae.py +401 -0
  17. pixal3d/models/autoencoders/distributions.py +51 -0
  18. pixal3d/models/autoencoders/encoder.py +133 -0
  19. pixal3d/models/autoencoders/ss_vae.py +129 -0
  20. pixal3d/models/conditional_encoders/__init__.py +2 -0
  21. pixal3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc +0 -0
  22. pixal3d/models/conditional_encoders/__pycache__/dinov2_project_grid.cpython-310.pyc +0 -0
  23. pixal3d/models/conditional_encoders/dinov2_project_grid.py +750 -0
  24. pixal3d/models/transformers/__init__.py +2 -0
  25. pixal3d/models/transformers/__pycache__/__init__.cpython-310.pyc +0 -0
  26. pixal3d/models/transformers/__pycache__/dense_dit.cpython-310.pyc +0 -0
  27. pixal3d/models/transformers/__pycache__/sparse_dit.cpython-310.pyc +0 -0
  28. pixal3d/models/transformers/dense_dit.py +298 -0
  29. pixal3d/models/transformers/sparse_dit.py +469 -0
  30. pixal3d/modules/__pycache__/norm.cpython-310.pyc +0 -0
  31. pixal3d/modules/__pycache__/spatial.cpython-310.pyc +0 -0
  32. pixal3d/modules/__pycache__/utils.cpython-310.pyc +0 -0
  33. pixal3d/modules/attention/__init__.py +35 -0
  34. pixal3d/modules/attention/__pycache__/__init__.cpython-310.pyc +0 -0
  35. pixal3d/modules/attention/__pycache__/full_attn.cpython-310.pyc +0 -0
  36. pixal3d/modules/attention/__pycache__/modules.cpython-310.pyc +0 -0
  37. pixal3d/modules/attention/full_attn.py +140 -0
  38. pixal3d/modules/attention/modules.py +164 -0
  39. pixal3d/modules/norm.py +25 -0
  40. pixal3d/modules/sparse/__init__.py +105 -0
  41. pixal3d/modules/sparse/__pycache__/__init__.cpython-310.pyc +0 -0
  42. pixal3d/modules/sparse/__pycache__/basic.cpython-310.pyc +0 -0
  43. pixal3d/modules/sparse/__pycache__/linear.cpython-310.pyc +0 -0
  44. pixal3d/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc +0 -0
  45. pixal3d/modules/sparse/__pycache__/norm.cpython-310.pyc +0 -0
  46. pixal3d/modules/sparse/__pycache__/spatial.cpython-310.pyc +0 -0
  47. pixal3d/modules/sparse/attention/__init__.py +5 -0
  48. pixal3d/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc +0 -0
  49. pixal3d/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc +0 -0
  50. pixal3d/modules/sparse/attention/__pycache__/modules.cpython-310.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: Pixal3D D
3
- emoji: 👁
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Pixal3D-D
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.29.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ extra_gated_eu_disallowed: true
12
  ---
 
 
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pixal3D Gradio App
3
+ Upload an image and generate a 3D mesh. Supports both automatic (MoGe) and fixed camera parameters.
4
+ """
5
+
6
+ import os
7
+ os.environ["no_proxy"] = os.environ.get("no_proxy", "") + ",localhost,127.0.0.1"
8
+
9
+ import torch
10
+ import tempfile
11
+ import numpy as np
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+
15
+ import gradio as gr
16
+
17
+ from pixal3dpipeline2stage import Pixal3DPipeline2Stage
18
+ from pixal3dpipeline import Pixal3DPipeline
19
+
20
+
21
+ import trimesh
22
+ from trimesh.visual.material import PBRMaterial
23
+ from trimesh.transformations import rotation_matrix
24
+ # Static files directory for model viewer
25
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
26
+ SAVE_DIR = os.path.join(CURRENT_DIR, "gradio_outputs")
27
+
28
+ # Global pipeline reference
29
+ pipeline = None
30
+ rmbg = None
31
+
32
+
33
+ def load_pipeline(ckpt_dir="./ckpt", repo_id="Pixal3D/Pixal3D"):
34
+ """Load all weights at startup."""
35
+ global pipeline, rmbg
36
+ print("Loading Pixal3D 2-Stage pipeline (with MoGe + dense_check)...")
37
+ pipeline = Pixal3DPipeline2Stage.from_pretrained(
38
+ ckpt_dir=ckpt_dir,
39
+ repo_id=repo_id,
40
+ use_moge=True,
41
+ use_dense_check=True,
42
+ )
43
+ print("Pipeline loaded!")
44
+ print("Loading BiRefNet for background removal...")
45
+ from transformers import AutoModelForImageSegmentation
46
+ birefnet_model = AutoModelForImageSegmentation.from_pretrained(
47
+ 'ZhengPeng7/BiRefNet',
48
+ trust_remote_code=True,
49
+ ).to("cuda:0")
50
+ birefnet_model.eval()
51
+ rmbg = birefnet_model
52
+ print("BiRefNet loaded!")
53
+
54
+
55
+ def remove_background(image_np):
56
+ """Use BiRefNet to remove background and add alpha channel.
57
+ Input: numpy array (H, W, 3) RGB
58
+ Output: numpy array (H, W, 4) RGBA
59
+ """
60
+ pil_img = Image.fromarray(image_np[:, :, :3]).convert('RGB')
61
+ image_size = (1024, 1024)
62
+ transform_image = transforms.Compose([
63
+ transforms.Resize(image_size),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
66
+ ])
67
+ input_tensor = transform_image(pil_img).unsqueeze(0).to("cuda:0")
68
+ with torch.no_grad():
69
+ preds = rmbg(input_tensor)[-1].sigmoid().cpu()
70
+ pred = preds[0].squeeze()
71
+ pred_pil = transforms.ToPILImage()(pred)
72
+ mask = pred_pil.resize(pil_img.size)
73
+ mask = np.array(mask)
74
+ rgba = np.concatenate([np.array(pil_img), mask[..., None]], axis=-1)
75
+ return rgba
76
+
77
+
78
+ def preprocess_image(image, use_rmbg):
79
+ """Step 1: process image (background removal or use original), return immediately.
80
+
81
+ use_rmbg=True: run BiRefNet to remove background and generate RGBA
82
+ use_rmbg=False: directly use the original image (RGB or RGBA), skip background removal
83
+ """
84
+ if image is None:
85
+ return None
86
+
87
+ if use_rmbg:
88
+ # Run background removal
89
+ if rmbg is None:
90
+ gr.Warning("Background removal model not loaded.")
91
+ return None
92
+ processed = remove_background(image)
93
+ else:
94
+ # Directly use original image, no background removal
95
+ processed = image
96
+
97
+ os.makedirs("./gradio_outputs", exist_ok=True)
98
+ Image.fromarray(processed).save("./gradio_outputs/processed.png")
99
+ return processed
100
+
101
+
102
+ def infer_mesh(
103
+ processed,
104
+ use_fixed_camera,
105
+ camera_angle_x,
106
+ mesh_scale,
107
+ dense_steps,
108
+ dense_guidance_scale,
109
+ dense_seed,
110
+ sparse_512_steps,
111
+ sparse_512_guidance_scale,
112
+ sparse_1024_steps,
113
+ sparse_1024_guidance_scale,
114
+ sparse_seed,
115
+ dense_threshold,
116
+ mc_threshold,
117
+ ):
118
+ """Step 2: run 3D inference on the already-processed image."""
119
+ if processed is None or pipeline is None:
120
+ return None, None
121
+
122
+ tmp_input = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
123
+ Image.fromarray(processed).save(tmp_input.name)
124
+ input_path = tmp_input.name
125
+
126
+ try:
127
+ if use_fixed_camera:
128
+ mesh = Pixal3DPipeline.infer(
129
+ pipeline,
130
+ image=input_path,
131
+ camera_angle_x=camera_angle_x,
132
+ mesh_scale=mesh_scale,
133
+ dense_steps=int(dense_steps),
134
+ dense_guidance_scale=dense_guidance_scale,
135
+ dense_seed=int(dense_seed),
136
+ sparse_512_steps=int(sparse_512_steps),
137
+ sparse_512_guidance_scale=sparse_512_guidance_scale,
138
+ sparse_1024_steps=int(sparse_1024_steps),
139
+ sparse_1024_guidance_scale=sparse_1024_guidance_scale,
140
+ sparse_seed=int(sparse_seed),
141
+ dense_threshold=dense_threshold,
142
+ mc_threshold=mc_threshold,
143
+ )
144
+ else:
145
+ mesh = pipeline.infer(
146
+ image=input_path,
147
+ mesh_scale=mesh_scale,
148
+ optimize_mesh_scale=True,
149
+ target_padding=3,
150
+ max_optim_iterations=2,
151
+ dense_steps=int(dense_steps),
152
+ dense_guidance_scale=dense_guidance_scale,
153
+ dense_seed=int(dense_seed),
154
+ sparse_512_steps=int(sparse_512_steps),
155
+ sparse_512_guidance_scale=sparse_512_guidance_scale,
156
+ sparse_1024_steps=int(sparse_1024_steps),
157
+ sparse_1024_guidance_scale=sparse_1024_guidance_scale,
158
+ sparse_seed=int(sparse_seed),
159
+ dense_threshold=dense_threshold,
160
+ mc_threshold=mc_threshold,
161
+ )
162
+
163
+ ply_file = tempfile.NamedTemporaryFile(suffix=".ply", delete=False)
164
+ glb_file = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
165
+ ply_path = ply_file.name
166
+ glb_path = glb_file.name
167
+ ply_file.close()
168
+ glb_file.close()
169
+ mesh.export(ply_path)
170
+ # Export GLB with PBR material (same as hunyuan_app)
171
+
172
+ material = PBRMaterial(baseColorFactor=[102, 102, 102, 255])
173
+ clean_mesh = trimesh.Trimesh(mesh.vertices, mesh.faces)
174
+ clean_mesh.visual = trimesh.visual.TextureVisuals(material=material)
175
+ # Rotate mesh to desired view angle (only X rotation needed)
176
+ rot_x = rotation_matrix(np.radians(-90), [1, 0, 0])
177
+ clean_mesh.apply_transform(rot_x)
178
+ clean_mesh.export(glb_path)
179
+
180
+ return glb_path, ply_path
181
+
182
+ except Exception as e:
183
+ import traceback
184
+ traceback.print_exc()
185
+ return None, None
186
+ finally:
187
+ os.unlink(input_path)
188
+
189
+
190
+ def build_ui():
191
+ # Custom CSS to hide the download button in Model3D
192
+ custom_css = """
193
+ #model3d-viewer button[aria-label="下载"],
194
+ #model3d-viewer button[aria-label="Download"],
195
+ #model3d-viewer button[title="下载"],
196
+ #model3d-viewer button[title="Download"] {
197
+ display: none !important;
198
+ }
199
+ """
200
+
201
+ with gr.Blocks(title="Pixal3D", theme=gr.themes.Soft(), css=custom_css) as demo:
202
+ gr.Markdown("# Pixal3D: Pixel-Aligned 3D Generation from Images")
203
+
204
+ with gr.Row():
205
+ # Left column: input (scale=1)
206
+ with gr.Column(scale=1):
207
+ image_input = gr.Image(label="Input Image", type="numpy", image_mode=None)
208
+
209
+ processed_image = gr.Image(
210
+ label="Processed Image",
211
+ image_mode="RGBA",
212
+ type="numpy",
213
+ interactive=False,
214
+ )
215
+
216
+ use_rmbg = gr.Checkbox(
217
+ label="Remove Background",
218
+ value=True,
219
+ info="Checked: auto remove background via BiRefNet. Unchecked: use original image directly.",
220
+ )
221
+
222
+ use_fixed_camera = gr.Checkbox(
223
+ label="Use Fixed Camera Parameters",
224
+ value=False,
225
+ info="If checked, use manually set FOV/distance/mesh_scale instead of MoGe auto-estimation.",
226
+ )
227
+
228
+ with gr.Group(visible=False) as fixed_camera_group:
229
+ gr.Markdown("### Camera Parameters (fixed mode)")
230
+ camera_angle_x = gr.Number(value=0.2, label="camera_angle_x (rad)", step=0.01)
231
+
232
+ with gr.Group():
233
+ gr.Markdown("### Mesh Scale")
234
+ mesh_scale = gr.Number(value=0.5, label="mesh_scale", step=0.01,
235
+ info="Initial mesh scale. Fixed mode default: 0.9, Auto mode default: 0.5")
236
+
237
+ with gr.Accordion("Advanced Inference Parameters", open=False):
238
+ dense_steps = gr.Number(value=50, label="Dense Steps", step=1, precision=0)
239
+ dense_guidance_scale = gr.Number(value=7.0, label="Dense Guidance Scale", step=0.1)
240
+ dense_seed = gr.Number(value=0, label="Dense Seed", step=1, precision=0)
241
+ sparse_512_steps = gr.Number(value=30, label="Sparse 512 Steps", step=1, precision=0)
242
+ sparse_512_guidance_scale = gr.Number(value=7.0, label="Sparse 512 Guidance Scale", step=0.1)
243
+ sparse_1024_steps = gr.Number(value=15, label="Sparse 1024 Steps", step=1, precision=0)
244
+ sparse_1024_guidance_scale = gr.Number(value=7.0, label="Sparse 1024 Guidance Scale", step=0.1)
245
+ sparse_seed = gr.Number(value=0, label="Sparse Seed", step=1, precision=0)
246
+ dense_threshold = gr.Number(value=0.1, label="Dense Threshold", step=0.01)
247
+ mc_threshold = gr.Number(value=0.2, label="MC Threshold", step=0.01)
248
+
249
+ run_btn = gr.Button("Generate 3D Mesh", variant="primary", size="lg")
250
+
251
+ # Right column: output (scale=2)
252
+ with gr.Column(scale=2):
253
+ model_viewer = gr.Model3D(label="3D Mesh Preview", interactive=False, clear_color=[1.0, 1.0, 1.0, 1.0], elem_id="model3d-viewer")
254
+ output_file = gr.File(label="Download .ply")
255
+
256
+ # Toggle fixed camera group visibility and mesh_scale default
257
+ def on_toggle_fixed(use_fixed):
258
+ new_scale = 0.9 if use_fixed else 0.5
259
+ return gr.update(visible=use_fixed), gr.update(value=new_scale)
260
+
261
+ use_fixed_camera.change(
262
+ fn=on_toggle_fixed,
263
+ inputs=[use_fixed_camera],
264
+ outputs=[fixed_camera_group, mesh_scale],
265
+ )
266
+
267
+ # Step 1: preprocess image → show processed image immediately
268
+ # Step 2: run 3D inference → show mesh and download
269
+ run_btn.click(
270
+ fn=preprocess_image,
271
+ inputs=[image_input, use_rmbg],
272
+ outputs=[processed_image],
273
+ ).then(
274
+ fn=infer_mesh,
275
+ inputs=[
276
+ processed_image,
277
+ use_fixed_camera,
278
+ camera_angle_x,
279
+ mesh_scale,
280
+ dense_steps,
281
+ dense_guidance_scale,
282
+ dense_seed,
283
+ sparse_512_steps,
284
+ sparse_512_guidance_scale,
285
+ sparse_1024_steps,
286
+ sparse_1024_guidance_scale,
287
+ sparse_seed,
288
+ dense_threshold,
289
+ mc_threshold,
290
+ ],
291
+ outputs=[model_viewer, output_file],
292
+ )
293
+
294
+ demo.queue(api_open=False)
295
+ return demo
296
+
297
+
298
+ if __name__ == "__main__":
299
+ import argparse
300
+
301
+ parser = argparse.ArgumentParser()
302
+ parser.add_argument("--repo_id", type=str, default="TencentARC/Pixal3D-D")
303
+ args = parser.parse_args()
304
+
305
+ load_pipeline(repo_id=args.repo_id)
306
+
307
+ demo = build_ui()
308
+ demo.launch(
309
+ server_name="127.0.0.1",
310
+ share=True,
311
+ )
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libsparsehash-dev
pixal3d/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __modules__ = {}
4
+
5
+
6
+ def register(name):
7
+ def decorator(cls):
8
+ # Allow re-registration for checkpoint loading compatibility
9
+ # When torch.load triggers module re-import, the same class may be registered again
10
+ __modules__[name] = cls
11
+ return cls
12
+
13
+ return decorator
14
+
15
+
16
+ def find(name):
17
+ if name in __modules__:
18
+ return __modules__[name]
19
+ else:
20
+ try:
21
+ module_string = ".".join(name.split(".")[:-1])
22
+ cls_name = name.split(".")[-1]
23
+ module = importlib.import_module(module_string, package=None)
24
+ return getattr(module, cls_name)
25
+ except Exception as e:
26
+ raise ValueError(f"Module {name} not found!")
27
+
28
+
29
+ ### grammar sugar for logging utilities ###
30
+ import logging
31
+
32
+ logger = logging.getLogger("pixal3d")
33
+
34
+
35
+ def debug(*args, **kwargs):
36
+ logger.debug(*args, **kwargs)
37
+
38
+
39
+ def info(*args, **kwargs):
40
+ logger.info(*args, **kwargs)
41
+
42
+
43
+ def warn(*args, **kwargs):
44
+ logger.warning(*args, **kwargs)
pixal3d/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.25 kB). View file
 
pixal3d/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import conditional_encoders, transformers
pixal3d/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (241 Bytes). View file
 
pixal3d/models/autoencoders/__pycache__/base.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
pixal3d/models/autoencoders/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (8.77 kB). View file
 
pixal3d/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
pixal3d/models/autoencoders/__pycache__/distributions.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
pixal3d/models/autoencoders/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (3.77 kB). View file
 
pixal3d/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
pixal3d/models/autoencoders/base.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32
6
+ from ...modules import sparse as sp
7
+ from ...modules.transformer import AbsolutePositionEmbedder
8
+ from ...modules.sparse.transformer import SparseTransformerBlock
9
+
10
+
11
+ def block_attn_config(self):
12
+ """
13
+ Return the attention configuration of the model.
14
+ """
15
+ for i in range(self.num_blocks):
16
+ if self.attn_mode == "shift_window":
17
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
18
+ elif self.attn_mode == "shift_sequence":
19
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
20
+ elif self.attn_mode == "shift_order":
21
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
22
+ elif self.attn_mode == "full":
23
+ yield "full", None, None, None, None
24
+ elif self.attn_mode == "swin":
25
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
26
+
27
+
28
+ class SparseTransformerBase(nn.Module):
29
+ """
30
+ Sparse Transformer without output layers.
31
+ Serve as the base class for encoder and decoder.
32
+ """
33
+ def __init__(
34
+ self,
35
+ in_channels: int,
36
+ model_channels: int,
37
+ num_blocks: int,
38
+ num_heads: Optional[int] = None,
39
+ num_head_channels: Optional[int] = 64,
40
+ mlp_ratio: float = 4.0,
41
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
42
+ window_size: Optional[int] = None,
43
+ pe_mode: Literal["ape", "rope"] = "ape",
44
+ use_fp16: bool = False,
45
+ use_checkpoint: bool = False,
46
+ qk_rms_norm: bool = False,
47
+ ):
48
+ super().__init__()
49
+ self.in_channels = in_channels
50
+ self.model_channels = model_channels
51
+ self.num_blocks = num_blocks
52
+ self.window_size = window_size
53
+ self.num_heads = num_heads or model_channels // num_head_channels
54
+ self.mlp_ratio = mlp_ratio
55
+ self.attn_mode = attn_mode
56
+ self.pe_mode = pe_mode
57
+ self.use_fp16 = use_fp16
58
+ self.use_checkpoint = use_checkpoint
59
+ self.qk_rms_norm = qk_rms_norm
60
+ self.dtype = torch.float16 if use_fp16 else torch.float32
61
+
62
+ if pe_mode == "ape":
63
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
64
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
+ self.blocks = nn.ModuleList([
66
+ SparseTransformerBlock(
67
+ model_channels,
68
+ num_heads=self.num_heads,
69
+ mlp_ratio=self.mlp_ratio,
70
+ attn_mode=attn_mode,
71
+ window_size=window_size,
72
+ shift_sequence=shift_sequence,
73
+ shift_window=shift_window,
74
+ serialize_mode=serialize_mode,
75
+ use_checkpoint=self.use_checkpoint,
76
+ use_rope=(pe_mode == "rope"),
77
+ qk_rms_norm=self.qk_rms_norm,
78
+ )
79
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80
+ ])
81
+
82
+ @property
83
+ def device(self) -> torch.device:
84
+ """
85
+ Return the device of the model.
86
+ """
87
+ return next(self.parameters()).device
88
+
89
+ def convert_to_fp16(self) -> None:
90
+ """
91
+ Convert the torso of the model to float16.
92
+ """
93
+ # self.blocks.apply(convert_module_to_f16)
94
+ self.apply(convert_module_to_f16)
95
+
96
+ def convert_to_fp32(self) -> None:
97
+ """
98
+ Convert the torso of the model to float32.
99
+ """
100
+ self.blocks.apply(convert_module_to_f32)
101
+
102
+ def initialize_weights(self) -> None:
103
+ # Initialize transformer layers:
104
+ def _basic_init(module):
105
+ if isinstance(module, nn.Linear):
106
+ torch.nn.init.xavier_uniform_(module.weight)
107
+ if module.bias is not None:
108
+ nn.init.constant_(module.bias, 0)
109
+ self.apply(_basic_init)
110
+
111
+ def forward(self, x: sp.SparseTensor, factor: float = None) -> sp.SparseTensor:
112
+ h = self.input_layer(x)
113
+ if self.pe_mode == "ape":
114
+ h = h + self.pos_embedder(x.coords[:, 1:], factor)
115
+ h = h.type(self.dtype)
116
+ for block in self.blocks:
117
+ h = block(h)
118
+ return h
pixal3d/models/autoencoders/decoder.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
+ from ...modules import sparse as sp
8
+ from .base import SparseTransformerBase
9
+
10
+
11
+ class SparseSubdivideBlock3d(nn.Module):
12
+
13
+ def __init__(
14
+ self,
15
+ channels: int,
16
+ out_channels: Optional[int] = None,
17
+ use_checkpoint: bool = False,
18
+ ):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.out_channels = out_channels or channels
22
+ self.use_checkpoint = use_checkpoint
23
+
24
+ self.act_layers = nn.Sequential(
25
+ sp.SparseConv3d(channels, self.out_channels, 3, padding=1),
26
+ sp.SparseSiLU()
27
+ )
28
+
29
+ self.sub = sp.SparseSubdivide()
30
+
31
+ self.out_layers = nn.Sequential(
32
+ sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1),
33
+ sp.SparseSiLU(),
34
+ )
35
+
36
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
37
+ h = self.act_layers(x)
38
+ h = self.sub(h)
39
+ h = self.out_layers(h)
40
+ return h
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ if self.use_checkpoint:
44
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
45
+ else:
46
+ return self._forward(x)
47
+
48
+
49
+ class SparseSDFDecoder(SparseTransformerBase):
50
+ def __init__(
51
+ self,
52
+ resolution: int,
53
+ model_channels: int,
54
+ latent_channels: int,
55
+ num_blocks: int,
56
+ num_heads: Optional[int] = None,
57
+ num_head_channels: Optional[int] = 64,
58
+ mlp_ratio: float = 4,
59
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
60
+ window_size: int = 8,
61
+ pe_mode: Literal["ape", "rope"] = "ape",
62
+ use_fp16: bool = False,
63
+ use_checkpoint: bool = False,
64
+ qk_rms_norm: bool = False,
65
+ representation_config: dict = None,
66
+ out_channels: int = 1,
67
+ chunk_size: int = 1,
68
+ ):
69
+ super().__init__(
70
+ in_channels=latent_channels,
71
+ model_channels=model_channels,
72
+ num_blocks=num_blocks,
73
+ num_heads=num_heads,
74
+ num_head_channels=num_head_channels,
75
+ mlp_ratio=mlp_ratio,
76
+ attn_mode=attn_mode,
77
+ window_size=window_size,
78
+ pe_mode=pe_mode,
79
+ use_fp16=use_fp16,
80
+ use_checkpoint=use_checkpoint,
81
+ qk_rms_norm=qk_rms_norm,
82
+ )
83
+ self.resolution = resolution
84
+ self.rep_config = representation_config
85
+ self.out_channels = out_channels
86
+ self.chunk_size = chunk_size
87
+ self.upsample = nn.ModuleList([
88
+ SparseSubdivideBlock3d(
89
+ channels=model_channels,
90
+ out_channels=model_channels // 4,
91
+ use_checkpoint=use_checkpoint,
92
+ ),
93
+ SparseSubdivideBlock3d(
94
+ channels=model_channels // 4,
95
+ out_channels=model_channels // 8,
96
+ use_checkpoint=use_checkpoint,
97
+ ),
98
+ SparseSubdivideBlock3d(
99
+ channels=model_channels // 8,
100
+ out_channels=model_channels // 16,
101
+ use_checkpoint=use_checkpoint,
102
+ )
103
+ ])
104
+
105
+ self.out_layer = sp.SparseLinear(model_channels // 16, self.out_channels)
106
+ self.out_active = sp.SparseTanh()
107
+
108
+ self.initialize_weights()
109
+ if use_fp16:
110
+ self.convert_to_fp16()
111
+
112
+ def initialize_weights(self) -> None:
113
+ super().initialize_weights()
114
+ # Zero-out output layers:
115
+ nn.init.constant_(self.out_layer.weight, 0)
116
+ nn.init.constant_(self.out_layer.bias, 0)
117
+
118
+ def convert_to_fp16(self) -> None:
119
+ """
120
+ Convert the torso of the model to float16.
121
+ """
122
+ super().convert_to_fp16()
123
+ self.upsample.apply(convert_module_to_f16)
124
+
125
+ def convert_to_fp32(self) -> None:
126
+ """
127
+ Convert the torso of the model to float32.
128
+ """
129
+ super().convert_to_fp32()
130
+ self.upsample.apply(convert_module_to_f32)
131
+
132
+ @torch.no_grad()
133
+ def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4):
134
+
135
+ sub_resolution = self.resolution // chunk_size
136
+ upsample_ratio = 8 # hard-coded here
137
+ assert sub_resolution % padding == 0
138
+ out = []
139
+
140
+ for i in range(chunk_size):
141
+ for j in range(chunk_size):
142
+ for k in range(chunk_size):
143
+ # Calculate padded boundaries
144
+ start_x = max(0, i * sub_resolution - padding)
145
+ end_x = min((i + 1) * sub_resolution + padding, self.resolution)
146
+ start_y = max(0, j * sub_resolution - padding)
147
+ end_y = min((j + 1) * sub_resolution + padding, self.resolution)
148
+ start_z = max(0, k * sub_resolution - padding)
149
+ end_z = min((k + 1) * sub_resolution + padding, self.resolution)
150
+
151
+ # Store original (unpadded) boundaries for later cropping
152
+ orig_start_x = i * sub_resolution
153
+ orig_end_x = (i + 1) * sub_resolution
154
+ orig_start_y = j * sub_resolution
155
+ orig_end_y = (j + 1) * sub_resolution
156
+ orig_start_z = k * sub_resolution
157
+ orig_end_z = (k + 1) * sub_resolution
158
+
159
+ mask = torch.logical_and(
160
+ torch.logical_and(
161
+ torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
162
+ torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
163
+ ),
164
+ torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
165
+ )
166
+
167
+ if mask.sum() > 0:
168
+ # Get the coordinates and shift them to local space
169
+ coords = x.coords[mask].clone()
170
+ # Shift to local coordinates
171
+ coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
172
+ device=coords.device).view(1, 3)
173
+
174
+ chunk_tensor = sp.SparseTensor(x.feats[mask], coords)
175
+ # Store the boundaries and offsets as metadata for later reconstruction
176
+ chunk_tensor.bounds = {
177
+ 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
178
+ 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
179
+ }
180
+ out.append(chunk_tensor)
181
+
182
+ del mask
183
+ torch.cuda.empty_cache()
184
+ return out
185
+
186
+ @torch.no_grad()
187
+ def split_single_chunk(self, x: sp.SparseTensor, chunk_size=4, padding=4):
188
+ sub_resolution = self.resolution // chunk_size
189
+ upsample_ratio = 8 # hard-coded here
190
+ assert sub_resolution % padding == 0
191
+
192
+ mask_sum = -1
193
+ while mask_sum < 1:
194
+ orig_start_x = random.randint(0, self.resolution - sub_resolution)
195
+ orig_end_x = orig_start_x + sub_resolution
196
+ orig_start_y = random.randint(0, self.resolution - sub_resolution)
197
+ orig_end_y = orig_start_y + sub_resolution
198
+ orig_start_z = random.randint(0, self.resolution - sub_resolution)
199
+ orig_end_z = orig_start_z + sub_resolution
200
+ start_x = max(0, orig_start_x - padding)
201
+ end_x = min(orig_end_x + padding, self.resolution)
202
+ start_y = max(0, orig_start_y - padding)
203
+ end_y = min(orig_end_y + padding, self.resolution)
204
+ start_z = max(0, orig_start_z - padding)
205
+ end_z = min(orig_end_z + padding, self.resolution)
206
+
207
+ mask_ori = torch.logical_and(
208
+ torch.logical_and(
209
+ torch.logical_and(x.coords[:, 1] >= orig_start_x, x.coords[:, 1] < orig_end_x),
210
+ torch.logical_and(x.coords[:, 2] >= orig_start_y, x.coords[:, 2] < orig_end_y)
211
+ ),
212
+ torch.logical_and(x.coords[:, 3] >= orig_start_z, x.coords[:, 3] < orig_end_z)
213
+ )
214
+ mask_sum = mask_ori.sum()
215
+
216
+ # Store the boundaries and offsets as metadata for later reconstruction
217
+ bounds = {
218
+ 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
219
+ 'start': (start_x, end_x, start_y, end_y, start_z, end_z),
220
+ 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
221
+ }
222
+ return bounds
223
+
224
+ def forward_single_chunk(self, x: sp.SparseTensor, padding=4):
225
+
226
+ bounds = self.split_single_chunk(x, self.chunk_size, padding=padding)
227
+
228
+ start_x, end_x, start_y, end_y, start_z, end_z = bounds['start']
229
+ mask = torch.logical_and(
230
+ torch.logical_and(
231
+ torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
232
+ torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
233
+ ),
234
+ torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
235
+ )
236
+
237
+ # Shift to local coordinates
238
+ coords = x.coords.clone()
239
+ coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
240
+ device=coords.device).view(1, 3)
241
+
242
+ chunk = sp.SparseTensor(x.feats[mask], coords[mask])
243
+
244
+ chunk_result = self.upsamples(chunk)
245
+
246
+ coords = chunk_result.coords.clone()
247
+
248
+ # Restore global coordinates
249
+ offsets = torch.tensor(bounds['offsets'],
250
+ device=coords.device).view(1, 3)
251
+ coords[:, 1:] = coords[:, 1:] + offsets
252
+
253
+ # Filter points within original bounds
254
+ original = bounds['original']
255
+ within_bounds = torch.logical_and(
256
+ torch.logical_and(
257
+ torch.logical_and(
258
+ coords[:, 1] >= original[0],
259
+ coords[:, 1] < original[1]
260
+ ),
261
+ torch.logical_and(
262
+ coords[:, 2] >= original[2],
263
+ coords[:, 2] < original[3]
264
+ )
265
+ ),
266
+ torch.logical_and(
267
+ coords[:, 3] >= original[4],
268
+ coords[:, 3] < original[5]
269
+ )
270
+ )
271
+
272
+ final_coords = coords[within_bounds]
273
+ final_feats = chunk_result.feats[within_bounds]
274
+
275
+ return sp.SparseTensor(final_feats, final_coords)
276
+
277
+ def upsamples(self, x, return_feat: bool = False):
278
+ dtype = x.dtype
279
+ for block in self.upsample:
280
+ x = block(x)
281
+ x = x.type(dtype)
282
+
283
+ output = self.out_active(self.out_layer(x))
284
+
285
+ if return_feat:
286
+ return output, x
287
+ else:
288
+ return output
289
+
290
+ def forward(self, x: sp.SparseTensor, factor: float = None, return_feat: bool = False):
291
+ h = super().forward(x, factor)
292
+ if self.chunk_size <= 1:
293
+ for block in self.upsample:
294
+ h = block(h)
295
+ h = h.type(x.dtype)
296
+
297
+ if return_feat:
298
+ return self.out_active(self.out_layer(h)), h
299
+
300
+ h = self.out_layer(h)
301
+ h = self.out_active(h)
302
+ return h
303
+ else:
304
+ if self.training:
305
+ return self.forward_single_chunk(h)
306
+ else:
307
+ batch_size = x.shape[0]
308
+ chunks = self.split_for_meshing(h, chunk_size=self.chunk_size)
309
+ all_coords, all_feats = [], []
310
+ for chunk_idx, chunk in enumerate(chunks):
311
+ chunk_result = self.upsamples(chunk)
312
+
313
+ for b in range(batch_size):
314
+ mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1)
315
+ if mask.numel() > 0:
316
+ coords = chunk_result.coords[mask].clone()
317
+
318
+ # Restore global coordinates
319
+ offsets = torch.tensor(chunk.bounds['offsets'],
320
+ device=coords.device).view(1, 3)
321
+ coords[:, 1:] = coords[:, 1:] + offsets
322
+
323
+ # Filter points within original bounds
324
+ bounds = chunk.bounds['original']
325
+ within_bounds = torch.logical_and(
326
+ torch.logical_and(
327
+ torch.logical_and(
328
+ coords[:, 1] >= bounds[0],
329
+ coords[:, 1] < bounds[1]
330
+ ),
331
+ torch.logical_and(
332
+ coords[:, 2] >= bounds[2],
333
+ coords[:, 2] < bounds[3]
334
+ )
335
+ ),
336
+ torch.logical_and(
337
+ coords[:, 3] >= bounds[4],
338
+ coords[:, 3] < bounds[5]
339
+ )
340
+ )
341
+
342
+ if within_bounds.any():
343
+ all_coords.append(coords[within_bounds])
344
+ all_feats.append(chunk_result.feats[mask][within_bounds])
345
+
346
+ if not self.training:
347
+ torch.cuda.empty_cache()
348
+
349
+ final_coords = torch.cat(all_coords)
350
+ final_feats = torch.cat(all_feats)
351
+
352
+ return sp.SparseTensor(final_feats, final_coords)
353
+
pixal3d/models/autoencoders/dense_vae.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import trimesh
6
+ from skimage import measure
7
+ from ...modules.norm import GroupNorm32, ChannelLayerNorm32
8
+ from ...modules.spatial import pixel_shuffle_3d
9
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
10
+ from .distributions import DiagonalGaussianDistribution
11
+
12
+
13
+ def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
14
+ """
15
+ Return a normalization layer.
16
+ """
17
+ if norm_type == "group":
18
+ return GroupNorm32(32, *args, **kwargs)
19
+ elif norm_type == "layer":
20
+ return ChannelLayerNorm32(*args, **kwargs)
21
+ else:
22
+ raise ValueError(f"Invalid norm type {norm_type}")
23
+
24
+
25
+ class ResBlock3d(nn.Module):
26
+ def __init__(
27
+ self,
28
+ channels: int,
29
+ out_channels: Optional[int] = None,
30
+ norm_type: Literal["group", "layer"] = "layer",
31
+ ):
32
+ super().__init__()
33
+ self.channels = channels
34
+ self.out_channels = out_channels or channels
35
+
36
+ self.norm1 = norm_layer(norm_type, channels)
37
+ self.norm2 = norm_layer(norm_type, self.out_channels)
38
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
39
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
40
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ h = self.norm1(x)
44
+ h = F.silu(h)
45
+ h = self.conv1(h)
46
+ h = self.norm2(h)
47
+ h = F.silu(h)
48
+ h = self.conv2(h)
49
+ h = h + self.skip_connection(x)
50
+ return h
51
+
52
+
53
+ class DownsampleBlock3d(nn.Module):
54
+ def __init__(
55
+ self,
56
+ in_channels: int,
57
+ out_channels: int,
58
+ mode: Literal["conv", "avgpool"] = "conv",
59
+ ):
60
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
61
+
62
+ super().__init__()
63
+ self.in_channels = in_channels
64
+ self.out_channels = out_channels
65
+
66
+ if mode == "conv":
67
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
68
+ elif mode == "avgpool":
69
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ if hasattr(self, "conv"):
73
+ return self.conv(x)
74
+ else:
75
+ return F.avg_pool3d(x, 2)
76
+
77
+
78
+ class UpsampleBlock3d(nn.Module):
79
+ def __init__(
80
+ self,
81
+ in_channels: int,
82
+ out_channels: int,
83
+ mode: Literal["conv", "nearest"] = "conv",
84
+ ):
85
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
86
+
87
+ super().__init__()
88
+ self.in_channels = in_channels
89
+ self.out_channels = out_channels
90
+
91
+ if mode == "conv":
92
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
93
+ elif mode == "nearest":
94
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ if hasattr(self, "conv"):
98
+ x = self.conv(x)
99
+ return pixel_shuffle_3d(x, 2)
100
+ else:
101
+ return F.interpolate(x, scale_factor=2, mode="nearest")
102
+
103
+
104
+ class SparseStructureEncoder(nn.Module):
105
+ """
106
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
107
+
108
+ Args:
109
+ in_channels (int): Channels of the input.
110
+ latent_channels (int): Channels of the latent representation.
111
+ num_res_blocks (int): Number of residual blocks at each resolution.
112
+ channels (List[int]): Channels of the encoder blocks.
113
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
114
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
115
+ use_fp16 (bool): Whether to use FP16.
116
+ """
117
+ def __init__(
118
+ self,
119
+ in_channels: int,
120
+ latent_channels: int,
121
+ num_res_blocks: int,
122
+ channels: List[int],
123
+ num_res_blocks_middle: int = 2,
124
+ norm_type: Literal["group", "layer"] = "layer",
125
+ use_fp16: bool = False,
126
+ use_checkpoint: bool = False,
127
+ ):
128
+ super().__init__()
129
+ self.in_channels = in_channels
130
+ self.latent_channels = latent_channels
131
+ self.num_res_blocks = num_res_blocks
132
+ self.channels = channels
133
+ self.num_res_blocks_middle = num_res_blocks_middle
134
+ self.norm_type = norm_type
135
+ self.use_fp16 = use_fp16
136
+ self.dtype = torch.float16 if use_fp16 else torch.float32
137
+ self.use_checkpoint = use_checkpoint
138
+
139
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
140
+
141
+ self.blocks = nn.ModuleList([])
142
+ for i, ch in enumerate(channels):
143
+ self.blocks.extend([
144
+ ResBlock3d(ch, ch)
145
+ for _ in range(num_res_blocks)
146
+ ])
147
+ if i < len(channels) - 1:
148
+ self.blocks.append(
149
+ DownsampleBlock3d(ch, channels[i+1])
150
+ )
151
+
152
+ self.middle_block = nn.Sequential(*[
153
+ ResBlock3d(channels[-1], channels[-1])
154
+ for _ in range(num_res_blocks_middle)
155
+ ])
156
+
157
+ self.out_layer = nn.Sequential(
158
+ norm_layer(norm_type, channels[-1]),
159
+ nn.SiLU(),
160
+ nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
161
+ )
162
+
163
+ if use_fp16:
164
+ self.convert_to_fp16()
165
+
166
+ @property
167
+ def device(self) -> torch.device:
168
+ """
169
+ Return the device of the model.
170
+ """
171
+ return next(self.parameters()).device
172
+
173
+ def convert_to_fp16(self) -> None:
174
+ """
175
+ Convert the torso of the model to float16.
176
+ """
177
+ self.use_fp16 = True
178
+ self.dtype = torch.float16
179
+ self.blocks.apply(convert_module_to_f16)
180
+ self.middle_block.apply(convert_module_to_f16)
181
+
182
+ def convert_to_fp32(self) -> None:
183
+ """
184
+ Convert the torso of the model to float32.
185
+ """
186
+ self.use_fp16 = False
187
+ self.dtype = torch.float32
188
+ self.blocks.apply(convert_module_to_f32)
189
+ self.middle_block.apply(convert_module_to_f32)
190
+
191
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
192
+ h = self.input_layer(x)
193
+
194
+ for block in self.blocks:
195
+ h = block(h)
196
+ h = self.middle_block(h)
197
+
198
+ h = self.out_layer(h)
199
+
200
+ return h
201
+
202
+
203
+ class SparseStructureDecoder(nn.Module):
204
+ """
205
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
206
+
207
+ Args:
208
+ out_channels (int): Channels of the output.
209
+ latent_channels (int): Channels of the latent representation.
210
+ num_res_blocks (int): Number of residual blocks at each resolution.
211
+ channels (List[int]): Channels of the decoder blocks.
212
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
213
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
214
+ use_fp16 (bool): Whether to use FP16.
215
+ """
216
+ def __init__(
217
+ self,
218
+ out_channels: int,
219
+ latent_channels: int,
220
+ num_res_blocks: int,
221
+ channels: List[int],
222
+ num_res_blocks_middle: int = 2,
223
+ norm_type: Literal["group", "layer"] = "layer",
224
+ use_fp16: bool = False,
225
+ use_checkpoint: bool = False,
226
+ ):
227
+ super().__init__()
228
+ self.out_channels = out_channels
229
+ self.latent_channels = latent_channels
230
+ self.num_res_blocks = num_res_blocks
231
+ self.channels = channels
232
+ self.num_res_blocks_middle = num_res_blocks_middle
233
+ self.norm_type = norm_type
234
+ self.use_fp16 = use_fp16
235
+ self.dtype = torch.float16 if use_fp16 else torch.float32
236
+ self.use_checkpoint = use_checkpoint
237
+
238
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
239
+
240
+ self.middle_block = nn.Sequential(*[
241
+ ResBlock3d(channels[0], channels[0])
242
+ for _ in range(num_res_blocks_middle)
243
+ ])
244
+
245
+ self.blocks = nn.ModuleList([])
246
+ for i, ch in enumerate(channels):
247
+ self.blocks.extend([
248
+ ResBlock3d(ch, ch)
249
+ for _ in range(num_res_blocks)
250
+ ])
251
+ if i < len(channels) - 1:
252
+ self.blocks.append(
253
+ UpsampleBlock3d(ch, channels[i+1])
254
+ )
255
+
256
+ self.out_layer = nn.Sequential(
257
+ norm_layer(norm_type, channels[-1]),
258
+ nn.SiLU(),
259
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1)
260
+ )
261
+
262
+ if use_fp16:
263
+ self.convert_to_fp16()
264
+
265
+ @property
266
+ def device(self) -> torch.device:
267
+ """
268
+ Return the device of the model.
269
+ """
270
+ return next(self.parameters()).device
271
+
272
+ def convert_to_fp16(self) -> None:
273
+ """
274
+ Convert the torso of the model to float16.
275
+ """
276
+ self.use_fp16 = True
277
+ self.dtype = torch.float16
278
+ # self.blocks.apply(convert_module_to_f16)
279
+ # self.middle_block.apply(convert_module_to_f16)
280
+ self.apply(convert_module_to_f16)
281
+
282
+ def convert_to_fp32(self) -> None:
283
+ """
284
+ Convert the torso of the model to float32.
285
+ """
286
+ self.use_fp16 = False
287
+ self.dtype = torch.float32
288
+ self.blocks.apply(convert_module_to_f32)
289
+ self.middle_block.apply(convert_module_to_f32)
290
+
291
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
292
+ h = self.input_layer(x)
293
+
294
+ h = self.middle_block(h)
295
+ for block in self.blocks:
296
+ h = block(h)
297
+
298
+ h = self.out_layer(h)
299
+ return h
300
+
301
+
302
+ class DenseShapeVAE(nn.Module):
303
+ def __init__(self,
304
+ embed_dim: int = 0,
305
+ model_channels_encoder: list = [32, 128, 512],
306
+ model_channels_decoder: list = [512, 128, 32],
307
+ num_res_blocks_encoder: int = 2,
308
+ num_res_blocks_middle_encoder: int = 2,
309
+ num_res_blocks_decoder: int = 2,
310
+ num_res_blocks_middle_decoder: int=2,
311
+ in_channels: int = 1,
312
+ out_channels: int = 1,
313
+ use_fp16: bool = False,
314
+ use_checkpoint: bool = False,
315
+ latents_scale: float = 1.0,
316
+ latents_shift: float = 0.0):
317
+
318
+ super().__init__()
319
+
320
+ self.use_checkpoint = use_checkpoint
321
+ self.latents_scale = latents_scale
322
+ self.latents_shift = latents_shift
323
+
324
+ self.encoder = SparseStructureEncoder(
325
+ in_channels=in_channels,
326
+ latent_channels=embed_dim,
327
+ num_res_blocks=num_res_blocks_encoder,
328
+ channels=model_channels_encoder,
329
+ num_res_blocks_middle=num_res_blocks_middle_encoder,
330
+ use_fp16=use_fp16,
331
+ use_checkpoint=use_checkpoint,
332
+ )
333
+
334
+ self.decoder = SparseStructureDecoder(
335
+ num_res_blocks=num_res_blocks_decoder,
336
+ num_res_blocks_middle=num_res_blocks_middle_decoder,
337
+ channels=model_channels_decoder,
338
+ latent_channels=embed_dim,
339
+ out_channels=out_channels,
340
+ use_fp16=use_fp16,
341
+ use_checkpoint=use_checkpoint,
342
+ )
343
+
344
+ self.embed_dim = embed_dim
345
+
346
+ def encode(self, batch, sample_posterior: bool = True):
347
+
348
+ x = batch['dense_index'] * 2.0 - 1.0
349
+ h = self.encoder(x)
350
+ posterior = DiagonalGaussianDistribution(h, feat_dim=1)
351
+ if sample_posterior:
352
+ z = posterior.sample()
353
+ else:
354
+ z = posterior.mode()
355
+
356
+ return z, posterior
357
+
358
+ def forward(self, batch):
359
+
360
+ z, posterior = self.encode(batch)
361
+ reconst_x = self.decoder(z)
362
+ outputs = {'reconst_x': reconst_x, 'posterior': posterior}
363
+
364
+ return outputs
365
+
366
+ def decode_mesh(self,
367
+ latents,
368
+ voxel_resolution: int = 64,
369
+ mc_threshold: float = 0.5,
370
+ return_index: bool = False):
371
+ x = self.decoder(latents)
372
+ if return_index:
373
+ outputs = []
374
+ for i in range(len(x)):
375
+ occ = x[i].sigmoid()
376
+ occ = (occ >= mc_threshold).float().squeeze(0)
377
+ index = occ.unsqueeze(0).nonzero()
378
+ outputs.append(index)
379
+ else:
380
+ outputs = self.dense2mesh(x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold)
381
+
382
+ return outputs
383
+
384
+ def dense2mesh(self,
385
+ x: torch.FloatTensor,
386
+ voxel_resolution: int = 64,
387
+ mc_threshold: float = 0.5):
388
+
389
+ meshes = []
390
+ for i in range(len(x)):
391
+ occ = x[i].sigmoid()
392
+ occ = (occ >= 0.1).float().squeeze(0).cpu().detach().numpy()
393
+ vertices, faces, _, _ = measure.marching_cubes(
394
+ occ,
395
+ mc_threshold,
396
+ method="lewiner",
397
+ )
398
+ vertices = vertices / voxel_resolution * 2 - 1
399
+ meshes.append(trimesh.Trimesh(vertices, faces))
400
+
401
+ return meshes
pixal3d/models/autoencoders/distributions.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Union, List
4
+
5
+
6
+ class DiagonalGaussianDistribution(object):
7
+ def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
8
+ self.feat_dim = feat_dim
9
+ self.parameters = parameters
10
+
11
+ if isinstance(parameters, list):
12
+ self.mean = parameters[0]
13
+ self.logvar = parameters[1]
14
+ else:
15
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
16
+
17
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
18
+ self.deterministic = deterministic
19
+ self.std = torch.exp(0.5 * self.logvar)
20
+ self.var = torch.exp(self.logvar)
21
+ if self.deterministic:
22
+ self.var = self.std = torch.zeros_like(self.mean)
23
+
24
+ def sample(self):
25
+ x = self.mean + self.std * torch.randn_like(self.mean)
26
+ return x
27
+
28
+ def kl(self, other=None, dims=(1, 2, 3)):
29
+ if self.deterministic:
30
+ return torch.Tensor([0.])
31
+ else:
32
+ if other is None:
33
+ return 0.5 * torch.mean(torch.pow(self.mean, 2)
34
+ + self.var - 1.0 - self.logvar,
35
+ dim=dims)
36
+ else:
37
+ return 0.5 * torch.mean(
38
+ torch.pow(self.mean - other.mean, 2) / other.var
39
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
40
+ dim=dims)
41
+
42
+ def nll(self, sample, dims=(1, 2, 3)):
43
+ if self.deterministic:
44
+ return torch.Tensor([0.])
45
+ logtwopi = np.log(2.0 * np.pi)
46
+ return 0.5 * torch.sum(
47
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
48
+ dim=dims)
49
+
50
+ def mode(self):
51
+ return self.mean
pixal3d/models/autoencoders/encoder.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ...modules import sparse as sp
6
+ from .base import SparseTransformerBase
7
+
8
+
9
+ class SparseDownBlock3d(nn.Module):
10
+
11
+ def __init__(
12
+ self,
13
+ channels: int,
14
+ out_channels: Optional[int] = None,
15
+ num_groups: int = 32,
16
+ use_checkpoint: bool = False,
17
+ ):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.out_channels = out_channels or channels
21
+
22
+ self.act_layers = nn.Sequential(
23
+ sp.SparseGroupNorm32(num_groups, channels),
24
+ sp.SparseSiLU()
25
+ )
26
+
27
+ self.down = sp.SparseDownsample(2)
28
+ self.out_layers = nn.Sequential(
29
+ sp.SparseConv3d(channels, self.out_channels, 3, padding=1),
30
+ sp.SparseGroupNorm32(num_groups, self.out_channels),
31
+ sp.SparseSiLU(),
32
+ sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1),
33
+ )
34
+
35
+ if self.out_channels == channels:
36
+ self.skip_connection = nn.Identity()
37
+ else:
38
+ self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1)
39
+
40
+ self.use_checkpoint = use_checkpoint
41
+
42
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
43
+ h = self.act_layers(x)
44
+ h = self.down(h)
45
+ x = self.down(x)
46
+ h = self.out_layers(h)
47
+ h = h + self.skip_connection(x)
48
+ return h
49
+
50
+ def forward(self, x: torch.Tensor):
51
+ if self.use_checkpoint:
52
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
53
+ else:
54
+ return self._forward(x)
55
+
56
+
57
+ class SparseSDFEncoder(SparseTransformerBase):
58
+ def __init__(
59
+ self,
60
+ resolution: int,
61
+ in_channels: int,
62
+ model_channels: int,
63
+ latent_channels: int,
64
+ num_blocks: int,
65
+ num_heads: Optional[int] = None,
66
+ num_head_channels: Optional[int] = 64,
67
+ mlp_ratio: float = 4,
68
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
69
+ window_size: int = 8,
70
+ pe_mode: Literal["ape", "rope"] = "ape",
71
+ use_fp16: bool = False,
72
+ use_checkpoint: bool = False,
73
+ qk_rms_norm: bool = False,
74
+ ):
75
+ super().__init__(
76
+ in_channels=in_channels,
77
+ model_channels=model_channels,
78
+ num_blocks=num_blocks,
79
+ num_heads=num_heads,
80
+ num_head_channels=num_head_channels,
81
+ mlp_ratio=mlp_ratio,
82
+ attn_mode=attn_mode,
83
+ window_size=window_size,
84
+ pe_mode=pe_mode,
85
+ use_fp16=use_fp16,
86
+ use_checkpoint=use_checkpoint,
87
+ qk_rms_norm=qk_rms_norm,
88
+ )
89
+
90
+ self.input_layer1 = sp.SparseLinear(1, model_channels // 16)
91
+
92
+ self.downsample = nn.ModuleList([
93
+ SparseDownBlock3d(
94
+ channels=model_channels//16,
95
+ out_channels=model_channels // 8,
96
+ use_checkpoint=use_checkpoint,
97
+ ),
98
+ SparseDownBlock3d(
99
+ channels=model_channels // 8,
100
+ out_channels=model_channels // 4,
101
+ use_checkpoint=use_checkpoint,
102
+ ),
103
+ SparseDownBlock3d(
104
+ channels=model_channels // 4,
105
+ out_channels=model_channels,
106
+ use_checkpoint=use_checkpoint,
107
+ )
108
+ ])
109
+
110
+ self.resolution = resolution
111
+ self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
112
+
113
+ self.initialize_weights()
114
+ if use_fp16:
115
+ self.convert_to_fp16()
116
+
117
+ def initialize_weights(self) -> None:
118
+ super().initialize_weights()
119
+ # Zero-out output layers:
120
+ nn.init.constant_(self.out_layer.weight, 0)
121
+ nn.init.constant_(self.out_layer.bias, 0)
122
+
123
+ def forward(self, x: sp.SparseTensor, factor: float = None):
124
+
125
+ x = self.input_layer1(x)
126
+ for block in self.downsample:
127
+ x = block(x)
128
+ h = super().forward(x, factor)
129
+ h = h.type(x.dtype)
130
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
131
+ h = self.out_layer(h)
132
+
133
+ return h
pixal3d/models/autoencoders/ss_vae.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import trimesh
7
+ from skimage import measure
8
+
9
+ from ...modules import sparse as sp
10
+ from .encoder import SparseSDFEncoder
11
+ from .decoder import SparseSDFDecoder
12
+ from .distributions import DiagonalGaussianDistribution
13
+
14
+
15
+ class SparseSDFVAE(nn.Module):
16
+ def __init__(self, *,
17
+ embed_dim: int = 0,
18
+ resolution: int = 64,
19
+ model_channels_encoder: int = 512,
20
+ num_blocks_encoder: int = 4,
21
+ num_heads_encoder: int = 8,
22
+ num_head_channels_encoder: int = 64,
23
+ model_channels_decoder: int = 512,
24
+ num_blocks_decoder: int = 4,
25
+ num_heads_decoder: int = 8,
26
+ num_head_channels_decoder: int = 64,
27
+ out_channels: int = 1,
28
+ use_fp16: bool = False,
29
+ use_checkpoint: bool = False,
30
+ chunk_size: int = 1,
31
+ latents_scale: float = 1.0,
32
+ latents_shift: float = 0.0):
33
+
34
+ super().__init__()
35
+
36
+ self.use_checkpoint = use_checkpoint
37
+ self.resolution = resolution
38
+ self.latents_scale = latents_scale
39
+ self.latents_shift = latents_shift
40
+
41
+ self.encoder = SparseSDFEncoder(
42
+ resolution=resolution,
43
+ in_channels=model_channels_encoder,
44
+ model_channels=model_channels_encoder,
45
+ latent_channels=embed_dim,
46
+ num_blocks=num_blocks_encoder,
47
+ num_heads=num_heads_encoder,
48
+ num_head_channels=num_head_channels_encoder,
49
+ use_fp16=use_fp16,
50
+ use_checkpoint=use_checkpoint,
51
+ )
52
+
53
+ self.decoder = SparseSDFDecoder(
54
+ resolution=resolution,
55
+ model_channels=model_channels_decoder,
56
+ latent_channels=embed_dim,
57
+ num_blocks=num_blocks_decoder,
58
+ num_heads=num_heads_decoder,
59
+ num_head_channels=num_head_channels_decoder,
60
+ out_channels=out_channels,
61
+ use_fp16=use_fp16,
62
+ use_checkpoint=use_checkpoint,
63
+ chunk_size=chunk_size,
64
+ )
65
+ self.embed_dim = embed_dim
66
+
67
+ def forward(self, batch):
68
+
69
+ z, posterior = self.encode(batch)
70
+
71
+ reconst_x = self.decoder(z)
72
+ outputs = {'reconst_x': reconst_x, 'posterior': posterior}
73
+ return outputs
74
+
75
+ def encode(self, batch, sample_posterior: bool = True):
76
+
77
+ feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx']
78
+ if feat.ndim == 1:
79
+ feat = feat.unsqueeze(-1)
80
+ coords = torch.cat([batch_idx.unsqueeze(-1), xyz], dim=-1).int()
81
+
82
+ x = sp.SparseTensor(feat, coords)
83
+ h = self.encoder(x, batch.get('factor', None))
84
+ posterior = DiagonalGaussianDistribution(h.feats, feat_dim=1)
85
+ if sample_posterior:
86
+ z = posterior.sample()
87
+ else:
88
+ z = posterior.mode()
89
+ z = h.replace(z)
90
+
91
+ return z, posterior
92
+
93
+ def decode_mesh(self,
94
+ latents,
95
+ voxel_resolution: int = 512,
96
+ mc_threshold: float = 0.2,
97
+ return_feat: bool = False,
98
+ factor: float = 1.0):
99
+ voxel_resolution = int(voxel_resolution / factor)
100
+ reconst_x = self.decoder(latents, factor=factor, return_feat=return_feat)
101
+ if return_feat:
102
+ return reconst_x
103
+ outputs = self.sparse2mesh(reconst_x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold)
104
+
105
+ return outputs
106
+
107
+ def sparse2mesh(self,
108
+ reconst_x: torch.FloatTensor,
109
+ voxel_resolution: int = 512,
110
+ mc_threshold: float = 0.0):
111
+
112
+ sparse_sdf, sparse_index = reconst_x.feats.float(), reconst_x.coords
113
+ batch_size = int(sparse_index[..., 0].max().cpu().numpy() + 1)
114
+
115
+ meshes = []
116
+ for i in range(batch_size):
117
+ idx = sparse_index[..., 0] == i
118
+ sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1).cpu(), sparse_index[idx][..., 1:].detach().cpu()
119
+ sdf = torch.ones((voxel_resolution, voxel_resolution, voxel_resolution))
120
+ sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i
121
+ vertices, faces, _, _ = measure.marching_cubes(
122
+ sdf.numpy(),
123
+ mc_threshold,
124
+ method="lewiner",
125
+ )
126
+ vertices = vertices / voxel_resolution * 2 - 1
127
+ meshes.append(trimesh.Trimesh(vertices, faces))
128
+
129
+ return meshes
pixal3d/models/conditional_encoders/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import dinov2_project_grid
2
+
pixal3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (238 Bytes). View file
 
pixal3d/models/conditional_encoders/__pycache__/dinov2_project_grid.cpython-310.pyc ADDED
Binary file (16 kB). View file
 
pixal3d/models/conditional_encoders/dinov2_project_grid.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DINOv2 Project Grid Encoders
3
+ Includes single-view and multi-view DINOv2 encoders with 3D grid projection support
4
+ """
5
+
6
+ import random
7
+ from dataclasses import dataclass
8
+ from typing import List, Dict, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torchvision import transforms
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+
16
+ import pixal3d
17
+ from pixal3d.utils.base import BaseModule
18
+
19
+ # Set linear algebra backend to avoid cusolver errors
20
+ try:
21
+ torch.backends.cuda.preferred_linalg_library("cusolver")
22
+ except Exception:
23
+ pass
24
+
25
+
26
+ # =============================================================================
27
+ # Base DINOv2 Encoder
28
+ # =============================================================================
29
+
30
+ @pixal3d.register("dinov2-encoder")
31
+ class DinoEncoder(BaseModule, ModelMixin):
32
+ """Base DINOv2 Encoder"""
33
+
34
+ @dataclass
35
+ class Config(BaseModule.Config):
36
+ model: str = "facebookresearch/dinov2"
37
+ version: str = "dinov2_vitl14_reg"
38
+ size: int = 518
39
+ empty_embeds_ratio: float = 0.1
40
+
41
+ cfg: Config
42
+
43
+ def configure(self) -> None:
44
+ super().configure()
45
+ self.empty_embeds_ratio = self.cfg.empty_embeds_ratio
46
+
47
+ # Load DINOv2 model
48
+ dino_model = torch.hub.load(
49
+ self.cfg.model, self.cfg.version, pretrained=True
50
+ )
51
+ self.encoder = dino_model.eval()
52
+
53
+ # Image preprocessing
54
+ self.transform = transforms.Compose([
55
+ transforms.Resize(
56
+ self.cfg.size,
57
+ transforms.InterpolationMode.BILINEAR,
58
+ antialias=True
59
+ ),
60
+ transforms.CenterCrop(self.cfg.size),
61
+ transforms.Normalize(
62
+ mean=[0.485, 0.456, 0.406],
63
+ std=[0.229, 0.224, 0.225],
64
+ ),
65
+ ])
66
+
67
+
68
+
69
+
70
+ def forward(self, image, image_mask=None, is_training=False):
71
+ z = self.encoder(self.transform(image), is_training=True)['x_prenorm']
72
+ z = F.layer_norm(z, z.shape[-1:])
73
+
74
+ if is_training and random.random() < self.empty_embeds_ratio:
75
+ # zero out embeddings
76
+ z = z * 0
77
+
78
+ if image_mask is not None:
79
+ image_mask_patch = F.max_pool2d(
80
+ image_mask, kernel_size=14, stride=14
81
+ ).squeeze(1) > 0
82
+ return z, image_mask_patch
83
+
84
+ return z
85
+
86
+
87
+ # =============================================================================
88
+ # 3D Projection Utility Functions
89
+ # =============================================================================
90
+
91
+ def project_points_to_image_batch(
92
+ points_3d: torch.Tensor,
93
+ transform_matrix: torch.Tensor,
94
+ camera_angle_x: torch.Tensor,
95
+ resolution: int = 518
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
97
+ """
98
+ Project 3D points to 2D image coordinates with batch support
99
+
100
+ Args:
101
+ points_3d: [N, 3] or [B, N, 3], 3D point coordinates (in [-1, 1] range)
102
+ transform_matrix: [B, 4, 4], batch of camera transformation matrices
103
+ camera_angle_x: [B], batch of camera horizontal FOV angles (radians)
104
+ resolution: Rendering image resolution
105
+
106
+ Returns:
107
+ points_2d: [B, N, 2], image coordinates [x, y]
108
+ depth: [B, N], depth values
109
+ valid_mask: [B, N], mask indicating if points are within view
110
+ """
111
+ device = points_3d.device
112
+ B = transform_matrix.shape[0]
113
+
114
+ # Ensure inputs are torch.Tensor
115
+ if not isinstance(transform_matrix, torch.Tensor):
116
+ transform_matrix = torch.tensor(
117
+ transform_matrix, dtype=torch.float32, device=device
118
+ )
119
+ if not isinstance(points_3d, torch.Tensor):
120
+ points_3d = torch.tensor(
121
+ points_3d, dtype=torch.float32, device=device
122
+ )
123
+ if not isinstance(camera_angle_x, torch.Tensor):
124
+ camera_angle_x = torch.tensor(
125
+ camera_angle_x, dtype=torch.float32, device=device
126
+ )
127
+
128
+ # Expand points_3d to batch dimension
129
+ if points_3d.dim() == 2:
130
+ points_3d_batch = points_3d.unsqueeze(0).expand(B, -1, -1)
131
+ else:
132
+ points_3d_batch = points_3d
133
+
134
+ N = points_3d_batch.shape[1]
135
+
136
+ # Add homogeneous coordinates
137
+ ones = torch.ones(B, N, 1, device=device)
138
+ points_homogeneous = torch.cat([points_3d_batch, ones], dim=-1)
139
+
140
+ # World to camera transformation
141
+ world_to_camera = torch.linalg.inv(transform_matrix)
142
+ points_camera = torch.bmm(
143
+ points_homogeneous,
144
+ world_to_camera.transpose(-2, -1)
145
+ )[..., :3]
146
+
147
+ # Extract camera coordinates
148
+ x_cam = points_camera[..., 0]
149
+ y_cam = points_camera[..., 1]
150
+ z_cam = points_camera[..., 2]
151
+
152
+ # Depth values
153
+ depth = -z_cam
154
+
155
+ # Compute camera intrinsics
156
+ sensor_width = 32.0
157
+ focal_length = 16.0 / torch.tan(camera_angle_x / 2.0)
158
+ focal_length_pixels = focal_length * resolution / sensor_width
159
+ focal_length_pixels = focal_length_pixels.unsqueeze(1)
160
+
161
+ # Perspective projection
162
+ x_ndc = focal_length_pixels * x_cam / (-z_cam)
163
+ y_ndc = focal_length_pixels * y_cam / (-z_cam)
164
+
165
+ # Convert to image coordinates
166
+ x_pixel = x_ndc + resolution / 2.0
167
+ y_pixel = -y_ndc + resolution / 2.0
168
+
169
+ # Validity mask
170
+ valid_mask = (
171
+ (x_pixel >= 0) & (x_pixel < resolution) &
172
+ (y_pixel >= 0) & (y_pixel < resolution) &
173
+ (depth > 0)
174
+ )
175
+
176
+ points_2d = torch.stack([x_pixel, y_pixel], dim=-1)
177
+ return points_2d, depth, valid_mask
178
+
179
+
180
+ def project_points_to_image(
181
+ points_3d: torch.Tensor,
182
+ transform_matrix: torch.Tensor,
183
+ camera_angle_x: float,
184
+ resolution: int = 512
185
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
186
+ """
187
+ Project 3D points to 2D image coordinates (single-view version)
188
+
189
+ Args:
190
+ points_3d: [N, 3], 3D point coordinates
191
+ transform_matrix: [4, 4], camera transformation matrix
192
+ camera_angle_x: Camera horizontal FOV angle (radians)
193
+ resolution: Rendering image resolution
194
+
195
+ Returns:
196
+ points_2d: [N, 2], image coordinates [x, y]
197
+ depth: [N], depth values
198
+ valid_mask: [N], mask indicating if points are within view
199
+ """
200
+ device = points_3d.device
201
+
202
+ if not isinstance(transform_matrix, torch.Tensor):
203
+ transform_matrix = torch.tensor(
204
+ transform_matrix, dtype=torch.float32, device=device
205
+ )
206
+ if not isinstance(points_3d, torch.Tensor):
207
+ points_3d = torch.tensor(
208
+ points_3d, dtype=torch.float32, device=device
209
+ )
210
+
211
+ N = points_3d.shape[0]
212
+ points_homogeneous = torch.cat([
213
+ points_3d,
214
+ torch.ones(N, 1, device=device)
215
+ ], dim=1)
216
+
217
+ # World to camera transformation
218
+ camera_to_world = transform_matrix
219
+ world_to_camera = torch.linalg.inv(camera_to_world)
220
+ points_camera = torch.matmul(
221
+ points_homogeneous,
222
+ world_to_camera.T
223
+ )[:, :3]
224
+
225
+ x_cam = points_camera[:, 0]
226
+ y_cam = points_camera[:, 1]
227
+ z_cam = points_camera[:, 2]
228
+ depth = -z_cam
229
+
230
+ # Camera intrinsics
231
+ sensor_width = 32.0
232
+ focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0))
233
+ focal_length_pixels = focal_length * resolution / sensor_width
234
+
235
+ # Perspective projection
236
+ x_ndc = focal_length_pixels * x_cam / (-z_cam)
237
+ y_ndc = focal_length_pixels * y_cam / (-z_cam)
238
+
239
+ # Image coordinates
240
+ x_pixel = x_ndc + resolution / 2.0
241
+ y_pixel = -y_ndc + resolution / 2.0
242
+
243
+ valid_mask = (
244
+ (x_pixel >= 0) & (x_pixel < resolution) &
245
+ (y_pixel >= 0) & (y_pixel < resolution) &
246
+ (depth > 0)
247
+ )
248
+
249
+ points_2d = torch.stack([x_pixel, y_pixel], dim=1)
250
+ return points_2d, depth, valid_mask
251
+
252
+
253
+ def sample_features(
254
+ fmap: torch.Tensor,
255
+ queries_ndc: torch.Tensor
256
+ ) -> torch.Tensor:
257
+ """
258
+ Sample features using grid_sample
259
+
260
+ Args:
261
+ fmap: [B, C, H, W], feature map
262
+ queries_ndc: [B, K, 2], NDC coordinates
263
+
264
+ Returns:
265
+ feat: [B, C, K], sampled features
266
+ """
267
+ B, C, H, W = fmap.shape
268
+ Bq, K, _ = queries_ndc.shape
269
+ assert Bq == B, "batch 不一致"
270
+
271
+ grid = queries_ndc.view(B, K, 1, 2)
272
+ feat = F.grid_sample(
273
+ fmap, grid, mode='bilinear',
274
+ align_corners=False, padding_mode='border'
275
+ )
276
+ return feat.squeeze(-1)
277
+
278
+
279
+ # =============================================================================
280
+ # Projection Grid Module
281
+ # =============================================================================
282
+
283
+ class ProjGrid(nn.Module):
284
+ """3D Grid Projection Module"""
285
+
286
+ def __init__(self, grid_resolution: int = 16):
287
+ super().__init__()
288
+ self.grid_resolution = grid_resolution
289
+ self.image_resolution = 518
290
+
291
+ # Create 3D grid points
292
+ one_dim = torch.linspace(-1, 1, grid_resolution)
293
+ x, y, z = torch.meshgrid(one_dim, one_dim, one_dim, indexing='ij')
294
+ grid_points = torch.stack((x, y, z), dim=-1)
295
+
296
+ # Rotation matrix (align with Blender)
297
+ rotation_matrix = torch.tensor([
298
+ [1.0, 0.0, 0.0],
299
+ [0.0, 0.0, -1.0],
300
+ [0.0, 1.0, 0.0]
301
+ ])
302
+ grid_points = torch.matmul(grid_points, rotation_matrix.T)
303
+ grid_points = grid_points.reshape(-1, 3)
304
+ self.register_buffer('grid_points', grid_points)
305
+
306
+ # Front view transformation matrix
307
+ front_view_transform_matrix = torch.tensor([
308
+ [1.0, 0.0, 0.0, 0.0],
309
+ [0.0, 0.0, -1.0, -2.0],
310
+ [0.0, 1.0, 0.0, 0.0],
311
+ [0.0, 0.0, 0.0, 1.0]
312
+ ])
313
+ self.register_buffer(
314
+ "front_view_transform_matrix",
315
+ front_view_transform_matrix
316
+ )
317
+
318
+ def forward(
319
+ self,
320
+ features_map: torch.Tensor,
321
+ camera_angle_x: torch.Tensor,
322
+ distance: torch.Tensor,
323
+ mesh_scale: torch.Tensor,
324
+ transform_matrix: torch.Tensor = None,
325
+ BHWC: bool = True
326
+ ) -> torch.Tensor:
327
+ """
328
+ Project feature map to 3D grid
329
+
330
+ Args:
331
+ features_map: [B, H, W, C] or [B, C, H, W]
332
+ camera_angle_x: [B]
333
+ distance: [B]
334
+ mesh_scale: [B]
335
+ transform_matrix: [B, 4, 4] or None
336
+ BHWC: Whether input is in BHWC format
337
+
338
+ Returns:
339
+ x: [B, K, C], projected features
340
+ """
341
+ if BHWC:
342
+ B, H, W, C = features_map.shape
343
+ else:
344
+ B, C, H, W = features_map.shape
345
+
346
+ # Prepare grid points
347
+ grid_points = self.grid_points.expand(B, -1, -1)
348
+ grid_points = grid_points / mesh_scale.unsqueeze(-1).unsqueeze(-1) / 2
349
+
350
+ # Use default transformation matrix
351
+ if transform_matrix is None:
352
+ transform_matrix = self.front_view_transform_matrix
353
+ transform_matrix = transform_matrix.expand(B, -1, -1).clone()
354
+ transform_matrix[:, 1, 3] = -distance
355
+
356
+ # Project to image
357
+ image_points, depth, valid_mask = project_points_to_image_batch(
358
+ grid_points, transform_matrix, camera_angle_x, self.image_resolution
359
+ )
360
+
361
+ # Normalize to [-1, 1]
362
+
363
+ image_points_norm = (image_points + 0.5) / self.image_resolution * 2 - 1
364
+
365
+
366
+ # Adjust dimensions and sample
367
+ if BHWC:
368
+ features_map = features_map.permute(0, 3, 1, 2)
369
+
370
+ x = sample_features(features_map, image_points_norm)
371
+ x = x.permute(0, 2, 1)
372
+
373
+ return x
374
+
375
+
376
+
377
+
378
+
379
+ # =============================================================================
380
+ # DINOv2 Encoder with Projection
381
+ # =============================================================================
382
+
383
+ @pixal3d.register("dinov2-encoder-proj")
384
+ class DinoEncoderProj(BaseModule, ModelMixin):
385
+ """DINOv2 Encoder with 3D Grid Projection"""
386
+
387
+ @dataclass
388
+ class Config(BaseModule.Config):
389
+ model: str = "facebookresearch/dinov2"
390
+ version: str = "dinov2_vitl14_reg"
391
+ size: int = 518
392
+ empty_embeds_ratio: float = 0.1
393
+ grid_resolution: int = 16
394
+ use_upsample: bool = False
395
+ use_geo_feats: bool = False
396
+
397
+ cfg: Config
398
+
399
+ def configure(self) -> None:
400
+ super().configure()
401
+ self.grid_resolution = self.cfg.grid_resolution
402
+ self.empty_embeds_ratio = self.cfg.empty_embeds_ratio
403
+ self.use_upsample = self.cfg.use_upsample
404
+
405
+ # Load DINOv2
406
+ dino_model = torch.hub.load(
407
+ self.cfg.model, self.cfg.version, pretrained=True
408
+ )
409
+ self.encoder = dino_model.eval()
410
+
411
+ # Optional: load upsampler
412
+ if self.use_upsample:
413
+ upsampler = torch.hub.load("valeoai/NAF", "naf", pretrained=True)
414
+ self.upsampler = upsampler.eval()
415
+
416
+ # Image preprocessing (normalization only)
417
+ self.transform = transforms.Compose([
418
+ transforms.Normalize(
419
+ mean=[0.485, 0.456, 0.406],
420
+ std=[0.229, 0.224, 0.225],
421
+ ),
422
+ ])
423
+
424
+ self.patch_size = self.encoder.patch_size
425
+ self.patch_number = self.cfg.size // self.patch_size
426
+ self.proj_grid = ProjGrid(grid_resolution=self.cfg.grid_resolution)
427
+
428
+
429
+
430
+
431
+
432
+
433
+ def forward(
434
+ self,
435
+ image: torch.Tensor,
436
+ image_mask: torch.Tensor = None,
437
+ camera_angle_x: torch.Tensor = None,
438
+ distance: torch.Tensor = None,
439
+ mesh_scale: torch.Tensor = None,
440
+ transform_matrix: torch.Tensor = None,
441
+ is_training: bool = False
442
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
443
+ """
444
+ Forward pass
445
+
446
+ Args:
447
+ image: [B, C, H, W]
448
+ camera_angle_x: [B]
449
+ distance: [B]
450
+ mesh_scale: [B]
451
+ is_training: Training mode flag
452
+
453
+ Returns:
454
+ z_global: [B, num_global, C]
455
+ z: [B, grid_resolution^3, C]
456
+ """
457
+ image = self.transform(image)
458
+
459
+ with torch.no_grad():
460
+ z = self.encoder(image, is_training=True)['x_prenorm']
461
+ z = F.layer_norm(z, z.shape[-1:])
462
+
463
+ # Split tokens
464
+ z_clstoken = z[:, 0:1]
465
+ z_regtokens = z[:, 1:self.encoder.num_register_tokens + 1]
466
+ z_patchtokens = z[:, 1 + self.encoder.num_register_tokens:]
467
+ z_patchtokens = z_patchtokens.reshape(
468
+ z_patchtokens.shape[0],
469
+ self.patch_number,
470
+ self.patch_number,
471
+ -1
472
+ )
473
+
474
+ # Project to grid
475
+ z = self.proj_grid(
476
+ z_patchtokens, camera_angle_x, distance, mesh_scale
477
+ )
478
+
479
+ # Optional: upsample and fuse
480
+ if self.use_upsample:
481
+ z_patchtokens_permuted = z_patchtokens.permute(0, 3, 1, 2)
482
+ z_upsampled = self.upsampler(
483
+ image, z_patchtokens_permuted, output_size=(518, 518)
484
+ )
485
+ z_upsampled = self.proj_grid(
486
+ z_upsampled, camera_angle_x, distance, mesh_scale, BHWC=False
487
+ )
488
+ z = z + z_upsampled
489
+
490
+ # Global tokens
491
+ z_global = torch.cat([z_clstoken, z_regtokens], dim=1)
492
+ z_global = z_global.expand(z.shape[0], -1, -1)
493
+
494
+ # Classifier-free guidance: random drop
495
+ if is_training and random.random() < self.empty_embeds_ratio:
496
+ z_global = z_global * 0
497
+ z = z * 0
498
+
499
+ return z_global, z
500
+
501
+
502
+ # =============================================================================
503
+ # Multi-View Projection Encoder Helper Functions
504
+ # =============================================================================
505
+
506
+ def compute_calc_mat(
507
+ true_view_mat: torch.Tensor,
508
+ ext_true_view_mat: torch.Tensor,
509
+ fix_mat: torch.Tensor
510
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
511
+ """
512
+ Compute calc_mat using matrix relative transformation
513
+
514
+ Args:
515
+ true_view_mat: [B, 1, 4, 4], ground truth camera matrix
516
+ ext_true_view_mat: [B, N, 4, 4], extended ground truth camera matrices
517
+ fix_mat: [B, 1, 4, 4], fixed matrix
518
+
519
+ Returns:
520
+ calc_mat: [B, N, 4, 4]
521
+ relative_transform: [B, N, 4, 4]
522
+ """
523
+ B, N = ext_true_view_mat.shape[:2]
524
+
525
+ # Expand to [B, N, 4, 4]
526
+ true_view_mat_exp = true_view_mat.expand(B, N, 4, 4)
527
+ fix_mat_exp = fix_mat.expand(B, N, 4, 4)
528
+
529
+ # Flatten to [B*N, 4, 4]
530
+ true_view_mat_flat = true_view_mat_exp.reshape(B * N, 4, 4)
531
+ ext_true_view_mat_flat = ext_true_view_mat.reshape(B * N, 4, 4)
532
+ fix_mat_flat = fix_mat_exp.reshape(B * N, 4, 4)
533
+
534
+ # Compute relative transformation (disable autocast for fp32 precision)
535
+ with torch.amp.autocast('cuda', enabled=False):
536
+ true_view_mat_flat = true_view_mat_flat.float()
537
+ ext_true_view_mat_flat = ext_true_view_mat_flat.float()
538
+ fix_mat_flat = fix_mat_flat.float()
539
+
540
+ relative_transform_flat = torch.bmm(
541
+ torch.linalg.inv(true_view_mat_flat),
542
+ ext_true_view_mat_flat
543
+ )
544
+ calc_mat_flat = torch.bmm(fix_mat_flat, relative_transform_flat)
545
+
546
+ calc_mat = calc_mat_flat.view(B, N, 4, 4)
547
+ relative_transform = relative_transform_flat.view(B, N, 4, 4)
548
+
549
+ return calc_mat, relative_transform
550
+
551
+
552
+ # =============================================================================
553
+ # Multi-View DINOv2 Projection Encoder
554
+ # =============================================================================
555
+
556
+ @pixal3d.register("dinov2-encoder-proj-multi-view")
557
+ class DinoEncoderProjMultiView(BaseModule, ModelMixin):
558
+ """Multi-View DINOv2 Projection Encoder"""
559
+
560
+ @dataclass
561
+ class Config(BaseModule.Config):
562
+ model: str = "facebookresearch/dinov2"
563
+ version: str = "dinov2_vitl14_reg"
564
+ size: int = 518
565
+ empty_embeds_ratio: float = 0.1
566
+ grid_resolution: int = 16
567
+ use_upsample: bool = False
568
+
569
+ cfg: Config
570
+
571
+ def configure(self) -> None:
572
+ super().configure()
573
+ self.grid_resolution = self.cfg.grid_resolution
574
+ self.empty_embeds_ratio = self.cfg.empty_embeds_ratio
575
+ self.use_upsample = self.cfg.use_upsample
576
+
577
+ # Load DINOv2
578
+ dino_model = torch.hub.load(
579
+ self.cfg.model, self.cfg.version, pretrained=True
580
+ )
581
+
582
+ self.encoder = dino_model.eval()
583
+
584
+ # Optional: upsampler
585
+ if self.use_upsample:
586
+ upsampler = torch.hub.load("valeoai/NAF", "naf", pretrained=True)
587
+ self.upsampler = upsampler.eval()
588
+
589
+ # Image preprocessing
590
+ self.transform = transforms.Compose([
591
+ transforms.Normalize(
592
+ mean=[0.485, 0.456, 0.406],
593
+ std=[0.229, 0.224, 0.225],
594
+ ),
595
+ ])
596
+
597
+ self.patch_size = self.encoder.patch_size
598
+ self.patch_number = self.cfg.size // self.patch_size
599
+ self.proj_grid = ProjGrid(grid_resolution=self.cfg.grid_resolution)
600
+
601
+ # Fixed transformation matrix
602
+ self.register_buffer("fix_transform_matrix", torch.tensor([
603
+ [1.0, 0.0, 0.0, 0.0],
604
+ [0.0, 0.0, -1.0, -2.0],
605
+ [0.0, 1.0, 0.0, 0.0],
606
+ [0.0, 0.0, 0.0, 1.0]
607
+ ]))
608
+
609
+ def forward(
610
+ self,
611
+ image: torch.Tensor,
612
+ image_mask: torch.Tensor = None,
613
+ camera_angle_x: torch.Tensor = None,
614
+ distance: torch.Tensor = None,
615
+ mesh_scale: torch.Tensor = None,
616
+ transform_matrix: torch.Tensor = None,
617
+ is_training: bool = False
618
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
619
+ """
620
+ Forward pass
621
+
622
+ Args:
623
+ image: [B, num_views, C, H, W]
624
+ camera_angle_x: [B, num_views]
625
+ distance: [B, num_views]
626
+ mesh_scale: [B]
627
+ transform_matrix: [B, num_views, 4, 4]
628
+
629
+ Returns:
630
+ z_global: [B, num_global, C]
631
+ z: [B, grid_resolution^3, C]
632
+ """
633
+ B, num_views, C, H, W = image.shape
634
+ image = image.reshape(B * num_views, C, H, W)
635
+ image = self.transform(image)
636
+
637
+ with torch.no_grad():
638
+ z = self.encoder(image, is_training=True)['x_prenorm']
639
+ z = F.layer_norm(z, z.shape[-1:])
640
+ z_clstoken = z[:, 0:1]
641
+ z_regtokens = z[:, 1:self.encoder.num_register_tokens + 1]
642
+ z_patchtokens = z[:, 1 + self.encoder.num_register_tokens:]
643
+ z_patchtokens = z_patchtokens.reshape(
644
+ z_patchtokens.shape[0],
645
+ self.patch_number,
646
+ self.patch_number,
647
+ -1
648
+ )
649
+
650
+ # Compute relative transformation
651
+ calc_mat, relative_transform = self.get_relative_transform(
652
+ transform_matrix, distance
653
+ )
654
+ calc_mat = calc_mat.reshape(B * num_views, 4, 4)
655
+
656
+ # Prepare parameters
657
+ init_mesh_scale = mesh_scale[:, None].expand(B, num_views).reshape(B * num_views)
658
+ camera_angle_x_flat = camera_angle_x.reshape(B * num_views)
659
+ distance_flat = distance.reshape(B * num_views)
660
+
661
+ # Accumulate per-view (avoid OOM)
662
+ z_accumulated = None
663
+ z_patchtokens_permuted = z_patchtokens.permute(0, 3, 1, 2) if self.use_upsample else None
664
+
665
+ with torch.no_grad():
666
+ for view_idx in range(num_views):
667
+ indices = torch.arange(
668
+ view_idx, B * num_views, num_views, device=z_patchtokens.device
669
+ )
670
+
671
+ # Project current view
672
+ z_view = self.proj_grid(
673
+ z_patchtokens[indices],
674
+ camera_angle_x_flat[indices],
675
+ distance_flat[indices],
676
+ init_mesh_scale[indices],
677
+ calc_mat[indices]
678
+ )
679
+
680
+ # Optional: upsample
681
+ if self.use_upsample:
682
+ chunk_upsampled = self.upsampler(
683
+ image[indices],
684
+ z_patchtokens_permuted[indices],
685
+ output_size=(518, 518)
686
+ )
687
+ chunk_proj = self.proj_grid(
688
+ chunk_upsampled,
689
+ camera_angle_x_flat[indices],
690
+ distance_flat[indices],
691
+ init_mesh_scale[indices],
692
+ calc_mat[indices],
693
+ BHWC=False
694
+ )
695
+ z_view = z_view + chunk_proj
696
+ del chunk_upsampled, chunk_proj
697
+
698
+ # Accumulate
699
+ if z_accumulated is None:
700
+ z_accumulated = z_view.clone()
701
+ else:
702
+ z_accumulated = z_accumulated + z_view
703
+ del z_view
704
+
705
+ if z_patchtokens_permuted is not None:
706
+ del z_patchtokens_permuted
707
+
708
+ # Average
709
+ z = z_accumulated / num_views
710
+
711
+ # Average global tokens
712
+ z_global = torch.cat([z_clstoken, z_regtokens], dim=1)
713
+ z_global = z_global.reshape(B, num_views, z_global.shape[-2], z_global.shape[-1])
714
+ z_global = z_global.mean(dim=1)
715
+
716
+ # Classifier-free guidance
717
+ if is_training and random.random() < self.empty_embeds_ratio:
718
+ z_global = z_global * 0
719
+ z = z * 0
720
+
721
+ return z_global, z
722
+
723
+ def get_relative_transform(
724
+ self,
725
+ transform_matrix: torch.Tensor,
726
+ distance: torch.Tensor
727
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
728
+ """
729
+ Compute relative transformation matrix
730
+
731
+ Args:
732
+ transform_matrix: [B, num_views, 4, 4]
733
+ distance: [B, num_views]
734
+
735
+ Returns:
736
+ calc_mat: [B, num_views, 4, 4]
737
+ relative_transform: [B, num_views, 4, 4]
738
+ """
739
+ B, num_views, _, _ = transform_matrix.shape
740
+ init_transform_matrix = transform_matrix[:, 0:1]
741
+
742
+ fix_transform_matrix = self.fix_transform_matrix.unsqueeze(0).expand(B, -1, -1).clone()
743
+ init_distance = distance[:, 0]
744
+ fix_transform_matrix[:, 1, 3] = -init_distance
745
+ fix_transform_matrix = fix_transform_matrix.unsqueeze(1)
746
+
747
+ calc_mat, relative_transform = compute_calc_mat(
748
+ init_transform_matrix, transform_matrix, fix_transform_matrix
749
+ )
750
+ return calc_mat, relative_transform
pixal3d/models/transformers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import sparse_dit
2
+ from . import dense_dit
pixal3d/models/transformers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (253 Bytes). View file
 
pixal3d/models/transformers/__pycache__/dense_dit.cpython-310.pyc ADDED
Binary file (9.49 kB). View file
 
pixal3d/models/transformers/__pycache__/sparse_dit.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
pixal3d/models/transformers/dense_dit.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from dataclasses import dataclass
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32
8
+ from ...modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
9
+ from ...modules.spatial import patchify, unpatchify
10
+ from ...utils.base import BaseModule
11
+ import pixal3d
12
+ from huggingface_hub import hf_hub_download
13
+ import os
14
+
15
+ class TimestepEmbedder(nn.Module):
16
+ """
17
+ Embeds scalar timesteps into vector representations.
18
+ """
19
+ def __init__(self, hidden_size, frequency_embedding_size=256):
20
+ super().__init__()
21
+ self.mlp = nn.Sequential(
22
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
23
+ nn.SiLU(),
24
+ nn.Linear(hidden_size, hidden_size, bias=True),
25
+ )
26
+ self.frequency_embedding_size = frequency_embedding_size
27
+
28
+ @staticmethod
29
+ def timestep_embedding(t, dim, max_period=10000):
30
+ """
31
+ Create sinusoidal timestep embeddings.
32
+
33
+ Args:
34
+ t: a 1-D Tensor of N indices, one per batch element.
35
+ These may be fractional.
36
+ dim: the dimension of the output.
37
+ max_period: controls the minimum frequency of the embeddings.
38
+
39
+ Returns:
40
+ an (N, D) Tensor of positional embeddings.
41
+ """
42
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
43
+ half = dim // 2
44
+ freqs = torch.exp(
45
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
46
+ ).to(device=t.device)
47
+ args = t[:, None].float() * freqs[None]
48
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
49
+ if dim % 2:
50
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
51
+ return embedding
52
+
53
+ def forward(self, t):
54
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
55
+ t_freq = t_freq.to(self.mlp[0].weight.dtype)
56
+ t_emb = self.mlp(t_freq)
57
+ return t_emb
58
+
59
+
60
+ class DenseDiT(nn.Module):
61
+ def __init__(
62
+ self,
63
+ resolution: int,
64
+ in_channels: int,
65
+ model_channels: int,
66
+ cond_channels: int,
67
+ out_channels: int,
68
+ num_blocks: int,
69
+ num_heads: Optional[int] = None,
70
+ num_head_channels: Optional[int] = 64,
71
+ mlp_ratio: float = 4,
72
+ patch_size: int = 2,
73
+ pe_mode: Literal["ape", "rope"] = "ape",
74
+ use_fp16: bool = False,
75
+ use_checkpoint: bool = False,
76
+ share_mod: bool = False,
77
+ qk_rms_norm: bool = False,
78
+ qk_rms_norm_cross: bool = False,
79
+ latent_shape: list = [8, 16, 16, 16],
80
+ image_attn_mode:str = "cross",
81
+ load_ckpt:bool = True,
82
+ ):
83
+ super().__init__()
84
+ self.resolution = resolution
85
+ self.in_channels = in_channels
86
+ self.model_channels = model_channels
87
+ self.cond_channels = cond_channels
88
+ self.out_channels = out_channels
89
+ self.num_blocks = num_blocks
90
+ self.num_heads = num_heads or model_channels // num_head_channels
91
+ self.mlp_ratio = mlp_ratio
92
+ self.patch_size = patch_size
93
+ self.pe_mode = pe_mode
94
+ self.use_fp16 = use_fp16
95
+ self.use_checkpoint = use_checkpoint
96
+ self.share_mod = share_mod
97
+ self.qk_rms_norm = qk_rms_norm
98
+ self.qk_rms_norm_cross = qk_rms_norm_cross
99
+ self.dtype = torch.float16 if use_fp16 else torch.float32
100
+ self.latent_shape = latent_shape
101
+ self.image_attn_mode = image_attn_mode
102
+
103
+ self.t_embedder = TimestepEmbedder(model_channels)
104
+ if share_mod:
105
+ self.adaLN_modulation = nn.Sequential(
106
+ nn.SiLU(),
107
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
108
+ )
109
+
110
+ if pe_mode == "ape":
111
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
112
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
113
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
114
+ pos_emb = pos_embedder(coords)
115
+ self.register_buffer("pos_emb", pos_emb)
116
+
117
+ self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
118
+
119
+ self.blocks = nn.ModuleList([
120
+ ModulatedTransformerCrossBlock(
121
+ model_channels,
122
+ cond_channels,
123
+ num_heads=self.num_heads,
124
+ mlp_ratio=self.mlp_ratio,
125
+ attn_mode='full',
126
+ use_checkpoint=self.use_checkpoint,
127
+ use_rope=(pe_mode == "rope"),
128
+ share_mod=share_mod,
129
+ qk_rms_norm=self.qk_rms_norm,
130
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
131
+ image_attn_mode = self.image_attn_mode,
132
+
133
+ )
134
+ for _ in range(num_blocks)
135
+ ])
136
+
137
+ self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
138
+
139
+ self.initialize_weights()
140
+ if use_fp16:
141
+ self.convert_to_fp16()
142
+
143
+
144
+ @property
145
+ def device(self) -> torch.device:
146
+ """
147
+ Return the device of the model.
148
+ """
149
+ return next(self.parameters()).device
150
+
151
+ def convert_to_fp16(self) -> None:
152
+ """
153
+ Convert the torso of the model to float16.
154
+ """
155
+ # self.blocks.apply(convert_module_to_f16)
156
+ self.apply(convert_module_to_f16)
157
+
158
+ def convert_to_fp32(self) -> None:
159
+ """
160
+ Convert the torso of the model to float32.
161
+ """
162
+ self.blocks.apply(convert_module_to_f32)
163
+
164
+ def initialize_weights(self) -> None:
165
+ # Initialize transformer layers:
166
+ def _basic_init(module):
167
+ if isinstance(module, nn.Linear):
168
+ torch.nn.init.xavier_uniform_(module.weight)
169
+ if module.bias is not None:
170
+ nn.init.constant_(module.bias, 0)
171
+ self.apply(_basic_init)
172
+
173
+ # Initialize timestep embedding MLP:
174
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
175
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
176
+
177
+ # Zero-out adaLN modulation layers in DiT blocks:
178
+ if self.share_mod:
179
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
180
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
181
+ else:
182
+ for block in self.blocks:
183
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
184
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
185
+
186
+ # Zero-out output layers:
187
+ nn.init.constant_(self.out_layer.weight, 0)
188
+ nn.init.constant_(self.out_layer.bias, 0)
189
+
190
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
191
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
192
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
193
+
194
+ h = patchify(x, self.patch_size)
195
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
196
+ h = self.input_layer(h)
197
+ h = h + self.pos_emb[None]
198
+ t_emb = self.t_embedder(t)
199
+ if self.share_mod:
200
+ t_emb = self.adaLN_modulation(t_emb)
201
+ t_emb = t_emb.type(self.dtype)
202
+ h = h.type(self.dtype)
203
+ if self.image_attn_mode=='proj':
204
+ global_cond,proj_cond = cond
205
+ global_cond = global_cond.type(self.dtype)
206
+ proj_cond = proj_cond.type(self.dtype)
207
+ cond = (global_cond, proj_cond)
208
+ else:
209
+ cond = cond.type(self.dtype)
210
+ for block in self.blocks:
211
+ h = block(h, t_emb, cond)
212
+ h = h.type(x.dtype)
213
+ h = F.layer_norm(h, h.shape[-1:])
214
+ h = self.out_layer(h)
215
+
216
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
217
+ h = unpatchify(h, self.patch_size).contiguous()
218
+
219
+ return h
220
+
221
+
222
+ # ===== Align to sparse_dit style: ModelOutput + Denoiser wrapper (Lightning-friendly) =====
223
+
224
+ @dataclass
225
+ class DenseDiTModelOutput:
226
+ sample: torch.Tensor
227
+
228
+
229
+ @pixal3d.register("dense-dit-denoiser")
230
+ class DenseDiTDenoiser(BaseModule):
231
+ @dataclass
232
+ class Config(BaseModule.Config):
233
+ # Mirror DenseDiT init signature with reasonable defaults
234
+ resolution: int = 64
235
+ in_channels: int = 16
236
+ model_channels: int = 1024
237
+ cond_channels: int = 1024
238
+ out_channels: int = 16
239
+ num_blocks: int = 24
240
+ num_heads: Optional[int] = None
241
+ num_head_channels: Optional[int] = 64
242
+ mlp_ratio: float = 4.0
243
+ patch_size: int = 2
244
+ pe_mode: str = "ape" # "ape" | "rope"
245
+ use_fp16: bool = False
246
+ use_checkpoint: bool = False
247
+ share_mod: bool = False
248
+ qk_rms_norm: bool = False
249
+ qk_rms_norm_cross: bool = False
250
+ latent_shape: list = (8, 16, 16, 16)
251
+ image_attn_mode: str = "cross"
252
+ load_ckpt:bool = True
253
+
254
+ cfg: Config
255
+
256
+ def configure(self) -> None:
257
+ # Instantiate the underlying DenseDiT model
258
+ self.dit_model = DenseDiT(
259
+ resolution=self.cfg.resolution,
260
+ in_channels=self.cfg.in_channels,
261
+ model_channels=self.cfg.model_channels,
262
+ cond_channels=self.cfg.cond_channels,
263
+ out_channels=self.cfg.out_channels,
264
+ num_blocks=self.cfg.num_blocks,
265
+ num_heads=self.cfg.num_heads,
266
+ num_head_channels=self.cfg.num_head_channels,
267
+ mlp_ratio=self.cfg.mlp_ratio,
268
+ patch_size=self.cfg.patch_size,
269
+ pe_mode=self.cfg.pe_mode,
270
+ use_fp16=self.cfg.use_fp16,
271
+ use_checkpoint=self.cfg.use_checkpoint,
272
+ share_mod=self.cfg.share_mod,
273
+ qk_rms_norm=self.cfg.qk_rms_norm,
274
+ qk_rms_norm_cross=self.cfg.qk_rms_norm_cross,
275
+ latent_shape=list(self.cfg.latent_shape) if isinstance(self.cfg.latent_shape, (list, tuple)) else self.cfg.latent_shape,
276
+ image_attn_mode=self.cfg.image_attn_mode,
277
+ load_ckpt=self.cfg.load_ckpt,
278
+ )
279
+
280
+ # For a consistent external API (some systems may read out_channels)
281
+ self.out_channels = self.cfg.out_channels
282
+
283
+ def forward(
284
+ self,
285
+ x: torch.Tensor,
286
+ t: torch.Tensor,
287
+ cond: torch.Tensor,
288
+ **kwargs,
289
+ ) -> DenseDiTModelOutput:
290
+ """Forward wrapper returning a structured output like diffusers models.
291
+
292
+ Args:
293
+ x: [B, C, D, H, W] dense latent tensor.
294
+ t: [B] or [1] timestep tensor.
295
+ cond: conditioning tensor matching the transformer blocks' expected dims.
296
+ """
297
+ out = self.dit_model(x, t, cond)
298
+ return DenseDiTModelOutput(sample=out)
pixal3d/models/transformers/sparse_dit.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Some parts of this file are adapted from the SparseDiT implementation
2
+ import os
3
+ from typing import Any, Dict, Optional, Union, Tuple, Literal
4
+ from dataclasses import dataclass
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.loaders import PeftAdapterMixin
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.utils import logging
15
+
16
+ import pixal3d
17
+ from pixal3d.utils.base import BaseModule
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ # Import sparse operations
21
+
22
+ from ...modules import sparse as sp
23
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32
24
+ from ...modules.transformer import AbsolutePositionEmbedder
25
+ from ...modules.sparse.transformer.modulated import ModulatedSparseTransformerCrossBlock
26
+ SPARSE_AVAILABLE = True
27
+ # except ImportError:
28
+ # print("Warning: sparse modules not found. Please ensure it's in your Python path.")
29
+ # sp = None
30
+ # convert_module_to_f16 = None
31
+ # convert_module_to_f32 = None
32
+ # AbsolutePositionEmbedder = None
33
+ # ModulatedSparseTransformerCrossBlock = None
34
+ # SPARSE_AVAILABLE = False
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ @dataclass
40
+ class SparseDiTModelOutput:
41
+ sample: Any # Can be torch.FloatTensor or sp.SparseTensor
42
+
43
+
44
+ class TimestepEmbedder(nn.Module):
45
+ """
46
+ Embeds scalar timesteps into vector representations.
47
+ """
48
+ def __init__(self, hidden_size, frequency_embedding_size=256):
49
+ super().__init__()
50
+ self.mlp = nn.Sequential(
51
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
52
+ nn.SiLU(),
53
+ nn.Linear(hidden_size, hidden_size, bias=True),
54
+ )
55
+ self.frequency_embedding_size = frequency_embedding_size
56
+
57
+ @staticmethod
58
+ def timestep_embedding(t, dim, max_period=10000):
59
+ """
60
+ Create sinusoidal timestep embeddings.
61
+ """
62
+ half = dim // 2
63
+ freqs = torch.exp(
64
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
65
+ ).to(device=t.device)
66
+ args = t[:, None].float() * freqs[None]
67
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
68
+ if dim % 2:
69
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
70
+ return embedding
71
+
72
+ def forward(self, t):
73
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
74
+ t_freq = t_freq.to(self.mlp[0].weight.dtype)
75
+ t_emb = self.mlp(t_freq)
76
+ return t_emb
77
+
78
+
79
+ class SparseDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
80
+ """
81
+ Sparse Diffusion Transformer model for 3D shape generation.
82
+
83
+ This model processes sparse 3D data using sparse attention mechanisms.
84
+ """
85
+
86
+ _supports_gradient_checkpointing = True
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ resolution: int = 64,
92
+ in_channels: int = 16,
93
+ model_channels: int = 1024,
94
+ cond_channels: int = 1024,
95
+ out_channels: int = 16,
96
+ num_blocks: int = 24,
97
+ num_heads: int = 32,
98
+ num_head_channels: int = 64,
99
+ num_kv_heads: int = 2,
100
+ compression_block_size: int = 4,
101
+ selection_block_size: int = 8,
102
+ topk: int = 32,
103
+ compression_version: str = 'v2',
104
+ mlp_ratio: float = 4.0,
105
+ pe_mode: str = "ape",
106
+ use_fp16: bool = True,
107
+ use_checkpoint: bool = True,
108
+ share_mod: bool = False,
109
+ qk_rms_norm: bool = True,
110
+ qk_rms_norm_cross: bool = False,
111
+ sparse_conditions: bool = True,
112
+ factor: float = 1.0,
113
+ window_size: int = 8,
114
+ use_shift: bool = True,
115
+ image_attn_mode:str='cross',
116
+ load_ckpt:bool=True,
117
+ version:Optional[str]='V10',
118
+ ):
119
+ super().__init__()
120
+
121
+ if not SPARSE_AVAILABLE:
122
+ raise ImportError("sparse modules not found.")
123
+
124
+ self.resolution = resolution
125
+ self.in_channels = in_channels
126
+ self.model_channels = model_channels
127
+ self.cond_channels = cond_channels
128
+ self.out_channels = out_channels
129
+ self.num_blocks = num_blocks
130
+ self.num_heads = num_heads or model_channels // num_head_channels
131
+ self.mlp_ratio = mlp_ratio
132
+ self.pe_mode = pe_mode
133
+ self.use_fp16 = use_fp16
134
+ self.use_checkpoint = use_checkpoint
135
+ self.share_mod = share_mod
136
+ self.qk_rms_norm = qk_rms_norm
137
+ self.qk_rms_norm_cross = qk_rms_norm_cross
138
+ self._dtype = torch.float16 if use_fp16 else torch.float32
139
+ self.sparse_conditions = sparse_conditions
140
+ self.factor = factor
141
+ self.compression_block_size = compression_block_size
142
+ self.selection_block_size = selection_block_size
143
+ self.image_attn_mode = image_attn_mode
144
+
145
+ # Timestep embedding
146
+ self.t_embedder = TimestepEmbedder(model_channels)
147
+
148
+ # Shared modulation if enabled
149
+ if share_mod:
150
+ self.adaLN_modulation = nn.Sequential(
151
+ nn.SiLU(),
152
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
153
+ )
154
+
155
+ # Condition processing for sparse conditions
156
+ if sparse_conditions:
157
+ self.cond_proj = sp.SparseLinear(cond_channels, cond_channels)
158
+ self.pos_embedder_cond = AbsolutePositionEmbedder(model_channels, in_channels=3)
159
+
160
+ # Position embedding
161
+ if pe_mode == "ape":
162
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
163
+
164
+ # Input projection
165
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
166
+
167
+ # Transformer blocks
168
+ self.blocks = nn.ModuleList([
169
+ ModulatedSparseTransformerCrossBlock(
170
+ model_channels,
171
+ cond_channels,
172
+ num_heads=self.num_heads,
173
+ num_kv_heads=num_kv_heads,
174
+ compression_block_size=compression_block_size,
175
+ selection_block_size=selection_block_size,
176
+ topk=topk,
177
+ mlp_ratio=self.mlp_ratio,
178
+ attn_mode='full',
179
+ compression_version=compression_version,
180
+ use_checkpoint=self.use_checkpoint,
181
+ use_rope=(pe_mode == "rope"),
182
+ share_mod=self.share_mod,
183
+ qk_rms_norm=self.qk_rms_norm,
184
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
185
+ resolution=resolution,
186
+ window_size=window_size,
187
+ shift_window=window_size // 2 * (i % 2) if use_shift else window_size // 2,
188
+ image_attn_mode = image_attn_mode,
189
+ )
190
+ for i in range(num_blocks)
191
+ ])
192
+
193
+ # Output projection
194
+ self.out_layer = sp.SparseLinear(model_channels, out_channels)
195
+
196
+ # Initialize weights
197
+ self.initialize_weights()
198
+
199
+
200
+ self.gradient_checkpointing = False
201
+
202
+ if use_fp16:
203
+ print("Converting model to float16 ============================")
204
+ self.convert_to_fp16()
205
+ # else:
206
+ # self.convert_to_fp32()
207
+ @property
208
+ def device(self) -> torch.device:
209
+ """Return the device of the model."""
210
+ return next(self.parameters()).device
211
+
212
+ def _set_gradient_checkpointing(self, module, value=False):
213
+ if hasattr(module, "gradient_checkpointing"):
214
+ module.gradient_checkpointing = value
215
+
216
+ def convert_to_fp16(self) -> None:
217
+ """Convert the model to float16."""
218
+ self.apply(convert_module_to_f16)
219
+
220
+ def convert_to_fp32(self) -> None:
221
+ """Convert the model to float32."""
222
+ self.apply(convert_module_to_f32)
223
+
224
+ def initialize_weights(self) -> None:
225
+ """Initialize model weights."""
226
+ # Initialize transformer layers
227
+ def _basic_init(module):
228
+ if isinstance(module, nn.Linear):
229
+ torch.nn.init.xavier_uniform_(module.weight)
230
+ if module.bias is not None:
231
+ nn.init.constant_(module.bias, 0)
232
+ self.apply(_basic_init)
233
+
234
+ # Initialize timestep embedding MLP
235
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
236
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
237
+
238
+ # Zero-out adaLN modulation layers
239
+ if self.share_mod:
240
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
241
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
242
+ else:
243
+ for block in self.blocks:
244
+ # if hasattr(block, 'adaLN_modulation'):
245
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
246
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
247
+
248
+ # Zero-out output layers
249
+ nn.init.constant_(self.out_layer.weight, 0)
250
+ nn.init.constant_(self.out_layer.bias, 0)
251
+
252
+ def forward(
253
+ self,
254
+ hidden_states: Any, # sp.SparseTensor
255
+ timestep: torch.Tensor,
256
+ encoder_hidden_states: Optional[Any] = None, # torch.Tensor or sp.SparseTensor
257
+ attention_kwargs: Optional[Dict[str, Any]] = None,
258
+ return_dict: bool = True,
259
+ ) -> Union[SparseDiTModelOutput, Tuple]:
260
+ """
261
+ Forward pass of the SparseDiT model.
262
+
263
+ Args:
264
+ hidden_states: Input sparse tensor
265
+ timestep: Timestep tensor
266
+ encoder_hidden_states: Condition tensor (visual/text conditions)
267
+ attention_kwargs: Additional attention arguments
268
+ return_dict: Whether to return a dictionary
269
+ """
270
+ # breakpoint()
271
+ # Process input
272
+ assert attention_kwargs is None, "attention_kwargs not supported in SparseDiT"
273
+ # breakpoint()
274
+ h = self.input_layer(hidden_states).type(self._dtype)
275
+
276
+ # Process timestep
277
+ t_emb = self.t_embedder(timestep)
278
+ if self.share_mod:
279
+ t_emb = self.adaLN_modulation(t_emb)
280
+ t_emb = t_emb.type(self._dtype)
281
+
282
+ # Process conditions
283
+
284
+ cond = encoder_hidden_states
285
+ if self.image_attn_mode=='proj':
286
+ global_cond,sparse_cond = cond
287
+
288
+ if sparse_cond is not None:
289
+ sparse_cond = sparse_cond.type(self._dtype)
290
+ global_cond = global_cond.type(self._dtype)
291
+ # breakpoint()
292
+ if self.sparse_conditions and isinstance(sparse_cond, sp.SparseTensor):
293
+ # breakpoint()
294
+ sparse_cond = self.cond_proj(sparse_cond)
295
+ sparse_cond = sparse_cond + self.pos_embedder_cond(sparse_cond.coords[:, 1:]).type(self._dtype)
296
+ cond = (global_cond,sparse_cond)
297
+ else:
298
+ if self.sparse_conditions:
299
+ cond = self.cond_proj(cond)
300
+ cond = cond + self.pos_embedder_cond(cond.coords[:, 1:]).type(self.dtype)
301
+
302
+ # Add positional embeddings
303
+ if self.pe_mode == "ape":
304
+ h = h + self.pos_embedder(h.coords[:, 1:], factor=self.factor).type(self._dtype)
305
+
306
+ # Process through transformer blocks
307
+ for block in self.blocks:
308
+ if self.training and self.gradient_checkpointing:
309
+ def create_custom_forward(module):
310
+ def custom_forward(*inputs):
311
+ return module(*inputs)
312
+ return custom_forward
313
+
314
+ h = torch.utils.checkpoint.checkpoint(
315
+ create_custom_forward(block),
316
+ h, t_emb, cond
317
+ )
318
+ else:
319
+ h = block(h, t_emb, cond)
320
+
321
+ # Final layer norm and output projection
322
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
323
+ h = self.out_layer(h.type(hidden_states.dtype))
324
+
325
+ if not return_dict:
326
+ return (h,)
327
+
328
+ return SparseDiTModelOutput(sample=h)
329
+
330
+
331
+ @pixal3d.register("sparse-dit-denoiser")
332
+ class SparseDiTDenoiser(BaseModule):
333
+ """
334
+ Sparse DiT Denoiser wrapper for pixal3d framework.
335
+ """
336
+
337
+ @dataclass
338
+ class Config(BaseModule.Config):
339
+ # Model architecture
340
+ resolution: int = 64
341
+ in_channels: int = 16
342
+ model_channels: int = 1024
343
+ cond_channels: int = 1024
344
+ out_channels: int = 16
345
+ num_blocks: int = 24
346
+ num_heads: int = 32
347
+ num_kv_heads: int = 2
348
+ compression_block_size: int = 4
349
+ selection_block_size: int = 8
350
+ topk: int = 32
351
+ compression_version: str = 'v2'
352
+ mlp_ratio: float = 4.0
353
+ pe_mode: str = "ape"
354
+ use_fp16: bool = True
355
+ use_checkpoint: bool = True
356
+ qk_rms_norm: bool = True
357
+ qk_rms_norm_cross: bool = False
358
+ sparse_conditions: bool = True
359
+ factor: float = 1.0
360
+ window_size: int = 8
361
+ use_shift: bool = True
362
+
363
+ # Condition settings
364
+ use_visual_condition: bool = True
365
+ visual_condition_dim: int = 1024
366
+ use_caption_condition: bool = False
367
+ caption_condition_dim: int = 1024
368
+ use_label_condition: bool = False
369
+ label_condition_dim: int = 1024
370
+
371
+ # Training settings
372
+ pretrained_model_name_or_path: Optional[str] = None
373
+
374
+ image_attn_mode:Optional[str]='cross'
375
+ load_ckpt:bool =True
376
+ version:Optional[str]='V10'
377
+
378
+ cfg: Config
379
+
380
+ def configure(self) -> None:
381
+ """Configure the SparseDiT model."""
382
+
383
+ # Create the core SparseDiT model
384
+ self.dit_model = SparseDiTModel(
385
+ resolution=self.cfg.resolution,
386
+ in_channels=self.cfg.in_channels,
387
+ model_channels=self.cfg.model_channels,
388
+ cond_channels=self.cfg.cond_channels,
389
+ out_channels=self.cfg.out_channels,
390
+ num_blocks=self.cfg.num_blocks,
391
+ num_heads=self.cfg.num_heads,
392
+ num_kv_heads=self.cfg.num_kv_heads,
393
+ compression_block_size=self.cfg.compression_block_size,
394
+ selection_block_size=self.cfg.selection_block_size,
395
+ topk=self.cfg.topk,
396
+ compression_version=self.cfg.compression_version,
397
+ mlp_ratio=self.cfg.mlp_ratio,
398
+ pe_mode=self.cfg.pe_mode,
399
+ use_fp16=self.cfg.use_fp16,
400
+ use_checkpoint=self.cfg.use_checkpoint,
401
+ sparse_conditions=self.cfg.sparse_conditions,
402
+ factor=self.cfg.factor,
403
+ window_size=self.cfg.window_size,
404
+ use_shift=self.cfg.use_shift,
405
+ image_attn_mode=self.cfg.image_attn_mode,
406
+ load_ckpt = self.cfg.load_ckpt,
407
+ version=self.cfg.version,
408
+ )
409
+
410
+ # Condition projectors
411
+ if self.cfg.use_visual_condition and self.cfg.visual_condition_dim != self.cfg.cond_channels:
412
+ self.proj_visual_condition = nn.Sequential(
413
+ nn.RMSNorm(self.cfg.visual_condition_dim),
414
+ nn.Linear(self.cfg.visual_condition_dim, self.cfg.cond_channels),
415
+ )
416
+
417
+ if self.cfg.use_caption_condition and self.cfg.caption_condition_dim != self.cfg.cond_channels:
418
+ self.proj_caption_condition = nn.Sequential(
419
+ nn.RMSNorm(self.cfg.caption_condition_dim),
420
+ nn.Linear(self.cfg.caption_condition_dim, self.cfg.cond_channels),
421
+ )
422
+
423
+ if self.cfg.use_label_condition and self.cfg.label_condition_dim != self.cfg.cond_channels:
424
+ self.proj_label_condition = nn.Sequential(
425
+ nn.RMSNorm(self.cfg.label_condition_dim),
426
+ nn.Linear(self.cfg.label_condition_dim, self.cfg.cond_channels),
427
+ )
428
+
429
+ # Load pretrained weights if specified
430
+ if self.cfg.pretrained_model_name_or_path:
431
+ print(f"Loading pretrained SparseDiT model from {self.cfg.pretrained_model_name_or_path}")
432
+ ckpt = torch.load(
433
+ self.cfg.pretrained_model_name_or_path,
434
+ map_location="cpu",
435
+ weights_only=True,
436
+ )
437
+ if "state_dict" in ckpt.keys():
438
+ ckpt = ckpt["state_dict"]
439
+ self.load_state_dict(ckpt, strict=True)
440
+
441
+ def forward(
442
+ self,
443
+ x: Any, # sp.SparseTensor
444
+ t: torch.Tensor,
445
+ cond: Optional[Any] = None,
446
+ ):
447
+ """
448
+ Forward pass of the denoiser.
449
+
450
+ Args:
451
+ model_input: Input sparse tensor [SparseTensor with features]
452
+ timestep: Timestep tensor [batch_size,]
453
+ visual_condition: Visual condition tensor
454
+ caption_condition: Caption condition tensor
455
+ label_condition: Label condition tensor
456
+ attention_kwargs: Additional attention arguments
457
+ return_dict: Whether to return a dictionary
458
+ """
459
+
460
+
461
+ output = self.dit_model(
462
+ hidden_states=x,
463
+ timestep=t,
464
+ encoder_hidden_states=cond,
465
+ )
466
+
467
+ return output
468
+
469
+
pixal3d/modules/__pycache__/norm.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
pixal3d/modules/__pycache__/spatial.cpython-310.pyc ADDED
Binary file (2.49 kB). View file
 
pixal3d/modules/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.55 kB). View file
 
pixal3d/modules/attention/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ BACKEND = 'flash_attn'
3
+ DEBUG = False
4
+
5
+ def __from_env():
6
+ import os
7
+
8
+ global BACKEND
9
+ global DEBUG
10
+
11
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
12
+ env_sttn_debug = os.environ.get('ATTN_DEBUG')
13
+
14
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
15
+ BACKEND = env_attn_backend
16
+ if env_sttn_debug is not None:
17
+ DEBUG = env_sttn_debug == '1'
18
+
19
+ print(f"[ATTENTION] Using backend: {BACKEND}")
20
+
21
+
22
+ __from_env()
23
+
24
+
25
+ def set_backend(backend: Literal['xformers', 'flash_attn']):
26
+ global BACKEND
27
+ BACKEND = backend
28
+
29
+ def set_debug(debug: bool):
30
+ global DEBUG
31
+ DEBUG = debug
32
+
33
+
34
+ from .full_attn import *
35
+ from .modules import *
pixal3d/modules/attention/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (940 Bytes). View file
 
pixal3d/modules/attention/__pycache__/full_attn.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
pixal3d/modules/attention/__pycache__/modules.cpython-310.pyc ADDED
Binary file (6.21 kB). View file
 
pixal3d/modules/attention/full_attn.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from . import DEBUG, BACKEND
5
+
6
+ if BACKEND == 'xformers':
7
+ import xformers.ops as xops
8
+ elif BACKEND == 'flash_attn':
9
+ import flash_attn
10
+ elif BACKEND == 'sdpa':
11
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
12
+ elif BACKEND == 'naive':
13
+ pass
14
+ else:
15
+ raise ValueError(f"Unknown attention backend: {BACKEND}")
16
+
17
+
18
+ __all__ = [
19
+ 'scaled_dot_product_attention',
20
+ ]
21
+
22
+
23
+ def _naive_sdpa(q, k, v):
24
+ """
25
+ Naive implementation of scaled dot product attention.
26
+ """
27
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
+ scale_factor = 1 / math.sqrt(q.size(-1))
31
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
+ attn_weight = torch.softmax(attn_weight, dim=-1)
33
+ out = attn_weight @ v
34
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
+ return out
36
+
37
+
38
+ @overload
39
+ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Apply scaled dot product attention.
42
+
43
+ Args:
44
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
45
+ """
46
+ ...
47
+
48
+ @overload
49
+ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Apply scaled dot product attention.
52
+
53
+ Args:
54
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
55
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
56
+ """
57
+ ...
58
+
59
+ @overload
60
+ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Apply scaled dot product attention.
63
+
64
+ Args:
65
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
66
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
67
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
68
+
69
+ Note:
70
+ k and v are assumed to have the same coordinate map.
71
+ """
72
+ ...
73
+
74
+ def scaled_dot_product_attention(*args, **kwargs):
75
+ arg_names_dict = {
76
+ 1: ['qkv'],
77
+ 2: ['q', 'kv'],
78
+ 3: ['q', 'k', 'v']
79
+ }
80
+ num_all_args = len(args) + len(kwargs)
81
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
82
+ for key in arg_names_dict[num_all_args][len(args):]:
83
+ assert key in kwargs, f"Missing argument {key}"
84
+
85
+ if num_all_args == 1:
86
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
87
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
88
+ device = qkv.device
89
+
90
+ elif num_all_args == 2:
91
+ q = args[0] if len(args) > 0 else kwargs['q']
92
+ kv = args[1] if len(args) > 1 else kwargs['kv']
93
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
94
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
95
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
96
+ device = q.device
97
+
98
+ elif num_all_args == 3:
99
+ q = args[0] if len(args) > 0 else kwargs['q']
100
+ k = args[1] if len(args) > 1 else kwargs['k']
101
+ v = args[2] if len(args) > 2 else kwargs['v']
102
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
103
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
104
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
105
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
106
+ device = q.device
107
+
108
+ if BACKEND == 'xformers':
109
+ if num_all_args == 1:
110
+ q, k, v = qkv.unbind(dim=2)
111
+ elif num_all_args == 2:
112
+ k, v = kv.unbind(dim=2)
113
+ out = xops.memory_efficient_attention(q, k, v)
114
+ elif BACKEND == 'flash_attn':
115
+ if num_all_args == 1:
116
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
117
+ elif num_all_args == 2:
118
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
119
+ elif num_all_args == 3:
120
+ out = flash_attn.flash_attn_func(q, k, v)
121
+ elif BACKEND == 'sdpa':
122
+ if num_all_args == 1:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ elif num_all_args == 2:
125
+ k, v = kv.unbind(dim=2)
126
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
127
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
128
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
129
+ out = sdpa(q, k, v) # [N, H, L, C]
130
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
131
+ elif BACKEND == 'naive':
132
+ if num_all_args == 1:
133
+ q, k, v = qkv.unbind(dim=2)
134
+ elif num_all_args == 2:
135
+ k, v = kv.unbind(dim=2)
136
+ out = _naive_sdpa(q, k, v)
137
+ else:
138
+ raise ValueError(f"Unknown attention module: {BACKEND}")
139
+
140
+ return out
pixal3d/modules/attention/modules.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .full_attn import scaled_dot_product_attention
6
+
7
+
8
+ class MultiHeadRMSNorm(nn.Module):
9
+ def __init__(self, dim: int, heads: int):
10
+ super().__init__()
11
+ self.scale = dim ** 0.5
12
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
16
+
17
+
18
+ class RotaryPositionEmbedder(nn.Module):
19
+ def __init__(self, hidden_size: int, in_channels: int = 3):
20
+ super().__init__()
21
+ assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
22
+ self.hidden_size = hidden_size
23
+ self.in_channels = in_channels
24
+ self.freq_dim = hidden_size // in_channels // 2
25
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
+ self.freqs = 1.0 / (10000 ** self.freqs)
27
+
28
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
+ self.freqs = self.freqs.to(indices.device)
30
+ phases = torch.outer(indices, self.freqs)
31
+ phases = torch.polar(torch.ones_like(phases), phases)
32
+ return phases
33
+
34
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
+ x_rotated = x_complex * phases
37
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
38
+ return x_embed
39
+
40
+ def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Args:
43
+ q (sp.SparseTensor): [..., N, D] tensor of queries
44
+ k (sp.SparseTensor): [..., N, D] tensor of keys
45
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
46
+ """
47
+ if indices is None:
48
+ indices = torch.arange(q.shape[-2], device=q.device)
49
+ if len(q.shape) > 2:
50
+ indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
51
+
52
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
53
+ if phases.shape[1] < self.hidden_size // 2:
54
+ phases = torch.cat([phases, torch.polar(
55
+ torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
56
+ torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
57
+ )], dim=-1)
58
+ q_embed = self._rotary_embedding(q, phases)
59
+ k_embed = self._rotary_embedding(k, phases)
60
+ return q_embed, k_embed
61
+
62
+
63
+ class MultiHeadAttention(nn.Module):
64
+ def __init__(
65
+ self,
66
+ channels: int,
67
+ num_heads: int,
68
+ ctx_channels: Optional[int]=None,
69
+ type: Literal["self", "cross"] = "self",
70
+ attn_mode: Literal["full", "windowed"] = "full",
71
+ window_size: Optional[int] = None,
72
+ shift_window: Optional[Tuple[int, int, int]] = None,
73
+ qkv_bias: bool = True,
74
+ use_rope: bool = False,
75
+ qk_rms_norm: bool = False,
76
+ ):
77
+ super().__init__()
78
+ assert channels % num_heads == 0
79
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
80
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
81
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
82
+
83
+ if attn_mode == "windowed":
84
+ raise NotImplementedError("Windowed attention is not yet implemented")
85
+
86
+ self.channels = channels
87
+ self.head_dim = channels // num_heads
88
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
89
+ self.num_heads = num_heads
90
+ self._type = type
91
+ self.attn_mode = attn_mode
92
+ self.window_size = window_size
93
+ self.shift_window = shift_window
94
+ self.use_rope = use_rope
95
+ self.qk_rms_norm = qk_rms_norm
96
+
97
+ if self._type == "self":
98
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
99
+ else:
100
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
101
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
102
+
103
+ if self.qk_rms_norm:
104
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
105
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
106
+
107
+ self.to_out = nn.Linear(channels, channels)
108
+
109
+ if use_rope:
110
+ self.rope = RotaryPositionEmbedder(channels)
111
+
112
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
113
+ B, L, C = x.shape
114
+ if self._type == "self":
115
+ qkv = self.to_qkv(x)
116
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
117
+ if self.use_rope:
118
+ q, k, v = qkv.unbind(dim=2)
119
+ q, k = self.rope(q, k, indices)
120
+ qkv = torch.stack([q, k, v], dim=2)
121
+ if self.attn_mode == "full":
122
+ if self.qk_rms_norm:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ q = self.q_rms_norm(q)
125
+ k = self.k_rms_norm(k)
126
+ h = scaled_dot_product_attention(q, k, v)
127
+ else:
128
+ h = scaled_dot_product_attention(qkv)
129
+ elif self.attn_mode == "windowed":
130
+ raise NotImplementedError("Windowed attention is not yet implemented")
131
+ else:
132
+ Lkv = context.shape[1]
133
+ q = self.to_q(x)
134
+ kv = self.to_kv(context)
135
+ q = q.reshape(B, L, self.num_heads, -1)
136
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
137
+ if self.qk_rms_norm:
138
+ q = self.q_rms_norm(q)
139
+ k, v = kv.unbind(dim=2)
140
+ k = self.k_rms_norm(k)
141
+ h = scaled_dot_product_attention(q, k, v)
142
+ else:
143
+ h = scaled_dot_product_attention(q, kv)
144
+ h = h.reshape(B, L, -1)
145
+ h = self.to_out(h)
146
+ return h
147
+
148
+
149
+ class ProjectAttention(nn.Module):
150
+ def __init__(self,cross_attn_block: nn.Module):
151
+ super().__init__()
152
+ self.cross_attn_block = cross_attn_block
153
+ self.global_token_length = 5
154
+
155
+
156
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
157
+
158
+ global_context = context[0]
159
+ proj_context = context[1]
160
+ global_context = self.cross_attn_block(x, global_context)
161
+ context = proj_context + global_context
162
+ return context + x
163
+
164
+
pixal3d/modules/norm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerNorm32(nn.LayerNorm):
6
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
7
+ return super().forward(x.float()).type(x.dtype)
8
+
9
+
10
+ class GroupNorm32(nn.GroupNorm):
11
+ """
12
+ A GroupNorm layer that converts to float32 before the forward pass.
13
+ """
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return super().forward(x.float()).type(x.dtype)
16
+
17
+
18
+ class ChannelLayerNorm32(LayerNorm32):
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ DIM = x.dim()
21
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
22
+ x = super().forward(x)
23
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24
+ return x
25
+
pixal3d/modules/sparse/__init__.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'torchsparse'
4
+ DEBUG = False
5
+ ATTN = 'flash_attn'
6
+
7
+ def __from_env():
8
+ import os
9
+
10
+ global BACKEND
11
+ global DEBUG
12
+ global ATTN
13
+
14
+ env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
+ env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
+ if env_sparse_attn is None:
18
+ env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
+
20
+ if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
21
+ BACKEND = env_sparse_backend
22
+ if env_sparse_debug is not None:
23
+ DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
+ ATTN = env_sparse_attn
26
+
27
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
+
29
+
30
+ __from_env()
31
+
32
+
33
+ def set_backend(backend: Literal['spconv', 'torchsparse']):
34
+ global BACKEND
35
+ BACKEND = backend
36
+
37
+ def set_debug(debug: bool):
38
+ global DEBUG
39
+ DEBUG = debug
40
+
41
+ def set_attn(attn: Literal['xformers', 'flash_attn']):
42
+ global ATTN
43
+ ATTN = attn
44
+
45
+
46
+ import importlib
47
+
48
+ __attributes = {
49
+ 'SparseTensor': 'basic',
50
+ 'sparse_batch_broadcast': 'basic',
51
+ 'sparse_batch_op': 'basic',
52
+ 'sparse_cat': 'basic',
53
+ 'sparse_unbind': 'basic',
54
+ 'SparseGroupNorm': 'norm',
55
+ 'SparseLayerNorm': 'norm',
56
+ 'SparseGroupNorm32': 'norm',
57
+ 'SparseLayerNorm32': 'norm',
58
+ 'SparseSigmoid': 'nonlinearity',
59
+ 'SparseReLU': 'nonlinearity',
60
+ 'SparseSiLU': 'nonlinearity',
61
+ 'SparseGELU': 'nonlinearity',
62
+ 'SparseTanh': 'nonlinearity',
63
+ 'SparseActivation': 'nonlinearity',
64
+ 'SparseLinear': 'linear',
65
+ 'sparse_scaled_dot_product_attention': 'attention',
66
+ 'SerializeMode': 'attention',
67
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
68
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
69
+ 'SparseMultiHeadAttention': 'attention',
70
+ 'SparseConv3d': 'conv',
71
+ 'SparseInverseConv3d': 'conv',
72
+ 'sparseconv3d_func': 'conv',
73
+ 'SparseDownsample': 'spatial',
74
+ 'SparseUpsample': 'spatial',
75
+ 'SparseSubdivide' : 'spatial'
76
+ }
77
+
78
+ __submodules = ['transformer']
79
+
80
+ __all__ = list(__attributes.keys()) + __submodules
81
+
82
+ def __getattr__(name):
83
+ if name not in globals():
84
+ if name in __attributes:
85
+ module_name = __attributes[name]
86
+ module = importlib.import_module(f".{module_name}", __name__)
87
+ globals()[name] = getattr(module, name)
88
+ elif name in __submodules:
89
+ module = importlib.import_module(f".{name}", __name__)
90
+ globals()[name] = module
91
+ else:
92
+ raise AttributeError(f"module {__name__} has no attribute {name}")
93
+ return globals()[name]
94
+
95
+
96
+ # For Pylance
97
+ if __name__ == '__main__':
98
+ from .basic import *
99
+ from .norm import *
100
+ from .nonlinearity import *
101
+ from .linear import *
102
+ from .attention import *
103
+ from .conv import *
104
+ from .spatial import *
105
+ import transformer
pixal3d/modules/sparse/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.64 kB). View file
 
pixal3d/modules/sparse/__pycache__/basic.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
pixal3d/modules/sparse/__pycache__/linear.cpython-310.pyc ADDED
Binary file (884 Bytes). View file
 
pixal3d/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
pixal3d/modules/sparse/__pycache__/norm.cpython-310.pyc ADDED
Binary file (2.7 kB). View file
 
pixal3d/modules/sparse/__pycache__/spatial.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
pixal3d/modules/sparse/attention/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .full_attn import *
2
+ from .serialized_attn import *
3
+ from .windowed_attn import *
4
+ from .modules import *
5
+ from .spatial_sparse_attention.module.spatial_sparse_attention import SpatialSparseAttention
pixal3d/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (391 Bytes). View file
 
pixal3d/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc ADDED
Binary file (7.3 kB). View file
 
pixal3d/modules/sparse/attention/__pycache__/modules.cpython-310.pyc ADDED
Binary file (5.86 kB). View file