Spaces:
Running
Running
Fix device mismatch, use remote RMBG client, improve progress tracking, translate comments to English
Browse files- app.py +13 -2
- autotune_cache.json +30 -0
- trellis2/datasets/components.py +4 -4
- trellis2/datasets/sparse_structure_latent.py +1 -1
- trellis2/datasets/structured_latent_shape.py +1 -1
- trellis2/pipelines/pixal3d_image_to_3d.py +7 -7
- trellis2/pipelines/trellis2_image_to_3d.py +1 -1
- trellis2/trainers/basic.py +20 -20
app.py
CHANGED
|
@@ -140,9 +140,15 @@ def init_models():
|
|
| 140 |
pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
|
| 141 |
pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
|
| 142 |
|
| 143 |
-
pipeline.cuda()
|
| 144 |
pipeline.rembg_model = None # Use remote BRIA-RMBG-2.0 instead
|
| 145 |
pipeline.low_vram = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
print("[NAF] Pre-loading NAF upsampler model...")
|
| 148 |
for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
|
|
@@ -328,6 +334,10 @@ class _TqdmProgressInterceptor(_original_tqdm):
|
|
| 328 |
self._stage_desc = kwargs.get('desc', 'Processing')
|
| 329 |
super().__init__(*args, **kwargs)
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
def update(self, n=1):
|
| 332 |
super().update(n)
|
| 333 |
_update_progress(self._stage_desc, self.n, self.total or 0)
|
|
@@ -339,6 +349,8 @@ import trellis2.pipelines.samplers.flow_euler as _fe_module
|
|
| 339 |
_fe_module.tqdm = _TqdmProgressInterceptor
|
| 340 |
import trellis2.utils.render_utils as _ru_module
|
| 341 |
_ru_module.tqdm = _TqdmProgressInterceptor
|
|
|
|
|
|
|
| 342 |
|
| 343 |
# ============================================================================
|
| 344 |
# API Implementation
|
|
@@ -494,7 +506,6 @@ def extract_glb_api(state_path: str, decimation_target: int, texture_size: int,
|
|
| 494 |
mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
|
| 495 |
_update_progress("Decoding latent", 1, 1)
|
| 496 |
|
| 497 |
-
_update_progress("Extracting GLB mesh", 0, 1)
|
| 498 |
glb = o_voxel.postprocess.to_glb(
|
| 499 |
vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
|
| 500 |
coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
|
|
|
|
| 140 |
pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
|
| 141 |
pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
|
| 142 |
|
|
|
|
| 143 |
pipeline.rembg_model = None # Use remote BRIA-RMBG-2.0 instead
|
| 144 |
pipeline.low_vram = False
|
| 145 |
+
pipeline.cuda()
|
| 146 |
+
|
| 147 |
+
# Ensure image_cond_models are on GPU
|
| 148 |
+
pipeline.image_cond_model_ss.cuda()
|
| 149 |
+
pipeline.image_cond_model_shape_512.cuda()
|
| 150 |
+
pipeline.image_cond_model_shape_1024.cuda()
|
| 151 |
+
pipeline.image_cond_model_tex_1024.cuda()
|
| 152 |
|
| 153 |
print("[NAF] Pre-loading NAF upsampler model...")
|
| 154 |
for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
|
|
|
|
| 334 |
self._stage_desc = kwargs.get('desc', 'Processing')
|
| 335 |
super().__init__(*args, **kwargs)
|
| 336 |
|
| 337 |
+
def set_description(self, desc=None, refresh=True):
|
| 338 |
+
self._stage_desc = desc or 'Processing'
|
| 339 |
+
super().set_description(desc, refresh)
|
| 340 |
+
|
| 341 |
def update(self, n=1):
|
| 342 |
super().update(n)
|
| 343 |
_update_progress(self._stage_desc, self.n, self.total or 0)
|
|
|
|
| 349 |
_fe_module.tqdm = _TqdmProgressInterceptor
|
| 350 |
import trellis2.utils.render_utils as _ru_module
|
| 351 |
_ru_module.tqdm = _TqdmProgressInterceptor
|
| 352 |
+
import o_voxel.postprocess as _ovp_module
|
| 353 |
+
_ovp_module.tqdm = _TqdmProgressInterceptor
|
| 354 |
|
| 355 |
# ============================================================================
|
| 356 |
# API Implementation
|
|
|
|
| 506 |
mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
|
| 507 |
_update_progress("Decoding latent", 1, 1)
|
| 508 |
|
|
|
|
| 509 |
glb = o_voxel.postprocess.to_glb(
|
| 510 |
vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
|
| 511 |
coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
|
autotune_cache.json
CHANGED
|
@@ -24944,6 +24944,36 @@
|
|
| 24944 |
"reg_inc_consumer": 0,
|
| 24945 |
"maxnreg": null,
|
| 24946 |
"pre_hook": null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24947 |
}
|
| 24948 |
},
|
| 24949 |
"flex_gemm.kernels.triton.spconv.sparse_submanifold_conv_bwd_implicit_gemm.sparse_submanifold_conv_bwd_input_implicit_gemm_kernel": {
|
|
|
|
| 24944 |
"reg_inc_consumer": 0,
|
| 24945 |
"maxnreg": null,
|
| 24946 |
"pre_hook": null
|
| 24947 |
+
},
|
| 24948 |
+
"(23, 7552645, 6, 8, 'torch.float32', 'torch.uint32', 'torch.float32', 'torch.float32')": {
|
| 24949 |
+
"kwargs": {
|
| 24950 |
+
"BM": 16,
|
| 24951 |
+
"BK": 8
|
| 24952 |
+
},
|
| 24953 |
+
"num_warps": 2,
|
| 24954 |
+
"num_ctas": 1,
|
| 24955 |
+
"num_stages": 2,
|
| 24956 |
+
"num_buffers_warp_spec": 0,
|
| 24957 |
+
"num_consumer_groups": 0,
|
| 24958 |
+
"reg_dec_producer": 0,
|
| 24959 |
+
"reg_inc_consumer": 0,
|
| 24960 |
+
"maxnreg": null,
|
| 24961 |
+
"pre_hook": null
|
| 24962 |
+
},
|
| 24963 |
+
"(22, 7813095, 6, 8, 'torch.float32', 'torch.uint32', 'torch.float32', 'torch.float32')": {
|
| 24964 |
+
"kwargs": {
|
| 24965 |
+
"BM": 16,
|
| 24966 |
+
"BK": 8
|
| 24967 |
+
},
|
| 24968 |
+
"num_warps": 2,
|
| 24969 |
+
"num_ctas": 1,
|
| 24970 |
+
"num_stages": 2,
|
| 24971 |
+
"num_buffers_warp_spec": 0,
|
| 24972 |
+
"num_consumer_groups": 0,
|
| 24973 |
+
"reg_dec_producer": 0,
|
| 24974 |
+
"reg_inc_consumer": 0,
|
| 24975 |
+
"maxnreg": null,
|
| 24976 |
+
"pre_hook": null
|
| 24977 |
}
|
| 24978 |
},
|
| 24979 |
"flex_gemm.kernels.triton.spconv.sparse_submanifold_conv_bwd_implicit_gemm.sparse_submanifold_conv_bwd_input_implicit_gemm_kernel": {
|
trellis2/datasets/components.py
CHANGED
|
@@ -57,16 +57,16 @@ class StandardDatasetBase(Dataset):
|
|
| 57 |
self._stats[key] = {}
|
| 58 |
metadata = pd.DataFrame(columns=['sha256']).set_index('sha256')
|
| 59 |
|
| 60 |
-
#
|
| 61 |
-
#
|
| 62 |
for sub_key, r in root.items():
|
| 63 |
if sub_key == 'base':
|
| 64 |
-
continue #
|
| 65 |
metadata_file = os.path.join(r, 'metadata.csv')
|
| 66 |
if os.path.exists(metadata_file):
|
| 67 |
metadata = metadata.combine_first(pd.read_csv(metadata_file).set_index('sha256'))
|
| 68 |
|
| 69 |
-
#
|
| 70 |
if 'base' in root:
|
| 71 |
base_metadata_file = os.path.join(root['base'], 'metadata.csv')
|
| 72 |
if os.path.exists(base_metadata_file):
|
|
|
|
| 57 |
self._stats[key] = {}
|
| 58 |
metadata = pd.DataFrame(columns=['sha256']).set_index('sha256')
|
| 59 |
|
| 60 |
+
# Only merge key fields from ss_latent and render_cond
|
| 61 |
+
# Exclude base, because cond_rendered=False in base/metadata.csv would incorrectly overwrite real values
|
| 62 |
for sub_key, r in root.items():
|
| 63 |
if sub_key == 'base':
|
| 64 |
+
continue # Skip base directory
|
| 65 |
metadata_file = os.path.join(r, 'metadata.csv')
|
| 66 |
if os.path.exists(metadata_file):
|
| 67 |
metadata = metadata.combine_first(pd.read_csv(metadata_file).set_index('sha256'))
|
| 68 |
|
| 69 |
+
# Read aesthetic_score separately from base (avoid reading other potentially conflicting columns)
|
| 70 |
if 'base' in root:
|
| 71 |
base_metadata_file = os.path.join(root['base'], 'metadata.csv')
|
| 72 |
if os.path.exists(base_metadata_file):
|
trellis2/datasets/sparse_structure_latent.py
CHANGED
|
@@ -349,7 +349,7 @@ class SparseStructureLatentView(SparseStructureLatentVisMixin, StandardDatasetBa
|
|
| 349 |
|
| 350 |
if existing_view_cols:
|
| 351 |
# Filter rows where all required views are encoded
|
| 352 |
-
#
|
| 353 |
has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
|
| 354 |
metadata = metadata[has_all_views]
|
| 355 |
stats[f'With {self.num_views} view latents'] = len(metadata)
|
|
|
|
| 349 |
|
| 350 |
if existing_view_cols:
|
| 351 |
# Filter rows where all required views are encoded
|
| 352 |
+
# Note: NaN should be treated as False, so use == True for explicit comparison
|
| 353 |
has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
|
| 354 |
metadata = metadata[has_all_views]
|
| 355 |
stats[f'With {self.num_views} view latents'] = len(metadata)
|
trellis2/datasets/structured_latent_shape.py
CHANGED
|
@@ -293,7 +293,7 @@ class SLatShapeView(SLatShapeVisMixin, SLat):
|
|
| 293 |
|
| 294 |
if existing_view_cols:
|
| 295 |
# Filter rows where all required views are encoded
|
| 296 |
-
#
|
| 297 |
has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
|
| 298 |
metadata = metadata[has_all_views]
|
| 299 |
stats[f'With {self.num_views} view latents'] = len(metadata)
|
|
|
|
| 293 |
|
| 294 |
if existing_view_cols:
|
| 295 |
# Filter rows where all required views are encoded
|
| 296 |
+
# Note: NaN should be treated as False, so use == True for explicit comparison
|
| 297 |
has_all_views = (metadata[existing_view_cols] == True).all(axis=1)
|
| 298 |
metadata = metadata[has_all_views]
|
| 299 |
stats[f'With {self.num_views} view latents'] = len(metadata)
|
trellis2/pipelines/pixal3d_image_to_3d.py
CHANGED
|
@@ -14,9 +14,9 @@ class Pixal3DImageTo3DPipeline(Pipeline):
|
|
| 14 |
"""
|
| 15 |
Pipeline for inferring Pixal3D (proj mode) image-to-3D models.
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
|
| 21 |
Args:
|
| 22 |
models (dict[str, nn.Module]): The models to use in the pipeline.
|
|
@@ -114,13 +114,13 @@ class Pixal3DImageTo3DPipeline(Pipeline):
|
|
| 114 |
pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
| 115 |
pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
| 116 |
|
| 117 |
-
# Proj mode: image_cond_models
|
| 118 |
pipeline.image_cond_model_ss = None
|
| 119 |
pipeline.image_cond_model_shape_512 = None
|
| 120 |
pipeline.image_cond_model_shape_1024 = None
|
| 121 |
pipeline.image_cond_model_tex_1024 = None
|
| 122 |
|
| 123 |
-
pipeline.rembg_model =
|
| 124 |
|
| 125 |
pipeline.low_vram = args.get('low_vram', True)
|
| 126 |
pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
|
|
@@ -186,7 +186,7 @@ class Pixal3DImageTo3DPipeline(Pipeline):
|
|
| 186 |
return output
|
| 187 |
|
| 188 |
# =========================================================================
|
| 189 |
-
# Proj
|
| 190 |
# =========================================================================
|
| 191 |
|
| 192 |
@torch.no_grad()
|
|
@@ -295,7 +295,7 @@ class Pixal3DImageTo3DPipeline(Pipeline):
|
|
| 295 |
}
|
| 296 |
|
| 297 |
# =========================================================================
|
| 298 |
-
# Sampling methods (
|
| 299 |
# =========================================================================
|
| 300 |
|
| 301 |
def sample_sparse_structure(
|
|
|
|
| 14 |
"""
|
| 15 |
Pipeline for inferring Pixal3D (proj mode) image-to-3D models.
|
| 16 |
|
| 17 |
+
Based on Trellis2 pipeline, using proj mode for inference.
|
| 18 |
+
Each stage (SS, Shape 512, Shape 1024, Tex 1024) has its own image_cond_model (DinoV3ProjFeatureExtractor).
|
| 19 |
+
Condition building uses camera-aware projection (requires camera_angle_x, distance, mesh_scale parameters).
|
| 20 |
|
| 21 |
Args:
|
| 22 |
models (dict[str, nn.Module]): The models to use in the pipeline.
|
|
|
|
| 114 |
pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
| 115 |
pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
| 116 |
|
| 117 |
+
# Proj mode: image_cond_models need to be loaded externally, set to None here
|
| 118 |
pipeline.image_cond_model_ss = None
|
| 119 |
pipeline.image_cond_model_shape_512 = None
|
| 120 |
pipeline.image_cond_model_shape_1024 = None
|
| 121 |
pipeline.image_cond_model_tex_1024 = None
|
| 122 |
|
| 123 |
+
pipeline.rembg_model = None # Skip local RMBG loading; use remote client instead
|
| 124 |
|
| 125 |
pipeline.low_vram = args.get('low_vram', True)
|
| 126 |
pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
|
|
|
|
| 186 |
return output
|
| 187 |
|
| 188 |
# =========================================================================
|
| 189 |
+
# Proj mode condition building
|
| 190 |
# =========================================================================
|
| 191 |
|
| 192 |
@torch.no_grad()
|
|
|
|
| 295 |
}
|
| 296 |
|
| 297 |
# =========================================================================
|
| 298 |
+
# Sampling methods (consistent with Trellis2)
|
| 299 |
# =========================================================================
|
| 300 |
|
| 301 |
def sample_sparse_structure(
|
trellis2/pipelines/trellis2_image_to_3d.py
CHANGED
|
@@ -101,7 +101,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 101 |
pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
| 102 |
pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
| 103 |
|
| 104 |
-
# HACK:
|
| 105 |
image_cond_args = args['image_cond_model']['args'].copy()
|
| 106 |
if image_cond_args.get('model_name') == 'facebook/dinov3-vitl16-pretrain-lvd1689m':
|
| 107 |
image_cond_args['model_name'] = 'camenduru/dinov3-vitl16-pretrain-lvd1689m'
|
|
|
|
| 101 |
pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
| 102 |
pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
| 103 |
|
| 104 |
+
# HACK: Replace dinov3 model source with camenduru mirror
|
| 105 |
image_cond_args = args['image_cond_model']['args'].copy()
|
| 106 |
if image_cond_args.get('model_name') == 'facebook/dinov3-vitl16-pretrain-lvd1689m':
|
| 107 |
image_cond_args['model_name'] = 'camenduru/dinov3-vitl16-pretrain-lvd1689m'
|
trellis2/trainers/basic.py
CHANGED
|
@@ -491,7 +491,7 @@ class BasicTrainer:
|
|
| 491 |
Finetune from a checkpoint.
|
| 492 |
Should be called by all processes.
|
| 493 |
"""
|
| 494 |
-
#
|
| 495 |
ALLOWED_MISSING_KEYS = {'rope_phases'}
|
| 496 |
|
| 497 |
if self.is_master:
|
|
@@ -508,7 +508,7 @@ class BasicTrainer:
|
|
| 508 |
# Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper)
|
| 509 |
model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict)
|
| 510 |
|
| 511 |
-
#
|
| 512 |
for k, v in model_ckpt.items():
|
| 513 |
if k not in model_state_dict:
|
| 514 |
if self.is_master:
|
|
@@ -520,7 +520,7 @@ class BasicTrainer:
|
|
| 520 |
model_ckpt[k] = model_state_dict[k]
|
| 521 |
model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
|
| 522 |
|
| 523 |
-
#
|
| 524 |
missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys())
|
| 525 |
unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS
|
| 526 |
if unexpected_missing and self.is_master:
|
|
@@ -529,7 +529,7 @@ class BasicTrainer:
|
|
| 529 |
if missing_keys & ALLOWED_MISSING_KEYS and self.is_master:
|
| 530 |
print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}')
|
| 531 |
|
| 532 |
-
#
|
| 533 |
for k in missing_keys:
|
| 534 |
model_ckpt[k] = model_state_dict[k]
|
| 535 |
|
|
@@ -903,16 +903,16 @@ class BasicTrainer:
|
|
| 903 |
|
| 904 |
def _verify_gradient_sync(self):
|
| 905 |
"""
|
| 906 |
-
|
| 907 |
-
DDP
|
| 908 |
|
| 909 |
-
|
| 910 |
-
1.
|
| 911 |
-
2.
|
| 912 |
-
3.
|
| 913 |
-
4.
|
| 914 |
"""
|
| 915 |
-
#
|
| 916 |
total_grad_norm_sq = 0.0
|
| 917 |
grad_count = 0
|
| 918 |
for p in self.model_params:
|
|
@@ -925,16 +925,16 @@ class BasicTrainer:
|
|
| 925 |
|
| 926 |
local_grad_norm = total_grad_norm_sq ** 0.5
|
| 927 |
|
| 928 |
-
#
|
| 929 |
dist.barrier()
|
| 930 |
|
| 931 |
-
#
|
| 932 |
grad_norm_tensor = torch.tensor([local_grad_norm], dtype=torch.float64, device=self.device)
|
| 933 |
all_grad_norms = [torch.zeros(1, dtype=torch.float64, device=self.device) for _ in range(self.world_size)]
|
| 934 |
dist.all_gather(all_grad_norms, grad_norm_tensor)
|
| 935 |
all_grad_norms = [g.item() for g in all_grad_norms]
|
| 936 |
|
| 937 |
-
#
|
| 938 |
ref_norm = all_grad_norms[0]
|
| 939 |
if ref_norm > 0:
|
| 940 |
is_synced = all(abs(g - ref_norm) / ref_norm < 1e-3 for g in all_grad_norms)
|
|
@@ -1010,7 +1010,7 @@ class BasicTrainer:
|
|
| 1010 |
loss, status = self.training_losses(**mb_data)
|
| 1011 |
l = loss['loss'] / len(data_list)
|
| 1012 |
|
| 1013 |
-
# DEBUG:
|
| 1014 |
if self.debug:
|
| 1015 |
print(f'[Rank {self.rank}/{self.world_size}] Step {self.step} batch {i}: loss={loss["loss"].item():.6f}')
|
| 1016 |
|
|
@@ -1029,10 +1029,10 @@ class BasicTrainer:
|
|
| 1029 |
elastic_controller_logs.append(self.elastic_controller.log())
|
| 1030 |
|
| 1031 |
# ============================================================
|
| 1032 |
-
# DEBUG:
|
| 1033 |
-
#
|
| 1034 |
-
# DDP
|
| 1035 |
-
#
|
| 1036 |
# ============================================================
|
| 1037 |
if self.debug and self.world_size > 1:
|
| 1038 |
self._verify_gradient_sync()
|
|
|
|
| 491 |
Finetune from a checkpoint.
|
| 492 |
Should be called by all processes.
|
| 493 |
"""
|
| 494 |
+
# Allow missing keys (e.g., register_buffer parameters)
|
| 495 |
ALLOWED_MISSING_KEYS = {'rope_phases'}
|
| 496 |
|
| 497 |
if self.is_master:
|
|
|
|
| 508 |
# Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper)
|
| 509 |
model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict)
|
| 510 |
|
| 511 |
+
# Check extra keys (in ckpt but not in model)
|
| 512 |
for k, v in model_ckpt.items():
|
| 513 |
if k not in model_state_dict:
|
| 514 |
if self.is_master:
|
|
|
|
| 520 |
model_ckpt[k] = model_state_dict[k]
|
| 521 |
model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
|
| 522 |
|
| 523 |
+
# Check missing keys (in model but not in ckpt)
|
| 524 |
missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys())
|
| 525 |
unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS
|
| 526 |
if unexpected_missing and self.is_master:
|
|
|
|
| 529 |
if missing_keys & ALLOWED_MISSING_KEYS and self.is_master:
|
| 530 |
print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}')
|
| 531 |
|
| 532 |
+
# Fill in missing keys (using model initialized values)
|
| 533 |
for k in missing_keys:
|
| 534 |
model_ckpt[k] = model_state_dict[k]
|
| 535 |
|
|
|
|
| 903 |
|
| 904 |
def _verify_gradient_sync(self):
|
| 905 |
"""
|
| 906 |
+
Verify that DDP gradient synchronization is working correctly.
|
| 907 |
+
DDP's backward automatically performs all_reduce on gradients; after sync all ranks should have identical gradients.
|
| 908 |
|
| 909 |
+
Verification method:
|
| 910 |
+
1. Compute total gradient norm across all parameters
|
| 911 |
+
2. Gather gradient norms from all ranks
|
| 912 |
+
3. If DDP sync is working, all ranks should have identical gradient norms
|
| 913 |
+
4. If not synced, gradient norms will differ (since each rank processes different data)
|
| 914 |
"""
|
| 915 |
+
# Compute total gradient norm on this rank
|
| 916 |
total_grad_norm_sq = 0.0
|
| 917 |
grad_count = 0
|
| 918 |
for p in self.model_params:
|
|
|
|
| 925 |
|
| 926 |
local_grad_norm = total_grad_norm_sq ** 0.5
|
| 927 |
|
| 928 |
+
# Ensure all processes reach the same point
|
| 929 |
dist.barrier()
|
| 930 |
|
| 931 |
+
# Gather gradient norms from all ranks
|
| 932 |
grad_norm_tensor = torch.tensor([local_grad_norm], dtype=torch.float64, device=self.device)
|
| 933 |
all_grad_norms = [torch.zeros(1, dtype=torch.float64, device=self.device) for _ in range(self.world_size)]
|
| 934 |
dist.all_gather(all_grad_norms, grad_norm_tensor)
|
| 935 |
all_grad_norms = [g.item() for g in all_grad_norms]
|
| 936 |
|
| 937 |
+
# Verify all ranks have the same gradient norm (relative error tolerance: 0.1%)
|
| 938 |
ref_norm = all_grad_norms[0]
|
| 939 |
if ref_norm > 0:
|
| 940 |
is_synced = all(abs(g - ref_norm) / ref_norm < 1e-3 for g in all_grad_norms)
|
|
|
|
| 1010 |
loss, status = self.training_losses(**mb_data)
|
| 1011 |
l = loss['loss'] / len(data_list)
|
| 1012 |
|
| 1013 |
+
# DEBUG: Print loss for each rank
|
| 1014 |
if self.debug:
|
| 1015 |
print(f'[Rank {self.rank}/{self.world_size}] Step {self.step} batch {i}: loss={loss["loss"].item():.6f}')
|
| 1016 |
|
|
|
|
| 1029 |
elastic_controller_logs.append(self.elastic_controller.log())
|
| 1030 |
|
| 1031 |
# ============================================================
|
| 1032 |
+
# DEBUG: Verify DDP gradient synchronization
|
| 1033 |
+
# Check if gradients are consistent across ranks after backward
|
| 1034 |
+
# DDP automatically all_reduces gradients during the last batch_split's backward
|
| 1035 |
+
# After sync, all ranks should have identical gradients
|
| 1036 |
# ============================================================
|
| 1037 |
if self.debug and self.world_size > 1:
|
| 1038 |
self._verify_gradient_sync()
|