| from .configuration_dna2vec import DNAEncoderConfig |
| from transformers import PreTrainedModel |
| import math |
| from typing import Literal, Optional |
| import torch |
| import torch.nn as nn |
|
|
| class Encoder(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int = 4, |
| embedding_dim: int = 384, |
| dim_feedforward: int = 1536, |
| num_heads: int = 12, |
| num_layers: int = 6, |
| dropout: float = 0.1, |
| activation: Literal["relu", "gelu"] = "gelu", |
| pos_embedding: Optional[str] = "SinusoidalPositionalEncoding", |
| max_position_embeddings: int = 1024, |
| ): |
| """ |
| Default values taken from miniLM v6 |
| https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json |
| """ |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.embedding_dim = embedding_dim |
| self.dropout = dropout |
| self.num_heads = num_heads |
| self.num_layers = num_layers |
| self.emb_dropout = nn.Dropout(p=dropout) |
| |
| if pos_embedding == "SinusoidalPositionalEncoding": |
| position = torch.arange(max_position_embeddings).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim) |
| ) |
| pe = torch.zeros(max_position_embeddings, 1, embedding_dim) |
| pe[:, 0, 0::2] = torch.sin(position * div_term) |
| pe[:, 0, 1::2] = torch.cos(position * div_term) |
| pe = pe.squeeze(1).unsqueeze(0) |
| self.register_buffer("positional_embedding", pe) |
| else: |
| raise ValueError(f"Positional embedding {pos_embedding} not found") |
|
|
| self.embedding = nn.Embedding( |
| num_embeddings=vocab_size, |
| embedding_dim=embedding_dim, |
| ) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=embedding_dim, |
| nhead=num_heads, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, |
| activation=activation, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.trf_encoder = nn.TransformerEncoder( |
| encoder_layer=encoder_layer, num_layers=num_layers |
| ) |
|
|
| def forward( |
| self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| |
| |
|
|
| |
| emb = self.emb_dropout( |
| self.embedding(input_ids) + self.positional_embedding[:, :input_ids.size(1), :] |
| ) |
| |
|
|
| |
| attn = None |
| if attention_mask is not None: |
| attn = attention_mask == 0 |
| out = self.trf_encoder(emb, src_key_padding_mask=attn) |
| |
| return out |
|
|
| class DNAEncoder(PreTrainedModel): |
| config_class = DNAEncoderConfig |
| |
| def __init__(self, config: DNAEncoderConfig): |
| super().__init__(config) |
| self.config = config |
| self.encoder = Encoder( |
| vocab_size=config.vocab_size, |
| embedding_dim=config.embedding_dim, |
| dim_feedforward=config.dim_feedforward, |
| num_heads=config.num_heads, |
| num_layers=config.num_layers, |
| dropout=config.dropout, |
| activation=config.activation, |
| max_position_embeddings=config.max_position_embeddings, |
| ) |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| return self.encoder(input_ids, attention_mask) |