| """ |
| Retriever module for Norwegian RAG chatbot. |
| Retrieves relevant document chunks based on query embeddings. |
| """ |
|
|
| import os |
| import json |
| import numpy as np |
| from typing import List, Dict, Any, Optional, Tuple, Union |
|
|
| from ..api.huggingface_api import HuggingFaceAPI |
| from ..api.config import MAX_CHUNKS_TO_RETRIEVE, SIMILARITY_THRESHOLD |
|
|
| class Retriever: |
| """ |
| Retrieves relevant document chunks based on query embeddings. |
| Uses cosine similarity to find the most relevant chunks. |
| """ |
| |
| def __init__( |
| self, |
| api_client: Optional[HuggingFaceAPI] = None, |
| processed_dir: str = "/home/ubuntu/chatbot_project/data/processed", |
| max_chunks: int = MAX_CHUNKS_TO_RETRIEVE, |
| similarity_threshold: float = SIMILARITY_THRESHOLD |
| ): |
| """ |
| Initialize the retriever. |
| |
| Args: |
| api_client: HuggingFaceAPI client for generating embeddings |
| processed_dir: Directory containing processed documents |
| max_chunks: Maximum number of chunks to retrieve |
| similarity_threshold: Minimum similarity score for retrieval |
| """ |
| self.api_client = api_client or HuggingFaceAPI() |
| self.processed_dir = processed_dir |
| self.max_chunks = max_chunks |
| self.similarity_threshold = similarity_threshold |
| |
| |
| self.document_index_path = os.path.join(self.processed_dir, "document_index.json") |
| self.document_index = self._load_document_index() |
| |
| def retrieve(self, query: str) -> List[Dict[str, Any]]: |
| """ |
| Retrieve relevant document chunks for a query. |
| |
| Args: |
| query: User query |
| |
| Returns: |
| List of retrieved chunks with metadata |
| """ |
| |
| query_embedding = self.api_client.generate_embeddings(query)[0] |
| |
| |
| all_results = [] |
| |
| for doc_id in self.document_index: |
| try: |
| |
| doc_results = self._retrieve_from_document(doc_id, query_embedding) |
| all_results.extend(doc_results) |
| except Exception as e: |
| print(f"Error retrieving from document {doc_id}: {str(e)}") |
| |
| |
| all_results.sort(key=lambda x: x["similarity"], reverse=True) |
| |
| |
| return [ |
| result for result in all_results[:self.max_chunks] |
| if result["similarity"] >= self.similarity_threshold |
| ] |
| |
| def _retrieve_from_document( |
| self, |
| document_id: str, |
| query_embedding: List[float] |
| ) -> List[Dict[str, Any]]: |
| """ |
| Retrieve relevant chunks from a specific document. |
| |
| Args: |
| document_id: Document ID |
| query_embedding: Query embedding vector |
| |
| Returns: |
| List of retrieved chunks with metadata |
| """ |
| document_path = os.path.join(self.processed_dir, f"{document_id}.json") |
| if not os.path.exists(document_path): |
| return [] |
| |
| |
| with open(document_path, 'r', encoding='utf-8') as f: |
| document_data = json.load(f) |
| |
| chunks = document_data.get("chunks", []) |
| embeddings = document_data.get("embeddings", []) |
| metadata = document_data.get("metadata", {}) |
| |
| if not chunks or not embeddings or len(chunks) != len(embeddings): |
| return [] |
| |
| |
| results = [] |
| for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): |
| similarity = self._cosine_similarity(query_embedding, embedding) |
| |
| results.append({ |
| "document_id": document_id, |
| "chunk_index": i, |
| "chunk_text": chunk, |
| "similarity": similarity, |
| "metadata": metadata |
| }) |
| |
| |
| results.sort(key=lambda x: x["similarity"], reverse=True) |
| |
| return results[:self.max_chunks] |
| |
| def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: |
| """ |
| Calculate cosine similarity between two vectors. |
| |
| Args: |
| vec1: First vector |
| vec2: Second vector |
| |
| Returns: |
| Cosine similarity score |
| """ |
| vec1 = np.array(vec1) |
| vec2 = np.array(vec2) |
| |
| dot_product = np.dot(vec1, vec2) |
| norm1 = np.linalg.norm(vec1) |
| norm2 = np.linalg.norm(vec2) |
| |
| if norm1 == 0 or norm2 == 0: |
| return 0.0 |
| |
| return dot_product / (norm1 * norm2) |
| |
| def _load_document_index(self) -> Dict[str, Dict[str, Any]]: |
| """ |
| Load the document index from disk. |
| |
| Returns: |
| Dictionary of document IDs to metadata |
| """ |
| if os.path.exists(self.document_index_path): |
| try: |
| with open(self.document_index_path, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| except Exception as e: |
| print(f"Error loading document index: {str(e)}") |
| |
| return {} |
|
|