File size: 13,447 Bytes
5f90488
09f70c3
 
 
 
 
 
5f90488
c4cf486
09f70c3
 
 
 
5f90488
09f70c3
e88ac9f
 
09f70c3
 
 
 
 
 
 
 
 
 
 
 
c4cf486
 
e88ac9f
09f70c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4cf486
 
09f70c3
 
c4cf486
 
 
 
09f70c3
e88ac9f
c4cf486
09f70c3
 
5f90488
c4cf486
09f70c3
c4cf486
09f70c3
 
 
 
 
 
 
 
 
 
 
 
 
c4cf486
09f70c3
 
 
 
 
 
 
 
 
c4cf486
09f70c3
 
 
c4cf486
09f70c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82ad387
09f70c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4cf486
09f70c3
 
 
c4cf486
09f70c3
 
 
 
 
c4cf486
09f70c3
c4cf486
 
09f70c3
 
 
 
 
 
 
 
 
 
 
 
 
 
c4cf486
 
09f70c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4cf486
09f70c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4cf486
09f70c3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
"""
Retrieval pipeline for RAG validation.

This module provides functions to:
- Convert search queries to embeddings using Cohere
- Perform similarity search against Qdrant collection
- Format and return results with metadata
"""

import argparse
import json
import sys
import time
import logging
from pathlib import Path
from typing import List, Dict, Any

# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))

import cohere
from qdrant_client import QdrantClient

# Importfrom existing modules
import config
import utils
from logging_config import setup_logging

# Initialize logger
logger = logging.getLogger(__name__)


# Custom exceptions
class ConfigurationError(Exception):
    """Raised when required configuration is missing."""

    pass


class CollectionNotFoundError(Exception):
    """Raised when Qdrant collection doesn't exist."""

    pass


class DimensionMismatchError(Exception):
    """Raised when embedding dimension doesn't match collection."""

    pass


class APIError(Exception):
    """Raised when Cohere or Qdrant API call fails after retries."""

    pass


def validate_config(cfg: dict) -> None:
    """Validate that all required config values are present."""
    required = ["cohere_api_key", "qdrant_url", "qdrant_api_key"]
    missing = [key for key in required if not cfg.get(key)]
    if missing:
        raise ConfigurationError(
            f"Missing required environment variables: {', '.join(missing)}"
        )


def init_clients(cfg: dict):
    """Initialize Cohere and Qdrant clients."""
    cohere_client = cohere.ClientV2(api_key=cfg["cohere_api_key"])
    qdrant_client = QdrantClient(url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"])
    return cohere_client, qdrant_client


def check_collection(
    qdrant_client: QdrantClient, collection_name: str
) -> Dict[str, Any]:
    """Verify collection exists and has correct vector size."""
    try:
        info = qdrant_client.get_collection(collection_name)
    except Exception as e:
        if "not found" in str(e).lower():
            raise CollectionNotFoundError(
                f"Collection '{collection_name}' does not exist"
            )
        raise

    vector_size = info.config.params.vectors.size
    if vector_size != 1024:
        raise DimensionMismatchError(f"Expected vector size 1024 but got {vector_size}")

    return {
        "exists": True,
        "vector_size": vector_size,
        "points_count": info.points_count,
    }


def embed_query(text: str, cohere_client: cohere.ClientV2) -> List[float]:
    """Generate embedding for a search query using Cohere."""
    try:
        response = cohere_client.embed(
            texts=[text], model="embed-english-v3.0", input_type="search_query"
        )
        # Extract embedding from response.embeddings.float_
        embedding = response.embeddings.float_[0]
        return embedding
    except Exception as e:
        logger.error(f"Failed to generate embedding: {e}")
        raise APIError(f"Cohere embedding failed: {e}")


def validate_metadata_completeness(results: List[Dict[str, Any]]) -> float:
    """
    Check metadata completeness in search results.

    Returns:
        Percentage (0-100) of results with complete metadata:
        - url present and non-empty
        - text present with length ≥ 10
        - at least one of title or section non-empty
    """
    if not results:
        return 0.0

    complete = 0
    total = len(results)

    for result in results:
        payload = result.get("payload", {})
        url = payload.get("url", "")
        text = payload.get("text", "")
        title = payload.get("title", "")
        section = payload.get("section", "")

        # Check completeness criteria
        url_ok = bool(url and url.strip())
        text_ok = len(text or "") >= 10
        title_section_ok = bool(
            (title and title.strip()) or (section and section.strip())
        )

        if url_ok and text_ok and title_section_ok:
            complete += 1

    percentage = (complete / total) * 100
    logger.debug(f"Metadata completeness: {complete}/{total} = {percentage:.1f}%")
    return percentage


def validate_chunk_sequencing(results: List[Dict[str, Any]]) -> bool:
    """
    Verify that chunk_index values are properly assigned: integers >= 0 and unique per URL.

    Note: Since search may return only a subset of chunks for a URL, we cannot
    verify full sequential continuity (0,1,2,3...). Instead we check:
    - All chunk_index values are integers >= 0
    - No duplicate chunk_index for the same URL in the result set

    Args:
        results: List of search results

    Returns:
        True if chunk indices are valid, False otherwise
    """
    # Group by URL
    url_chunks = {}
    for result in results:
        payload = result.get("payload", {})
        url = payload.get("url", "")
        chunk_idx = payload.get("chunk_index")

        if url not in url_chunks:
            url_chunks[url] = []
        url_chunks[url].append(chunk_idx)

    # Check each URL's chunks are valid
    for url, indices in url_chunks.items():
        # All indices must be integers >= 0
        for idx in indices:
            if not isinstance(idx, int) or idx < 0:
                logger.debug(
                    f"Invalid chunk_index for {url}: {idx} (must be non-negative integer)"
                )
                return False

        # Check for duplicates (within this URL's results)
        if len(set(indices)) != len(indices):
            logger.debug(f"Duplicate chunk_index for {url}: {indices}")
            return False

    logger.debug(f"Chunk indexing valid for {len(url_chunks)} URLs")
    return True


def search(
    query_text: str,
    cohere_client: cohere.ClientV2,
    qdrant_client: QdrantClient,
    collection_name: str,
    top_k: int = 5,
) -> List[Dict[str, Any]]:
    """
    Convert query to embedding and retrieve top-K relevant chunks.

    Args:
        query_text: User's search query (non-empty, ≤1000 chars)
        top_k: Number of results to return (1-100)

    Returns:
        List of search results with id, score, and payload
    """
    # Validate inputs
    if not query_text or not query_text.strip():
        raise ValueError("Query text must be non-empty")
    query_text = query_text.strip()
    if len(query_text) > 1000:
        raise ValueError("Query text must be ≤ 1000 characters")
    if top_k < 1 or top_k > 100:
        raise ValueError("top_k must be between 1 and 100")

    logger.info(f"Embedding query: '{query_text[:100]}...' (top_k={top_k})")
    start_time = time.time()

    # Generate query embedding with retry
    try:
        embedding = utils.retry_with_backoff(
            lambda: embed_query(query_text, cohere_client),
            max_retries=3,
            base_delay=1.0,
            max_delay=10.0,
        )
        embed_time = time.time() - start_time
        logger.debug(
            f"Generated embedding in {embed_time:.2f}s, dimension: {len(embedding)}"
        )
    except Exception as e:
        logger.error(f"Failed to embed query: {e}")
        raise

    # Search Qdrant with retry
    try:
        search_start = time.time()
        response = utils.retry_with_backoff(
            lambda: qdrant_client.query_points(
                collection_name=collection_name,
                query=embedding,
                limit=top_k,
                with_payload=True,
                with_vectors=False,
            ),
            max_retries=3,
            base_delay=1.0,
            max_delay=10.0,
        )
        results = response.points
        search_time = time.time() - search_start
        logger.info(
            f"Search completed in {search_time:.2f}s, returned {len(results)} results"
        )
    except Exception as e:
        logger.error(f"Search failed: {e}")
        raise APIError(f"Qdrant search failed: {e}")

    # Format results
    formatted = []
    for result in results:
        formatted.append(
            {
                "id": str(result.id),
                "score": float(result.score),
                "payload": result.payload,
            }
        )

    total_time = time.time() - start_time
    logger.info(f"Total query time: {total_time:.2f}s")

    return formatted


def format_results(
    results: List[Dict[str, Any]], query: str, latency_ms: int
) -> Dict[str, Any]:
    """Format search results into JSON output structure."""
    output = {
        "query": query,
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "results": results,
        "metadata": {
            "total_results": len(results),
            "collection": None,  # Will be filled by main
            "latency_ms": latency_ms,
        },
    }
    return output


def main() -> int:
    """CLI entrypoint for retrieval."""
    parser = argparse.ArgumentParser(
        description="Retrieve relevant chunks from Qdrant using Cohere embeddings"
    )
    parser.add_argument("--query", type=str, help="Search query text")
    parser.add_argument(
        "--top-k", type=int, default=5, help="Number of results to return (default: 5)"
    )
    parser.add_argument("--output", type=str, help="Output file path (default: stdout)")
    parser.add_argument(
        "--config",
        type=str,
        default=".env",
        help="Path to .env config file (default: .env)",
    )
    parser.add_argument(
        "--validate-metadata",
        action="store_true",
        help="Run metadata validation on search results (requires --query)",
    )

    args = parser.parse_args()

    # Setup logging
    log_file = "retrieve.log"
    setup_logging(log_file=log_file, console_level="INFO")
    logger.info("=== Retrieval Pipeline Started ===")

    try:
        # Load config
        logger.info(f"Loading config from {args.config}")
        cfg = config.get_config()
        validate_config(cfg)

        # Initialize clients
        logger.info("Initializing Cohere and Qdrant clients")
        cohere_client, qdrant_client = init_clients(cfg)

        # Check collection
        collection_name = cfg["qdrant_collection"]
        logger.info(f"Checking collection '{collection_name}'")
        coll_info = check_collection(qdrant_client, collection_name)
        logger.info(
            f"Collection OK: vector_size={coll_info['vector_size']}, points={coll_info['points_count']}"
        )

        # Validate query argument
        if not args.query:
            parser.error("--query is required")

        # Perform search
        results = search(
            query_text=args.query,
            cohere_client=cohere_client,
            qdrant_client=qdrant_client,
            collection_name=collection_name,
            top_k=args.top_k,
        )

        # Perform metadata validation if requested
        metadata_validation = None
        if args.validate_metadata:
            completeness = validate_metadata_completeness(results)
            sequencing = validate_chunk_sequencing(results)
            metadata_validation = {
                "completeness_pct": round(completeness, 2),
                "sequencing_valid": sequencing,
                "pass": completeness >= 98.0 and sequencing,
            }
            logger.info(f"Metadata completeness: {completeness:.1f}%")
            logger.info(f"Chunk sequencing: {'VALID' if sequencing else 'INVALID'}")
            logger.info(
                f"Validation result: {'PASS' if metadata_validation['pass'] else 'FAIL'}"
            )

        # Format output
        output = {
            "query": args.query,
            "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            "results": results,
            "metadata": {
                "total_results": len(results),
                "collection": collection_name,
                "vector_size": coll_info["vector_size"],
                "points_count": coll_info["points_count"],
            },
        }

        if metadata_validation:
            output["metadata_validation"] = metadata_validation

        # Output JSON
        json_output = json.dumps(output, indent=2)
        if args.output:
            with open(args.output, "w") as f:
                f.write(json_output)
            logger.info(f"Results written to {args.output}")
        else:
            print(json_output)

        logger.info("=== Retrieval Pipeline Completed Successfully ===")
        return 0

    except ValueError as ve:
        logger.error(f"Validation error: {ve}")
        print(f"ERROR: {ve}", file=sys.stderr)
        return 2
    except ConfigurationError as ce:
        logger.error(f"Configuration error: {ce}")
        print(f"ERROR: {ce}", file=sys.stderr)
        return 1
    except CollectionNotFoundError as cnfe:
        logger.error(f"Collection error: {cnfe}")
        print(f"ERROR: {cnfe}", file=sys.stderr)
        return 1
    except DimensionMismatchError as dme:
        logger.error(f"Dimension error: {dme}")
        print(f"ERROR: {dme}", file=sys.stderr)
        return 1
    except APIError as api_err:
        logger.error(f"API error: {api_err}")
        print(f"ERROR: {api_err}", file=sys.stderr)
        return 1
    except Exception as e:
        logger.exception(f"Unexpected error: {e}")
        print(f"ERROR: Unexpected error: {e}", file=sys.stderr)
        return 1


if __name__ == "__main__":
    sys.exit(main())