Spaces:
Runtime error
Runtime error
| from typing import * | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 | |
| from ...modules import sparse as sp | |
| from .base import SparseTransformerBase | |
| class SparseSubdivideBlock3d(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| out_channels: Optional[int] = None, | |
| use_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_checkpoint = use_checkpoint | |
| self.act_layers = nn.Sequential( | |
| sp.SparseConv3d(channels, self.out_channels, 3, padding=1), | |
| sp.SparseSiLU() | |
| ) | |
| self.sub = sp.SparseSubdivide() | |
| self.out_layers = nn.Sequential( | |
| sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1), | |
| sp.SparseSiLU(), | |
| ) | |
| def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| h = self.act_layers(x) | |
| h = self.sub(h) | |
| h = self.out_layers(h) | |
| return h | |
| def forward(self, x: torch.Tensor): | |
| if self.use_checkpoint: | |
| return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| class SparseSDFDecoder(SparseTransformerBase): | |
| def __init__( | |
| self, | |
| resolution: int, | |
| model_channels: int, | |
| latent_channels: int, | |
| num_blocks: int, | |
| num_heads: Optional[int] = None, | |
| num_head_channels: Optional[int] = 64, | |
| mlp_ratio: float = 4, | |
| attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", | |
| window_size: int = 8, | |
| pe_mode: Literal["ape", "rope"] = "ape", | |
| use_fp16: bool = False, | |
| use_checkpoint: bool = False, | |
| qk_rms_norm: bool = False, | |
| representation_config: dict = None, | |
| out_channels: int = 1, | |
| chunk_size: int = 1, | |
| ): | |
| super().__init__( | |
| in_channels=latent_channels, | |
| model_channels=model_channels, | |
| num_blocks=num_blocks, | |
| num_heads=num_heads, | |
| num_head_channels=num_head_channels, | |
| mlp_ratio=mlp_ratio, | |
| attn_mode=attn_mode, | |
| window_size=window_size, | |
| pe_mode=pe_mode, | |
| use_fp16=use_fp16, | |
| use_checkpoint=use_checkpoint, | |
| qk_rms_norm=qk_rms_norm, | |
| ) | |
| self.resolution = resolution | |
| self.rep_config = representation_config | |
| self.out_channels = out_channels | |
| self.chunk_size = chunk_size | |
| self.upsample = nn.ModuleList([ | |
| SparseSubdivideBlock3d( | |
| channels=model_channels, | |
| out_channels=model_channels // 4, | |
| use_checkpoint=use_checkpoint, | |
| ), | |
| SparseSubdivideBlock3d( | |
| channels=model_channels // 4, | |
| out_channels=model_channels // 8, | |
| use_checkpoint=use_checkpoint, | |
| ), | |
| SparseSubdivideBlock3d( | |
| channels=model_channels // 8, | |
| out_channels=model_channels // 16, | |
| use_checkpoint=use_checkpoint, | |
| ) | |
| ]) | |
| self.out_layer = sp.SparseLinear(model_channels // 16, self.out_channels) | |
| self.out_active = sp.SparseTanh() | |
| self.initialize_weights() | |
| if use_fp16: | |
| self.convert_to_fp16() | |
| def initialize_weights(self) -> None: | |
| super().initialize_weights() | |
| # Zero-out output layers: | |
| nn.init.constant_(self.out_layer.weight, 0) | |
| nn.init.constant_(self.out_layer.bias, 0) | |
| def convert_to_fp16(self) -> None: | |
| """ | |
| Convert the torso of the model to float16. | |
| """ | |
| super().convert_to_fp16() | |
| self.upsample.apply(convert_module_to_f16) | |
| def convert_to_fp32(self) -> None: | |
| """ | |
| Convert the torso of the model to float32. | |
| """ | |
| super().convert_to_fp32() | |
| self.upsample.apply(convert_module_to_f32) | |
| def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4): | |
| sub_resolution = self.resolution // chunk_size | |
| upsample_ratio = 8 # hard-coded here | |
| assert sub_resolution % padding == 0 | |
| out = [] | |
| for i in range(chunk_size): | |
| for j in range(chunk_size): | |
| for k in range(chunk_size): | |
| # Calculate padded boundaries | |
| start_x = max(0, i * sub_resolution - padding) | |
| end_x = min((i + 1) * sub_resolution + padding, self.resolution) | |
| start_y = max(0, j * sub_resolution - padding) | |
| end_y = min((j + 1) * sub_resolution + padding, self.resolution) | |
| start_z = max(0, k * sub_resolution - padding) | |
| end_z = min((k + 1) * sub_resolution + padding, self.resolution) | |
| # Store original (unpadded) boundaries for later cropping | |
| orig_start_x = i * sub_resolution | |
| orig_end_x = (i + 1) * sub_resolution | |
| orig_start_y = j * sub_resolution | |
| orig_end_y = (j + 1) * sub_resolution | |
| orig_start_z = k * sub_resolution | |
| orig_end_z = (k + 1) * sub_resolution | |
| mask = torch.logical_and( | |
| torch.logical_and( | |
| torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), | |
| torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) | |
| ), | |
| torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) | |
| ) | |
| if mask.sum() > 0: | |
| # Get the coordinates and shift them to local space | |
| coords = x.coords[mask].clone() | |
| # Shift to local coordinates | |
| coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], | |
| device=coords.device).view(1, 3) | |
| chunk_tensor = sp.SparseTensor(x.feats[mask], coords) | |
| # Store the boundaries and offsets as metadata for later reconstruction | |
| chunk_tensor.bounds = { | |
| 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), | |
| 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction | |
| } | |
| out.append(chunk_tensor) | |
| del mask | |
| torch.cuda.empty_cache() | |
| return out | |
| def split_single_chunk(self, x: sp.SparseTensor, chunk_size=4, padding=4): | |
| sub_resolution = self.resolution // chunk_size | |
| upsample_ratio = 8 # hard-coded here | |
| assert sub_resolution % padding == 0 | |
| mask_sum = -1 | |
| while mask_sum < 1: | |
| orig_start_x = random.randint(0, self.resolution - sub_resolution) | |
| orig_end_x = orig_start_x + sub_resolution | |
| orig_start_y = random.randint(0, self.resolution - sub_resolution) | |
| orig_end_y = orig_start_y + sub_resolution | |
| orig_start_z = random.randint(0, self.resolution - sub_resolution) | |
| orig_end_z = orig_start_z + sub_resolution | |
| start_x = max(0, orig_start_x - padding) | |
| end_x = min(orig_end_x + padding, self.resolution) | |
| start_y = max(0, orig_start_y - padding) | |
| end_y = min(orig_end_y + padding, self.resolution) | |
| start_z = max(0, orig_start_z - padding) | |
| end_z = min(orig_end_z + padding, self.resolution) | |
| mask_ori = torch.logical_and( | |
| torch.logical_and( | |
| torch.logical_and(x.coords[:, 1] >= orig_start_x, x.coords[:, 1] < orig_end_x), | |
| torch.logical_and(x.coords[:, 2] >= orig_start_y, x.coords[:, 2] < orig_end_y) | |
| ), | |
| torch.logical_and(x.coords[:, 3] >= orig_start_z, x.coords[:, 3] < orig_end_z) | |
| ) | |
| mask_sum = mask_ori.sum() | |
| # Store the boundaries and offsets as metadata for later reconstruction | |
| bounds = { | |
| 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), | |
| 'start': (start_x, end_x, start_y, end_y, start_z, end_z), | |
| 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction | |
| } | |
| return bounds | |
| def forward_single_chunk(self, x: sp.SparseTensor, padding=4): | |
| bounds = self.split_single_chunk(x, self.chunk_size, padding=padding) | |
| start_x, end_x, start_y, end_y, start_z, end_z = bounds['start'] | |
| mask = torch.logical_and( | |
| torch.logical_and( | |
| torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), | |
| torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) | |
| ), | |
| torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) | |
| ) | |
| # Shift to local coordinates | |
| coords = x.coords.clone() | |
| coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], | |
| device=coords.device).view(1, 3) | |
| chunk = sp.SparseTensor(x.feats[mask], coords[mask]) | |
| chunk_result = self.upsamples(chunk) | |
| coords = chunk_result.coords.clone() | |
| # Restore global coordinates | |
| offsets = torch.tensor(bounds['offsets'], | |
| device=coords.device).view(1, 3) | |
| coords[:, 1:] = coords[:, 1:] + offsets | |
| # Filter points within original bounds | |
| original = bounds['original'] | |
| within_bounds = torch.logical_and( | |
| torch.logical_and( | |
| torch.logical_and( | |
| coords[:, 1] >= original[0], | |
| coords[:, 1] < original[1] | |
| ), | |
| torch.logical_and( | |
| coords[:, 2] >= original[2], | |
| coords[:, 2] < original[3] | |
| ) | |
| ), | |
| torch.logical_and( | |
| coords[:, 3] >= original[4], | |
| coords[:, 3] < original[5] | |
| ) | |
| ) | |
| final_coords = coords[within_bounds] | |
| final_feats = chunk_result.feats[within_bounds] | |
| return sp.SparseTensor(final_feats, final_coords) | |
| def upsamples(self, x, return_feat: bool = False): | |
| dtype = x.dtype | |
| for block in self.upsample: | |
| x = block(x) | |
| x = x.type(dtype) | |
| output = self.out_active(self.out_layer(x)) | |
| if return_feat: | |
| return output, x | |
| else: | |
| return output | |
| def forward(self, x: sp.SparseTensor, factor: float = None, return_feat: bool = False): | |
| h = super().forward(x, factor) | |
| if self.chunk_size <= 1: | |
| for block in self.upsample: | |
| h = block(h) | |
| h = h.type(x.dtype) | |
| if return_feat: | |
| return self.out_active(self.out_layer(h)), h | |
| h = self.out_layer(h) | |
| h = self.out_active(h) | |
| return h | |
| else: | |
| if self.training: | |
| return self.forward_single_chunk(h) | |
| else: | |
| batch_size = x.shape[0] | |
| chunks = self.split_for_meshing(h, chunk_size=self.chunk_size) | |
| all_coords, all_feats = [], [] | |
| for chunk_idx, chunk in enumerate(chunks): | |
| chunk_result = self.upsamples(chunk) | |
| for b in range(batch_size): | |
| mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1) | |
| if mask.numel() > 0: | |
| coords = chunk_result.coords[mask].clone() | |
| # Restore global coordinates | |
| offsets = torch.tensor(chunk.bounds['offsets'], | |
| device=coords.device).view(1, 3) | |
| coords[:, 1:] = coords[:, 1:] + offsets | |
| # Filter points within original bounds | |
| bounds = chunk.bounds['original'] | |
| within_bounds = torch.logical_and( | |
| torch.logical_and( | |
| torch.logical_and( | |
| coords[:, 1] >= bounds[0], | |
| coords[:, 1] < bounds[1] | |
| ), | |
| torch.logical_and( | |
| coords[:, 2] >= bounds[2], | |
| coords[:, 2] < bounds[3] | |
| ) | |
| ), | |
| torch.logical_and( | |
| coords[:, 3] >= bounds[4], | |
| coords[:, 3] < bounds[5] | |
| ) | |
| ) | |
| if within_bounds.any(): | |
| all_coords.append(coords[within_bounds]) | |
| all_feats.append(chunk_result.feats[mask][within_bounds]) | |
| if not self.training: | |
| torch.cuda.empty_cache() | |
| final_coords = torch.cat(all_coords) | |
| final_feats = torch.cat(all_feats) | |
| return sp.SparseTensor(final_feats, final_coords) | |