code-compass / src /vector_store.py
technophyle's picture
Sync from GitHub via hub-sync
ee4d71c verified
import os
from typing import List, Optional, Tuple
from uuid import uuid4
import numpy as np
from qdrant_client import QdrantClient, models
class QdrantVectorStore:
def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = False):
self.embedding_dim = embedding_dim
self.collection_name = os.getenv("QDRANT_COLLECTION", "repo_qa_chunks")
self.upsert_batch_size = max(1, int(os.getenv("QDRANT_UPSERT_BATCH_SIZE", "64")))
self.client = self._create_client()
self._ensure_collection()
def _create_client(self):
url = self._clean_env("QDRANT_URL")
api_key = self._clean_env("QDRANT_API_KEY")
timeout = int(os.getenv("QDRANT_TIMEOUT_SECONDS", "120"))
if url:
return QdrantClient(
url=url,
api_key=api_key,
timeout=timeout,
check_compatibility=False,
)
return QdrantClient(":memory:")
@staticmethod
def _clean_env(name: str) -> Optional[str]:
value = os.getenv(name)
if value is None:
return None
cleaned = value.strip()
return cleaned or None
def _ensure_collection(self):
if not self.client.collection_exists(self.collection_name):
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=self.embedding_dim,
distance=models.Distance.COSINE,
),
)
self._ensure_payload_indexes()
def _ensure_payload_indexes(self):
self.client.create_payload_index(
collection_name=self.collection_name,
field_name="repository_id",
field_schema=models.PayloadSchemaType.INTEGER,
wait=True,
)
def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[int]:
if embeddings.size == 0:
return []
embeddings = embeddings.astype("float32")
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
ids = [uuid4().hex for _ in metadata]
points = []
for idx, meta, embedding in zip(ids, metadata, embeddings):
payload = dict(meta)
payload["id"] = idx
points.append(
models.PointStruct(
id=idx,
vector=embedding.tolist(),
payload=payload,
)
)
total_points = len(points)
for start in range(0, total_points, self.upsert_batch_size):
batch = points[start : start + self.upsert_batch_size]
batch_number = (start // self.upsert_batch_size) + 1
total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
print(
f"[qdrant] Upserting batch {batch_number}/{total_batches} "
f"points={len(batch)} progress={start}/{total_points}",
flush=True,
)
self.client.upsert(
collection_name=self.collection_name,
wait=True,
points=batch,
)
return ids
def search(
self,
query_embedding: np.ndarray,
k: int = 10,
repo_filter: Optional[int] = None,
) -> List[Tuple[float, dict]]:
if query_embedding.ndim == 1:
query_embedding = query_embedding.reshape(1, -1)
query_embedding = query_embedding.astype("float32")
query_filter = None
if repo_filter is not None:
query_filter = models.Filter(
must=[
models.FieldCondition(
key="repository_id",
match=models.MatchValue(value=repo_filter),
)
]
)
hits = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding[0].tolist(),
query_filter=query_filter,
limit=k,
)
return [(float(hit.score), dict(hit.payload or {})) for hit in hits]
def remove_repository(self, repo_id: int):
self.client.delete(
collection_name=self.collection_name,
wait=True,
points_selector=models.FilterSelector(
filter=models.Filter(
must=[
models.FieldCondition(
key="repository_id",
match=models.MatchValue(value=repo_id),
)
]
)
),
)
def clear(self):
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(self.collection_name)
self._ensure_collection()
def save(self):
return None
def load(self):
self._ensure_collection()
def get_stats(self) -> dict:
info = self.client.get_collection(self.collection_name)
return {
"total_vectors": info.points_count or 0,
"embedding_dim": self.embedding_dim,
"collection_name": self.collection_name,
}