Yang2001 commited on
Commit
551545a
·
1 Parent(s): 53ad659

Fix device mismatch, use remote RMBG client, improve progress tracking, translate comments to English

Browse files
app.py CHANGED
@@ -140,9 +140,15 @@ def init_models():
140
  pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
141
  pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
142
 
143
- pipeline.cuda()
144
  pipeline.rembg_model = None # Use remote BRIA-RMBG-2.0 instead
145
  pipeline.low_vram = False
 
 
 
 
 
 
 
146
 
147
  print("[NAF] Pre-loading NAF upsampler model...")
148
  for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
@@ -328,6 +334,10 @@ class _TqdmProgressInterceptor(_original_tqdm):
328
  self._stage_desc = kwargs.get('desc', 'Processing')
329
  super().__init__(*args, **kwargs)
330
 
 
 
 
 
331
  def update(self, n=1):
332
  super().update(n)
333
  _update_progress(self._stage_desc, self.n, self.total or 0)
@@ -339,6 +349,8 @@ import trellis2.pipelines.samplers.flow_euler as _fe_module
339
  _fe_module.tqdm = _TqdmProgressInterceptor
340
  import trellis2.utils.render_utils as _ru_module
341
  _ru_module.tqdm = _TqdmProgressInterceptor
 
 
342
 
343
  # ============================================================================
344
  # API Implementation
@@ -494,7 +506,6 @@ def extract_glb_api(state_path: str, decimation_target: int, texture_size: int,
494
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
495
  _update_progress("Decoding latent", 1, 1)
496
 
497
- _update_progress("Extracting GLB mesh", 0, 1)
498
  glb = o_voxel.postprocess.to_glb(
499
  vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
500
  coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
 
140
  pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
141
  pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
142
 
 
143
  pipeline.rembg_model = None # Use remote BRIA-RMBG-2.0 instead
144
  pipeline.low_vram = False
145
+ pipeline.cuda()
146
+
147
+ # Ensure image_cond_models are on GPU
148
+ pipeline.image_cond_model_ss.cuda()
149
+ pipeline.image_cond_model_shape_512.cuda()
150
+ pipeline.image_cond_model_shape_1024.cuda()
151
+ pipeline.image_cond_model_tex_1024.cuda()
152
 
153
  print("[NAF] Pre-loading NAF upsampler model...")
154
  for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
 
334
  self._stage_desc = kwargs.get('desc', 'Processing')
335
  super().__init__(*args, **kwargs)
336
 
337
+ def set_description(self, desc=None, refresh=True):
338
+ self._stage_desc = desc or 'Processing'
339
+ super().set_description(desc, refresh)
340
+
341
  def update(self, n=1):
342
  super().update(n)
343
  _update_progress(self._stage_desc, self.n, self.total or 0)
 
349
  _fe_module.tqdm = _TqdmProgressInterceptor
350
  import trellis2.utils.render_utils as _ru_module
351
  _ru_module.tqdm = _TqdmProgressInterceptor
352
+ import o_voxel.postprocess as _ovp_module
353
+ _ovp_module.tqdm = _TqdmProgressInterceptor
354
 
355
  # ============================================================================
356
  # API Implementation
 
506
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
507
  _update_progress("Decoding latent", 1, 1)
508
 
 
509
  glb = o_voxel.postprocess.to_glb(
510
  vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
511
  coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
autotune_cache.json CHANGED
@@ -24944,6 +24944,36 @@
24944
  "reg_inc_consumer": 0,
24945
  "maxnreg": null,
24946
  "pre_hook": null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24947
  }
24948
  },
24949
  "flex_gemm.kernels.triton.spconv.sparse_submanifold_conv_bwd_implicit_gemm.sparse_submanifold_conv_bwd_input_implicit_gemm_kernel": {
 
24944
  "reg_inc_consumer": 0,
24945
  "maxnreg": null,
24946
  "pre_hook": null
24947
+ },
24948
+ "(23, 7552645, 6, 8, 'torch.float32', 'torch.uint32', 'torch.float32', 'torch.float32')": {
24949
+ "kwargs": {
24950
+ "BM": 16,
24951
+ "BK": 8
24952
+ },
24953
+ "num_warps": 2,
24954
+ "num_ctas": 1,
24955
+ "num_stages": 2,
24956
+ "num_buffers_warp_spec": 0,
24957
+ "num_consumer_groups": 0,
24958
+ "reg_dec_producer": 0,
24959
+ "reg_inc_consumer": 0,
24960
+ "maxnreg": null,
24961
+ "pre_hook": null
24962
+ },
24963
+ "(22, 7813095, 6, 8, 'torch.float32', 'torch.uint32', 'torch.float32', 'torch.float32')": {
24964
+ "kwargs": {
24965
+ "BM": 16,
24966
+ "BK": 8
24967
+ },
24968
+ "num_warps": 2,
24969
+ "num_ctas": 1,
24970
+ "num_stages": 2,
24971
+ "num_buffers_warp_spec": 0,
24972
+ "num_consumer_groups": 0,
24973
+ "reg_dec_producer": 0,
24974
+ "reg_inc_consumer": 0,
24975
+ "maxnreg": null,
24976
+ "pre_hook": null
24977
  }
24978
  },
24979
  "flex_gemm.kernels.triton.spconv.sparse_submanifold_conv_bwd_implicit_gemm.sparse_submanifold_conv_bwd_input_implicit_gemm_kernel": {
trellis2/datasets/components.py CHANGED
@@ -57,16 +57,16 @@ class StandardDatasetBase(Dataset):
57
  self._stats[key] = {}
58
  metadata = pd.DataFrame(columns=['sha256']).set_index('sha256')
59
 
60
- # 只从 ss_latent render_cond 合并关键字段
61
- # 不包含 base,因为 base/metadata.csv 中的 cond_rendered=False 会错误覆盖真实值
62
  for sub_key, r in root.items():
63
  if sub_key == 'base':
64
- continue # 跳过 base 目录
65
  metadata_file = os.path.join(r, 'metadata.csv')
66
  if os.path.exists(metadata_file):
67
  metadata = metadata.combine_first(pd.read_csv(metadata_file).set_index('sha256'))
68
 
69
- # base 单独读取 aesthetic_score(不读取其他可能冲突的列)
70
  if 'base' in root:
71
  base_metadata_file = os.path.join(root['base'], 'metadata.csv')
72
  if os.path.exists(base_metadata_file):
 
57
  self._stats[key] = {}
58
  metadata = pd.DataFrame(columns=['sha256']).set_index('sha256')
59
 
60
+ # Only merge key fields from ss_latent and render_cond
61
+ # Exclude base, because cond_rendered=False in base/metadata.csv would incorrectly overwrite real values
62
  for sub_key, r in root.items():
63
  if sub_key == 'base':
64
+ continue # Skip base directory
65
  metadata_file = os.path.join(r, 'metadata.csv')
66
  if os.path.exists(metadata_file):
67
  metadata = metadata.combine_first(pd.read_csv(metadata_file).set_index('sha256'))
68
 
69
+ # Read aesthetic_score separately from base (avoid reading other potentially conflicting columns)
70
  if 'base' in root:
71
  base_metadata_file = os.path.join(root['base'], 'metadata.csv')
72
  if os.path.exists(base_metadata_file):
trellis2/datasets/sparse_structure_latent.py CHANGED
@@ -349,7 +349,7 @@ class SparseStructureLatentView(SparseStructureLatentVisMixin, StandardDatasetBa
349
 
350
  if existing_view_cols:
351
  # Filter rows where all required views are encoded
352
- # 注意:NaN 需要被视为 False,所以用 == True 显式比较
353
  has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
354
  metadata = metadata[has_all_views]
355
  stats[f'With {self.num_views} view latents'] = len(metadata)
 
349
 
350
  if existing_view_cols:
351
  # Filter rows where all required views are encoded
352
+ # Note: NaN should be treated as False, so use == True for explicit comparison
353
  has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
354
  metadata = metadata[has_all_views]
355
  stats[f'With {self.num_views} view latents'] = len(metadata)
trellis2/datasets/structured_latent_shape.py CHANGED
@@ -293,7 +293,7 @@ class SLatShapeView(SLatShapeVisMixin, SLat):
293
 
294
  if existing_view_cols:
295
  # Filter rows where all required views are encoded
296
- # 注意:NaN 需要被视为 False,所以用 == True 显式比较
297
  has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
298
  metadata = metadata[has_all_views]
299
  stats[f'With {self.num_views} view latents'] = len(metadata)
 
293
 
294
  if existing_view_cols:
295
  # Filter rows where all required views are encoded
296
+ # Note: NaN should be treated as False, so use == True for explicit comparison
297
  has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
298
  metadata = metadata[has_all_views]
299
  stats[f'With {self.num_views} view latents'] = len(metadata)
trellis2/pipelines/pixal3d_image_to_3d.py CHANGED
@@ -14,9 +14,9 @@ class Pixal3DImageTo3DPipeline(Pipeline):
14
  """
15
  Pipeline for inferring Pixal3D (proj mode) image-to-3D models.
16
 
17
- 基于 Trellis2 pipeline,使用 proj 模式进行推理。
18
- 每个 stage (SS, Shape 512, Shape 1024, Tex 1024) 有独立的 image_cond_model (DinoV3ProjFeatureExtractor)
19
- 条件构建使用 camera-aware projection(需要 camera_angle_x, distance, mesh_scale 参数)。
20
 
21
  Args:
22
  models (dict[str, nn.Module]): The models to use in the pipeline.
@@ -114,13 +114,13 @@ class Pixal3DImageTo3DPipeline(Pipeline):
114
  pipeline.shape_slat_normalization = args['shape_slat_normalization']
115
  pipeline.tex_slat_normalization = args['tex_slat_normalization']
116
 
117
- # Proj mode: image_cond_models 需要外部加载后设置,这里先置为 None
118
  pipeline.image_cond_model_ss = None
119
  pipeline.image_cond_model_shape_512 = None
120
  pipeline.image_cond_model_shape_1024 = None
121
  pipeline.image_cond_model_tex_1024 = None
122
 
123
- pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
124
 
125
  pipeline.low_vram = args.get('low_vram', True)
126
  pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
@@ -186,7 +186,7 @@ class Pixal3DImageTo3DPipeline(Pipeline):
186
  return output
187
 
188
  # =========================================================================
189
- # Proj 模式条件构建
190
  # =========================================================================
191
 
192
  @torch.no_grad()
@@ -295,7 +295,7 @@ class Pixal3DImageTo3DPipeline(Pipeline):
295
  }
296
 
297
  # =========================================================================
298
- # Sampling methods (保持与 Trellis2 一致)
299
  # =========================================================================
300
 
301
  def sample_sparse_structure(
 
14
  """
15
  Pipeline for inferring Pixal3D (proj mode) image-to-3D models.
16
 
17
+ Based on Trellis2 pipeline, using proj mode for inference.
18
+ Each stage (SS, Shape 512, Shape 1024, Tex 1024) has its own image_cond_model (DinoV3ProjFeatureExtractor).
19
+ Condition building uses camera-aware projection (requires camera_angle_x, distance, mesh_scale parameters).
20
 
21
  Args:
22
  models (dict[str, nn.Module]): The models to use in the pipeline.
 
114
  pipeline.shape_slat_normalization = args['shape_slat_normalization']
115
  pipeline.tex_slat_normalization = args['tex_slat_normalization']
116
 
117
+ # Proj mode: image_cond_models need to be loaded externally, set to None here
118
  pipeline.image_cond_model_ss = None
119
  pipeline.image_cond_model_shape_512 = None
120
  pipeline.image_cond_model_shape_1024 = None
121
  pipeline.image_cond_model_tex_1024 = None
122
 
123
+ pipeline.rembg_model = None # Skip local RMBG loading; use remote client instead
124
 
125
  pipeline.low_vram = args.get('low_vram', True)
126
  pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
 
186
  return output
187
 
188
  # =========================================================================
189
+ # Proj mode condition building
190
  # =========================================================================
191
 
192
  @torch.no_grad()
 
295
  }
296
 
297
  # =========================================================================
298
+ # Sampling methods (consistent with Trellis2)
299
  # =========================================================================
300
 
301
  def sample_sparse_structure(
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -101,7 +101,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
101
  pipeline.shape_slat_normalization = args['shape_slat_normalization']
102
  pipeline.tex_slat_normalization = args['tex_slat_normalization']
103
 
104
- # HACK: 替换 dinov3 模型源为 camenduru 镜像
105
  image_cond_args = args['image_cond_model']['args'].copy()
106
  if image_cond_args.get('model_name') == 'facebook/dinov3-vitl16-pretrain-lvd1689m':
107
  image_cond_args['model_name'] = 'camenduru/dinov3-vitl16-pretrain-lvd1689m'
 
101
  pipeline.shape_slat_normalization = args['shape_slat_normalization']
102
  pipeline.tex_slat_normalization = args['tex_slat_normalization']
103
 
104
+ # HACK: Replace dinov3 model source with camenduru mirror
105
  image_cond_args = args['image_cond_model']['args'].copy()
106
  if image_cond_args.get('model_name') == 'facebook/dinov3-vitl16-pretrain-lvd1689m':
107
  image_cond_args['model_name'] = 'camenduru/dinov3-vitl16-pretrain-lvd1689m'
trellis2/trainers/basic.py CHANGED
@@ -491,7 +491,7 @@ class BasicTrainer:
491
  Finetune from a checkpoint.
492
  Should be called by all processes.
493
  """
494
- # 允许缺失的 keys(如 register_buffer 的参数)
495
  ALLOWED_MISSING_KEYS = {'rope_phases'}
496
 
497
  if self.is_master:
@@ -508,7 +508,7 @@ class BasicTrainer:
508
  # Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper)
509
  model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict)
510
 
511
- # 检查多余的 keys(在 ckpt 中但不在 model 中)
512
  for k, v in model_ckpt.items():
513
  if k not in model_state_dict:
514
  if self.is_master:
@@ -520,7 +520,7 @@ class BasicTrainer:
520
  model_ckpt[k] = model_state_dict[k]
521
  model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
522
 
523
- # 检查缺失的 keys(在 model 中但不在 ckpt 中)
524
  missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys())
525
  unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS
526
  if unexpected_missing and self.is_master:
@@ -529,7 +529,7 @@ class BasicTrainer:
529
  if missing_keys & ALLOWED_MISSING_KEYS and self.is_master:
530
  print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}')
531
 
532
- # 补充缺失的 keys(使用模型初始化值)
533
  for k in missing_keys:
534
  model_ckpt[k] = model_state_dict[k]
535
 
@@ -903,16 +903,16 @@ class BasicTrainer:
903
 
904
  def _verify_gradient_sync(self):
905
  """
906
- 验证 DDP 梯度同步是否真正生效。
907
- DDP backward 会自动对梯度进行 all_reduce,同步后所有卡的梯度应该完全相同。
908
 
909
- 验证方法:
910
- 1. 计算所有参数的总梯度 norm
911
- 2. 收集各卡的梯度 norm
912
- 3. 如果 DDP 同步正常,所有卡的梯度 norm 应该完全相同
913
- 4. 如果没有同步,各卡梯度 norm 会不同(因为各卡处理的数据不同)
914
  """
915
- # 计算本卡所有参数的总梯度 norm
916
  total_grad_norm_sq = 0.0
917
  grad_count = 0
918
  for p in self.model_params:
@@ -925,16 +925,16 @@ class BasicTrainer:
925
 
926
  local_grad_norm = total_grad_norm_sq ** 0.5
927
 
928
- # 确保所有进程到达同一点
929
  dist.barrier()
930
 
931
- # 收集所有卡的梯度 norm
932
  grad_norm_tensor = torch.tensor([local_grad_norm], dtype=torch.float64, device=self.device)
933
  all_grad_norms = [torch.zeros(1, dtype=torch.float64, device=self.device) for _ in range(self.world_size)]
934
  dist.all_gather(all_grad_norms, grad_norm_tensor)
935
  all_grad_norms = [g.item() for g in all_grad_norms]
936
 
937
- # 验证所有卡的梯度 norm 是否相同(使用相对误差,容忍 0.1%
938
  ref_norm = all_grad_norms[0]
939
  if ref_norm > 0:
940
  is_synced = all(abs(g - ref_norm) / ref_norm < 1e-3 for g in all_grad_norms)
@@ -1010,7 +1010,7 @@ class BasicTrainer:
1010
  loss, status = self.training_losses(**mb_data)
1011
  l = loss['loss'] / len(data_list)
1012
 
1013
- # DEBUG: 打印每个 rank loss
1014
  if self.debug:
1015
  print(f'[Rank {self.rank}/{self.world_size}] Step {self.step} batch {i}: loss={loss["loss"].item():.6f}')
1016
 
@@ -1029,10 +1029,10 @@ class BasicTrainer:
1029
  elastic_controller_logs.append(self.elastic_controller.log())
1030
 
1031
  # ============================================================
1032
- # DEBUG: 验证 DDP 梯度同步
1033
- # 检查 backward 后各卡梯度是否一致
1034
- # DDP 在最后一个 batch_split backward 时会自动 all_reduce 梯度
1035
- # 同步后所有卡的梯度应该完全相同
1036
  # ============================================================
1037
  if self.debug and self.world_size > 1:
1038
  self._verify_gradient_sync()
 
491
  Finetune from a checkpoint.
492
  Should be called by all processes.
493
  """
494
+ # Allow missing keys (e.g., register_buffer parameters)
495
  ALLOWED_MISSING_KEYS = {'rope_phases'}
496
 
497
  if self.is_master:
 
508
  # Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper)
509
  model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict)
510
 
511
+ # Check extra keys (in ckpt but not in model)
512
  for k, v in model_ckpt.items():
513
  if k not in model_state_dict:
514
  if self.is_master:
 
520
  model_ckpt[k] = model_state_dict[k]
521
  model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
522
 
523
+ # Check missing keys (in model but not in ckpt)
524
  missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys())
525
  unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS
526
  if unexpected_missing and self.is_master:
 
529
  if missing_keys & ALLOWED_MISSING_KEYS and self.is_master:
530
  print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}')
531
 
532
+ # Fill in missing keys (using model initialized values)
533
  for k in missing_keys:
534
  model_ckpt[k] = model_state_dict[k]
535
 
 
903
 
904
  def _verify_gradient_sync(self):
905
  """
906
+ Verify that DDP gradient synchronization is working correctly.
907
+ DDP's backward automatically performs all_reduce on gradients; after sync all ranks should have identical gradients.
908
 
909
+ Verification method:
910
+ 1. Compute total gradient norm across all parameters
911
+ 2. Gather gradient norms from all ranks
912
+ 3. If DDP sync is working, all ranks should have identical gradient norms
913
+ 4. If not synced, gradient norms will differ (since each rank processes different data)
914
  """
915
+ # Compute total gradient norm on this rank
916
  total_grad_norm_sq = 0.0
917
  grad_count = 0
918
  for p in self.model_params:
 
925
 
926
  local_grad_norm = total_grad_norm_sq ** 0.5
927
 
928
+ # Ensure all processes reach the same point
929
  dist.barrier()
930
 
931
+ # Gather gradient norms from all ranks
932
  grad_norm_tensor = torch.tensor([local_grad_norm], dtype=torch.float64, device=self.device)
933
  all_grad_norms = [torch.zeros(1, dtype=torch.float64, device=self.device) for _ in range(self.world_size)]
934
  dist.all_gather(all_grad_norms, grad_norm_tensor)
935
  all_grad_norms = [g.item() for g in all_grad_norms]
936
 
937
+ # Verify all ranks have the same gradient norm (relative error tolerance: 0.1%)
938
  ref_norm = all_grad_norms[0]
939
  if ref_norm > 0:
940
  is_synced = all(abs(g - ref_norm) / ref_norm < 1e-3 for g in all_grad_norms)
 
1010
  loss, status = self.training_losses(**mb_data)
1011
  l = loss['loss'] / len(data_list)
1012
 
1013
+ # DEBUG: Print loss for each rank
1014
  if self.debug:
1015
  print(f'[Rank {self.rank}/{self.world_size}] Step {self.step} batch {i}: loss={loss["loss"].item():.6f}')
1016
 
 
1029
  elastic_controller_logs.append(self.elastic_controller.log())
1030
 
1031
  # ============================================================
1032
+ # DEBUG: Verify DDP gradient synchronization
1033
+ # Check if gradients are consistent across ranks after backward
1034
+ # DDP automatically all_reduces gradients during the last batch_split's backward
1035
+ # After sync, all ranks should have identical gradients
1036
  # ============================================================
1037
  if self.debug and self.world_size > 1:
1038
  self._verify_gradient_sync()