Spaces:
Running
Running
File size: 13,447 Bytes
5f90488 09f70c3 5f90488 c4cf486 09f70c3 5f90488 09f70c3 e88ac9f 09f70c3 c4cf486 e88ac9f 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 e88ac9f c4cf486 09f70c3 5f90488 c4cf486 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 82ad387 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 c4cf486 09f70c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 | """
Retrieval pipeline for RAG validation.
This module provides functions to:
- Convert search queries to embeddings using Cohere
- Perform similarity search against Qdrant collection
- Format and return results with metadata
"""
import argparse
import json
import sys
import time
import logging
from pathlib import Path
from typing import List, Dict, Any
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
import cohere
from qdrant_client import QdrantClient
# Importfrom existing modules
import config
import utils
from logging_config import setup_logging
# Initialize logger
logger = logging.getLogger(__name__)
# Custom exceptions
class ConfigurationError(Exception):
"""Raised when required configuration is missing."""
pass
class CollectionNotFoundError(Exception):
"""Raised when Qdrant collection doesn't exist."""
pass
class DimensionMismatchError(Exception):
"""Raised when embedding dimension doesn't match collection."""
pass
class APIError(Exception):
"""Raised when Cohere or Qdrant API call fails after retries."""
pass
def validate_config(cfg: dict) -> None:
"""Validate that all required config values are present."""
required = ["cohere_api_key", "qdrant_url", "qdrant_api_key"]
missing = [key for key in required if not cfg.get(key)]
if missing:
raise ConfigurationError(
f"Missing required environment variables: {', '.join(missing)}"
)
def init_clients(cfg: dict):
"""Initialize Cohere and Qdrant clients."""
cohere_client = cohere.ClientV2(api_key=cfg["cohere_api_key"])
qdrant_client = QdrantClient(url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"])
return cohere_client, qdrant_client
def check_collection(
qdrant_client: QdrantClient, collection_name: str
) -> Dict[str, Any]:
"""Verify collection exists and has correct vector size."""
try:
info = qdrant_client.get_collection(collection_name)
except Exception as e:
if "not found" in str(e).lower():
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist"
)
raise
vector_size = info.config.params.vectors.size
if vector_size != 1024:
raise DimensionMismatchError(f"Expected vector size 1024 but got {vector_size}")
return {
"exists": True,
"vector_size": vector_size,
"points_count": info.points_count,
}
def embed_query(text: str, cohere_client: cohere.ClientV2) -> List[float]:
"""Generate embedding for a search query using Cohere."""
try:
response = cohere_client.embed(
texts=[text], model="embed-english-v3.0", input_type="search_query"
)
# Extract embedding from response.embeddings.float_
embedding = response.embeddings.float_[0]
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
raise APIError(f"Cohere embedding failed: {e}")
def validate_metadata_completeness(results: List[Dict[str, Any]]) -> float:
"""
Check metadata completeness in search results.
Returns:
Percentage (0-100) of results with complete metadata:
- url present and non-empty
- text present with length ≥ 10
- at least one of title or section non-empty
"""
if not results:
return 0.0
complete = 0
total = len(results)
for result in results:
payload = result.get("payload", {})
url = payload.get("url", "")
text = payload.get("text", "")
title = payload.get("title", "")
section = payload.get("section", "")
# Check completeness criteria
url_ok = bool(url and url.strip())
text_ok = len(text or "") >= 10
title_section_ok = bool(
(title and title.strip()) or (section and section.strip())
)
if url_ok and text_ok and title_section_ok:
complete += 1
percentage = (complete / total) * 100
logger.debug(f"Metadata completeness: {complete}/{total} = {percentage:.1f}%")
return percentage
def validate_chunk_sequencing(results: List[Dict[str, Any]]) -> bool:
"""
Verify that chunk_index values are properly assigned: integers >= 0 and unique per URL.
Note: Since search may return only a subset of chunks for a URL, we cannot
verify full sequential continuity (0,1,2,3...). Instead we check:
- All chunk_index values are integers >= 0
- No duplicate chunk_index for the same URL in the result set
Args:
results: List of search results
Returns:
True if chunk indices are valid, False otherwise
"""
# Group by URL
url_chunks = {}
for result in results:
payload = result.get("payload", {})
url = payload.get("url", "")
chunk_idx = payload.get("chunk_index")
if url not in url_chunks:
url_chunks[url] = []
url_chunks[url].append(chunk_idx)
# Check each URL's chunks are valid
for url, indices in url_chunks.items():
# All indices must be integers >= 0
for idx in indices:
if not isinstance(idx, int) or idx < 0:
logger.debug(
f"Invalid chunk_index for {url}: {idx} (must be non-negative integer)"
)
return False
# Check for duplicates (within this URL's results)
if len(set(indices)) != len(indices):
logger.debug(f"Duplicate chunk_index for {url}: {indices}")
return False
logger.debug(f"Chunk indexing valid for {len(url_chunks)} URLs")
return True
def search(
query_text: str,
cohere_client: cohere.ClientV2,
qdrant_client: QdrantClient,
collection_name: str,
top_k: int = 5,
) -> List[Dict[str, Any]]:
"""
Convert query to embedding and retrieve top-K relevant chunks.
Args:
query_text: User's search query (non-empty, ≤1000 chars)
top_k: Number of results to return (1-100)
Returns:
List of search results with id, score, and payload
"""
# Validate inputs
if not query_text or not query_text.strip():
raise ValueError("Query text must be non-empty")
query_text = query_text.strip()
if len(query_text) > 1000:
raise ValueError("Query text must be ≤ 1000 characters")
if top_k < 1 or top_k > 100:
raise ValueError("top_k must be between 1 and 100")
logger.info(f"Embedding query: '{query_text[:100]}...' (top_k={top_k})")
start_time = time.time()
# Generate query embedding with retry
try:
embedding = utils.retry_with_backoff(
lambda: embed_query(query_text, cohere_client),
max_retries=3,
base_delay=1.0,
max_delay=10.0,
)
embed_time = time.time() - start_time
logger.debug(
f"Generated embedding in {embed_time:.2f}s, dimension: {len(embedding)}"
)
except Exception as e:
logger.error(f"Failed to embed query: {e}")
raise
# Search Qdrant with retry
try:
search_start = time.time()
response = utils.retry_with_backoff(
lambda: qdrant_client.query_points(
collection_name=collection_name,
query=embedding,
limit=top_k,
with_payload=True,
with_vectors=False,
),
max_retries=3,
base_delay=1.0,
max_delay=10.0,
)
results = response.points
search_time = time.time() - search_start
logger.info(
f"Search completed in {search_time:.2f}s, returned {len(results)} results"
)
except Exception as e:
logger.error(f"Search failed: {e}")
raise APIError(f"Qdrant search failed: {e}")
# Format results
formatted = []
for result in results:
formatted.append(
{
"id": str(result.id),
"score": float(result.score),
"payload": result.payload,
}
)
total_time = time.time() - start_time
logger.info(f"Total query time: {total_time:.2f}s")
return formatted
def format_results(
results: List[Dict[str, Any]], query: str, latency_ms: int
) -> Dict[str, Any]:
"""Format search results into JSON output structure."""
output = {
"query": query,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"results": results,
"metadata": {
"total_results": len(results),
"collection": None, # Will be filled by main
"latency_ms": latency_ms,
},
}
return output
def main() -> int:
"""CLI entrypoint for retrieval."""
parser = argparse.ArgumentParser(
description="Retrieve relevant chunks from Qdrant using Cohere embeddings"
)
parser.add_argument("--query", type=str, help="Search query text")
parser.add_argument(
"--top-k", type=int, default=5, help="Number of results to return (default: 5)"
)
parser.add_argument("--output", type=str, help="Output file path (default: stdout)")
parser.add_argument(
"--config",
type=str,
default=".env",
help="Path to .env config file (default: .env)",
)
parser.add_argument(
"--validate-metadata",
action="store_true",
help="Run metadata validation on search results (requires --query)",
)
args = parser.parse_args()
# Setup logging
log_file = "retrieve.log"
setup_logging(log_file=log_file, console_level="INFO")
logger.info("=== Retrieval Pipeline Started ===")
try:
# Load config
logger.info(f"Loading config from {args.config}")
cfg = config.get_config()
validate_config(cfg)
# Initialize clients
logger.info("Initializing Cohere and Qdrant clients")
cohere_client, qdrant_client = init_clients(cfg)
# Check collection
collection_name = cfg["qdrant_collection"]
logger.info(f"Checking collection '{collection_name}'")
coll_info = check_collection(qdrant_client, collection_name)
logger.info(
f"Collection OK: vector_size={coll_info['vector_size']}, points={coll_info['points_count']}"
)
# Validate query argument
if not args.query:
parser.error("--query is required")
# Perform search
results = search(
query_text=args.query,
cohere_client=cohere_client,
qdrant_client=qdrant_client,
collection_name=collection_name,
top_k=args.top_k,
)
# Perform metadata validation if requested
metadata_validation = None
if args.validate_metadata:
completeness = validate_metadata_completeness(results)
sequencing = validate_chunk_sequencing(results)
metadata_validation = {
"completeness_pct": round(completeness, 2),
"sequencing_valid": sequencing,
"pass": completeness >= 98.0 and sequencing,
}
logger.info(f"Metadata completeness: {completeness:.1f}%")
logger.info(f"Chunk sequencing: {'VALID' if sequencing else 'INVALID'}")
logger.info(
f"Validation result: {'PASS' if metadata_validation['pass'] else 'FAIL'}"
)
# Format output
output = {
"query": args.query,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"results": results,
"metadata": {
"total_results": len(results),
"collection": collection_name,
"vector_size": coll_info["vector_size"],
"points_count": coll_info["points_count"],
},
}
if metadata_validation:
output["metadata_validation"] = metadata_validation
# Output JSON
json_output = json.dumps(output, indent=2)
if args.output:
with open(args.output, "w") as f:
f.write(json_output)
logger.info(f"Results written to {args.output}")
else:
print(json_output)
logger.info("=== Retrieval Pipeline Completed Successfully ===")
return 0
except ValueError as ve:
logger.error(f"Validation error: {ve}")
print(f"ERROR: {ve}", file=sys.stderr)
return 2
except ConfigurationError as ce:
logger.error(f"Configuration error: {ce}")
print(f"ERROR: {ce}", file=sys.stderr)
return 1
except CollectionNotFoundError as cnfe:
logger.error(f"Collection error: {cnfe}")
print(f"ERROR: {cnfe}", file=sys.stderr)
return 1
except DimensionMismatchError as dme:
logger.error(f"Dimension error: {dme}")
print(f"ERROR: {dme}", file=sys.stderr)
return 1
except APIError as api_err:
logger.error(f"API error: {api_err}")
print(f"ERROR: {api_err}", file=sys.stderr)
return 1
except Exception as e:
logger.exception(f"Unexpected error: {e}")
print(f"ERROR: Unexpected error: {e}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())
|