| from __future__ import annotations |
|
|
| import warnings |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from transformers import LlamaConfig, LlamaModel, PreTrainedTokenizer |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
|
|
|
| class DramaModel(LlamaModel): |
| """ |
| DramaModel is a modified version of the LlamaModel that supports bi-directional attention |
| and provides query and document encoding functionalities. |
| """ |
|
|
| def __init__(self, config: LlamaConfig): |
| """ |
| Initializes the DramaModel by disabling causal masking in self-attention layers. |
| """ |
| super().__init__(config) |
| for layer in self.layers: |
| layer.self_attn.is_causal = False |
| |
| self.query_prefix = "Query: " |
| self.max_seq_len = 8192 |
| self.hidden_size = config.hidden_size |
|
|
| def _update_causal_mask( |
| self, |
| attention_mask: torch.Tensor, |
| input_tensor: torch.Tensor, |
| cache_position: torch.Tensor, |
| past_seen_tokens=None, |
| output_attentions=False, |
| ): |
| """ |
| Updates the causal mask for attention computations. |
| """ |
| if self.config._attn_implementation == "flash_attention_2": |
| if attention_mask is not None and (attention_mask == 0.0).any(): |
| return attention_mask |
| return None |
| if attention_mask is None or attention_mask.dim() == 4: |
| return attention_mask |
|
|
| return AttentionMaskConverter._expand_mask( |
| mask=attention_mask, |
| dtype=input_tensor.dtype, |
| ) |
|
|
| def _average_pool( |
| self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Computes the average pooled representation of the last hidden states. |
| """ |
| last_hidden = last_hidden_states.masked_fill( |
| ~attention_mask[..., None].bool(), 0.0 |
| ) |
| return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
| def _tokenize( |
| self, |
| tokenizer: PreTrainedTokenizer, |
| texts: list[str], |
| max_seq_len: int = None, |
| use_nested: bool = False, |
| ): |
| """ |
| Tokenizes input text sequences with optional sequence length restriction. |
| """ |
| if max_seq_len is None: |
| max_seq_len = self.max_seq_len |
| tokenized = tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=max_seq_len, |
| return_tensors="pt", |
| ).to(self.device) |
| return tokenized |
|
|
| def encode(self, input_ids, attention_mask, dim, *args, **kwargs): |
| """ |
| Pass through the model and compute normalized embeddings. |
| |
| Args: |
| input_ids (torch.Tensor): Input token IDs. |
| attention_mask (torch.Tensor): Attention mask tensor. |
| dim (int): Dimensionality for output embeddings. |
| |
| Returns: |
| torch.Tensor: Normalized output embeddings. |
| """ |
| outputs = self.forward(input_ids, attention_mask, *args, **kwargs) |
| embeddings = self._average_pool( |
| outputs.last_hidden_state[:, :, :dim], attention_mask |
| ) |
| |
| embeddings = F.normalize(embeddings, p=2, dim=1) |
| return embeddings |
|
|
| def encode_queries( |
| self, |
| tokenizer: PreTrainedTokenizer, |
| queries: list[str], |
| max_seq_len: int = None, |
| dim: int = None, |
| use_nested: bool = False, |
| ): |
| """ |
| Encodes a list of queries into embeddings. |
| |
| Args: |
| tokenizer (PreTrainedTokenizer): Tokenizer for text processing. |
| queries (list[str]): List of query texts. |
| max_seq_len (int, optional): Maximum sequence length. |
| dim (int, optional): Dimensionality for output embeddings. |
| |
| Returns: |
| torch.Tensor: Encoded query embeddings in shape (num_queries, dim). |
| """ |
| if not queries: |
| raise ValueError("queries must not be empty.") |
| if not isinstance(queries, list) or not all( |
| isinstance(q, str) for q in queries |
| ): |
| raise ValueError("queries must be a list of strings.") |
| if tokenizer is None: |
| raise ValueError("tokenizer must not be None.") |
| if dim is not None and (dim < 1 or dim > self.hidden_size): |
| raise ValueError(f"dim must be in range [1, {self.hidden_size}].") |
| if use_nested: |
| warnings.warn( |
| "use_nested is not supported due to package import versions.", |
| UserWarning, |
| ) |
| queries = [self.query_prefix + query for query in queries] |
| tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested) |
| embeddings = self.encode(**tokenized_queries, dim=dim) |
| return embeddings |
|
|
| def encode_documents( |
| self, |
| tokenizer: PreTrainedTokenizer, |
| documents: list[str], |
| max_seq_len: int = None, |
| dim: int = None, |
| use_nested: bool = False, |
| ): |
| """ |
| Encodes a list of documents into embeddings. |
| |
| Args: |
| tokenizer (PreTrainedTokenizer): Tokenizer for text processing. |
| documents (list[str]): List of document texts. |
| max_seq_len (int, optional): Maximum sequence length. |
| dim (int, optional): Dimensionality for output embeddings. |
| |
| Returns: |
| torch.Tensor: Encoded document embeddings in shape (num_documents, dim). |
| """ |
| if not documents: |
| raise ValueError("documents must not be empty.") |
| if not isinstance(documents, list) or not all( |
| isinstance(d, str) for d in documents |
| ): |
| raise ValueError("documents must be a list of strings.") |
| if tokenizer is None: |
| raise ValueError("tokenizer must not be None.") |
| if dim is not None and (dim < 1 or dim > self.hidden_size): |
| raise ValueError(f"dim must be in range [1, {self.hidden_size}].") |
| if use_nested: |
| warnings.warn( |
| "use_nested is not supported due to package import versions.", |
| UserWarning, |
| ) |
| tokenized_documents = self._tokenize( |
| tokenizer, documents, max_seq_len, use_nested |
| ) |
| embeddings = self.encode(**tokenized_documents, dim=dim) |
| return embeddings |
|
|