| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from .configuration_exon_classifier import Evo2ExonConfig |
| |
| class Evo2ExonModel(PreTrainedModel): |
| config_class = Evo2ExonConfig |
| base_model_prefix = "evo2_exon_classifier" |
|
|
| def __init__(self, config: Evo2ExonConfig): |
| super().__init__(config) |
|
|
| |
| layers = [nn.Linear(config.embedding_dim, config.hidden_dim), nn.ReLU()] |
| for _ in range(config.num_hidden_layers - 1): |
| layers += [nn.Linear(config.hidden_dim, config.hidden_dim), nn.ReLU()] |
| layers += [nn.Linear(config.hidden_dim, 1)] |
|
|
| self.fc_layers = nn.Sequential(*layers) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, inputs_embeds, labels=None, **kwargs): |
| """ |
| inputs_embeds : (batch, seq_len, embedding_dim) |
| labels : (batch, seq_len) optional, 0/1 floats or ints |
| """ |
| bsz, seq_len, _ = inputs_embeds.shape |
|
|
| |
| logits = self.fc_layers(inputs_embeds.view(-1, inputs_embeds.size(-1))) |
| logits = logits.view(bsz, seq_len) |
| probs = self.sigmoid(logits) |
|
|
| if labels is not None: |
| loss = nn.BCELoss()(probs, labels.float()) |
| return {"loss": loss, "logits": probs} |
|
|
| return {"logits": probs} |