Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from typing import Callable, List, Optional | |
| import numpy as np | |
| from openai import OpenAI | |
| from sentence_transformers import SentenceTransformer | |
| class EmbeddingGenerator: | |
| def __init__(self, provider: str = None, model_name: str = None): | |
| configured_provider = (provider or os.getenv("EMBEDDING_PROVIDER", "auto")).lower() | |
| self.provider = self._resolve_provider(configured_provider) | |
| self.model_name = model_name or self._resolve_model_name() | |
| self.batch_size = int(os.getenv("EMBEDDING_BATCH_SIZE", "8")) | |
| self.device = os.getenv("EMBEDDING_DEVICE") | |
| self.client = None | |
| self.model = None | |
| self.vertex_task_type_document = os.getenv( | |
| "VERTEX_EMBEDDING_TASK_TYPE_DOCUMENT", "RETRIEVAL_DOCUMENT" | |
| ) | |
| self.vertex_task_type_query = os.getenv( | |
| "VERTEX_EMBEDDING_TASK_TYPE_QUERY", "RETRIEVAL_QUERY" | |
| ) | |
| self.vertex_output_dimensionality = self._optional_int( | |
| os.getenv("VERTEX_EMBEDDING_OUTPUT_DIMENSIONALITY") | |
| ) | |
| self.query_prefix = os.getenv("EMBEDDING_QUERY_PREFIX", "").strip() | |
| normalized_model_name = self.model_name.lower() | |
| self.query_prompt_name = ( | |
| os.getenv("EMBEDDING_QUERY_PROMPT_NAME", "query") | |
| if "nomic-embed-code" in normalized_model_name | |
| or "coderankembed" in normalized_model_name | |
| else None | |
| ) | |
| if self.provider == "openai": | |
| print( | |
| f"[embeddings] Initializing OpenAI embeddings with model={self.model_name}", | |
| flush=True, | |
| ) | |
| self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| self.embedding_dim = int(os.getenv("OPENAI_EMBEDDING_DIM", "1536")) | |
| elif self.provider == "vertex_ai": | |
| print( | |
| f"[embeddings] Initializing Vertex AI embeddings with model={self.model_name}", | |
| flush=True, | |
| ) | |
| try: | |
| from google import genai | |
| except ImportError as exc: | |
| raise RuntimeError( | |
| "Vertex AI embedding support requires the `google-genai` package." | |
| ) from exc | |
| project = os.getenv("GOOGLE_CLOUD_PROJECT") | |
| location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") | |
| if not project: | |
| raise RuntimeError( | |
| "GOOGLE_CLOUD_PROJECT must be set when using Vertex AI embeddings." | |
| ) | |
| self.client = genai.Client( | |
| vertexai=True, | |
| project=project, | |
| location=location, | |
| ) | |
| self.embedding_dim = int( | |
| os.getenv( | |
| "VERTEX_EMBEDDING_DIM", | |
| str(self.vertex_output_dimensionality or 3072), | |
| ) | |
| ) | |
| else: | |
| model_device = self.device or "cpu" | |
| print( | |
| f"[embeddings] Loading local embedding model={self.model_name} on device={model_device}", | |
| flush=True, | |
| ) | |
| started_at = time.perf_counter() | |
| self.model = SentenceTransformer( | |
| self.model_name, | |
| trust_remote_code=True, | |
| device=model_device, | |
| ) | |
| self.embedding_dim = self.model.get_sentence_embedding_dimension() | |
| elapsed = time.perf_counter() - started_at | |
| print( | |
| f"[embeddings] Model ready dim={self.embedding_dim} load_time={elapsed:.2f}s", | |
| flush=True, | |
| ) | |
| def embed_text(self, text: str) -> np.ndarray: | |
| if self.provider == "openai": | |
| return self.embed_batch([text])[0] | |
| if self.provider == "vertex_ai": | |
| return self._embed_with_vertex( | |
| [text], | |
| task_type=self.vertex_task_type_query, | |
| )[0] | |
| query_text = f"{self.query_prefix}: {text}" if self.query_prefix else text | |
| return self._encode_with_backoff([query_text], prompt_name=self.query_prompt_name)[0] | |
| def embed_batch( | |
| self, | |
| texts: List[str], | |
| batch_size: int = None, | |
| progress_callback: Optional[Callable[[int, int], None]] = None, | |
| ) -> np.ndarray: | |
| if not texts: | |
| return np.array([], dtype="float32") | |
| if self.provider == "openai": | |
| response = self.client.embeddings.create( | |
| model=self.model_name or "text-embedding-3-small", | |
| input=texts, | |
| ) | |
| embeddings = [item.embedding for item in response.data] | |
| if progress_callback: | |
| progress_callback(len(texts), len(texts)) | |
| return np.array(embeddings, dtype="float32") | |
| if self.provider == "vertex_ai": | |
| return self._embed_batch_with_vertex( | |
| texts=texts, | |
| batch_size=batch_size, | |
| progress_callback=progress_callback, | |
| ) | |
| effective_batch_size = max(1, batch_size or self.batch_size) | |
| all_embeddings = [] | |
| total = len(texts) | |
| for start in range(0, total, effective_batch_size): | |
| batch = texts[start : start + effective_batch_size] | |
| batch_number = (start // effective_batch_size) + 1 | |
| total_batches = (total + effective_batch_size - 1) // effective_batch_size | |
| print( | |
| f"[embeddings] Encoding batch {batch_number}/{total_batches} " | |
| f"items={len(batch)} progress={start}/{total}", | |
| flush=True, | |
| ) | |
| started_at = time.perf_counter() | |
| batch_embeddings = self._encode_with_backoff( | |
| batch, | |
| batch_size=min(effective_batch_size, len(batch)), | |
| ) | |
| all_embeddings.append(batch_embeddings) | |
| elapsed = time.perf_counter() - started_at | |
| print( | |
| f"[embeddings] Finished batch {batch_number}/{total_batches} " | |
| f"elapsed={elapsed:.2f}s progress={min(start + len(batch), total)}/{total}", | |
| flush=True, | |
| ) | |
| if progress_callback: | |
| progress_callback(min(start + len(batch), total), total) | |
| return np.vstack(all_embeddings).astype("float32") | |
| def _embed_batch_with_vertex( | |
| self, | |
| texts: List[str], | |
| batch_size: int = None, | |
| progress_callback: Optional[Callable[[int, int], None]] = None, | |
| ) -> np.ndarray: | |
| effective_batch_size = max(1, batch_size or self.batch_size) | |
| all_embeddings = [] | |
| total = len(texts) | |
| for start in range(0, total, effective_batch_size): | |
| batch = texts[start : start + effective_batch_size] | |
| batch_number = (start // effective_batch_size) + 1 | |
| total_batches = (total + effective_batch_size - 1) // effective_batch_size | |
| print( | |
| f"[embeddings] Vertex batch {batch_number}/{total_batches} " | |
| f"items={len(batch)} progress={start}/{total}", | |
| flush=True, | |
| ) | |
| started_at = time.perf_counter() | |
| batch_embeddings = self._embed_with_vertex( | |
| batch, | |
| task_type=self.vertex_task_type_document, | |
| ) | |
| all_embeddings.append(batch_embeddings) | |
| elapsed = time.perf_counter() - started_at | |
| print( | |
| f"[embeddings] Finished Vertex batch {batch_number}/{total_batches} " | |
| f"elapsed={elapsed:.2f}s progress={min(start + len(batch), total)}/{total}", | |
| flush=True, | |
| ) | |
| if progress_callback: | |
| progress_callback(min(start + len(batch), total), total) | |
| return np.vstack(all_embeddings).astype("float32") | |
| def _embed_with_vertex(self, texts: List[str], task_type: str) -> np.ndarray: | |
| config = { | |
| "task_type": task_type, | |
| } | |
| if self.vertex_output_dimensionality: | |
| config["output_dimensionality"] = self.vertex_output_dimensionality | |
| response = self.client.models.embed_content( | |
| model=self.model_name, | |
| contents=texts, | |
| config=config, | |
| ) | |
| embeddings = getattr(response, "embeddings", None) | |
| if not embeddings: | |
| raise RuntimeError("Vertex AI embeddings returned an empty response.") | |
| values = [] | |
| for item in embeddings: | |
| if hasattr(item, "values"): | |
| values.append(item.values) | |
| elif isinstance(item, dict): | |
| values.append(item.get("values")) | |
| else: | |
| values.append(getattr(item, "embedding", None)) | |
| if not values or any(vector is None for vector in values): | |
| raise RuntimeError("Vertex AI embeddings response could not be parsed.") | |
| return np.array(values, dtype="float32") | |
| def _encode_with_backoff( | |
| self, | |
| texts: List[str], | |
| batch_size: int = None, | |
| prompt_name: str = None, | |
| ) -> np.ndarray: | |
| effective_batch_size = max(1, batch_size or self.batch_size) | |
| while True: | |
| try: | |
| encode_kwargs = { | |
| "sentences": texts, | |
| "batch_size": effective_batch_size, | |
| "show_progress_bar": len(texts) > effective_batch_size, | |
| "convert_to_numpy": True, | |
| "normalize_embeddings": True, | |
| } | |
| if prompt_name: | |
| encode_kwargs["prompt_name"] = prompt_name | |
| embeddings = self.model.encode( | |
| **encode_kwargs, | |
| ) | |
| return embeddings.astype("float32") | |
| except RuntimeError as exc: | |
| message = str(exc).lower() | |
| is_memory_error = "out of memory" in message or "mps" in message | |
| if not is_memory_error or effective_batch_size == 1: | |
| raise | |
| print( | |
| f"[embeddings] Retrying batch with smaller size due to memory pressure: " | |
| f"{effective_batch_size} -> {max(1, effective_batch_size // 2)}", | |
| flush=True, | |
| ) | |
| effective_batch_size = max(1, effective_batch_size // 2) | |
| def get_embedding_dim(self) -> int: | |
| return self.embedding_dim | |
| def _resolve_provider(self, configured_provider: str) -> str: | |
| if configured_provider != "auto": | |
| return configured_provider | |
| if self._is_hf_space() or self._is_test_context(): | |
| return "local" | |
| return "vertex_ai" | |
| def _resolve_model_name(self) -> str: | |
| explicit_model = os.getenv("EMBEDDING_MODEL") | |
| if explicit_model: | |
| return explicit_model | |
| if self.provider == "vertex_ai": | |
| return os.getenv("VERTEX_EMBEDDING_MODEL", "gemini-embedding-001") | |
| if self._is_hf_space() or self._is_test_context(): | |
| return os.getenv( | |
| "LIGHTWEIGHT_LOCAL_EMBEDDING_MODEL", | |
| "sentence-transformers/all-MiniLM-L6-v2", | |
| ) | |
| return os.getenv("LOCAL_EMBEDDING_MODEL", "nomic-ai/CodeRankEmbed") | |
| def _is_hf_space(self) -> bool: | |
| return bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID")) | |
| def _is_test_context(self) -> bool: | |
| app_env = os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "")).lower() | |
| return app_env == "test" or bool(os.getenv("PYTEST_CURRENT_TEST")) | |
| def _optional_int(self, value: Optional[str]) -> Optional[int]: | |
| if value is None or not str(value).strip(): | |
| return None | |
| return int(value) | |