| import torch |
| import torch.nn as nn |
| from typing import Optional, Tuple, Union |
|
|
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import logging |
|
|
| from configuration_pdeeppp import PDeepPPConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| class SelfAttentionGlobalFeatures(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.self_attention = nn.MultiheadAttention( |
| embed_dim=config.input_size, |
| num_heads=config.num_heads, |
| batch_first=True |
| ) |
| self.fc1 = nn.Linear(config.input_size, config.hidden_size) |
| self.fc2 = nn.Linear(config.hidden_size, config.output_size) |
| self.layer_norm = nn.LayerNorm(config.input_size) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| attn_output, _ = self.self_attention(x, x, x) |
| x = self.layer_norm(x + attn_output) |
| x = self.fc1(x) |
| x = self.dropout(x) |
| x = self.fc2(x) |
| return x |
|
|
| class TransConv1d(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.self_attention_global_features = SelfAttentionGlobalFeatures(config) |
| self.transformer_encoder = nn.TransformerEncoderLayer( |
| d_model=config.output_size, |
| nhead=config.num_heads, |
| dim_feedforward=config.hidden_size*2, |
| dropout=config.dropout, |
| batch_first=True |
| ) |
| self.transformer = nn.TransformerEncoder( |
| self.transformer_encoder, |
| num_layers=config.num_transformer_layers |
| ) |
| self.fc1 = nn.Linear(config.output_size, config.output_size) |
| self.fc2 = nn.Linear(config.output_size, config.output_size) |
| self.layer_norm = nn.LayerNorm(config.output_size) |
|
|
| def forward(self, x): |
| x = self.self_attention_global_features(x) |
| residual = x |
| x = self.transformer(x) |
| x = self.fc1(x) |
| residual = x |
| x = self.fc2(x) |
| x = self.layer_norm(x + residual) |
| return x |
|
|
| class PosCNN(nn.Module): |
| def __init__(self, config, use_position_encoding=True): |
| super().__init__() |
| self.use_position_encoding = use_position_encoding |
| self.conv1d = nn.Conv1d( |
| in_channels=config.input_size, |
| out_channels=64, |
| kernel_size=3, |
| padding=1 |
| ) |
| self.relu = nn.ReLU() |
| self.global_pooling = nn.AdaptiveAvgPool1d(1) |
| self.fc = nn.Linear(64, config.output_size) |
| |
| if self.use_position_encoding: |
| self.position_encoding = nn.Parameter(torch.zeros(64, config.input_size)) |
|
|
| def forward(self, x): |
| x = x.permute(0, 2, 1) |
| x = self.conv1d(x) |
| x = self.relu(x) |
| |
| if self.use_position_encoding: |
| seq_len = x.size(2) |
| pos_encoding = self.position_encoding[:, :seq_len].unsqueeze(0) |
| x = x + pos_encoding |
| |
| x = self.global_pooling(x) |
| x = x.squeeze(-1) |
| x = self.fc(x) |
| return x |
|
|
| class PDeepPPPreTrainedModel(PreTrainedModel): |
| """ |
| 抽象基类,包含所有PDeepPP模型所需的方法 |
| """ |
| config_class = PDeepPPConfig |
| base_model_prefix = "PDeepPP" |
| supports_gradient_checkpointing = True |
| |
| def _init_weights(self, module): |
| """初始化权重""" |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| class PDeepPPModel(PDeepPPPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| self.transformer = TransConv1d(config) |
| self.cnn = PosCNN(config) |
| self.cnn_layers = nn.Sequential( |
| nn.Conv1d(config.output_size*2, 32, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.AdaptiveMaxPool1d(1), |
| nn.Dropout(config.dropout/2), |
| nn.Conv1d(32, 64, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.AdaptiveMaxPool1d(1), |
| nn.Dropout(config.dropout/2), |
| nn.Flatten(), |
| nn.Linear(64, 1) |
| ) |
| |
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_embeds=None, |
| labels=None, |
| return_dict=None, |
| ): |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the classification loss. |
| |
| Returns: |
| dict or tuple: 根据return_dict参数返回不同格式的结果 |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| transformer_output = self.transformer(input_embeds) |
| cnn_output = self.cnn(input_embeds) |
| cnn_output = cnn_output.unsqueeze(1).expand(-1, transformer_output.size(1), -1) |
| combined = torch.cat([transformer_output, cnn_output], dim=2) |
| combined = combined.permute(0, 2, 1) |
| logits = self.cnn_layers(combined).squeeze(1) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels.float()) |
| |
| |
| probs = torch.sigmoid(logits) |
| ent = -(probs*torch.log(probs+1e-12) + |
| (1-probs)*torch.log(1-probs+1e-12)).mean() |
| cond_ent = -(probs*torch.log(probs+1e-12)).mean() |
| reg_loss = self.config.lambda_ * ent - self.config.lambda_ * cond_ent |
| |
| loss = self.config.lambda_ * loss + (1 - self.config.lambda_) * reg_loss |
| |
| if return_dict: |
| return { |
| "loss": loss, |
| "logits": logits, |
| } |
| else: |
| return (loss, logits) if loss is not None else logits |
| |
| PDeepPPModel.register_for_auto_class("AutoModel") |