code_SAS_VLM2Vec / src /trainer_early_exit_V2.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import Trainer
from src.utils import batch_to_device
from src.classifier_utils import HomogeneousBatchSampler
class EarlyExitTrainer(Trainer):
def __init__(self, backbone_model, target_layer_idx, model_args, *args, **kwargs):
self.max_length = kwargs.pop("max_length", 512)
super().__init__(*args, **kwargs)
self.backbone = backbone_model.to(self.args.device)
self.backbone.eval()
self.target_layer_idx = target_layer_idx
self.model_args = model_args
self._grad_check_done = False
# ---------------- Dataloader ----------------
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = HomogeneousBatchSampler(
self.train_dataset,
batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last,
)
return DataLoader(
self.train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
# ---------------- Optimizer ----------------
def create_optimizer(self):
if self.optimizer is None:
print(f"\n[Debug Rank {self.args.local_rank}] Creating Optimizer...")
decay_parameters = []
no_decay_parameters = []
trainable_count = 0
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
trainable_count += 1
if "bias" in name or "LayerNorm" in name or "BatchNorm" in name:
no_decay_parameters.append(param)
else:
decay_parameters.append(param)
print(f"[Debug] Found {trainable_count} trainable parameters.")
self.optimizer = torch.optim.AdamW(
[
{"params": decay_parameters, "weight_decay": self.args.weight_decay},
{"params": no_decay_parameters, "weight_decay": 0.0},
],
lr=self.args.learning_rate,
eps=self.args.adam_epsilon,
)
return self.optimizer
# ---------------- Pooling ----------------
def _perform_pooling(self, hidden_state, attention_mask):
pooling_method = self.model_args.pooling
batch_size = hidden_state.shape[0]
if pooling_method in ("last", "eos"):
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
reps = hidden_state[torch.arange(batch_size), -1, :]
else:
eos_indices = attention_mask.sum(dim=1) - 1
reps = hidden_state[
torch.arange(batch_size, device=hidden_state.device),
eos_indices,
]
else:
reps = hidden_state[:, -1, :]
if self.model_args.normalize:
reps = F.normalize(reps, p=2, dim=-1)
return reps
# ---------------- Loss entry ----------------
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
loss = self._compute_early_exit_loss(model, inputs)
return (loss, None) if return_outputs else loss
# ---------------- Core loss ----------------
def _compute_early_exit_loss(self, model, inputs) -> torch.Tensor:
"""
与 offline gating 对齐:
- 特征:复刻 feat_cols_single 中的 mid 侧特征(27 维)
- 标签:
label = 1 → need_last(mid✗ & last✓,必须跑 last)
label = 0 → safe(其它情况,可以早停)
"""
self.backbone.eval()
model.train()
device = self.args.device
qry_inputs, tgt_inputs = inputs
qry_inputs = batch_to_device(qry_inputs, device)
tgt_inputs = batch_to_device(tgt_inputs, device)
with torch.no_grad():
# -------- 1) backbone forward --------
tgt_outputs = self.backbone.encoder(
**tgt_inputs, return_dict=True, output_hidden_states=True
)
# 【关键修复】提取 target 的中间层和最后层表征
tgt_hidden_mid = tgt_outputs.hidden_states[self.target_layer_idx]
tgt_reps_mid = self._perform_pooling(
tgt_hidden_mid, tgt_inputs["attention_mask"]
)
tgt_reps_last = self._perform_pooling(
tgt_outputs.hidden_states[-1], tgt_inputs["attention_mask"]
)
qry_outputs = self.backbone.encoder(
**qry_inputs, return_dict=True, output_hidden_states=True
)
q_hidden_mid = qry_outputs.hidden_states[self.target_layer_idx]
q_hidden_last = qry_outputs.hidden_states[-1]
qry_reps_mid = self._perform_pooling(
q_hidden_mid, qry_inputs["attention_mask"]
)
qry_reps_last = self._perform_pooling(
q_hidden_last, qry_inputs["attention_mask"]
)
batch_size = qry_reps_mid.size(0)
backbone_ptr = (
self.backbone.module if hasattr(self.backbone, "module") else self.backbone
)
temp = getattr(backbone_ptr, "temperature", 0.02)
# # -------- 2) 相似度矩阵(cos & scaled)--------
# # 【关键修复】中间层相似度:qry_mid × tgt_mid(表征空间对齐)
# cos_mid = torch.matmul(qry_reps_mid, tgt_reps_mid.T) # [B,B]
# cos_last = torch.matmul(qry_reps_last, tgt_reps_last.T) # [B,B]
# scores_mid = cos_mid / temp
# probs_mid = torch.softmax(scores_mid, dim=1) # [B,B]
# diag_idx = torch.arange(batch_size, device=device)
# diag_cos = cos_mid[diag_idx, diag_idx] # s1_mid
# diag_mask = torch.eye(batch_size, dtype=torch.bool, device=device)
# offdiag_cos = cos_mid.masked_fill(diag_mask, -1e9)
# s2_cos = offdiag_cos.max(dim=1)[0] # s2_mid
# margin_mid = diag_cos - s2_cos # margin_mid
# # z_margin_mid & mad_margin_mid(batch 内标准化)
# margin_mean = margin_mid.mean()
# margin_std = margin_mid.std(unbiased=False) + 1e-6
# z_margin_mid = (margin_mid - margin_mean) / margin_std
# margin_median = margin_mid.median()
# mad = (margin_mid - margin_median).abs().median() + 1e-6
# mad_margin_mid = (margin_mid - margin_median) / mad
# # p1_mid, H_mid, gini_mid
# p1_mid = probs_mid[diag_idx, diag_idx]
# H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
# gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
# # topk 统计(在 prob 空间上)
# TOPK = min(16, batch_size)
# topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
# topk_mean = topk_vals.mean(dim=1)
# topk_std = topk_vals.std(dim=1, unbiased=False)
# topk_cv = topk_std / (topk_mean + 1e-6)
# centered = topk_vals - topk_mean.unsqueeze(1)
# var = (centered ** 2).mean(dim=1) + 1e-6
# m4 = (centered ** 4).mean(dim=1)
# topk_kurt = m4 / (var ** 2)
# topk_med = topk_vals.median(dim=1).values
# # s1_over_mean / s1_over_med (在 cos 空间)
# row_mean_cos = cos_mid.mean(dim=1)
# row_med_cos = cos_mid.median(dim=1).values
# s1_over_mean = diag_cos - row_mean_cos
# s1_over_med = diag_cos - row_med_cos
# # Shape 特征:对 prob 排序
# sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True)
# p1 = sorted_probs[:, 0]
# p2 = sorted_probs[:, 1]
# shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1)
# shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
# # slope: log(prob) vs rank 的线性回归斜率(前 R=10)
# R = min(10, sorted_probs.size(1))
# x = torch.arange(R, device=device, dtype=sorted_probs.dtype)
# x_centered = x - x.mean()
# denom = (x_centered ** 2).sum()
# y = torch.log(sorted_probs[:, :R] + 1e-6)
# slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
# # z1: diag prob 在该行 prob 分布的 z-score
# row_mean_p = probs_mid.mean(dim=1)
# row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
# z1 = (p1_mid - row_mean_p) / row_std_p
# # s1_over_sk: 用 prob 的 skewness 近似
# center_p = probs_mid - row_mean_p.unsqueeze(1)
# m3 = (center_p ** 3).mean(dim=1)
# skew = m3 / (row_std_p ** 3 + 1e-6)
# s1_over_sk = p1_mid - skew
# # tail_mean & head5_mean(在 prob 空间)
# TAIL_K = min(10, sorted_probs.size(1))
# tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
# HEAD_K = min(5, sorted_probs.size(1))
# head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
# # mask 特征:在线无法从 JSONL 重建,这里占位为 0
# mask_ratio = torch.zeros_like(diag_cos)
# mask_len = torch.zeros_like(diag_cos)
# mask_runs = torch.zeros_like(diag_cos)
# # 拼成 27 维特征
# scalar_inputs = torch.stack(
# [
# diag_cos, # s1_mid
# s2_cos, # s2_mid
# margin_mid, # margin_mid
# z_margin_mid, # z_margin_mid
# mad_margin_mid, # mad_margin_mid
# p1_mid, # p1_mid
# H_mid, # H_mid
# gini_mid, # gini_mid
# topk_mean, # topk_mean
# topk_std, # topk_std
# topk_cv, # topk_cv
# topk_kurt, # topk_kurt
# topk_med, # topk_med
# s1_over_mean, # s1_over_mean
# s1_over_med, # s1_over_med
# p1, # p1
# p2, # p2
# shape_H, # shape_H
# shape_gini, # shape_gini
# slope, # slope
# z1, # z1
# s1_over_sk, # s1_over_sk
# tail_mean, # tail_mean
# head5_mean, # head5_mean
# mask_ratio, # mask_ratio
# mask_len, # mask_len
# mask_runs, # mask_runs
# ],
# dim=1,
# )
# -------- 2) 相似度矩阵(cos & scaled)--------
cos_mid = torch.matmul(qry_reps_mid, tgt_reps_mid.T) # [B,B]
cos_last = torch.matmul(qry_reps_last, tgt_reps_last.T) # [B,B]
scores_mid = cos_mid / temp
probs_mid = torch.softmax(scores_mid, dim=1) # [B,B]
# === 特征构造:完全不依赖标签,与推理时 run_early_exit_queries 一致 ===
# s1_mid: 每行的 top1 分数
diag_cos = cos_mid.max(dim=1)[0] # [B]
# s2_mid: 每行的第二大分数
sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0]
margin_mid = diag_cos - s2_cos
# z_margin_mid / mad_margin_mid
margin_mean = margin_mid.mean()
margin_std = margin_mid.std(unbiased=False) + 1e-6
z_margin_mid = (margin_mid - margin_mean) / margin_std
margin_median = margin_mid.median()
mad = (margin_mid - margin_median).abs().median() + 1e-6
mad_margin_mid = (margin_mid - margin_median) / mad
# p1_mid, H_mid, gini_mid(基于 probs_mid)
p1_mid = probs_mid.max(dim=1)[0] # [B]
H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
# topk 统计
TOPK = min(16, probs_mid.size(1))
topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
topk_mean = topk_vals.mean(dim=1)
topk_std = topk_vals.std(dim=1, unbiased=False)
topk_cv = topk_std / (topk_mean + 1e-6)
centered = topk_vals - topk_mean.unsqueeze(1)
var = (centered ** 2).mean(dim=1) + 1e-6
m4 = (centered ** 4).mean(dim=1)
topk_kurt = m4 / (var ** 2)
topk_med = topk_vals.median(dim=1).values
# s1_over_mean / s1_over_med(行均值/中位数)
row_mean_cos = cos_mid.mean(dim=1)
row_med_cos = cos_mid.median(dim=1).values
s1_over_mean = diag_cos - row_mean_cos
s1_over_med = diag_cos - row_med_cos
# shape 特征:prob 排序
sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True)
p1 = sorted_probs[:, 0]
p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0]
shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1)
shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
# slope: log(prob) vs rank 的斜率(前 R=10)
R = min(10, sorted_probs.size(1))
x = torch.arange(R, device=device, dtype=sorted_probs.dtype)
x_centered = x - x.mean()
denom = (x_centered ** 2).sum()
y = torch.log(sorted_probs[:, :R] + 1e-6)
slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
# z1: top1 prob 在本行 prob 分布中的 z-score
row_mean_p = probs_mid.mean(dim=1)
row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
z1 = (p1_mid - row_mean_p) / row_std_p
# s1_over_sk: 用 prob 的 skewness
center_p = probs_mid - row_mean_p.unsqueeze(1)
m3 = (center_p ** 3).mean(dim=1)
skew = m3 / (row_std_p ** 3 + 1e-6)
s1_over_sk = p1_mid - skew
# tail_mean / head5_mean
TAIL_K = min(10, sorted_probs.size(1))
tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
HEAD_K = min(5, sorted_probs.size(1))
head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
# mask 相关特征:暂时置零(和推理端一致)
mask_ratio = torch.zeros_like(diag_cos)
mask_len = torch.zeros_like(diag_cos)
mask_runs = torch.zeros_like(diag_cos)
# === 拼成 27 维 scalar_inputs(顺序与推理端一致)===
scalar_inputs = torch.stack(
[
diag_cos, # s1_mid
s2_cos, # s2_mid
margin_mid, # margin_mid
z_margin_mid, # z_margin_mid
mad_margin_mid, # mad_margin_mid
p1_mid, # p1_mid
H_mid, # H_mid
gini_mid, # gini_mid
topk_mean, # topk_mean
topk_std, # topk_std
topk_cv, # topk_cv
topk_kurt, # topk_kurt
topk_med, # topk_med
s1_over_mean, # s1_over_mean
s1_over_med, # s1_over_med
p1, # p1
p2, # p2
shape_H, # shape_H
shape_gini, # shape_gini
slope, # slope
z1, # z1
s1_over_sk, # s1_over_sk
tail_mean, # tail_mean
head5_mean, # head5_mean
mask_ratio, # mask_ratio
mask_len, # mask_len
mask_runs, # mask_runs
],
dim=1,
)
# -------- 3) 模态标记 --------
modality_idx = torch.zeros(batch_size, dtype=torch.long, device=device)
if "pixel_values" in qry_inputs and qry_inputs["pixel_values"] is not None:
pv = qry_inputs["pixel_values"]
if isinstance(pv, list):
for i, item in enumerate(pv):
if item is not None:
modality_idx[i] = 1
elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
modality_idx.fill_(1)
# -------- 4) gating label:need_last 为正类 --------
gt = torch.arange(batch_size, device=device)
mid_top1 = cos_mid.argmax(dim=1)
last_top1 = cos_last.argmax(dim=1)
mid_hit = mid_top1.eq(gt)
last_hit = last_top1.eq(gt)
# 1 = need_last: mid✗ & last✓
need_last = (~mid_hit) & last_hit
safe_to_exit = ~need_last
labels = need_last.float().unsqueeze(1) # [B,1]
both_correct = mid_hit & last_hit
both_wrong = (~mid_hit) & (~last_hit)
# 首步额外打印一些对齐信息
if self.state.global_step == 0 and self.args.local_rank == 0:
diag_cos_mid = torch.diagonal(cos_mid)
diag_cos_last = torch.diagonal(cos_last)
print(f"\n[Debug] Diagonal cos_mid: mean={diag_cos_mid.mean():.4f}, "
f"min={diag_cos_mid.min():.4f}, max={diag_cos_mid.max():.4f}")
print(f"[Debug] Diagonal cos_last: mean={diag_cos_last.mean():.4f}, "
f"min={diag_cos_last.min():.4f}, max={diag_cos_last.max():.4f}")
print(f"[Debug] cos_mid row max: mean={cos_mid.max(dim=1)[0].mean():.4f}")
print(f"[Debug] cos_last row max: mean={cos_last.max(dim=1)[0].mean():.4f}")
print(f"[Debug] mid_hit rate: {mid_hit.float().mean():.2%}")
print(f"[Debug] last_hit rate: {last_hit.float().mean():.2%}")
print(f"[Debug] both_correct: {both_correct.float().mean():.2%}")
print(f"[Debug] need_last: {need_last.float().mean():.2%}")
print(f"[Debug] both_wrong: {both_wrong.float().mean():.2%}")
print(f"[Debug] safe_to_exit rate: {safe_to_exit.float().mean():.2%}")
# -------- 5) classifier forward + loss --------
logits = model(scalar_inputs, modality_idx) # [B,1]
# 现在正类=need_last,safe 是多数类
# 建议用一个固定的 pos_weight,大约 = safe / need ≈ 4
POS_WEIGHT = 5.0
pos_weight = torch.tensor([POS_WEIGHT], device=device)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss = criterion(logits, labels)
pred_probs = torch.sigmoid(logits) # 解释为 p(need_last=1)
# criterion = torch.nn.BCEWithLogitsLoss() # 不加权
# loss = criterion(logits, labels)
# pred_probs = torch.sigmoid(logits)
# # 现在正类=need_last,safe 是多数类。
# # 动态按当前 batch 的类别比例计算 pos_weight ≈ (#neg / #pos)
# # 注意:如果想更稳定,可以把 pos_ratio 在前几个 step 里打印出来,
# # 再手动写死一个常数,比如 20 或 50。
# with torch.no_grad():
# pos_ratio = labels.mean() # need_last 的比例
# neg_ratio = 1.0 - pos_ratio
# pos_weight_val = (neg_ratio / (pos_ratio + 1e-6)).clamp(max=100.0) # 上限防炸
# pos_weight = torch.tensor([pos_weight_val.item()], device=device)
# criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
# loss = criterion(logits, labels)
# pred_probs = torch.sigmoid(logits) # 解释为 p(need_last = 1)
# -------- 6) 训练早期打印 --------
if self.state.global_step < 10 and self.args.local_rank == 0:
pos_ratio = labels.mean().item()
neg_ratio = 1.0 - pos_ratio
print(f"\n[Probe Step {self.state.global_step}] Loss: {loss.item():.4f}")
print(
f" - Pred Probs (need_last=1): mean={pred_probs.mean().item():.4f}, "
f"std={pred_probs.std().item():.4f}"
)
print(
f" - Labels: need_last={pos_ratio:.4f}, safe={neg_ratio:.4f}"
)
print(
f" - mid_hit: {mid_hit.float().mean().item():.4f}, "
f"last_hit: {last_hit.float().mean().item():.4f}"
)
print(
f" - both_correct: {both_correct.float().mean().item():.4f}, "
f"both_wrong: {both_wrong.float().mean().item():.4f}"
)
print(
f" - Scalar Inputs: mean={scalar_inputs.mean().item():.4f}, "
f"std={scalar_inputs.std().item():.4f}"
)
print(
f" - Modality: Text={((modality_idx == 0).sum().item())}, "
f"Image={((modality_idx == 1).sum().item())}"
)
# print(f" - pos_weight={pos_weight.item():.2f}")
return loss
# ---------------- training_step:加梯度监控 ----------------
def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor:
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean()
self.accelerator.backward(loss)
if not self._grad_check_done and self.args.local_rank == 0:
print(
f"\n[Gradient Check After Backward - Step {self.state.global_step}]"
)
inner_model = model.module if hasattr(model, "module") else model
has_grad = False
total_grad_norm = 0.0
for name, param in inner_model.named_parameters():
if param.grad is not None:
has_grad = True
grad_norm = param.grad.norm().item()
total_grad_norm += grad_norm ** 2
if self.state.global_step < 3:
print(f" - {name}: grad_norm={grad_norm:.6f}")
total_grad_norm = total_grad_norm ** 0.5
print(f" - Total Grad Norm: {total_grad_norm:.6f}")
print(f" - Has Gradient: {has_grad}")
if self.state.global_step >= 2:
self._grad_check_done = True
return loss.detach() / self.args.gradient_accumulation_steps