code-compass / src /embeddings.py
technophyle's picture
Sync from GitHub via hub-sync
60b97da verified
import os
import time
from typing import Callable, List, Optional
import numpy as np
from openai import OpenAI
from sentence_transformers import SentenceTransformer
class EmbeddingGenerator:
def __init__(self, provider: str = None, model_name: str = None):
configured_provider = (provider or os.getenv("EMBEDDING_PROVIDER", "auto")).lower()
self.provider = self._resolve_provider(configured_provider)
self.model_name = model_name or self._resolve_model_name()
self.batch_size = int(os.getenv("EMBEDDING_BATCH_SIZE", "8"))
self.device = os.getenv("EMBEDDING_DEVICE")
self.client = None
self.model = None
self.vertex_task_type_document = os.getenv(
"VERTEX_EMBEDDING_TASK_TYPE_DOCUMENT", "RETRIEVAL_DOCUMENT"
)
self.vertex_task_type_query = os.getenv(
"VERTEX_EMBEDDING_TASK_TYPE_QUERY", "RETRIEVAL_QUERY"
)
self.vertex_output_dimensionality = self._optional_int(
os.getenv("VERTEX_EMBEDDING_OUTPUT_DIMENSIONALITY")
)
self.query_prefix = os.getenv("EMBEDDING_QUERY_PREFIX", "").strip()
normalized_model_name = self.model_name.lower()
self.query_prompt_name = (
os.getenv("EMBEDDING_QUERY_PROMPT_NAME", "query")
if "nomic-embed-code" in normalized_model_name
or "coderankembed" in normalized_model_name
else None
)
if self.provider == "openai":
print(
f"[embeddings] Initializing OpenAI embeddings with model={self.model_name}",
flush=True,
)
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.embedding_dim = int(os.getenv("OPENAI_EMBEDDING_DIM", "1536"))
elif self.provider == "vertex_ai":
print(
f"[embeddings] Initializing Vertex AI embeddings with model={self.model_name}",
flush=True,
)
try:
from google import genai
except ImportError as exc:
raise RuntimeError(
"Vertex AI embedding support requires the `google-genai` package."
) from exc
project = os.getenv("GOOGLE_CLOUD_PROJECT")
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
if not project:
raise RuntimeError(
"GOOGLE_CLOUD_PROJECT must be set when using Vertex AI embeddings."
)
self.client = genai.Client(
vertexai=True,
project=project,
location=location,
)
self.embedding_dim = int(
os.getenv(
"VERTEX_EMBEDDING_DIM",
str(self.vertex_output_dimensionality or 3072),
)
)
else:
model_device = self.device or "cpu"
print(
f"[embeddings] Loading local embedding model={self.model_name} on device={model_device}",
flush=True,
)
started_at = time.perf_counter()
self.model = SentenceTransformer(
self.model_name,
trust_remote_code=True,
device=model_device,
)
self.embedding_dim = self.model.get_sentence_embedding_dimension()
elapsed = time.perf_counter() - started_at
print(
f"[embeddings] Model ready dim={self.embedding_dim} load_time={elapsed:.2f}s",
flush=True,
)
def embed_text(self, text: str) -> np.ndarray:
if self.provider == "openai":
return self.embed_batch([text])[0]
if self.provider == "vertex_ai":
return self._embed_with_vertex(
[text],
task_type=self.vertex_task_type_query,
)[0]
query_text = f"{self.query_prefix}: {text}" if self.query_prefix else text
return self._encode_with_backoff([query_text], prompt_name=self.query_prompt_name)[0]
def embed_batch(
self,
texts: List[str],
batch_size: int = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> np.ndarray:
if not texts:
return np.array([], dtype="float32")
if self.provider == "openai":
response = self.client.embeddings.create(
model=self.model_name or "text-embedding-3-small",
input=texts,
)
embeddings = [item.embedding for item in response.data]
if progress_callback:
progress_callback(len(texts), len(texts))
return np.array(embeddings, dtype="float32")
if self.provider == "vertex_ai":
return self._embed_batch_with_vertex(
texts=texts,
batch_size=batch_size,
progress_callback=progress_callback,
)
effective_batch_size = max(1, batch_size or self.batch_size)
all_embeddings = []
total = len(texts)
for start in range(0, total, effective_batch_size):
batch = texts[start : start + effective_batch_size]
batch_number = (start // effective_batch_size) + 1
total_batches = (total + effective_batch_size - 1) // effective_batch_size
print(
f"[embeddings] Encoding batch {batch_number}/{total_batches} "
f"items={len(batch)} progress={start}/{total}",
flush=True,
)
started_at = time.perf_counter()
batch_embeddings = self._encode_with_backoff(
batch,
batch_size=min(effective_batch_size, len(batch)),
)
all_embeddings.append(batch_embeddings)
elapsed = time.perf_counter() - started_at
print(
f"[embeddings] Finished batch {batch_number}/{total_batches} "
f"elapsed={elapsed:.2f}s progress={min(start + len(batch), total)}/{total}",
flush=True,
)
if progress_callback:
progress_callback(min(start + len(batch), total), total)
return np.vstack(all_embeddings).astype("float32")
def _embed_batch_with_vertex(
self,
texts: List[str],
batch_size: int = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> np.ndarray:
effective_batch_size = max(1, batch_size or self.batch_size)
all_embeddings = []
total = len(texts)
for start in range(0, total, effective_batch_size):
batch = texts[start : start + effective_batch_size]
batch_number = (start // effective_batch_size) + 1
total_batches = (total + effective_batch_size - 1) // effective_batch_size
print(
f"[embeddings] Vertex batch {batch_number}/{total_batches} "
f"items={len(batch)} progress={start}/{total}",
flush=True,
)
started_at = time.perf_counter()
batch_embeddings = self._embed_with_vertex(
batch,
task_type=self.vertex_task_type_document,
)
all_embeddings.append(batch_embeddings)
elapsed = time.perf_counter() - started_at
print(
f"[embeddings] Finished Vertex batch {batch_number}/{total_batches} "
f"elapsed={elapsed:.2f}s progress={min(start + len(batch), total)}/{total}",
flush=True,
)
if progress_callback:
progress_callback(min(start + len(batch), total), total)
return np.vstack(all_embeddings).astype("float32")
def _embed_with_vertex(self, texts: List[str], task_type: str) -> np.ndarray:
config = {
"task_type": task_type,
}
if self.vertex_output_dimensionality:
config["output_dimensionality"] = self.vertex_output_dimensionality
response = self.client.models.embed_content(
model=self.model_name,
contents=texts,
config=config,
)
embeddings = getattr(response, "embeddings", None)
if not embeddings:
raise RuntimeError("Vertex AI embeddings returned an empty response.")
values = []
for item in embeddings:
if hasattr(item, "values"):
values.append(item.values)
elif isinstance(item, dict):
values.append(item.get("values"))
else:
values.append(getattr(item, "embedding", None))
if not values or any(vector is None for vector in values):
raise RuntimeError("Vertex AI embeddings response could not be parsed.")
return np.array(values, dtype="float32")
def _encode_with_backoff(
self,
texts: List[str],
batch_size: int = None,
prompt_name: str = None,
) -> np.ndarray:
effective_batch_size = max(1, batch_size or self.batch_size)
while True:
try:
encode_kwargs = {
"sentences": texts,
"batch_size": effective_batch_size,
"show_progress_bar": len(texts) > effective_batch_size,
"convert_to_numpy": True,
"normalize_embeddings": True,
}
if prompt_name:
encode_kwargs["prompt_name"] = prompt_name
embeddings = self.model.encode(
**encode_kwargs,
)
return embeddings.astype("float32")
except RuntimeError as exc:
message = str(exc).lower()
is_memory_error = "out of memory" in message or "mps" in message
if not is_memory_error or effective_batch_size == 1:
raise
print(
f"[embeddings] Retrying batch with smaller size due to memory pressure: "
f"{effective_batch_size} -> {max(1, effective_batch_size // 2)}",
flush=True,
)
effective_batch_size = max(1, effective_batch_size // 2)
def get_embedding_dim(self) -> int:
return self.embedding_dim
def _resolve_provider(self, configured_provider: str) -> str:
if configured_provider != "auto":
return configured_provider
if self._is_hf_space() or self._is_test_context():
return "local"
return "vertex_ai"
def _resolve_model_name(self) -> str:
explicit_model = os.getenv("EMBEDDING_MODEL")
if explicit_model:
return explicit_model
if self.provider == "vertex_ai":
return os.getenv("VERTEX_EMBEDDING_MODEL", "gemini-embedding-001")
if self._is_hf_space() or self._is_test_context():
return os.getenv(
"LIGHTWEIGHT_LOCAL_EMBEDDING_MODEL",
"sentence-transformers/all-MiniLM-L6-v2",
)
return os.getenv("LOCAL_EMBEDDING_MODEL", "nomic-ai/CodeRankEmbed")
def _is_hf_space(self) -> bool:
return bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"))
def _is_test_context(self) -> bool:
app_env = os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "")).lower()
return app_env == "test" or bool(os.getenv("PYTEST_CURRENT_TEST"))
def _optional_int(self, value: Optional[str]) -> Optional[int]:
if value is None or not str(value).strip():
return None
return int(value)