| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class EarlyExitClassifier(nn.Module): |
| """ |
| V5 版本分类器:集成轻量级 TTA (LayerNorm) 和 Log1p 特征变换 |
| """ |
| def __init__(self, input_dim=27, hidden_dim=128, embedding_dim=0, dropout_prob=0.2): |
| super().__init__() |
| |
| |
| |
| |
| |
| |
| |
| self.scalar_ln = nn.LayerNorm(input_dim) |
| |
| self.modality_emb = nn.Embedding(2, 4) |
|
|
| self.use_embedding = embedding_dim > 0 |
| if self.use_embedding: |
| |
| self.emb_proj = nn.Sequential( |
| nn.Linear(embedding_dim, hidden_dim // 2), |
| nn.LayerNorm(hidden_dim // 2), |
| nn.ReLU() |
| ) |
| |
| total_input_dim = input_dim + 4 + (hidden_dim // 2) |
| else: |
| total_input_dim = input_dim + 4 |
|
|
| |
| self.mlp = nn.Sequential( |
| nn.Linear(total_input_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout_prob), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.LayerNorm(hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(dropout_prob), |
| nn.Linear(hidden_dim // 2, 1), |
| ) |
| |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.weight, 1.0) |
| nn.init.constant_(m.bias, 0.0) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, mean=0, std=0.02) |
|
|
| def forward(self, scalar_feats, modality_idx, qry_emb=None): |
| |
| |
| |
| |
| |
| |
| scalar_feats_log = torch.sign(scalar_feats) * torch.log1p(torch.abs(scalar_feats)) |
| |
|
|
| |
| s_feat = self.scalar_ln(scalar_feats_log) |
| |
| |
| m_feat = self.modality_emb(modality_idx) |
| |
| features = [s_feat, m_feat] |
|
|
| |
| if self.use_embedding: |
| if qry_emb is None: |
| raise ValueError("Classifier init with embedding_dim > 0 but forward received None") |
| |
| if qry_emb.dtype != torch.float32: |
| qry_emb = qry_emb.float() |
| e_feat = self.emb_proj(qry_emb) |
| features.append(e_feat) |
| |
| |
| x = torch.cat(features, dim=1) |
| |
| |
| logits = self.mlp(x) |
| return logits |