""" Retrieval pipeline for RAG validation. This module provides functions to: - Convert search queries to embeddings using Cohere - Perform similarity search against Qdrant collection - Format and return results with metadata """ import argparse import json import sys import time import logging from pathlib import Path from typing import List, Dict, Any # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) import cohere from qdrant_client import QdrantClient # Importfrom existing modules import config import utils from logging_config import setup_logging # Initialize logger logger = logging.getLogger(__name__) # Custom exceptions class ConfigurationError(Exception): """Raised when required configuration is missing.""" pass class CollectionNotFoundError(Exception): """Raised when Qdrant collection doesn't exist.""" pass class DimensionMismatchError(Exception): """Raised when embedding dimension doesn't match collection.""" pass class APIError(Exception): """Raised when Cohere or Qdrant API call fails after retries.""" pass def validate_config(cfg: dict) -> None: """Validate that all required config values are present.""" required = ["cohere_api_key", "qdrant_url", "qdrant_api_key"] missing = [key for key in required if not cfg.get(key)] if missing: raise ConfigurationError( f"Missing required environment variables: {', '.join(missing)}" ) def init_clients(cfg: dict): """Initialize Cohere and Qdrant clients.""" cohere_client = cohere.ClientV2(api_key=cfg["cohere_api_key"]) qdrant_client = QdrantClient(url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"]) return cohere_client, qdrant_client def check_collection( qdrant_client: QdrantClient, collection_name: str ) -> Dict[str, Any]: """Verify collection exists and has correct vector size.""" try: info = qdrant_client.get_collection(collection_name) except Exception as e: if "not found" in str(e).lower(): raise CollectionNotFoundError( f"Collection '{collection_name}' does not exist" ) raise vector_size = info.config.params.vectors.size if vector_size != 1024: raise DimensionMismatchError(f"Expected vector size 1024 but got {vector_size}") return { "exists": True, "vector_size": vector_size, "points_count": info.points_count, } def embed_query(text: str, cohere_client: cohere.ClientV2) -> List[float]: """Generate embedding for a search query using Cohere.""" try: response = cohere_client.embed( texts=[text], model="embed-english-v3.0", input_type="search_query" ) # Extract embedding from response.embeddings.float_ embedding = response.embeddings.float_[0] return embedding except Exception as e: logger.error(f"Failed to generate embedding: {e}") raise APIError(f"Cohere embedding failed: {e}") def validate_metadata_completeness(results: List[Dict[str, Any]]) -> float: """ Check metadata completeness in search results. Returns: Percentage (0-100) of results with complete metadata: - url present and non-empty - text present with length ≥ 10 - at least one of title or section non-empty """ if not results: return 0.0 complete = 0 total = len(results) for result in results: payload = result.get("payload", {}) url = payload.get("url", "") text = payload.get("text", "") title = payload.get("title", "") section = payload.get("section", "") # Check completeness criteria url_ok = bool(url and url.strip()) text_ok = len(text or "") >= 10 title_section_ok = bool( (title and title.strip()) or (section and section.strip()) ) if url_ok and text_ok and title_section_ok: complete += 1 percentage = (complete / total) * 100 logger.debug(f"Metadata completeness: {complete}/{total} = {percentage:.1f}%") return percentage def validate_chunk_sequencing(results: List[Dict[str, Any]]) -> bool: """ Verify that chunk_index values are properly assigned: integers >= 0 and unique per URL. Note: Since search may return only a subset of chunks for a URL, we cannot verify full sequential continuity (0,1,2,3...). Instead we check: - All chunk_index values are integers >= 0 - No duplicate chunk_index for the same URL in the result set Args: results: List of search results Returns: True if chunk indices are valid, False otherwise """ # Group by URL url_chunks = {} for result in results: payload = result.get("payload", {}) url = payload.get("url", "") chunk_idx = payload.get("chunk_index") if url not in url_chunks: url_chunks[url] = [] url_chunks[url].append(chunk_idx) # Check each URL's chunks are valid for url, indices in url_chunks.items(): # All indices must be integers >= 0 for idx in indices: if not isinstance(idx, int) or idx < 0: logger.debug( f"Invalid chunk_index for {url}: {idx} (must be non-negative integer)" ) return False # Check for duplicates (within this URL's results) if len(set(indices)) != len(indices): logger.debug(f"Duplicate chunk_index for {url}: {indices}") return False logger.debug(f"Chunk indexing valid for {len(url_chunks)} URLs") return True def search( query_text: str, cohere_client: cohere.ClientV2, qdrant_client: QdrantClient, collection_name: str, top_k: int = 5, ) -> List[Dict[str, Any]]: """ Convert query to embedding and retrieve top-K relevant chunks. Args: query_text: User's search query (non-empty, ≤1000 chars) top_k: Number of results to return (1-100) Returns: List of search results with id, score, and payload """ # Validate inputs if not query_text or not query_text.strip(): raise ValueError("Query text must be non-empty") query_text = query_text.strip() if len(query_text) > 1000: raise ValueError("Query text must be ≤ 1000 characters") if top_k < 1 or top_k > 100: raise ValueError("top_k must be between 1 and 100") logger.info(f"Embedding query: '{query_text[:100]}...' (top_k={top_k})") start_time = time.time() # Generate query embedding with retry try: embedding = utils.retry_with_backoff( lambda: embed_query(query_text, cohere_client), max_retries=3, base_delay=1.0, max_delay=10.0, ) embed_time = time.time() - start_time logger.debug( f"Generated embedding in {embed_time:.2f}s, dimension: {len(embedding)}" ) except Exception as e: logger.error(f"Failed to embed query: {e}") raise # Search Qdrant with retry try: search_start = time.time() response = utils.retry_with_backoff( lambda: qdrant_client.query_points( collection_name=collection_name, query=embedding, limit=top_k, with_payload=True, with_vectors=False, ), max_retries=3, base_delay=1.0, max_delay=10.0, ) results = response.points search_time = time.time() - search_start logger.info( f"Search completed in {search_time:.2f}s, returned {len(results)} results" ) except Exception as e: logger.error(f"Search failed: {e}") raise APIError(f"Qdrant search failed: {e}") # Format results formatted = [] for result in results: formatted.append( { "id": str(result.id), "score": float(result.score), "payload": result.payload, } ) total_time = time.time() - start_time logger.info(f"Total query time: {total_time:.2f}s") return formatted def format_results( results: List[Dict[str, Any]], query: str, latency_ms: int ) -> Dict[str, Any]: """Format search results into JSON output structure.""" output = { "query": query, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "results": results, "metadata": { "total_results": len(results), "collection": None, # Will be filled by main "latency_ms": latency_ms, }, } return output def main() -> int: """CLI entrypoint for retrieval.""" parser = argparse.ArgumentParser( description="Retrieve relevant chunks from Qdrant using Cohere embeddings" ) parser.add_argument("--query", type=str, help="Search query text") parser.add_argument( "--top-k", type=int, default=5, help="Number of results to return (default: 5)" ) parser.add_argument("--output", type=str, help="Output file path (default: stdout)") parser.add_argument( "--config", type=str, default=".env", help="Path to .env config file (default: .env)", ) parser.add_argument( "--validate-metadata", action="store_true", help="Run metadata validation on search results (requires --query)", ) args = parser.parse_args() # Setup logging log_file = "retrieve.log" setup_logging(log_file=log_file, console_level="INFO") logger.info("=== Retrieval Pipeline Started ===") try: # Load config logger.info(f"Loading config from {args.config}") cfg = config.get_config() validate_config(cfg) # Initialize clients logger.info("Initializing Cohere and Qdrant clients") cohere_client, qdrant_client = init_clients(cfg) # Check collection collection_name = cfg["qdrant_collection"] logger.info(f"Checking collection '{collection_name}'") coll_info = check_collection(qdrant_client, collection_name) logger.info( f"Collection OK: vector_size={coll_info['vector_size']}, points={coll_info['points_count']}" ) # Validate query argument if not args.query: parser.error("--query is required") # Perform search results = search( query_text=args.query, cohere_client=cohere_client, qdrant_client=qdrant_client, collection_name=collection_name, top_k=args.top_k, ) # Perform metadata validation if requested metadata_validation = None if args.validate_metadata: completeness = validate_metadata_completeness(results) sequencing = validate_chunk_sequencing(results) metadata_validation = { "completeness_pct": round(completeness, 2), "sequencing_valid": sequencing, "pass": completeness >= 98.0 and sequencing, } logger.info(f"Metadata completeness: {completeness:.1f}%") logger.info(f"Chunk sequencing: {'VALID' if sequencing else 'INVALID'}") logger.info( f"Validation result: {'PASS' if metadata_validation['pass'] else 'FAIL'}" ) # Format output output = { "query": args.query, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "results": results, "metadata": { "total_results": len(results), "collection": collection_name, "vector_size": coll_info["vector_size"], "points_count": coll_info["points_count"], }, } if metadata_validation: output["metadata_validation"] = metadata_validation # Output JSON json_output = json.dumps(output, indent=2) if args.output: with open(args.output, "w") as f: f.write(json_output) logger.info(f"Results written to {args.output}") else: print(json_output) logger.info("=== Retrieval Pipeline Completed Successfully ===") return 0 except ValueError as ve: logger.error(f"Validation error: {ve}") print(f"ERROR: {ve}", file=sys.stderr) return 2 except ConfigurationError as ce: logger.error(f"Configuration error: {ce}") print(f"ERROR: {ce}", file=sys.stderr) return 1 except CollectionNotFoundError as cnfe: logger.error(f"Collection error: {cnfe}") print(f"ERROR: {cnfe}", file=sys.stderr) return 1 except DimensionMismatchError as dme: logger.error(f"Dimension error: {dme}") print(f"ERROR: {dme}", file=sys.stderr) return 1 except APIError as api_err: logger.error(f"API error: {api_err}") print(f"ERROR: {api_err}", file=sys.stderr) return 1 except Exception as e: logger.exception(f"Unexpected error: {e}") print(f"ERROR: Unexpected error: {e}", file=sys.stderr) return 1 if __name__ == "__main__": sys.exit(main())