backend / retrieve.py
m-ahmad-official's picture
update
09f70c3
"""
Retrieval pipeline for RAG validation.
This module provides functions to:
- Convert search queries to embeddings using Cohere
- Perform similarity search against Qdrant collection
- Format and return results with metadata
"""
import argparse
import json
import sys
import time
import logging
from pathlib import Path
from typing import List, Dict, Any
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
import cohere
from qdrant_client import QdrantClient
# Importfrom existing modules
import config
import utils
from logging_config import setup_logging
# Initialize logger
logger = logging.getLogger(__name__)
# Custom exceptions
class ConfigurationError(Exception):
"""Raised when required configuration is missing."""
pass
class CollectionNotFoundError(Exception):
"""Raised when Qdrant collection doesn't exist."""
pass
class DimensionMismatchError(Exception):
"""Raised when embedding dimension doesn't match collection."""
pass
class APIError(Exception):
"""Raised when Cohere or Qdrant API call fails after retries."""
pass
def validate_config(cfg: dict) -> None:
"""Validate that all required config values are present."""
required = ["cohere_api_key", "qdrant_url", "qdrant_api_key"]
missing = [key for key in required if not cfg.get(key)]
if missing:
raise ConfigurationError(
f"Missing required environment variables: {', '.join(missing)}"
)
def init_clients(cfg: dict):
"""Initialize Cohere and Qdrant clients."""
cohere_client = cohere.ClientV2(api_key=cfg["cohere_api_key"])
qdrant_client = QdrantClient(url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"])
return cohere_client, qdrant_client
def check_collection(
qdrant_client: QdrantClient, collection_name: str
) -> Dict[str, Any]:
"""Verify collection exists and has correct vector size."""
try:
info = qdrant_client.get_collection(collection_name)
except Exception as e:
if "not found" in str(e).lower():
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist"
)
raise
vector_size = info.config.params.vectors.size
if vector_size != 1024:
raise DimensionMismatchError(f"Expected vector size 1024 but got {vector_size}")
return {
"exists": True,
"vector_size": vector_size,
"points_count": info.points_count,
}
def embed_query(text: str, cohere_client: cohere.ClientV2) -> List[float]:
"""Generate embedding for a search query using Cohere."""
try:
response = cohere_client.embed(
texts=[text], model="embed-english-v3.0", input_type="search_query"
)
# Extract embedding from response.embeddings.float_
embedding = response.embeddings.float_[0]
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
raise APIError(f"Cohere embedding failed: {e}")
def validate_metadata_completeness(results: List[Dict[str, Any]]) -> float:
"""
Check metadata completeness in search results.
Returns:
Percentage (0-100) of results with complete metadata:
- url present and non-empty
- text present with length ≥ 10
- at least one of title or section non-empty
"""
if not results:
return 0.0
complete = 0
total = len(results)
for result in results:
payload = result.get("payload", {})
url = payload.get("url", "")
text = payload.get("text", "")
title = payload.get("title", "")
section = payload.get("section", "")
# Check completeness criteria
url_ok = bool(url and url.strip())
text_ok = len(text or "") >= 10
title_section_ok = bool(
(title and title.strip()) or (section and section.strip())
)
if url_ok and text_ok and title_section_ok:
complete += 1
percentage = (complete / total) * 100
logger.debug(f"Metadata completeness: {complete}/{total} = {percentage:.1f}%")
return percentage
def validate_chunk_sequencing(results: List[Dict[str, Any]]) -> bool:
"""
Verify that chunk_index values are properly assigned: integers >= 0 and unique per URL.
Note: Since search may return only a subset of chunks for a URL, we cannot
verify full sequential continuity (0,1,2,3...). Instead we check:
- All chunk_index values are integers >= 0
- No duplicate chunk_index for the same URL in the result set
Args:
results: List of search results
Returns:
True if chunk indices are valid, False otherwise
"""
# Group by URL
url_chunks = {}
for result in results:
payload = result.get("payload", {})
url = payload.get("url", "")
chunk_idx = payload.get("chunk_index")
if url not in url_chunks:
url_chunks[url] = []
url_chunks[url].append(chunk_idx)
# Check each URL's chunks are valid
for url, indices in url_chunks.items():
# All indices must be integers >= 0
for idx in indices:
if not isinstance(idx, int) or idx < 0:
logger.debug(
f"Invalid chunk_index for {url}: {idx} (must be non-negative integer)"
)
return False
# Check for duplicates (within this URL's results)
if len(set(indices)) != len(indices):
logger.debug(f"Duplicate chunk_index for {url}: {indices}")
return False
logger.debug(f"Chunk indexing valid for {len(url_chunks)} URLs")
return True
def search(
query_text: str,
cohere_client: cohere.ClientV2,
qdrant_client: QdrantClient,
collection_name: str,
top_k: int = 5,
) -> List[Dict[str, Any]]:
"""
Convert query to embedding and retrieve top-K relevant chunks.
Args:
query_text: User's search query (non-empty, ≤1000 chars)
top_k: Number of results to return (1-100)
Returns:
List of search results with id, score, and payload
"""
# Validate inputs
if not query_text or not query_text.strip():
raise ValueError("Query text must be non-empty")
query_text = query_text.strip()
if len(query_text) > 1000:
raise ValueError("Query text must be ≤ 1000 characters")
if top_k < 1 or top_k > 100:
raise ValueError("top_k must be between 1 and 100")
logger.info(f"Embedding query: '{query_text[:100]}...' (top_k={top_k})")
start_time = time.time()
# Generate query embedding with retry
try:
embedding = utils.retry_with_backoff(
lambda: embed_query(query_text, cohere_client),
max_retries=3,
base_delay=1.0,
max_delay=10.0,
)
embed_time = time.time() - start_time
logger.debug(
f"Generated embedding in {embed_time:.2f}s, dimension: {len(embedding)}"
)
except Exception as e:
logger.error(f"Failed to embed query: {e}")
raise
# Search Qdrant with retry
try:
search_start = time.time()
response = utils.retry_with_backoff(
lambda: qdrant_client.query_points(
collection_name=collection_name,
query=embedding,
limit=top_k,
with_payload=True,
with_vectors=False,
),
max_retries=3,
base_delay=1.0,
max_delay=10.0,
)
results = response.points
search_time = time.time() - search_start
logger.info(
f"Search completed in {search_time:.2f}s, returned {len(results)} results"
)
except Exception as e:
logger.error(f"Search failed: {e}")
raise APIError(f"Qdrant search failed: {e}")
# Format results
formatted = []
for result in results:
formatted.append(
{
"id": str(result.id),
"score": float(result.score),
"payload": result.payload,
}
)
total_time = time.time() - start_time
logger.info(f"Total query time: {total_time:.2f}s")
return formatted
def format_results(
results: List[Dict[str, Any]], query: str, latency_ms: int
) -> Dict[str, Any]:
"""Format search results into JSON output structure."""
output = {
"query": query,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"results": results,
"metadata": {
"total_results": len(results),
"collection": None, # Will be filled by main
"latency_ms": latency_ms,
},
}
return output
def main() -> int:
"""CLI entrypoint for retrieval."""
parser = argparse.ArgumentParser(
description="Retrieve relevant chunks from Qdrant using Cohere embeddings"
)
parser.add_argument("--query", type=str, help="Search query text")
parser.add_argument(
"--top-k", type=int, default=5, help="Number of results to return (default: 5)"
)
parser.add_argument("--output", type=str, help="Output file path (default: stdout)")
parser.add_argument(
"--config",
type=str,
default=".env",
help="Path to .env config file (default: .env)",
)
parser.add_argument(
"--validate-metadata",
action="store_true",
help="Run metadata validation on search results (requires --query)",
)
args = parser.parse_args()
# Setup logging
log_file = "retrieve.log"
setup_logging(log_file=log_file, console_level="INFO")
logger.info("=== Retrieval Pipeline Started ===")
try:
# Load config
logger.info(f"Loading config from {args.config}")
cfg = config.get_config()
validate_config(cfg)
# Initialize clients
logger.info("Initializing Cohere and Qdrant clients")
cohere_client, qdrant_client = init_clients(cfg)
# Check collection
collection_name = cfg["qdrant_collection"]
logger.info(f"Checking collection '{collection_name}'")
coll_info = check_collection(qdrant_client, collection_name)
logger.info(
f"Collection OK: vector_size={coll_info['vector_size']}, points={coll_info['points_count']}"
)
# Validate query argument
if not args.query:
parser.error("--query is required")
# Perform search
results = search(
query_text=args.query,
cohere_client=cohere_client,
qdrant_client=qdrant_client,
collection_name=collection_name,
top_k=args.top_k,
)
# Perform metadata validation if requested
metadata_validation = None
if args.validate_metadata:
completeness = validate_metadata_completeness(results)
sequencing = validate_chunk_sequencing(results)
metadata_validation = {
"completeness_pct": round(completeness, 2),
"sequencing_valid": sequencing,
"pass": completeness >= 98.0 and sequencing,
}
logger.info(f"Metadata completeness: {completeness:.1f}%")
logger.info(f"Chunk sequencing: {'VALID' if sequencing else 'INVALID'}")
logger.info(
f"Validation result: {'PASS' if metadata_validation['pass'] else 'FAIL'}"
)
# Format output
output = {
"query": args.query,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"results": results,
"metadata": {
"total_results": len(results),
"collection": collection_name,
"vector_size": coll_info["vector_size"],
"points_count": coll_info["points_count"],
},
}
if metadata_validation:
output["metadata_validation"] = metadata_validation
# Output JSON
json_output = json.dumps(output, indent=2)
if args.output:
with open(args.output, "w") as f:
f.write(json_output)
logger.info(f"Results written to {args.output}")
else:
print(json_output)
logger.info("=== Retrieval Pipeline Completed Successfully ===")
return 0
except ValueError as ve:
logger.error(f"Validation error: {ve}")
print(f"ERROR: {ve}", file=sys.stderr)
return 2
except ConfigurationError as ce:
logger.error(f"Configuration error: {ce}")
print(f"ERROR: {ce}", file=sys.stderr)
return 1
except CollectionNotFoundError as cnfe:
logger.error(f"Collection error: {cnfe}")
print(f"ERROR: {cnfe}", file=sys.stderr)
return 1
except DimensionMismatchError as dme:
logger.error(f"Dimension error: {dme}")
print(f"ERROR: {dme}", file=sys.stderr)
return 1
except APIError as api_err:
logger.error(f"API error: {api_err}")
print(f"ERROR: {api_err}", file=sys.stderr)
return 1
except Exception as e:
logger.exception(f"Unexpected error: {e}")
print(f"ERROR: Unexpected error: {e}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())