# VLM2Vec/src/trainer_add_CRD.py import collections import contextlib import functools import shutil import sys import time from datetime import timedelta from packaging import version from accelerate import skip_first_batches, DistributedType, InitProcessGroupKwargs from transformers import PretrainedConfig from transformers.trainer import Trainer, TRAINING_ARGS_NAME, TRAINER_STATE_NAME import torch.distributed as dist from typing import Optional import os import torch import math import torch.nn as nn from src.data.collator.train_collator import split_vlm_inputs, get_dense_rep, split_and_process_vlm_inputs from src.model.model_add_CRD import MMEBModel from src.loss_add_CRD import SimpleContrastiveLoss, DistributedContrastiveLoss, MultiLayerCRDLoss, DistributedMultiLayerCRDLoss from src.grad_cache.grad_cache import GradCache from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments from transformers.trainer_callback import ( ExportableState, TrainerState, ) from transformers.trainer_utils import ( TrainOutput, has_length, speed_metrics, seed_worker, ) from transformers.trainer_pt_utils import ( get_model_param_count, ) from transformers.trainer import FSDP_MODEL_NAME from transformers.utils import ( XLA_FSDPV2_MIN_VERSION, is_accelerate_available, is_apex_available, is_torch_xla_available, logging, is_sagemaker_mp_enabled, CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME ) from src.utils import batch_to_device from src.utils import print_master, print_rank if is_apex_available(): from apex import amp if is_torch_xla_available(): import torch_xla.core.xla_model as xm from torch_xla import __version__ as XLA_VERSION IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) if IS_XLA_FSDPV2_POST_2_2: pass else: IS_XLA_FSDPV2_POST_2_2 = False logger = logging.get_logger(__name__) # =============== Helper: locate last LM block for grad diagnostics =============== def _locate_lm_layers_modulelist(encoder): """ Try common paths to get ModuleList of LM blocks in Qwen/LLaMA-like models. """ candidates = [ ("model", "language_model", "layers"), ("model", "model", "layers"), ("model", "layers"), ("language_model", "layers"), ("transformer", "layers"), ] for path in candidates: obj = encoder ok = True for p in path: if hasattr(obj, p): obj = getattr(obj, p) else: ok = False break if ok and isinstance(obj, torch.nn.ModuleList) and len(obj) > 0: return obj return None def _grad_norm(params): tot = 0.0 for p in params: if p.grad is not None: g = p.grad.detach().float() tot += (g * g).sum().item() return math.sqrt(tot) if tot > 0 else 0.0 # ================================================================================ class MMEBTrainer(Trainer): def __init__(self, *args, **kwargs): super(MMEBTrainer, self).__init__(*args, **kwargs) ws = dist.get_world_size() if dist.is_initialized() else 1 self.is_ddp = dist.is_initialized() and ws > 1 self._dist_loss_scale_factor = ws if self.is_ddp else 1 self.processor = self.processing_class def get_batch_samples(self, epoch_iterator, num_batches): batch_samples = [] num_items_in_batch = None for _ in range(num_batches): try: batch_samples += [next(epoch_iterator)] except StopIteration: break if len(batch_samples) > 0 and "labels" in batch_samples[0]: # For now we don't support object detection try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) except (TypeError, AttributeError): pass if self.args.average_tokens_across_devices and num_items_in_batch is not None: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() return batch_samples, num_items_in_batch def compute_loss(self, model, inputs, *args, **kwargs): qry_inputs, tgt_inputs = inputs return model(qry=qry_inputs, tgt=tgt_inputs) def _save(self, output_dir: Optional[str] = None, state_dict=None): os.makedirs(output_dir, exist_ok=True) if state_dict is None: state_dict = self.model.state_dict() prefix = 'encoder.' assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} self.model.encoder.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: # override original trainer's method if self.train_dataset is None or not has_length(self.train_dataset): return None return RandomSampler(self.train_dataset) def get_train_dataloader(self) -> DataLoader: """ override original trainer's method to disable self.accelerator.prepare since it will wrap DataLoaderDispatcher and lead to (1) `RuntimeError: You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`.` (2) all outputs of dataloader must be tensors """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset data_collator = self.data_collator train_dataset = self._remove_unused_columns(train_dataset, description="training") dataloader_params = { "batch_size": self._train_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "persistent_workers": self.args.dataloader_persistent_workers, } if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor else: dataloader_params["sampler"] = None dataloader_params["shuffle"] = False dataloader_params["drop_last"] = True dataloader_params["prefetch_factor"] = None # # tune on both prefetch_factor and persistent_workers will cause hang at epoch2 return DataLoader(train_dataset, **dataloader_params) def _load_from_checkpoint(self, resume_from_checkpoint, model=None): self.model_args.checkpoint_path = resume_from_checkpoint logger.info(f"Loading checkpoint from {resume_from_checkpoint}") self.model = MMEBModel.load(self.model_args) self.model_wrapped = self.model def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): self.accelerator.free_memory() self._train_batch_size = batch_size if self.args.auto_find_batch_size: if self.state.train_batch_size != self._train_batch_size: from accelerate.utils import release_memory (self.model_wrapped,) = release_memory(self.model_wrapped) self.model_wrapped = self.model # Check for DeepSpeed *after* the intial pass and modify the config if self.is_deepspeed_enabled: # Temporarily unset `self.args.train_batch_size` original_bs = self.args.per_device_train_batch_size self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) self.propagate_args_to_deepspeed(True) self.args.per_device_train_batch_size = original_bs self.state.train_batch_size = self._train_batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() # Setting up training control variables: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None num_train_tokens = None if has_length(train_dataloader): len_dataloader = len(train_dataloader) num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_examples = self.num_examples(train_dataloader) if args.max_steps > 0: max_steps = args.max_steps num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( args.max_steps % num_update_steps_per_epoch > 0 ) # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's # the best we can do. num_train_samples = args.max_steps * total_train_batch_size if args.include_tokens_per_second: num_train_tokens = ( self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps ) else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs if args.include_tokens_per_second: num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size max_steps = args.max_steps # Setting a very large number of epochs so we go as many times as necessary over the iterator. num_train_epochs = sys.maxsize num_update_steps_per_epoch = max_steps num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size if args.include_tokens_per_second: num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps else: raise ValueError( "args.max_steps must be set to a positive value if dataloader does not have a length, was" f" {args.max_steps}" ) delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: self.lr_scheduler = None self._created_lr_scheduler = False self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState( stateful_callbacks=[ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) ] ) self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps is not None: if args.logging_steps < 1: self.state.logging_steps = math.ceil(max_steps * args.logging_steps) else: self.state.logging_steps = args.logging_steps if args.eval_steps is not None: if args.eval_steps < 1: self.state.eval_steps = math.ceil(max_steps * args.eval_steps) else: self.state.eval_steps = args.eval_steps if args.save_steps is not None: if args.save_steps < 1: self.state.save_steps = math.ceil(max_steps * args.save_steps) else: self.state.save_steps = args.save_steps # Activate gradient checkpointing if needed if args.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) model = self._wrap_model(self.model_wrapped) # as the model is wrapped, don't use `accelerator.prepare` # this is for unhandled cases such as # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False if delay_optimizer_creation: if use_accelerator_prepare: self._fsdp_qlora_plugin_updates() self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare if use_accelerator_prepare: self.model.train() if hasattr(self.lr_scheduler, "step"): if self.use_apex: model = self.accelerator.prepare(self.model) else: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: # In this case we are in DDP + LOMO, which should be supported self.optimizer = self.accelerator.prepare(self.optimizer) if self.is_fsdp_enabled: self.model = self.model_wrapped = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) # important: at this point: # self.model is the Transformers Model # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs:,}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") if self.args.per_device_train_batch_size != self._train_batch_size: logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps:,}") logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() epochs_trained = 0 steps_trained_in_current_epoch = 0 steps_trained_progress_bar = None # @ruimeng use steps_trained_in_current_epoch to skip batches for finding buggy data # steps_trained_in_current_epoch = 42 # Check if continuing training from a checkpoint if resume_from_checkpoint is not None and os.path.isfile( os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.compare_trainer_and_checkpoint_args(self.args, self.state) self._load_callback_state() epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps else: steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." ) # Update the references self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() grad_norm: Optional[float] = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) if args.eval_on_start: self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_dataloader = train_dataloader if hasattr(epoch_dataloader.dataset, "set_epoch"): epoch_dataloader.dataset.set_epoch(epoch) if args.past_index >= 0: self._past = None steps_in_epoch = ( len(epoch_dataloader) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False steps_skipped = 0 if steps_trained_in_current_epoch > 0: epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 rng_to_sync = True step = -1 epoch_iterator = iter(epoch_dataloader) remainder = num_examples % args.gradient_accumulation_steps num_items_in_batch = None if remainder == 0: remainder = args.gradient_accumulation_steps update_step = -1 total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 for _ in range(total_updates): update_step += 1 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) for i, inputs in enumerate(batch_samples): step += 1 total_batched_samples += 1 dataset_stat = collections.Counter(inputs[0]['global_dataset_name']) is_last_step_and_steps_less_than_grad_acc = ( steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ) do_sync_step = is_last_step_and_steps_less_than_grad_acc or ( total_batched_samples % args.gradient_accumulation_steps == 0 ) if not do_sync_step: self.accelerator.gradient_state._set_sync_gradients(False) else: self.accelerator.gradient_state._set_sync_gradients(True) if self.args.include_num_input_tokens_seen: main_input_name = getattr(self.model, "main_input_name", "input_ids") if main_input_name not in inputs: logger.warning("Tried to track the number of tokens seen, however the current model is not configured properly.") else: input_tokens = inputs[main_input_name].numel() input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item() if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 if steps_trained_progress_bar is not None: steps_trained_progress_bar.update(1) if steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) continue elif steps_trained_progress_bar is not None: steps_trained_progress_bar.close() steps_trained_progress_bar = None if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) context = ( functools.partial(self.accelerator.no_sync, model=model) if i != len(batch_samples) - 1 else contextlib.nullcontext ) with context(): tr_loss_step = self.training_step(model, inputs, num_items_in_batch) if ( args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: if tr_loss.device != tr_loss_step.device: raise ValueError(f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}") tr_loss = tr_loss + tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) if do_sync_step: self.accelerator.gradient_state._set_sync_gradients(True) if args.max_grad_norm is not None and args.max_grad_norm > 0: if self.use_apex: _grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), args.max_grad_norm) else: _grad_norm = self.accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) if (is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED): grad_norm = model.get_global_grad_norm() if hasattr(grad_norm, "item"): grad_norm = grad_norm.item() else: grad_norm = _grad_norm self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) try: self._maybe_log_teacher_grad(model) except Exception as e: logger.warning(f"teacher grad log failed (ignored): {e}") self.optimizer.step() self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, time.time()) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): xm.mark_step() break if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): xm.mark_step() break if step < 0: logger.warning("There seems not to be a single sample in your epoch_iterator, stopping training.") self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, time.time()) if self.control.should_training_stop: break if args.past_index and hasattr(self, "_past"): delattr(self, "_past") logger.info("\n\nTraining completed.\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: if is_torch_xla_available(): xm.rendezvous("load_best_model_at_end") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() self._load_best_model() self._total_loss_scalar += tr_loss.item() effective_global_step = max(self.state.global_step, 0.001) train_loss = self._total_loss_scalar / effective_global_step metrics = speed_metrics( "train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps, num_tokens=num_train_tokens, ) self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss self.is_in_train = False self._memory_tracker.stop_and_update_metrics(metrics) self.log(metrics) run_dir = self._get_output_dir(trial) checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: for checkpoint in checkpoints_sorted: if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint, ignore_errors=True) self.control = self.callback_handler.on_train_end(args, self.state, self.control) self._finish_current_push() if self.neftune_noise_alpha is not None: self._deactivate_neftune(self.model) return TrainOutput(self.state.global_step, train_loss, metrics) class GradCacheLateProcessTrainer(MMEBTrainer): """ Adapted from gradcache repo. """ def __init__(self, *args, **kwargs): self.max_length = kwargs.get("max_length", 512) if "max_length" in kwargs: del kwargs["max_length"] self.model_args = kwargs.get("model_args", None) if "model_args" in kwargs: del kwargs["model_args"] super(GradCacheLateProcessTrainer, self).__init__(*args, **kwargs) ws = dist.get_world_size() if dist.is_initialized() else 1 self.is_ddp = dist.is_initialized() and ws > 1 self._dist_loss_scale_factor = ws if self.is_ddp else 1 loss_fn_cls = DistributedMultiLayerCRDLoss if self.is_ddp else MultiLayerCRDLoss crd_layers = getattr(self.args, "crd_layers", None) if isinstance(crd_layers, str) and len(crd_layers.strip()) > 0: crd_layers = [int(x.strip()) for x in crd_layers.split(",") if x.strip() != ""] else: crd_layers = None # ADDED: allow detach_teacher from args detach_teacher = getattr(self.args, "crd_detach_teacher", True) # 读取可选开关(也可以加到 TrainingArguments 后再改这里的读取) crd_side = os.getenv("CRD_SIDE", "both") # "both"|"qry"|"tgt" queue_size = int(os.getenv("CRD_QUEUE_SIZE", "0") or 0) self.loss_fn = loss_fn_cls( temperature=self.model.temperature, weights=getattr(self.model, "supervise_weights", None), crd_weight=getattr(self.args, "crd_weight", 0.2), crd_temperature=getattr(self.args, "crd_temperature", 0.07), crd_layers=crd_layers, detach_teacher=detach_teacher, crd_side=crd_side, queue_size=queue_size ) self.gc = GradCache( models=[self.model, self.model], chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], loss_fn=self.loss_fn, split_input_fn=split_and_process_vlm_inputs, get_rep_fn=get_dense_rep, # 返回 [B,K,D] fp16=self.args.fp16, scaler=self.scaler if self.args.fp16 else None ) # ADDED: cache for debug (last block params) self._last_block_params_cache = None def _infer_device(self, model): if hasattr(model, "device") and model.device is not None: return model.device try: return next(model.parameters()).device except StopIteration: return self.args.device def _batch_size(self, batch: dict) -> int: for k, v in batch.items(): if torch.is_tensor(v) and v.dim() > 0: return v.size(0) if isinstance(v, list): return len(v) raise ValueError("Cannot infer batch size from batch keys.") def _slice_batch(self, batch: dict, size: int, offset: int = 0): out = {} end = offset + size for k, v in batch.items(): if torch.is_tensor(v) and v.dim() > 0 and v.size(0) >= end: out[k] = v[offset:end] elif isinstance(v, list) and len(v) >= end: out[k] = v[offset:end] else: out[k] = v return out def _maybe_log_teacher_grad(self, model): """ 统计“最后一层 Block”的梯度范数,用于确认教师是否被更新。 可用环境变量 LOG_TEACHER_GRAD=0 关闭。 """ if os.getenv("LOG_TEACHER_GRAD", "1") not in ("1", "true", "True"): return try: params_last = self._get_last_block_params() tgn = _grad_norm(params_last) # 已转 float # 1) 控制台 print_master(f"[teacher_grad] step={self.state.global_step} norm={tgn:.6f}") # 2) HF metrics(会和 loss、grad_norm 一起进到 log_history/W&B) self.log({"teacher_grad_norm": tgn}) except Exception as e: logger.warning(f"teacher grad log failed: {e}") def _forward_reps_in_chunks(self, model, batch: dict, side: str, chunk_size: int, device): B = self._batch_size(batch) outs = [] use_bf16 = getattr(self.args, "bf16", False) dev_type = "cuda" if "cuda" in str(device) else "cpu" # 不切换 eval,不使用 no_grad;分块前向后立刻 detach for s in range(0, B, max(1, chunk_size or B)): bs = min(chunk_size or B, B - s) sub = self._slice_batch(batch, size=bs, offset=s) with torch.autocast(device_type=dev_type, dtype=torch.bfloat16, enabled=use_bf16): if side == "qry": o = model(qry=sub) reps = o["qry_reps"] # [bs,K,D] else: o = model(tgt=sub) reps = o["tgt_reps"] # [bs,K,D] outs.append(reps.detach()) # 关键:立即 detach del o, reps # 如显存非常紧张可偶尔清理缓存(频繁调用会慢) # if s % (4 * (chunk_size or B)) == 0: torch.cuda.empty_cache() return torch.cat(outs, dim=0) # [B,K,D] def _norm_weights(self, weights_list, K, device): if weights_list is None: return torch.ones(K, device=device) / K w = torch.tensor(list(weights_list), dtype=torch.float32, device=device) w = torch.clamp(w, min=0) s = float(w.sum()) return w / (s if s > 0 else 1.0) def _crd_indices(self, K: int): # 使用 args.crd_layers(相对 supervise_layers 的 0-based 索引);None -> 所有非最后层 cl = getattr(self.args, "crd_layers", None) if cl is None or (isinstance(cl, str) and cl.strip() == ""): return list(range(0, max(0, K - 1))) if isinstance(cl, str): idxs = [int(x.strip()) for x in cl.split(",") if x.strip() != ""] else: idxs = list(cl) # 过滤掉最后一层 K-1 out = [] for i in idxs: if i < 0: i = K + i if 0 <= i < K - 1: out.append(i) return sorted(set(out)) def _single_gpu_chunked_step(self, model, queries: dict, targets: dict, device): """ 手写分块版:两阶段梯度 A) 对 query 分块,backward 到 q(y / teacher 常量) B) 对 target 分块,backward 到 p(x / teacher 常量) 返回一个不带梯度的标量 loss 供日志使用。 """ # 预先无梯度计算“常量库”:全 batch 的多层表示(占用小) q_all = self._forward_reps_in_chunks(model, queries, side="qry", chunk_size=self.args.gc_q_chunk_size, device=device) # [B,K,D] p_all = self._forward_reps_in_chunks(model, targets, side="tgt", chunk_size=self.args.gc_p_chunk_size, device=device) # [B,K,D] B, K, D = q_all.shape w_ret = self._norm_weights(getattr(self.model, "supervise_weights", None), K, device) crd_idxs = self._crd_indices(K) w_crd = self._norm_weights([w_ret[i].item() for i in crd_idxs] if len(crd_idxs) > 0 else [1.0], max(1, len(crd_idxs)), device) temp = float(getattr(self.model, "temperature", 0.02)) beta = float(getattr(self.loss_fn, "runtime_beta", getattr(self.loss_fn, "crd_weight", getattr(self.args, "crd_weight", 0.2)))) crd_temp = float(getattr(self.loss_fn, "crd_temperature", getattr(self.args, "crd_temperature", 0.07))) detach_teacher = bool(getattr(self.args, "crd_detach_teacher", True)) use_bf16 = getattr(self.args, "bf16", False) dev_type = "cuda" if "cuda" in str(device) else "cpu" # 教师(最后一层)常量 tq_all = q_all[:, K - 1, :].detach() if detach_teacher else q_all[:, K - 1, :] tp_all = p_all[:, K - 1, :].detach() if detach_teacher else p_all[:, K - 1, :] total_loss_scalar = 0.0 # Phase-A: 查询侧分块,更新 q q_chunk = max(1, self.args.gc_q_chunk_size or B) for s in range(0, B, q_chunk): bs = min(q_chunk, B - s) sub_q = self._slice_batch(queries, size=bs, offset=s) labels = torch.arange(s, s + bs, device=device, dtype=torch.long) with torch.autocast(device_type=dev_type, dtype=torch.bfloat16, enabled=use_bf16): o_q = model(qry=sub_q) # [bs,K,D],带梯度 qk = o_q["qry_reps"] # L_ret (q_k vs p_all_k) L_ret_q = 0.0 for k_idx in range(K): logits = torch.matmul(qk[:, k_idx, :], p_all[:, k_idx, :].transpose(0, 1)) / temp # [bs,B] Lk = torch.nn.functional.cross_entropy(logits, labels, reduction="mean") L_ret_q = L_ret_q + w_ret[k_idx] * Lk # L_crd_q (q_k vs tq_all) L_crd_q = 0.0 for j, k_idx in enumerate(crd_idxs): logits = torch.matmul(qk[:, k_idx, :], tq_all.transpose(0, 1)) / crd_temp Lk = torch.nn.functional.cross_entropy(logits, labels, reduction="mean") L_crd_q = L_crd_q + w_crd[j] * Lk L_q = L_ret_q + beta * L_crd_q # 反传仅更新 q 分支 L_q.backward() total_loss_scalar += float(L_q.detach()) * (bs / B) del qk, o_q, L_q, L_ret_q, L_crd_q # Phase-B: 候选侧分块,更新 p p_chunk = max(1, self.args.gc_p_chunk_size or B) for s in range(0, B, p_chunk): bs = min(p_chunk, B - s) sub_p = self._slice_batch(targets, size=bs, offset=s) labels = torch.arange(s, s + bs, device=device, dtype=torch.long) with torch.autocast(device_type=dev_type, dtype=torch.bfloat16, enabled=use_bf16): o_p = model(tgt=sub_p) # [bs,K,D],带梯度 pk = o_p["tgt_reps"] # L_ret (p_k vs q_all_k) L_ret_p = 0.0 for k_idx in range(K): logits = torch.matmul(pk[:, k_idx, :], q_all[:, k_idx, :].transpose(0, 1)) / temp Lk = torch.nn.functional.cross_entropy(logits, labels, reduction="mean") L_ret_p = L_ret_p + w_ret[k_idx] * Lk # L_crd_p (p_k vs tp_all) L_crd_p = 0.0 for j, k_idx in enumerate(crd_idxs): logits = torch.matmul(pk[:, k_idx, :], tp_all.transpose(0, 1)) / crd_temp Lk = torch.nn.functional.cross_entropy(logits, labels, reduction="mean") L_crd_p = L_crd_p + w_crd[j] * Lk L_p = L_ret_p + beta * L_crd_p L_p.backward() total_loss_scalar += float(L_p.detach()) * (bs / B) del pk, o_p, L_p, L_ret_p, L_crd_p # 返回常数张量用于日志(避免 HF 再做 backward) return torch.tensor(total_loss_scalar, device=device, dtype=torch.float32, requires_grad=False) # ADDED: dynamic beta warmup and grad diagnostics def _apply_crd_warmup(self): target_beta = getattr(self.args, "crd_weight", 0.2) warm_steps = getattr(self.args, "crd_warmup_steps", 0) # self.state.global_step 在 accumulation 前是上一次的值,用它也足够 step = max(0, getattr(self.state, "global_step", 0)) if warm_steps and step < warm_steps: beta = target_beta * float(step + 1) / float(warm_steps) else: beta = target_beta # 同时写入 runtime_beta 与 crd_weight,保证兼容不同实现 setattr(self.loss_fn, "runtime_beta", beta) if hasattr(self.loss_fn, "crd_weight"): self.loss_fn.crd_weight = beta def _get_last_block_params(self): if self._last_block_params_cache is not None: return self._last_block_params_cache layers = _locate_lm_layers_modulelist(self.model.encoder) if layers is None or len(layers) == 0: logger.warning("Could not locate LM layers for grad debug; will use encoder parameters as fallback.") params = list(self.model.encoder.parameters()) else: params = list(layers[-1].parameters()) self._last_block_params_cache = params return params def _debug_teacher_grad(self, queries, targets, model): # 单卡禁用,避免两次整批前向导致 OOM if not self.is_ddp: return dbg_every = int(getattr(self.args, "crd_debug_every", 0) or 0) if dbg_every <= 0 or (self.state.global_step % dbg_every) != 0: return last_params = self._get_last_block_params() beta_saved = getattr(self.loss_fn, "crd_weight", 0.0) if self.is_ddp: # DDP: 两次经过 GradCache self.model.zero_grad(set_to_none=True) self.loss_fn.crd_weight = 0.0 _ = self.gc(queries, targets, no_sync_except_last=True) gn_ret = _grad_norm(last_params) self.model.zero_grad(set_to_none=True) self.loss_fn.crd_weight = beta_saved _ = self.gc(queries, targets, no_sync_except_last=True) gn_all = _grad_norm(last_params) else: # 单卡: 直接模型前向 + loss_fn self.model.zero_grad(set_to_none=True) # 取多层表征 out_q = model(qry=queries["qry"]) out_p = model(tgt=targets["tgt"]) x, y = out_q["qry_reps"], out_p["tgt_reps"] self.loss_fn.crd_weight = 0.0 loss = self.loss_fn(x, y) loss.backward() gn_ret = _grad_norm(last_params) self.model.zero_grad(set_to_none=True) self.loss_fn.crd_weight = beta_saved loss = self.loss_fn(x, y) loss.backward() gn_all = _grad_norm(last_params) print_master(f"[CRD-Debug] step={self.state.global_step} grad_norm(last-block): RET={gn_ret:.6f}, RET+CRD={gn_all:.6f}, delta={max(0.0, gn_all-gn_ret):.6f}") self.model.zero_grad(set_to_none=True) def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: model.train() queries, targets = inputs device = self._infer_device(model) queries = batch_to_device(queries, device) targets = batch_to_device(targets, device) queries, targets = {'qry': queries}, {'tgt': targets} # 动态 CRD warmup self._apply_crd_warmup() # 可选梯度诊断 try: self._debug_teacher_grad(queries, targets, model) except Exception as e: logger.warning(f"CRD grad debug failed (ignored): {e}") if self.is_ddp: # 多卡:使用 GradCache(要求模型已被 DDP 包裹,HF 会在 _wrap_model + accelerator.prepare 后处理) self.gc.models = [model, model] loss = self.gc(queries, targets, no_sync_except_last=True) else: # 单卡:手写分块两阶段,避免整批前向 OOM loss = self._single_gpu_chunked_step(model, queries["qry"], targets["tgt"], device) return loss / self._dist_loss_scale_factor def _save(self, output_dir: Optional[str] = None, state_dict=None): print_master(f"Saving model to {output_dir}") os.makedirs(output_dir, exist_ok=True) if state_dict is None: state_dict = self.model.state_dict() prefix = 'encoder.' assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} self.model.encoder.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) self.model.encoder.config.to_json_file(os.path.join(output_dir, 'config.json'))