m-ahmad-official commited on
Commit
993bc66
·
1 Parent(s): a724f5f
Files changed (1) hide show
  1. retrieve.py +1 -237
retrieve.py CHANGED
@@ -85,240 +85,4 @@ def search(
85
  f"Search failed for query '{query_text[:100]}...': {type(e).__name__}: {e}",
86
  exc_info=True,
87
  )
88
- raise
89
-
90
- def search(
91
- query_text: str,
92
- cohere_client: cohere.ClientV2,
93
- qdrant_client: QdrantClient,
94
- collection_name: str,
95
- top_k: int = 5,
96
- ) -> List[Dict[str, Any]]:
97
- """
98
- Convert query to embedding and retrieve top-K relevant chunks.
99
-
100
- Args:
101
- query_text: User's search query (non-empty, ≤1000 chars)
102
- top_k: Number of results to return (1-100)
103
-
104
- Returns:
105
- List of search results with id, score, and payload
106
- """
107
- # Validate inputs
108
- if not query_text or not query_text.strip():
109
- raise ValueError("Query text must be non-empty")
110
- query_text = query_text.strip()
111
- if len(query_text) > 1000:
112
- raise ValueError("Query text must be ≤ 1000 characters")
113
- if top_k < 1 or top_k > 100:
114
- raise ValueError("top_k must be between 1 and 100")
115
-
116
- logger.info(f"Embedding query: '{query_text[:100]}...' (top_k={top_k})")
117
- start_time = time.time()
118
-
119
- # Generate query embedding with retry
120
- try:
121
- embedding = utils.retry_with_backoff(
122
- lambda: embed_query(query_text, cohere_client),
123
- max_retries=3,
124
- base_delay=1.0,
125
- max_delay=10.0,
126
- )
127
- embed_time = time.time() - start_time
128
- logger.debug(
129
- f"Generated embedding in {embed_time:.2f}s, dimension: {len(embedding)}"
130
- )
131
- except Exception as e:
132
- logger.error(f"Failed to embed query: {e}")
133
- raise
134
-
135
- # Search Qdrant with retry
136
- try:
137
- search_start = time.time()
138
- response = utils.retry_with_backoff(
139
- lambda: qdrant_client.query_points(
140
- collection_name=collection_name,
141
- query=embedding,
142
- limit=top_k,
143
- with_payload=True,
144
- with_vectors=False,
145
- ),
146
- max_retries=3,
147
- base_delay=1.0,
148
- max_delay=10.0,
149
- )
150
- results = response.points
151
- search_time = time.time() - search_start
152
- logger.info(
153
- f"Search completed in {search_time:.2f}s, returned {len(results)} results"
154
- )
155
- except Exception as e:
156
- logger.error(f"Search failed: {e}")
157
- raise APIError(f"Qdrant search failed: {e}")
158
-
159
- # Format results
160
- formatted = []
161
- for result in results:
162
- formatted.append(
163
- {
164
- "id": str(result.id),
165
- "score": float(result.score),
166
- "payload": result.payload,
167
- }
168
- )
169
-
170
- total_time = time.time() - start_time
171
- logger.info(f"Total query time: {total_time:.2f}s")
172
-
173
- return formatted
174
-
175
-
176
- def format_results(
177
- results: List[Dict[str, Any]], query: str, latency_ms: int
178
- ) -> Dict[str, Any]:
179
- """Format search results into JSON output structure."""
180
- output = {
181
- "query": query,
182
- "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
183
- "results": results,
184
- "metadata": {
185
- "total_results": len(results),
186
- "collection": None, # Will be filled by main
187
- "latency_ms": latency_ms,
188
- },
189
- }
190
- return output
191
-
192
-
193
- def main() -> int:
194
- """CLI entrypoint for retrieval."""
195
- parser = argparse.ArgumentParser(
196
- description="Retrieve relevant chunks from Qdrant using Cohere embeddings"
197
- )
198
- parser.add_argument("--query", type=str, help="Search query text")
199
- parser.add_argument(
200
- "--top-k", type=int, default=5, help="Number of results to return (default: 5)"
201
- )
202
- parser.add_argument("--output", type=str, help="Output file path (default: stdout)")
203
- parser.add_argument(
204
- "--config",
205
- type=str,
206
- default=".env",
207
- help="Path to .env config file (default: .env)",
208
- )
209
- parser.add_argument(
210
- "--validate-metadata",
211
- action="store_true",
212
- help="Run metadata validation on search results (requires --query)",
213
- )
214
-
215
- args = parser.parse_args()
216
-
217
- # Setup logging
218
- log_file = "retrieve.log"
219
- setup_logging(log_file=log_file, console_level="INFO")
220
- logger.info("=== Retrieval Pipeline Started ===")
221
-
222
- try:
223
- # Load config
224
- logger.info(f"Loading config from {args.config}")
225
- cfg = config.get_config()
226
- validate_config(cfg)
227
-
228
- # Initialize clients
229
- logger.info("Initializing Cohere and Qdrant clients")
230
- cohere_client, qdrant_client = init_clients(cfg)
231
-
232
- # Check collection
233
- collection_name = cfg["qdrant_collection"]
234
- logger.info(f"Checking collection '{collection_name}'")
235
- coll_info = check_collection(qdrant_client, collection_name)
236
- logger.info(
237
- f"Collection OK: vector_size={coll_info['vector_size']}, points={coll_info['points_count']}"
238
- )
239
-
240
- # Validate query argument
241
- if not args.query:
242
- parser.error("--query is required")
243
-
244
- # Perform search
245
- results = search(
246
- query_text=args.query,
247
- cohere_client=cohere_client,
248
- qdrant_client=qdrant_client,
249
- collection_name=collection_name,
250
- top_k=args.top_k,
251
- )
252
-
253
- # Perform metadata validation if requested
254
- metadata_validation = None
255
- if args.validate_metadata:
256
- completeness = validate_metadata_completeness(results)
257
- sequencing = validate_chunk_sequencing(results)
258
- metadata_validation = {
259
- "completeness_pct": round(completeness, 2),
260
- "sequencing_valid": sequencing,
261
- "pass": completeness >= 98.0 and sequencing,
262
- }
263
- logger.info(f"Metadata completeness: {completeness:.1f}%")
264
- logger.info(f"Chunk sequencing: {'VALID' if sequencing else 'INVALID'}")
265
- logger.info(
266
- f"Validation result: {'PASS' if metadata_validation['pass'] else 'FAIL'}"
267
- )
268
-
269
- # Format output
270
- output = {
271
- "query": args.query,
272
- "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
273
- "results": results,
274
- "metadata": {
275
- "total_results": len(results),
276
- "collection": collection_name,
277
- "vector_size": coll_info["vector_size"],
278
- "points_count": coll_info["points_count"],
279
- },
280
- }
281
-
282
- if metadata_validation:
283
- output["metadata_validation"] = metadata_validation
284
-
285
- # Output JSON
286
- json_output = json.dumps(output, indent=2)
287
- if args.output:
288
- with open(args.output, "w") as f:
289
- f.write(json_output)
290
- logger.info(f"Results written to {args.output}")
291
- else:
292
- print(json_output)
293
-
294
- logger.info("=== Retrieval Pipeline Completed Successfully ===")
295
- return 0
296
-
297
- except ValueError as ve:
298
- logger.error(f"Validation error: {ve}")
299
- print(f"ERROR: {ve}", file=sys.stderr)
300
- return 2
301
- except ConfigurationError as ce:
302
- logger.error(f"Configuration error: {ce}")
303
- print(f"ERROR: {ce}", file=sys.stderr)
304
- return 1
305
- except CollectionNotFoundError as cnfe:
306
- logger.error(f"Collection error: {cnfe}")
307
- print(f"ERROR: {cnfe}", file=sys.stderr)
308
- return 1
309
- except DimensionMismatchError as dme:
310
- logger.error(f"Dimension error: {dme}")
311
- print(f"ERROR: {dme}", file=sys.stderr)
312
- return 1
313
- except APIError as api_err:
314
- logger.error(f"API error: {api_err}")
315
- print(f"ERROR: {api_err}", file=sys.stderr)
316
- return 1
317
- except Exception as e:
318
- logger.exception(f"Unexpected error: {e}")
319
- print(f"ERROR: Unexpected error: {e}", file=sys.stderr)
320
- return 1
321
-
322
-
323
- if __name__ == "__main__":
324
- sys.exit(main())
 
85
  f"Search failed for query '{query_text[:100]}...': {type(e).__name__}: {e}",
86
  exc_info=True,
87
  )
88
+ raise