m-ahmad-official commited on
Commit
e88ac9f
·
1 Parent(s): 09b3403
Files changed (1) hide show
  1. retrieve.py +396 -70
retrieve.py CHANGED
@@ -1,104 +1,430 @@
1
  """
2
- Retrieval module for RAG Book Chatbot.
3
- Handles vector search using Qdrant and Cohere embeddings.
 
 
 
 
4
  """
5
 
 
 
 
 
6
  import logging
7
- from typing import List, Dict, Any, Optional
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
9
  logger = logging.getLogger(__name__)
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def search(
13
  query_text: str,
14
- cohere_client: Any,
15
- qdrant_client: Any,
16
  collection_name: str,
17
  top_k: int = 5,
18
  ) -> List[Dict[str, Any]]:
19
  """
20
- Search for relevant chunks in Qdrant using Cohere embeddings.
21
 
22
  Args:
23
- query_text: User's question or search query
24
- cohere_client: Initialized Cohere client
25
- qdrant_client: Initialized Qdrant client
26
- collection_name: Name of the Qdrant collection
27
- top_k: Number of results to return (default: 5)
28
 
29
  Returns:
30
- List of search results with scores and metadata
31
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  try:
33
- # Generate embedding for the query
34
- logger.info(f"Generating embedding for query: {query_text[:100]}...")
35
- embedding_response = cohere_client.embed(
36
- texts=[query_text],
37
- model="embed-english-v3.0",
38
- input_type="search_query",
39
  )
 
 
 
 
 
 
 
40
 
41
- # In Cohere V2, embeddings are returned as a list-like object directly
42
- # The response.embeddings is iterable, and we want the first element
43
- try:
44
- # First try: if embeddings is directly a list of embeddings
45
- if isinstance(embedding_response.embeddings, list):
46
- query_embedding = list(embedding_response.embeddings[0])
47
- else:
48
- # Convert to list if it's an iterable object
49
- embeddings_list = [e for e in embedding_response.embeddings]
50
- query_embedding = list(embeddings_list[0])
51
-
52
- # Search in Qdrant
53
- logger.info(f"Searching Qdrant collection: {collection_name}")
54
- search_results = qdrant_client.query_points(
55
- collection_name=collection_name,
56
- query=query_embedding,
57
- limit=top_k,
58
  )
 
 
 
 
 
 
 
 
59
 
60
- logger.info(f"Found {len(search_results.points)} results")
61
-
62
- # Format results
63
- results = []
64
- for hit in search_results.points:
65
- results.append(
66
- {
67
- "id": hit.id,
68
- "score": hit.score,
69
- "payload": hit.payload,
70
- }
71
- )
72
 
73
- return results
 
74
 
75
- except Exception as e:
76
- logger.error(
77
- f"Search failed for query '{query_text[:100]}...': {type(e).__name__}: {e}",
78
- exc_info=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  )
80
- raise
81
 
 
 
 
82
 
83
- def validate_results(results: List[Dict[str, Any]]) -> float:
84
- """
85
- Validate that results have required metadata.
 
 
 
 
 
86
 
87
- Args:
88
- results: List of search results
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- Returns:
91
- Percentage of results with complete metadata (0-1)
92
- """
93
- if not results:
94
- return 1.0
 
 
 
 
 
 
 
95
 
96
- required_fields = {"url", "chunk_index", "text"}
97
- valid_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- for result in results:
100
- payload = result.get("payload", {})
101
- if all(field in payload for field in required_fields):
102
- valid_count += 1
103
 
104
- return valid_count / len(results)
 
 
1
  """
2
+ Retrieval pipeline for RAG validation.
3
+
4
+ This module provides functions to:
5
+ - Convert search queries to embeddings using Cohere
6
+ - Perform similarity search against Qdrant collection
7
+ - Format and return results with metadata
8
  """
9
 
10
+ import argparse
11
+ import json
12
+ import sys
13
+ import time
14
  import logging
15
+ from pathlib import Path
16
+ from typing import List, Dict, Any
17
+
18
+ # Add parent directory to path for imports
19
+ sys.path.insert(0, str(Path(__file__).parent))
20
+
21
+ import cohere
22
+ from qdrant_client import QdrantClient
23
+
24
+ # Importfrom existing modules
25
+ import config
26
+ import utils
27
+ from logging_config import setup_logging
28
 
29
+ # Initialize logger
30
  logger = logging.getLogger(__name__)
31
 
32
 
33
+ # Custom exceptions
34
+ class ConfigurationError(Exception):
35
+ """Raised when required configuration is missing."""
36
+
37
+ pass
38
+
39
+
40
+ class CollectionNotFoundError(Exception):
41
+ """Raised when Qdrant collection doesn't exist."""
42
+
43
+ pass
44
+
45
+
46
+ class DimensionMismatchError(Exception):
47
+ """Raised when embedding dimension doesn't match collection."""
48
+
49
+ pass
50
+
51
+
52
+ class APIError(Exception):
53
+ """Raised when Cohere or Qdrant API call fails after retries."""
54
+
55
+ pass
56
+
57
+
58
+ def validate_config(cfg: dict) -> None:
59
+ """Validate that all required config values are present."""
60
+ required = ["cohere_api_key", "qdrant_url", "qdrant_api_key"]
61
+ missing = [key for key in required if not cfg.get(key)]
62
+ if missing:
63
+ raise ConfigurationError(
64
+ f"Missing required environment variables: {', '.join(missing)}"
65
+ )
66
+
67
+
68
+ def init_clients(cfg: dict):
69
+ """Initialize Cohere and Qdrant clients."""
70
+ cohere_client = cohere.ClientV2(api_key=cfg["cohere_api_key"])
71
+ qdrant_client = QdrantClient(url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"])
72
+ return cohere_client, qdrant_client
73
+
74
+
75
+ def check_collection(
76
+ qdrant_client: QdrantClient, collection_name: str
77
+ ) -> Dict[str, Any]:
78
+ """Verify collection exists and has correct vector size."""
79
+ try:
80
+ info = qdrant_client.get_collection(collection_name)
81
+ except Exception as e:
82
+ if "not found" in str(e).lower():
83
+ raise CollectionNotFoundError(
84
+ f"Collection '{collection_name}' does not exist"
85
+ )
86
+ raise
87
+
88
+ vector_size = info.config.params.vectors.size
89
+ if vector_size != 1024:
90
+ raise DimensionMismatchError(f"Expected vector size 1024 but got {vector_size}")
91
+
92
+ return {
93
+ "exists": True,
94
+ "vector_size": vector_size,
95
+ "points_count": info.points_count,
96
+ }
97
+
98
+
99
+ def embed_query(text: str, cohere_client: cohere.ClientV2) -> List[float]:
100
+ """Generate embedding for a search query using Cohere."""
101
+ try:
102
+ response = cohere_client.embed(
103
+ texts=[text], model="embed-english-v3.0", input_type="search_query"
104
+ )
105
+ # Extract embedding from response.embeddings.float_
106
+ embedding = response.embeddings.float_[0]
107
+ return embedding
108
+ except Exception as e:
109
+ logger.error(f"Failed to generate embedding: {e}")
110
+ raise APIError(f"Cohere embedding failed: {e}")
111
+
112
+
113
+ def validate_metadata_completeness(results: List[Dict[str, Any]]) -> float:
114
+ """
115
+ Check metadata completeness in search results.
116
+
117
+ Returns:
118
+ Percentage (0-100) of results with complete metadata:
119
+ - url present and non-empty
120
+ - text present with length ≥ 10
121
+ - at least one of title or section non-empty
122
+ """
123
+ if not results:
124
+ return 0.0
125
+
126
+ complete = 0
127
+ total = len(results)
128
+
129
+ for result in results:
130
+ payload = result.get("payload", {})
131
+ url = payload.get("url", "")
132
+ text = payload.get("text", "")
133
+ title = payload.get("title", "")
134
+ section = payload.get("section", "")
135
+
136
+ # Check completeness criteria
137
+ url_ok = bool(url and url.strip())
138
+ text_ok = len(text or "") >= 10
139
+ title_section_ok = bool(
140
+ (title and title.strip()) or (section and section.strip())
141
+ )
142
+
143
+ if url_ok and text_ok and title_section_ok:
144
+ complete += 1
145
+
146
+ percentage = (complete / total) * 100
147
+ logger.debug(f"Metadata completeness: {complete}/{total} = {percentage:.1f}%")
148
+ return percentage
149
+
150
+
151
+ def validate_chunk_sequencing(results: List[Dict[str, Any]]) -> bool:
152
+ """
153
+ Verify that chunk_index values are properly assigned: integers >= 0 and unique per URL.
154
+
155
+ Note: Since search may return only a subset of chunks for a URL, we cannot
156
+ verify full sequential continuity (0,1,2,3...). Instead we check:
157
+ - All chunk_index values are integers >= 0
158
+ - No duplicate chunk_index for the same URL in the result set
159
+
160
+ Args:
161
+ results: List of search results
162
+
163
+ Returns:
164
+ True if chunk indices are valid, False otherwise
165
+ """
166
+ # Group by URL
167
+ url_chunks = {}
168
+ for result in results:
169
+ payload = result.get("payload", {})
170
+ url = payload.get("url", "")
171
+ chunk_idx = payload.get("chunk_index")
172
+
173
+ if url not in url_chunks:
174
+ url_chunks[url] = []
175
+ url_chunks[url].append(chunk_idx)
176
+
177
+ # Check each URL's chunks are valid
178
+ for url, indices in url_chunks.items():
179
+ # All indices must be integers >= 0
180
+ for idx in indices:
181
+ if not isinstance(idx, int) or idx < 0:
182
+ logger.debug(
183
+ f"Invalid chunk_index for {url}: {idx} (must be non-negative integer)"
184
+ )
185
+ return False
186
+
187
+ # Check for duplicates (within this URL's results)
188
+ if len(set(indices)) != len(indices):
189
+ logger.debug(f"Duplicate chunk_index for {url}: {indices}")
190
+ return False
191
+
192
+ logger.debug(f"Chunk indexing valid for {len(url_chunks)} URLs")
193
+ return True
194
+
195
+
196
  def search(
197
  query_text: str,
198
+ cohere_client: cohere.ClientV2,
199
+ qdrant_client: QdrantClient,
200
  collection_name: str,
201
  top_k: int = 5,
202
  ) -> List[Dict[str, Any]]:
203
  """
204
+ Convert query to embedding and retrieve top-K relevant chunks.
205
 
206
  Args:
207
+ query_text: User's search query (non-empty, ≤1000 chars)
208
+ top_k: Number of results to return (1-100)
 
 
 
209
 
210
  Returns:
211
+ List of search results with id, score, and payload
212
  """
213
+ # Validate inputs
214
+ if not query_text or not query_text.strip():
215
+ raise ValueError("Query text must be non-empty")
216
+ query_text = query_text.strip()
217
+ if len(query_text) > 1000:
218
+ raise ValueError("Query text must be ≤ 1000 characters")
219
+ if top_k < 1 or top_k > 100:
220
+ raise ValueError("top_k must be between 1 and 100")
221
+
222
+ logger.info(f"Embedding query: '{query_text[:100]}...' (top_k={top_k})")
223
+ start_time = time.time()
224
+
225
+ # Generate query embedding with retry
226
  try:
227
+ embedding = utils.retry_with_backoff(
228
+ lambda: embed_query(query_text, cohere_client),
229
+ max_retries=3,
230
+ base_delay=1.0,
231
+ max_delay=10.0,
 
232
  )
233
+ embed_time = time.time() - start_time
234
+ logger.debug(
235
+ f"Generated embedding in {embed_time:.2f}s, dimension: {len(embedding)}"
236
+ )
237
+ except Exception as e:
238
+ logger.error(f"Failed to embed query: {e}")
239
+ raise
240
 
241
+ # Search Qdrant with retry
242
+ try:
243
+ search_start = time.time()
244
+ response = utils.retry_with_backoff(
245
+ lambda: qdrant_client.query_points(
246
+ collection_name=collection_name,
247
+ query=embedding,
248
+ limit=top_k,
249
+ with_payload=True,
250
+ with_vectors=False,
251
+ ),
252
+ max_retries=3,
253
+ base_delay=1.0,
254
+ max_delay=10.0,
 
 
 
255
  )
256
+ results = response.points
257
+ search_time = time.time() - search_start
258
+ logger.info(
259
+ f"Search completed in {search_time:.2f}s, returned {len(results)} results"
260
+ )
261
+ except Exception as e:
262
+ logger.error(f"Search failed: {e}")
263
+ raise APIError(f"Qdrant search failed: {e}")
264
 
265
+ # Format results
266
+ formatted = []
267
+ for result in results:
268
+ formatted.append(
269
+ {
270
+ "id": str(result.id),
271
+ "score": float(result.score),
272
+ "payload": result.payload,
273
+ }
274
+ )
 
 
275
 
276
+ total_time = time.time() - start_time
277
+ logger.info(f"Total query time: {total_time:.2f}s")
278
 
279
+ return formatted
280
+
281
+
282
+ def format_results(
283
+ results: List[Dict[str, Any]], query: str, latency_ms: int
284
+ ) -> Dict[str, Any]:
285
+ """Format search results into JSON output structure."""
286
+ output = {
287
+ "query": query,
288
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
289
+ "results": results,
290
+ "metadata": {
291
+ "total_results": len(results),
292
+ "collection": None, # Will be filled by main
293
+ "latency_ms": latency_ms,
294
+ },
295
+ }
296
+ return output
297
+
298
+
299
+ def main() -> int:
300
+ """CLI entrypoint for retrieval."""
301
+ parser = argparse.ArgumentParser(
302
+ description="Retrieve relevant chunks from Qdrant using Cohere embeddings"
303
+ )
304
+ parser.add_argument("--query", type=str, help="Search query text")
305
+ parser.add_argument(
306
+ "--top-k", type=int, default=5, help="Number of results to return (default: 5)"
307
+ )
308
+ parser.add_argument("--output", type=str, help="Output file path (default: stdout)")
309
+ parser.add_argument(
310
+ "--config",
311
+ type=str,
312
+ default=".env",
313
+ help="Path to .env config file (default: .env)",
314
+ )
315
+ parser.add_argument(
316
+ "--validate-metadata",
317
+ action="store_true",
318
+ help="Run metadata validation on search results (requires --query)",
319
+ )
320
+
321
+ args = parser.parse_args()
322
+
323
+ # Setup logging
324
+ log_file = "retrieve.log"
325
+ setup_logging(log_file=log_file, console_level="INFO")
326
+ logger.info("=== Retrieval Pipeline Started ===")
327
+
328
+ try:
329
+ # Load config
330
+ logger.info(f"Loading config from {args.config}")
331
+ cfg = config.get_config()
332
+ validate_config(cfg)
333
+
334
+ # Initialize clients
335
+ logger.info("Initializing Cohere and Qdrant clients")
336
+ cohere_client, qdrant_client = init_clients(cfg)
337
+
338
+ # Check collection
339
+ collection_name = cfg["qdrant_collection"]
340
+ logger.info(f"Checking collection '{collection_name}'")
341
+ coll_info = check_collection(qdrant_client, collection_name)
342
+ logger.info(
343
+ f"Collection OK: vector_size={coll_info['vector_size']}, points={coll_info['points_count']}"
344
  )
 
345
 
346
+ # Validate query argument
347
+ if not args.query:
348
+ parser.error("--query is required")
349
 
350
+ # Perform search
351
+ results = search(
352
+ query_text=args.query,
353
+ cohere_client=cohere_client,
354
+ qdrant_client=qdrant_client,
355
+ collection_name=collection_name,
356
+ top_k=args.top_k,
357
+ )
358
 
359
+ # Perform metadata validation if requested
360
+ metadata_validation = None
361
+ if args.validate_metadata:
362
+ completeness = validate_metadata_completeness(results)
363
+ sequencing = validate_chunk_sequencing(results)
364
+ metadata_validation = {
365
+ "completeness_pct": round(completeness, 2),
366
+ "sequencing_valid": sequencing,
367
+ "pass": completeness >= 98.0 and sequencing,
368
+ }
369
+ logger.info(f"Metadata completeness: {completeness:.1f}%")
370
+ logger.info(f"Chunk sequencing: {'VALID' if sequencing else 'INVALID'}")
371
+ logger.info(
372
+ f"Validation result: {'PASS' if metadata_validation['pass'] else 'FAIL'}"
373
+ )
374
 
375
+ # Format output
376
+ output = {
377
+ "query": args.query,
378
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
379
+ "results": results,
380
+ "metadata": {
381
+ "total_results": len(results),
382
+ "collection": collection_name,
383
+ "vector_size": coll_info["vector_size"],
384
+ "points_count": coll_info["points_count"],
385
+ },
386
+ }
387
 
388
+ if metadata_validation:
389
+ output["metadata_validation"] = metadata_validation
390
+
391
+ # Output JSON
392
+ json_output = json.dumps(output, indent=2)
393
+ if args.output:
394
+ with open(args.output, "w") as f:
395
+ f.write(json_output)
396
+ logger.info(f"Results written to {args.output}")
397
+ else:
398
+ print(json_output)
399
+
400
+ logger.info("=== Retrieval Pipeline Completed Successfully ===")
401
+ return 0
402
+
403
+ except ValueError as ve:
404
+ logger.error(f"Validation error: {ve}")
405
+ print(f"ERROR: {ve}", file=sys.stderr)
406
+ return 2
407
+ except ConfigurationError as ce:
408
+ logger.error(f"Configuration error: {ce}")
409
+ print(f"ERROR: {ce}", file=sys.stderr)
410
+ return 1
411
+ except CollectionNotFoundError as cnfe:
412
+ logger.error(f"Collection error: {cnfe}")
413
+ print(f"ERROR: {cnfe}", file=sys.stderr)
414
+ return 1
415
+ except DimensionMismatchError as dme:
416
+ logger.error(f"Dimension error: {dme}")
417
+ print(f"ERROR: {dme}", file=sys.stderr)
418
+ return 1
419
+ except APIError as api_err:
420
+ logger.error(f"API error: {api_err}")
421
+ print(f"ERROR: {api_err}", file=sys.stderr)
422
+ return 1
423
+ except Exception as e:
424
+ logger.exception(f"Unexpected error: {e}")
425
+ print(f"ERROR: Unexpected error: {e}", file=sys.stderr)
426
+ return 1
427
 
 
 
 
 
428
 
429
+ if __name__ == "__main__":
430
+ sys.exit(main())