Spaces:
Running
Running
| """ | |
| BERT Embedding Service — computes text embeddings using a pretrained BERT model. | |
| Loaded once at startup, reused for all requests. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from transformers import BertModel, BertTokenizer | |
| from config import BERT_MODEL_NAME, BERT_MAX_LENGTH | |
| class BertEmbeddingService: | |
| """Singleton service that computes BERT embeddings for text.""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self._loaded = False | |
| def load(self): | |
| """Load the BERT model and tokenizer. Call once at app startup.""" | |
| if self._loaded: | |
| return | |
| print(f"[BertService] Loading BERT model: {BERT_MODEL_NAME}...") | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = BertModel.from_pretrained(BERT_MODEL_NAME) | |
| self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self._loaded = True | |
| print(f"[BertService] Model loaded on {self.device}") | |
| def _pool_summary(self, last_hidden_states, pool_op="max"): | |
| """Pool the BERT output into a single vector per input.""" | |
| num_features = last_hidden_states.size()[1] | |
| hidden_p = last_hidden_states.permute(0, 2, 1) | |
| pool_fn = F.max_pool1d if pool_op == "max" else F.avg_pool1d | |
| return pool_fn(hidden_p, kernel_size=num_features).squeeze(-1) | |
| def compute_embedding(self, text: str) -> np.ndarray: | |
| """ | |
| Compute a single BERT embedding for the given text. | |
| Returns a numpy array of shape (768,). | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("BertService not loaded. Call load() first.") | |
| # Tokenize | |
| tokens = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=BERT_MAX_LENGTH, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| ) | |
| # Move to device | |
| inputs = { | |
| "input_ids": tokens["input_ids"].to(self.device), | |
| "attention_mask": tokens["attention_mask"].to(self.device), | |
| "token_type_ids": tokens["token_type_ids"].to(self.device), | |
| } | |
| # Forward pass | |
| with torch.no_grad(): | |
| output = self.model(**inputs) | |
| embedding = self._pool_summary(output[0]) | |
| return embedding.detach().cpu().numpy().squeeze(0) # shape: (768,) | |
| def compute_embeddings_batch(self, texts: list[str]) -> np.ndarray: | |
| """ | |
| Compute BERT embeddings for a batch of texts. | |
| Returns numpy array of shape (N, 768). | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("BertService not loaded. Call load() first.") | |
| tokens = self.tokenizer( | |
| texts, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=BERT_MAX_LENGTH, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = { | |
| "input_ids": tokens["input_ids"].to(self.device), | |
| "attention_mask": tokens["attention_mask"].to(self.device), | |
| "token_type_ids": tokens["token_type_ids"].to(self.device), | |
| } | |
| with torch.no_grad(): | |
| output = self.model(**inputs) | |
| embeddings = self._pool_summary(output[0]) | |
| return embeddings.detach().cpu().numpy() # shape: (N, 768) | |
| # Global singleton instance | |
| bert_service = BertEmbeddingService() | |