Spaces:
Running
Running
File size: 17,806 Bytes
8d595ff 551545a 8d595ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 | import os
import json
from typing import *
import numpy as np
import torch
import utils3d
from .. import models
from .components import ImageConditionedMixin, ViewImageConditionedMixin
from ..modules.sparse import SparseTensor
from .structured_latent import SLatVisMixin, SLat
from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics
from ..utils.data_utils import load_balanced_group_indices
class SLatShapeVisMixin(SLatVisMixin):
def _loading_slat_dec(self):
if self.slat_dec is not None:
return
if self.slat_dec_path is not None:
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_slat_dec)
decoder.set_resolution(self.resolution)
self.slat_dec = decoder.cuda().eval()
@torch.no_grad()
def visualize_sample(
self,
x_0: Union[SparseTensor, dict],
camera_angle_x: Optional[torch.Tensor] = None,
camera_distance: Optional[torch.Tensor] = None,
mesh_scale: Optional[torch.Tensor] = None,
):
"""
Visualize shape samples.
Args:
x_0: SparseTensor or dict containing 'x_0'
camera_angle_x: Optional [B] camera FOV angle in radians
camera_distance: Optional [B] camera distance for GT view rendering
mesh_scale: Optional [B] mesh scale factor for coordinate alignment
Returns:
dict with:
'multiview': [B, 3, 1024, 1024] - 4 fixed views rendered in 2x2 grid (normal)
'gt_view': [B, 3, 512, 512] - GT camera view (if camera params provided)
"""
x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
reps = self.decode_latent(x_0.cuda())
# build fixed camera views (4 views: 0°, 90°, 180°, 270°)
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
yaw_offset = -16 / 180 * np.pi
yaw = [y + yaw_offset for y in yaw]
pitch = [20 / 180 * np.pi for _ in range(4)]
fixed_exts, fixed_ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
# Check if we have GT camera parameters for GT view rendering
has_gt_camera = (
camera_angle_x is not None and
camera_distance is not None and
mesh_scale is not None
)
# render
renderer = get_renderer(reps[0])
multiview_images = []
gt_view_images = []
for i, representation in enumerate(reps):
# Render 4 fixed views (2x2 grid)
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
# Validate mesh data before rasterization
verts = representation.vertices
faces = representation.faces
if verts.shape[0] == 0 or faces.shape[0] == 0:
print(f"[visualize_sample] Warning: sample {i} has empty mesh, skipping")
multiview_images.append(image)
continue
if faces.max() >= verts.shape[0]:
print(f"[visualize_sample] Warning: sample {i} has out-of-bound face indices "
f"(max face idx={faces.max().item()}, num verts={verts.shape[0]}), skipping")
multiview_images.append(image)
continue
if torch.isnan(verts).any() or torch.isinf(verts).any():
print(f"[visualize_sample] Warning: sample {i} has NaN/Inf vertices, skipping")
multiview_images.append(image)
continue
try:
for j, (ext, intr) in enumerate(zip(fixed_exts, fixed_ints)):
res = renderer.render(representation, ext, intr)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['normal']
except RuntimeError as e:
print(f"[visualize_sample] Warning: render failed for sample {i}: {e}")
image = torch.zeros(3, 1024, 1024).cuda()
multiview_images.append(image)
# Render GT camera view using the fixed front view (same as sparse_structure_latent.py)
if has_gt_camera:
# The GT view should match exactly how ProjGrid projects 3D points to 2D.
#
# In image_conditioned_proj.py (ProjGrid.forward):
# 1. grid_points are in [-1, 1]^3 (from torch.linspace(-1, 1, res))
# 2. grid_points are rotated by rotation_matrix (Y-Z swap): x'=x, y'=-z, z'=y
# 3. grid_points are scaled: grid_points / mesh_scale / 2
# 4. Points are projected using front_view_transform_matrix with distance
#
# Mesh vertices are in [-0.5, 0.5]^3. To match ProjGrid's coordinate space,
# we need to scale them: vertices / mesh_scale -> [-0.5/s, 0.5/s]^3
# This is equivalent to ProjGrid's: [-1,1]^3 / scale / 2 -> [-0.5/s, 0.5/s]^3
#
# Camera position: ProjGrid camera is at (0, -distance, 0) in Blender coords (Z-up).
# After inverse rotation to mesh space, camera is at (0, 0, distance).
scale = mesh_scale[i].item()
distance = camera_distance[i].item()
fov = camera_angle_x[i].item()
device = representation.vertices.device
# Scale mesh vertices to match ProjGrid's projection space
from ..representations import Mesh
scaled_rep = Mesh(
vertices=representation.vertices / scale,
faces=representation.faces,
)
cam_pos = torch.tensor([0.0, 0.0, distance], device=device)
look_at = torch.tensor([0.0, 0.0, 0.0], device=device)
cam_up = torch.tensor([0.0, 1.0, 0.0], device=device)
gt_ext = utils3d.torch.extrinsics_look_at(cam_pos, look_at, cam_up)
gt_int = utils3d.torch.intrinsics_from_fov_xy(
torch.tensor(fov, device=device),
torch.tensor(fov, device=device)
)
gt_ext = gt_ext.to(device)
gt_int = gt_int.to(device)
# Use scaled mesh renderer with appropriate near/far for smaller mesh
mesh_half_size = 0.5 / scale
renderer.rendering_options.near = max(0.01, distance - mesh_half_size - 0.5)
renderer.rendering_options.far = distance + mesh_half_size + 0.5
try:
gt_res = renderer.render(scaled_rep, gt_ext, gt_int)
gt_view_images.append(gt_res['normal'])
except RuntimeError as e:
print(f"[visualize_sample] Warning: GT view render failed for sample {i}: {e}")
gt_view_images.append(torch.full((3, 512, 512), 0.5, device=device))
result = {
'multiview': torch.stack(multiview_images),
}
if has_gt_camera and len(gt_view_images) > 0:
result['gt_view'] = torch.stack(gt_view_images)
return result
class SLatShape(SLatShapeVisMixin, SLat):
"""
structured latent for shape generation
Args:
roots (str): path to the dataset
resolution (int): resolution of the shape
min_aesthetic_score (float): minimum aesthetic score
max_tokens (int): maximum number of tokens
latent_key (str): key of the latent to be used
normalization (dict): normalization stats
pretrained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
skip_list (str, optional): path to a file containing sha256 hashes to skip
skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check
"""
def __init__(self,
roots: str,
*,
resolution: int,
min_aesthetic_score: float = 5.0,
max_tokens: int = 32768,
normalization: Optional[dict] = None,
pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
slat_dec_path: Optional[str] = None,
slat_dec_ckpt: Optional[str] = None,
skip_list: Optional[str] = None,
skip_aesthetic_score_datasets: Optional[list] = None,
):
super().__init__(
roots,
min_aesthetic_score=min_aesthetic_score,
max_tokens=max_tokens,
latent_key='shape_latent',
normalization=normalization,
pretrained_slat_dec=pretrained_slat_dec,
slat_dec_path=slat_dec_path,
slat_dec_ckpt=slat_dec_ckpt,
skip_list=skip_list,
skip_aesthetic_score_datasets=skip_aesthetic_score_datasets,
)
self.resolution = resolution
class ImageConditionedSLatShape(ImageConditionedMixin, SLatShape):
"""
Image conditioned structured latent for shape generation
"""
pass
class SLatShapeView(SLatShapeVisMixin, SLat):
"""
View-based structured latent for shape generation.
Data format: {sha256}/view{XX}.npz where each npz contains 'coords' and 'feats' keys.
Args:
roots (str): path to the dataset
resolution (int): resolution of the shape
min_aesthetic_score (float): minimum aesthetic score
max_tokens (int): maximum number of tokens
num_views (int): Number of views to use (0 to num_views-1). Default is 2.
normalization (dict): normalization stats
pretrained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
skip_list (str, optional): path to a file containing sha256 hashes to skip
skip_aesthetic_score_datasets (list, optional): list of dataset names to skip aesthetic score check
"""
def __init__(self,
roots: str,
*,
resolution: int,
min_aesthetic_score: float = 5.0,
max_tokens: int = 32768,
num_views: int = 2,
normalization: Optional[dict] = None,
pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16',
slat_dec_path: Optional[str] = None,
slat_dec_ckpt: Optional[str] = None,
skip_list: Optional[str] = None,
skip_aesthetic_score_datasets: Optional[list] = None,
):
self.normalization = normalization
self.min_aesthetic_score = min_aesthetic_score
self.max_tokens = max_tokens
self.num_views = num_views
self.latent_key = 'shape_latent'
self.value_range = (0, 1)
# Initialize parent with SLatVisMixin parameters
from .components import StandardDatasetBase
SLatVisMixin.__init__(
self,
roots,
pretrained_slat_dec=pretrained_slat_dec,
slat_dec_path=slat_dec_path,
slat_dec_ckpt=slat_dec_ckpt,
)
StandardDatasetBase.__init__(self, roots, skip_list=skip_list, skip_aesthetic_score_datasets=skip_aesthetic_score_datasets)
self.resolution = resolution
# Calculate loads for load balancing
self.loads = []
for _, sha256, _ in self.instances:
if f'{self.latent_key}_tokens' in self.metadata.columns:
try:
self.loads.append(self.metadata.loc[sha256, f'{self.latent_key}_tokens'])
except:
self.loads.append(self.max_tokens)
else:
self.loads.append(self.max_tokens)
if self.normalization is not None:
self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
def filter_metadata(self, metadata, dataset_name=None):
stats = {}
# View-based shape_latent uses columns like shape_latent_view00_encoded, shape_latent_view01_encoded, etc.
required_view_cols = [f'shape_latent_view{i:02d}_encoded' for i in range(self.num_views)]
existing_view_cols = [col for col in required_view_cols if col in metadata.columns]
if existing_view_cols:
# Filter rows where all required views are encoded
# Note: NaN should be treated as False, so use == True for explicit comparison
has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
metadata = metadata[has_all_views]
stats[f'With {self.num_views} view latents'] = len(metadata)
else:
# Fallback: check shape_latent_encoded column
if f'{self.latent_key}_encoded' in metadata.columns:
metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True]
stats['With latent'] = len(metadata)
# Skip aesthetic score check for specified datasets (e.g., texverse) or if column doesn't exist
skip_aesthetic = (
(dataset_name and dataset_name.lower() in [d.lower() for d in self.skip_aesthetic_score_datasets]) or
('aesthetic_score' not in metadata.columns)
)
if skip_aesthetic:
stats[f'Aesthetic score check skipped'] = len(metadata)
else:
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
# Filter by max_tokens if column exists
tokens_col = f'{self.latent_key}_tokens'
if tokens_col in metadata.columns:
metadata = metadata[metadata[tokens_col] <= self.max_tokens]
stats[f'Num tokens <= {self.max_tokens}'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
# View-based format: directory with view{XX}.npz files
latent_dir = os.path.join(root[self.latent_key], instance)
# Randomly select a view from the configured range
view_idx = np.random.randint(0, self.num_views)
view_file = f'view{view_idx:02d}.npz'
# Store view info for ViewImageConditionedMixin
self._current_view_idx = view_idx
self._current_latent_dir = latent_dir
data = np.load(os.path.join(latent_dir, view_file))
coords = torch.tensor(data['coords']).int()
feats = torch.tensor(data['feats']).float()
if self.normalization is not None:
feats = (feats - self.mean) / self.std
return {
'coords': coords,
'feats': feats,
'view_idx': view_idx,
}
@staticmethod
def collate_fn(batch, split_size=None):
if split_size is None:
group_idx = [list(range(len(batch)))]
else:
group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
packs = []
for group in group_idx:
sub_batch = [batch[i] for i in group]
pack = {}
coords = []
feats = []
layout = []
start = 0
for i, b in enumerate(sub_batch):
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
feats.append(b['feats'])
layout.append(slice(start, start + b['coords'].shape[0]))
start += b['coords'].shape[0]
coords = torch.cat(coords)
feats = torch.cat(feats)
pack['x_0'] = SparseTensor(
coords=coords,
feats=feats,
)
pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
pack['x_0'].register_spatial_cache('layout', layout)
# collate other data
keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
for k in keys:
if isinstance(sub_batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in sub_batch])
elif isinstance(sub_batch[0][k], list):
pack[k] = sum([b[k] for b in sub_batch], [])
else:
pack[k] = [b[k] for b in sub_batch]
packs.append(pack)
if split_size is None:
return packs[0]
return packs
class ViewImageConditionedSLatShapeView(ViewImageConditionedMixin, SLatShapeView):
"""
Image-conditioned view-based structured latent for shape generation.
Loads shape_latent from {sha256}/view{XX}.npz format and pairs with
corresponding view from render_cond.
Uses ViewImageConditionedMixin which reads mesh_scale from view{XX}_scale.json.
"""
pass
|