import torch import torch.nn.functional as F from torch.utils.data import DataLoader from transformers import Trainer from src.utils import batch_to_device from src.classifier_utils import HomogeneousBatchSampler # 手动实现 Focal Loss,不需要额外安装包 def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): """ Loss = -alpha * (1 - p)^gamma * log(p) """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.mean() class EarlyExitTrainer(Trainer): def __init__(self, backbone_model, target_layer_idx, model_args, *args, **kwargs): self.max_length = kwargs.pop("max_length", 512) super().__init__(*args, **kwargs) self.backbone = backbone_model.to(self.args.device) self.backbone.eval() self.target_layer_idx = target_layer_idx self.model_args = model_args self._grad_check_done = False # ---------------- Dataloader ---------------- def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_sampler = HomogeneousBatchSampler( self.train_dataset, batch_size=self._train_batch_size, drop_last=self.args.dataloader_drop_last, ) return DataLoader( self.train_dataset, batch_sampler=train_sampler, collate_fn=self.data_collator, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, ) # ---------------- Optimizer ---------------- def create_optimizer(self): if self.optimizer is None: print(f"\n[Debug Rank {self.args.local_rank}] Creating Optimizer...") decay_parameters = [] no_decay_parameters = [] trainable_count = 0 for name, param in self.model.named_parameters(): if not param.requires_grad: continue trainable_count += 1 if "bias" in name or "LayerNorm" in name or "BatchNorm" in name: no_decay_parameters.append(param) else: decay_parameters.append(param) print(f"[Debug] Found {trainable_count} trainable parameters.") self.optimizer = torch.optim.AdamW( [ {"params": decay_parameters, "weight_decay": self.args.weight_decay}, {"params": no_decay_parameters, "weight_decay": 0.0}, ], lr=self.args.learning_rate, eps=self.args.adam_epsilon, ) return self.optimizer # ---------------- Pooling ---------------- def _perform_pooling(self, hidden_state, attention_mask): pooling_method = self.model_args.pooling batch_size = hidden_state.shape[0] if pooling_method in ("last", "eos"): left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: reps = hidden_state[torch.arange(batch_size), -1, :] else: eos_indices = attention_mask.sum(dim=1) - 1 reps = hidden_state[ torch.arange(batch_size, device=hidden_state.device), eos_indices, ] else: reps = hidden_state[:, -1, :] if self.model_args.normalize: reps = F.normalize(reps, p=2, dim=-1) return reps # ---------------- Loss entry ---------------- def compute_loss(self, model, inputs, return_outputs=False, **kwargs): loss = self._compute_early_exit_loss(model, inputs) return (loss, None) if return_outputs else loss # ---------------- Core loss ---------------- def _compute_early_exit_loss(self, model, inputs) -> torch.Tensor: """ 与 offline gating 对齐: - 特征:复刻 feat_cols_single 中的 mid 侧特征(27 维) - 标签: label = 1 → need_last(mid✗ & last✓,必须跑 last) label = 0 → safe(其它情况,可以早停) """ self.backbone.eval() model.train() device = self.args.device qry_inputs, tgt_inputs = inputs qry_inputs = batch_to_device(qry_inputs, device) tgt_inputs = batch_to_device(tgt_inputs, device) # === 1. Backbone 特征提取 (不需要梯度) === with torch.no_grad(): # -------- 1) backbone forward -------- tgt_outputs = self.backbone.encoder( **tgt_inputs, return_dict=True, output_hidden_states=True ) # 提取 target 的中间层和最后层表征 tgt_hidden_mid = tgt_outputs.hidden_states[self.target_layer_idx] tgt_reps_mid = self._perform_pooling( tgt_hidden_mid, tgt_inputs["attention_mask"] ) tgt_reps_last = self._perform_pooling( tgt_outputs.hidden_states[-1], tgt_inputs["attention_mask"] ) qry_outputs = self.backbone.encoder( **qry_inputs, return_dict=True, output_hidden_states=True ) q_hidden_mid = qry_outputs.hidden_states[self.target_layer_idx] q_hidden_last = qry_outputs.hidden_states[-1] qry_reps_mid = self._perform_pooling( q_hidden_mid, qry_inputs["attention_mask"] ) qry_reps_last = self._perform_pooling( q_hidden_last, qry_inputs["attention_mask"] ) batch_size = qry_reps_mid.size(0) backbone_ptr = ( self.backbone.module if hasattr(self.backbone, "module") else self.backbone ) temp = getattr(backbone_ptr, "temperature", 0.02) # -------- 2) 相似度矩阵(cos & scaled)-------- cos_mid = torch.matmul(qry_reps_mid, tgt_reps_mid.T) # [B,B] cos_last = torch.matmul(qry_reps_last, tgt_reps_last.T) # [B,B] scores_mid = cos_mid / temp probs_mid = torch.softmax(scores_mid, dim=1) # [B,B] # === 特征构造 === # s1_mid: 每行的 top1 分数 diag_cos = cos_mid.max(dim=1)[0] # [B] # s2_mid: 每行的第二大分数 sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0] margin_mid = diag_cos - s2_cos # z_margin_mid / mad_margin_mid margin_mean = margin_mid.mean() margin_std = margin_mid.std(unbiased=False) + 1e-6 z_margin_mid = (margin_mid - margin_mean) / margin_std margin_median = margin_mid.median() mad = (margin_mid - margin_median).abs().median() + 1e-6 mad_margin_mid = (margin_mid - margin_median) / mad # p1_mid, H_mid, gini_mid(基于 probs_mid) p1_mid = probs_mid.max(dim=1)[0] # [B] H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) # topk 统计 TOPK = min(16, probs_mid.size(1)) topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) topk_mean = topk_vals.mean(dim=1) topk_std = topk_vals.std(dim=1, unbiased=False) topk_cv = topk_std / (topk_mean + 1e-6) centered = topk_vals - topk_mean.unsqueeze(1) var = (centered ** 2).mean(dim=1) + 1e-6 m4 = (centered ** 4).mean(dim=1) topk_kurt = m4 / (var ** 2) topk_med = topk_vals.median(dim=1).values # s1_over_mean / s1_over_med(行均值/中位数) row_mean_cos = cos_mid.mean(dim=1) row_med_cos = cos_mid.median(dim=1).values s1_over_mean = diag_cos - row_mean_cos s1_over_med = diag_cos - row_med_cos # shape 特征:prob 排序 sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True) p1 = sorted_probs[:, 0] p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0] shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1) shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) # slope: log(prob) vs rank 的斜率(前 R=10) R = min(10, sorted_probs.size(1)) x = torch.arange(R, device=device, dtype=sorted_probs.dtype) x_centered = x - x.mean() denom = (x_centered ** 2).sum() y = torch.log(sorted_probs[:, :R] + 1e-6) slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom # z1: top1 prob 在本行 prob 分布中的 z-score row_mean_p = probs_mid.mean(dim=1) row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 z1 = (p1_mid - row_mean_p) / row_std_p # s1_over_sk: 用 prob 的 skewness center_p = probs_mid - row_mean_p.unsqueeze(1) m3 = (center_p ** 3).mean(dim=1) skew = m3 / (row_std_p ** 3 + 1e-6) s1_over_sk = p1_mid - skew # tail_mean / head5_mean TAIL_K = min(10, sorted_probs.size(1)) tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) HEAD_K = min(5, sorted_probs.size(1)) head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) # mask 相关特征:暂时置零 mask_ratio = torch.zeros_like(diag_cos) mask_len = torch.zeros_like(diag_cos) mask_runs = torch.zeros_like(diag_cos) # === 拼成 27 维 scalar_inputs === scalar_inputs = torch.stack( [ diag_cos, # s1_mid s2_cos, # s2_mid margin_mid, # margin_mid z_margin_mid, # z_margin_mid mad_margin_mid, # mad_margin_mid p1_mid, # p1_mid H_mid, # H_mid gini_mid, # gini_mid topk_mean, # topk_mean topk_std, # topk_std topk_cv, # topk_cv topk_kurt, # topk_kurt topk_med, # topk_med s1_over_mean, # s1_over_mean s1_over_med, # s1_over_med p1, # p1 p2, # p2 shape_H, # shape_H shape_gini, # shape_gini slope, # slope z1, # z1 s1_over_sk, # s1_over_sk tail_mean, # tail_mean head5_mean, # head5_mean mask_ratio, # mask_ratio mask_len, # mask_len mask_runs, # mask_runs ], dim=1, ) # -------- 3) 模态标记 -------- modality_idx = torch.zeros(batch_size, dtype=torch.long, device=device) if "pixel_values" in qry_inputs and qry_inputs["pixel_values"] is not None: pv = qry_inputs["pixel_values"] if isinstance(pv, list): for i, item in enumerate(pv): if item is not None: modality_idx[i] = 1 elif isinstance(pv, torch.Tensor) and pv.numel() > 0: modality_idx.fill_(1) # -------- 4) gating label -------- gt = torch.arange(batch_size, device=device) mid_top1 = cos_mid.argmax(dim=1) last_top1 = cos_last.argmax(dim=1) mid_hit = mid_top1.eq(gt) last_hit = last_top1.eq(gt) # 1 = need_last need_last = (~mid_hit) & last_hit labels = need_last.float().unsqueeze(1) both_correct = mid_hit & last_hit both_wrong = (~mid_hit) & (~last_hit) # ========================================================== # [关键修改] 注意!with torch.no_grad(): 到此结束! # 以下代码必须取消缩进,回到主函数层级,否则无法生成梯度! # ========================================================== # -------- 5) classifier forward (开启梯度) -------- # scalar_inputs 和 qry_reps_mid 虽然没有梯度(因为来自no_grad), # 但 model 的权重有梯度,所以 logits 会有 grad_fn。 logits = model(scalar_inputs, modality_idx, qry_emb=qry_reps_mid) # -------- 6) Loss: Focal Loss 参数调整 -------- # 原始: alpha=0.25 (关注负类), gamma=2 # 新版建议: alpha=0.75 (关注正类 Need Last), gamma=2 或 3 # 解释: alpha > 0.5 意味着我们要给正样本(少样本)更大的权重 loss = sigmoid_focal_loss(logits, labels, alpha=0.75, gamma=3.0) pred_probs = torch.sigmoid(logits) # -------- 7) 训练早期打印 -------- if self.state.global_step < 10 and self.args.local_rank == 0: pos_ratio = labels.mean().item() neg_ratio = 1.0 - pos_ratio print(f"\n[Probe Step {self.state.global_step}] Loss: {loss.item():.4f}") print( f" - Pred Probs (need_last=1): mean={pred_probs.mean().item():.4f}, " f"std={pred_probs.std().item():.4f}" ) print( f" - Labels: need_last={pos_ratio:.4f}, safe={neg_ratio:.4f}" ) print( f" - mid_hit: {mid_hit.float().mean().item():.4f}, " f"last_hit: {last_hit.float().mean().item():.4f}" ) print( f" - both_correct: {both_correct.float().mean().item():.4f}, " f"both_wrong: {both_wrong.float().mean().item():.4f}" ) print( f" - Scalar Inputs: mean={scalar_inputs.mean().item():.4f}, " f"std={scalar_inputs.std().item():.4f}" ) print( f" - Modality: Text={((modality_idx == 0).sum().item())}, " f"Image={((modality_idx == 1).sum().item())}" ) return loss # ---------------- training_step:加梯度监控 ---------------- def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor: model.train() inputs = self._prepare_inputs(inputs) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: loss = loss.mean() self.accelerator.backward(loss) if not self._grad_check_done and self.args.local_rank == 0: print( f"\n[Gradient Check After Backward - Step {self.state.global_step}]" ) inner_model = model.module if hasattr(model, "module") else model has_grad = False total_grad_norm = 0.0 for name, param in inner_model.named_parameters(): if param.grad is not None: has_grad = True grad_norm = param.grad.norm().item() total_grad_norm += grad_norm ** 2 if self.state.global_step < 3: print(f" - {name}: grad_norm={grad_norm:.6f}") total_grad_norm = total_grad_norm ** 0.5 print(f" - Total Grad Norm: {total_grad_norm:.6f}") print(f" - Has Gradient: {has_grad}") if self.state.global_step >= 2: self._grad_check_done = True return loss.detach() / self.args.gradient_accumulation_steps