File size: 5,319 Bytes
60b97da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee4d71c
 
60b97da
 
 
 
 
 
 
 
 
 
ee4d71c
 
 
 
 
 
 
 
60b97da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
from typing import List, Optional, Tuple
from uuid import uuid4

import numpy as np
from qdrant_client import QdrantClient, models


class QdrantVectorStore:
    def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = False):
        self.embedding_dim = embedding_dim
        self.collection_name = os.getenv("QDRANT_COLLECTION", "repo_qa_chunks")
        self.upsert_batch_size = max(1, int(os.getenv("QDRANT_UPSERT_BATCH_SIZE", "64")))
        self.client = self._create_client()
        self._ensure_collection()

    def _create_client(self):
        url = self._clean_env("QDRANT_URL")
        api_key = self._clean_env("QDRANT_API_KEY")
        timeout = int(os.getenv("QDRANT_TIMEOUT_SECONDS", "120"))
        if url:
            return QdrantClient(
                url=url,
                api_key=api_key,
                timeout=timeout,
                check_compatibility=False,
            )
        return QdrantClient(":memory:")

    @staticmethod
    def _clean_env(name: str) -> Optional[str]:
        value = os.getenv(name)
        if value is None:
            return None
        cleaned = value.strip()
        return cleaned or None

    def _ensure_collection(self):
        if not self.client.collection_exists(self.collection_name):
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=models.VectorParams(
                    size=self.embedding_dim,
                    distance=models.Distance.COSINE,
                ),
            )
        self._ensure_payload_indexes()

    def _ensure_payload_indexes(self):
        self.client.create_payload_index(
            collection_name=self.collection_name,
            field_name="repository_id",
            field_schema=models.PayloadSchemaType.INTEGER,
            wait=True,
        )

    def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[int]:
        if embeddings.size == 0:
            return []

        embeddings = embeddings.astype("float32")
        if embeddings.ndim == 1:
            embeddings = embeddings.reshape(1, -1)

        ids = [uuid4().hex for _ in metadata]
        points = []
        for idx, meta, embedding in zip(ids, metadata, embeddings):
            payload = dict(meta)
            payload["id"] = idx
            points.append(
                models.PointStruct(
                    id=idx,
                    vector=embedding.tolist(),
                    payload=payload,
                )
            )
        total_points = len(points)
        for start in range(0, total_points, self.upsert_batch_size):
            batch = points[start : start + self.upsert_batch_size]
            batch_number = (start // self.upsert_batch_size) + 1
            total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
            print(
                f"[qdrant] Upserting batch {batch_number}/{total_batches} "
                f"points={len(batch)} progress={start}/{total_points}",
                flush=True,
            )
            self.client.upsert(
                collection_name=self.collection_name,
                wait=True,
                points=batch,
            )

        return ids

    def search(
        self,
        query_embedding: np.ndarray,
        k: int = 10,
        repo_filter: Optional[int] = None,
    ) -> List[Tuple[float, dict]]:
        if query_embedding.ndim == 1:
            query_embedding = query_embedding.reshape(1, -1)
        query_embedding = query_embedding.astype("float32")

        query_filter = None
        if repo_filter is not None:
            query_filter = models.Filter(
                must=[
                    models.FieldCondition(
                        key="repository_id",
                        match=models.MatchValue(value=repo_filter),
                    )
                ]
            )

        hits = self.client.search(
            collection_name=self.collection_name,
            query_vector=query_embedding[0].tolist(),
            query_filter=query_filter,
            limit=k,
        )

        return [(float(hit.score), dict(hit.payload or {})) for hit in hits]

    def remove_repository(self, repo_id: int):
        self.client.delete(
            collection_name=self.collection_name,
            wait=True,
            points_selector=models.FilterSelector(
                filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="repository_id",
                            match=models.MatchValue(value=repo_id),
                        )
                    ]
                )
            ),
        )

    def clear(self):
        if self.client.collection_exists(self.collection_name):
            self.client.delete_collection(self.collection_name)
        self._ensure_collection()

    def save(self):
        return None

    def load(self):
        self._ensure_collection()

    def get_stats(self) -> dict:
        info = self.client.get_collection(self.collection_name)
        return {
            "total_vectors": info.points_count or 0,
            "embedding_dim": self.embedding_dim,
            "collection_name": self.collection_name,
        }