from abc import abstractmethod import os import time import json import copy import threading from functools import partial from contextlib import nullcontext import torch import torch.distributed as dist from torch.utils.data import DataLoader from torch.nn.parallel import DistributedDataParallel as DDP import numpy as np from torchvision import utils try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False from .utils import * from ..utils.general_utils import * from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler from ..utils.dist_utils import * from ..utils import grad_clip_utils, elastic_utils class BasicTrainer: """ Trainer for basic training loop. Args: models (dict[str, nn.Module]): Models to train. dataset (torch.utils.data.Dataset): Dataset. output_dir (str): Output directory. load_dir (str): Load directory. step (int): Step to load. batch_size (int): Batch size. batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. batch_split (int): Split batch with gradient accumulation. max_steps (int): Max steps. optimizer (dict): Optimizer config. lr_scheduler (dict): Learning rate scheduler config. elastic (dict): Elastic memory management config. grad_clip (float or dict): Gradient clip config. ema_rate (float or list): Exponential moving average rates. mix_precision_mode (str): - None: No mixed precision. - 'inflat_all': Hold a inflated fp32 master param for all params. - 'amp': Automatic mixed precision. mix_precision_dtype (str): Mixed precision dtype. fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. parallel_mode (str): Parallel mode. Options are 'ddp'. finetune_ckpt (dict): Finetune checkpoint. log_param_stats (bool): Log parameter stats. i_print (int): Print interval. i_log (int): Log interval. i_sample (int): Sample interval. i_save (int): Save interval. i_ddpcheck (int): DDP check interval. """ def __init__(self, models, dataset, *, output_dir, load_dir, step, max_steps, batch_size=None, batch_size_per_gpu=None, batch_split=None, optimizer={}, lr_scheduler=None, elastic=None, grad_clip=None, ema_rate=0.9999, fp16_mode=None, mix_precision_mode='inflat_all', mix_precision_dtype='float16', fp16_scale_growth=1e-3, parallel_mode='ddp', finetune_ckpt=None, log_param_stats=False, prefetch_data=True, snapshot_batch_size=4, snapshot_num_samples=64, num_workers=None, debug=False, i_print=1000, i_log=500, i_sample=10000, i_save=10000, i_ddpcheck=10000, wandb_run=None, # wandb run object **kwargs ): assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.' self.models = models self.dataset = dataset self.batch_split = batch_split if batch_split is not None else 1 self.max_steps = max_steps self.debug = debug self.optimizer_config = optimizer self.lr_scheduler_config = lr_scheduler self.elastic_controller_config = elastic self.grad_clip = grad_clip self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate if fp16_mode is not None: mix_precision_dtype = 'float16' mix_precision_mode = fp16_mode self.mix_precision_mode = mix_precision_mode self.mix_precision_dtype = str_to_dtype(mix_precision_dtype) self.fp16_scale_growth = fp16_scale_growth self.parallel_mode = parallel_mode self.log_param_stats = log_param_stats self.prefetch_data = prefetch_data self.snapshot_batch_size = snapshot_batch_size self.snapshot_num_samples = snapshot_num_samples self.num_workers = num_workers self.log = [] if self.prefetch_data: self._data_prefetched = None self.output_dir = output_dir from datetime import datetime self._log_file = os.path.join(self.output_dir, f'log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt') self.i_print = i_print self.i_log = i_log self.i_sample = i_sample self.i_save = i_save self.i_ddpcheck = i_ddpcheck if dist.is_initialized(): # Multi-GPU params self.world_size = dist.get_world_size() self.rank = dist.get_rank() self.local_rank = dist.get_rank() % torch.cuda.device_count() self.is_master = self.rank == 0 else: # Single-GPU params self.world_size = 1 self.rank = 0 self.local_rank = 0 self.is_master = True self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.' assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.' self.init_models_and_more(**kwargs) self.prepare_dataloader(**kwargs) # Load checkpoint self.step = 0 if load_dir is not None and step is not None: self.load(load_dir, step) elif finetune_ckpt is not None: self.finetune_from(finetune_ckpt) if self.is_master: os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True) os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True) self.writer = None # TensorBoard disabled (S3 FUSE does not support append) # Initialize wandb self.wandb_run = wandb_run if self.wandb_run is not None: print(f'Wandb logging enabled: {self.wandb_run.url}') if self.parallel_mode == 'ddp' and self.world_size > 1: self.check_ddp() if self.is_master: print('\n\nTrainer initialized.') print(self) def __str__(self): lines = [] lines.append(self.__class__.__name__) lines.append(f' - Models:') for name, model in self.models.items(): lines.append(f' - {name}: {model.__class__.__name__}') lines.append(f' - Dataset: {indent(str(self.dataset), 2)}') lines.append(f' - Dataloader:') lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}') lines.append(f' - Num workers: {self.dataloader.num_workers}') lines.append(f' - Number of steps: {self.max_steps}') lines.append(f' - Number of GPUs: {self.world_size}') lines.append(f' - Batch size: {self.batch_size}') lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}') lines.append(f' - Batch split: {self.batch_split}') lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}') lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}') if self.lr_scheduler_config is not None: lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}') if self.elastic_controller_config is not None: lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}') if self.grad_clip is not None: lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}') lines.append(f' - EMA rate: {self.ema_rate}') lines.append(f' - Mixed precision dtype: {self.mix_precision_dtype}') lines.append(f' - Mixed precision mode: {self.mix_precision_mode}') if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: lines.append(f' - FP16 scale growth: {self.fp16_scale_growth}') lines.append(f' - Parallel mode: {self.parallel_mode}') return '\n'.join(lines) @property def device(self): for _, model in self.models.items(): if hasattr(model, 'device'): return model.device return next(list(self.models.values())[0].parameters()).device def init_models_and_more(self, **kwargs): """ Initialize models and more. """ if self.world_size > 1: # Prepare distributed data parallel self.training_models = { name: DDP( model, device_ids=[self.local_rank], output_device=self.local_rank, bucket_cap_mb=128, find_unused_parameters=False ) for name, model in self.models.items() } else: self.training_models = self.models # Build master params self.model_params = sum( [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()] , []) if self.mix_precision_mode == 'amp': self.master_params = self.model_params if self.mix_precision_dtype == torch.float16: self.scaler = torch.GradScaler() elif self.mix_precision_mode == 'inflat_all': self.master_params = make_master_params(self.model_params) if self.mix_precision_dtype == torch.float16: self.log_scale = 20.0 elif self.mix_precision_mode is None: self.master_params = self.model_params else: raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.') # Build EMA params if self.is_master: self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate] # Initialize optimizer if hasattr(torch.optim, self.optimizer_config['name']): self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args']) else: self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args']) # Initalize learning rate scheduler if self.lr_scheduler_config is not None: if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']): self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args']) else: self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args']) # Initialize elastic memory controller if self.elastic_controller_config is not None: assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \ 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin' self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args']) for model in self.models.values(): if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)): model.register_memory_controller(self.elastic_controller) # Initialize gradient clipper if self.grad_clip is not None: if isinstance(self.grad_clip, (float, int)): self.grad_clip = float(self.grad_clip) else: self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args']) def prepare_dataloader(self, **kwargs): """ Prepare dataloader. """ self.data_sampler = ResumableSampler( self.dataset, shuffle=True, ) if self.num_workers is None or self.num_workers == -1: num_workers = max(1, int(np.ceil((os.cpu_count() - 16) / torch.cuda.device_count()))) else: num_workers = self.num_workers self.dataloader = DataLoader( self.dataset, batch_size=self.batch_size_per_gpu, num_workers=num_workers, pin_memory=True, drop_last=True, persistent_workers=True, collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, sampler=self.data_sampler, ) self.data_iterator = cycle(self.dataloader) def _master_params_to_state_dicts(self, master_params): """ Convert master params to dict of state_dicts. """ if self.mix_precision_mode == 'inflat_all': master_params = unflatten_master_params(self.model_params, master_params) state_dicts = {name: model.state_dict() for name, model in self.models.items()} master_params_names = sum( [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] , []) for i, (model_name, param_name) in enumerate(master_params_names): state_dicts[model_name][param_name] = master_params[i] return state_dicts def _state_dicts_to_master_params(self, master_params, state_dicts): """ Convert a state_dict to master params. """ master_params_names = sum( [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] , []) params = [state_dicts[name][param_name] for name, param_name in master_params_names] if self.mix_precision_mode == 'inflat_all': model_params_to_master_params(params, master_params) else: for i, param in enumerate(params): master_params[i].data.copy_(param.data) def load(self, load_dir, step=0): """ Load a checkpoint. Should be called by all processes. """ if self.is_master: print(f'\nLoading checkpoint from step {step}...', end='') model_ckpts = {} for name, model in self.models.items(): model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True) model_ckpts[name] = model_ckpt model.load_state_dict(model_ckpt) self._state_dicts_to_master_params(self.master_params, model_ckpts) del model_ckpts if self.is_master: for i, ema_rate in enumerate(self.ema_rate): ema_ckpts = {} for name, model in self.models.items(): ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True) ema_ckpts[name] = ema_ckpt self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) del ema_ckpts misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False) self.optimizer.load_state_dict(misc_ckpt['optimizer']) self.step = misc_ckpt['step'] self.data_sampler.load_state_dict(misc_ckpt['data_sampler']) if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: self.scaler.load_state_dict(misc_ckpt['scaler']) elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: self.log_scale = misc_ckpt['log_scale'] if self.lr_scheduler_config is not None: self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler']) if self.elastic_controller_config is not None: self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller']) if self.grad_clip is not None and not isinstance(self.grad_clip, float): self.grad_clip.load_state_dict(misc_ckpt['grad_clip']) del misc_ckpt if self.world_size > 1: dist.barrier() if self.is_master: print(' Done.') if self.world_size > 1: self.check_ddp() def save(self, non_blocking=True): """ Save a checkpoint. Should be called only by the rank 0 process. """ assert self.is_master, 'save() should be called only by the rank 0 process.' print(f'\nSaving checkpoint at step {self.step}...', end='') model_ckpts = self._master_params_to_state_dicts(self.master_params) for name, model_ckpt in model_ckpts.items(): model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving if non_blocking: threading.Thread( target=torch.save, args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')), ).start() else: torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')) for i, ema_rate in enumerate(self.ema_rate): ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i]) for name, ema_ckpt in ema_ckpts.items(): ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()} # Move to CPU for saving if non_blocking: threading.Thread( target=torch.save, args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')), ).start() else: torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')) misc_ckpt = { 'optimizer': self.optimizer.state_dict(), 'step': self.step, 'data_sampler': self.data_sampler.state_dict(), } if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: misc_ckpt['scaler'] = self.scaler.state_dict() elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: misc_ckpt['log_scale'] = self.log_scale if self.lr_scheduler_config is not None: misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict() if self.elastic_controller_config is not None: misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict() if self.grad_clip is not None and not isinstance(self.grad_clip, float): misc_ckpt['grad_clip'] = self.grad_clip.state_dict() if non_blocking: threading.Thread( target=torch.save, args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')), ).start() else: torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')) print(' Done.') def _remap_checkpoint_keys(self, model_ckpt, model_state_dict): """ Remap checkpoint keys to match model state dict. Handles structural changes like: - cross_attn.xxx -> cross_attn.cross_attn_block.xxx (for ProjectAttention wrapper) Args: model_ckpt: Checkpoint state dict model_state_dict: Model state dict Returns: Remapped checkpoint dict """ remapped_ckpt = {} remapped_count = 0 for ckpt_key, ckpt_value in model_ckpt.items(): # Check if key exists directly if ckpt_key in model_state_dict: remapped_ckpt[ckpt_key] = ckpt_value continue # Try remapping: cross_attn.xxx -> cross_attn.cross_attn_block.xxx # This handles the case when cross_attn is wrapped by ProjectAttention if '.cross_attn.' in ckpt_key: # Split at .cross_attn. parts = ckpt_key.split('.cross_attn.') if len(parts) == 2: new_key = f'{parts[0]}.cross_attn.cross_attn_block.{parts[1]}' if new_key in model_state_dict: remapped_ckpt[new_key] = ckpt_value remapped_count += 1 continue # Key not remapped, keep original (will be handled by missing key logic) remapped_ckpt[ckpt_key] = ckpt_value if remapped_count > 0 and self.is_master: print(f'Info: Remapped {remapped_count} cross_attn keys to cross_attn.cross_attn_block') return remapped_ckpt def finetune_from(self, finetune_ckpt): """ Finetune from a checkpoint. Should be called by all processes. """ # Allow missing keys (e.g., register_buffer parameters) ALLOWED_MISSING_KEYS = {'rope_phases'} if self.is_master: print('\nFinetuning from:') for name, path in finetune_ckpt.items(): print(f' - {name}: {path}') model_ckpts = {} for name, model in self.models.items(): model_state_dict = model.state_dict() if name in finetune_ckpt: model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True) # Remap checkpoint keys to handle structural changes (e.g., ProjectAttention wrapper) model_ckpt = self._remap_checkpoint_keys(model_ckpt, model_state_dict) # Check extra keys (in ckpt but not in model) for k, v in model_ckpt.items(): if k not in model_state_dict: if self.is_master: print(f'Warning: {k} not found in model_state_dict, skipped.') model_ckpt[k] = None elif model_ckpt[k].shape != model_state_dict[k].shape: if self.is_master: print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.') model_ckpt[k] = model_state_dict[k] model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None} # Check missing keys (in model but not in ckpt) missing_keys = set(model_state_dict.keys()) - set(model_ckpt.keys()) unexpected_missing = missing_keys - ALLOWED_MISSING_KEYS if unexpected_missing and self.is_master: print(f'Error: Missing keys in checkpoint: {unexpected_missing}') raise RuntimeError(f'Missing keys in checkpoint: {unexpected_missing}') if missing_keys & ALLOWED_MISSING_KEYS and self.is_master: print(f'Info: Using model initialized values for: {missing_keys & ALLOWED_MISSING_KEYS}') # Fill in missing keys (using model initialized values) for k in missing_keys: model_ckpt[k] = model_state_dict[k] model_ckpts[name] = model_ckpt model.load_state_dict(model_ckpt) else: if self.is_master: print(f'Warning: {name} not found in finetune_ckpt, skipped.') model_ckpts[name] = model_state_dict self._state_dicts_to_master_params(self.master_params, model_ckpts) if self.is_master: for i, ema_rate in enumerate(self.ema_rate): self._state_dicts_to_master_params(self.ema_params[i], model_ckpts) del model_ckpts if self.world_size > 1: dist.barrier() if self.is_master: print('Done.') if self.world_size > 1: self.check_ddp() @abstractmethod def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs): """ Run a snapshot of the model. """ pass @torch.no_grad() def visualize_sample(self, sample): """ Convert a sample to an image. """ if hasattr(self.dataset, 'visualize_sample'): return self.dataset.visualize_sample(sample) else: return sample @torch.no_grad() def snapshot_dataset(self, num_samples=100, batch_size=4): """ Sample images from the dataset. """ dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=batch_size, num_workers=0, shuffle=True, collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, ) save_cfg = {} for i in range(0, num_samples, batch_size): data = next(iter(dataloader)) data = {k: v[:min(num_samples - i, batch_size)] for k, v in data.items()} data = recursive_to_device(data, self.device) try: vis = self.visualize_sample(data) except (RuntimeError, Exception) as e: print(f'\033[93m[WARN] snapshot_dataset visualize_sample failed (batch {i}), skipping: {e}\033[0m') torch.cuda.empty_cache() continue if isinstance(vis, dict): for k, v in vis.items(): if f'dataset_{k}' not in save_cfg: save_cfg[f'dataset_{k}'] = [] save_cfg[f'dataset_{k}'].append(v) else: if 'dataset' not in save_cfg: save_cfg['dataset'] = [] save_cfg['dataset'].append(vis) for name, image in save_cfg.items(): utils.save_image( torch.cat(image, dim=0), os.path.join(self.output_dir, 'samples', f'{name}.jpg'), nrow=int(np.sqrt(num_samples)), normalize=True, value_range=self.dataset.value_range, ) @torch.no_grad() def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False): """ Sample images from the model. NOTE: When num_samples >= 4, this function should be called by all processes. When num_samples < 4, only master runs snapshot (other ranks skip via barrier). """ # Free cached GPU memory before snapshot to avoid OOM / illegal address errors import gc gc.collect() torch.cuda.empty_cache() if self.is_master: print(f'\nSampling {num_samples} images...', end='') if suffix is None: suffix = f'step{self.step:07d}' # When num_samples < 4, only master runs snapshot to avoid multi-rank gather issues master_only = num_samples < 4 sample_metadata = None # Will hold list of "dataset_name/sha256" strings if master_only and self.world_size > 1: if not self.is_master: # Non-master ranks just wait at barrier dist.barrier() return # Master runs snapshot alone amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext with amp_context(): samples = self.run_snapshot(num_samples, batch_size=batch_size, verbose=verbose) # Extract metadata before preprocessing sample_metadata = samples.pop('_metadata', None) # Free GPU memory after sampling, before decode + render torch.cuda.empty_cache() # Preprocess images for key in list(samples.keys()): if samples[key]['type'] == 'sample': try: vis = self.visualize_sample(samples[key]['value']) except RuntimeError as e: print(f"[Snapshot] WARNING: visualize_sample failed for '{key}': {e}") # Reset CUDA error state and skip this sample try: torch.cuda.synchronize() except RuntimeError: pass torch.cuda.empty_cache() del samples[key] continue if isinstance(vis, dict): for k, v in vis.items(): samples[f'{key}_{k}'] = {'value': v, 'type': 'image'} del samples[key] else: samples[key] = {'value': vis, 'type': 'image'} # No gather needed, master already has all samples dist.barrier() else: # Distribute sampling across all ranks num_samples_per_process = int(np.ceil(num_samples / self.world_size)) amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext with amp_context(): samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose) # Extract metadata before preprocessing local_metadata = samples.pop('_metadata', None) # Free GPU memory after sampling, before decode + render torch.cuda.empty_cache() # Preprocess images for key in list(samples.keys()): if samples[key]['type'] == 'sample': try: vis = self.visualize_sample(samples[key]['value']) except RuntimeError as e: print(f"[Snapshot] WARNING: visualize_sample failed for '{key}': {e}") torch.cuda.synchronize() del samples[key] continue if isinstance(vis, dict): for k, v in vis.items(): samples[f'{key}_{k}'] = {'value': v, 'type': 'image'} del samples[key] else: samples[key] = {'value': vis, 'type': 'image'} # Gather results if self.world_size > 1: for key in samples.keys(): samples[key]['value'] = samples[key]['value'].contiguous() if self.is_master: all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)] else: all_images = [] dist.gather(samples[key]['value'], all_images, dst=0) if self.is_master: samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples] # Gather metadata across ranks if local_metadata is not None: all_metadata = [None] * self.world_size dist.all_gather_object(all_metadata, local_metadata) if self.is_master: sample_metadata = sum(all_metadata, [])[:num_samples] else: sample_metadata = None else: sample_metadata = local_metadata # Save images if self.is_master: os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True) wandb_images = {} # Collect images for wandb logging nrow = int(np.sqrt(num_samples)) vr = self.dataset.value_range # Build metadata caption string for wandb metadata_caption = '' if sample_metadata: metadata_caption = '\n' + ' | '.join(sample_metadata) # Also save metadata to file with open(os.path.join(self.output_dir, 'samples', suffix, 'metadata.txt'), 'w') as f: for i, m in enumerate(sample_metadata): f.write(f'{i}: {m}\n') # Helper: make a normalized grid tensor from a batch of images def _make_grid(tensor): return utils.make_grid(tensor, nrow=nrow, normalize=True, value_range=vr) # Helper: resize grid to target height (keep aspect ratio) def _resize_to_height(grid, target_h): import torch.nn.functional as F _, h, w = grid.shape if h == target_h: return grid target_w = int(round(w * target_h / h)) return F.interpolate(grid.unsqueeze(0), size=(target_h, target_w), mode='bilinear', align_corners=False).squeeze(0) # --- Save individual images (original behavior) --- for key in samples.keys(): if samples[key]['type'] == 'image': image_path = os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg') utils.save_image( samples[key]['value'], image_path, nrow=nrow, normalize=True, value_range=vr, ) # Collect for wandb if self.wandb_run is not None: grid = _make_grid(samples[key]['value']) grid_np = grid.permute(1, 2, 0).cpu().numpy() grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8) wandb_images[f'samples/{key}'] = wandb.Image(grid_np, caption=f'{key} at step {self.step}{metadata_caption}') elif samples[key]['type'] == 'number': val_min = samples[key]['value'].min() val_max = samples[key]['value'].max() images = (samples[key]['value'] - val_min) / (val_max - val_min) images = utils.make_grid( images, nrow=nrow, normalize=False, ) save_image_with_notes( images, os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), notes=f'{key} min: {val_min}, max: {val_max}', ) # --- Save combined images --- sample_keys = set(samples.keys()) # Combined 1: image + sample_gt_view + sample_gt_gt_view (shape) # image + sample_gt_view_{attr} + sample_gt_gt_view_{attr} (tex, per attribute) # Detect gt_view attribute suffixes from sample keys gt_view_attrs = set() for k in sample_keys: if k.startswith('sample_gt_view_'): attr = k[len('sample_gt_view_'):] gt_view_attrs.add(attr) if gt_view_attrs: # Tex mode: generate combined view for each PBR attribute for attr in sorted(gt_view_attrs): combo1_keys = ['image', f'sample_gt_view_{attr}', f'sample_gt_gt_view_{attr}'] combo1_present = [k for k in combo1_keys if k in sample_keys and samples[k]['type'] == 'image'] if len(combo1_present) >= 2: grids = [_make_grid(samples[k]['value']) for k in combo1_present] target_h = max(g.shape[1] for g in grids) grids = [_resize_to_height(g, target_h) for g in grids] combined = torch.cat(grids, dim=2) combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_views_{attr}_{suffix}.jpg') utils.save_image(combined, combined_path, normalize=False) if self.wandb_run is not None: grid_np = combined.permute(1, 2, 0).cpu().numpy() grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8) label = ' | '.join(combo1_present) wandb_images[f'samples/combined_views_{attr}'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}') else: # Shape mode: single gt_view combo1_keys = ['image', 'sample_gt_view', 'sample_gt_gt_view'] combo1_present = [k for k in combo1_keys if k in sample_keys and samples[k]['type'] == 'image'] if len(combo1_present) >= 2: grids = [_make_grid(samples[k]['value']) for k in combo1_present] target_h = max(g.shape[1] for g in grids) grids = [_resize_to_height(g, target_h) for g in grids] combined = torch.cat(grids, dim=2) combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_views_{suffix}.jpg') utils.save_image(combined, combined_path, normalize=False) if self.wandb_run is not None: grid_np = combined.permute(1, 2, 0).cpu().numpy() grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8) label = ' | '.join(combo1_present) wandb_images[f'samples/combined_views'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}') # Combined 2: sample_multiview + sample_gt_multiview combo2_keys = ['sample_multiview', 'sample_gt_multiview'] combo2_present = [k for k in combo2_keys if k in sample_keys and samples[k]['type'] == 'image'] if len(combo2_present) >= 2: grids = [_make_grid(samples[k]['value']) for k in combo2_present] target_h = max(g.shape[1] for g in grids) grids = [_resize_to_height(g, target_h) for g in grids] combined = torch.cat(grids, dim=2) # concatenate along width combined_path = os.path.join(self.output_dir, 'samples', suffix, f'combined_multiview_{suffix}.jpg') utils.save_image(combined, combined_path, normalize=False) if self.wandb_run is not None: grid_np = combined.permute(1, 2, 0).cpu().numpy() grid_np = (grid_np * 255).clip(0, 255).astype(np.uint8) label = ' | '.join(combo2_present) wandb_images[f'samples/combined_multiview'] = wandb.Image(grid_np, caption=f'{label} at step {self.step}{metadata_caption}') # Log images to wandb if self.wandb_run is not None and wandb_images: self.wandb_run.log(wandb_images, step=self.step) if self.is_master: print(' Done.') def update_ema(self): """ Update exponential moving average. Should only be called by the rank 0 process. """ assert self.is_master, 'update_ema() should be called only by the rank 0 process.' for i, ema_rate in enumerate(self.ema_rate): for master_param, ema_param in zip(self.master_params, self.ema_params[i]): ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate) def check_ddp(self): """ Check if DDP is working properly. Should be called by all process. """ if self.is_master: print('\nPerforming DDP check...') if self.is_master: print('Checking if parameters are consistent across processes...') dist.barrier() try: for p in self.master_params: # split to avoid OOM for i in range(0, p.numel(), 10000000): sub_size = min(10000000, p.numel() - i) sub_p = p.detach().view(-1)[i:i+sub_size] # gather from all processes sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)] dist.all_gather(sub_p_gather, sub_p) # check if equal assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes' except AssertionError as e: if self.is_master: print(f'\n\033[91mError: {e}\033[0m') print('DDP check failed.') raise e dist.barrier() if self.is_master: print('Done.') def _verify_gradient_sync(self): """ Verify that DDP gradient synchronization is working correctly. DDP's backward automatically performs all_reduce on gradients; after sync all ranks should have identical gradients. Verification method: 1. Compute total gradient norm across all parameters 2. Gather gradient norms from all ranks 3. If DDP sync is working, all ranks should have identical gradient norms 4. If not synced, gradient norms will differ (since each rank processes different data) """ # Compute total gradient norm on this rank total_grad_norm_sq = 0.0 grad_count = 0 for p in self.model_params: if p.grad is not None: total_grad_norm_sq += p.grad.detach().float().norm().item() ** 2 grad_count += 1 if grad_count == 0: return local_grad_norm = total_grad_norm_sq ** 0.5 # Ensure all processes reach the same point dist.barrier() # Gather gradient norms from all ranks grad_norm_tensor = torch.tensor([local_grad_norm], dtype=torch.float64, device=self.device) all_grad_norms = [torch.zeros(1, dtype=torch.float64, device=self.device) for _ in range(self.world_size)] dist.all_gather(all_grad_norms, grad_norm_tensor) all_grad_norms = [g.item() for g in all_grad_norms] # Verify all ranks have the same gradient norm (relative error tolerance: 0.1%) ref_norm = all_grad_norms[0] if ref_norm > 0: is_synced = all(abs(g - ref_norm) / ref_norm < 1e-3 for g in all_grad_norms) else: is_synced = all(abs(g) < 1e-10 for g in all_grad_norms) if self.is_master: print(f'\n{"="*60}') print(f'[Step {self.step}] DDP Gradient Sync Verification:') for i, g in enumerate(all_grad_norms): print(f' Rank {i} grad_norm: {g:.8f}') if is_synced: print(f' \033[92m✓ PASS: All gradients are synchronized!\033[0m') else: max_diff = max(abs(g - ref_norm) for g in all_grad_norms) print(f' \033[91m✗ FAIL: Gradients are NOT synchronized! Max diff: {max_diff:.8f}\033[0m') print(f'{"="*60}\n') @abstractmethod def training_losses(**mb_data): """ Compute training losses. """ pass def load_data(self): """ Load data. """ if self.prefetch_data: if self._data_prefetched is None: self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) data = self._data_prefetched self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) else: data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu if isinstance(data, dict): if self.batch_split == 1: data_list = [data] else: batch_size = list(data.values())[0].shape[0] data_list = [ {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()} for i in range(self.batch_split) ] elif isinstance(data, list): data_list = data else: raise ValueError('Data must be a dict or a list of dicts.') return data_list def run_step(self, data_list): """ Run a training step. """ step_log = {'loss': {}, 'status': {}} amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext # Train losses = [] statuses = [] elastic_controller_logs = [] zero_grad(self.model_params) for i, mb_data in enumerate(data_list): ## sync at the end of each batch split sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext] with nested_contexts(*sync_contexts), elastic_controller_context(): with amp_context(): loss, status = self.training_losses(**mb_data) l = loss['loss'] / len(data_list) # DEBUG: Print loss for each rank if self.debug: print(f'[Rank {self.rank}/{self.world_size}] Step {self.step} batch {i}: loss={loss["loss"].item():.6f}') ## backward if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: self.scaler.scale(l).backward() elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: scaled_l = l * (2 ** self.log_scale) scaled_l.backward() else: l.backward() ## log losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) if self.elastic_controller_config is not None: elastic_controller_logs.append(self.elastic_controller.log()) # ============================================================ # DEBUG: Verify DDP gradient synchronization # Check if gradients are consistent across ranks after backward # DDP automatically all_reduces gradients during the last batch_split's backward # After sync, all ranks should have identical gradients # ============================================================ if self.debug and self.world_size > 1: self._verify_gradient_sync() ## gradient clip if self.grad_clip is not None: if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: self.scaler.unscale_(self.optimizer) elif self.mix_precision_mode == 'inflat_all': model_grads_to_master_grads(self.model_params, self.master_params) if self.mix_precision_dtype == torch.float16: self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) if isinstance(self.grad_clip, float): grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip) else: grad_norm = self.grad_clip(self.master_params) if torch.isfinite(grad_norm): statuses[-1]['grad_norm'] = grad_norm.item() ## step if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: prev_scale = self.scaler.get_scale() self.scaler.step(self.optimizer) self.scaler.update() elif self.mix_precision_mode == 'inflat_all': if self.mix_precision_dtype == torch.float16: prev_scale = 2 ** self.log_scale if not any(not p.grad.isfinite().all() for p in self.model_params): if self.grad_clip is None: model_grads_to_master_grads(self.model_params, self.master_params) self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) self.optimizer.step() master_params_to_model_params(self.model_params, self.master_params) self.log_scale += self.fp16_scale_growth else: self.log_scale -= 1 else: prev_scale = 1.0 if self.grad_clip is None: model_grads_to_master_grads(self.model_params, self.master_params) if not any(not p.grad.isfinite().all() for p in self.master_params): self.optimizer.step() master_params_to_model_params(self.model_params, self.master_params) else: print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') else: prev_scale = 1.0 if not any(not p.grad.isfinite().all() for p in self.model_params): self.optimizer.step() else: print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') ## adjust learning rate if self.lr_scheduler_config is not None: statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.step() # Logs step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x)) step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)}) if self.elastic_controller_config is not None: step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x)) if self.grad_clip is not None: step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log() # Check grad and norm of each param if self.log_param_stats: param_norms = {} param_grads = {} for model_name, model in self.models.items(): for name, param in model.named_parameters(): if param.requires_grad: param_norms[f'{model_name}.{name}'] = param.norm().item() if param.grad is not None and torch.isfinite(param.grad).all(): param_grads[f'{model_name}.{name}'] = param.grad.norm().item() / prev_scale step_log['param_norms'] = param_norms step_log['param_grads'] = param_grads # Update exponential moving average if self.is_master: self.update_ema() return step_log def save_logs(self): log_str = '\n'.join([ f'{step}: {json.dumps(dict_foreach(log, lambda x: float(x)))}' for step, log in self.log ]) # Accumulate logs in memory and overwrite file each time (S3 FUSE does not support append) if not hasattr(self, '_log_buffer'): self._log_buffer = [] self._log_buffer.append(log_str) try: with open(self._log_file, 'w') as log_file: log_file.write('\n'.join(self._log_buffer) + '\n') except Exception as e: print(f'\033[93m[WARN] Failed to write log file: {e}\033[0m') # show with mlflow log_show = [l for _, l in self.log if not dict_any(l, lambda x: np.isnan(x))] log_show = dict_reduce(log_show, lambda x: np.mean(x)) log_show = dict_flatten(log_show, sep='/') if self.writer is not None: for key, value in log_show.items(): self.writer.add_scalar(key, value, self.step) # Log to wandb if self.wandb_run is not None: wandb_log = {key: value for key, value in log_show.items()} wandb_log['step'] = self.step self.wandb_run.log(wandb_log, step=self.step) self.log = [] def check_abort(self): """ Check if training should be aborted due to certain conditions. """ # 1. If log_scale in inflat_all mode is less than 0 if self.mix_precision_dtype == torch.float16 and \ self.mix_precision_mode == 'inflat_all' and \ self.log_scale < 0: if self.is_master: print ('\n\n\033[91m') print (f'ABORT: log_scale in inflat_all mode is less than 0 at step {self.step}.') print ('This indicates that the model is diverging. You should look into the model and the data.') print ('\033[0m') self.save(non_blocking=False) self.save_logs() if self.world_size > 1: dist.barrier() raise ValueError('ABORT: log_scale in inflat_all mode is less than 0.') def run(self): """ Run training. """ if self.is_master: print('\nStarting training...') if self.i_sample != -1: try: self.snapshot_dataset(num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) except (RuntimeError, Exception) as e: print(f'\033[93m[WARN] snapshot_dataset failed, skipping: {e}\033[0m') torch.cuda.empty_cache() else: print('[INFO] i_sample=-1, all snapshots disabled.') if self.i_sample != -1: if self.step == 0: try: self.snapshot(suffix='init', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) except (RuntimeError, Exception) as e: print(f'\033[93m[WARN] snapshot (init) failed, skipping: {e}\033[0m') torch.cuda.empty_cache() else: # resume try: self.snapshot(suffix=f'resume_step{self.step:07d}', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) except (RuntimeError, Exception) as e: print(f'\033[93m[WARN] snapshot (resume) failed, skipping: {e}\033[0m') torch.cuda.empty_cache() time_last_print = 0.0 time_elapsed = 0.0 while self.step < self.max_steps: time_start = time.time() data_list = self.load_data() step_log = self.run_step(data_list) time_end = time.time() time_elapsed += time_end - time_start self.step += 1 # Print progress if self.is_master and self.step % self.i_print == 0: speed = self.i_print / (time_elapsed - time_last_print) * 3600 columns = [ f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)', f'Elapsed: {time_elapsed / 3600:.2f} h', f'Speed: {speed:.2f} steps/h', f'ETA: {(self.max_steps - self.step) / speed:.2f} h', ] print(' | '.join([c.ljust(25) for c in columns]), flush=True) time_last_print = time_elapsed # Check ddp if self.parallel_mode == 'ddp' and self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0: self.check_ddp() # Sample images if self.i_sample != -1 and self.step % self.i_sample == 0: try: self.snapshot(num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) except (RuntimeError, Exception) as e: if self.is_master: print(f'\033[93m[WARN] snapshot at step {self.step} failed, skipping: {e}\033[0m') try: torch.cuda.empty_cache() except Exception: pass if self.is_master: self.log.append((self.step, {})) # Log time self.log[-1][1]['time'] = { 'step': time_end - time_start, 'elapsed': time_elapsed, } # Log losses if step_log is not None: self.log[-1][1].update(step_log) # Log scale if self.mix_precision_dtype == torch.float16: if self.mix_precision_mode == 'amp': self.log[-1][1]['scale'] = self.scaler.get_scale() elif self.mix_precision_mode == 'inflat_all': self.log[-1][1]['log_scale'] = self.log_scale # Save log if self.step % self.i_log == 0: self.save_logs() # Save checkpoint if self.step % self.i_save == 0: self.save() # Check abort self.check_abort() if self.i_sample != -1: try: self.snapshot(suffix='final', num_samples=self.snapshot_num_samples, batch_size=self.snapshot_batch_size) except (RuntimeError, Exception) as e: if self.is_master: print(f'\033[93m[WARN] snapshot (final) failed, skipping: {e}\033[0m') torch.cuda.empty_cache() if self.world_size > 1: dist.barrier() if self.is_master: self.writer.close() print('Training finished.') def profile(self, wait=2, warmup=3, active=5): """ Profile the training loop. """ with torch.profiler.profile( schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')), profile_memory=True, with_stack=True, ) as prof: for _ in range(wait + warmup + active): self.run_step() prof.step()