Spaces:
Running
Running
| """ | |
| Retrieval pipeline for RAG validation. | |
| This module provides functions to: | |
| - Convert search queries to embeddings using Cohere | |
| - Perform similarity search against Qdrant collection | |
| - Format and return results with metadata | |
| """ | |
| import argparse | |
| import json | |
| import sys | |
| import time | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| import cohere | |
| from qdrant_client import QdrantClient | |
| # Importfrom existing modules | |
| import config | |
| import utils | |
| from logging_config import setup_logging | |
| # Initialize logger | |
| logger = logging.getLogger(__name__) | |
| # Custom exceptions | |
| class ConfigurationError(Exception): | |
| """Raised when required configuration is missing.""" | |
| pass | |
| class CollectionNotFoundError(Exception): | |
| """Raised when Qdrant collection doesn't exist.""" | |
| pass | |
| class DimensionMismatchError(Exception): | |
| """Raised when embedding dimension doesn't match collection.""" | |
| pass | |
| class APIError(Exception): | |
| """Raised when Cohere or Qdrant API call fails after retries.""" | |
| pass | |
| def validate_config(cfg: dict) -> None: | |
| """Validate that all required config values are present.""" | |
| required = ["cohere_api_key", "qdrant_url", "qdrant_api_key"] | |
| missing = [key for key in required if not cfg.get(key)] | |
| if missing: | |
| raise ConfigurationError( | |
| f"Missing required environment variables: {', '.join(missing)}" | |
| ) | |
| def init_clients(cfg: dict): | |
| """Initialize Cohere and Qdrant clients.""" | |
| cohere_client = cohere.ClientV2(api_key=cfg["cohere_api_key"]) | |
| qdrant_client = QdrantClient(url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"]) | |
| return cohere_client, qdrant_client | |
| def check_collection( | |
| qdrant_client: QdrantClient, collection_name: str | |
| ) -> Dict[str, Any]: | |
| """Verify collection exists and has correct vector size.""" | |
| try: | |
| info = qdrant_client.get_collection(collection_name) | |
| except Exception as e: | |
| if "not found" in str(e).lower(): | |
| raise CollectionNotFoundError( | |
| f"Collection '{collection_name}' does not exist" | |
| ) | |
| raise | |
| vector_size = info.config.params.vectors.size | |
| if vector_size != 1024: | |
| raise DimensionMismatchError(f"Expected vector size 1024 but got {vector_size}") | |
| return { | |
| "exists": True, | |
| "vector_size": vector_size, | |
| "points_count": info.points_count, | |
| } | |
| def embed_query(text: str, cohere_client: cohere.ClientV2) -> List[float]: | |
| """Generate embedding for a search query using Cohere.""" | |
| try: | |
| response = cohere_client.embed( | |
| texts=[text], model="embed-english-v3.0", input_type="search_query" | |
| ) | |
| # Extract embedding from response.embeddings.float_ | |
| embedding = response.embeddings.float_[0] | |
| return embedding | |
| except Exception as e: | |
| logger.error(f"Failed to generate embedding: {e}") | |
| raise APIError(f"Cohere embedding failed: {e}") | |
| def validate_metadata_completeness(results: List[Dict[str, Any]]) -> float: | |
| """ | |
| Check metadata completeness in search results. | |
| Returns: | |
| Percentage (0-100) of results with complete metadata: | |
| - url present and non-empty | |
| - text present with length ≥ 10 | |
| - at least one of title or section non-empty | |
| """ | |
| if not results: | |
| return 0.0 | |
| complete = 0 | |
| total = len(results) | |
| for result in results: | |
| payload = result.get("payload", {}) | |
| url = payload.get("url", "") | |
| text = payload.get("text", "") | |
| title = payload.get("title", "") | |
| section = payload.get("section", "") | |
| # Check completeness criteria | |
| url_ok = bool(url and url.strip()) | |
| text_ok = len(text or "") >= 10 | |
| title_section_ok = bool( | |
| (title and title.strip()) or (section and section.strip()) | |
| ) | |
| if url_ok and text_ok and title_section_ok: | |
| complete += 1 | |
| percentage = (complete / total) * 100 | |
| logger.debug(f"Metadata completeness: {complete}/{total} = {percentage:.1f}%") | |
| return percentage | |
| def validate_chunk_sequencing(results: List[Dict[str, Any]]) -> bool: | |
| """ | |
| Verify that chunk_index values are properly assigned: integers >= 0 and unique per URL. | |
| Note: Since search may return only a subset of chunks for a URL, we cannot | |
| verify full sequential continuity (0,1,2,3...). Instead we check: | |
| - All chunk_index values are integers >= 0 | |
| - No duplicate chunk_index for the same URL in the result set | |
| Args: | |
| results: List of search results | |
| Returns: | |
| True if chunk indices are valid, False otherwise | |
| """ | |
| # Group by URL | |
| url_chunks = {} | |
| for result in results: | |
| payload = result.get("payload", {}) | |
| url = payload.get("url", "") | |
| chunk_idx = payload.get("chunk_index") | |
| if url not in url_chunks: | |
| url_chunks[url] = [] | |
| url_chunks[url].append(chunk_idx) | |
| # Check each URL's chunks are valid | |
| for url, indices in url_chunks.items(): | |
| # All indices must be integers >= 0 | |
| for idx in indices: | |
| if not isinstance(idx, int) or idx < 0: | |
| logger.debug( | |
| f"Invalid chunk_index for {url}: {idx} (must be non-negative integer)" | |
| ) | |
| return False | |
| # Check for duplicates (within this URL's results) | |
| if len(set(indices)) != len(indices): | |
| logger.debug(f"Duplicate chunk_index for {url}: {indices}") | |
| return False | |
| logger.debug(f"Chunk indexing valid for {len(url_chunks)} URLs") | |
| return True | |
| def search( | |
| query_text: str, | |
| cohere_client: cohere.ClientV2, | |
| qdrant_client: QdrantClient, | |
| collection_name: str, | |
| top_k: int = 5, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Convert query to embedding and retrieve top-K relevant chunks. | |
| Args: | |
| query_text: User's search query (non-empty, ≤1000 chars) | |
| top_k: Number of results to return (1-100) | |
| Returns: | |
| List of search results with id, score, and payload | |
| """ | |
| # Validate inputs | |
| if not query_text or not query_text.strip(): | |
| raise ValueError("Query text must be non-empty") | |
| query_text = query_text.strip() | |
| if len(query_text) > 1000: | |
| raise ValueError("Query text must be ≤ 1000 characters") | |
| if top_k < 1 or top_k > 100: | |
| raise ValueError("top_k must be between 1 and 100") | |
| logger.info(f"Embedding query: '{query_text[:100]}...' (top_k={top_k})") | |
| start_time = time.time() | |
| # Generate query embedding with retry | |
| try: | |
| embedding = utils.retry_with_backoff( | |
| lambda: embed_query(query_text, cohere_client), | |
| max_retries=3, | |
| base_delay=1.0, | |
| max_delay=10.0, | |
| ) | |
| embed_time = time.time() - start_time | |
| logger.debug( | |
| f"Generated embedding in {embed_time:.2f}s, dimension: {len(embedding)}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to embed query: {e}") | |
| raise | |
| # Search Qdrant with retry | |
| try: | |
| search_start = time.time() | |
| response = utils.retry_with_backoff( | |
| lambda: qdrant_client.query_points( | |
| collection_name=collection_name, | |
| query=embedding, | |
| limit=top_k, | |
| with_payload=True, | |
| with_vectors=False, | |
| ), | |
| max_retries=3, | |
| base_delay=1.0, | |
| max_delay=10.0, | |
| ) | |
| results = response.points | |
| search_time = time.time() - search_start | |
| logger.info( | |
| f"Search completed in {search_time:.2f}s, returned {len(results)} results" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Search failed: {e}") | |
| raise APIError(f"Qdrant search failed: {e}") | |
| # Format results | |
| formatted = [] | |
| for result in results: | |
| formatted.append( | |
| { | |
| "id": str(result.id), | |
| "score": float(result.score), | |
| "payload": result.payload, | |
| } | |
| ) | |
| total_time = time.time() - start_time | |
| logger.info(f"Total query time: {total_time:.2f}s") | |
| return formatted | |
| def format_results( | |
| results: List[Dict[str, Any]], query: str, latency_ms: int | |
| ) -> Dict[str, Any]: | |
| """Format search results into JSON output structure.""" | |
| output = { | |
| "query": query, | |
| "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "results": results, | |
| "metadata": { | |
| "total_results": len(results), | |
| "collection": None, # Will be filled by main | |
| "latency_ms": latency_ms, | |
| }, | |
| } | |
| return output | |
| def main() -> int: | |
| """CLI entrypoint for retrieval.""" | |
| parser = argparse.ArgumentParser( | |
| description="Retrieve relevant chunks from Qdrant using Cohere embeddings" | |
| ) | |
| parser.add_argument("--query", type=str, help="Search query text") | |
| parser.add_argument( | |
| "--top-k", type=int, default=5, help="Number of results to return (default: 5)" | |
| ) | |
| parser.add_argument("--output", type=str, help="Output file path (default: stdout)") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default=".env", | |
| help="Path to .env config file (default: .env)", | |
| ) | |
| parser.add_argument( | |
| "--validate-metadata", | |
| action="store_true", | |
| help="Run metadata validation on search results (requires --query)", | |
| ) | |
| args = parser.parse_args() | |
| # Setup logging | |
| log_file = "retrieve.log" | |
| setup_logging(log_file=log_file, console_level="INFO") | |
| logger.info("=== Retrieval Pipeline Started ===") | |
| try: | |
| # Load config | |
| logger.info(f"Loading config from {args.config}") | |
| cfg = config.get_config() | |
| validate_config(cfg) | |
| # Initialize clients | |
| logger.info("Initializing Cohere and Qdrant clients") | |
| cohere_client, qdrant_client = init_clients(cfg) | |
| # Check collection | |
| collection_name = cfg["qdrant_collection"] | |
| logger.info(f"Checking collection '{collection_name}'") | |
| coll_info = check_collection(qdrant_client, collection_name) | |
| logger.info( | |
| f"Collection OK: vector_size={coll_info['vector_size']}, points={coll_info['points_count']}" | |
| ) | |
| # Validate query argument | |
| if not args.query: | |
| parser.error("--query is required") | |
| # Perform search | |
| results = search( | |
| query_text=args.query, | |
| cohere_client=cohere_client, | |
| qdrant_client=qdrant_client, | |
| collection_name=collection_name, | |
| top_k=args.top_k, | |
| ) | |
| # Perform metadata validation if requested | |
| metadata_validation = None | |
| if args.validate_metadata: | |
| completeness = validate_metadata_completeness(results) | |
| sequencing = validate_chunk_sequencing(results) | |
| metadata_validation = { | |
| "completeness_pct": round(completeness, 2), | |
| "sequencing_valid": sequencing, | |
| "pass": completeness >= 98.0 and sequencing, | |
| } | |
| logger.info(f"Metadata completeness: {completeness:.1f}%") | |
| logger.info(f"Chunk sequencing: {'VALID' if sequencing else 'INVALID'}") | |
| logger.info( | |
| f"Validation result: {'PASS' if metadata_validation['pass'] else 'FAIL'}" | |
| ) | |
| # Format output | |
| output = { | |
| "query": args.query, | |
| "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "results": results, | |
| "metadata": { | |
| "total_results": len(results), | |
| "collection": collection_name, | |
| "vector_size": coll_info["vector_size"], | |
| "points_count": coll_info["points_count"], | |
| }, | |
| } | |
| if metadata_validation: | |
| output["metadata_validation"] = metadata_validation | |
| # Output JSON | |
| json_output = json.dumps(output, indent=2) | |
| if args.output: | |
| with open(args.output, "w") as f: | |
| f.write(json_output) | |
| logger.info(f"Results written to {args.output}") | |
| else: | |
| print(json_output) | |
| logger.info("=== Retrieval Pipeline Completed Successfully ===") | |
| return 0 | |
| except ValueError as ve: | |
| logger.error(f"Validation error: {ve}") | |
| print(f"ERROR: {ve}", file=sys.stderr) | |
| return 2 | |
| except ConfigurationError as ce: | |
| logger.error(f"Configuration error: {ce}") | |
| print(f"ERROR: {ce}", file=sys.stderr) | |
| return 1 | |
| except CollectionNotFoundError as cnfe: | |
| logger.error(f"Collection error: {cnfe}") | |
| print(f"ERROR: {cnfe}", file=sys.stderr) | |
| return 1 | |
| except DimensionMismatchError as dme: | |
| logger.error(f"Dimension error: {dme}") | |
| print(f"ERROR: {dme}", file=sys.stderr) | |
| return 1 | |
| except APIError as api_err: | |
| logger.error(f"API error: {api_err}") | |
| print(f"ERROR: {api_err}", file=sys.stderr) | |
| return 1 | |
| except Exception as e: | |
| logger.exception(f"Unexpected error: {e}") | |
| print(f"ERROR: Unexpected error: {e}", file=sys.stderr) | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |