| from collections.abc import Generator, Iterable |
| from dataclasses import dataclass |
| from enum import StrEnum |
| from itertools import chain |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| AutoTokenizer, |
| ModernBertModel, |
| PreTrainedConfig, |
| PreTrainedModel, |
| ) |
| from transformers.modeling_outputs import TokenClassifierOutput |
|
|
| BATCH_SIZE = 64 |
|
|
|
|
| class ModelURI(StrEnum): |
| BASE = "answerdotai/ModernBERT-base" |
| LARGE = "answerdotai/ModernBERT-large" |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class LexicalExample: |
| concept: str |
| definition: str |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class PaddedBatch: |
| input_ids: torch.Tensor |
| attention_mask: torch.Tensor |
|
|
|
|
| class DisamBert(PreTrainedModel): |
| def __init__(self, config: PreTrainedConfig): |
| super().__init__(config) |
| if config.init_basemodel: |
| self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") |
| self.classifier_head = nn.UninitializedParameter() |
| self.bias = nn.UninitializedParameter() |
| self.__entities = None |
| else: |
| self.BaseModel = ModernBertModel(config) |
| self.classifier_head = nn.Parameter( |
| torch.empty((config.ontology_size, config.hidden_size)) |
| ) |
| self.bias = nn.Parameter(torch.empty((config.ontology_size, 1))) |
| self.__entities = pd.Series(config.entities) |
| config.init_basemodel = False |
| self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path) |
| self.loss = nn.CrossEntropyLoss() |
| self.post_init() |
|
|
| @classmethod |
| def from_base(cls, base_id: ModelURI): |
| config = AutoConfig.from_pretrained(base_id) |
| config.init_basemodel = True |
| config.tokenizer_path = base_id |
| return cls(config) |
|
|
| def init_classifier(self, entities: Generator[LexicalExample]) -> None: |
| entity_ids = [] |
| vectors = [] |
| batch = [] |
| n = 0 |
| with self.BaseModel.device: |
| for entity in entities: |
| entity_ids.append(entity.concept) |
| batch.append(entity.definition) |
|
|
| n += 1 |
| if n == BATCH_SIZE: |
| tokens = self.tokenizer(batch, padding=True, return_tensors="pt") |
| encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) |
| vectors.append(encoding.last_hidden_state.detach()[:, 0]) |
| n = 0 |
| batch = [] |
| if n > 0: |
| tokens = self.tokenizer(batch, padding=True, return_tensors="pt") |
| encoding = self.BaseModel(tokens["input_ids"], tokens["attention_mask"]) |
| vectors.append(encoding.last_hidden_state.detach()[:, 0]) |
|
|
| self.__entities = pd.Series(entity_ids) |
| self.config.entities = entity_ids |
| self.config.ontology_size = len(entity_ids) |
| self.classifier_head = nn.Parameter(torch.cat(vectors, dim=0)) |
| self.bias = nn.Parameter( |
| torch.nn.init.normal_( |
| torch.empty((self.config.ontology_size, 1)), |
| std=self.classifier_head.std().item() * np.sqrt(self.config.hidden_size) |
| ) |
| ) |
|
|
| @property |
| def entities(self) -> pd.Series: |
| if self.__entities is None and hasattr(self.config, "entities"): |
| self.__entities = pd.Series(self.config.entities) |
| return self.__entities |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| lengths: list[list[int]], |
| candidates: list[list[list[int]]], |
| labels: Iterable[list[int]] | None = None, |
| output_hidden_states: bool = False, |
| output_attentions: bool = False, |
| ) -> TokenClassifierOutput: |
| assert not nn.parameter.is_lazy(self.classifier_head), ( |
| "Run init_classifier to initialise weights" |
| ) |
| base_model_output = self.BaseModel( |
| input_ids, |
| attention_mask, |
| output_hidden_states=output_hidden_states, |
| output_attentions=output_attentions, |
| ) |
| token_vectors = base_model_output.last_hidden_state |
| span_vectors = torch.cat( |
| [ |
| torch.vstack( |
| [ |
| torch.sum(chunk, dim=0) |
| for chunk in self.split(token_vectors[i], sentence_indices) |
| ] |
| ) |
| for (i, sentence_indices) in enumerate(lengths) |
| ] |
| ) |
| logits = torch.einsum("ij,kj->ki", span_vectors, self.classifier_head) + self.bias |
| logits1 = logits - logits.min() |
| mask = torch.zeros_like(logits) |
| for i, concepts in enumerate(chain.from_iterable(candidates)): |
| mask[concepts, i] = torch.tensor(1.0) |
| logits2 = logits1 * mask |
| sentence_lengths = [len(sentence_indices) for sentence_indices in lengths] |
| maxlen = max(sentence_lengths) |
| split_logits = torch.split(logits2, sentence_lengths, dim=1) |
| logits3 = torch.stack( |
| [ |
| self.extend_to_max_length(sentence, length, maxlen) |
| for (sentence, length) in zip(split_logits, sentence_lengths, strict=True) |
| ] |
| ) |
| return TokenClassifierOutput( |
| logits=logits3, |
| loss=self.loss(logits3, labels) if labels is not None else None, |
| hidden_states=base_model_output.hidden_states if output_hidden_states else None, |
| attentions=base_model_output.attentions if output_attentions else None, |
| ) |
|
|
| def split(self, vectors: torch.Tensor, lengths: list[int]) -> tuple[torch.Tensor, ...]: |
| maxlen = vectors.shape[0] |
| total_length = sum(lengths) |
| is_padded = total_length < maxlen |
| chunks = vectors.split((lengths + [maxlen - total_length]) if is_padded else lengths) |
| return chunks[:-1] if is_padded else chunks |
|
|
| def pad(self, tokens: Iterable[list[int]]) -> PaddedBatch: |
| lengths = [len(sentence) for sentence in tokens] |
| maxlen = max(lengths) |
| input_ids = torch.tensor( |
| [ |
| sentence + [self.config.pad_token_id] * (maxlen - length) |
| for (sentence, length) in zip(tokens, lengths, strict=True) |
| ] |
| ) |
| attention_mask = torch.vstack( |
| [torch.cat((torch.ones(length), torch.zeros(maxlen - length))) for length in lengths] |
| ) |
| return PaddedBatch(input_ids, attention_mask) |
|
|
| def extend_to_max_length( |
| self, sentence: torch.Tensor, length: int, maxlength: int |
| ) -> torch.Tensor: |
| with self.BaseModel.device: |
| return ( |
| torch.cat( |
| [ |
| sentence, |
| torch.zeros((self.config.ontology_size, maxlength - length)), |
| ], |
| dim=1, |
| ) |
| if length < maxlength |
| else sentence |
| ) |
|
|
| def pad_labels(self, labels: list[list[int]]) -> torch.Tensor: |
| unk = len(self.config.entities) - 1 |
| lengths = [len(seq) for seq in labels] |
| maxlen = max(lengths) |
| with self.BaseModel.device: |
| return torch.tensor( |
| [ |
| seq + [unk] * (maxlen - length) |
| for (seq, length) in zip(labels, lengths, strict=True) |
| ] |
| ) |
|
|
| def tokenize( |
| self, batch: list[dict[str, str | list[int]]] |
| ) -> dict[str, torch.Tensor | list[list[int]]]: |
| all_indices = [] |
| all_tokens = [] |
| with self.BaseModel.device: |
| for example in batch: |
| text = example["text"] |
| span_indices = example["indices"] |
| indices = [] |
| tokens = [] |
| last_span = len(span_indices) - 2 |
| for i, position in enumerate(span_indices[:-1]): |
| span = text[position : span_indices[i + 1]] |
| span_tokens = self.tokenizer([span], padding=False)["input_ids"][0] |
| if i > 0: |
| span_tokens = span_tokens[1:] |
| if i < last_span: |
| span_tokens = span_tokens[:-1] |
| indices.append(len(span_tokens)) |
| tokens.extend(span_tokens) |
| all_indices.append(indices) |
| all_tokens.append(tokens) |
| padded = self.pad(all_tokens) |
| result = { |
| "input_ids": padded.input_ids, |
| "attention_mask": padded.attention_mask, |
| "lengths": all_indices, |
| "candidates": [example["candidates"] for example in batch], |
| } |
| if "labels" in batch[0]: |
| result["labels"] = self.pad_labels([example["labels"] for example in batch]) |
| return result |
|
|