Spaces:
Runtime error
Runtime error
| # models/embedder.py | |
| from typing import Any, Dict, List | |
| import torch | |
| from FlagEmbedding import BGEM3FlagModel | |
| from pydantic import BaseModel | |
| from core.exceptions import ModelLoadError | |
| from core.logger import setup_logger | |
| logger = setup_logger("embedder") | |
| # Data structure for return (Type Hinting) | |
| class EmbedderResult(BaseModel): | |
| dense_vector: List[float] | |
| sparse_indices: List[int] | |
| sparse_values: List[float] | |
| class TextEmbedder: | |
| """ | |
| Converts the input text into Dense Vectors and Sparse Vectors (Lexical Weights) using the BGE-M3 model. | |
| """ | |
| def __init__(self, model_name: str = "BAAI/bge-m3", use_fp16: bool = False): | |
| self.model_name = model_name | |
| self.device = self._get_device() | |
| try: | |
| logger.info(f"⏳ Loading Embedder Model: {self.model_name} on {self.device}") | |
| self.model = BGEM3FlagModel( | |
| self.model_name, | |
| use_fp16=(use_fp16 and self.device.startswith("cuda")) | |
| ) | |
| self._warmup() | |
| logger.info("✅ Embedder Model loaded successfully.") | |
| except Exception as e: | |
| logger.critical(f"❌ Failed to load Embedder Model: {e}", exc_info=True) | |
| raise ModelLoadError(f"Embedder initialization failed: {e}") | |
| def _get_device(self) -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" # Apple Silicon | |
| return "cpu" | |
| def _warmup(self): | |
| logger.info("Warming up embedder model with a dummy input.") | |
| self.encode_query("Hello world") | |
| def encode_query(self, text: str) -> EmbedderResult: | |
| """ | |
| Converts a single query text into Qdrant hybrid search format. | |
| """ | |
| try: | |
| # 1. model inference to get dense vector and sparse lexical weights | |
| output = self.model.encode( | |
| text, | |
| return_dense=True, | |
| return_sparse=True, | |
| return_colbert_vecs=False | |
| ) | |
| dense_vec = output['dense_vecs'].tolist() | |
| lexical_weights: Dict[str, float] = output['lexical_weights'] | |
| # 2. Sparse Vector Transformation (Qdrant specifications: token_id array, weight array) | |
| sparse_indices = [] | |
| sparse_values = [] | |
| # Convert text tokens into unique IDs (integers) using the BGE-M3 tokenizer | |
| for token_str, weight in lexical_weights.items(): | |
| # Get the ID of the string token through the tokenizer (vocab index) | |
| token_id = self.model.tokenizer.convert_tokens_to_ids(token_str) | |
| if token_id is not None: | |
| sparse_indices.append(token_id) | |
| sparse_values.append(float(weight)) | |
| return EmbedderResult( | |
| dense_vector=dense_vec, | |
| sparse_indices=sparse_indices, | |
| sparse_values=sparse_values | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to encode query '{text}': {e}") | |
| raise RuntimeError(f"Embedding generation failed: {e}") | |
| # options for batch encoding if needed in the future | |
| def encode_documents(self, texts: List[str], batch_size: int = 12) -> Dict[str, Any]: | |
| return self.model.encode( | |
| texts, | |
| batch_size=batch_size, | |
| max_length=8192, # BGE-M3's max token length | |
| return_dense=True, | |
| return_sparse=True, | |
| return_colbert_vecs=False | |
| ) |