code_SAS_VLM2Vec / src /trainer_early_exit_V3.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import Trainer
from src.utils import batch_to_device
from src.classifier_utils import HomogeneousBatchSampler
# 手动实现 Focal Loss,不需要额外安装包
def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
"""
Loss = -alpha * (1 - p)^gamma * log(p)
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean()
class EarlyExitTrainer(Trainer):
def __init__(self, backbone_model, target_layer_idx, model_args, *args, **kwargs):
self.max_length = kwargs.pop("max_length", 512)
super().__init__(*args, **kwargs)
self.backbone = backbone_model.to(self.args.device)
self.backbone.eval()
self.target_layer_idx = target_layer_idx
self.model_args = model_args
self._grad_check_done = False
# ---------------- Dataloader ----------------
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = HomogeneousBatchSampler(
self.train_dataset,
batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last,
)
return DataLoader(
self.train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
# ---------------- Optimizer ----------------
def create_optimizer(self):
if self.optimizer is None:
print(f"\n[Debug Rank {self.args.local_rank}] Creating Optimizer...")
decay_parameters = []
no_decay_parameters = []
trainable_count = 0
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
trainable_count += 1
if "bias" in name or "LayerNorm" in name or "BatchNorm" in name:
no_decay_parameters.append(param)
else:
decay_parameters.append(param)
print(f"[Debug] Found {trainable_count} trainable parameters.")
self.optimizer = torch.optim.AdamW(
[
{"params": decay_parameters, "weight_decay": self.args.weight_decay},
{"params": no_decay_parameters, "weight_decay": 0.0},
],
lr=self.args.learning_rate,
eps=self.args.adam_epsilon,
)
return self.optimizer
# ---------------- Pooling ----------------
def _perform_pooling(self, hidden_state, attention_mask):
pooling_method = self.model_args.pooling
batch_size = hidden_state.shape[0]
if pooling_method in ("last", "eos"):
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
reps = hidden_state[torch.arange(batch_size), -1, :]
else:
eos_indices = attention_mask.sum(dim=1) - 1
reps = hidden_state[
torch.arange(batch_size, device=hidden_state.device),
eos_indices,
]
else:
reps = hidden_state[:, -1, :]
if self.model_args.normalize:
reps = F.normalize(reps, p=2, dim=-1)
return reps
# ---------------- Loss entry ----------------
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
loss = self._compute_early_exit_loss(model, inputs)
return (loss, None) if return_outputs else loss
# ---------------- Core loss ----------------
def _compute_early_exit_loss(self, model, inputs) -> torch.Tensor:
"""
与 offline gating 对齐:
- 特征:复刻 feat_cols_single 中的 mid 侧特征(27 维)
- 标签:
label = 1 → need_last(mid✗ & last✓,必须跑 last)
label = 0 → safe(其它情况,可以早停)
"""
self.backbone.eval()
model.train()
device = self.args.device
qry_inputs, tgt_inputs = inputs
qry_inputs = batch_to_device(qry_inputs, device)
tgt_inputs = batch_to_device(tgt_inputs, device)
# === 1. Backbone 特征提取 (不需要梯度) ===
with torch.no_grad():
# -------- 1) backbone forward --------
tgt_outputs = self.backbone.encoder(
**tgt_inputs, return_dict=True, output_hidden_states=True
)
# 提取 target 的中间层和最后层表征
tgt_hidden_mid = tgt_outputs.hidden_states[self.target_layer_idx]
tgt_reps_mid = self._perform_pooling(
tgt_hidden_mid, tgt_inputs["attention_mask"]
)
tgt_reps_last = self._perform_pooling(
tgt_outputs.hidden_states[-1], tgt_inputs["attention_mask"]
)
qry_outputs = self.backbone.encoder(
**qry_inputs, return_dict=True, output_hidden_states=True
)
q_hidden_mid = qry_outputs.hidden_states[self.target_layer_idx]
q_hidden_last = qry_outputs.hidden_states[-1]
qry_reps_mid = self._perform_pooling(
q_hidden_mid, qry_inputs["attention_mask"]
)
qry_reps_last = self._perform_pooling(
q_hidden_last, qry_inputs["attention_mask"]
)
batch_size = qry_reps_mid.size(0)
backbone_ptr = (
self.backbone.module if hasattr(self.backbone, "module") else self.backbone
)
temp = getattr(backbone_ptr, "temperature", 0.02)
# -------- 2) 相似度矩阵(cos & scaled)--------
cos_mid = torch.matmul(qry_reps_mid, tgt_reps_mid.T) # [B,B]
cos_last = torch.matmul(qry_reps_last, tgt_reps_last.T) # [B,B]
scores_mid = cos_mid / temp
probs_mid = torch.softmax(scores_mid, dim=1) # [B,B]
# === 特征构造 ===
# s1_mid: 每行的 top1 分数
diag_cos = cos_mid.max(dim=1)[0] # [B]
# s2_mid: 每行的第二大分数
sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0]
margin_mid = diag_cos - s2_cos
# z_margin_mid / mad_margin_mid
margin_mean = margin_mid.mean()
margin_std = margin_mid.std(unbiased=False) + 1e-6
z_margin_mid = (margin_mid - margin_mean) / margin_std
margin_median = margin_mid.median()
mad = (margin_mid - margin_median).abs().median() + 1e-6
mad_margin_mid = (margin_mid - margin_median) / mad
# p1_mid, H_mid, gini_mid(基于 probs_mid)
p1_mid = probs_mid.max(dim=1)[0] # [B]
H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
# topk 统计
TOPK = min(16, probs_mid.size(1))
topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
topk_mean = topk_vals.mean(dim=1)
topk_std = topk_vals.std(dim=1, unbiased=False)
topk_cv = topk_std / (topk_mean + 1e-6)
centered = topk_vals - topk_mean.unsqueeze(1)
var = (centered ** 2).mean(dim=1) + 1e-6
m4 = (centered ** 4).mean(dim=1)
topk_kurt = m4 / (var ** 2)
topk_med = topk_vals.median(dim=1).values
# s1_over_mean / s1_over_med(行均值/中位数)
row_mean_cos = cos_mid.mean(dim=1)
row_med_cos = cos_mid.median(dim=1).values
s1_over_mean = diag_cos - row_mean_cos
s1_over_med = diag_cos - row_med_cos
# shape 特征:prob 排序
sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True)
p1 = sorted_probs[:, 0]
p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0]
shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1)
shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
# slope: log(prob) vs rank 的斜率(前 R=10)
R = min(10, sorted_probs.size(1))
x = torch.arange(R, device=device, dtype=sorted_probs.dtype)
x_centered = x - x.mean()
denom = (x_centered ** 2).sum()
y = torch.log(sorted_probs[:, :R] + 1e-6)
slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
# z1: top1 prob 在本行 prob 分布中的 z-score
row_mean_p = probs_mid.mean(dim=1)
row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
z1 = (p1_mid - row_mean_p) / row_std_p
# s1_over_sk: 用 prob 的 skewness
center_p = probs_mid - row_mean_p.unsqueeze(1)
m3 = (center_p ** 3).mean(dim=1)
skew = m3 / (row_std_p ** 3 + 1e-6)
s1_over_sk = p1_mid - skew
# tail_mean / head5_mean
TAIL_K = min(10, sorted_probs.size(1))
tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
HEAD_K = min(5, sorted_probs.size(1))
head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
# mask 相关特征:暂时置零
mask_ratio = torch.zeros_like(diag_cos)
mask_len = torch.zeros_like(diag_cos)
mask_runs = torch.zeros_like(diag_cos)
# === 拼成 27 维 scalar_inputs ===
scalar_inputs = torch.stack(
[
diag_cos, # s1_mid
s2_cos, # s2_mid
margin_mid, # margin_mid
z_margin_mid, # z_margin_mid
mad_margin_mid, # mad_margin_mid
p1_mid, # p1_mid
H_mid, # H_mid
gini_mid, # gini_mid
topk_mean, # topk_mean
topk_std, # topk_std
topk_cv, # topk_cv
topk_kurt, # topk_kurt
topk_med, # topk_med
s1_over_mean, # s1_over_mean
s1_over_med, # s1_over_med
p1, # p1
p2, # p2
shape_H, # shape_H
shape_gini, # shape_gini
slope, # slope
z1, # z1
s1_over_sk, # s1_over_sk
tail_mean, # tail_mean
head5_mean, # head5_mean
mask_ratio, # mask_ratio
mask_len, # mask_len
mask_runs, # mask_runs
],
dim=1,
)
# -------- 3) 模态标记 --------
modality_idx = torch.zeros(batch_size, dtype=torch.long, device=device)
if "pixel_values" in qry_inputs and qry_inputs["pixel_values"] is not None:
pv = qry_inputs["pixel_values"]
if isinstance(pv, list):
for i, item in enumerate(pv):
if item is not None:
modality_idx[i] = 1
elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
modality_idx.fill_(1)
# -------- 4) gating label --------
gt = torch.arange(batch_size, device=device)
mid_top1 = cos_mid.argmax(dim=1)
last_top1 = cos_last.argmax(dim=1)
mid_hit = mid_top1.eq(gt)
last_hit = last_top1.eq(gt)
# 1 = need_last
need_last = (~mid_hit) & last_hit
labels = need_last.float().unsqueeze(1)
both_correct = mid_hit & last_hit
both_wrong = (~mid_hit) & (~last_hit)
# ==========================================================
# [关键修改] 注意!with torch.no_grad(): 到此结束!
# 以下代码必须取消缩进,回到主函数层级,否则无法生成梯度!
# ==========================================================
# -------- 5) classifier forward (开启梯度) --------
# scalar_inputs 和 qry_reps_mid 虽然没有梯度(因为来自no_grad),
# 但 model 的权重有梯度,所以 logits 会有 grad_fn。
logits = model(scalar_inputs, modality_idx, qry_emb=qry_reps_mid)
# -------- 6) Loss: 使用 Focal Loss --------
loss = sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0)
pred_probs = torch.sigmoid(logits)
# -------- 7) 训练早期打印 --------
if self.state.global_step < 10 and self.args.local_rank == 0:
pos_ratio = labels.mean().item()
neg_ratio = 1.0 - pos_ratio
print(f"\n[Probe Step {self.state.global_step}] Loss: {loss.item():.4f}")
print(
f" - Pred Probs (need_last=1): mean={pred_probs.mean().item():.4f}, "
f"std={pred_probs.std().item():.4f}"
)
print(
f" - Labels: need_last={pos_ratio:.4f}, safe={neg_ratio:.4f}"
)
print(
f" - mid_hit: {mid_hit.float().mean().item():.4f}, "
f"last_hit: {last_hit.float().mean().item():.4f}"
)
print(
f" - both_correct: {both_correct.float().mean().item():.4f}, "
f"both_wrong: {both_wrong.float().mean().item():.4f}"
)
print(
f" - Scalar Inputs: mean={scalar_inputs.mean().item():.4f}, "
f"std={scalar_inputs.std().item():.4f}"
)
print(
f" - Modality: Text={((modality_idx == 0).sum().item())}, "
f"Image={((modality_idx == 1).sum().item())}"
)
return loss
# ---------------- training_step:加梯度监控 ----------------
def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor:
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean()
self.accelerator.backward(loss)
if not self._grad_check_done and self.args.local_rank == 0:
print(
f"\n[Gradient Check After Backward - Step {self.state.global_step}]"
)
inner_model = model.module if hasattr(model, "module") else model
has_grad = False
total_grad_norm = 0.0
for name, param in inner_model.named_parameters():
if param.grad is not None:
has_grad = True
grad_norm = param.grad.norm().item()
total_grad_norm += grad_norm ** 2
if self.state.global_step < 3:
print(f" - {name}: grad_norm={grad_norm:.6f}")
total_grad_norm = total_grad_norm ** 0.5
print(f" - Total Grad Norm: {total_grad_norm:.6f}")
print(f" - Has Gradient: {has_grad}")
if self.state.global_step >= 2:
self._grad_check_done = True
return loss.detach() / self.args.gradient_accumulation_steps