| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from transformers import Trainer |
|
|
| from src.utils import batch_to_device |
| from src.classifier_utils import HomogeneousBatchSampler |
|
|
|
|
| class EarlyExitTrainer(Trainer): |
| def __init__(self, backbone_model, target_layer_idx, model_args, *args, **kwargs): |
| self.max_length = kwargs.pop("max_length", 512) |
| super().__init__(*args, **kwargs) |
|
|
| self.backbone = backbone_model.to(self.args.device) |
| self.backbone.eval() |
|
|
| self.target_layer_idx = target_layer_idx |
| self.model_args = model_args |
|
|
| self._grad_check_done = False |
|
|
| |
| def get_train_dataloader(self) -> DataLoader: |
| if self.train_dataset is None: |
| raise ValueError("Trainer: training requires a train_dataset.") |
|
|
| train_sampler = HomogeneousBatchSampler( |
| self.train_dataset, |
| batch_size=self._train_batch_size, |
| drop_last=self.args.dataloader_drop_last, |
| ) |
|
|
| return DataLoader( |
| self.train_dataset, |
| batch_sampler=train_sampler, |
| collate_fn=self.data_collator, |
| num_workers=self.args.dataloader_num_workers, |
| pin_memory=self.args.dataloader_pin_memory, |
| ) |
|
|
| |
| def create_optimizer(self): |
| if self.optimizer is None: |
| print(f"\n[Debug Rank {self.args.local_rank}] Creating Optimizer...") |
| decay_parameters = [] |
| no_decay_parameters = [] |
| trainable_count = 0 |
|
|
| for name, param in self.model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| trainable_count += 1 |
| if "bias" in name or "LayerNorm" in name or "BatchNorm" in name: |
| no_decay_parameters.append(param) |
| else: |
| decay_parameters.append(param) |
|
|
| print(f"[Debug] Found {trainable_count} trainable parameters.") |
|
|
| self.optimizer = torch.optim.AdamW( |
| [ |
| {"params": decay_parameters, "weight_decay": self.args.weight_decay}, |
| {"params": no_decay_parameters, "weight_decay": 0.0}, |
| ], |
| lr=self.args.learning_rate, |
| eps=self.args.adam_epsilon, |
| ) |
| return self.optimizer |
|
|
| |
| def _perform_pooling(self, hidden_state, attention_mask): |
| pooling_method = self.model_args.pooling |
| batch_size = hidden_state.shape[0] |
|
|
| if pooling_method in ("last", "eos"): |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| if left_padding: |
| reps = hidden_state[torch.arange(batch_size), -1, :] |
| else: |
| eos_indices = attention_mask.sum(dim=1) - 1 |
| reps = hidden_state[ |
| torch.arange(batch_size, device=hidden_state.device), |
| eos_indices, |
| ] |
| else: |
| reps = hidden_state[:, -1, :] |
|
|
| if self.model_args.normalize: |
| reps = F.normalize(reps, p=2, dim=-1) |
| return reps |
|
|
| |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| loss = self._compute_early_exit_loss(model, inputs) |
| return (loss, None) if return_outputs else loss |
|
|
| |
| def _compute_early_exit_loss(self, model, inputs) -> torch.Tensor: |
| """ |
| 与 offline gating 对齐: |
| - 特征:复刻 feat_cols_single 中的 mid 侧特征(27 维) |
| - 标签: |
| label = 1 → need_last(mid✗ & last✓,必须跑 last) |
| label = 0 → safe(其它情况,可以早停) |
| """ |
| self.backbone.eval() |
| model.train() |
|
|
| device = self.args.device |
| qry_inputs, tgt_inputs = inputs |
| qry_inputs = batch_to_device(qry_inputs, device) |
| tgt_inputs = batch_to_device(tgt_inputs, device) |
|
|
| with torch.no_grad(): |
| |
| tgt_outputs = self.backbone.encoder( |
| **tgt_inputs, return_dict=True, output_hidden_states=True |
| ) |
| |
| tgt_hidden_mid = tgt_outputs.hidden_states[self.target_layer_idx] |
| tgt_reps_mid = self._perform_pooling( |
| tgt_hidden_mid, tgt_inputs["attention_mask"] |
| ) |
| tgt_reps_last = self._perform_pooling( |
| tgt_outputs.hidden_states[-1], tgt_inputs["attention_mask"] |
| ) |
|
|
| qry_outputs = self.backbone.encoder( |
| **qry_inputs, return_dict=True, output_hidden_states=True |
| ) |
| q_hidden_mid = qry_outputs.hidden_states[self.target_layer_idx] |
| q_hidden_last = qry_outputs.hidden_states[-1] |
|
|
| qry_reps_mid = self._perform_pooling( |
| q_hidden_mid, qry_inputs["attention_mask"] |
| ) |
| qry_reps_last = self._perform_pooling( |
| q_hidden_last, qry_inputs["attention_mask"] |
| ) |
|
|
| batch_size = qry_reps_mid.size(0) |
| backbone_ptr = ( |
| self.backbone.module if hasattr(self.backbone, "module") else self.backbone |
| ) |
| temp = getattr(backbone_ptr, "temperature", 0.02) |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| cos_mid = torch.matmul(qry_reps_mid, tgt_reps_mid.T) |
| cos_last = torch.matmul(qry_reps_last, tgt_reps_last.T) |
|
|
| scores_mid = cos_mid / temp |
| probs_mid = torch.softmax(scores_mid, dim=1) |
|
|
| |
| |
| diag_cos = cos_mid.max(dim=1)[0] |
| |
| sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) |
| s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0] |
| margin_mid = diag_cos - s2_cos |
|
|
| |
| margin_mean = margin_mid.mean() |
| margin_std = margin_mid.std(unbiased=False) + 1e-6 |
| z_margin_mid = (margin_mid - margin_mean) / margin_std |
|
|
| margin_median = margin_mid.median() |
| mad = (margin_mid - margin_median).abs().median() + 1e-6 |
| mad_margin_mid = (margin_mid - margin_median) / mad |
|
|
| |
| p1_mid = probs_mid.max(dim=1)[0] |
| H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) |
| gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) |
|
|
| |
| TOPK = min(16, probs_mid.size(1)) |
| topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) |
| topk_mean = topk_vals.mean(dim=1) |
| topk_std = topk_vals.std(dim=1, unbiased=False) |
| topk_cv = topk_std / (topk_mean + 1e-6) |
|
|
| centered = topk_vals - topk_mean.unsqueeze(1) |
| var = (centered ** 2).mean(dim=1) + 1e-6 |
| m4 = (centered ** 4).mean(dim=1) |
| topk_kurt = m4 / (var ** 2) |
| topk_med = topk_vals.median(dim=1).values |
|
|
| |
| row_mean_cos = cos_mid.mean(dim=1) |
| row_med_cos = cos_mid.median(dim=1).values |
| s1_over_mean = diag_cos - row_mean_cos |
| s1_over_med = diag_cos - row_med_cos |
|
|
| |
| sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True) |
| p1 = sorted_probs[:, 0] |
| p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0] |
|
|
| shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1) |
| shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) |
|
|
| |
| R = min(10, sorted_probs.size(1)) |
| x = torch.arange(R, device=device, dtype=sorted_probs.dtype) |
| x_centered = x - x.mean() |
| denom = (x_centered ** 2).sum() |
| y = torch.log(sorted_probs[:, :R] + 1e-6) |
| slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom |
|
|
| |
| row_mean_p = probs_mid.mean(dim=1) |
| row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 |
| z1 = (p1_mid - row_mean_p) / row_std_p |
|
|
| |
| center_p = probs_mid - row_mean_p.unsqueeze(1) |
| m3 = (center_p ** 3).mean(dim=1) |
| skew = m3 / (row_std_p ** 3 + 1e-6) |
| s1_over_sk = p1_mid - skew |
|
|
| |
| TAIL_K = min(10, sorted_probs.size(1)) |
| tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) |
| HEAD_K = min(5, sorted_probs.size(1)) |
| head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) |
|
|
| |
| mask_ratio = torch.zeros_like(diag_cos) |
| mask_len = torch.zeros_like(diag_cos) |
| mask_runs = torch.zeros_like(diag_cos) |
|
|
| |
| scalar_inputs = torch.stack( |
| [ |
| diag_cos, |
| s2_cos, |
| margin_mid, |
| z_margin_mid, |
| mad_margin_mid, |
| p1_mid, |
| H_mid, |
| gini_mid, |
| topk_mean, |
| topk_std, |
| topk_cv, |
| topk_kurt, |
| topk_med, |
| s1_over_mean, |
| s1_over_med, |
| p1, |
| p2, |
| shape_H, |
| shape_gini, |
| slope, |
| z1, |
| s1_over_sk, |
| tail_mean, |
| head5_mean, |
| mask_ratio, |
| mask_len, |
| mask_runs, |
| ], |
| dim=1, |
| ) |
|
|
| |
| modality_idx = torch.zeros(batch_size, dtype=torch.long, device=device) |
| if "pixel_values" in qry_inputs and qry_inputs["pixel_values"] is not None: |
| pv = qry_inputs["pixel_values"] |
| if isinstance(pv, list): |
| for i, item in enumerate(pv): |
| if item is not None: |
| modality_idx[i] = 1 |
| elif isinstance(pv, torch.Tensor) and pv.numel() > 0: |
| modality_idx.fill_(1) |
|
|
| |
| gt = torch.arange(batch_size, device=device) |
| mid_top1 = cos_mid.argmax(dim=1) |
| last_top1 = cos_last.argmax(dim=1) |
|
|
| mid_hit = mid_top1.eq(gt) |
| last_hit = last_top1.eq(gt) |
|
|
| |
| need_last = (~mid_hit) & last_hit |
| safe_to_exit = ~need_last |
|
|
| labels = need_last.float().unsqueeze(1) |
|
|
| both_correct = mid_hit & last_hit |
| both_wrong = (~mid_hit) & (~last_hit) |
|
|
| |
| if self.state.global_step == 0 and self.args.local_rank == 0: |
| diag_cos_mid = torch.diagonal(cos_mid) |
| diag_cos_last = torch.diagonal(cos_last) |
| print(f"\n[Debug] Diagonal cos_mid: mean={diag_cos_mid.mean():.4f}, " |
| f"min={diag_cos_mid.min():.4f}, max={diag_cos_mid.max():.4f}") |
| print(f"[Debug] Diagonal cos_last: mean={diag_cos_last.mean():.4f}, " |
| f"min={diag_cos_last.min():.4f}, max={diag_cos_last.max():.4f}") |
| print(f"[Debug] cos_mid row max: mean={cos_mid.max(dim=1)[0].mean():.4f}") |
| print(f"[Debug] cos_last row max: mean={cos_last.max(dim=1)[0].mean():.4f}") |
| print(f"[Debug] mid_hit rate: {mid_hit.float().mean():.2%}") |
| print(f"[Debug] last_hit rate: {last_hit.float().mean():.2%}") |
| print(f"[Debug] both_correct: {both_correct.float().mean():.2%}") |
| print(f"[Debug] need_last: {need_last.float().mean():.2%}") |
| print(f"[Debug] both_wrong: {both_wrong.float().mean():.2%}") |
| print(f"[Debug] safe_to_exit rate: {safe_to_exit.float().mean():.2%}") |
|
|
| |
| logits = model(scalar_inputs, modality_idx) |
|
|
| |
| |
| POS_WEIGHT = 5.0 |
| pos_weight = torch.tensor([POS_WEIGHT], device=device) |
|
|
| criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
| loss = criterion(logits, labels) |
|
|
| pred_probs = torch.sigmoid(logits) |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| if self.state.global_step < 10 and self.args.local_rank == 0: |
| pos_ratio = labels.mean().item() |
| neg_ratio = 1.0 - pos_ratio |
| print(f"\n[Probe Step {self.state.global_step}] Loss: {loss.item():.4f}") |
| print( |
| f" - Pred Probs (need_last=1): mean={pred_probs.mean().item():.4f}, " |
| f"std={pred_probs.std().item():.4f}" |
| ) |
| print( |
| f" - Labels: need_last={pos_ratio:.4f}, safe={neg_ratio:.4f}" |
| ) |
| print( |
| f" - mid_hit: {mid_hit.float().mean().item():.4f}, " |
| f"last_hit: {last_hit.float().mean().item():.4f}" |
| ) |
| print( |
| f" - both_correct: {both_correct.float().mean().item():.4f}, " |
| f"both_wrong: {both_wrong.float().mean().item():.4f}" |
| ) |
| print( |
| f" - Scalar Inputs: mean={scalar_inputs.mean().item():.4f}, " |
| f"std={scalar_inputs.std().item():.4f}" |
| ) |
| print( |
| f" - Modality: Text={((modality_idx == 0).sum().item())}, " |
| f"Image={((modality_idx == 1).sum().item())}" |
| ) |
| |
|
|
| return loss |
|
|
| |
| def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor: |
| model.train() |
| inputs = self._prepare_inputs(inputs) |
|
|
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs) |
|
|
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
|
|
| self.accelerator.backward(loss) |
|
|
| if not self._grad_check_done and self.args.local_rank == 0: |
| print( |
| f"\n[Gradient Check After Backward - Step {self.state.global_step}]" |
| ) |
| inner_model = model.module if hasattr(model, "module") else model |
| has_grad = False |
| total_grad_norm = 0.0 |
| for name, param in inner_model.named_parameters(): |
| if param.grad is not None: |
| has_grad = True |
| grad_norm = param.grad.norm().item() |
| total_grad_norm += grad_norm ** 2 |
| if self.state.global_step < 3: |
| print(f" - {name}: grad_norm={grad_norm:.6f}") |
|
|
| total_grad_norm = total_grad_norm ** 0.5 |
| print(f" - Total Grad Norm: {total_grad_norm:.6f}") |
| print(f" - Has Gradient: {has_grad}") |
|
|
| if self.state.global_step >= 2: |
| self._grad_check_done = True |
|
|
| return loss.detach() / self.args.gradient_accumulation_steps |