code_SAS_VLM2Vec / src /trainer_early_exit.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import Trainer
from src.utils import batch_to_device
from src.classifier_utils import HomogeneousBatchSampler
class EarlyExitTrainer(Trainer):
def __init__(self, backbone_model, target_layer_idx, model_args, *args, **kwargs):
self.max_length = kwargs.pop('max_length', 512)
super().__init__(*args, **kwargs)
self.backbone = backbone_model.to(self.args.device)
self.backbone.eval()
self.target_layer_idx = target_layer_idx
self.model_args = model_args
# 添加梯度检查标志
self._grad_check_done = False
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = HomogeneousBatchSampler(
self.train_dataset,
batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last
)
return DataLoader(
self.train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
def create_optimizer(self):
if self.optimizer is None:
print(f"\n[Debug Rank {self.args.local_rank}] Creating Optimizer...")
decay_parameters = []
no_decay_parameters = []
trainable_count = 0
# 注意:self.model 可能是 DDP 包装后的,所以用 self.model.named_parameters()
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
trainable_count += 1
if "bias" in name or "LayerNorm" in name or "BatchNorm" in name:
no_decay_parameters.append(param)
else:
decay_parameters.append(param)
print(f"[Debug] Found {trainable_count} trainable parameters.")
self.optimizer = torch.optim.AdamW(
[
{"params": decay_parameters, "weight_decay": self.args.weight_decay},
{"params": no_decay_parameters, "weight_decay": 0.0},
],
lr=self.args.learning_rate,
eps=self.args.adam_epsilon,
)
return self.optimizer
def _perform_pooling(self, hidden_state, attention_mask):
pooling_method = self.model_args.pooling
batch_size = hidden_state.shape[0]
if pooling_method == 'last' or pooling_method == 'eos':
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
reps = hidden_state[torch.arange(batch_size), -1, :]
else:
eos_indices = attention_mask.sum(dim=1) - 1
reps = hidden_state[torch.arange(batch_size, device=hidden_state.device), eos_indices]
else:
reps = hidden_state[:, -1, :]
if self.model_args.normalize:
reps = F.normalize(reps, p=2, dim=-1)
return reps
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
覆盖 Trainer 的 compute_loss 方法
这是 Trainer 真正调用的损失计算入口
"""
loss = self._compute_early_exit_loss(model, inputs)
return (loss, None) if return_outputs else loss
def _compute_early_exit_loss(self, model, inputs) -> torch.Tensor:
"""
计算 Early Exit 分类器的损失
"""
self.backbone.eval()
model.train()
device = self.args.device
qry_inputs, tgt_inputs = inputs
qry_inputs = batch_to_device(qry_inputs, device)
tgt_inputs = batch_to_device(tgt_inputs, device)
with torch.no_grad():
# Backbone Forward
tgt_outputs = self.backbone.encoder(**tgt_inputs, return_dict=True, output_hidden_states=True)
tgt_reps = self._perform_pooling(tgt_outputs.hidden_states[-1], tgt_inputs['attention_mask'])
qry_outputs = self.backbone.encoder(**qry_inputs, return_dict=True, output_hidden_states=True)
q_hidden_mid = qry_outputs.hidden_states[self.target_layer_idx]
qry_reps_mid = self._perform_pooling(q_hidden_mid, qry_inputs['attention_mask'])
batch_size = qry_reps_mid.size(0)
# 特征工程
backbone_ptr = self.backbone.module if hasattr(self.backbone, 'module') else self.backbone
temp = getattr(backbone_ptr, 'temperature', 0.02)
sim_matrix = torch.matmul(qry_reps_mid, tgt_reps.T) / temp
diag_mask = torch.eye(batch_size, dtype=torch.bool, device=device)
sim_matrix_no_diag = sim_matrix.masked_fill(diag_mask, -1e9)
all_topk_vals, all_topk_inds = torch.topk(sim_matrix, k=2, dim=1)
feat_s1 = torch.diag(sim_matrix) # 使用对角线作为正样本分数
# 【关键修复】Margin = Top1 - Top2,而不是 Top1 - Top1
# all_topk_vals[:, 0] 是最大值(通常就是对角线)
# all_topk_vals[:, 1] 是第二大值(hard negative)
feat_margin = feat_s1 - all_topk_vals[:, 1]
probs = torch.softmax(sim_matrix, dim=1)
feat_entropy = -(probs * torch.log(probs + 1e-6)).sum(dim=1)
# Norm/Var 近似
left_padding = (qry_inputs['attention_mask'][:, -1].sum() == batch_size)
if left_padding:
q_raw_pooled = q_hidden_mid[:, -1, :]
else:
eos_indices = qry_inputs['attention_mask'].sum(dim=1) - 1
q_raw_pooled = q_hidden_mid[torch.arange(batch_size, device=device), eos_indices]
feat_norm = torch.norm(q_raw_pooled, p=2, dim=1)
feat_var = torch.var(q_raw_pooled, dim=1)
scalar_inputs = torch.stack([feat_s1, feat_margin, feat_entropy, feat_norm, feat_var], dim=1)
# 模态判断
modality_idx = torch.zeros(batch_size, dtype=torch.long, device=device)
if 'pixel_values' in qry_inputs and qry_inputs['pixel_values'] is not None:
pv = qry_inputs['pixel_values']
if isinstance(pv, list):
for i, item in enumerate(pv):
if item is not None: modality_idx[i] = 1
elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
modality_idx.fill_(1)
# 标签生成:衡量中间层检索质量
# 方法1:检查Top-1是否正确(对角线)
preds = torch.argmax(sim_matrix, dim=1)
ground_truth = torch.arange(batch_size, device=device)
is_correct = (preds == ground_truth).float()
# 方法2:使用归一化的相似度分数作为连续标签
diag_scores = torch.diag(sim_matrix) # 正样本得分
max_scores = sim_matrix.max(dim=1)[0] # 最高得分
# 混合标签:如果预测正确给1.0,否则用相对分数
# 这样可以给模型提供更丰富的监督信号
relative_scores = torch.sigmoid((diag_scores - max_scores) * 2) # 放大差异
labels = torch.where(
is_correct.bool(),
torch.ones_like(is_correct),
relative_scores
).unsqueeze(1)
# Classifier Forward
# scalar_inputs 已经 detach,不需要梯度
# 模型参数自动有 requires_grad=True
pred_probs = model(scalar_inputs, modality_idx)
loss = F.binary_cross_entropy(pred_probs, labels)
# ========================================================
# 【梯度探针】 仅在前 10 步检查梯度流
# ========================================================
if self.state.global_step < 10 and self.args.local_rank == 0:
print(f"\n[Probe Step {self.state.global_step}] Loss: {loss.item():.4f}")
print(f" - Loss has grad_fn: {loss.grad_fn is not None}")
print(f" - Pred Probs: Mean={pred_probs.mean().item():.4f}, Std={pred_probs.std().item():.4f}")
print(f" - Pred Probs Range: [{pred_probs.min().item():.4f}, {pred_probs.max().item():.4f}]")
print(f" - Labels: Mean={labels.mean().item():.4f}, Std={labels.std().item():.4f}")
print(f" - Labels Range: [{labels.min().item():.4f}, {labels.max().item():.4f}]")
print(f" - Correct Rate: {is_correct.mean().item():.4f}")
# 检查输入特征的统计
print(f" - Scalar Inputs: Mean={scalar_inputs.mean().item():.4f}, Std={scalar_inputs.std().item():.4f}")
print(f" - Modality: Text={((modality_idx==0).sum().item())}, Image={((modality_idx==1).sum().item())}")
# ========================================================
return loss
def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor:
"""
覆盖 Trainer 的 training_step 以添加梯度监控
注意:这个方法在新版 Transformers 中被调用,负责完整的前向+反向过程
Args:
model: 要训练的模型
inputs: 输入数据
num_items_in_batch: batch 中的样本数(新版 Transformers 会传入)
"""
model.train()
inputs = self._prepare_inputs(inputs)
# 计算损失
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean()
# 反向传播(使用 Accelerator)
self.accelerator.backward(loss)
# 在前几步检查梯度
if not self._grad_check_done and self.args.local_rank == 0:
print(f"\n[Gradient Check After Backward - Step {self.state.global_step}]")
inner_model = model.module if hasattr(model, 'module') else model
has_grad = False
total_grad_norm = 0.0
for name, param in inner_model.named_parameters():
if param.grad is not None:
has_grad = True
grad_norm = param.grad.norm().item()
total_grad_norm += grad_norm ** 2
if self.state.global_step < 3:
print(f" - {name}: grad_norm={grad_norm:.6f}")
total_grad_norm = total_grad_norm ** 0.5
print(f" - Total Grad Norm: {total_grad_norm:.6f}")
print(f" - Has Gradient: {has_grad}")
if self.state.global_step >= 2:
self._grad_check_done = True
return loss.detach() / self.args.gradient_accumulation_steps