from torch import nn, Tensor import torch.distributed as dist import torch import torch.nn.functional as F import os # 新增 class SimpleContrastiveLoss: def __init__(self, temperature: float = 0.02, alpha: float = 0.05, weights=None): """ weights: list[float] or None - 若提供 weights,则用于对多个视角/层的 CE 加权,长度需与视角数一致,训练时会归一化。 - 若为 None 且 K==2,退化为 [alpha, 1-alpha] - 若为 None 且 K>2,默认均匀权重 """ self.temperature = temperature self.alpha = alpha self.weights = weights # e.g. [0.1, 0.2, 0.7] def _get_weights(self, K: int, device): if self.weights is not None: assert len(self.weights) == K, f"weights length {len(self.weights)} != K={K}" w = torch.tensor(self.weights, dtype=torch.float32, device=device) w = torch.clamp(w, min=0) s = w.sum().item() if s <= 0: w = torch.ones(K, device=device) / K else: w = w / s return w if K == 2: w = torch.tensor([self.alpha, 1.0 - self.alpha], dtype=torch.float32, device=device) return torch.clamp(w, min=0) / max(w.sum().item(), 1e-8) # default uniform return torch.ones(K, dtype=torch.float32, device=device) / K def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: """ 统一支持: - x=[B, D], y=[B, D] -> 单视角 - x=[B, K, D], y=[B, D] -> K 个 query 视角对单一候选视角 - x=[B, D], y=[B, K, D] -> 单一 query 视角对 K 个候选视角 - x=[B, K, D], y=[B, K, D] -> 逐视角配对(k↔k)加权 """ B = x.size(0) if target is None: target_per_qry = y.size(0) // B target = torch.arange(0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long) # 单视角 if x.dim() == 2 and y.dim() == 2: logits = torch.matmul(x, y.transpose(0, 1)) / self.temperature return F.cross_entropy(logits, target, reduction=reduction) # 多视角 query, 单视角 cand if x.dim() == 3 and y.dim() == 2: K = x.size(1) w = self._get_weights(K, x.device) loss = 0.0 for k in range(K): logits_k = torch.matmul(x[:, k, :], y.transpose(0, 1)) / self.temperature loss_k = F.cross_entropy(logits_k, target, reduction=reduction) loss = loss + w[k] * loss_k return loss # 单视角 query, 多视角 cand if x.dim() == 2 and y.dim() == 3: K = y.size(1) w = self._get_weights(K, x.device) loss = 0.0 for k in range(K): logits_k = torch.matmul(x, y[:, k, :].transpose(0, 1)) / self.temperature loss_k = F.cross_entropy(logits_k, target, reduction=reduction) loss = loss + w[k] * loss_k return loss # 多视角配对(k↔k) if x.dim() == 3 and y.dim() == 3: Kx, Ky = x.size(1), y.size(1) assert Kx == Ky, f"view mismatch: {Kx} vs {Ky}" K = Kx w = self._get_weights(K, x.device) loss = 0.0 for k in range(K): logits_k = torch.matmul(x[:, k, :], y[:, k, :].transpose(0, 1)) / self.temperature loss_k = F.cross_entropy(logits_k, target, reduction=reduction) loss = loss + w[k] * loss_k return loss raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}") class DistributedContrastiveLoss(SimpleContrastiveLoss): def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05, weights=None): assert dist.is_initialized(), "Distributed training has not been properly initialized." super().__init__(temperature=temperature, alpha=alpha, weights=weights) self.word_size = dist.get_world_size() self.rank = dist.get_rank() self.scale_loss = scale_loss def __call__(self, x: Tensor, y: Tensor, **kwargs): dist_x = self.gather_tensor(x) dist_y = self.gather_tensor(y) loss = super().__call__(dist_x, dist_y, **kwargs) if self.scale_loss: loss = loss * self.word_size return loss def gather_tensor(self, t): gathered = [torch.empty_like(t) for _ in range(self.word_size)] dist.all_gather(gathered, t) gathered[self.rank] = t # 保留本rank的梯度 return torch.cat(gathered, dim=0) class InExampleContrastiveLoss: """ 保持不变 x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim] """ def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs): self.target_per_qry = n_hard_negatives + 1 self.temperature = temperature self.ndim = ndim def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'): if torch.distributed.is_initialized(): x = dist_utils.dist_gather(x) y = dist_utils.dist_gather(y) bsz, ndim = x.size(0), x.size(1) target = torch.zeros(bsz, dtype=torch.long, device=x.device) if self.ndim: ndim = self.ndim x = x[:, :ndim] y = y[:, :ndim] logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature preds = torch.argmax(logits, dim=-1) loss = F.cross_entropy(logits, target, reduction=reduction) loss_detail = {"logits": logits, "labels": target, "preds": preds} return loss, loss_detail class MultiLayerCRDLoss(nn.Module): def __init__(self, temperature: float = 0.02, weights=None, # 长度 K,与 supervise_layers 对齐 crd_weight: float = 0.2, # β crd_temperature: float = 0.07, crd_layers=None, # None -> 使用所有中间层(去掉最后一层) detach_teacher: bool = True, crd_side: str = "both", # 新增: "both"|"qry"|"tgt" queue_size: int = 0 # 新增: 教师队列长度(0=关闭) ): super().__init__() self.temperature = temperature self.weights = weights self.crd_weight = crd_weight self.crd_temperature = crd_temperature self.crd_layers = crd_layers self.detach_teacher = detach_teacher self.crd_side = str(crd_side).lower() assert self.crd_side in ("both", "qry", "tgt") self.queue_size = int(queue_size) if queue_size is not None else 0 self.ce = nn.CrossEntropyLoss(reduction='mean') # 教师记忆队列(延迟初始化,第一次 forward 时决定 D、device) self.tq_queue = None self.tp_queue = None self.tq_ptr = 0 self.tp_ptr = 0 self.tq_filled = 0 self.tp_filled = 0 @staticmethod def _ensure_3d(x: Tensor) -> Tensor: return x if x.dim() == 3 else x.unsqueeze(1) def _norm_weights(self, K: int, device, w_list=None) -> Tensor: if w_list is None: return torch.ones(K, device=device) / max(1, K) assert len(w_list) == K w = torch.tensor(w_list, dtype=torch.float32, device=device) w = torch.clamp(w, min=0) s = w.sum().item() return w / (s if s > 0 else 1.0) def _crd_indices(self, K: int) -> list[int]: if K <= 1: return [] if self.crd_layers is None: return list(range(0, K - 1)) idxs = [] for idx in self.crd_layers: if idx < 0: idx = K + idx if 0 <= idx < K - 1: idxs.append(idx) return sorted(set(idxs)) def _maybe_init_queue(self, D: int, device): if self.queue_size <= 0: return if (self.tq_queue is None) or (self.tq_queue.device != device) or (self.tq_queue.size(-1) != D): self.tq_queue = torch.zeros(self.queue_size, D, device=device, dtype=torch.float32) self.tp_queue = torch.zeros(self.queue_size, D, device=device, dtype=torch.float32) self.tq_ptr = self.tp_ptr = 0 self.tq_filled = self.tp_filled = 0 @torch.no_grad() def _enqueue_teacher(self, tq: Tensor, tp: Tensor): # tq,tp: [B,D](已 detach) if self.queue_size <= 0: return B = tq.size(0) # Query侧队列 num = min(B, self.queue_size) end = self.tq_ptr + num if end <= self.queue_size: self.tq_queue[self.tq_ptr:end].copy_(tq[:num]) else: part1 = self.queue_size - self.tq_ptr self.tq_queue[self.tq_ptr:].copy_(tq[:part1]) self.tq_queue[:num - part1].copy_(tq[part1:num]) self.tq_ptr = (self.tq_ptr + num) % self.queue_size self.tq_filled = min(self.queue_size, self.tq_filled + num) # Target侧队列 end = self.tp_ptr + num if end <= self.queue_size: self.tp_queue[self.tp_ptr:end].copy_(tp[:num]) else: part1 = self.queue_size - self.tp_ptr self.tp_queue[self.tp_ptr:].copy_(tp[:part1]) self.tp_queue[:num - part1].copy_(tp[part1:num]) self.tp_ptr = (self.tp_ptr + num) % self.queue_size self.tp_filled = min(self.queue_size, self.tp_filled + num) def forward(self, x: Tensor, y: Tensor, target: Tensor = None) -> Tensor: """ x,y: [B,K,D];正例 = 同一 i 的学生(q_k(i)/p_k(i)) 与教师(q_L(i)/p_L(i)) 负例 = 该学生与其它 j≠i 的教师(可选拼接教师队列) """ # 形状与设备 x = self._ensure_3d(x) y = self._ensure_3d(y) Bx, Kx, D = x.shape By, Ky, _ = y.shape assert Bx == By, f"batch mismatch: {Bx} vs {By}" B = Bx K = min(Kx, Ky) x = x[:, :K, :] y = y[:, :K, :] device = x.device if target is None: target = torch.arange(B, device=device) # 1) 逐层检索 (k↔k) — 主目标 w_ret = self._norm_weights(K, device, self.weights) L_ret = 0.0 for k in range(K): logits = torch.matmul(x[:, k, :], y[:, k, :].transpose(0, 1)) / self.temperature L_ret = L_ret + w_ret[k] * self.ce(logits, target) # 2) CRD(学生=中间层,教师=最后一层) if K <= 1: return L_ret crd_idxs = self._crd_indices(K) if len(crd_idxs) == 0: return L_ret w_crd_list = [w_ret[k].item() for k in crd_idxs] w_crd = self._norm_weights(len(crd_idxs), device, w_crd_list) tq = x[:, K - 1, :] tp = y[:, K - 1, :] if self.detach_teacher: tq = tq.detach() tp = tp.detach() # 教师队列(memory bank):将教师池扩展为 [当前batch教师 | 历史教师] self._maybe_init_queue(D, device) if self.queue_size > 0 and (self.tq_queue is not None): if self.tq_filled > 0: tq_bank = torch.cat([tq, self.tq_queue[:self.tq_filled]], dim=0) # [B+Q,D] else: tq_bank = tq if self.tp_filled > 0: tp_bank = torch.cat([tp, self.tp_queue[:self.tp_filled]], dim=0) else: tp_bank = tp else: tq_bank, tp_bank = tq, tp L_crd_q = 0.0 L_crd_p = 0.0 # Query侧 CRD: 学生 x[:,k,:] 与 教师 tq_bank,正例 j=i,负例 j≠i (含队列) if self.crd_side in ("qry", "both"): for j, k in enumerate(crd_idxs): logits_q = torch.matmul(x[:, k, :], tq_bank.transpose(0, 1)) / self.crd_temperature # [B, B+Q] # 标签仍为 [0..B-1],因为正例位于 tq_bank 的前 B 列(当前 batch 教师) L_crd_q = L_crd_q + w_crd[j] * self.ce(logits_q, target) # Target侧 CRD: 学生 y[:,k,:] 与 教师 tp_bank if self.crd_side in ("tgt", "both"): for j, k in enumerate(crd_idxs): logits_p = torch.matmul(y[:, k, :], tp_bank.transpose(0, 1)) / self.crd_temperature L_crd_p = L_crd_p + w_crd[j] * self.ce(logits_p, target) # 更新教师队列(不回传) if self.queue_size > 0 and (self.tq_queue is not None): self._enqueue_teacher(tq, tp) # 支持 trainer 的分量调试 if os.environ.get("CRD_DEBUG_RET_COMPONENTS", "0") == "1": return L_ret, L_crd_q, L_crd_p beta = getattr(self, "runtime_beta", self.crd_weight) return L_ret + beta * (L_crd_q + L_crd_p) class DistributedMultiLayerCRDLoss(MultiLayerCRDLoss): def __init__(self, *args, scale_loss: bool = True, **kwargs): # 分布式版本:默认关闭本地队列(负例足够多) queue_size = kwargs.pop("queue_size", 0) # 强制禁用或由用户决定 super().__init__(*args, queue_size=queue_size, **kwargs) assert dist.is_initialized() self.world_size = dist.get_world_size() self.rank = dist.get_rank() self.scale_loss = scale_loss def _gather(self, t: Tensor) -> Tensor: gathered = [torch.empty_like(t) for _ in range(self.world_size)] dist.all_gather(gathered, t) gathered[self.rank] = t return torch.cat(gathered, dim=0) def forward(self, x: Tensor, y: Tensor, target: Tensor = None) -> Tensor: x = self._ensure_3d(x) y = self._ensure_3d(y) x_all = self._gather(x) y_all = self._gather(y) # 注意:labels 在 all_gather 后应为全局对角 B = x_all.size(0) device = x_all.device loss = super().forward(x_all, y_all, target=torch.arange(B, device=device)) if self.scale_loss: loss = loss * self.world_size return loss