Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Optional, Tuple | |
| from uuid import uuid4 | |
| import numpy as np | |
| from qdrant_client import QdrantClient, models | |
| class QdrantVectorStore: | |
| def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = False): | |
| self.embedding_dim = embedding_dim | |
| self.collection_name = os.getenv("QDRANT_COLLECTION", "repo_qa_chunks") | |
| self.upsert_batch_size = max(1, int(os.getenv("QDRANT_UPSERT_BATCH_SIZE", "64"))) | |
| self.client = self._create_client() | |
| self._ensure_collection() | |
| def _create_client(self): | |
| url = self._clean_env("QDRANT_URL") | |
| api_key = self._clean_env("QDRANT_API_KEY") | |
| timeout = int(os.getenv("QDRANT_TIMEOUT_SECONDS", "120")) | |
| if url: | |
| return QdrantClient( | |
| url=url, | |
| api_key=api_key, | |
| timeout=timeout, | |
| check_compatibility=False, | |
| ) | |
| return QdrantClient(":memory:") | |
| def _clean_env(name: str) -> Optional[str]: | |
| value = os.getenv(name) | |
| if value is None: | |
| return None | |
| cleaned = value.strip() | |
| return cleaned or None | |
| def _ensure_collection(self): | |
| if not self.client.collection_exists(self.collection_name): | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=models.VectorParams( | |
| size=self.embedding_dim, | |
| distance=models.Distance.COSINE, | |
| ), | |
| ) | |
| self._ensure_payload_indexes() | |
| def _ensure_payload_indexes(self): | |
| self.client.create_payload_index( | |
| collection_name=self.collection_name, | |
| field_name="repository_id", | |
| field_schema=models.PayloadSchemaType.INTEGER, | |
| wait=True, | |
| ) | |
| def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[int]: | |
| if embeddings.size == 0: | |
| return [] | |
| embeddings = embeddings.astype("float32") | |
| if embeddings.ndim == 1: | |
| embeddings = embeddings.reshape(1, -1) | |
| ids = [uuid4().hex for _ in metadata] | |
| points = [] | |
| for idx, meta, embedding in zip(ids, metadata, embeddings): | |
| payload = dict(meta) | |
| payload["id"] = idx | |
| points.append( | |
| models.PointStruct( | |
| id=idx, | |
| vector=embedding.tolist(), | |
| payload=payload, | |
| ) | |
| ) | |
| total_points = len(points) | |
| for start in range(0, total_points, self.upsert_batch_size): | |
| batch = points[start : start + self.upsert_batch_size] | |
| batch_number = (start // self.upsert_batch_size) + 1 | |
| total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size | |
| print( | |
| f"[qdrant] Upserting batch {batch_number}/{total_batches} " | |
| f"points={len(batch)} progress={start}/{total_points}", | |
| flush=True, | |
| ) | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| wait=True, | |
| points=batch, | |
| ) | |
| return ids | |
| def search( | |
| self, | |
| query_embedding: np.ndarray, | |
| k: int = 10, | |
| repo_filter: Optional[int] = None, | |
| ) -> List[Tuple[float, dict]]: | |
| if query_embedding.ndim == 1: | |
| query_embedding = query_embedding.reshape(1, -1) | |
| query_embedding = query_embedding.astype("float32") | |
| query_filter = None | |
| if repo_filter is not None: | |
| query_filter = models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="repository_id", | |
| match=models.MatchValue(value=repo_filter), | |
| ) | |
| ] | |
| ) | |
| hits = self.client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding[0].tolist(), | |
| query_filter=query_filter, | |
| limit=k, | |
| ) | |
| return [(float(hit.score), dict(hit.payload or {})) for hit in hits] | |
| def remove_repository(self, repo_id: int): | |
| self.client.delete( | |
| collection_name=self.collection_name, | |
| wait=True, | |
| points_selector=models.FilterSelector( | |
| filter=models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="repository_id", | |
| match=models.MatchValue(value=repo_id), | |
| ) | |
| ] | |
| ) | |
| ), | |
| ) | |
| def clear(self): | |
| if self.client.collection_exists(self.collection_name): | |
| self.client.delete_collection(self.collection_name) | |
| self._ensure_collection() | |
| def save(self): | |
| return None | |
| def load(self): | |
| self._ensure_collection() | |
| def get_stats(self) -> dict: | |
| info = self.client.get_collection(self.collection_name) | |
| return { | |
| "total_vectors": info.points_count or 0, | |
| "embedding_dim": self.embedding_dim, | |
| "collection_name": self.collection_name, | |
| } | |