m-ahmad-official commited on
Commit
5f90488
·
1 Parent(s): 812f6e6
Files changed (4) hide show
  1. agent.py +6 -21
  2. config.py +37 -0
  3. logging_config.py +30 -0
  4. retrieve.py +92 -0
agent.py CHANGED
@@ -35,27 +35,13 @@ third_party_model = OpenAIChatCompletionsModel(
35
 
36
  # Make backend package importable
37
  current_dir = os.path.dirname(os.path.abspath(__file__))
38
- backend_parent = os.path.dirname(current_dir)
39
- if backend_parent not in sys.path:
40
- sys.path.insert(0, backend_parent)
41
 
42
- # Import backend modules (support both module and script execution)
43
- try:
44
- from config import get_config
45
- from retrieve import search as retrieve_search
46
- from logging_config import setup_logging
47
- except ImportError as e:
48
- try:
49
- from .config import get_config
50
- from .retrieve import search as retrieve_search
51
- from .logging_config import setup_logging
52
- except ImportError as e2:
53
- try:
54
- from backend.config import get_config
55
- from backend.retrieve import search as retrieve_search
56
- from backend.logging_config import setup_logging
57
- except ImportError as e3:
58
- raise ImportError(f"Failed to import backend modules: {e3}")
59
 
60
  # Import OpenAI Agents SDK (must be installed separately)
61
  try:
@@ -212,7 +198,6 @@ def get_agent():
212
 
213
  def check_qdrant_health() -> str:
214
  try:
215
- from backend.config import get_config
216
  from qdrant_client import QdrantClient
217
 
218
  cfg = get_config()
 
35
 
36
  # Make backend package importable
37
  current_dir = os.path.dirname(os.path.abspath(__file__))
38
+ if current_dir not in sys.path:
39
+ sys.path.insert(0, current_dir)
 
40
 
41
+ # Import backend modules
42
+ from config import get_config
43
+ from retrieve import search as retrieve_search
44
+ from logging_config import setup_logging
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Import OpenAI Agents SDK (must be installed separately)
47
  try:
 
198
 
199
  def check_qdrant_health() -> str:
200
  try:
 
201
  from qdrant_client import QdrantClient
202
 
203
  cfg = get_config()
config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration module for RAG Book Chatbot.
3
+ Loads configuration from environment variables.
4
+ """
5
+
6
+ import os
7
+ from typing import Dict, Any
8
+
9
+
10
+ def get_config() -> Dict[str, Any]:
11
+ """
12
+ Get configuration from environment variables.
13
+
14
+ Returns:
15
+ Dictionary containing all required configuration
16
+
17
+ Raises:
18
+ ValueError: If required environment variables are missing
19
+ """
20
+ config = {
21
+ "openai_api_key": os.getenv("OPENAI_API_KEY"),
22
+ "cohere_api_key": os.getenv("COHERE_API_KEY"),
23
+ "qdrant_url": os.getenv("QDRANT_URL"),
24
+ "qdrant_api_key": os.getenv("QDRANT_API_KEY"),
25
+ "qdrant_collection": os.getenv("QDRANT_COLLECTION", "book-chunks"),
26
+ }
27
+
28
+ # Validate required keys
29
+ required_keys = ["openai_api_key", "qdrant_url", "qdrant_api_key"]
30
+ missing_keys = [key for key in required_keys if not config[key]]
31
+
32
+ if missing_keys:
33
+ raise ValueError(
34
+ f"Missing required environment variables: {', '.join(missing_keys)}"
35
+ )
36
+
37
+ return config
logging_config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logging configuration module.
3
+ """
4
+
5
+ import logging
6
+ import sys
7
+
8
+
9
+ def setup_logging(name: str) -> logging.Logger:
10
+ """
11
+ Set up basic logging for the application.
12
+
13
+ Args:
14
+ name: Logger name
15
+
16
+ Returns:
17
+ Configured logger instance
18
+ """
19
+ logger = logging.getLogger(name)
20
+ logger.setLevel(logging.INFO)
21
+
22
+ if not logger.handlers:
23
+ handler = logging.StreamHandler(sys.stdout)
24
+ formatter = logging.Formatter(
25
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
26
+ )
27
+ handler.setFormatter(formatter)
28
+ logger.addHandler(handler)
29
+
30
+ return logger
retrieve.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retrieval module for RAG Book Chatbot.
3
+ Handles vector search using Qdrant and Cohere embeddings.
4
+ """
5
+
6
+ import logging
7
+ from typing import List, Dict, Any, Optional
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def search(
13
+ query_text: str,
14
+ cohere_client: Any,
15
+ qdrant_client: Any,
16
+ collection_name: str,
17
+ top_k: int = 5,
18
+ ) -> List[Dict[str, Any]]:
19
+ """
20
+ Search for relevant chunks in Qdrant using Cohere embeddings.
21
+
22
+ Args:
23
+ query_text: User's question or search query
24
+ cohere_client: Initialized Cohere client
25
+ qdrant_client: Initialized Qdrant client
26
+ collection_name: Name of the Qdrant collection
27
+ top_k: Number of results to return (default: 5)
28
+
29
+ Returns:
30
+ List of search results with scores and metadata
31
+ """
32
+ try:
33
+ # Generate embedding for the query
34
+ logger.info(f"Generating embedding for query: {query_text[:100]}...")
35
+ embedding_response = cohere_client.embed(
36
+ texts=[query_text],
37
+ model="embed-english-v3.0",
38
+ input_type="search_query",
39
+ )
40
+ query_embedding = embedding_response.embeddings[0]
41
+ logger.debug(f"Generated embedding dimension: {len(query_embedding)}")
42
+
43
+ # Search in Qdrant
44
+ logger.info(f"Searching Qdrant collection: {collection_name}")
45
+ search_results = qdrant_client.search(
46
+ collection_name=collection_name,
47
+ query_vector=query_embedding,
48
+ limit=top_k,
49
+ )
50
+
51
+ logger.info(f"Found {len(search_results)} results")
52
+
53
+ # Format results
54
+ results = []
55
+ for hit in search_results:
56
+ results.append(
57
+ {
58
+ "id": hit.id,
59
+ "score": hit.score,
60
+ "payload": hit.payload,
61
+ }
62
+ )
63
+
64
+ return results
65
+
66
+ except Exception as e:
67
+ logger.error(f"Search failed: {e}", exc_info=True)
68
+ raise
69
+
70
+
71
+ def validate_results(results: List[Dict[str, Any]]) -> float:
72
+ """
73
+ Validate that results have required metadata.
74
+
75
+ Args:
76
+ results: List of search results
77
+
78
+ Returns:
79
+ Percentage of results with complete metadata (0-1)
80
+ """
81
+ if not results:
82
+ return 1.0
83
+
84
+ required_fields = {"url", "chunk_index", "text"}
85
+ valid_count = 0
86
+
87
+ for result in results:
88
+ payload = result.get("payload", {})
89
+ if all(field in payload for field in required_fields):
90
+ valid_count += 1
91
+
92
+ return valid_count / len(results)