diff --git a/README.md b/README.md index c7fc018ba10a6180cf40c73c73dfcf600850bbaf..27d1eda7d75b0f2728e2ad9d68e9bec374855c0c 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,12 @@ --- -title: Pixal3D D -emoji: 👁 -colorFrom: indigo -colorTo: blue +title: Pixal3D-D +emoji: 🎨 +colorFrom: blue +colorTo: purple sdk: gradio -sdk_version: 6.14.0 +sdk_version: 5.29.0 app_file: app.py pinned: false license: apache-2.0 +extra_gated_eu_disallowed: true --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc25438162208a4730f6e64f41e01aed232e024 --- /dev/null +++ b/app.py @@ -0,0 +1,311 @@ +""" +Pixal3D Gradio App +Upload an image and generate a 3D mesh. Supports both automatic (MoGe) and fixed camera parameters. +""" + +import os +os.environ["no_proxy"] = os.environ.get("no_proxy", "") + ",localhost,127.0.0.1" + +import torch +import tempfile +import numpy as np +from PIL import Image +from torchvision import transforms + +import gradio as gr + +from pixal3dpipeline2stage import Pixal3DPipeline2Stage +from pixal3dpipeline import Pixal3DPipeline + + +import trimesh +from trimesh.visual.material import PBRMaterial +from trimesh.transformations import rotation_matrix +# Static files directory for model viewer +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +SAVE_DIR = os.path.join(CURRENT_DIR, "gradio_outputs") + +# Global pipeline reference +pipeline = None +rmbg = None + + +def load_pipeline(ckpt_dir="./ckpt", repo_id="Pixal3D/Pixal3D"): + """Load all weights at startup.""" + global pipeline, rmbg + print("Loading Pixal3D 2-Stage pipeline (with MoGe + dense_check)...") + pipeline = Pixal3DPipeline2Stage.from_pretrained( + ckpt_dir=ckpt_dir, + repo_id=repo_id, + use_moge=True, + use_dense_check=True, + ) + print("Pipeline loaded!") + print("Loading BiRefNet for background removal...") + from transformers import AutoModelForImageSegmentation + birefnet_model = AutoModelForImageSegmentation.from_pretrained( + 'ZhengPeng7/BiRefNet', + trust_remote_code=True, + ).to("cuda:0") + birefnet_model.eval() + rmbg = birefnet_model + print("BiRefNet loaded!") + + +def remove_background(image_np): + """Use BiRefNet to remove background and add alpha channel. + Input: numpy array (H, W, 3) RGB + Output: numpy array (H, W, 4) RGBA + """ + pil_img = Image.fromarray(image_np[:, :, :3]).convert('RGB') + image_size = (1024, 1024) + transform_image = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + input_tensor = transform_image(pil_img).unsqueeze(0).to("cuda:0") + with torch.no_grad(): + preds = rmbg(input_tensor)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(pil_img.size) + mask = np.array(mask) + rgba = np.concatenate([np.array(pil_img), mask[..., None]], axis=-1) + return rgba + + +def preprocess_image(image, use_rmbg): + """Step 1: process image (background removal or use original), return immediately. + + use_rmbg=True: run BiRefNet to remove background and generate RGBA + use_rmbg=False: directly use the original image (RGB or RGBA), skip background removal + """ + if image is None: + return None + + if use_rmbg: + # Run background removal + if rmbg is None: + gr.Warning("Background removal model not loaded.") + return None + processed = remove_background(image) + else: + # Directly use original image, no background removal + processed = image + + os.makedirs("./gradio_outputs", exist_ok=True) + Image.fromarray(processed).save("./gradio_outputs/processed.png") + return processed + + +def infer_mesh( + processed, + use_fixed_camera, + camera_angle_x, + mesh_scale, + dense_steps, + dense_guidance_scale, + dense_seed, + sparse_512_steps, + sparse_512_guidance_scale, + sparse_1024_steps, + sparse_1024_guidance_scale, + sparse_seed, + dense_threshold, + mc_threshold, +): + """Step 2: run 3D inference on the already-processed image.""" + if processed is None or pipeline is None: + return None, None + + tmp_input = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + Image.fromarray(processed).save(tmp_input.name) + input_path = tmp_input.name + + try: + if use_fixed_camera: + mesh = Pixal3DPipeline.infer( + pipeline, + image=input_path, + camera_angle_x=camera_angle_x, + mesh_scale=mesh_scale, + dense_steps=int(dense_steps), + dense_guidance_scale=dense_guidance_scale, + dense_seed=int(dense_seed), + sparse_512_steps=int(sparse_512_steps), + sparse_512_guidance_scale=sparse_512_guidance_scale, + sparse_1024_steps=int(sparse_1024_steps), + sparse_1024_guidance_scale=sparse_1024_guidance_scale, + sparse_seed=int(sparse_seed), + dense_threshold=dense_threshold, + mc_threshold=mc_threshold, + ) + else: + mesh = pipeline.infer( + image=input_path, + mesh_scale=mesh_scale, + optimize_mesh_scale=True, + target_padding=3, + max_optim_iterations=2, + dense_steps=int(dense_steps), + dense_guidance_scale=dense_guidance_scale, + dense_seed=int(dense_seed), + sparse_512_steps=int(sparse_512_steps), + sparse_512_guidance_scale=sparse_512_guidance_scale, + sparse_1024_steps=int(sparse_1024_steps), + sparse_1024_guidance_scale=sparse_1024_guidance_scale, + sparse_seed=int(sparse_seed), + dense_threshold=dense_threshold, + mc_threshold=mc_threshold, + ) + + ply_file = tempfile.NamedTemporaryFile(suffix=".ply", delete=False) + glb_file = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) + ply_path = ply_file.name + glb_path = glb_file.name + ply_file.close() + glb_file.close() + mesh.export(ply_path) + # Export GLB with PBR material (same as hunyuan_app) + + material = PBRMaterial(baseColorFactor=[102, 102, 102, 255]) + clean_mesh = trimesh.Trimesh(mesh.vertices, mesh.faces) + clean_mesh.visual = trimesh.visual.TextureVisuals(material=material) + # Rotate mesh to desired view angle (only X rotation needed) + rot_x = rotation_matrix(np.radians(-90), [1, 0, 0]) + clean_mesh.apply_transform(rot_x) + clean_mesh.export(glb_path) + + return glb_path, ply_path + + except Exception as e: + import traceback + traceback.print_exc() + return None, None + finally: + os.unlink(input_path) + + +def build_ui(): + # Custom CSS to hide the download button in Model3D + custom_css = """ + #model3d-viewer button[aria-label="下载"], + #model3d-viewer button[aria-label="Download"], + #model3d-viewer button[title="下载"], + #model3d-viewer button[title="Download"] { + display: none !important; + } + """ + + with gr.Blocks(title="Pixal3D", theme=gr.themes.Soft(), css=custom_css) as demo: + gr.Markdown("# Pixal3D: Pixel-Aligned 3D Generation from Images") + + with gr.Row(): + # Left column: input (scale=1) + with gr.Column(scale=1): + image_input = gr.Image(label="Input Image", type="numpy", image_mode=None) + + processed_image = gr.Image( + label="Processed Image", + image_mode="RGBA", + type="numpy", + interactive=False, + ) + + use_rmbg = gr.Checkbox( + label="Remove Background", + value=True, + info="Checked: auto remove background via BiRefNet. Unchecked: use original image directly.", + ) + + use_fixed_camera = gr.Checkbox( + label="Use Fixed Camera Parameters", + value=False, + info="If checked, use manually set FOV/distance/mesh_scale instead of MoGe auto-estimation.", + ) + + with gr.Group(visible=False) as fixed_camera_group: + gr.Markdown("### Camera Parameters (fixed mode)") + camera_angle_x = gr.Number(value=0.2, label="camera_angle_x (rad)", step=0.01) + + with gr.Group(): + gr.Markdown("### Mesh Scale") + mesh_scale = gr.Number(value=0.5, label="mesh_scale", step=0.01, + info="Initial mesh scale. Fixed mode default: 0.9, Auto mode default: 0.5") + + with gr.Accordion("Advanced Inference Parameters", open=False): + dense_steps = gr.Number(value=50, label="Dense Steps", step=1, precision=0) + dense_guidance_scale = gr.Number(value=7.0, label="Dense Guidance Scale", step=0.1) + dense_seed = gr.Number(value=0, label="Dense Seed", step=1, precision=0) + sparse_512_steps = gr.Number(value=30, label="Sparse 512 Steps", step=1, precision=0) + sparse_512_guidance_scale = gr.Number(value=7.0, label="Sparse 512 Guidance Scale", step=0.1) + sparse_1024_steps = gr.Number(value=15, label="Sparse 1024 Steps", step=1, precision=0) + sparse_1024_guidance_scale = gr.Number(value=7.0, label="Sparse 1024 Guidance Scale", step=0.1) + sparse_seed = gr.Number(value=0, label="Sparse Seed", step=1, precision=0) + dense_threshold = gr.Number(value=0.1, label="Dense Threshold", step=0.01) + mc_threshold = gr.Number(value=0.2, label="MC Threshold", step=0.01) + + run_btn = gr.Button("Generate 3D Mesh", variant="primary", size="lg") + + # Right column: output (scale=2) + with gr.Column(scale=2): + model_viewer = gr.Model3D(label="3D Mesh Preview", interactive=False, clear_color=[1.0, 1.0, 1.0, 1.0], elem_id="model3d-viewer") + output_file = gr.File(label="Download .ply") + + # Toggle fixed camera group visibility and mesh_scale default + def on_toggle_fixed(use_fixed): + new_scale = 0.9 if use_fixed else 0.5 + return gr.update(visible=use_fixed), gr.update(value=new_scale) + + use_fixed_camera.change( + fn=on_toggle_fixed, + inputs=[use_fixed_camera], + outputs=[fixed_camera_group, mesh_scale], + ) + + # Step 1: preprocess image → show processed image immediately + # Step 2: run 3D inference → show mesh and download + run_btn.click( + fn=preprocess_image, + inputs=[image_input, use_rmbg], + outputs=[processed_image], + ).then( + fn=infer_mesh, + inputs=[ + processed_image, + use_fixed_camera, + camera_angle_x, + mesh_scale, + dense_steps, + dense_guidance_scale, + dense_seed, + sparse_512_steps, + sparse_512_guidance_scale, + sparse_1024_steps, + sparse_1024_guidance_scale, + sparse_seed, + dense_threshold, + mc_threshold, + ], + outputs=[model_viewer, output_file], + ) + + demo.queue(api_open=False) + return demo + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--repo_id", type=str, default="TencentARC/Pixal3D-D") + args = parser.parse_args() + + load_pipeline(repo_id=args.repo_id) + + demo = build_ui() + demo.launch( + server_name="127.0.0.1", + share=True, + ) diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..f31f1afb09a628ccaf92c924675a9c67164740be --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +libsparsehash-dev \ No newline at end of file diff --git a/pixal3d/__init__.py b/pixal3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4d0a7fcc21a6ac7a6c2d870eef6f2e873fc8de --- /dev/null +++ b/pixal3d/__init__.py @@ -0,0 +1,44 @@ +import importlib + +__modules__ = {} + + +def register(name): + def decorator(cls): + # Allow re-registration for checkpoint loading compatibility + # When torch.load triggers module re-import, the same class may be registered again + __modules__[name] = cls + return cls + + return decorator + + +def find(name): + if name in __modules__: + return __modules__[name] + else: + try: + module_string = ".".join(name.split(".")[:-1]) + cls_name = name.split(".")[-1] + module = importlib.import_module(module_string, package=None) + return getattr(module, cls_name) + except Exception as e: + raise ValueError(f"Module {name} not found!") + + +### grammar sugar for logging utilities ### +import logging + +logger = logging.getLogger("pixal3d") + + +def debug(*args, **kwargs): + logger.debug(*args, **kwargs) + + +def info(*args, **kwargs): + logger.info(*args, **kwargs) + + +def warn(*args, **kwargs): + logger.warning(*args, **kwargs) diff --git a/pixal3d/__pycache__/__init__.cpython-310.pyc b/pixal3d/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ce34b0aeafc8042f153c915b7b9436f4e5c809c Binary files /dev/null and b/pixal3d/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/models/__init__.py b/pixal3d/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6db3f994c209c3329de8736aac109a84a49c2654 --- /dev/null +++ b/pixal3d/models/__init__.py @@ -0,0 +1 @@ +from . import conditional_encoders, transformers diff --git a/pixal3d/models/__pycache__/__init__.cpython-310.pyc b/pixal3d/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8665f613e7b0d45c83fc5f71f06d33b7daea3af Binary files /dev/null and b/pixal3d/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/models/autoencoders/__pycache__/base.cpython-310.pyc b/pixal3d/models/autoencoders/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f15abc4c110be65bac8379638db2493f6037421 Binary files /dev/null and b/pixal3d/models/autoencoders/__pycache__/base.cpython-310.pyc differ diff --git a/pixal3d/models/autoencoders/__pycache__/decoder.cpython-310.pyc b/pixal3d/models/autoencoders/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cff0dc84057c61ee1e8ff5f662643a309de11b2 Binary files /dev/null and b/pixal3d/models/autoencoders/__pycache__/decoder.cpython-310.pyc differ diff --git a/pixal3d/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc b/pixal3d/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e4fcd4ec398e00813e58606af0eadccc98eb312 Binary files /dev/null and b/pixal3d/models/autoencoders/__pycache__/dense_vae.cpython-310.pyc differ diff --git a/pixal3d/models/autoencoders/__pycache__/distributions.cpython-310.pyc b/pixal3d/models/autoencoders/__pycache__/distributions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92743ac296768ffc1dffecb60c9b092f1d15bd74 Binary files /dev/null and b/pixal3d/models/autoencoders/__pycache__/distributions.cpython-310.pyc differ diff --git a/pixal3d/models/autoencoders/__pycache__/encoder.cpython-310.pyc b/pixal3d/models/autoencoders/__pycache__/encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..489272eddaecb538d0b82c9a1cfbe006f959b7fe Binary files /dev/null and b/pixal3d/models/autoencoders/__pycache__/encoder.cpython-310.pyc differ diff --git a/pixal3d/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc b/pixal3d/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef07be42c31b353120841224317bfdd48a71a8e6 Binary files /dev/null and b/pixal3d/models/autoencoders/__pycache__/ss_vae.cpython-310.pyc differ diff --git a/pixal3d/models/autoencoders/base.py b/pixal3d/models/autoencoders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ffeea9be7a588d0de1de9aae07ba00dbca99e1 --- /dev/null +++ b/pixal3d/models/autoencoders/base.py @@ -0,0 +1,118 @@ +from typing import * +import torch +import torch.nn as nn + +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer import SparseTransformerBlock + + +def block_attn_config(self): + """ + Return the attention configuration of the model. + """ + for i in range(self.num_blocks): + if self.attn_mode == "shift_window": + yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_sequence": + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_order": + yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] + elif self.attn_mode == "full": + yield "full", None, None, None, None + elif self.attn_mode == "swin": + yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None + + +class SparseTransformerBase(nn.Module): + """ + Sparse Transformer without output layers. + Serve as the base class for encoder and decoder. + """ + def __init__( + self, + in_channels: int, + model_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.window_size = window_size + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.attn_mode = attn_mode + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.qk_rms_norm = qk_rms_norm + self.dtype = torch.float16 if use_fp16 else torch.float32 + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + self.input_layer = sp.SparseLinear(in_channels, model_channels) + self.blocks = nn.ModuleList([ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) + ]) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + # self.blocks.apply(convert_module_to_f16) + self.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, factor: float = None) -> sp.SparseTensor: + h = self.input_layer(x) + if self.pe_mode == "ape": + h = h + self.pos_embedder(x.coords[:, 1:], factor) + h = h.type(self.dtype) + for block in self.blocks: + h = block(h) + return h \ No newline at end of file diff --git a/pixal3d/models/autoencoders/decoder.py b/pixal3d/models/autoencoders/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..209cccf24f80dfcd962610c31811dadfeaf50812 --- /dev/null +++ b/pixal3d/models/autoencoders/decoder.py @@ -0,0 +1,353 @@ +from typing import * +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from .base import SparseTransformerBase + + +class SparseSubdivideBlock3d(nn.Module): + + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.act_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, padding=1), + sp.SparseSiLU() + ) + + self.sub = sp.SparseSubdivide() + + self.out_layers = nn.Sequential( + sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1), + sp.SparseSiLU(), + ) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.act_layers(x) + h = self.sub(h) + h = self.out_layers(h) + return h + + def forward(self, x: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseSDFDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + out_channels: int = 1, + chunk_size: int = 1, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self.out_channels = out_channels + self.chunk_size = chunk_size + self.upsample = nn.ModuleList([ + SparseSubdivideBlock3d( + channels=model_channels, + out_channels=model_channels // 4, + use_checkpoint=use_checkpoint, + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + out_channels=model_channels // 8, + use_checkpoint=use_checkpoint, + ), + SparseSubdivideBlock3d( + channels=model_channels // 8, + out_channels=model_channels // 16, + use_checkpoint=use_checkpoint, + ) + ]) + + self.out_layer = sp.SparseLinear(model_channels // 16, self.out_channels) + self.out_active = sp.SparseTanh() + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + super().convert_to_fp16() + self.upsample.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + super().convert_to_fp32() + self.upsample.apply(convert_module_to_f32) + + @torch.no_grad() + def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4): + + sub_resolution = self.resolution // chunk_size + upsample_ratio = 8 # hard-coded here + assert sub_resolution % padding == 0 + out = [] + + for i in range(chunk_size): + for j in range(chunk_size): + for k in range(chunk_size): + # Calculate padded boundaries + start_x = max(0, i * sub_resolution - padding) + end_x = min((i + 1) * sub_resolution + padding, self.resolution) + start_y = max(0, j * sub_resolution - padding) + end_y = min((j + 1) * sub_resolution + padding, self.resolution) + start_z = max(0, k * sub_resolution - padding) + end_z = min((k + 1) * sub_resolution + padding, self.resolution) + + # Store original (unpadded) boundaries for later cropping + orig_start_x = i * sub_resolution + orig_end_x = (i + 1) * sub_resolution + orig_start_y = j * sub_resolution + orig_end_y = (j + 1) * sub_resolution + orig_start_z = k * sub_resolution + orig_end_z = (k + 1) * sub_resolution + + mask = torch.logical_and( + torch.logical_and( + torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), + torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) + ), + torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) + ) + + if mask.sum() > 0: + # Get the coordinates and shift them to local space + coords = x.coords[mask].clone() + # Shift to local coordinates + coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], + device=coords.device).view(1, 3) + + chunk_tensor = sp.SparseTensor(x.feats[mask], coords) + # Store the boundaries and offsets as metadata for later reconstruction + chunk_tensor.bounds = { + 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), + 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction + } + out.append(chunk_tensor) + + del mask + torch.cuda.empty_cache() + return out + + @torch.no_grad() + def split_single_chunk(self, x: sp.SparseTensor, chunk_size=4, padding=4): + sub_resolution = self.resolution // chunk_size + upsample_ratio = 8 # hard-coded here + assert sub_resolution % padding == 0 + + mask_sum = -1 + while mask_sum < 1: + orig_start_x = random.randint(0, self.resolution - sub_resolution) + orig_end_x = orig_start_x + sub_resolution + orig_start_y = random.randint(0, self.resolution - sub_resolution) + orig_end_y = orig_start_y + sub_resolution + orig_start_z = random.randint(0, self.resolution - sub_resolution) + orig_end_z = orig_start_z + sub_resolution + start_x = max(0, orig_start_x - padding) + end_x = min(orig_end_x + padding, self.resolution) + start_y = max(0, orig_start_y - padding) + end_y = min(orig_end_y + padding, self.resolution) + start_z = max(0, orig_start_z - padding) + end_z = min(orig_end_z + padding, self.resolution) + + mask_ori = torch.logical_and( + torch.logical_and( + torch.logical_and(x.coords[:, 1] >= orig_start_x, x.coords[:, 1] < orig_end_x), + torch.logical_and(x.coords[:, 2] >= orig_start_y, x.coords[:, 2] < orig_end_y) + ), + torch.logical_and(x.coords[:, 3] >= orig_start_z, x.coords[:, 3] < orig_end_z) + ) + mask_sum = mask_ori.sum() + + # Store the boundaries and offsets as metadata for later reconstruction + bounds = { + 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), + 'start': (start_x, end_x, start_y, end_y, start_z, end_z), + 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction + } + return bounds + + def forward_single_chunk(self, x: sp.SparseTensor, padding=4): + + bounds = self.split_single_chunk(x, self.chunk_size, padding=padding) + + start_x, end_x, start_y, end_y, start_z, end_z = bounds['start'] + mask = torch.logical_and( + torch.logical_and( + torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), + torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) + ), + torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) + ) + + # Shift to local coordinates + coords = x.coords.clone() + coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], + device=coords.device).view(1, 3) + + chunk = sp.SparseTensor(x.feats[mask], coords[mask]) + + chunk_result = self.upsamples(chunk) + + coords = chunk_result.coords.clone() + + # Restore global coordinates + offsets = torch.tensor(bounds['offsets'], + device=coords.device).view(1, 3) + coords[:, 1:] = coords[:, 1:] + offsets + + # Filter points within original bounds + original = bounds['original'] + within_bounds = torch.logical_and( + torch.logical_and( + torch.logical_and( + coords[:, 1] >= original[0], + coords[:, 1] < original[1] + ), + torch.logical_and( + coords[:, 2] >= original[2], + coords[:, 2] < original[3] + ) + ), + torch.logical_and( + coords[:, 3] >= original[4], + coords[:, 3] < original[5] + ) + ) + + final_coords = coords[within_bounds] + final_feats = chunk_result.feats[within_bounds] + + return sp.SparseTensor(final_feats, final_coords) + + def upsamples(self, x, return_feat: bool = False): + dtype = x.dtype + for block in self.upsample: + x = block(x) + x = x.type(dtype) + + output = self.out_active(self.out_layer(x)) + + if return_feat: + return output, x + else: + return output + + def forward(self, x: sp.SparseTensor, factor: float = None, return_feat: bool = False): + h = super().forward(x, factor) + if self.chunk_size <= 1: + for block in self.upsample: + h = block(h) + h = h.type(x.dtype) + + if return_feat: + return self.out_active(self.out_layer(h)), h + + h = self.out_layer(h) + h = self.out_active(h) + return h + else: + if self.training: + return self.forward_single_chunk(h) + else: + batch_size = x.shape[0] + chunks = self.split_for_meshing(h, chunk_size=self.chunk_size) + all_coords, all_feats = [], [] + for chunk_idx, chunk in enumerate(chunks): + chunk_result = self.upsamples(chunk) + + for b in range(batch_size): + mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1) + if mask.numel() > 0: + coords = chunk_result.coords[mask].clone() + + # Restore global coordinates + offsets = torch.tensor(chunk.bounds['offsets'], + device=coords.device).view(1, 3) + coords[:, 1:] = coords[:, 1:] + offsets + + # Filter points within original bounds + bounds = chunk.bounds['original'] + within_bounds = torch.logical_and( + torch.logical_and( + torch.logical_and( + coords[:, 1] >= bounds[0], + coords[:, 1] < bounds[1] + ), + torch.logical_and( + coords[:, 2] >= bounds[2], + coords[:, 2] < bounds[3] + ) + ), + torch.logical_and( + coords[:, 3] >= bounds[4], + coords[:, 3] < bounds[5] + ) + ) + + if within_bounds.any(): + all_coords.append(coords[within_bounds]) + all_feats.append(chunk_result.feats[mask][within_bounds]) + + if not self.training: + torch.cuda.empty_cache() + + final_coords = torch.cat(all_coords) + final_feats = torch.cat(all_feats) + + return sp.SparseTensor(final_feats, final_coords) + \ No newline at end of file diff --git a/pixal3d/models/autoencoders/dense_vae.py b/pixal3d/models/autoencoders/dense_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..f7aa0cfdc8a307b65c3ed2746ee2e4a5f4537c45 --- /dev/null +++ b/pixal3d/models/autoencoders/dense_vae.py @@ -0,0 +1,401 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import trimesh +from skimage import measure +from ...modules.norm import GroupNorm32, ChannelLayerNorm32 +from ...modules.spatial import pixel_shuffle_3d +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from .distributions import DiagonalGaussianDistribution + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + use_checkpoint: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.use_checkpoint = use_checkpoint + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = self.out_layer(h) + + return h + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + use_checkpoint: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.use_checkpoint = use_checkpoint + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + # self.blocks.apply(convert_module_to_f16) + # self.middle_block.apply(convert_module_to_f16) + self.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = self.out_layer(h) + return h + + +class DenseShapeVAE(nn.Module): + def __init__(self, + embed_dim: int = 0, + model_channels_encoder: list = [32, 128, 512], + model_channels_decoder: list = [512, 128, 32], + num_res_blocks_encoder: int = 2, + num_res_blocks_middle_encoder: int = 2, + num_res_blocks_decoder: int = 2, + num_res_blocks_middle_decoder: int=2, + in_channels: int = 1, + out_channels: int = 1, + use_fp16: bool = False, + use_checkpoint: bool = False, + latents_scale: float = 1.0, + latents_shift: float = 0.0): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.latents_scale = latents_scale + self.latents_shift = latents_shift + + self.encoder = SparseStructureEncoder( + in_channels=in_channels, + latent_channels=embed_dim, + num_res_blocks=num_res_blocks_encoder, + channels=model_channels_encoder, + num_res_blocks_middle=num_res_blocks_middle_encoder, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + ) + + self.decoder = SparseStructureDecoder( + num_res_blocks=num_res_blocks_decoder, + num_res_blocks_middle=num_res_blocks_middle_decoder, + channels=model_channels_decoder, + latent_channels=embed_dim, + out_channels=out_channels, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + ) + + self.embed_dim = embed_dim + + def encode(self, batch, sample_posterior: bool = True): + + x = batch['dense_index'] * 2.0 - 1.0 + h = self.encoder(x) + posterior = DiagonalGaussianDistribution(h, feat_dim=1) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + return z, posterior + + def forward(self, batch): + + z, posterior = self.encode(batch) + reconst_x = self.decoder(z) + outputs = {'reconst_x': reconst_x, 'posterior': posterior} + + return outputs + + def decode_mesh(self, + latents, + voxel_resolution: int = 64, + mc_threshold: float = 0.5, + return_index: bool = False): + x = self.decoder(latents) + if return_index: + outputs = [] + for i in range(len(x)): + occ = x[i].sigmoid() + occ = (occ >= mc_threshold).float().squeeze(0) + index = occ.unsqueeze(0).nonzero() + outputs.append(index) + else: + outputs = self.dense2mesh(x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold) + + return outputs + + def dense2mesh(self, + x: torch.FloatTensor, + voxel_resolution: int = 64, + mc_threshold: float = 0.5): + + meshes = [] + for i in range(len(x)): + occ = x[i].sigmoid() + occ = (occ >= 0.1).float().squeeze(0).cpu().detach().numpy() + vertices, faces, _, _ = measure.marching_cubes( + occ, + mc_threshold, + method="lewiner", + ) + vertices = vertices / voxel_resolution * 2 - 1 + meshes.append(trimesh.Trimesh(vertices, faces)) + + return meshes diff --git a/pixal3d/models/autoencoders/distributions.py b/pixal3d/models/autoencoders/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..a702359a8d2dedcb28c2e654bb60221af7f72c8f --- /dev/null +++ b/pixal3d/models/autoencoders/distributions.py @@ -0,0 +1,51 @@ +import torch +import numpy as np +from typing import Union, List + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self): + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=dims) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=dims) + + def nll(self, sample, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean diff --git a/pixal3d/models/autoencoders/encoder.py b/pixal3d/models/autoencoders/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..49fe567ff4dd1b62bae84c8627b27502ec4cdd09 --- /dev/null +++ b/pixal3d/models/autoencoders/encoder.py @@ -0,0 +1,133 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .base import SparseTransformerBase + + +class SparseDownBlock3d(nn.Module): + + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + num_groups: int = 32, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.down = sp.SparseDownsample(2) + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, padding=1), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1) + + self.use_checkpoint = use_checkpoint + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.act_layers(x) + h = self.down(h) + x = self.down(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseSDFEncoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + + self.input_layer1 = sp.SparseLinear(1, model_channels // 16) + + self.downsample = nn.ModuleList([ + SparseDownBlock3d( + channels=model_channels//16, + out_channels=model_channels // 8, + use_checkpoint=use_checkpoint, + ), + SparseDownBlock3d( + channels=model_channels // 8, + out_channels=model_channels // 4, + use_checkpoint=use_checkpoint, + ), + SparseDownBlock3d( + channels=model_channels // 4, + out_channels=model_channels, + use_checkpoint=use_checkpoint, + ) + ]) + + self.resolution = resolution + self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, factor: float = None): + + x = self.input_layer1(x) + for block in self.downsample: + x = block(x) + h = super().forward(x, factor) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + + return h \ No newline at end of file diff --git a/pixal3d/models/autoencoders/ss_vae.py b/pixal3d/models/autoencoders/ss_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..b32fd20767df4f90d7d5bc55b2b393676e8e47a3 --- /dev/null +++ b/pixal3d/models/autoencoders/ss_vae.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +import trimesh +from skimage import measure + +from ...modules import sparse as sp +from .encoder import SparseSDFEncoder +from .decoder import SparseSDFDecoder +from .distributions import DiagonalGaussianDistribution + + +class SparseSDFVAE(nn.Module): + def __init__(self, *, + embed_dim: int = 0, + resolution: int = 64, + model_channels_encoder: int = 512, + num_blocks_encoder: int = 4, + num_heads_encoder: int = 8, + num_head_channels_encoder: int = 64, + model_channels_decoder: int = 512, + num_blocks_decoder: int = 4, + num_heads_decoder: int = 8, + num_head_channels_decoder: int = 64, + out_channels: int = 1, + use_fp16: bool = False, + use_checkpoint: bool = False, + chunk_size: int = 1, + latents_scale: float = 1.0, + latents_shift: float = 0.0): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.resolution = resolution + self.latents_scale = latents_scale + self.latents_shift = latents_shift + + self.encoder = SparseSDFEncoder( + resolution=resolution, + in_channels=model_channels_encoder, + model_channels=model_channels_encoder, + latent_channels=embed_dim, + num_blocks=num_blocks_encoder, + num_heads=num_heads_encoder, + num_head_channels=num_head_channels_encoder, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + ) + + self.decoder = SparseSDFDecoder( + resolution=resolution, + model_channels=model_channels_decoder, + latent_channels=embed_dim, + num_blocks=num_blocks_decoder, + num_heads=num_heads_decoder, + num_head_channels=num_head_channels_decoder, + out_channels=out_channels, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + chunk_size=chunk_size, + ) + self.embed_dim = embed_dim + + def forward(self, batch): + + z, posterior = self.encode(batch) + + reconst_x = self.decoder(z) + outputs = {'reconst_x': reconst_x, 'posterior': posterior} + return outputs + + def encode(self, batch, sample_posterior: bool = True): + + feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx'] + if feat.ndim == 1: + feat = feat.unsqueeze(-1) + coords = torch.cat([batch_idx.unsqueeze(-1), xyz], dim=-1).int() + + x = sp.SparseTensor(feat, coords) + h = self.encoder(x, batch.get('factor', None)) + posterior = DiagonalGaussianDistribution(h.feats, feat_dim=1) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + z = h.replace(z) + + return z, posterior + + def decode_mesh(self, + latents, + voxel_resolution: int = 512, + mc_threshold: float = 0.2, + return_feat: bool = False, + factor: float = 1.0): + voxel_resolution = int(voxel_resolution / factor) + reconst_x = self.decoder(latents, factor=factor, return_feat=return_feat) + if return_feat: + return reconst_x + outputs = self.sparse2mesh(reconst_x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold) + + return outputs + + def sparse2mesh(self, + reconst_x: torch.FloatTensor, + voxel_resolution: int = 512, + mc_threshold: float = 0.0): + + sparse_sdf, sparse_index = reconst_x.feats.float(), reconst_x.coords + batch_size = int(sparse_index[..., 0].max().cpu().numpy() + 1) + + meshes = [] + for i in range(batch_size): + idx = sparse_index[..., 0] == i + sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1).cpu(), sparse_index[idx][..., 1:].detach().cpu() + sdf = torch.ones((voxel_resolution, voxel_resolution, voxel_resolution)) + sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i + vertices, faces, _, _ = measure.marching_cubes( + sdf.numpy(), + mc_threshold, + method="lewiner", + ) + vertices = vertices / voxel_resolution * 2 - 1 + meshes.append(trimesh.Trimesh(vertices, faces)) + + return meshes diff --git a/pixal3d/models/conditional_encoders/__init__.py b/pixal3d/models/conditional_encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1daefabbaadb53997ef091c491108804331d4c9f --- /dev/null +++ b/pixal3d/models/conditional_encoders/__init__.py @@ -0,0 +1,2 @@ +from . import dinov2_project_grid + diff --git a/pixal3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc b/pixal3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..107c747dea5c89075f316927509616d50cdf2f23 Binary files /dev/null and b/pixal3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/models/conditional_encoders/__pycache__/dinov2_project_grid.cpython-310.pyc b/pixal3d/models/conditional_encoders/__pycache__/dinov2_project_grid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..750786971e35eaa35c10e4e7faafeab07eb518d2 Binary files /dev/null and b/pixal3d/models/conditional_encoders/__pycache__/dinov2_project_grid.cpython-310.pyc differ diff --git a/pixal3d/models/conditional_encoders/dinov2_project_grid.py b/pixal3d/models/conditional_encoders/dinov2_project_grid.py new file mode 100644 index 0000000000000000000000000000000000000000..a9dcf30d47461922df23494a50560faee0295aed --- /dev/null +++ b/pixal3d/models/conditional_encoders/dinov2_project_grid.py @@ -0,0 +1,750 @@ +""" +DINOv2 Project Grid Encoders +Includes single-view and multi-view DINOv2 encoders with 3D grid projection support +""" + +import random +from dataclasses import dataclass +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from diffusers.models.modeling_utils import ModelMixin + +import pixal3d +from pixal3d.utils.base import BaseModule + +# Set linear algebra backend to avoid cusolver errors +try: + torch.backends.cuda.preferred_linalg_library("cusolver") +except Exception: + pass + + +# ============================================================================= +# Base DINOv2 Encoder +# ============================================================================= + +@pixal3d.register("dinov2-encoder") +class DinoEncoder(BaseModule, ModelMixin): + """Base DINOv2 Encoder""" + + @dataclass + class Config(BaseModule.Config): + model: str = "facebookresearch/dinov2" + version: str = "dinov2_vitl14_reg" + size: int = 518 + empty_embeds_ratio: float = 0.1 + + cfg: Config + + def configure(self) -> None: + super().configure() + self.empty_embeds_ratio = self.cfg.empty_embeds_ratio + + # Load DINOv2 model + dino_model = torch.hub.load( + self.cfg.model, self.cfg.version, pretrained=True + ) + self.encoder = dino_model.eval() + + # Image preprocessing + self.transform = transforms.Compose([ + transforms.Resize( + self.cfg.size, + transforms.InterpolationMode.BILINEAR, + antialias=True + ), + transforms.CenterCrop(self.cfg.size), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ]) + + + + + def forward(self, image, image_mask=None, is_training=False): + z = self.encoder(self.transform(image), is_training=True)['x_prenorm'] + z = F.layer_norm(z, z.shape[-1:]) + + if is_training and random.random() < self.empty_embeds_ratio: + # zero out embeddings + z = z * 0 + + if image_mask is not None: + image_mask_patch = F.max_pool2d( + image_mask, kernel_size=14, stride=14 + ).squeeze(1) > 0 + return z, image_mask_patch + + return z + + +# ============================================================================= +# 3D Projection Utility Functions +# ============================================================================= + +def project_points_to_image_batch( + points_3d: torch.Tensor, + transform_matrix: torch.Tensor, + camera_angle_x: torch.Tensor, + resolution: int = 518 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D image coordinates with batch support + + Args: + points_3d: [N, 3] or [B, N, 3], 3D point coordinates (in [-1, 1] range) + transform_matrix: [B, 4, 4], batch of camera transformation matrices + camera_angle_x: [B], batch of camera horizontal FOV angles (radians) + resolution: Rendering image resolution + + Returns: + points_2d: [B, N, 2], image coordinates [x, y] + depth: [B, N], depth values + valid_mask: [B, N], mask indicating if points are within view + """ + device = points_3d.device + B = transform_matrix.shape[0] + + # Ensure inputs are torch.Tensor + if not isinstance(transform_matrix, torch.Tensor): + transform_matrix = torch.tensor( + transform_matrix, dtype=torch.float32, device=device + ) + if not isinstance(points_3d, torch.Tensor): + points_3d = torch.tensor( + points_3d, dtype=torch.float32, device=device + ) + if not isinstance(camera_angle_x, torch.Tensor): + camera_angle_x = torch.tensor( + camera_angle_x, dtype=torch.float32, device=device + ) + + # Expand points_3d to batch dimension + if points_3d.dim() == 2: + points_3d_batch = points_3d.unsqueeze(0).expand(B, -1, -1) + else: + points_3d_batch = points_3d + + N = points_3d_batch.shape[1] + + # Add homogeneous coordinates + ones = torch.ones(B, N, 1, device=device) + points_homogeneous = torch.cat([points_3d_batch, ones], dim=-1) + + # World to camera transformation + world_to_camera = torch.linalg.inv(transform_matrix) + points_camera = torch.bmm( + points_homogeneous, + world_to_camera.transpose(-2, -1) + )[..., :3] + + # Extract camera coordinates + x_cam = points_camera[..., 0] + y_cam = points_camera[..., 1] + z_cam = points_camera[..., 2] + + # Depth values + depth = -z_cam + + # Compute camera intrinsics + sensor_width = 32.0 + focal_length = 16.0 / torch.tan(camera_angle_x / 2.0) + focal_length_pixels = focal_length * resolution / sensor_width + focal_length_pixels = focal_length_pixels.unsqueeze(1) + + # Perspective projection + x_ndc = focal_length_pixels * x_cam / (-z_cam) + y_ndc = focal_length_pixels * y_cam / (-z_cam) + + # Convert to image coordinates + x_pixel = x_ndc + resolution / 2.0 + y_pixel = -y_ndc + resolution / 2.0 + + # Validity mask + valid_mask = ( + (x_pixel >= 0) & (x_pixel < resolution) & + (y_pixel >= 0) & (y_pixel < resolution) & + (depth > 0) + ) + + points_2d = torch.stack([x_pixel, y_pixel], dim=-1) + return points_2d, depth, valid_mask + + +def project_points_to_image( + points_3d: torch.Tensor, + transform_matrix: torch.Tensor, + camera_angle_x: float, + resolution: int = 512 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D image coordinates (single-view version) + + Args: + points_3d: [N, 3], 3D point coordinates + transform_matrix: [4, 4], camera transformation matrix + camera_angle_x: Camera horizontal FOV angle (radians) + resolution: Rendering image resolution + + Returns: + points_2d: [N, 2], image coordinates [x, y] + depth: [N], depth values + valid_mask: [N], mask indicating if points are within view + """ + device = points_3d.device + + if not isinstance(transform_matrix, torch.Tensor): + transform_matrix = torch.tensor( + transform_matrix, dtype=torch.float32, device=device + ) + if not isinstance(points_3d, torch.Tensor): + points_3d = torch.tensor( + points_3d, dtype=torch.float32, device=device + ) + + N = points_3d.shape[0] + points_homogeneous = torch.cat([ + points_3d, + torch.ones(N, 1, device=device) + ], dim=1) + + # World to camera transformation + camera_to_world = transform_matrix + world_to_camera = torch.linalg.inv(camera_to_world) + points_camera = torch.matmul( + points_homogeneous, + world_to_camera.T + )[:, :3] + + x_cam = points_camera[:, 0] + y_cam = points_camera[:, 1] + z_cam = points_camera[:, 2] + depth = -z_cam + + # Camera intrinsics + sensor_width = 32.0 + focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0)) + focal_length_pixels = focal_length * resolution / sensor_width + + # Perspective projection + x_ndc = focal_length_pixels * x_cam / (-z_cam) + y_ndc = focal_length_pixels * y_cam / (-z_cam) + + # Image coordinates + x_pixel = x_ndc + resolution / 2.0 + y_pixel = -y_ndc + resolution / 2.0 + + valid_mask = ( + (x_pixel >= 0) & (x_pixel < resolution) & + (y_pixel >= 0) & (y_pixel < resolution) & + (depth > 0) + ) + + points_2d = torch.stack([x_pixel, y_pixel], dim=1) + return points_2d, depth, valid_mask + + +def sample_features( + fmap: torch.Tensor, + queries_ndc: torch.Tensor +) -> torch.Tensor: + """ + Sample features using grid_sample + + Args: + fmap: [B, C, H, W], feature map + queries_ndc: [B, K, 2], NDC coordinates + + Returns: + feat: [B, C, K], sampled features + """ + B, C, H, W = fmap.shape + Bq, K, _ = queries_ndc.shape + assert Bq == B, "batch 不一致" + + grid = queries_ndc.view(B, K, 1, 2) + feat = F.grid_sample( + fmap, grid, mode='bilinear', + align_corners=False, padding_mode='border' + ) + return feat.squeeze(-1) + + +# ============================================================================= +# Projection Grid Module +# ============================================================================= + +class ProjGrid(nn.Module): + """3D Grid Projection Module""" + + def __init__(self, grid_resolution: int = 16): + super().__init__() + self.grid_resolution = grid_resolution + self.image_resolution = 518 + + # Create 3D grid points + one_dim = torch.linspace(-1, 1, grid_resolution) + x, y, z = torch.meshgrid(one_dim, one_dim, one_dim, indexing='ij') + grid_points = torch.stack((x, y, z), dim=-1) + + # Rotation matrix (align with Blender) + rotation_matrix = torch.tensor([ + [1.0, 0.0, 0.0], + [0.0, 0.0, -1.0], + [0.0, 1.0, 0.0] + ]) + grid_points = torch.matmul(grid_points, rotation_matrix.T) + grid_points = grid_points.reshape(-1, 3) + self.register_buffer('grid_points', grid_points) + + # Front view transformation matrix + front_view_transform_matrix = torch.tensor([ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, -1.0, -2.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ]) + self.register_buffer( + "front_view_transform_matrix", + front_view_transform_matrix + ) + + def forward( + self, + features_map: torch.Tensor, + camera_angle_x: torch.Tensor, + distance: torch.Tensor, + mesh_scale: torch.Tensor, + transform_matrix: torch.Tensor = None, + BHWC: bool = True + ) -> torch.Tensor: + """ + Project feature map to 3D grid + + Args: + features_map: [B, H, W, C] or [B, C, H, W] + camera_angle_x: [B] + distance: [B] + mesh_scale: [B] + transform_matrix: [B, 4, 4] or None + BHWC: Whether input is in BHWC format + + Returns: + x: [B, K, C], projected features + """ + if BHWC: + B, H, W, C = features_map.shape + else: + B, C, H, W = features_map.shape + + # Prepare grid points + grid_points = self.grid_points.expand(B, -1, -1) + grid_points = grid_points / mesh_scale.unsqueeze(-1).unsqueeze(-1) / 2 + + # Use default transformation matrix + if transform_matrix is None: + transform_matrix = self.front_view_transform_matrix + transform_matrix = transform_matrix.expand(B, -1, -1).clone() + transform_matrix[:, 1, 3] = -distance + + # Project to image + image_points, depth, valid_mask = project_points_to_image_batch( + grid_points, transform_matrix, camera_angle_x, self.image_resolution + ) + + # Normalize to [-1, 1] + + image_points_norm = (image_points + 0.5) / self.image_resolution * 2 - 1 + + + # Adjust dimensions and sample + if BHWC: + features_map = features_map.permute(0, 3, 1, 2) + + x = sample_features(features_map, image_points_norm) + x = x.permute(0, 2, 1) + + return x + + + + + +# ============================================================================= +# DINOv2 Encoder with Projection +# ============================================================================= + +@pixal3d.register("dinov2-encoder-proj") +class DinoEncoderProj(BaseModule, ModelMixin): + """DINOv2 Encoder with 3D Grid Projection""" + + @dataclass + class Config(BaseModule.Config): + model: str = "facebookresearch/dinov2" + version: str = "dinov2_vitl14_reg" + size: int = 518 + empty_embeds_ratio: float = 0.1 + grid_resolution: int = 16 + use_upsample: bool = False + use_geo_feats: bool = False + + cfg: Config + + def configure(self) -> None: + super().configure() + self.grid_resolution = self.cfg.grid_resolution + self.empty_embeds_ratio = self.cfg.empty_embeds_ratio + self.use_upsample = self.cfg.use_upsample + + # Load DINOv2 + dino_model = torch.hub.load( + self.cfg.model, self.cfg.version, pretrained=True + ) + self.encoder = dino_model.eval() + + # Optional: load upsampler + if self.use_upsample: + upsampler = torch.hub.load("valeoai/NAF", "naf", pretrained=True) + self.upsampler = upsampler.eval() + + # Image preprocessing (normalization only) + self.transform = transforms.Compose([ + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ]) + + self.patch_size = self.encoder.patch_size + self.patch_number = self.cfg.size // self.patch_size + self.proj_grid = ProjGrid(grid_resolution=self.cfg.grid_resolution) + + + + + + + def forward( + self, + image: torch.Tensor, + image_mask: torch.Tensor = None, + camera_angle_x: torch.Tensor = None, + distance: torch.Tensor = None, + mesh_scale: torch.Tensor = None, + transform_matrix: torch.Tensor = None, + is_training: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass + + Args: + image: [B, C, H, W] + camera_angle_x: [B] + distance: [B] + mesh_scale: [B] + is_training: Training mode flag + + Returns: + z_global: [B, num_global, C] + z: [B, grid_resolution^3, C] + """ + image = self.transform(image) + + with torch.no_grad(): + z = self.encoder(image, is_training=True)['x_prenorm'] + z = F.layer_norm(z, z.shape[-1:]) + + # Split tokens + z_clstoken = z[:, 0:1] + z_regtokens = z[:, 1:self.encoder.num_register_tokens + 1] + z_patchtokens = z[:, 1 + self.encoder.num_register_tokens:] + z_patchtokens = z_patchtokens.reshape( + z_patchtokens.shape[0], + self.patch_number, + self.patch_number, + -1 + ) + + # Project to grid + z = self.proj_grid( + z_patchtokens, camera_angle_x, distance, mesh_scale + ) + + # Optional: upsample and fuse + if self.use_upsample: + z_patchtokens_permuted = z_patchtokens.permute(0, 3, 1, 2) + z_upsampled = self.upsampler( + image, z_patchtokens_permuted, output_size=(518, 518) + ) + z_upsampled = self.proj_grid( + z_upsampled, camera_angle_x, distance, mesh_scale, BHWC=False + ) + z = z + z_upsampled + + # Global tokens + z_global = torch.cat([z_clstoken, z_regtokens], dim=1) + z_global = z_global.expand(z.shape[0], -1, -1) + + # Classifier-free guidance: random drop + if is_training and random.random() < self.empty_embeds_ratio: + z_global = z_global * 0 + z = z * 0 + + return z_global, z + + +# ============================================================================= +# Multi-View Projection Encoder Helper Functions +# ============================================================================= + +def compute_calc_mat( + true_view_mat: torch.Tensor, + ext_true_view_mat: torch.Tensor, + fix_mat: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute calc_mat using matrix relative transformation + + Args: + true_view_mat: [B, 1, 4, 4], ground truth camera matrix + ext_true_view_mat: [B, N, 4, 4], extended ground truth camera matrices + fix_mat: [B, 1, 4, 4], fixed matrix + + Returns: + calc_mat: [B, N, 4, 4] + relative_transform: [B, N, 4, 4] + """ + B, N = ext_true_view_mat.shape[:2] + + # Expand to [B, N, 4, 4] + true_view_mat_exp = true_view_mat.expand(B, N, 4, 4) + fix_mat_exp = fix_mat.expand(B, N, 4, 4) + + # Flatten to [B*N, 4, 4] + true_view_mat_flat = true_view_mat_exp.reshape(B * N, 4, 4) + ext_true_view_mat_flat = ext_true_view_mat.reshape(B * N, 4, 4) + fix_mat_flat = fix_mat_exp.reshape(B * N, 4, 4) + + # Compute relative transformation (disable autocast for fp32 precision) + with torch.amp.autocast('cuda', enabled=False): + true_view_mat_flat = true_view_mat_flat.float() + ext_true_view_mat_flat = ext_true_view_mat_flat.float() + fix_mat_flat = fix_mat_flat.float() + + relative_transform_flat = torch.bmm( + torch.linalg.inv(true_view_mat_flat), + ext_true_view_mat_flat + ) + calc_mat_flat = torch.bmm(fix_mat_flat, relative_transform_flat) + + calc_mat = calc_mat_flat.view(B, N, 4, 4) + relative_transform = relative_transform_flat.view(B, N, 4, 4) + + return calc_mat, relative_transform + + +# ============================================================================= +# Multi-View DINOv2 Projection Encoder +# ============================================================================= + +@pixal3d.register("dinov2-encoder-proj-multi-view") +class DinoEncoderProjMultiView(BaseModule, ModelMixin): + """Multi-View DINOv2 Projection Encoder""" + + @dataclass + class Config(BaseModule.Config): + model: str = "facebookresearch/dinov2" + version: str = "dinov2_vitl14_reg" + size: int = 518 + empty_embeds_ratio: float = 0.1 + grid_resolution: int = 16 + use_upsample: bool = False + + cfg: Config + + def configure(self) -> None: + super().configure() + self.grid_resolution = self.cfg.grid_resolution + self.empty_embeds_ratio = self.cfg.empty_embeds_ratio + self.use_upsample = self.cfg.use_upsample + + # Load DINOv2 + dino_model = torch.hub.load( + self.cfg.model, self.cfg.version, pretrained=True + ) + + self.encoder = dino_model.eval() + + # Optional: upsampler + if self.use_upsample: + upsampler = torch.hub.load("valeoai/NAF", "naf", pretrained=True) + self.upsampler = upsampler.eval() + + # Image preprocessing + self.transform = transforms.Compose([ + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ]) + + self.patch_size = self.encoder.patch_size + self.patch_number = self.cfg.size // self.patch_size + self.proj_grid = ProjGrid(grid_resolution=self.cfg.grid_resolution) + + # Fixed transformation matrix + self.register_buffer("fix_transform_matrix", torch.tensor([ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, -1.0, -2.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ])) + + def forward( + self, + image: torch.Tensor, + image_mask: torch.Tensor = None, + camera_angle_x: torch.Tensor = None, + distance: torch.Tensor = None, + mesh_scale: torch.Tensor = None, + transform_matrix: torch.Tensor = None, + is_training: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass + + Args: + image: [B, num_views, C, H, W] + camera_angle_x: [B, num_views] + distance: [B, num_views] + mesh_scale: [B] + transform_matrix: [B, num_views, 4, 4] + + Returns: + z_global: [B, num_global, C] + z: [B, grid_resolution^3, C] + """ + B, num_views, C, H, W = image.shape + image = image.reshape(B * num_views, C, H, W) + image = self.transform(image) + + with torch.no_grad(): + z = self.encoder(image, is_training=True)['x_prenorm'] + z = F.layer_norm(z, z.shape[-1:]) + z_clstoken = z[:, 0:1] + z_regtokens = z[:, 1:self.encoder.num_register_tokens + 1] + z_patchtokens = z[:, 1 + self.encoder.num_register_tokens:] + z_patchtokens = z_patchtokens.reshape( + z_patchtokens.shape[0], + self.patch_number, + self.patch_number, + -1 + ) + + # Compute relative transformation + calc_mat, relative_transform = self.get_relative_transform( + transform_matrix, distance + ) + calc_mat = calc_mat.reshape(B * num_views, 4, 4) + + # Prepare parameters + init_mesh_scale = mesh_scale[:, None].expand(B, num_views).reshape(B * num_views) + camera_angle_x_flat = camera_angle_x.reshape(B * num_views) + distance_flat = distance.reshape(B * num_views) + + # Accumulate per-view (avoid OOM) + z_accumulated = None + z_patchtokens_permuted = z_patchtokens.permute(0, 3, 1, 2) if self.use_upsample else None + + with torch.no_grad(): + for view_idx in range(num_views): + indices = torch.arange( + view_idx, B * num_views, num_views, device=z_patchtokens.device + ) + + # Project current view + z_view = self.proj_grid( + z_patchtokens[indices], + camera_angle_x_flat[indices], + distance_flat[indices], + init_mesh_scale[indices], + calc_mat[indices] + ) + + # Optional: upsample + if self.use_upsample: + chunk_upsampled = self.upsampler( + image[indices], + z_patchtokens_permuted[indices], + output_size=(518, 518) + ) + chunk_proj = self.proj_grid( + chunk_upsampled, + camera_angle_x_flat[indices], + distance_flat[indices], + init_mesh_scale[indices], + calc_mat[indices], + BHWC=False + ) + z_view = z_view + chunk_proj + del chunk_upsampled, chunk_proj + + # Accumulate + if z_accumulated is None: + z_accumulated = z_view.clone() + else: + z_accumulated = z_accumulated + z_view + del z_view + + if z_patchtokens_permuted is not None: + del z_patchtokens_permuted + + # Average + z = z_accumulated / num_views + + # Average global tokens + z_global = torch.cat([z_clstoken, z_regtokens], dim=1) + z_global = z_global.reshape(B, num_views, z_global.shape[-2], z_global.shape[-1]) + z_global = z_global.mean(dim=1) + + # Classifier-free guidance + if is_training and random.random() < self.empty_embeds_ratio: + z_global = z_global * 0 + z = z * 0 + + return z_global, z + + def get_relative_transform( + self, + transform_matrix: torch.Tensor, + distance: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute relative transformation matrix + + Args: + transform_matrix: [B, num_views, 4, 4] + distance: [B, num_views] + + Returns: + calc_mat: [B, num_views, 4, 4] + relative_transform: [B, num_views, 4, 4] + """ + B, num_views, _, _ = transform_matrix.shape + init_transform_matrix = transform_matrix[:, 0:1] + + fix_transform_matrix = self.fix_transform_matrix.unsqueeze(0).expand(B, -1, -1).clone() + init_distance = distance[:, 0] + fix_transform_matrix[:, 1, 3] = -init_distance + fix_transform_matrix = fix_transform_matrix.unsqueeze(1) + + calc_mat, relative_transform = compute_calc_mat( + init_transform_matrix, transform_matrix, fix_transform_matrix + ) + return calc_mat, relative_transform diff --git a/pixal3d/models/transformers/__init__.py b/pixal3d/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7dd2693075655278cdb92f4b554fd4a2a55677 --- /dev/null +++ b/pixal3d/models/transformers/__init__.py @@ -0,0 +1,2 @@ +from . import sparse_dit +from . import dense_dit \ No newline at end of file diff --git a/pixal3d/models/transformers/__pycache__/__init__.cpython-310.pyc b/pixal3d/models/transformers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35326df602cb8e85580d0497d7b6e0c7004ae656 Binary files /dev/null and b/pixal3d/models/transformers/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/models/transformers/__pycache__/dense_dit.cpython-310.pyc b/pixal3d/models/transformers/__pycache__/dense_dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9434b0193938cbd3176a13eb23bf22b3275a7c71 Binary files /dev/null and b/pixal3d/models/transformers/__pycache__/dense_dit.cpython-310.pyc differ diff --git a/pixal3d/models/transformers/__pycache__/sparse_dit.cpython-310.pyc b/pixal3d/models/transformers/__pycache__/sparse_dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ecb257db1e6ec15e9be57174f27dfe78b289afe Binary files /dev/null and b/pixal3d/models/transformers/__pycache__/sparse_dit.cpython-310.pyc differ diff --git a/pixal3d/models/transformers/dense_dit.py b/pixal3d/models/transformers/dense_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..db908a52d0c6f33273bcf00feb5183f8a1f79c41 --- /dev/null +++ b/pixal3d/models/transformers/dense_dit.py @@ -0,0 +1,298 @@ +from typing import * +from dataclasses import dataclass +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ...modules.spatial import patchify, unpatchify +from ...utils.base import BaseModule +import pixal3d +from huggingface_hub import hf_hub_download +import os + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_freq = t_freq.to(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class DenseDiT(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + latent_shape: list = [8, 16, 16, 16], + image_attn_mode:str = "cross", + load_ckpt:bool = True, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.latent_shape = latent_shape + self.image_attn_mode = image_attn_mode + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + + self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + image_attn_mode = self.image_attn_mode, + + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + # self.blocks.apply(convert_module_to_f16) + self.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = patchify(x, self.patch_size) + h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() + h = self.input_layer(h) + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + h = h.type(self.dtype) + if self.image_attn_mode=='proj': + global_cond,proj_cond = cond + global_cond = global_cond.type(self.dtype) + proj_cond = proj_cond.type(self.dtype) + cond = (global_cond, proj_cond) + else: + cond = cond.type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + h = h.type(x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) + h = unpatchify(h, self.patch_size).contiguous() + + return h + + +# ===== Align to sparse_dit style: ModelOutput + Denoiser wrapper (Lightning-friendly) ===== + +@dataclass +class DenseDiTModelOutput: + sample: torch.Tensor + + +@pixal3d.register("dense-dit-denoiser") +class DenseDiTDenoiser(BaseModule): + @dataclass + class Config(BaseModule.Config): + # Mirror DenseDiT init signature with reasonable defaults + resolution: int = 64 + in_channels: int = 16 + model_channels: int = 1024 + cond_channels: int = 1024 + out_channels: int = 16 + num_blocks: int = 24 + num_heads: Optional[int] = None + num_head_channels: Optional[int] = 64 + mlp_ratio: float = 4.0 + patch_size: int = 2 + pe_mode: str = "ape" # "ape" | "rope" + use_fp16: bool = False + use_checkpoint: bool = False + share_mod: bool = False + qk_rms_norm: bool = False + qk_rms_norm_cross: bool = False + latent_shape: list = (8, 16, 16, 16) + image_attn_mode: str = "cross" + load_ckpt:bool = True + + cfg: Config + + def configure(self) -> None: + # Instantiate the underlying DenseDiT model + self.dit_model = DenseDiT( + resolution=self.cfg.resolution, + in_channels=self.cfg.in_channels, + model_channels=self.cfg.model_channels, + cond_channels=self.cfg.cond_channels, + out_channels=self.cfg.out_channels, + num_blocks=self.cfg.num_blocks, + num_heads=self.cfg.num_heads, + num_head_channels=self.cfg.num_head_channels, + mlp_ratio=self.cfg.mlp_ratio, + patch_size=self.cfg.patch_size, + pe_mode=self.cfg.pe_mode, + use_fp16=self.cfg.use_fp16, + use_checkpoint=self.cfg.use_checkpoint, + share_mod=self.cfg.share_mod, + qk_rms_norm=self.cfg.qk_rms_norm, + qk_rms_norm_cross=self.cfg.qk_rms_norm_cross, + latent_shape=list(self.cfg.latent_shape) if isinstance(self.cfg.latent_shape, (list, tuple)) else self.cfg.latent_shape, + image_attn_mode=self.cfg.image_attn_mode, + load_ckpt=self.cfg.load_ckpt, + ) + + # For a consistent external API (some systems may read out_channels) + self.out_channels = self.cfg.out_channels + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + cond: torch.Tensor, + **kwargs, + ) -> DenseDiTModelOutput: + """Forward wrapper returning a structured output like diffusers models. + + Args: + x: [B, C, D, H, W] dense latent tensor. + t: [B] or [1] timestep tensor. + cond: conditioning tensor matching the transformer blocks' expected dims. + """ + out = self.dit_model(x, t, cond) + return DenseDiTModelOutput(sample=out) diff --git a/pixal3d/models/transformers/sparse_dit.py b/pixal3d/models/transformers/sparse_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed653efd4c736d1e379b45c59aecda0cff4f011 --- /dev/null +++ b/pixal3d/models/transformers/sparse_dit.py @@ -0,0 +1,469 @@ +# Some parts of this file are adapted from the SparseDiT implementation +import os +from typing import Any, Dict, Optional, Union, Tuple, Literal +from dataclasses import dataclass +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging + +import pixal3d +from pixal3d.utils.base import BaseModule +from huggingface_hub import hf_hub_download + +# Import sparse operations + +from ...modules import sparse as sp +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer.modulated import ModulatedSparseTransformerCrossBlock +SPARSE_AVAILABLE = True +# except ImportError: + # print("Warning: sparse modules not found. Please ensure it's in your Python path.") + # sp = None + # convert_module_to_f16 = None + # convert_module_to_f32 = None + # AbsolutePositionEmbedder = None + # ModulatedSparseTransformerCrossBlock = None + # SPARSE_AVAILABLE = False + +logger = logging.get_logger(__name__) + + +@dataclass +class SparseDiTModelOutput: + sample: Any # Can be torch.FloatTensor or sp.SparseTensor + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_freq = t_freq.to(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + Sparse Diffusion Transformer model for 3D shape generation. + + This model processes sparse 3D data using sparse attention mechanisms. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + resolution: int = 64, + in_channels: int = 16, + model_channels: int = 1024, + cond_channels: int = 1024, + out_channels: int = 16, + num_blocks: int = 24, + num_heads: int = 32, + num_head_channels: int = 64, + num_kv_heads: int = 2, + compression_block_size: int = 4, + selection_block_size: int = 8, + topk: int = 32, + compression_version: str = 'v2', + mlp_ratio: float = 4.0, + pe_mode: str = "ape", + use_fp16: bool = True, + use_checkpoint: bool = True, + share_mod: bool = False, + qk_rms_norm: bool = True, + qk_rms_norm_cross: bool = False, + sparse_conditions: bool = True, + factor: float = 1.0, + window_size: int = 8, + use_shift: bool = True, + image_attn_mode:str='cross', + load_ckpt:bool=True, + version:Optional[str]='V10', + ): + super().__init__() + + if not SPARSE_AVAILABLE: + raise ImportError("sparse modules not found.") + + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self._dtype = torch.float16 if use_fp16 else torch.float32 + self.sparse_conditions = sparse_conditions + self.factor = factor + self.compression_block_size = compression_block_size + self.selection_block_size = selection_block_size + self.image_attn_mode = image_attn_mode + + # Timestep embedding + self.t_embedder = TimestepEmbedder(model_channels) + + # Shared modulation if enabled + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + # Condition processing for sparse conditions + if sparse_conditions: + self.cond_proj = sp.SparseLinear(cond_channels, cond_channels) + self.pos_embedder_cond = AbsolutePositionEmbedder(model_channels, in_channels=3) + + # Position embedding + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + # Input projection + self.input_layer = sp.SparseLinear(in_channels, model_channels) + + # Transformer blocks + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + num_kv_heads=num_kv_heads, + compression_block_size=compression_block_size, + selection_block_size=selection_block_size, + topk=topk, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + compression_version=compression_version, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + resolution=resolution, + window_size=window_size, + shift_window=window_size // 2 * (i % 2) if use_shift else window_size // 2, + image_attn_mode = image_attn_mode, + ) + for i in range(num_blocks) + ]) + + # Output projection + self.out_layer = sp.SparseLinear(model_channels, out_channels) + + # Initialize weights + self.initialize_weights() + + + self.gradient_checkpointing = False + + if use_fp16: + print("Converting model to float16 ============================") + self.convert_to_fp16() + # else: + # self.convert_to_fp32() + @property + def device(self) -> torch.device: + """Return the device of the model.""" + return next(self.parameters()).device + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def convert_to_fp16(self) -> None: + """Convert the model to float16.""" + self.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """Convert the model to float32.""" + self.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + """Initialize model weights.""" + # Initialize transformer layers + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + # if hasattr(block, 'adaLN_modulation'): + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward( + self, + hidden_states: Any, # sp.SparseTensor + timestep: torch.Tensor, + encoder_hidden_states: Optional[Any] = None, # torch.Tensor or sp.SparseTensor + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[SparseDiTModelOutput, Tuple]: + """ + Forward pass of the SparseDiT model. + + Args: + hidden_states: Input sparse tensor + timestep: Timestep tensor + encoder_hidden_states: Condition tensor (visual/text conditions) + attention_kwargs: Additional attention arguments + return_dict: Whether to return a dictionary + """ + # breakpoint() + # Process input + assert attention_kwargs is None, "attention_kwargs not supported in SparseDiT" + # breakpoint() + h = self.input_layer(hidden_states).type(self._dtype) + + # Process timestep + t_emb = self.t_embedder(timestep) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self._dtype) + + # Process conditions + + cond = encoder_hidden_states + if self.image_attn_mode=='proj': + global_cond,sparse_cond = cond + + if sparse_cond is not None: + sparse_cond = sparse_cond.type(self._dtype) + global_cond = global_cond.type(self._dtype) + # breakpoint() + if self.sparse_conditions and isinstance(sparse_cond, sp.SparseTensor): + # breakpoint() + sparse_cond = self.cond_proj(sparse_cond) + sparse_cond = sparse_cond + self.pos_embedder_cond(sparse_cond.coords[:, 1:]).type(self._dtype) + cond = (global_cond,sparse_cond) + else: + if self.sparse_conditions: + cond = self.cond_proj(cond) + cond = cond + self.pos_embedder_cond(cond.coords[:, 1:]).type(self.dtype) + + # Add positional embeddings + if self.pe_mode == "ape": + h = h + self.pos_embedder(h.coords[:, 1:], factor=self.factor).type(self._dtype) + + # Process through transformer blocks + for block in self.blocks: + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + h = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + h, t_emb, cond + ) + else: + h = block(h, t_emb, cond) + + # Final layer norm and output projection + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h.type(hidden_states.dtype)) + + if not return_dict: + return (h,) + + return SparseDiTModelOutput(sample=h) + + +@pixal3d.register("sparse-dit-denoiser") +class SparseDiTDenoiser(BaseModule): + """ + Sparse DiT Denoiser wrapper for pixal3d framework. + """ + + @dataclass + class Config(BaseModule.Config): + # Model architecture + resolution: int = 64 + in_channels: int = 16 + model_channels: int = 1024 + cond_channels: int = 1024 + out_channels: int = 16 + num_blocks: int = 24 + num_heads: int = 32 + num_kv_heads: int = 2 + compression_block_size: int = 4 + selection_block_size: int = 8 + topk: int = 32 + compression_version: str = 'v2' + mlp_ratio: float = 4.0 + pe_mode: str = "ape" + use_fp16: bool = True + use_checkpoint: bool = True + qk_rms_norm: bool = True + qk_rms_norm_cross: bool = False + sparse_conditions: bool = True + factor: float = 1.0 + window_size: int = 8 + use_shift: bool = True + + # Condition settings + use_visual_condition: bool = True + visual_condition_dim: int = 1024 + use_caption_condition: bool = False + caption_condition_dim: int = 1024 + use_label_condition: bool = False + label_condition_dim: int = 1024 + + # Training settings + pretrained_model_name_or_path: Optional[str] = None + + image_attn_mode:Optional[str]='cross' + load_ckpt:bool =True + version:Optional[str]='V10' + + cfg: Config + + def configure(self) -> None: + """Configure the SparseDiT model.""" + + # Create the core SparseDiT model + self.dit_model = SparseDiTModel( + resolution=self.cfg.resolution, + in_channels=self.cfg.in_channels, + model_channels=self.cfg.model_channels, + cond_channels=self.cfg.cond_channels, + out_channels=self.cfg.out_channels, + num_blocks=self.cfg.num_blocks, + num_heads=self.cfg.num_heads, + num_kv_heads=self.cfg.num_kv_heads, + compression_block_size=self.cfg.compression_block_size, + selection_block_size=self.cfg.selection_block_size, + topk=self.cfg.topk, + compression_version=self.cfg.compression_version, + mlp_ratio=self.cfg.mlp_ratio, + pe_mode=self.cfg.pe_mode, + use_fp16=self.cfg.use_fp16, + use_checkpoint=self.cfg.use_checkpoint, + sparse_conditions=self.cfg.sparse_conditions, + factor=self.cfg.factor, + window_size=self.cfg.window_size, + use_shift=self.cfg.use_shift, + image_attn_mode=self.cfg.image_attn_mode, + load_ckpt = self.cfg.load_ckpt, + version=self.cfg.version, + ) + + # Condition projectors + if self.cfg.use_visual_condition and self.cfg.visual_condition_dim != self.cfg.cond_channels: + self.proj_visual_condition = nn.Sequential( + nn.RMSNorm(self.cfg.visual_condition_dim), + nn.Linear(self.cfg.visual_condition_dim, self.cfg.cond_channels), + ) + + if self.cfg.use_caption_condition and self.cfg.caption_condition_dim != self.cfg.cond_channels: + self.proj_caption_condition = nn.Sequential( + nn.RMSNorm(self.cfg.caption_condition_dim), + nn.Linear(self.cfg.caption_condition_dim, self.cfg.cond_channels), + ) + + if self.cfg.use_label_condition and self.cfg.label_condition_dim != self.cfg.cond_channels: + self.proj_label_condition = nn.Sequential( + nn.RMSNorm(self.cfg.label_condition_dim), + nn.Linear(self.cfg.label_condition_dim, self.cfg.cond_channels), + ) + + # Load pretrained weights if specified + if self.cfg.pretrained_model_name_or_path: + print(f"Loading pretrained SparseDiT model from {self.cfg.pretrained_model_name_or_path}") + ckpt = torch.load( + self.cfg.pretrained_model_name_or_path, + map_location="cpu", + weights_only=True, + ) + if "state_dict" in ckpt.keys(): + ckpt = ckpt["state_dict"] + self.load_state_dict(ckpt, strict=True) + + def forward( + self, + x: Any, # sp.SparseTensor + t: torch.Tensor, + cond: Optional[Any] = None, + ): + """ + Forward pass of the denoiser. + + Args: + model_input: Input sparse tensor [SparseTensor with features] + timestep: Timestep tensor [batch_size,] + visual_condition: Visual condition tensor + caption_condition: Caption condition tensor + label_condition: Label condition tensor + attention_kwargs: Additional attention arguments + return_dict: Whether to return a dictionary + """ + + + output = self.dit_model( + hidden_states=x, + timestep=t, + encoder_hidden_states=cond, + ) + + return output + + diff --git a/pixal3d/modules/__pycache__/norm.cpython-310.pyc b/pixal3d/modules/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee977b83807341277904c50db7bcd7f553d700b4 Binary files /dev/null and b/pixal3d/modules/__pycache__/norm.cpython-310.pyc differ diff --git a/pixal3d/modules/__pycache__/spatial.cpython-310.pyc b/pixal3d/modules/__pycache__/spatial.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7c5aca0bc18d310909545d4b21cd60171872e8e Binary files /dev/null and b/pixal3d/modules/__pycache__/spatial.cpython-310.pyc differ diff --git a/pixal3d/modules/__pycache__/utils.cpython-310.pyc b/pixal3d/modules/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4bbe8e72b9f392cfc892813867a1ba734397e52 Binary files /dev/null and b/pixal3d/modules/__pycache__/utils.cpython-310.pyc differ diff --git a/pixal3d/modules/attention/__init__.py b/pixal3d/modules/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d500ecd8ce72fd4e072ecdd9c008d2ae030e0629 --- /dev/null +++ b/pixal3d/modules/attention/__init__.py @@ -0,0 +1,35 @@ +from typing import * +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_sttn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_sttn_debug is not None: + DEBUG = env_sttn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +from .full_attn import * +from .modules import * diff --git a/pixal3d/modules/attention/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..815fee4f9b886cbe6de2e39894c32f47599e86ad Binary files /dev/null and b/pixal3d/modules/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/attention/__pycache__/full_attn.cpython-310.pyc b/pixal3d/modules/attention/__pycache__/full_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f36830206d7177ae5462ffe60b49ccb5c6385a3a Binary files /dev/null and b/pixal3d/modules/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/pixal3d/modules/attention/__pycache__/modules.cpython-310.pyc b/pixal3d/modules/attention/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e130285c4809429d7a589a38f3943cbba992cb6 Binary files /dev/null and b/pixal3d/modules/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/pixal3d/modules/attention/full_attn.py b/pixal3d/modules/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..f94cf46843412c4d349d2d5dcd7277fac938e507 --- /dev/null +++ b/pixal3d/modules/attention/full_attn.py @@ -0,0 +1,140 @@ +from typing import * +import torch +import math +from . import DEBUG, BACKEND + +if BACKEND == 'xformers': + import xformers.ops as xops +elif BACKEND == 'flash_attn': + import flash_attn +elif BACKEND == 'sdpa': + from torch.nn.functional import scaled_dot_product_attention as sdpa +elif BACKEND == 'naive': + pass +else: + raise ValueError(f"Unknown attention backend: {BACKEND}") + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif BACKEND == 'flash_attn': + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif BACKEND == 'sdpa': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {BACKEND}") + + return out diff --git a/pixal3d/modules/attention/modules.py b/pixal3d/modules/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc7ff63b1dbd1ea5c4eb31080759baf01073c97 --- /dev/null +++ b/pixal3d/modules/attention/modules.py @@ -0,0 +1,164 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class RotaryPositionEmbedder(nn.Module): + def __init__(self, hidden_size: int, in_channels: int = 3): + super().__init__() + assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" + self.hidden_size = hidden_size + self.in_channels = in_channels + self.freq_dim = hidden_size // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (sp.SparseTensor): [..., N, D] tensor of queries + k (sp.SparseTensor): [..., N, D] tensor of keys + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + if indices is None: + indices = torch.arange(q.shape[-2], device=q.device) + if len(q.shape) > 2: + indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) + + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[1] < self.hidden_size // 2: + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), + torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) + )], dim=-1) + q_embed = self._rotary_embedding(q, phases) + k_embed = self._rotary_embedding(k, phases) + return q_embed, k_embed + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + if self.use_rope: + q, k, v = qkv.unbind(dim=2) + q, k = self.rope(q, k, indices) + qkv = torch.stack([q, k, v], dim=2) + if self.attn_mode == "full": + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=2) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h + + +class ProjectAttention(nn.Module): + def __init__(self,cross_attn_block: nn.Module): + super().__init__() + self.cross_attn_block = cross_attn_block + self.global_token_length = 5 + + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + + global_context = context[0] + proj_context = context[1] + global_context = self.cross_attn_block(x, global_context) + context = proj_context + global_context + return context + x + + \ No newline at end of file diff --git a/pixal3d/modules/norm.py b/pixal3d/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..09035726081fb7afda2c62504d5474cfa483c58f --- /dev/null +++ b/pixal3d/modules/norm.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/pixal3d/modules/sparse/__init__.py b/pixal3d/modules/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..036ff1998f29424abb720642cd83a83c6abf750f --- /dev/null +++ b/pixal3d/modules/sparse/__init__.py @@ -0,0 +1,105 @@ +from typing import * + +BACKEND = 'torchsparse' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global BACKEND + global DEBUG + global ATTN + + env_sparse_backend = os.environ.get('SPARSE_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn is None: + env_sparse_attn = os.environ.get('ATTN_BACKEND') + + if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: + BACKEND = env_sparse_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + ATTN = env_sparse_attn + + print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") + + +__from_env() + + +def set_backend(backend: Literal['spconv', 'torchsparse']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn(attn: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = attn + + +import importlib + +__attributes = { + 'SparseTensor': 'basic', + 'sparse_batch_broadcast': 'basic', + 'sparse_batch_op': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseSigmoid': 'nonlinearity', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseTanh': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'sparseconv3d_func': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide' : 'spatial' +} + +__submodules = ['transformer'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + import transformer diff --git a/pixal3d/modules/sparse/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/sparse/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88cc980ec8cf6d64cd1a2d26b40b545223d08f4d Binary files /dev/null and b/pixal3d/modules/sparse/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/__pycache__/basic.cpython-310.pyc b/pixal3d/modules/sparse/__pycache__/basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c92c1b4d9959cb4d495f8756b4d6579d30b3731b Binary files /dev/null and b/pixal3d/modules/sparse/__pycache__/basic.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/__pycache__/linear.cpython-310.pyc b/pixal3d/modules/sparse/__pycache__/linear.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df479976fd5011dec32253d389db23648365933 Binary files /dev/null and b/pixal3d/modules/sparse/__pycache__/linear.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc b/pixal3d/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..348efe409e37474d230457e59ecb34557c14c899 Binary files /dev/null and b/pixal3d/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/__pycache__/norm.cpython-310.pyc b/pixal3d/modules/sparse/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f50516b9382c4e692814f917ad5cafd474ae01aa Binary files /dev/null and b/pixal3d/modules/sparse/__pycache__/norm.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/__pycache__/spatial.cpython-310.pyc b/pixal3d/modules/sparse/__pycache__/spatial.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb74fa00fe93a1be0710b17db77c1fa1a18501e9 Binary files /dev/null and b/pixal3d/modules/sparse/__pycache__/spatial.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/__init__.py b/pixal3d/modules/sparse/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2732f12d139e579fb27f224c523e27f1e8cefb --- /dev/null +++ b/pixal3d/modules/sparse/attention/__init__.py @@ -0,0 +1,5 @@ +from .full_attn import * +from .serialized_attn import * +from .windowed_attn import * +from .modules import * +from .spatial_sparse_attention.module.spatial_sparse_attention import SpatialSparseAttention diff --git a/pixal3d/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec82e5848f2ce389b996d7c4db14fcf9fcdee26f Binary files /dev/null and b/pixal3d/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc b/pixal3d/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1028e472143c8521e543ec88612e68bc58123395 Binary files /dev/null and b/pixal3d/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/__pycache__/modules.cpython-310.pyc b/pixal3d/modules/sparse/attention/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdec8b494a1056eca872c94f0e7c5d2c29891d8b Binary files /dev/null and b/pixal3d/modules/sparse/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc b/pixal3d/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df4ed7942962b42d4d17ad23d01b6e43fbaf7bf0 Binary files /dev/null and b/pixal3d/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc b/pixal3d/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4e094ef62005213c28a96ffafa80fb25358d987 Binary files /dev/null and b/pixal3d/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/full_attn.py b/pixal3d/modules/sparse/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e27aeb98419621f3f9999fd3b11eebf2b90a40 --- /dev/null +++ b/pixal3d/modules/sparse/attention/full_attn.py @@ -0,0 +1,215 @@ +from typing import * +import torch +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, SparseTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, SparseTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if DEBUG: + if s is not None: + for i in range(s.shape[0]): + assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + if num_all_args in [2, 3]: + assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" + if num_all_args == 3: + assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" + + if ATTN == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif ATTN == 'flash_attn': + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + else: + raise ValueError(f"Unknown attention module: {ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/pixal3d/modules/sparse/attention/modules.py b/pixal3d/modules/sparse/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9f313bf9de9ad702a949a54a656b0f6865a6ba76 --- /dev/null +++ b/pixal3d/modules/sparse/attention/modules.py @@ -0,0 +1,156 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from ...attention import RotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, SparseTensor): + x = x.replace(F.normalize(x.feats, dim=-1)) + else: + x = F.normalize(x, dim=-1) + return (x * self.gamma * self.scale).to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "serialized", "windowed"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + self.channels = channels + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_sequence = shift_sequence + self.shift_window = shift_window + self.serialize_mode = serialize_mode + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + @staticmethod + def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats + + def _rope(self, qkv: SparseTensor) -> SparseTensor: + q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] + q, k = self.rope(q, k, qkv.coords[:, 1:]) + qkv = qkv.replace(torch.stack([q, k, v], dim=1)) + return qkv + + def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.use_rope: + qkv = self._rope(qkv) + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=1) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "serialized": + h = sparse_serialized_scaled_dot_product_self_attention( + qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window + ) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=1) + k = self.k_rms_norm(k) + kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h + + + +class SparseProjAttention(nn.Module): + def __init__(self,cross_attn_block=None): + super().__init__() + self.cross_attn_block = cross_attn_block + self.global_token_length = 5 + + + def forward(self,x,context): + if self.cross_attn_block is not None: + global_context = context[0] + proj_context = context[1] + global_context = self.cross_attn_block(x, global_context) + context = proj_context + global_context + return x + context \ No newline at end of file diff --git a/pixal3d/modules/sparse/attention/serialized_attn.py b/pixal3d/modules/sparse/attention/serialized_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..5950b75b2f5a6d6e79ab6d472b8501aaa5ec4a26 --- /dev/null +++ b/pixal3d/modules/sparse/attention/serialized_attn.py @@ -0,0 +1,193 @@ +from typing import * +from enum import Enum +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_serialized_scaled_dot_product_self_attention', +] + + +class SerializeMode(Enum): + Z_ORDER = 0 + Z_ORDER_TRANSPOSED = 1 + HILBERT = 2 + HILBERT_TRANSPOSED = 3 + + +SerializeModes = [ + SerializeMode.Z_ORDER, + SerializeMode.Z_ORDER_TRANSPOSED, + SerializeMode.HILBERT, + SerializeMode.HILBERT_TRANSPOSED +] + + +def calc_serialization( + tensor: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (torch.Tensor, torch.Tensor): Forwards and backwards indices. + """ + fwd_indices = [] + bwd_indices = [] + seq_lens = [] + seq_batch_indices = [] + offsets = [0] + + if 'vox2seq' not in globals(): + import vox2seq + + # Serialize the input + serialize_coords = tensor.coords[:, 1:].clone() + serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) + if serialize_mode == SerializeMode.Z_ORDER: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) + elif serialize_mode == SerializeMode.HILBERT: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) + else: + raise ValueError(f"Unknown serialize mode: {serialize_mode}") + + for bi, s in enumerate(tensor.layout): + num_points = s.stop - s.start + num_windows = (num_points + window_size - 1) // window_size + valid_window_size = num_points / num_windows + to_ordered = torch.argsort(code[s.start:s.stop]) + if num_windows == 1: + fwd_indices.append(to_ordered) + bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) + fwd_indices[-1] += s.start + bwd_indices[-1] += offsets[-1] + seq_lens.append(num_points) + seq_batch_indices.append(bi) + offsets.append(offsets[-1] + seq_lens[-1]) + else: + # Partition the input + offset = 0 + mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] + split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] + bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) + for i in range(num_windows): + mid = mids[i] + valid_start = split[i] + valid_end = split[i + 1] + padded_start = math.floor(mid - 0.5 * window_size) + padded_end = padded_start + window_size + fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) + offset += valid_start - padded_start + bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) + offset += padded_end - valid_start + fwd_indices[-1] += s.start + seq_lens.extend([window_size] * num_windows) + seq_batch_indices.extend([bi] * num_windows) + bwd_indices.append(bwd_index + offsets[-1]) + offsets.append(offsets[-1] + num_windows * window_size) + + fwd_indices = torch.cat(fwd_indices) + bwd_indices = torch.cat(bwd_indices) + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_serialized_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply serialized scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/__init__.py b/pixal3d/modules/sparse/attention/spatial_sparse_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da4d954c0fc283fe943b4d2e39557bcb5e87c0e5 Binary files /dev/null and b/pixal3d/modules/sparse/attention/spatial_sparse_attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__init__.py b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__init__.py @@ -0,0 +1 @@ + diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f304040ca95e4a548a61a1fca8e9f2d2c734583e Binary files /dev/null and b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a81fd161458e40b85046e522139f7e475215367 Binary files /dev/null and b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/compression_block.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29bf6535896b3bda907529b8cc0a48822ca6118c Binary files /dev/null and b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/__pycache__/spatial_sparse_attention.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py new file mode 100644 index 0000000000000000000000000000000000000000..1b012427734480e18f5d8a0e333f3ccc624d64a5 --- /dev/null +++ b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/compression_block.py @@ -0,0 +1,65 @@ +import torch.nn as nn +import pixal3d.modules.sparse as sp + + +class SparseDownBlock3d_v1(nn.Module): + + def __init__( + self, + channels: int, + out_channels: int = None, + factor: int = 2, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseConv3d(self.out_channels, self.out_channels, 1, padding=0), + sp.SparseSiLU() + ) + self.down = sp.SparseDownsample(factor) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.act_layers(x) + h = self.down(h) + return h + +class SparseDownBlock3d_v2(nn.Module): + + def __init__( + self, + channels: int, + out_channels: int = None, + num_groups: int = 32, + factor: int = 2, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.down = sp.SparseDownsample(factor) + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, padding=1), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.act_layers(x) + h = self.down(h) + x = self.down(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h \ No newline at end of file diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8f92095853ed9284776fedfbd718720e4e23d087 --- /dev/null +++ b/pixal3d/modules/sparse/attention/spatial_sparse_attention/module/spatial_sparse_attention.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn +from einops import rearrange +from flash_attn import flash_attn_varlen_func +from ..ops import ( + spatial_selection_attention, + get_block_score, + sparse_window_attention, +) +from .compression_block import SparseDownBlock3d_v1, SparseDownBlock3d_v2 +import pixal3d.modules.sparse as sp + + +class SpatialSparseAttention(torch.nn.Module): + def __init__( + self, + hidden_size: int, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + compression_block_size: int, + selection_block_size: int, + topk: int, + window_size: int, + shift_window: int, + resolution: int = 64, + compression_version: str = 'v2', + ): + super().__init__() + self.hidden_size = hidden_size + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.compression_block_size = compression_block_size + self.selection_block_size = selection_block_size + self.topk = topk + self.window_size = window_size + self.shift_window = shift_window + self.resolution = resolution + + # qkv proj and o proj + self.proj_q = sp.SparseLinear( + hidden_size, num_q_heads * head_dim, bias=False + ) + self.proj_k = sp.SparseLinear( + hidden_size, num_kv_heads * head_dim, bias=False + ) + self.proj_v = sp.SparseLinear( + hidden_size, num_kv_heads * head_dim, bias=False + ) + self.proj_o = torch.nn.Linear( + num_q_heads * head_dim, hidden_size, bias=False + ) + + # ssa parameteres + if compression_version == 'v1': + compression_block = SparseDownBlock3d_v1 + elif compression_version == 'v2': + compression_block = SparseDownBlock3d_v2 + else: + raise NotImplementedError('only support v1 or v2 compression block') + self.compression_key = compression_block( + num_kv_heads * head_dim, num_kv_heads * head_dim, factor=compression_block_size + ) + self.compression_value = compression_block( + num_kv_heads * head_dim, num_kv_heads * head_dim, factor=compression_block_size + ) + self.intra_block_pe = torch.nn.Parameter( + torch.zeros(compression_block_size, + compression_block_size, + compression_block_size, + num_kv_heads * head_dim, + ) + ) + + # gate function + self.gate = torch.nn.Sequential( + sp.SparseLinear(hidden_size, 3, bias=False), sp.SparseSigmoid(), + ) + + def sparse3d_compression(self, x, key=True): + _, num_heads, num_dim = x.feats.shape + x = x.replace(x.feats.view(-1, num_heads * num_dim)) + if key: + coords = x.coords + intra_block_coords = coords[..., 1:] % self.compression_block_size + intra_block_pos = self.intra_block_pe[intra_block_coords[:, 0], intra_block_coords[:, 1], intra_block_coords[:, 2]].to(x.dtype) + x = x.replace(x.feats + intra_block_pos) + y = self.compression_key(x) + else: + y = self.compression_value(x) + y = y.replace(y.feats.view(-1, num_heads, num_dim)) + return y + + def forward(self, x: sp.SparseTensor): + # dtype and shape check + assert x.shape[-1] == self.hidden_size + assert self.selection_block_size % self.compression_block_size == 0 + # qkv proj + q = x.replace(self.proj_q(x).feats.view(-1, self.num_q_heads, self.head_dim)) + k = x.replace(self.proj_k(x).feats.view(-1, self.num_kv_heads, self.head_dim)) + v = x.replace(self.proj_v(x).feats.view(-1, self.num_kv_heads, self.head_dim)) + + # compression attention + compressed_k = self.sparse3d_compression(k, key=True) + compressed_v = self.sparse3d_compression(v, key=False) + + compressed_cu_seqlens = torch.tensor([s.start for s in compressed_v.layout] + [s.stop for s in compressed_v.layout if s.stop not in [s.start for s in compressed_v.layout]]).to(compressed_v.device).to(torch.int32) + compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1] + + cu_seqlens = torch.tensor([s.start for s in x.layout] + [s.stop for s in x.layout if s.stop not in [s.start for s in x.layout]]).to(x.device).to(torch.int32) + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + compressed_attn_output, lse, _ = flash_attn_varlen_func( + q.feats, + compressed_k.feats, + compressed_v.feats, + cu_seqlens, + compressed_cu_seqlens, + seqlens.max().item(), + compressed_seqlens.max().item(), + causal=False, + return_attn_probs=True, + ) + + with torch.no_grad(): + block_topk, cu_seqblocks, cu_block_include_tokens = get_block_score( + q, compressed_k, lse, self.resolution, self.compression_block_size, + self.selection_block_size, self.topk, cu_seqlens, compressed_cu_seqlens, + seqlens, compressed_seqlens, None) + + # spatial selection attention + selection_attn_output = spatial_selection_attention( + q.feats, k.feats, v.feats, block_topk, cu_seqblocks, + cu_block_include_tokens, self.selection_block_size, cu_seqlens, None, + ) + + # window attention + window_attn_output = sparse_window_attention( + q, k, v, window_size=self.window_size, shift_window=self.shift_window, + ).feats + + # gate average + gate = self.gate(x).feats + attn_output = ( + gate[:, 0:1, None] * compressed_attn_output + + gate[:, 1:2, None] * selection_attn_output + + gate[:, 2:3, None] * window_attn_output + ) + + # rearrange and output proj + attn_output = rearrange(attn_output, "n h d -> n (h d)") + attn_output = self.proj_o(attn_output) + + return x.replace(attn_output) diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3bfe015052a8b72b5c9d1d7337621656bc7ce7 --- /dev/null +++ b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__init__.py @@ -0,0 +1,3 @@ +from .compressed_attention import get_block_score +from .selection_attention import spatial_selection_attention +from .window_attention import sparse_window_attention diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9300f0333753b298a8db46fdcc8397f557dbd8f Binary files /dev/null and b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbb6d9879244d042b458fe07adcb1737ebc1c2ac Binary files /dev/null and b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/compressed_attention.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0999881b4bc08e4804ca9e7cf8fd9bcd3ef0a7d2 Binary files /dev/null and b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/selection_attention.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d48039a46e7f767f5f9389ab73d33a5d277b0f5 Binary files /dev/null and b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/__pycache__/window_attention.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e06a2c8ad11f8fe59a47ecc5399435a73419113e --- /dev/null +++ b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/compressed_attention.py @@ -0,0 +1,275 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Copyright 2025 Shuang Wu +# adapted from https://github.com/XunhaoLai/native-sparse-attention-triton/blob/main/native_sparse_attention/ops/triton/compressed_attention.py + +import math +import torch +from copy import deepcopy +import triton +import triton.language as tl +import pixal3d.modules.sparse as sp + + +@triton.jit +def score_kernel( + q_ptr, + k_ptr, + lse_ptr, + s_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_sh, + stride_sq, + stride_sk, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + # init k pointer and load k + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + # init score + s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + # loop over gqa heads + for h in range(NUM_SHARE_Q_HEADS): + pid_h = pid_kh * NUM_SHARE_Q_HEADS + h + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k) * qk_scale + # compute score + s += tl.exp2(qk - lse) + # save output + s_ptrs = tl.make_block_ptr( + base=s_ptr + pid_kh * stride_sh + q_start * stride_sq, + shape=(q_len, k_len), + strides=(stride_sq, stride_sk), + offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K), + order=(1, 0), + ) + tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_attention_score( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + # gqa + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # init score + score = torch.zeros( + num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device + ) + # launch kernel + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + score_kernel[grid]( + q, + k, + lse, + score, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + score.stride(0), + score.stride(1), + score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return score + + +def get_block_score( + q: sp.SparseTensor, + compressed_k: sp.SparseTensor, + lse: sp.SparseTensor, + resolution: int, + kernel_stride: int, + block_size: int, + topk: int, + cu_seqlens: torch.Tensor, + compressed_cu_seqlens: torch.Tensor, + seqlens: torch.Tensor, + compressed_seqlens: torch.Tensor, + sm_scale: float = None, +) -> torch.Tensor: + attn_score = _get_attention_score( + q.feats, + compressed_k.feats, + lse.exp().log2(), + cu_seqlens, + compressed_cu_seqlens, + seqlens.max().item(), + compressed_seqlens.max().item(), + sm_scale, + ) + + batch_size = len(cu_seqlens) - 1 + num_kv_head = attn_score.shape[0] + block_res = resolution // block_size + seqblocks, block_include_tokens = [], [] + block_topk = torch.ones((num_kv_head, cu_seqlens[-1], topk), device=q.device, dtype=torch.int32) * -1 + + q_coords = deepcopy(q.coords) + for b in range(batch_size): + q_start, q_end, q_len = cu_seqlens[b], cu_seqlens[b + 1], seqlens[b] + + compressed_k_start, compressed_k_end = compressed_cu_seqlens[b], compressed_cu_seqlens[b + 1] + attn_score_b = attn_score[:, q_start: q_end, :(compressed_k_end-compressed_k_start)] + compressed_block_coords_b = deepcopy(compressed_k.coords[compressed_k_start: compressed_k_end]) + if block_size == kernel_stride: + score_block_b = attn_score_b + real_topk = min(topk, compressed_k_end - compressed_k_start) + block_topk_b = score_block_b.topk(real_topk, dim=-1).indices.sort(-1).values + block_topk[:, q_start: q_end, :real_topk] = block_topk_b + else: + compressed_block_coords_b[:, 1:] = compressed_block_coords_b[:, 1:] // (block_size//kernel_stride) + compressed_block_coords_flatten_b = compressed_block_coords_b[:, 1] * block_res**2 + compressed_block_coords_b[:, 2] * block_res + compressed_block_coords_b[:, 3] + score_block_b = torch.scatter_reduce( + torch.zeros((num_kv_head, q_len, block_res**3), device=attn_score_b.device, dtype=attn_score_b.dtype), + index=compressed_block_coords_flatten_b.long().unsqueeze(0).unsqueeze(0).expand_as(attn_score_b), + src=attn_score_b, + reduce="sum", + dim=2, + ) + compressed_block_coords_flatten_unique_b = compressed_block_coords_flatten_b.unique() + score_block_b = score_block_b[..., compressed_block_coords_flatten_unique_b] + real_topk = min(topk, len(compressed_block_coords_flatten_unique_b)) + block_topk_b = score_block_b.topk(real_topk, dim=-1).indices.sort(-1).values + block_topk[:, q_start: q_end, :real_topk] = block_topk_b + + block_coords_b = q_coords[q_start: q_end] + block_coords_b[:, 1:] = block_coords_b[:, 1:] // block_size + block_coords_flatten_b = block_coords_b[:, 1] * block_res**2 + block_coords_b[:, 2] * block_res + block_coords_b[:, 3] + block_bins_b = torch.histc(block_coords_flatten_b, bins=block_res**3, min=0, max=block_res**3-1) + block_include_tokens.append(block_bins_b[block_bins_b > 0]) + seqblocks.append(len(block_include_tokens[-1])) + seqblocks = torch.Tensor(seqblocks).to(attn_score.device) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(seqblocks, dim=0), + ], + dim=0, + ).to(torch.int32) + block_include_tokens = torch.cat(block_include_tokens) + cu_block_include_tokens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(block_include_tokens, dim=0), + ], + dim=0, + ).to(torch.int32) + return block_topk.to(torch.int32), cu_seqblocks, cu_block_include_tokens diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..62607d6f5a0a1c2b18180b5495a6413d9867402f --- /dev/null +++ b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/selection_attention.py @@ -0,0 +1,1256 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Copyright 2025 Shuang Wu +# adapted from https://github.com/XunhaoLai/native-sparse-attention-triton/blob/main/native_sparse_attention/ops/triton/topk_sparse_attention.py + +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_block_include_tokens, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # q loop num + num_q_loop, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_h = pid_kh * NUM_SHARE_Q_HEADS + pid_q = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + block_start = tl.load(cu_seqblocks + pid_b) + block_len = tl.load(cu_seqblocks + pid_b + 1) - block_start + if pid_q * num_q_loop >= q_len: + return + num_q_loop_ = min(num_q_loop, q_len - pid_q * num_q_loop) + for j in range(num_q_loop_): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx < block_len), 1, 0), + axis=0, + ) + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_h = tl.arange(0, BLOCK_SIZE_H) + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) + # sparse attention + for i in range(real_topk): + # get current block start index + cur_block_idx = tl.load(t_ptr_j).to(tl.int32) + cur_token_start = tl.load(cu_block_include_tokens + block_start + cur_block_idx).to(tl.int32) + cur_block_size = tl.load(cu_block_include_tokens + block_start + cur_block_idx + 1).to(tl.int32) - cur_token_start + c = cur_token_start - k_start + t_ptr_j = t_ptr_j + stride_tk + for b_j in range(0, cur_block_size, BLOCK_SIZE_K): + # load k + k = tl.load( + tl.advance(k_ptrs, (0, c + b_j)), + boundary_check=(1, 0), padding_option="zero", + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(((c + b_j + off_k < k_len) & (b_j + off_k < cur_block_size))[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load( + tl.advance(v_ptrs, (c + b_j, 0)), + boundary_check=(0, 1), padding_option="zero" + ) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_oh, stride_od), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + lse_ptrs = ( + lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh + ) + tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + + off_o[:, None] * stride_on + + pid_h * stride_oh + + off_d[None, :] * stride_od, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + + off_o[:, None] * stride_don + + pid_h * stride_doh + + off_d[None, :] * stride_dod, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store( + delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len + ) + + +@triton.jit +def count_kernel( + x_ptr, # [num_kv_heads, total_len, topk] + y_ptr, # [num_kv_heads, total_blocks] + cu_seqlens, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + topk, + stride_xh, + stride_xn, + stride_xk, + stride_yh, + stride_yn, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_R: tl.constexpr, +): + pid_h = tl.program_id(0) + pid_b = tl.program_id(1) + # get start and len after rmpad + seq_start = tl.load(cu_seqlens + pid_b) + seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start + blocks_start = tl.load(cu_seqblocks + pid_b) + num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start + # load x + off_k = tl.arange(0, BLOCK_SIZE_K) + off_n = tl.arange(0, BLOCK_SIZE_N) + x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn + x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk + # init y + y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32) + # loop + for i in range(0, seq_len, BLOCK_SIZE_N): + x = tl.load( + x_ptrs, + mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :], + other=-1, + ) + x = tl.ravel(x) + y += tl.histogram(x, BLOCK_SIZE_R) + x_ptrs += BLOCK_SIZE_N * stride_xn + # store result + off_r = tl.arange(0, BLOCK_SIZE_R) + y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn + y_ptrs = y_ptr + off_r * stride_yn + tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks) + + +def count_query( + topk_idx: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1] + batch_size = seqlens.shape[0] + BLOCK_SIZE_K = triton.next_power_of_2(topk) + BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K) + BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2) + active_query_count = torch.zeros( + num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device + ) + grid = (num_kv_heads, batch_size) + count_kernel[grid]( + topk_idx, + active_query_count, + cu_seqlens, + cu_seqblocks, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + active_query_count.stride(0), + active_query_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_R=BLOCK_SIZE_R, + num_warps=4, + num_stages=3, + ) + return active_query_count + + +@triton.jit +def pad_topk_idx_kernel( + t_ptr, + p_ptr, + cu_seqlens, + topk, + stride_th, + stride_tn, + stride_tk, + stride_pb, + stride_ph, + stride_pn, + stride_pk, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_start = tl.load(cu_seqlens + pid_b) + q_len = tl.load(cu_seqlens + pid_b + 1) - q_start + if BLOCK_SIZE_N * pid_n >= q_len: + return + # init prts + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + q_start * stride_tn, + shape=(q_len, topk), + strides=(stride_tn, stride_tk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph, + shape=(q_len, topk), + strides=(stride_pn, stride_pk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + # load and save + idxs = tl.load(t_ptrs, boundary_check=(0, 1)) + tl.store(p_ptrs, idxs, boundary_check=(0, 1)) + + +@triton.jit +def save_topk_idx_kernel( + p_ptr, + t_ptr, + cu_seqblocks, + cu_topk_q_count, + n_len, + stride_pb, + stride_ph, + stride_pn, + stride_th, + stride_tn, + stride_ch, + stride_cn, + BLOCK_SIZE_N: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_block_start = tl.load(cu_seqblocks + pid_b) + q_block_end = tl.load(cu_seqblocks + pid_b + 1) + c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn) + c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn) + c_len = c_end - c_start + if c_len <= 0: + return + if pid_n * BLOCK_SIZE_N >= c_len: + return + # init ptrs + p_ptrs = tl.make_block_ptr( + base=p_ptr + + pid_b * stride_pb + + pid_h * stride_ph + + (n_len - c_len) * stride_pn, + shape=(c_len,), + strides=(stride_pn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + c_start * stride_tn, + shape=(c_len,), + strides=(stride_tn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + # load and save + idxs = tl.load(p_ptrs, boundary_check=(0,)) + tl.store(t_ptrs, idxs, boundary_check=(0,)) + + +def reorder_topk_idx( + topk_idx: torch.Tensor, + cu_topk_q_count: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + batch_size = cu_seqlens.shape[0] - 1 + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk] + pad_topk_idx = torch.full( + (batch_size, num_kv_heads, max_seqlen, topk), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + BLOCK_SIZE_N = min( + triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T) + ) + grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N)) + pad_topk_idx_kernel[grid]( + topk_idx, + pad_topk_idx, + cu_seqlens, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + pad_topk_idx.stride(0), + pad_topk_idx.stride(1), + pad_topk_idx.stride(2), + pad_topk_idx.stride(3), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_T=BLOCK_SIZE_T, + ) + # argsort + pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk + pad_topk_q_idx = pad_topk_q_idx.to(torch.int32) + # save as remove pad version + topk_q_idx = torch.full( + (num_kv_heads, cu_topk_q_count[:, -1].max().item()), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + max_len = ( + ( + cu_topk_q_count[:, cu_seqblocks][:, 1:] + - cu_topk_q_count[:, cu_seqblocks][:, :-1] + ) + .max() + .item() + ) + + BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192) + grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N)) + save_topk_idx_kernel[grid]( + pad_topk_q_idx, + topk_q_idx, + cu_seqblocks, + cu_topk_q_count, + pad_topk_q_idx.shape[-1], + pad_topk_q_idx.stride(0), + pad_topk_q_idx.stride(1), + pad_topk_q_idx.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return topk_q_idx + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + cu_block_include_tokens, # [total_blocks + 1] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + max_seqblocks, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_kb = tl.program_id(2) + pid_k = pid_kb % max_seqblocks + pid_block = pid_kb // max_seqblocks + + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + b_len = tl.load(cu_seqblocks + pid_b + 1) - b_start + + if pid_k >= b_len: + return + + cur_token_start = tl.load(cu_block_include_tokens + b_start + pid_k).to(tl.int32) + cur_block_size = tl.load(cu_block_include_tokens + b_start + pid_k + 1).to(tl.int32) - cur_token_start + cur_token_start_in_seq = cur_token_start - k_start + + if pid_block * BLOCK_SIZE_K >= cur_block_size: + return + + act_q_start = tl.load( + cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn + ) + act_q_end = tl.load( + cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn + ) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) #+ pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(cur_token_start_in_seq + pid_block * BLOCK_SIZE_K, 0), #(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = dk_ptr + (cur_token_start + pid_block * BLOCK_SIZE_K + off_k[:, None]) * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks + off_d[None, :] * stride_dkd + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(cur_token_start_in_seq + pid_block * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = dv_ptr + (cur_token_start + pid_block * BLOCK_SIZE_K + off_k[:, None]) * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs + off_d[None, :] * stride_dvd + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = ( + q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + ) + do_ptrs = ( + do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + ) + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to( + tl.int32 + ) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(((pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[None, :] & (off_q < act_q_len - i)[:, None]), float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), mask=(pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[:, None]) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), mask=(pid_block * BLOCK_SIZE_K + off_k < cur_block_size)[:, None]) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_block_include_tokens, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # q loop num + num_q_loop, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_q = tl.program_id(2) + pid_h = pid_kh * NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + block_start = tl.load(cu_seqblocks + pid_b) + block_len = tl.load(cu_seqblocks + pid_b + 1) - block_start + if pid_q * num_q_loop >= q_len: + return + num_q_loop_ = min(num_q_loop, q_len - pid_q * num_q_loop) + for j in range(num_q_loop_): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx < block_len), 1, 0), + axis=0, + ) + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_dqh, stride_dqd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_doh, stride_dod), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_dh, stride_dn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_lh, stride_ln), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + # offsets + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) + # sparse + for i in range(real_topk): + # get current block start index + cur_block_idx = tl.load(t_ptr_j).to(tl.int32) + cur_token_start = tl.load(cu_block_include_tokens + block_start + cur_block_idx).to(tl.int32) + cur_block_size = tl.load(cu_block_include_tokens + block_start + cur_block_idx + 1).to(tl.int32) - cur_token_start + c = cur_token_start - k_start + t_ptr_j = t_ptr_j + stride_tk + + for b_j in range(0, cur_block_size, BLOCK_SIZE_K): + # load kv + k = tl.load( + tl.advance(k_ptrs, (c + b_j, 0)), boundary_check=(1, 0), padding_option="zero" + ) + v = tl.load( + tl.advance(v_ptrs, (c + b_j, 0)),boundary_check=(0, 1),padding_option="zero" + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((off_k + b_j < cur_block_size)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, tl.trans(v)) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _spatial_selection_attention_fwd( + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + cu_seqblocks: torch.Tensor, + cu_block_include_tokens: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len == k_len and k_len == v_len + topk = topk_idx.shape[-1] + assert topk_idx.shape[0] == num_k_heads + assert topk_idx.shape[1] == q_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device) + # launch kernel + num_q_loop = ( + cu_seqlens_q[-1].item() // 32768 + 1 + ) # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + forward_kernel[grid]( + q, + k, + v, + topk_idx, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_block_include_tokens, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _spatial_selection_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + cu_seqblocks: torch.Tensor, + cu_block_include_tokens: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) # [num_kv_head, total_block] + cu_topk_q_count = torch.cat( + [ + torch.zeros( + topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device + ), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) # [num_kv_head, cu_total_block + 1] + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + # topk_q_idx[h, cu_topk_q_count[h, cu_seqblocks[b] + i]:cu_topk_q_count[h, cu_seqblocks[b] + i + 1]] + topk_q_idx = reorder_topk_idx( + topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size + ) + # compute dk dv + dk = torch.zeros( + num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype + ) + dv = torch.zeros( + num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype + ) + batch_size = cu_seqlens_q.shape[0] - 1 + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + max_include_block = (cu_block_include_tokens[..., 1:] - cu_block_include_tokens[..., :-1]).max().item() + BLOCK_SIZE_K = 64 + BLOCK_SIZE_Q = 128 if BLOCK_SIZE_K <= 64 else 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + max_seqblocks = (cu_seqblocks[1:] - cu_seqblocks[:-1]).max().item() + grid = (batch_size, num_q_heads, max_seqblocks * triton.cdiv(max_include_block, BLOCK_SIZE_K)) + + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + cu_block_include_tokens, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + max_seqblocks, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = ( + cu_seqlens_q[-1].item() // 32768 + 1 + ) # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + num_warps = 4 if head_dim <= 64 else 8 + num_stages = 3 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + backward_dq[grid]( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_block_include_tokens, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class SpatialSelectionAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + cu_seqblocks: torch.Tensor, # [batch_size + 1] + cu_block_include_tokens: torch.Tensor, # [total_block_len] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + o, lse = _spatial_selection_attention_fwd( + q, + k, + v, + topk_idx, + cu_seqblocks, + cu_block_include_tokens, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx, cu_seqblocks, cu_block_include_tokens) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + # return + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx, cu_seqblocks, cu_block_include_tokens = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + + dq, dk, dv = _spatial_selection_attention_bwd( + o, + do, + lse, + q, + k, + v, + topk_idx, + cu_seqblocks, + cu_block_include_tokens, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def spatial_selection_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_topk: torch.Tensor, + cu_seqblocks: torch.Tensor, + cu_block_include_tokens: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Spatial selection attention implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + block_topk (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + cu_block_include_tokens (torch.Tensor) shape [total_block_len]: number of tokens within each block + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return SpatialSelectionAttention.apply( + q, + k, + v, + block_topk, + cu_seqblocks, + cu_block_include_tokens, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) diff --git a/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..65fd6bb3021b36e8bb9fb677ac2e89699e85734b --- /dev/null +++ b/pixal3d/modules/sparse/attention/spatial_sparse_attention/ops/window_attention.py @@ -0,0 +1,59 @@ +from typing import * +import torch +import flash_attn +from pixal3d.modules.sparse import SparseTensor +from pixal3d.modules.sparse.attention.windowed_attn import calc_window_partition + + +def sparse_window_attention( + q: SparseTensor, + k: SparseTensor, + v: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + q (SparseTensor): [N, *, H_q, C] sparse tensor containing query. + k (SparseTensor): [N, *, H_kv, C] sparse tensor containing key. + v (SparseTensor): [N, *, H_kv, C] sparse tensor containing value. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + + serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' + serialization_spatial_cache = q.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(q, window_size, shift_window) + q.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = q.feats.shape[0] + H = q.feats.shape[1] + H_kv = k.feats.shape[1] + C = q.feats.shape[2] + q_feats = q.feats[fwd_indices] # [M, H, C] + k_feats = k.feats[fwd_indices] + v_feats = v.feats[fwd_indices] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + q_feats = q_feats.reshape(B, N, H, C) + k_feats = k_feats.reshape(B, N, H_kv, C) + v_feats = v_feats.reshape(B, N, H_kv, C) + out = flash_attn.flash_attn_func(q_feats, k_feats, v_feats) + out = out.reshape(B * N, H, C) # [M, H, C] + else: + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(q.device).int() + out = flash_attn.flash_attn_varlen_func(q_feats, k_feats, v_feats, cu_seqlens, cu_seqlens, max(seq_lens), max(seq_lens)) + + out = out[bwd_indices] # [T, H, C] + + return q.replace(out) \ No newline at end of file diff --git a/pixal3d/modules/sparse/attention/windowed_attn.py b/pixal3d/modules/sparse/attention/windowed_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..2c450398c634ba314c0f2100bf4949207a40847b --- /dev/null +++ b/pixal3d/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,133 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0 +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (List[int]): Sequence lengths. + (List[int]): Sequence batch indices. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] + mask = seq_lens != 0 + seq_lens = seq_lens[mask].tolist() + seq_batch_indices = seq_batch_indices[mask].tolist() + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/pixal3d/modules/sparse/basic.py b/pixal3d/modules/sparse/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2a80b64a0239a5c3a344eb829db132c4f40acc --- /dev/null +++ b/pixal3d/modules/sparse/basic.py @@ -0,0 +1,471 @@ +from typing import * +import torch +import torch.nn as nn +from . import BACKEND, DEBUG +SparseTensorData = None # Lazy import + + +__all__ = [ + 'SparseTensor', + 'sparse_batch_broadcast', + 'sparse_batch_op', + 'sparse_cat', + 'sparse_unbind', +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + global SparseTensorData + if SparseTensorData is None: + import importlib + if BACKEND == 'torchsparse': + SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif BACKEND == 'spconv': + SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == 'torchsparse': + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get('scale', (1, 1, 1)) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.F + elif BACKEND == 'spconv': + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.F = value + elif BACKEND == 'spconv': + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.C + elif BACKEND == 'spconv': + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.C = value + elif BACKEND == 'spconv': + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.dense() + elif BACKEND == 'spconv': + return self.data.dense() + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == 'torchsparse': + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == 'spconv': + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> 'SparseTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + # if isinstance(other, torch.Tensor): + # try: + # print(f"Broadcasting {other.shape} to {self.shape}") + # other = torch.broadcast_to(other, self.shape) + # other = sparse_batch_broadcast(self, other) + # print(other.shape) + # print("======") + # except Exception as e: + # breakpoint() + # print(f"Failed to broadcast {other.shape} to {self.shape}: {e}") + # pass + if isinstance(other,torch.Tensor): + if other.shape[0] == self.feats.shape[0]: + pass + else: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + # other = sparse_batch_broadcast(self, other) + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/pixal3d/modules/sparse/conv/__init__.py b/pixal3d/modules/sparse/conv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..340a87126a8de574ee0276feb96b49824a2ce234 --- /dev/null +++ b/pixal3d/modules/sparse/conv/__init__.py @@ -0,0 +1,21 @@ +from .. import BACKEND + + +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' + +def __from_env(): + import os + + global SPCONV_ALGO + env_spconv_algo = os.environ.get('SPCONV_ALGO') + if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: + SPCONV_ALGO = env_spconv_algo + print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") + + +__from_env() + +if BACKEND == 'torchsparse': + from .conv_torchsparse import * +elif BACKEND == 'spconv': + from .conv_spconv import * diff --git a/pixal3d/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e4c765efd0bcf45abb543e313b40ebd016ea87d Binary files /dev/null and b/pixal3d/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc b/pixal3d/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb1b1824c55c12a70ee3860a02bf0938d7ba2d84 Binary files /dev/null and b/pixal3d/modules/sparse/conv/__pycache__/conv_torchsparse.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/conv/conv_spconv.py b/pixal3d/modules/sparse/conv/conv_spconv.py new file mode 100644 index 0000000000000000000000000000000000000000..ff058302f033f8f340c9f75efc869006cbf5b993 --- /dev/null +++ b/pixal3d/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from .. import DEBUG +from . import SPCONV_ALGO + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + algo = None + if SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + def forward(self, x: SparseTensor) -> SparseTensor: + + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features #[fwd] + sorted_coords = new_data.indices #[fwd] + unsorted_data = new_data + + indice_dict = new_data.indice_dict + + if 'spconv' not in globals(): + import spconv.pytorch as spconv + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size, indice_dict=indice_dict) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + if DEBUG: + assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/pixal3d/modules/sparse/conv/conv_torchsparse.py b/pixal3d/modules/sparse/conv/conv_torchsparse.py new file mode 100644 index 0000000000000000000000000000000000000000..57c749f6f8171f47de117381fd7230b7364e44bf --- /dev/null +++ b/pixal3d/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from torchsparse.utils import make_ntuple + + +def sparseconv3d_func(input: SparseTensor, weight: torch.Tensor, kernel_size: int, stride: int = 1, dilation: int = 1, padding: int = 0, bias: torch.Tensor = None, training: bool = True): + if 'torchsparse' not in globals(): + import torchsparse + stride = make_ntuple(stride, ndim=3) + kernel_size = make_ntuple(kernel_size, ndim=3) + _padding = make_ntuple(padding, 3) + padding = () + for i in range(3): + if kernel_size[i] % 2 == 1 and stride[i] == 1: + padding += ((kernel_size[i] - 1) // 2,) + else: + padding += (_padding[i],) + out = torchsparse.nn.functional.conv3d(input.data, weight, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, training=training) + spatial_range = out.spatial_range + new_shape = [input.shape[0], weight.shape[1]] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=input.layout if all(s == 1 for s in stride) else None) + out._spatial_cache = input._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(input._scale, stride)]) + out.data.spatial_range = spatial_range + return out + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) + + def forward(self, x: SparseTensor) -> SparseTensor: + + input_data = x.data + input_dtype = input_data.F.dtype + + if input_dtype == torch.bfloat16: + input_data.F = input_data.F.float() + + with torch.amp.autocast('cuda', enabled=False): + out = self.conv(input_data) + else: + out = self.conv(input_data) + + spatial_range = out.spatial_range + + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + + out.data.spatial_range = spatial_range + + if input_dtype == torch.bfloat16: + out.data.F = out.data.F.to(torch.bfloat16) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + def forward(self, x: SparseTensor) -> SparseTensor: + input_data = x.data + input_dtype = input_data.F.dtype + + if input_dtype == torch.bfloat16: + input_data.F = input_data.F.float() + + with torch.amp.autocast('cuda', enabled=False): + out = self.conv(input_data) + else: + out = self.conv(input_data) + + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) + + if input_dtype == torch.bfloat16: + out.data.F = out.data.F.to(torch.bfloat16) + + return out + + + diff --git a/pixal3d/modules/sparse/linear.py b/pixal3d/modules/sparse/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..a854e77ce87d1a190b9730d91f363a821ff250bd --- /dev/null +++ b/pixal3d/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) diff --git a/pixal3d/modules/sparse/nonlinearity.py b/pixal3d/modules/sparse/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..2e6bfd855271d238fe34ab8bec2744bf9db58b94 --- /dev/null +++ b/pixal3d/modules/sparse/nonlinearity.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' + 'SparseTanh', + 'SparseSigmoid', +] + +class SparseSigmoid(nn.Sigmoid): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + +class SparseReLU(nn.ReLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + +class SparseGELU(nn.GELU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + +class SparseTanh(nn.Tanh): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(self.activation(input.feats)) + diff --git a/pixal3d/modules/sparse/norm.py b/pixal3d/modules/sparse/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..6b38a36682c098210000dc31d68ddc31ccd2929d --- /dev/null +++ b/pixal3d/modules/sparse/norm.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from . import SparseTensor +from . import DEBUG + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + if DEBUG: + assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) diff --git a/pixal3d/modules/sparse/spatial.py b/pixal3d/modules/sparse/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..329e254fc31c05a149d2a8272e96c422e9ed7959 --- /dev/null +++ b/pixal3d/modules/sparse/spatial.py @@ -0,0 +1,114 @@ +from typing import * +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', + 'SparseSubdivide' +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], mode="mean"): + super(SparseDownsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + self.mode = mode + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' + + coord = list(input.coords.unbind(dim=-1)) + for i, f in enumerate(factor): + coord[i+1] = coord[i+1] // f + + MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + #### using fp16 could cause overflow when factor is large ###### + dtype = input.feats.dtype + new_feats = torch.scatter_reduce( + torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=torch.float64), + dim=0, + index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), + src=input.feats.double(), + reduce=self.mode, + ) + new_feats = new_feats.to(dtype) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + out = SparseTensor(new_feats, new_coords, input.shape,) + out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + + out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) + out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) + out.register_spatial_cache(f'upsample_{factor}_idx', idx) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): + super(SparseUpsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' + + new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') + new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') + idx = input.get_spatial_cache(f'upsample_{factor}_idx') + if any([x is None for x in [new_coords, new_layout, idx]]): + raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') + new_feats = input.feats[idx] + out = SparseTensor(new_feats, new_coords, input.shape, new_layout) + out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + return out + +class SparseSubdivide(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self): + super(SparseSubdivide, self).__init__() + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + # upsample scale=2^DIM + n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) + n_coords = torch.nonzero(n_cube) + n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) + factor = n_coords.shape[0] + assert factor == 2 ** DIM + new_coords = input.coords.clone() + new_coords[:, 1:] *= 2 + new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) + + new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) + out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) + out._scale = input._scale * 2 + out._spatial_cache = input._spatial_cache + return out + diff --git a/pixal3d/modules/sparse/transformer/__init__.py b/pixal3d/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/pixal3d/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/pixal3d/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6736305ce8837330a7c5522cc55113bf0b99fa97 Binary files /dev/null and b/pixal3d/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc b/pixal3d/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd0b96da4859d61e43a3aeea11d1a6ab3ab70e6f Binary files /dev/null and b/pixal3d/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc b/pixal3d/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..147cbcd28961ca118eb7459ee131cc9ab4fed8b1 Binary files /dev/null and b/pixal3d/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/pixal3d/modules/sparse/transformer/blocks.py b/pixal3d/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9d037a49bf83e1c2dfb2f8c4b23d2e9d6c51e9f0 --- /dev/null +++ b/pixal3d/modules/sparse/transformer/blocks.py @@ -0,0 +1,151 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/pixal3d/modules/sparse/transformer/modulated.py b/pixal3d/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..dda6e58cc8257c6e2a769db94f6a7cb37d0b2b0c --- /dev/null +++ b/pixal3d/modules/sparse/transformer/modulated.py @@ -0,0 +1,228 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..attention import SparseMultiHeadAttention, SerializeMode, SpatialSparseAttention,SparseProjAttention +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + compression_version: str = "v2", + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + use_ssa: bool = True, + num_kv_heads: int = 2, + compression_block_size: int = 4, + selection_block_size: int = 8, + topk: int = 8, + resolution: int = 64, + image_attn_mode:str = 'cross', + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + if use_ssa: + self.self_attn = SpatialSparseAttention( + channels, + num_q_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=channels//num_heads, + compression_block_size=compression_block_size, + compression_version=compression_version, + selection_block_size=selection_block_size, + topk=topk, + window_size=window_size, + shift_window=shift_window, + resolution=resolution, + ) + else: + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.image_attn_mode = image_attn_mode + if self.image_attn_mode == 'cross': + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + else: + cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.cross_attn = SparseProjAttention(cross_attn_block=cross_attn) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + + feats_h = h.feats + layouts = h.layout + ada_r1 = [] + for i in range(len(layouts)): + ada_r1.append(feats_h[layouts[i]] * (1 + scale_msa[i:i+1]) + shift_msa[i:i+1]) + h = h.replace(torch.cat(ada_r1, dim=0)) + h = self.self_attn(h) + + feats_h = h.feats + layouts = h.layout + ada_r2 = [] + for i in range(len(layouts)): + ada_r2.append(feats_h[layouts[i]] * gate_msa[i:i+1]) + h = h.replace(torch.cat(ada_r2, dim=0)) + + x = x + h + h = x.replace(self.norm2(x.feats)) + # breakpoint() + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + + feats_h = h.feats + layouts = h.layout + ada_r3 = [] + for i in range(len(layouts)): + ada_r3.append(feats_h[layouts[i]] * (1 + scale_mlp[i:i+1]) + shift_mlp[i:i+1]) + h = h.replace(torch.cat(ada_r3, dim=0)) + h = self.mlp(h) + + feats_h = h.feats + layouts = h.layout + ada_r4 = [] + for i in range(len(layouts)): + ada_r4.append(feats_h[layouts[i]] * gate_mlp[i:i+1]) + h = h.replace(torch.cat(ada_r4, dim=0)) + + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/pixal3d/modules/spatial.py b/pixal3d/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3 --- /dev/null +++ b/pixal3d/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/pixal3d/modules/transformer/__init__.py b/pixal3d/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/pixal3d/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/pixal3d/modules/transformer/__pycache__/__init__.cpython-310.pyc b/pixal3d/modules/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..125b55bef6273de912816bb5a99de0cd1034c4c2 Binary files /dev/null and b/pixal3d/modules/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/modules/transformer/__pycache__/blocks.cpython-310.pyc b/pixal3d/modules/transformer/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2695b59a4dfa747e647573ab8c1dad57725c3ed7 Binary files /dev/null and b/pixal3d/modules/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/pixal3d/modules/transformer/__pycache__/modulated.cpython-310.pyc b/pixal3d/modules/transformer/__pycache__/modulated.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44a81a6eea6fe49ebaec3c4eef216a1bc18e3819 Binary files /dev/null and b/pixal3d/modules/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/pixal3d/modules/transformer/blocks.py b/pixal3d/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..605ae33bc44276f45f73789f547dc756ac3999da --- /dev/null +++ b/pixal3d/modules/transformer/blocks.py @@ -0,0 +1,184 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor, factor: float = None) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + if factor is not None: + x = x * factor + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) + \ No newline at end of file diff --git a/pixal3d/modules/transformer/modulated.py b/pixal3d/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..e35d17aa171323cd68a3ae191db631211de71c61 --- /dev/null +++ b/pixal3d/modules/transformer/modulated.py @@ -0,0 +1,182 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention,ProjectAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + image_attn_mode:Literal["cross", "proj"] = "cross", + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + # self.norm4 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.image_attn_mode = image_attn_mode + if image_attn_mode == "cross": + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + elif image_attn_mode == "proj": + cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.cross_attn = ProjectAttention(cross_attn) + else: + raise ValueError(f"Unknown image attention mode: {image_attn_mode}") + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + + + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + + # h = self.norm2(x) + # h = self.cross_attn(h, context) + # x = x + h + + h = self.norm2(x) + h = self.cross_attn(h, context) # h = h + context + x = x + h + + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) + \ No newline at end of file diff --git a/pixal3d/modules/utils.py b/pixal3d/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39788c3d6ef61c8dff9b741fc6d03140c06f33db --- /dev/null +++ b/pixal3d/modules/utils.py @@ -0,0 +1,54 @@ +import torch.nn as nn +from ..modules import sparse as sp + +FP16_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + # if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/pixal3d/utils/__init__.py b/pixal3d/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab60c0bba4624f90bc4c16216c1b7a56d5a20617 --- /dev/null +++ b/pixal3d/utils/__init__.py @@ -0,0 +1,7 @@ +from . import base +from .sparse import * + +# Inference utilities +from .util import instantiate_from_config +from .mesh import normalize_mesh, mesh2index +from .fill_hole import postprocess_mesh \ No newline at end of file diff --git a/pixal3d/utils/__pycache__/__init__.cpython-310.pyc b/pixal3d/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94fb2308c5c9571c394daf5a578cfab4ff404049 Binary files /dev/null and b/pixal3d/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/pixal3d/utils/__pycache__/base.cpython-310.pyc b/pixal3d/utils/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7232b0c95dfee827d8a3cb20a47bd25bc99585ec Binary files /dev/null and b/pixal3d/utils/__pycache__/base.cpython-310.pyc differ diff --git a/pixal3d/utils/__pycache__/config.cpython-310.pyc b/pixal3d/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..974ed525a226a2a0d01c53e8309d16a7a2762703 Binary files /dev/null and b/pixal3d/utils/__pycache__/config.cpython-310.pyc differ diff --git a/pixal3d/utils/__pycache__/fill_hole.cpython-310.pyc b/pixal3d/utils/__pycache__/fill_hole.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75b72020137c9eefaa17fdf676202a948ecbd616 Binary files /dev/null and b/pixal3d/utils/__pycache__/fill_hole.cpython-310.pyc differ diff --git a/pixal3d/utils/__pycache__/mesh.cpython-310.pyc b/pixal3d/utils/__pycache__/mesh.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b53576e3e5c88623612bd9a28d24687ecdf1f47 Binary files /dev/null and b/pixal3d/utils/__pycache__/mesh.cpython-310.pyc differ diff --git a/pixal3d/utils/__pycache__/misc.cpython-310.pyc b/pixal3d/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa72394c6c112b403b15ca6fa8592879f1d20f3 Binary files /dev/null and b/pixal3d/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/pixal3d/utils/__pycache__/sparse.cpython-310.pyc b/pixal3d/utils/__pycache__/sparse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11ca46beef03bef868210c7bac014ef7b7d5d1f2 Binary files /dev/null and b/pixal3d/utils/__pycache__/sparse.cpython-310.pyc differ diff --git a/pixal3d/utils/__pycache__/typing.cpython-310.pyc b/pixal3d/utils/__pycache__/typing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd3c1b8df5cde3cd7a1f7560c7a0a35565523dae Binary files /dev/null and b/pixal3d/utils/__pycache__/typing.cpython-310.pyc differ diff --git a/pixal3d/utils/__pycache__/util.cpython-310.pyc b/pixal3d/utils/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d6b6d31af98cb13d25b67fb9bc00ea95ec4bda7 Binary files /dev/null and b/pixal3d/utils/__pycache__/util.cpython-310.pyc differ diff --git a/pixal3d/utils/base.py b/pixal3d/utils/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8fdbfd93f64047e9e88af862f4863b26969f7e3f --- /dev/null +++ b/pixal3d/utils/base.py @@ -0,0 +1,219 @@ +from dataclasses import dataclass + +import os +import copy +import json +from omegaconf import OmegaConf +import torch +import torch.nn as nn + +from diffusers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import ( + extract_commit_hash, +) + +from pixal3d.utils.config import parse_structured +from pixal3d.utils.misc import get_device, load_module_weights +from pixal3d.utils.typing import * + + + + + + +class Configurable: + @dataclass + class Config: + pass + + def __init__(self, cfg: Optional[dict] = None) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + + +class Updateable: + def do_update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step( + epoch, global_step, on_load_weights=on_load_weights + ) + self.update_step(epoch, global_step, on_load_weights=on_load_weights) + + def do_update_step_end(self, epoch: int, global_step: int): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + self.update_step_end(epoch, global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # override this method to implement custom update logic + # if on_load_weights is True, you should be careful doing things related to model evaluations, + # as the models and tensors are not guarenteed to be on the same device + pass + + def update_step_end(self, epoch: int, global_step: int): + pass + + +def update_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step(epoch, global_step) + + +def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + + +class BaseObject(Updateable): + @dataclass + class Config: + pass + + cfg: Config # add this to every subclass of BaseObject to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self.configure(*args, **kwargs) + + def configure(self, *args, **kwargs) -> None: + pass + + +class BaseModule(ModelMixin, Updateable, nn.Module): + @dataclass + class Config: + weights: Optional[str] = None + + cfg: Config # add this to every subclass of BaseModule to enable static type checking + config_name = "config.json" + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + # self.device = get_device() + self.configure(*args, **kwargs) + if self.cfg.weights is not None: + # format: path/to/weights:module_name + weights_path, module_name = self.cfg.weights.split(":") + state_dict, epoch, global_step = load_module_weights( + weights_path, module_name=module_name, map_location="cpu" + ) + self.load_state_dict(state_dict) + self.do_update_step( + epoch, global_step, on_load_weights=True + ) # restore states + # dummy tensor to indicate model state + self._dummy: Float[Tensor, "..."] + self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) + + def configure(self, *args, **kwargs) -> None: + pass + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, + ): + subfolder = kwargs.pop("subfolder", None) + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, cls.config_name) + ): + # Load from a PyTorch checkpoint + config_file = os.path.join( + pretrained_model_name_or_path, cls.config_name + ) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + raise ValueError + + config_dict = json.load(open(config_file, "r")) + commit_hash = extract_commit_hash(config_file) + + outputs = (config_dict,) + + if return_unused_kwargs: + outputs += (kwargs,) + + if return_commit_hash: + outputs += (commit_hash,) + + return outputs + + @classmethod + def from_config(cls, config: Dict[str, Any] = None, **kwargs): + model = cls(config) + return model + + def register_to_config(self, **kwargs): + pass + + def save_config(self, save_directory: Union[str, os.PathLike], **kwargs): + """ + Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file is saved (will be created if it does not exist). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + raise AssertionError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + config_dict = OmegaConf.to_container(self.cfg, resolve=True) + for k in copy.deepcopy(config_dict).keys(): + if k.startswith("pretrained"): + config_dict.pop(k) + config_dict.pop("weights") + with open(output_config_file, "w", encoding="utf-8") as f: + json.dump(config_dict, f, ensure_ascii=False, indent=4) + + print(f"Configuration saved in {output_config_file}") diff --git a/pixal3d/utils/config.py b/pixal3d/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1384429d90e75d346976b06351114fabf002e9a7 --- /dev/null +++ b/pixal3d/utils/config.py @@ -0,0 +1,128 @@ +import os +from dataclasses import dataclass, field +from datetime import datetime + +from omegaconf import OmegaConf + +import pixal3d +from pixal3d.utils.typing import * + +# ============ Register OmegaConf Recolvers ============= # +OmegaConf.register_new_resolver( + "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) +) +OmegaConf.register_new_resolver("add", lambda a, b: a + b) +OmegaConf.register_new_resolver("sub", lambda a, b: a - b) +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +OmegaConf.register_new_resolver("div", lambda a, b: a / b) +OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) +OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) +OmegaConf.register_new_resolver("rmspace", lambda s, sub: str(s).replace(" ", sub)) +OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) +OmegaConf.register_new_resolver("gt0", lambda s: s > 0) +OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0) +OmegaConf.register_new_resolver("not", lambda s: not s) +OmegaConf.register_new_resolver( + "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0 +) +# ======================================================= # + + +def C_max(value: Any) -> float: + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError("Scalar specification only supports list, got", type(value)) + if len(value) >= 6: + max_value = value[2] + for i in range(4, len(value), 2): + max_value = max(max_value, value[i]) + value = [value[0], value[1], max_value, value[3]] + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + value = max(start_value, end_value) + return value + + +@dataclass +class ExperimentConfig: + name: str = "default" + description: str = "" + tag: str = "" + seed: int = 0 + use_timestamp: bool = True + timestamp: Optional[str] = None + exp_root_dir: str = "outputs" + + ### these shouldn't be set manually + exp_dir: str = "outputs/default" + trial_name: str = "exp" + trial_dir: str = "outputs/default/exp" + n_gpus: int = 1 + ### + + resume: Optional[str] = None + + data_type: str = "" + data: dict = field(default_factory=dict) + + system_type: str = "" + system: dict = field(default_factory=dict) + + # accept pytorch-lightning trainer parameters + # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api + trainer: dict = field(default_factory=dict) + + # accept pytorch-lightning checkpoint callback parameters + # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint + checkpoint: dict = field(default_factory=dict) + + def __post_init__(self): + if not self.tag and not self.use_timestamp: + raise ValueError("Either tag is specified or use_timestamp is True.") + self.trial_name = self.tag + # if resume from an existing config, self.timestamp should not be None + if self.timestamp is None: + self.timestamp = "" + if self.use_timestamp: + if self.n_gpus > 1: + pixal3d.warn( + "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." + ) + else: + self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") + self.trial_name += self.timestamp + self.exp_dir = os.path.join(self.exp_root_dir, self.name) + self.trial_dir = os.path.join(self.exp_dir, self.trial_name) + # os.makedirs(self.trial_dir, exist_ok=True) + + +def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: + if from_string: + yaml_confs = [OmegaConf.create(s) for s in yamls] + else: + yaml_confs = [OmegaConf.load(f) for f in yamls] + cli_conf = OmegaConf.from_cli(cli_args) + cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) + OmegaConf.resolve(cfg) + assert isinstance(cfg, DictConfig) + scfg = parse_structured(ExperimentConfig, cfg) + return scfg + + +def config_to_primitive(config, resolve: bool = True) -> Any: + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path: str, config) -> None: + with open(path, "w") as fp: + OmegaConf.save(config=config, f=fp) + + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + scfg = OmegaConf.structured(fields(**cfg)) + return scfg diff --git a/pixal3d/utils/fill_hole.py b/pixal3d/utils/fill_hole.py new file mode 100644 index 0000000000000000000000000000000000000000..b04a212448b01c6f1b13537aa037b220e8da462e --- /dev/null +++ b/pixal3d/utils/fill_hole.py @@ -0,0 +1,35 @@ +import numpy as np +from tqdm import tqdm +import pyvista as pv + + +def postprocess_mesh( + vertices: np.array, + faces: np.array, + simplify: bool = False, + simplify_ratio: float = 0.9, + verbose: bool = False, +): + """ + Postprocess a mesh by simplifying. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + simplify (bool): Whether to simplify the mesh, using quadric edge collapse. + simplify_ratio (float): Ratio of faces to keep after simplification. + verbose (bool): Whether to print progress. + """ + + if verbose: + tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Simplify + if simplify and simplify_ratio > 0: + mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) + mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) + vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] + if verbose: + tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + return vertices, faces diff --git a/pixal3d/utils/grad_clip.py b/pixal3d/utils/grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e2604593e48122d9af34e2ddbd679be24591a6 --- /dev/null +++ b/pixal3d/utils/grad_clip.py @@ -0,0 +1,150 @@ +""" +Adaptive Gradient Clipping utilities for training. +""" +import numpy as np +import torch +from typing import Optional, Iterable, Union + + +class AdaptiveGradClipper: + """ + Adaptive gradient clipping for training. + """ + def __init__( + self, + max_norm=None, + clip_percentile=95.0, + buffer_size=1000, + skip_mode=False, + max_skipped_steps=500, + use_buffer=True, + ): + self.max_norm = max_norm + self.clip_percentile = clip_percentile + self.buffer_size = buffer_size + self.skip_mode = skip_mode # 如果True,超过阈值时跳过更新;如果False,进行梯度裁剪 + + self._grad_norm = np.zeros(buffer_size, dtype=np.float32) + self._max_norm = max_norm + self._buffer_ptr = 0 + self._buffer_length = 0 + self._skipped_steps = 0 # 记录跳过的步数 + self._skipped_steps_list = np.zeros(buffer_size, dtype=np.int32) + self._skipped_steps_ptr = 0 + self.max_skipped_steps = max_skipped_steps + self.use_buffer = use_buffer + def __repr__(self): + mode_str = "skip" if self.skip_mode else "clip" + return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile}, mode={mode_str})' + + def state_dict(self): + return { + 'grad_norm': self._grad_norm, + 'max_norm': self._max_norm, + 'buffer_ptr': self._buffer_ptr, + 'buffer_length': self._buffer_length, + 'skipped_steps': self._skipped_steps, + 'skipped_steps_list': self._skipped_steps_list, + 'skipped_steps_ptr': self._skipped_steps_ptr, + } + + def load_state_dict(self, state_dict): + self._grad_norm = state_dict['grad_norm'] + self._max_norm = state_dict['max_norm'] + self._buffer_ptr = state_dict['buffer_ptr'] + self._buffer_length = state_dict['buffer_length'] + self._skipped_steps = state_dict.get('skipped_steps', 0) # 兼容旧版本 + # self._skipped_steps_list = state_dict.get('skipped_steps_list', np.zeros(self.buffer_size, dtype=np.int32)) + self._skipped_steps_ptr = state_dict.get('skipped_steps_ptr', 0) + + def log(self): + return { + 'max_norm': self._max_norm, + 'skipped_steps': self._skipped_steps, + 'skipped_steps_list': self._skipped_steps_list, + 'skipped_steps_ptr': self._skipped_steps_ptr, + } + + def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None,optimizer=None): + """Clip or skip gradients based on their norm with two-tier threshold system. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. + + Two-tier threshold logic: + 1. If grad_norm > initial_max_norm (constructor param): + - skip_mode=True: SKIP the update (zero gradients) + - skip_mode=False: CLIP to adaptive threshold + 2. If adaptive_max_norm < grad_norm <= initial_max_norm: + - Both modes: CLIP to adaptive threshold + 3. If grad_norm <= adaptive_max_norm: + - Both modes: No action (normal update) + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + tuple: (grad_norm, should_skip) - grad_norm is the original gradient norm, + should_skip indicates whether this step should be skipped + """ + # 使用初始max_norm作为skip阈值,自适应_max_norm作为clip阈值 + initial_max_norm = self.max_norm if self.max_norm is not None else float('inf') + adaptive_max_norm = self._max_norm if self._max_norm is not None else float('inf') + should_skip = False + + # 一次调用:获取原始梯度范数并裁剪到adaptive_max_norm + grad_norm = torch.nn.utils.clip_grad_norm_( + parameters, max_norm=adaptive_max_norm, norm_type=norm_type, + error_if_nonfinite=error_if_nonfinite, foreach=foreach + ) + if not self.use_buffer: + return grad_norm, should_skip + if torch.isfinite(grad_norm): + grad_norm_value = grad_norm.item() + + if self.skip_mode and grad_norm_value > initial_max_norm: + # Skip模式:如果原始梯度超过初始max_norm,跳过本次更新 + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is not None: + p.grad.zero_() + should_skip = True + self._skipped_steps += 1 + self._skipped_steps_list[self._skipped_steps_ptr] = 1 + self._skipped_steps_ptr = (self._skipped_steps_ptr + 1) % self.buffer_size + if optimizer is not None: + optimizer.zero_grad() + print(f"[AdaptiveGradClipper] Skipping step due to large gradient norm: {grad_norm_value:.6f} > {initial_max_norm:.6f} (initial_max_norm)") + # Skip时不更新缓冲区,因为异常梯度不应该影响自适应阈值计算 + + else: + # 正常情况:使用已经被裁剪到adaptive_max_norm的梯度,更新缓冲区 + self._grad_norm[self._buffer_ptr] = grad_norm_value + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._skipped_steps_list[self._skipped_steps_ptr] = 0 + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + self._skipped_steps_ptr = (self._skipped_steps_ptr + 1) % self.buffer_size + + if grad_norm_value > adaptive_max_norm: + if self.skip_mode: + print(f"[AdaptiveGradClipper] Clipping gradient norm: {grad_norm_value:.6f} -> {adaptive_max_norm:.6f} (adaptive_max_norm)") + + # 重新计算自适应阈值(只要不skip就可以更新阈值) + if not should_skip and self._buffer_length == self.buffer_size: + self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) + self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm + if self._skipped_steps_list.sum() > self.max_skipped_steps: + raise Exception("Too many skipped steps, something is wrong") + return grad_norm, should_skip + diff --git a/pixal3d/utils/mesh.py b/pixal3d/utils/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..321172edfb7814e63222d2b44ff9f9ba2365323f --- /dev/null +++ b/pixal3d/utils/mesh.py @@ -0,0 +1,34 @@ +import torch +import numpy as np +import udf_ext + + +def compute_valid_udf(vertices, faces, dim=512, threshold=8.0): + if not faces.is_cuda or not vertices.is_cuda: + raise ValueError("Both maze and visited tensors must be CUDA tensors") + udf = torch.zeros(dim**3,device=vertices.device).int() + 10000000 + n_faces = faces.shape[0] + udf_ext.compute_valid_udf(vertices, faces, udf, n_faces, dim, threshold) + return udf.float()/10000000. + +def normalize_mesh(mesh, scale=0.95): + vertices = mesh.vertices + min_coords, max_coords = vertices.min(axis=0), vertices.max(axis=0) + dxyz = max_coords - min_coords + dist = max(dxyz) + mesh_scale = 2.0 * scale / dist + mesh_offset = -(min_coords + max_coords) / 2 + vertices = (vertices + mesh_offset) * mesh_scale + mesh.vertices = vertices + return mesh + +def mesh2index(mesh, size=1024, factor=8): + vertices = torch.Tensor(mesh.vertices).float().cuda() * 0.5 + faces = torch.Tensor(mesh.faces).int().cuda() + sdf = compute_valid_udf(vertices, faces, dim=size, threshold=4.0) + sdf = sdf.reshape(size, size, size).unsqueeze(0) + + sparse_index = (sdf < 4/size).nonzero() + sparse_index[..., 1:] = sparse_index[..., 1:] // factor + latent_index = torch.unique(sparse_index, dim=0) + return latent_index \ No newline at end of file diff --git a/pixal3d/utils/misc.py b/pixal3d/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..36f0d4cb5c783bd706428dbabc9f9a838b576456 --- /dev/null +++ b/pixal3d/utils/misc.py @@ -0,0 +1,165 @@ +import gc +import os +import re + +import torch +import torch.distributed as dist +from packaging import version + +from pixal3d.utils.config import config_to_primitive +from pixal3d.utils.typing import * + + +def parse_version(ver: str): + return version.parse(ver) + + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def get_world_size(): + world_size_keys = ("WORLD_SIZE", "SLURM_NTASKS", "JSM_NAMESPACE_SIZE") + for key in world_size_keys: + world_size = os.environ.get(key) + if world_size is not None: + return int(world_size) + return 1 + + +def get_device(): + return torch.device(f"cuda:{get_rank()}") + + +def load_module_weights( + path, module_name=None, ignore_modules=None, map_location=None +) -> Tuple[dict, int, int]: + if module_name is not None and ignore_modules is not None: + raise ValueError("module_name and ignore_modules cannot be both set") + if map_location is None: + map_location = get_device() + + ckpt = torch.load(path, map_location=map_location) + state_dict = ckpt["state_dict"] + state_dict_to_load = state_dict + + if ignore_modules is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + ignore = any( + [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] + ) + if ignore: + continue + state_dict_to_load[k] = v + + if module_name is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + m = re.match(rf"^{module_name}\.(.*)$", k) + if m is None: + continue + state_dict_to_load[m.group(1)] = v + + return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] + + +def C(value: Any, epoch: int, global_step: int) -> float: + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError("Scalar specification only supports list, got", type(value)) + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + if isinstance(end_step, int): + current_step = global_step + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + elif isinstance(end_step, float): + current_step = epoch + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + return value + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + tcnn.free_temporary_memory() + + +def finish_with_cleanup(func: Callable): + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + cleanup() + return out + + return wrapper + + +def _distributed_available(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + +def barrier(): + if not _distributed_available(): + return + else: + torch.distributed.barrier() + + +def broadcast(tensor, src=0): + if not _distributed_available(): + return tensor + else: + torch.distributed.broadcast(tensor, src=src) + return tensor + + +def enable_gradient(model, enabled: bool = True) -> None: + for param in model.parameters(): + param.requires_grad_(enabled) + + +def all_gather_batch(tensors): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + if isinstance(tensors, list): + return tensors + return tensors + if not isinstance(tensors, list): + is_list = False + tensors = [tensors] + else: + is_list = True + output_tensor = [] + tensor_list = [] + for tensor in tensors: + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_all, tensor, async_op=False) # performance opt + + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + if not is_list: + return output_tensor[0] + return output_tensor diff --git a/pixal3d/utils/rembg.py b/pixal3d/utils/rembg.py new file mode 100644 index 0000000000000000000000000000000000000000..55ba0392743bc68741fe08b2ab1a6d45c130a7fb --- /dev/null +++ b/pixal3d/utils/rembg.py @@ -0,0 +1,40 @@ +import numpy as np +import torch +from torchvision import transforms + + +class BiRefNet(object): + def __init__(self, device): + from transformers import AutoModelForImageSegmentation + self.birefnet_model = AutoModelForImageSegmentation.from_pretrained( + 'briaai/RMBG-2.0', + trust_remote_code=True, + ).to(device) + print("loaded BiRefNet from briaai/RMBG-2.0") + self.birefnet_model.eval() + self.device = device + + def run(self, image, use_alpha=False): + if use_alpha: + if image.mode != 'RGBA': + image = image.convert('RGBA') + return np.array(image) + image = image.convert('RGB') + image_size = (1024, 1024) + transform_image = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + input_images = transform_image(image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + preds = self.birefnet_model(input_images)[-1].sigmoid().cpu() + + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image.size) + mask = np.array(mask) + image = np.concatenate([np.array(image), mask[..., None]], axis=-1) + return image \ No newline at end of file diff --git a/pixal3d/utils/saving.py b/pixal3d/utils/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..074216db04885e6d0a0be27e5a9f3fae49e37da9 --- /dev/null +++ b/pixal3d/utils/saving.py @@ -0,0 +1,468 @@ +import json +import os +import re +import shutil + +import cv2 +import imageio +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchvision.utils as vutils +import trimesh +import wandb +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +from PIL import Image, ImageDraw +from pytorch_lightning.loggers import WandbLogger + +from pixal3d.utils.typing import * + + +class SaverMixin: + _save_dir: Optional[str] = None + _wandb_logger: Optional[WandbLogger] = None + + def set_save_dir(self, save_dir: str): + self._save_dir = save_dir + + def get_save_dir(self): + if self._save_dir is None: + raise ValueError("Save dir is not set") + return self._save_dir + + def convert_data(self, data): + if data is None: + return None + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + return data.detach().cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + def get_save_path(self, filename): + save_path = os.path.join(self.get_save_dir(), filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + def create_loggers(self, cfg_loggers: DictConfig, full_config: Optional[DictConfig] = None) -> None: + if "wandb" in cfg_loggers.keys() and cfg_loggers.wandb.enable: + from omegaconf import OmegaConf + # 将完整配置转换为字典,上传到 wandb 云端显示 + config_dict = OmegaConf.to_container(full_config, resolve=True) if full_config else None + self._wandb_logger = WandbLogger( + project=cfg_loggers.wandb.project, + name=cfg_loggers.wandb.name, + config=config_dict # 上传配置到 wandb + ) + + def get_loggers(self) -> List: + if self._wandb_logger: + return [self._wandb_logger] + else: + return [] + + DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} + DEFAULT_UV_KWARGS = { + "data_format": "HWC", + "data_range": (0, 1), + "cmap": "checkerboard", + } + DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} + DEFAULT_GRID_KWARGS = {"align": "max"} + + def get_rgb_image_(self, img, data_format, data_range, rgba=False): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + if img.dtype != np.uint8: + img = img.clip(min=data_range[0], max=data_range[1]) + img = ( + (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 + ).astype(np.uint8) + nc = 4 if rgba else 3 + imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] + imgs = [ + ( + img_ + if img_.shape[-1] == nc + else np.concatenate( + [ + img_, + np.zeros( + (img_.shape[0], img_.shape[1], nc - img_.shape[2]), + dtype=img_.dtype, + ), + ], + axis=-1, + ) + ) + for img_ in imgs + ] + img = np.concatenate(imgs, axis=1) + if rgba: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_rgb_image( + self, + filename, + img, + data_format, + data_range, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Image(self.get_save_path(filename)), + "trainer/global_step": step, + } + ) + + def save_rgb_image( + self, + filename, + img, + data_format=DEFAULT_RGB_KWARGS["data_format"], + data_range=DEFAULT_RGB_KWARGS["data_range"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_rgb_image(save_path, img, data_format, data_range, name, step) + return save_path + + def get_uv_image_(self, img, data_format, data_range, cmap): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in ["checkerboard", "color"] + if cmap == "checkerboard": + n_grid = 64 + mask = (img * n_grid).astype(int) + mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 + img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 + img[mask] = np.array([255, 0, 255], dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif cmap == "color": + img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) + img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) + img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) + img = img_ + return img + + def save_uv_image( + self, + filename, + img, + data_format=DEFAULT_UV_KWARGS["data_format"], + data_range=DEFAULT_UV_KWARGS["data_range"], + cmap=DEFAULT_UV_KWARGS["cmap"], + ) -> str: + save_path = self.get_save_path(filename) + img = self.get_uv_image_(img, data_format, data_range, cmap) + cv2.imwrite(save_path, img) + return save_path + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, "jet", "magma", "spectral"] + if cmap == None: + img = (img * 255.0).astype(np.uint8) + img = np.repeat(img[..., None], 3, axis=2) + elif cmap == "jet": + img = (img * 255.0).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == "magma": + img = 1.0 - img + base = cm.get_cmap("magma") + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins + )(np.linspace(0, 1, num_bins))[:, :3] + a = np.floor(img * 255.0) + b = (a + 1).clip(max=255.0) + f = img * 255.0 - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] + img = (img * 255.0).astype(np.uint8) + elif cmap == "spectral": + colormap = plt.get_cmap("Spectral") + + def blend_rgba(image): + image = image[..., :3] * image[..., -1:] + ( + 1.0 - image[..., -1:] + ) # blend A to RGB + return image + + img = colormap(img) + img = blend_rgba(img) + img = (img * 255).astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_grayscale_image( + self, + filename, + img, + data_range, + cmap, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Image(self.get_save_path(filename)), + "trainer/global_step": step, + } + ) + + def save_grayscale_image( + self, + filename, + img, + data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], + cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_grayscale_image(save_path, img, data_range, cmap, name, step) + return save_path + + def get_image_grid_(self, imgs, align): + if isinstance(imgs[0], list): + return np.concatenate( + [self.get_image_grid_(row, align) for row in imgs], axis=0 + ) + cols = [] + for col in imgs: + assert col["type"] in ["rgb", "uv", "grayscale"] + if col["type"] == "rgb": + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col["kwargs"]) + cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) + elif col["type"] == "uv": + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col["kwargs"]) + cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) + elif col["type"] == "grayscale": + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col["kwargs"]) + cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) + + if align == "max": + h = max([col.shape[0] for col in cols]) + w = max([col.shape[1] for col in cols]) + elif align == "min": + h = min([col.shape[0] for col in cols]) + w = min([col.shape[1] for col in cols]) + elif isinstance(align, int): + h = align + w = align + elif ( + isinstance(align, tuple) + and isinstance(align[0], int) + and isinstance(align[1], int) + ): + h, w = align + else: + raise ValueError( + f"Unsupported image grid align: {align}, should be min, max, int or (int, int)" + ) + + for i in range(len(cols)): + if cols[i].shape[0] != h or cols[i].shape[1] != w: + cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_LINEAR) + return np.concatenate(cols, axis=1) + + def save_image_grid( + self, + filename, + imgs, + align=DEFAULT_GRID_KWARGS["align"], + name: Optional[str] = None, + step: Optional[int] = None, + texts: Optional[List[float]] = None, + ): + save_path = self.get_save_path(filename) + img = self.get_image_grid_(imgs, align=align) + + if texts is not None: + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + black, white = (0, 0, 0), (255, 255, 255) + for i, text in enumerate(texts): + draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) + img = np.asarray(img) + + cv2.imwrite(save_path, img) + if name and self._wandb_logger: + wandb.log({name: wandb.Image(save_path), "trainer/global_step": step}) + return save_path + + def save_image(self, filename, img) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.dtype == np.uint8 or img.dtype == np.uint16 + if img.ndim == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.ndim == 3 and img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(save_path, img) + return save_path + + def save_image_vutils(self, filename, img) -> str: + save_path = self.get_save_path(filename) + vutils.save_image(img, save_path) + return save_path + + def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] + + imgs_full = [] + for start in range(0, img.shape[-1], 3): + img_ = img[..., start : start + 3] + img_ = np.stack( + [ + self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) + for i in range(img_.shape[0]) + ], + axis=0, + ) + size = img_.shape[1] + placeholder = np.zeros((size, size, 3), dtype=np.float32) + img_full = np.concatenate( + [ + np.concatenate( + [placeholder, img_[2], placeholder, placeholder], axis=1 + ), + np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), + np.concatenate( + [placeholder, img_[3], placeholder, placeholder], axis=1 + ), + ], + axis=0, + ) + imgs_full.append(img_full) + + imgs_full = np.concatenate(imgs_full, axis=1) + cv2.imwrite(save_path, imgs_full) + return save_path + + def save_data(self, filename, data) -> str: + data = self.convert_data(data) + if isinstance(data, dict): + if not filename.endswith(".npz"): + filename += ".npz" + save_path = self.get_save_path(filename) + np.savez(save_path, **data) + else: + if not filename.endswith(".npy"): + filename += ".npy" + save_path = self.get_save_path(filename) + np.save(save_path, data) + return save_path + + def save_state_dict(self, filename, data) -> str: + save_path = self.get_save_path(filename) + torch.save(data, save_path) + return save_path + + def save_img_sequence( + self, + filename, + img_dir, + matcher, + save_format="mp4", + fps=30, + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + assert save_format in ["gif", "mp4"] + if not filename.endswith(save_format): + filename += f".{save_format}" + save_path = self.get_save_path(filename) + matcher = re.compile(matcher) + img_dir = os.path.join(self.get_save_dir(), img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == "gif": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) + elif save_format == "mp4": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Video(save_path, format="mp4"), + "trainer/global_step": step, + } + ) + return save_path + + def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None) -> str: + save_path = self.get_save_path(filename) + v_pos = self.convert_data(v_pos) + t_pos_idx = self.convert_data(t_pos_idx) + mesh = trimesh.Trimesh(vertices=v_pos, faces=t_pos_idx) + mesh.export(save_path) + return save_path + + def save_file(self, filename, src_path) -> str: + save_path = self.get_save_path(filename) + shutil.copyfile(src_path, save_path) + return save_path + + def save_txt(self, filename, comment) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(comment) + return save_path + + def save_json(self, filename, payload) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(json.dumps(payload)) + return save_path diff --git a/pixal3d/utils/scheduler.py b/pixal3d/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..052f9185a4bc2b5d9d6746a8d7cc8609eb979ce9 --- /dev/null +++ b/pixal3d/utils/scheduler.py @@ -0,0 +1,108 @@ +import sys +import warnings +from bisect import bisect_right + +import torch +import torch.nn as nn +from torch.optim import lr_scheduler + +import pixal3d + + +def get_scheduler(name): + if hasattr(lr_scheduler, name): + return getattr(lr_scheduler, name) + else: + raise NotImplementedError + + +def getattr_recursive(m, attr): + for name in attr.split("."): + m = getattr(m, name) + return m + + +def get_parameters(model, name): + module = getattr_recursive(model, name) + if isinstance(module, nn.Module): + return module.parameters() + elif isinstance(module, nn.Parameter): + return module + return [] + + +def parse_optimizer(config, model): + if hasattr(config, "params"): + params = [ + {"params": get_parameters(model, name), "name": name, **args} + for name, args in config.params.items() + ] + pixal3d.debug(f"Specify optimizer params: {config.params}") + else: + if hasattr(config, "only_requires_grad") and config.only_requires_grad: + params = list(filter(lambda p: p.requires_grad, model.parameters())) + else: + params = model.parameters() + + if config.name in ["FusedAdam"]: + import apex + + optim = getattr(apex.optimizers, config.name)(params, **config.args) + elif config.name in ["Prodigy"]: + import prodigyopt + + optim = getattr(prodigyopt, config.name)(params, **config.args) + else: + optim = getattr(torch.optim, config.name)(params, **config.args) + return optim + + +def parse_scheduler_to_instance(config, optimizer): + if config.name == "ChainedScheduler": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.ChainedScheduler(schedulers) + elif config.name == "Sequential": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.SequentialLR( + optimizer, schedulers, milestones=config.milestones + ) + else: + scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) + return scheduler + + +def parse_scheduler(config, optimizer): + interval = config.get("interval", "epoch") + assert interval in ["epoch", "step"] + if config.name == "SequentialLR": + scheduler = { + "scheduler": lr_scheduler.SequentialLR( + optimizer, + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ], + milestones=config.milestones, + ), + "interval": interval, + } + elif config.name == "ChainedScheduler": + scheduler = { + "scheduler": lr_scheduler.ChainedScheduler( + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ] + ), + "interval": interval, + } + else: + scheduler = { + "scheduler": get_scheduler(config.name)(optimizer, **config.args), + "interval": interval, + } + return scheduler diff --git a/pixal3d/utils/sparse.py b/pixal3d/utils/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..9682ed73c7f281c6717185a1067ec0e05daafdaa --- /dev/null +++ b/pixal3d/utils/sparse.py @@ -0,0 +1,68 @@ +import torch +import numpy as np + +def sort_block(latent_index, block_size): + device = latent_index.device + latent_index_block = latent_index.cpu().numpy() + latent_index_block[..., 1:] = latent_index_block[..., 1:] // block_size + latent_index_inblock = latent_index.cpu().numpy() + latent_index_inblock[..., 1:] = latent_index_inblock[..., 1:] % block_size + sort_index = np.lexsort(( + latent_index_inblock[..., 3], + latent_index_inblock[..., 2], + latent_index_inblock[..., 1], + latent_index_block[..., 3], + latent_index_block[..., 2], + latent_index_block[..., 1]) + ) + sort_index = torch.from_numpy(sort_index).to(device) + return latent_index[sort_index] + +def extract_tokens_and_coords(conditions, token_mask, num_cls=1, num_reg=4): + device = conditions.device + B = conditions.size(0) + patch_size = token_mask.size(1) + + class_tokens = conditions[:, 0:num_cls, :] # [B, 1, 1024] + register_tokens = conditions[:, num_cls:num_cls+num_reg, :] # [B, 4, 1024] + patch_tokens = conditions[:, num_cls+num_reg:, :] # [B, 1369, 1024] + + selected_tokens_list = [] + coords_list = [] + + for batch_idx in range(B): + cls_tokens = class_tokens[batch_idx] # [1, 1024] + reg_tokens = register_tokens[batch_idx] # [4, 1024] + cls_reg_tokens = torch.cat([cls_tokens, reg_tokens], dim=0) # [5, 1024] + + cls_coord = torch.tensor([[batch_idx, 0, 0, 1]] * num_cls, device=device) + reg_coords = torch.tensor([[batch_idx, 0, 0, 1]] * num_reg, device=device) + cls_reg_coords = torch.cat([cls_coord, reg_coords], dim=0) + + mask = token_mask[batch_idx] + pos = mask.nonzero(as_tuple=False) + K = pos.size(0) + + if K > 0: + h, w = pos[:, 0], pos[:, 1] + indices = h * patch_size + w # + patches = patch_tokens[batch_idx][indices] + + batch_ids = torch.full((K, 1), batch_idx, device=device) + x = w.unsqueeze(1) + y = h.unsqueeze(1) + patch_coords = torch.cat([batch_ids, x, y, torch.zeros((K, 1), device=device)], dim=1) + + combined_tokens = torch.cat([cls_reg_tokens, patches], dim=0) + combined_coords = torch.cat([cls_reg_coords, patch_coords], dim=0) + else: + combined_tokens = cls_reg_tokens + combined_coords = cls_reg_coords + + selected_tokens_list.append(combined_tokens) + coords_list.append(combined_coords) + + selected_tokens = torch.cat(selected_tokens_list, dim=0) + coords = torch.cat(coords_list, dim=0) + + return selected_tokens, coords \ No newline at end of file diff --git a/pixal3d/utils/typing.py b/pixal3d/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..21b1bb2dda7aac2ccc9f26fb47242892b470e671 --- /dev/null +++ b/pixal3d/utils/typing.py @@ -0,0 +1,41 @@ +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, + Sequence, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker diff --git a/pixal3d/utils/util.py b/pixal3d/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..51bd0eaa45994f0a3afe18040854288064535e7a --- /dev/null +++ b/pixal3d/utils/util.py @@ -0,0 +1,19 @@ +import importlib + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) diff --git a/pixal3dpipeline.py b/pixal3dpipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0133956e61bfe0a8f14ef526def4d510784aea --- /dev/null +++ b/pixal3dpipeline.py @@ -0,0 +1,728 @@ +""" +Pixal3D Pipeline +""" + +import os +import shutil +import torch +import torch.nn as nn +import numpy as np +from typing import Optional, Union, List, Tuple +from PIL import Image +from tqdm import tqdm +import trimesh +import json +from omegaconf import OmegaConf +import torchvision.transforms.functional as TF +from torchvision import transforms +import pixal3d +import sys +from pixal3d.modules import sparse as sp +from pixal3d.utils import postprocess_mesh, normalize_mesh, mesh2index, instantiate_from_config +from pixal3d.utils.sparse import sort_block + + + +def preprocess_image(image, resolution=518, padding=20, bg="white"): + """ + Preprocess image for model input. Supports str path, PIL Image, or numpy array. + Returns tensor [4, H, W] for model input. + """ + # Handle different input types + if isinstance(image, str): + img = Image.open(image) + elif isinstance(image, np.ndarray): + img = Image.fromarray(image) + elif isinstance(image, Image.Image): + img = image + else: + raise TypeError(f"Unsupported image type: {type(image)}") + + if img.mode == 'RGB': + resized = img.resize((resolution, resolution), Image.Resampling.BICUBIC) + img_np = np.array(resized).astype(np.float32) / 255.0 + mask = np.ones((resolution, resolution, 1), dtype=np.float32) + img_rgba = np.concatenate([img_np, mask], axis=-1) + else: + + img = img.convert('RGBA') + bbox = img.getbbox() + + if bbox is None: + + bg_val = 255 if bg == 'white' else (128 if bg == 'gray' else np.random.randint(0, 256)) + img_np = np.ones((resolution, resolution, 3), dtype=np.float32) * (bg_val / 255.0) + mask = np.ones((resolution, resolution, 1), dtype=np.float32) + img_rgba = np.concatenate([img_np, mask], axis=-1) + else: + + cropped = img.crop(bbox) + + + if bg == 'white': + bg_color = (255, 255, 255, 255) + elif bg == 'gray': + bg_color = (128, 128, 128, 255) + elif bg == 'random': + bg_color = tuple(np.random.randint(0, 256, size=3).tolist()) + (255,) + else: + bg_color = (255, 255, 255, 255) + + + bg_layer = Image.new('RGBA', cropped.size, bg_color) + cropped_rgb = Image.alpha_composite(bg_layer, cropped).convert('RGB') + + + target_size = resolution - padding * 2 + w, h = cropped_rgb.size + scale = min(target_size / w, target_size / h) + new_w, new_h = int(w * scale), int(h * scale) + + + resized = cropped_rgb.resize((new_w, new_h), Image.Resampling.LANCZOS) + + + result = Image.new('RGB', (resolution, resolution), bg_color[:3]) + offset_x = (resolution - new_w) // 2 + offset_y = (resolution - new_h) // 2 + result.paste(resized, (offset_x, offset_y)) + + img_np = np.array(result).astype(np.float32) / 255.0 + mask = np.ones((resolution, resolution, 1), dtype=np.float32) + img_rgba = np.concatenate([img_np, mask], axis=-1) + + + tensor = torch.from_numpy(img_rgba).permute(2, 0, 1) + return tensor + + +def compute_f_pixels(camera_angle_x, resolution): + """ + Compute focal length in pixels + """ + focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0)) # mm + f_pixels = focal_length * resolution / 32.0 # pixels + return float(f_pixels.item()) + + +def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution): + """ + Derive distance from FOV using analytical relationship. + Returns distance derived from X and Y axes and focal length in pixels. + """ + gp = grid_point.to(torch.float32) + xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item() + xt, yt = float(target_point[0].item()), float(target_point[1].item()) + + f_pixels = compute_f_pixels(camera_angle_x, image_resolution) + + x_ndc = xt - image_resolution / 2.0 + y_ndc = -(yt - image_resolution / 2.0) + + + eps = 1e-8 + if abs(x_ndc) < eps: + raise ValueError("x_ndc too small to stably derive distance from X coordinate") + if abs(y_ndc) < eps: + raise ValueError("y_ndc too small to stably derive distance from Y coordinate") + + + distance_x = f_pixels * xw / x_ndc - yw + + + distance_y = f_pixels * zw / y_ndc - yw + + return { + "distance_from_x": float(distance_x), + "distance_from_y": float(distance_y), + "f_pixels": float(f_pixels), + } + + +# ==================== Pixal3D Pipeline ==================== + +class Pixal3DPipeline: + """ + Pixal3D unified inference pipeline + + Self-contained pipeline integrating Dense and Sparse (512/1024) three-stage inference + """ + + def __init__( + self, + dense_visual_condition, + dense_denoiser_model, + dense_scheduler, + sparse_512_visual_condition, + sparse_512_denoiser_model, + sparse_512_scheduler, + sparse_1024_visual_condition, + sparse_1024_denoiser_model, + sparse_1024_scheduler, + dense_vae, + sparse_vae_512, + sparse_vae_1024, + dense_dtype: torch.dtype = torch.float16, + sparse_dtype: torch.dtype = torch.bfloat16, + ): + """ + Initialize Pixal3D Pipeline + + Args: + dense_visual_condition: Dense visual condition encoder + dense_denoiser_model: Dense denoising model + dense_scheduler: Dense scheduler + sparse_512_visual_condition: Sparse 512 visual condition encoder + sparse_512_denoiser_model: Sparse 512 denoising model + sparse_512_scheduler: Sparse 512 scheduler + sparse_1024_visual_condition: Sparse 1024 visual condition encoder + sparse_1024_denoiser_model: Sparse 1024 denoising model + sparse_1024_scheduler: Sparse 1024 scheduler + dense_vae: Dense VAE model + sparse_vae_512: Sparse VAE 512 model + sparse_vae_1024: Sparse VAE 1024 model + dense_dtype: Dense model dtype (default fp16) + sparse_dtype: Sparse model dtype (default bf16) + """ + self.dense_visual_condition = dense_visual_condition + self.dense_denoiser_model = dense_denoiser_model + self.dense_scheduler = dense_scheduler + + self.sparse_512_visual_condition = sparse_512_visual_condition + self.sparse_512_denoiser_model = sparse_512_denoiser_model + self.sparse_512_scheduler = sparse_512_scheduler + + self.sparse_1024_visual_condition = sparse_1024_visual_condition + self.sparse_1024_denoiser_model = sparse_1024_denoiser_model + self.sparse_1024_scheduler = sparse_1024_scheduler + + self.dense_vae = dense_vae + self.sparse_vae_512 = sparse_vae_512 + self.sparse_vae_1024 = sparse_vae_1024 + + self.device = "cuda" + self.dense_dtype = dense_dtype + self.sparse_dtype = sparse_dtype + + # Set evaluation mode + self._set_eval_mode() + + def _set_eval_mode(self): + """Set all models to evaluation mode""" + self.dense_visual_condition.eval() + self.dense_denoiser_model.eval() + self.sparse_512_visual_condition.eval() + self.sparse_512_denoiser_model.eval() + self.sparse_1024_visual_condition.eval() + self.sparse_1024_denoiser_model.eval() + self.dense_vae.eval() + self.sparse_vae_512.eval() + self.sparse_vae_1024.eval() + + def to(self, device): + """Move all models to specified device""" + self.device = device + self.dense_visual_condition.to(device) + self.dense_denoiser_model.to(device) + self.sparse_512_visual_condition.to(device) + self.sparse_512_denoiser_model.to(device) + self.sparse_1024_visual_condition.to(device) + self.sparse_1024_denoiser_model.to(device) + self.dense_vae.to(device) + self.sparse_vae_512.to(device) + self.sparse_vae_1024.to(device) + return self + + @classmethod + def from_pretrained( + cls, + ckpt_dir: str = "./ckpt", + repo_id: str = None, + dense_dtype: torch.dtype = torch.float16, + sparse_dtype: torch.dtype = torch.float16, + cache_dir: str = None, + ): + """ + Create Pixal3D Pipeline from local directory or HuggingFace Hub. + + Args: + ckpt_dir: Local directory containing converted checkpoints (used when repo_id is None) + repo_id: HuggingFace repo ID (e.g., "TencentARC/Pixal3D-D"). If provided, download from HF Hub. + dense_dtype: Data type for dense stage + sparse_dtype: Data type for sparse stages + cache_dir: Cache directory for downloaded models (default: ~/.cache/huggingface/hub) + + Usage: + # Load from local directory + pipeline = Pixal3DPipeline.from_ckpt("./ckpt") + + # Load from HuggingFace Hub + pipeline = Pixal3DPipeline.from_ckpt(repo_id="TencentARC/Pixal3D-D") + """ + import json + import importlib + from safetensors.torch import load_file + + # Determine source + if repo_id is not None: + # Load from HuggingFace Hub + from huggingface_hub import hf_hub_download, snapshot_download + use_hf_hub = True + print(f"Loading models from HuggingFace Hub: {repo_id}") + else: + # Load from local directory + use_hf_hub = False + print(f"Loading models from local directory: {ckpt_dir}") + + def get_component_path(stage: str, component: str) -> str: + """Get path to component directory.""" + if use_hf_hub: + return f"{stage}/{component}" + else: + return os.path.join(ckpt_dir, stage, component) + + def load_config_hf(repo_id: str, subfolder: str, cache_dir: str = None): + """Load config.json from HuggingFace Hub.""" + config_path = hf_hub_download( + repo_id=repo_id, + subfolder=subfolder, + filename="config.json", + cache_dir=cache_dir, + repo_type="model" + ) + with open(config_path, 'r') as f: + return json.load(f) + + def load_config_local(component_dir: str): + """Load config.json from local directory.""" + config_path = os.path.join(component_dir, "config.json") + with open(config_path, 'r') as f: + return json.load(f) + + def load_config(stage: str, component: str): + """Load config.json from appropriate source.""" + subfolder = f"{stage}/{component}" + if use_hf_hub: + return load_config_hf(repo_id, subfolder, cache_dir) + else: + return load_config_local(os.path.join(ckpt_dir, stage, component)) + + def load_model_hf(repo_id: str, subfolder: str, device="cuda", cache_dir: str = None): + """Load model from HuggingFace Hub.""" + config = load_config_hf(repo_id, subfolder, cache_dir) + model_class_path = config["model_class"] + + # Import class + module_path, class_name = model_class_path.rsplit('.', 1) + module = importlib.import_module(module_path) + model_class = getattr(module, class_name) + + # Build kwargs + if config.get("model_type") == "conditioner": + kwargs = config.get("config", {}) + else: + exclude_keys = {"model_type", "model_class", "scheduler_class", "scheduler_config", "config"} + kwargs = {k: v for k, v in config.items() if k not in exclude_keys} + + # Create model + if hasattr(model_class, 'Config'): + model = model_class(cfg=kwargs) + else: + model = model_class(**kwargs) + + # Load weights if exists (check config first) + if config.get("model_type") not in ["scheduler", "conditioner"]: + safetensors_path = hf_hub_download( + repo_id=repo_id, + subfolder=subfolder, + filename="model.safetensors", + cache_dir=cache_dir, + repo_type="model" + ) + state_dict = load_file(safetensors_path, device=device) + model.load_state_dict(state_dict, strict=True) + + return model + + def load_model_local(component_dir: str, device="cuda"): + """Load model from local directory.""" + config = load_config_local(component_dir) + model_class_path = config["model_class"] + + # Import class + module_path, class_name = model_class_path.rsplit('.', 1) + module = importlib.import_module(module_path) + model_class = getattr(module, class_name) + + # Build kwargs + if config.get("model_type") == "conditioner": + kwargs = config.get("config", {}) + else: + exclude_keys = {"model_type", "model_class", "scheduler_class", "scheduler_config", "config"} + kwargs = {k: v for k, v in config.items() if k not in exclude_keys} + + # Create model + if hasattr(model_class, 'Config'): + model = model_class(cfg=kwargs) + else: + model = model_class(**kwargs) + + # Load weights if exists + safetensors_path = os.path.join(component_dir, "model.safetensors") + if os.path.exists(safetensors_path): + state_dict = load_file(safetensors_path, device=device) + model.load_state_dict(state_dict, strict=True) + + return model + + def load_model(stage: str, component: str, device="cuda"): + """Load model from appropriate source.""" + subfolder = f"{stage}/{component}" + if use_hf_hub: + return load_model_hf(repo_id, subfolder, device, cache_dir) + else: + return load_model_local(os.path.join(ckpt_dir, stage, component), device) + + def load_scheduler(stage: str): + """Load scheduler from appropriate source.""" + config = load_config(stage, "scheduler") + scheduler_class_path = config["scheduler_class"] + scheduler_config = config.get("scheduler_config", {}) + + # Import class + module_path, class_name = scheduler_class_path.rsplit('.', 1) + module = importlib.import_module(module_path) + scheduler_class = getattr(module, class_name) + + return scheduler_class(**scheduler_config) + + def load_conditioner(stage: str): + """Load conditioner (no weights, just config).""" + config = load_config(stage, "conditioner") + conditioner_class_path = config["model_class"] + module_path, class_name = conditioner_class_path.rsplit('.', 1) + module = importlib.import_module(module_path) + conditioner_class = getattr(module, class_name) + visual_condition = conditioner_class(cfg=config.get("config", {})) + visual_condition.to("cuda") + visual_condition.requires_grad_(False) + return visual_condition + + # Load Dense stage + + dense_denoiser_model = load_model("dense", "dit") + dense_denoiser_model.to("cuda") + dense_vae = load_model("dense", "vae") + dense_vae.to("cuda") + dense_vae.eval() + dense_scheduler = load_scheduler("dense") + dense_visual_condition = load_conditioner("dense") + + # Load Sparse 512 stage + + sparse_512_denoiser_model = load_model("sparse512", "dit") + sparse_512_denoiser_model.to("cuda") + sparse_vae_512 = load_model("sparse512", "vae") + sparse_vae_512.to("cuda") + sparse_vae_512.eval() + sparse_512_scheduler = load_scheduler("sparse512") + sparse_512_visual_condition = load_conditioner("sparse512") + + # Load Sparse 1024 stage + + sparse_1024_denoiser_model = load_model("sparse1024", "dit") + sparse_1024_denoiser_model.to("cuda") + sparse_vae_1024 = load_model("sparse1024", "vae") + sparse_vae_1024.to("cuda") + sparse_vae_1024.eval() + sparse_1024_scheduler = load_scheduler("sparse1024") + sparse_1024_visual_condition = load_conditioner("sparse1024") + + print("All models loaded successfully!") + + return cls( + dense_visual_condition=dense_visual_condition, + dense_denoiser_model=dense_denoiser_model, + dense_scheduler=dense_scheduler, + sparse_512_visual_condition=sparse_512_visual_condition, + sparse_512_denoiser_model=sparse_512_denoiser_model, + sparse_512_scheduler=sparse_512_scheduler, + sparse_1024_visual_condition=sparse_1024_visual_condition, + sparse_1024_denoiser_model=sparse_1024_denoiser_model, + sparse_1024_scheduler=sparse_1024_scheduler, + dense_vae=dense_vae, + sparse_vae_512=sparse_vae_512, + sparse_vae_1024=sparse_vae_1024, + dense_dtype=dense_dtype, + sparse_dtype=sparse_dtype, + ) + + # ==================== Image Encoding ==================== + + def encode_image_dense(self, image, camera_angle_x, distance, mesh_scale): + + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=self.dense_dtype): + cond_global, cond_proj = self.dense_visual_condition( + image[:, :3], + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + ) + + + uncond_global = torch.zeros_like(cond_global) + uncond_proj = torch.zeros_like(cond_proj) + + return (cond_global, cond_proj), (uncond_global, uncond_proj) + + def encode_image_sparse(self, image, camera_angle_x, distance, mesh_scale, coords, visual_condition): + + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=self.sparse_dtype): + cond_global, cond_sparse = visual_condition( + image[:, :3], + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + ) + + + bs = cond_sparse.shape[0] + res = visual_condition.grid_resolution + cond_sparse = cond_sparse.reshape(bs, res, res, res, -1) + + batch_indices = coords[:, 0].long() + x_coords = coords[:, 1].long() + y_coords = coords[:, 2].long() + z_coords = coords[:, 3].long() + + cond_sparse = cond_sparse[batch_indices, x_coords, y_coords, z_coords] + + + uncond_global = torch.zeros_like(cond_global) + uncond_sparse = torch.zeros_like(cond_sparse) + + + cond_sparse = sp.SparseTensor(cond_sparse, coords.int()) + uncond_sparse = sp.SparseTensor(uncond_sparse, coords.int()) + + return (cond_global, cond_sparse), (uncond_global, uncond_sparse) + + + + @torch.no_grad() + def infer_dense(self, image, camera_angle_x, distance, mesh_scale, num_steps, guidance_scale, seed): + + batch_size = image.shape[0] + + # Encode conditions + do_cfg = guidance_scale > 0 + image = image.to(torch.float16) + cond, uncond = self.encode_image_dense(image, camera_angle_x, distance, mesh_scale) + + # Initialize latents + latent_shape = (batch_size, *self.dense_denoiser_model.dit_model.latent_shape) + generator = torch.Generator(device=self.device).manual_seed(seed) if seed is not None else None + latents = torch.randn(latent_shape, device=self.device, dtype=cond[0].dtype, generator=generator) + + # Setup scheduler + self.dense_scheduler.set_timesteps(num_steps, device=self.device) + timesteps = self.dense_scheduler.timesteps + + extra_step_kwargs = {'generator': generator} if generator is not None else {} + + # Denoising loop + for i, t in enumerate(tqdm(timesteps, desc="Dense Sampling")): + timestep_tensor = torch.tensor([t], dtype=latents.dtype, device=self.device) + + diffusion_inputs = {"x": latents, "t": timestep_tensor,"cond": cond,} + + with torch.cuda.amp.autocast(dtype=self.dense_dtype): + noise_pred_cond = self.dense_denoiser_model(**diffusion_inputs).sample + + if do_cfg: + diffusion_inputs["cond"] = uncond + noise_pred_uncond = self.dense_denoiser_model(**diffusion_inputs).sample + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.dense_scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + return latents + + @torch.no_grad() + def infer_sparse(self, image, camera_angle_x, distance, mesh_scale, index, num_steps, guidance_scale, seed, + visual_condition, denoiser_model, scheduler): + + batch_size = image.shape[0] + + # Encode conditions + do_cfg = guidance_scale > 0 + cond, uncond = self.encode_image_sparse(image, camera_angle_x, distance, mesh_scale, index, visual_condition) + + # Initialize latents + latent_shape = (index.shape[0], denoiser_model.dit_model.out_channels) + generator = torch.Generator(device=self.device).manual_seed(seed) if seed is not None else None + latents = torch.randn(latent_shape, device=self.device, dtype=cond[0].dtype, generator=generator) + + # Setup scheduler + scheduler.set_timesteps(num_steps, device=self.device) + timesteps = scheduler.timesteps + + extra_step_kwargs = {'generator': generator} if generator is not None else {} + + # Denoising loop + for i, t in enumerate(tqdm(timesteps, desc="Sparse Sampling")): + timestep_tensor = torch.tensor([t], dtype=latents.dtype, device=self.device) + + x_input = sp.SparseTensor(latents, index.int()) + + diffusion_inputs = { + "x": x_input, + "t": timestep_tensor, + "cond": cond, + } + + with torch.cuda.amp.autocast(dtype=self.sparse_dtype): + noise_pred_cond = denoiser_model(**diffusion_inputs).sample + noise_pred_cond = noise_pred_cond.feats + + if do_cfg: + diffusion_inputs["cond"] = uncond + noise_pred_uncond = denoiser_model(**diffusion_inputs).sample + noise_pred_uncond = noise_pred_uncond.feats + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + return sp.SparseTensor(latents, index.int()) + + # ==================== Main Inference Interface ==================== + + @torch.no_grad() + def infer( + self, + image: Union[str, Image.Image, np.ndarray], + dense_steps: int = 50, + dense_guidance_scale: float = 7.0, + dense_seed: int = 0, + sparse_512_steps: int = 30, + sparse_512_guidance_scale: float = 7.0, + sparse_1024_steps: int = 15, + sparse_1024_guidance_scale: float = 7.0, + sparse_seed: int = 0, + dense_threshold: float = 0.1, + mc_threshold: float = 0.2, + + extend_pixel: int = 20, + camera_angle_x: float = 0.2, + mesh_scale: float = 0.9, + ): + """ + Execute complete 1024 resolution inference pipeline, return simplified mesh + """ + # Image preprocessing (always executed) + image_tensor = preprocess_image(image, 518, padding=20).unsqueeze(0).to(self.device) + + # Compute camera distance + image_resolution = 518 + grid_points = torch.tensor([-1.0, 0, -1.0]) + grid_points = grid_points / mesh_scale / 2 + distance = distance_from_fov( + camera_angle_x, grid_points, torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]), mesh_scale, image_resolution + )["distance_from_x"] + + print(f"[Pixal3D] camera_angle_x: {camera_angle_x}, distance: {distance}") + + + camera_angle_x_tensor = torch.tensor([camera_angle_x], device=self.device, dtype=torch.float32) + distance_tensor = torch.tensor([distance], device=self.device, dtype=torch.float32) + mesh_scale_tensor = torch.tensor([mesh_scale], device=self.device, dtype=torch.float32) + + + + + # ============ Step 1: Dense Inference ============ + print(f"[Pixal3D] Step 1: Dense Inference...") + dense_latents = self.infer_dense( + image_tensor, camera_angle_x_tensor, distance_tensor, mesh_scale_tensor, + dense_steps, dense_guidance_scale, dense_seed + ) + + # Decode dense latents to get index + with torch.autocast("cuda", dtype=torch.float16): + decoded_index = self.dense_vae.decode_mesh( + dense_latents, mc_threshold=dense_threshold, return_index=True + )[0] + + decoded_index = sort_block(decoded_index, 8) + print(f"[Pixal3D] decoded_index max: {decoded_index.max(0)}, min: {decoded_index.min(0)}, shape: {decoded_index.shape}") + + # ============ Step 2: Sparse 512 Inference ============ + print(f"[Pixal3D] Step 2: Sparse 512 Inference...") + sparse_latents_512 = self.infer_sparse( + image_tensor, camera_angle_x_tensor, distance_tensor, mesh_scale_tensor, decoded_index, + sparse_512_steps, sparse_512_guidance_scale, sparse_seed, + self.sparse_512_visual_condition, self.sparse_512_denoiser_model, self.sparse_512_scheduler + ) + + # Decode 512 mesh + with torch.autocast("cuda", dtype=torch.float16): + with torch.no_grad(): + decoded_meshs_512 = self.sparse_vae_512.decode_mesh(sparse_latents_512, voxel_resolution=512) + mesh_512 = decoded_meshs_512[0] + + # Clean up memory + del decoded_index, sparse_latents_512 + torch.cuda.empty_cache() + + # ============ Step 3: Prepare 1024 Latent Index ============ + print(f"[Pixal3D] Step 3: Prepare 1024 latent index...") + latent_index_1024 = mesh2index(mesh_512, size=1024, factor=8) + block_size_1024 = getattr(self.sparse_1024_denoiser_model.dit_model, 'selection_block_size', 8) + latent_index_1024 = sort_block(latent_index_1024, block_size_1024) + print(f"[Pixal3D] 1024 latent tokens: {len(latent_index_1024)}") + + # ============ Step 4: Sparse 1024 Inference ============ + print(f"[Pixal3D] Step 4: Sparse 1024 Inference...") + sparse_latents_1024 = self.infer_sparse( + image_tensor, camera_angle_x_tensor, distance_tensor, mesh_scale_tensor, latent_index_1024, + sparse_1024_steps, sparse_1024_guidance_scale, sparse_seed, + self.sparse_1024_visual_condition, self.sparse_1024_denoiser_model, self.sparse_1024_scheduler + ) + + # Decode 1024 mesh and postprocess + with torch.autocast("cuda", dtype=torch.float16): + with torch.no_grad(): + decoded_meshs_1024 = self.sparse_vae_1024.decode_mesh( + sparse_latents_1024, voxel_resolution=1024, mc_threshold=mc_threshold + ) + + # Postprocess mesh + mesh_v, mesh_f = postprocess_mesh( + decoded_meshs_1024[0].vertices, decoded_meshs_1024[0].faces, + simplify=True, verbose=True, + ) + mesh_1024 = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False) + mesh_1024.apply_scale(0.5 / mesh_scale) + + # Apply rotation matrix to align with Blender coordinate system + rotation_matrix = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ]) + mesh_1024.apply_transform(rotation_matrix) + + # Clean up memory + del latent_index_1024, sparse_latents_1024, decoded_meshs_1024 + torch.cuda.empty_cache() + + return mesh_1024 + + def infer_from_image(self, image_path: str, **kwargs): + return self.infer(image=image_path, **kwargs) diff --git a/pixal3dpipeline2stage.py b/pixal3dpipeline2stage.py new file mode 100644 index 0000000000000000000000000000000000000000..a11a85d40327262095288a80b447a18ef03be583 --- /dev/null +++ b/pixal3dpipeline2stage.py @@ -0,0 +1,611 @@ +""" +Pixal3D 2-Stage Pipeline + +Extended pipeline with MoGe FOV estimation and iterative mesh_scale optimization. +Compared to the base Pixal3DPipeline (fixed camera_angle_x=0.2, mesh_scale=0.9), +this pipeline: + 1. Uses MoGe model to estimate camera FOV from the input image + 2. Iteratively optimizes mesh_scale so decoded indices fit within the grid + 3. Optionally uses a separate dense_check model for mesh_scale optimization +""" + +import os +import math +import torch +import numpy as np +from typing import Union +from PIL import Image +import trimesh + +from pixal3d.utils import postprocess_mesh, mesh2index +from pixal3d.utils.sparse import sort_block + +from pixal3dpipeline import ( + Pixal3DPipeline, + preprocess_image, + compute_f_pixels, + distance_from_fov, +) + + +def load_moge_model(device: str = "cuda", model_name: str = "Ruicheng/moge-vitl"): + """Load MoGe model for FOV estimation.""" + print(f"[MoGe] Loading model {model_name}...") + from moge.model.v1 import MoGeModel + moge_model = MoGeModel.from_pretrained(model_name).to(device) + moge_model.eval() + print("[MoGe] Model loaded!") + return moge_model + + +def get_camera_angle_x_from_moge(image_path: str, moge_model, device: str = "cuda") -> float: + """ + Estimate camera_angle_x (horizontal FOV in radians) via MoGe inference. + + Args: + image_path: Input image path (must be square) + moge_model: MoGe model instance + device: Inference device + + Returns: + camera_angle_x in radians + """ + pil_image = Image.open(image_path).convert("RGB") + width, height = pil_image.size + assert width == height, f"Image must be square, but got {width}x{height}" + + image_np = np.array(pil_image).astype(np.float32) / 255.0 + image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device) # [3, H, W] + + with torch.no_grad(): + output = moge_model.infer(image_tensor.unsqueeze(0)) + + intrinsics = output["intrinsics"].squeeze(0).cpu().numpy() # [3, 3] + fx = intrinsics[0, 0] * width + + camera_angle_x = 2 * math.atan(width / (2 * fx)) + print(f"[MoGe] fx={fx:.2f}, width={width}, camera_angle_x={camera_angle_x:.6f} rad ({math.degrees(camera_angle_x):.2f} deg)") + + return camera_angle_x + + +def compute_optimal_mesh_scale( + decoded_index: torch.Tensor, + original_mesh_scale: float, + grid_resolution: int = 64, + padding: int = 3, +) -> float: + """ + Compute optimal mesh_scale so decoded indices fill the grid with target padding. + + Args: + decoded_index: [N, 4] tensor, [:, 1:4] are xyz indices in 64^3 grid + original_mesh_scale: Current mesh scale factor + grid_resolution: Grid resolution (default 64) + padding: Target boundary distance in voxels (default 3) + + Returns: + optimal_mesh_scale + """ + xyz_index = decoded_index[:, 1:4].float() # [N, 3] + + center = (grid_resolution - 1) / 2.0 # 31.5 + offset = xyz_index - center + max_abs_offset = offset.abs().max().item() + target_max_offset = center - padding # 31.5 - 3 = 28.5 + + if max_abs_offset > 0: + scale_factor = target_max_offset / max_abs_offset + else: + scale_factor = 1.0 + + optimal_mesh_scale = original_mesh_scale * scale_factor + print(f"[compute_optimal_mesh_scale] max_abs_offset={max_abs_offset:.4f}, " + f"scale_factor={scale_factor:.4f}, optimal_mesh_scale={optimal_mesh_scale:.6f}") + + return optimal_mesh_scale + + +class Pixal3DPipeline2Stage(Pixal3DPipeline): + """ + 2-Stage Pixal3D Pipeline with MoGe FOV estimation and mesh_scale optimization. + + Inherits all model components and inference methods from Pixal3DPipeline. + Adds: + - MoGe-based camera FOV estimation + - Iterative mesh_scale optimization via dense inference loop + - Optional separate dense_check model for mesh_scale optimization + (uses a different dense checkpoint than the main dense model) + """ + + def __init__( + self, + *args, + moge_model=None, + dense_check_visual_condition=None, + dense_check_denoiser_model=None, + dense_check_scheduler=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.moge_model = moge_model + # Optional separate dense model for mesh_scale optimization + self.dense_check_visual_condition = dense_check_visual_condition + self.dense_check_denoiser_model = dense_check_denoiser_model + self.dense_check_scheduler = dense_check_scheduler + + @property + def has_dense_check(self): + """Whether a separate dense_check model is loaded.""" + return self.dense_check_denoiser_model is not None + + @classmethod + def from_pretrained( + cls, + ckpt_dir: str = "./ckpt", + repo_id: str = None, + dense_dtype: torch.dtype = torch.float16, + sparse_dtype: torch.dtype = torch.float16, + cache_dir: str = None, + use_moge: bool = True, + moge_model_name: str = "Ruicheng/moge-vitl", + use_dense_check: bool = True, + ): + """ + Create Pixal3D 2-Stage Pipeline. + + Same as Pixal3DPipeline.from_pretrained but additionally loads: + - MoGe model for FOV estimation + - Optional dense_check model (separate dense dit for mesh_scale optimization, + stored at dense/scale_init; scheduler & conditioner reuse dense/) + + Args: + ckpt_dir: Local checkpoint directory for main models + repo_id: HuggingFace repo ID for main models + dense_dtype: Dense model dtype + sparse_dtype: Sparse model dtype + cache_dir: HF cache directory + use_moge: Whether to load MoGe model for FOV estimation + moge_model_name: MoGe model name on HuggingFace + use_dense_check: Whether to load dense_check dit from dense/scale_init + """ + import json + import importlib + from safetensors.torch import load_file + + # Use parent class to load all pipeline components + base_pipeline = Pixal3DPipeline.from_pretrained( + ckpt_dir=ckpt_dir, + repo_id=repo_id, + dense_dtype=dense_dtype, + sparse_dtype=sparse_dtype, + cache_dir=cache_dir, + ) + + # Load MoGe model + moge_model = None + if use_moge: + moge_model = load_moge_model(device="cuda", model_name=moge_model_name) + + # Load dense_check dit (only the dit weights differ, scheduler & conditioner reuse dense/) + dense_check_denoiser_model = None + + if use_dense_check: + # Determine scale_init dit path + scale_init_dit_loaded = False + + if repo_id is not None: + from huggingface_hub import hf_hub_download + try: + config_path = hf_hub_download( + repo_id=repo_id, subfolder="dense/scale_init", + filename="config.json", cache_dir=cache_dir, repo_type="model" + ) + with open(config_path, 'r') as f: + config = json.load(f) + module_path, class_name = config["model_class"].rsplit('.', 1) + module = importlib.import_module(module_path) + model_class = getattr(module, class_name) + exclude_keys = {"model_type", "model_class", "scheduler_class", "scheduler_config", "config"} + kwargs = {k: v for k, v in config.items() if k not in exclude_keys} + if hasattr(model_class, 'Config'): + dense_check_denoiser_model = model_class(cfg=kwargs) + else: + dense_check_denoiser_model = model_class(**kwargs) + safetensors_path = hf_hub_download( + repo_id=repo_id, subfolder="dense/scale_init", + filename="model.safetensors", cache_dir=cache_dir, repo_type="model" + ) + state_dict = load_file(safetensors_path, device="cuda") + dense_check_denoiser_model.load_state_dict(state_dict, strict=True) + dense_check_denoiser_model.to("cuda").eval() + scale_init_dit_loaded = True + print("[2-Stage] dense_check dit loaded from HuggingFace (dense/scale_init)") + except Exception as e: + print(f"[2-Stage] dense/scale_init not found on HF: {e}, trying local...") + + if not scale_init_dit_loaded: + local_dit_dir = os.path.join(ckpt_dir, "dense", "scale_init") + config_file = os.path.join(local_dit_dir, "config.json") + safetensors_file = os.path.join(local_dit_dir, "model.safetensors") + if os.path.exists(config_file) and os.path.exists(safetensors_file): + with open(config_file, 'r') as f: + config = json.load(f) + module_path, class_name = config["model_class"].rsplit('.', 1) + module = importlib.import_module(module_path) + model_class = getattr(module, class_name) + exclude_keys = {"model_type", "model_class", "scheduler_class", "scheduler_config", "config"} + kwargs = {k: v for k, v in config.items() if k not in exclude_keys} + if hasattr(model_class, 'Config'): + dense_check_denoiser_model = model_class(cfg=kwargs) + else: + dense_check_denoiser_model = model_class(**kwargs) + state_dict = load_file(safetensors_file, device="cuda") + dense_check_denoiser_model.load_state_dict(state_dict, strict=True) + dense_check_denoiser_model.to("cuda").eval() + print(f"[2-Stage] dense_check dit loaded from local: {local_dit_dir}") + else: + print(f"[2-Stage] dense/scale_init not found locally, dense_check disabled") + + # Create 2-stage pipeline with all components from base + # dense_check reuses base dense scheduler & conditioner + pipeline = cls( + dense_visual_condition=base_pipeline.dense_visual_condition, + dense_denoiser_model=base_pipeline.dense_denoiser_model, + dense_scheduler=base_pipeline.dense_scheduler, + sparse_512_visual_condition=base_pipeline.sparse_512_visual_condition, + sparse_512_denoiser_model=base_pipeline.sparse_512_denoiser_model, + sparse_512_scheduler=base_pipeline.sparse_512_scheduler, + sparse_1024_visual_condition=base_pipeline.sparse_1024_visual_condition, + sparse_1024_denoiser_model=base_pipeline.sparse_1024_denoiser_model, + sparse_1024_scheduler=base_pipeline.sparse_1024_scheduler, + dense_vae=base_pipeline.dense_vae, + sparse_vae_512=base_pipeline.sparse_vae_512, + sparse_vae_1024=base_pipeline.sparse_vae_1024, + dense_dtype=dense_dtype, + sparse_dtype=sparse_dtype, + moge_model=moge_model, + dense_check_denoiser_model=dense_check_denoiser_model, + # scheduler & conditioner reuse main dense + dense_check_visual_condition=base_pipeline.dense_visual_condition if dense_check_denoiser_model else None, + dense_check_scheduler=base_pipeline.dense_scheduler if dense_check_denoiser_model else None, + ) + + return pipeline + + def estimate_fov(self, image_path: str) -> float: + """ + Estimate camera FOV from image using MoGe model. + + Args: + image_path: Path to the preprocessed square image + + Returns: + camera_angle_x in radians + """ + if self.moge_model is None: + raise ValueError("MoGe model not loaded. Set use_moge=True in from_pretrained().") + return get_camera_angle_x_from_moge(image_path, self.moge_model, device=self.device) + + def _infer_dense_check(self, image, camera_angle_x, distance, mesh_scale, num_steps, guidance_scale, seed): + """ + Run dense inference using the dense_check model (for mesh_scale optimization). + Falls back to the main dense model if no dense_check model is loaded. + """ + if not self.has_dense_check: + return self.infer_dense(image, camera_angle_x, distance, mesh_scale, num_steps, guidance_scale, seed) + + from tqdm import tqdm + + batch_size = image.shape[0] + do_cfg = guidance_scale > 0 + image = image.to(torch.float16) + + # Encode conditions using dense_check visual condition + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=self.dense_dtype): + cond_global, cond_proj = self.dense_check_visual_condition( + image[:, :3], + camera_angle_x=camera_angle_x, + distance=distance, + mesh_scale=mesh_scale, + ) + uncond_global = torch.zeros_like(cond_global) + uncond_proj = torch.zeros_like(cond_proj) + cond = (cond_global, cond_proj) + uncond = (uncond_global, uncond_proj) + + # Initialize latents + latent_shape = (batch_size, *self.dense_check_denoiser_model.dit_model.latent_shape) + generator = torch.Generator(device=self.device).manual_seed(seed) if seed is not None else None + latents = torch.randn(latent_shape, device=self.device, dtype=cond[0].dtype, generator=generator) + + # Setup scheduler + self.dense_check_scheduler.set_timesteps(num_steps, device=self.device) + timesteps = self.dense_check_scheduler.timesteps + extra_step_kwargs = {'generator': generator} if generator is not None else {} + + # Denoising loop + for i, t in enumerate(tqdm(timesteps, desc="Dense Check Sampling")): + timestep_tensor = torch.tensor([t], dtype=latents.dtype, device=self.device) + diffusion_inputs = {"x": latents, "t": timestep_tensor, "cond": cond} + + with torch.cuda.amp.autocast(dtype=self.dense_dtype): + noise_pred_cond = self.dense_check_denoiser_model(**diffusion_inputs).sample + if do_cfg: + diffusion_inputs["cond"] = uncond + noise_pred_uncond = self.dense_check_denoiser_model(**diffusion_inputs).sample + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.dense_check_scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + return latents + + def _optimize_mesh_scale( + self, + image_tensor: torch.Tensor, + camera_angle_x_tensor: torch.Tensor, + distance_tensor: torch.Tensor, + initial_mesh_scale: float, + dense_steps: int, + dense_guidance_scale: float, + dense_seed: int, + dense_threshold: float, + target_padding: int = 3, + padding_tolerance_min: int = 2, + padding_tolerance_max: int = 4, + max_iterations: int = 2, + ) -> tuple: + """ + Iteratively optimize mesh_scale so decoded dense indices stay within grid boundaries. + + Uses the dense_check model (if available) for the optimization loop, + then the main dense model is used for final inference in infer(). + + Returns: + optimized_mesh_scale (float) + """ + current_mesh_scale = initial_mesh_scale + best_mesh_scale = initial_mesh_scale + + use_check = self.has_dense_check + check_label = "dense_check" if use_check else "dense" + print(f"[mesh_scale optim] Using {check_label} model for optimization") + + # Initial dense inference with check model + mesh_scale_tensor = torch.tensor([current_mesh_scale], device=self.device, dtype=torch.float32) + dense_latents = self._infer_dense_check( + image_tensor, camera_angle_x_tensor, distance_tensor, mesh_scale_tensor, + dense_steps, dense_guidance_scale, dense_seed + ) + with torch.autocast("cuda", dtype=torch.float16): + decoded_index = self.dense_vae.decode_mesh( + dense_latents, mc_threshold=dense_threshold, return_index=True + )[0] + decoded_index = sort_block(decoded_index, 8) + + for iteration in range(max_iterations): + print(f"[mesh_scale optim] Iteration {iteration + 1}/{max_iterations}") + + optimal_mesh_scale = compute_optimal_mesh_scale( + decoded_index=decoded_index, + original_mesh_scale=current_mesh_scale, + grid_resolution=64, + padding=target_padding, + ) + + # Re-run dense inference with optimized mesh_scale (using check model) + mesh_scale_tensor = torch.tensor([optimal_mesh_scale], device=self.device, dtype=torch.float32) + dense_latents = self._infer_dense_check( + image_tensor, camera_angle_x_tensor, distance_tensor, mesh_scale_tensor, + dense_steps, dense_guidance_scale, dense_seed + ) + with torch.autocast("cuda", dtype=torch.float16): + opt_decoded_index = self.dense_vae.decode_mesh( + dense_latents, mc_threshold=dense_threshold, return_index=True + )[0] + opt_decoded_index = sort_block(opt_decoded_index, 8) + + # Check boundary + xyz_index = opt_decoded_index[:, 1:4] + min_padding = xyz_index.min(dim=0).values.min().item() + max_padding = 63 - xyz_index.max(dim=0).values.max().item() + actual_padding = min(min_padding, max_padding) + + print(f"[mesh_scale optim] mesh_scale={optimal_mesh_scale:.6f}, actual_padding={actual_padding}") + + if padding_tolerance_min <= actual_padding <= padding_tolerance_max: + print(f"[mesh_scale optim] Padding {actual_padding} within [{padding_tolerance_min}, {padding_tolerance_max}], done!") + best_mesh_scale = optimal_mesh_scale + break + elif actual_padding < padding_tolerance_min: + print(f"[mesh_scale optim] Padding {actual_padding} < {padding_tolerance_min}, object too large, reverting") + break + else: + print(f"[mesh_scale optim] Padding {actual_padding} > {padding_tolerance_max}, continuing...") + best_mesh_scale = optimal_mesh_scale + current_mesh_scale = optimal_mesh_scale + decoded_index = opt_decoded_index + else: + print(f"[mesh_scale optim] Reached max iterations {max_iterations}, using best result") + + return best_mesh_scale + + @torch.no_grad() + def infer( + self, + image: Union[str, Image.Image, np.ndarray], + dense_steps: int = 50, + dense_guidance_scale: float = 7.0, + dense_seed: int = 0, + sparse_512_steps: int = 30, + sparse_512_guidance_scale: float = 7.0, + sparse_1024_steps: int = 15, + sparse_1024_guidance_scale: float = 7.0, + sparse_seed: int = 0, + dense_threshold: float = 0.1, + mc_threshold: float = 0.2, + extend_pixel: int = 20, + # 2-stage specific parameters + mesh_scale: float = 0.5, + optimize_mesh_scale: bool = True, + target_padding: int = 3, + max_optim_iterations: int = 2, + ): + """ + Execute 2-stage inference with MoGe FOV estimation and mesh_scale optimization. + + Compared to Pixal3DPipeline.infer (fixed camera_angle_x=0.2, mesh_scale=0.9): + - Uses MoGe to estimate camera_angle_x from the input image + - Iteratively optimizes mesh_scale for better grid utilization + - Default mesh_scale=0.5 (vs 0.9 in base pipeline) + - When dense_check model is loaded, uses it for mesh_scale optimization, + then uses the main dense model for final inference + + Args: + image: Input image (path, PIL Image, or numpy array) + dense_steps: Dense inference steps + dense_guidance_scale: Dense CFG scale + dense_seed: Dense random seed + sparse_512_steps: Sparse 512 inference steps + sparse_512_guidance_scale: Sparse 512 CFG scale + sparse_1024_steps: Sparse 1024 inference steps + sparse_1024_guidance_scale: Sparse 1024 CFG scale + sparse_seed: Sparse random seed + dense_threshold: Dense decoding threshold + mc_threshold: Marching cubes threshold + extend_pixel: Pixel extension for distance computation + mesh_scale: Initial mesh scale (default 0.5) + optimize_mesh_scale: Whether to iteratively optimize mesh_scale + target_padding: Target boundary padding for optimization + max_optim_iterations: Max iterations for mesh_scale optimization + + Returns: + trimesh.Trimesh: Postprocessed 1024-resolution mesh + """ + # Image preprocessing + image_tensor = preprocess_image(image, 518, padding=20).unsqueeze(0).to(self.device) + + # Save preprocessed image for MoGe (MoGe needs the cropped/padded image, not the original) + import tempfile + img_np = (image_tensor[0, :3].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + Image.fromarray(img_np).save(tmp_img.name) + moge_image_path = tmp_img.name + + # Estimate camera FOV via MoGe + camera_angle_x = self.estimate_fov(moge_image_path) + + # Clean up temporary file + os.unlink(tmp_img.name) + + # Compute camera distance + image_resolution = 518 + grid_points = torch.tensor([-1.0, 0, -1.0]) + grid_points = grid_points / mesh_scale / 2 + distance = distance_from_fov( + camera_angle_x, grid_points, + torch.tensor([0 - extend_pixel, image_resolution + extend_pixel]), + mesh_scale, image_resolution + )["distance_from_x"] + + print(f"[Pixal3D-2Stage] camera_angle_x: {camera_angle_x:.6f}, distance: {distance:.4f}, mesh_scale: {mesh_scale}") + + camera_angle_x_tensor = torch.tensor([camera_angle_x], device=self.device, dtype=torch.float32) + distance_tensor = torch.tensor([distance], device=self.device, dtype=torch.float32) + + # ============ Step 1: Dense Inference + mesh_scale Optimization ============ + if optimize_mesh_scale: + print(f"[Pixal3D-2Stage] Step 1: mesh_scale optimization (using {'dense_check' if self.has_dense_check else 'dense'} model)...") + mesh_scale = self._optimize_mesh_scale( + image_tensor, camera_angle_x_tensor, distance_tensor, + initial_mesh_scale=mesh_scale, + dense_steps=dense_steps, + dense_guidance_scale=dense_guidance_scale, + dense_seed=dense_seed, + dense_threshold=dense_threshold, + target_padding=target_padding, + max_iterations=max_optim_iterations, + ) + print(f"[Pixal3D-2Stage] Optimized mesh_scale: {mesh_scale:.6f}") + + # Always run final dense inference with the MAIN dense model + print(f"[Pixal3D-2Stage] Step 1b: Final Dense Inference with main dense model...") + mesh_scale_tensor = torch.tensor([mesh_scale], device=self.device, dtype=torch.float32) + dense_latents = self.infer_dense( + image_tensor, camera_angle_x_tensor, distance_tensor, mesh_scale_tensor, + dense_steps, dense_guidance_scale, dense_seed + ) + with torch.autocast("cuda", dtype=torch.float16): + decoded_index = self.dense_vae.decode_mesh( + dense_latents, mc_threshold=dense_threshold, return_index=True + )[0] + decoded_index = sort_block(decoded_index, 8) + + print(f"[Pixal3D-2Stage] decoded_index shape: {decoded_index.shape}, " + f"max: {decoded_index.max(0).values.tolist()}, min: {decoded_index.min(0).values.tolist()}") + + # ============ Step 2: Sparse 512 Inference ============ + print(f"[Pixal3D-2Stage] Step 2: Sparse 512 Inference...") + sparse_latents_512 = self.infer_sparse( + image_tensor, camera_angle_x_tensor, distance_tensor, mesh_scale_tensor, decoded_index, + sparse_512_steps, sparse_512_guidance_scale, sparse_seed, + self.sparse_512_visual_condition, self.sparse_512_denoiser_model, self.sparse_512_scheduler + ) + + # Decode 512 mesh + with torch.autocast("cuda", dtype=torch.float16): + with torch.no_grad(): + decoded_meshs_512 = self.sparse_vae_512.decode_mesh(sparse_latents_512, voxel_resolution=512) + mesh_512 = decoded_meshs_512[0] + + del decoded_index, sparse_latents_512 + torch.cuda.empty_cache() + + # ============ Step 3: Prepare 1024 Latent Index ============ + print(f"[Pixal3D-2Stage] Step 3: Prepare 1024 latent index...") + latent_index_1024 = mesh2index(mesh_512, size=1024, factor=8) + block_size_1024 = getattr(self.sparse_1024_denoiser_model.dit_model, 'selection_block_size', 8) + latent_index_1024 = sort_block(latent_index_1024, block_size_1024) + print(f"[Pixal3D-2Stage] 1024 latent tokens: {len(latent_index_1024)}") + + # ============ Step 4: Sparse 1024 Inference ============ + print(f"[Pixal3D-2Stage] Step 4: Sparse 1024 Inference...") + sparse_latents_1024 = self.infer_sparse( + image_tensor, camera_angle_x_tensor, distance_tensor, mesh_scale_tensor, latent_index_1024, + sparse_1024_steps, sparse_1024_guidance_scale, sparse_seed, + self.sparse_1024_visual_condition, self.sparse_1024_denoiser_model, self.sparse_1024_scheduler + ) + + # Decode 1024 mesh and postprocess + with torch.autocast("cuda", dtype=torch.float16): + with torch.no_grad(): + decoded_meshs_1024 = self.sparse_vae_1024.decode_mesh( + sparse_latents_1024, voxel_resolution=1024, mc_threshold=mc_threshold + ) + + mesh_v, mesh_f = postprocess_mesh( + decoded_meshs_1024[0].vertices, decoded_meshs_1024[0].faces, + simplify=True, verbose=True, + ) + mesh_1024 = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False) + mesh_1024.apply_scale(0.5 / mesh_scale) + + # Apply rotation matrix to align with Blender coordinate system + rotation_matrix = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ]) + mesh_1024.apply_transform(rotation_matrix) + + del latent_index_1024, sparse_latents_1024, decoded_meshs_1024 + torch.cuda.empty_cache() + + return mesh_1024 + + def infer_from_image(self, image_path: str, **kwargs): + return self.infer(image=image_path, **kwargs) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9191cf6b71dde786f830aafb105c8288d1a26d20 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +numpy==1.26.4 +Pillow==11.3.0 +tqdm==4.67.1 +trimesh==4.10.1 +omegaconf==2.3.0 +diffusers==0.36.0 +huggingface_hub>=0.36.0 +jaxtyping==0.2.28 +typeguard==2.13.3 +packaging==24.2 +einops==0.8.1 +https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +triton==3.1.0 +scikit-image==0.25.2 +pyvista==0.46.4 +gradio>=6.2.0 +safetensors==0.7.0 +transformers==4.40.2 +kornia==0.8.2 +timm==1.0.24 +moge @ git+https://github.com/microsoft/MoGe.git@07444410f1e33f402353b99d6ccd26bd31e469e8 +https://github.com/LDYang694/Storages/releases/download/20260430/natten-0.21.0+torch2.5cu124-cp310-cp310-linux_x86_64.whl +https://github.com/LDYang694/Storages/releases/download/20260430/torchsparse-2.1.0+torch2.5cu124-cp310-cp310-linux_x86_64.whl +https://github.com/LDYang694/Storages/releases/download/20260430/udf_ext-0.0.0+torch2.5cu124-cp310-cp310-linux_x86_64.whl