| import torch |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader |
| from transformers import Trainer |
| from src.utils import batch_to_device |
| from src.classifier_utils import HomogeneousBatchSampler |
|
|
| class EarlyExitTrainer(Trainer): |
| def __init__(self, backbone_model, target_layer_idx, model_args, *args, **kwargs): |
| self.max_length = kwargs.pop('max_length', 512) |
| super().__init__(*args, **kwargs) |
| |
| self.backbone = backbone_model.to(self.args.device) |
| self.backbone.eval() |
| |
| self.target_layer_idx = target_layer_idx |
| self.model_args = model_args |
| |
| |
| self._grad_check_done = False |
|
|
| def get_train_dataloader(self) -> DataLoader: |
| if self.train_dataset is None: |
| raise ValueError("Trainer: training requires a train_dataset.") |
| train_sampler = HomogeneousBatchSampler( |
| self.train_dataset, |
| batch_size=self._train_batch_size, |
| drop_last=self.args.dataloader_drop_last |
| ) |
| return DataLoader( |
| self.train_dataset, |
| batch_sampler=train_sampler, |
| collate_fn=self.data_collator, |
| num_workers=self.args.dataloader_num_workers, |
| pin_memory=self.args.dataloader_pin_memory, |
| ) |
|
|
| def create_optimizer(self): |
| if self.optimizer is None: |
| print(f"\n[Debug Rank {self.args.local_rank}] Creating Optimizer...") |
| decay_parameters = [] |
| no_decay_parameters = [] |
| trainable_count = 0 |
| |
| |
| for name, param in self.model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| trainable_count += 1 |
| if "bias" in name or "LayerNorm" in name or "BatchNorm" in name: |
| no_decay_parameters.append(param) |
| else: |
| decay_parameters.append(param) |
| |
| print(f"[Debug] Found {trainable_count} trainable parameters.") |
| |
| self.optimizer = torch.optim.AdamW( |
| [ |
| {"params": decay_parameters, "weight_decay": self.args.weight_decay}, |
| {"params": no_decay_parameters, "weight_decay": 0.0}, |
| ], |
| lr=self.args.learning_rate, |
| eps=self.args.adam_epsilon, |
| ) |
| return self.optimizer |
|
|
| def _perform_pooling(self, hidden_state, attention_mask): |
| pooling_method = self.model_args.pooling |
| batch_size = hidden_state.shape[0] |
| if pooling_method == 'last' or pooling_method == 'eos': |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| if left_padding: |
| reps = hidden_state[torch.arange(batch_size), -1, :] |
| else: |
| eos_indices = attention_mask.sum(dim=1) - 1 |
| reps = hidden_state[torch.arange(batch_size, device=hidden_state.device), eos_indices] |
| else: |
| reps = hidden_state[:, -1, :] |
| if self.model_args.normalize: |
| reps = F.normalize(reps, p=2, dim=-1) |
| return reps |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| """ |
| 覆盖 Trainer 的 compute_loss 方法 |
| 这是 Trainer 真正调用的损失计算入口 |
| """ |
| loss = self._compute_early_exit_loss(model, inputs) |
| return (loss, None) if return_outputs else loss |
|
|
| def _compute_early_exit_loss(self, model, inputs) -> torch.Tensor: |
| """ |
| 计算 Early Exit 分类器的损失 |
| """ |
| self.backbone.eval() |
| model.train() |
| |
| device = self.args.device |
| qry_inputs, tgt_inputs = inputs |
| qry_inputs = batch_to_device(qry_inputs, device) |
| tgt_inputs = batch_to_device(tgt_inputs, device) |
| |
| with torch.no_grad(): |
| |
| tgt_outputs = self.backbone.encoder(**tgt_inputs, return_dict=True, output_hidden_states=True) |
| tgt_reps = self._perform_pooling(tgt_outputs.hidden_states[-1], tgt_inputs['attention_mask']) |
| |
| qry_outputs = self.backbone.encoder(**qry_inputs, return_dict=True, output_hidden_states=True) |
| q_hidden_mid = qry_outputs.hidden_states[self.target_layer_idx] |
| qry_reps_mid = self._perform_pooling(q_hidden_mid, qry_inputs['attention_mask']) |
| |
| batch_size = qry_reps_mid.size(0) |
| |
| |
| backbone_ptr = self.backbone.module if hasattr(self.backbone, 'module') else self.backbone |
| temp = getattr(backbone_ptr, 'temperature', 0.02) |
| |
| sim_matrix = torch.matmul(qry_reps_mid, tgt_reps.T) / temp |
| diag_mask = torch.eye(batch_size, dtype=torch.bool, device=device) |
| sim_matrix_no_diag = sim_matrix.masked_fill(diag_mask, -1e9) |
| |
| all_topk_vals, all_topk_inds = torch.topk(sim_matrix, k=2, dim=1) |
| feat_s1 = torch.diag(sim_matrix) |
| |
| |
| |
| feat_margin = feat_s1 - all_topk_vals[:, 1] |
| |
| probs = torch.softmax(sim_matrix, dim=1) |
| feat_entropy = -(probs * torch.log(probs + 1e-6)).sum(dim=1) |
| |
| |
| left_padding = (qry_inputs['attention_mask'][:, -1].sum() == batch_size) |
| if left_padding: |
| q_raw_pooled = q_hidden_mid[:, -1, :] |
| else: |
| eos_indices = qry_inputs['attention_mask'].sum(dim=1) - 1 |
| q_raw_pooled = q_hidden_mid[torch.arange(batch_size, device=device), eos_indices] |
| feat_norm = torch.norm(q_raw_pooled, p=2, dim=1) |
| feat_var = torch.var(q_raw_pooled, dim=1) |
| |
| scalar_inputs = torch.stack([feat_s1, feat_margin, feat_entropy, feat_norm, feat_var], dim=1) |
| |
| |
| modality_idx = torch.zeros(batch_size, dtype=torch.long, device=device) |
| if 'pixel_values' in qry_inputs and qry_inputs['pixel_values'] is not None: |
| pv = qry_inputs['pixel_values'] |
| if isinstance(pv, list): |
| for i, item in enumerate(pv): |
| if item is not None: modality_idx[i] = 1 |
| elif isinstance(pv, torch.Tensor) and pv.numel() > 0: |
| modality_idx.fill_(1) |
|
|
| |
| |
| preds = torch.argmax(sim_matrix, dim=1) |
| ground_truth = torch.arange(batch_size, device=device) |
| is_correct = (preds == ground_truth).float() |
| |
| |
| diag_scores = torch.diag(sim_matrix) |
| max_scores = sim_matrix.max(dim=1)[0] |
| |
| |
| |
| relative_scores = torch.sigmoid((diag_scores - max_scores) * 2) |
| labels = torch.where( |
| is_correct.bool(), |
| torch.ones_like(is_correct), |
| relative_scores |
| ).unsqueeze(1) |
|
|
| |
| |
| |
| pred_probs = model(scalar_inputs, modality_idx) |
| loss = F.binary_cross_entropy(pred_probs, labels) |
|
|
| |
| |
| |
| if self.state.global_step < 10 and self.args.local_rank == 0: |
| print(f"\n[Probe Step {self.state.global_step}] Loss: {loss.item():.4f}") |
| print(f" - Loss has grad_fn: {loss.grad_fn is not None}") |
| print(f" - Pred Probs: Mean={pred_probs.mean().item():.4f}, Std={pred_probs.std().item():.4f}") |
| print(f" - Pred Probs Range: [{pred_probs.min().item():.4f}, {pred_probs.max().item():.4f}]") |
| print(f" - Labels: Mean={labels.mean().item():.4f}, Std={labels.std().item():.4f}") |
| print(f" - Labels Range: [{labels.min().item():.4f}, {labels.max().item():.4f}]") |
| print(f" - Correct Rate: {is_correct.mean().item():.4f}") |
| |
| |
| print(f" - Scalar Inputs: Mean={scalar_inputs.mean().item():.4f}, Std={scalar_inputs.std().item():.4f}") |
| print(f" - Modality: Text={((modality_idx==0).sum().item())}, Image={((modality_idx==1).sum().item())}") |
| |
|
|
| return loss |
| |
| def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor: |
| """ |
| 覆盖 Trainer 的 training_step 以添加梯度监控 |
| 注意:这个方法在新版 Transformers 中被调用,负责完整的前向+反向过程 |
| |
| Args: |
| model: 要训练的模型 |
| inputs: 输入数据 |
| num_items_in_batch: batch 中的样本数(新版 Transformers 会传入) |
| """ |
| model.train() |
| inputs = self._prepare_inputs(inputs) |
| |
| |
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs) |
| |
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
| |
| |
| self.accelerator.backward(loss) |
| |
| |
| if not self._grad_check_done and self.args.local_rank == 0: |
| print(f"\n[Gradient Check After Backward - Step {self.state.global_step}]") |
| inner_model = model.module if hasattr(model, 'module') else model |
| has_grad = False |
| total_grad_norm = 0.0 |
| for name, param in inner_model.named_parameters(): |
| if param.grad is not None: |
| has_grad = True |
| grad_norm = param.grad.norm().item() |
| total_grad_norm += grad_norm ** 2 |
| if self.state.global_step < 3: |
| print(f" - {name}: grad_norm={grad_norm:.6f}") |
| |
| total_grad_norm = total_grad_norm ** 0.5 |
| print(f" - Total Grad Norm: {total_grad_norm:.6f}") |
| print(f" - Has Gradient: {has_grad}") |
| |
| if self.state.global_step >= 2: |
| self._grad_check_done = True |
| |
| return loss.detach() / self.args.gradient_accumulation_steps |