| """ |
| Vector database module using ChromaDB |
| Stores and retrieves document chunks with embeddings |
| """ |
|
|
| import logging |
| from pathlib import Path |
| from typing import List, Dict, Any, Optional |
|
|
| try: |
| import chromadb |
| from chromadb.config import Settings |
| CHROMADB_AVAILABLE = True |
| except ImportError: |
| CHROMADB_AVAILABLE = False |
|
|
| from .config import VECTOR_DB_DIR, DEFAULT_RETRIEVAL_K |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LegalVectorDB: |
| """ChromaDB vector database for legal documents""" |
| |
| def __init__(self, persist_directory: Path = VECTOR_DB_DIR): |
| """ |
| Initialize ChromaDB with persistent storage |
| |
| Args: |
| persist_directory: Directory to store the database |
| """ |
| if not CHROMADB_AVAILABLE: |
| raise ImportError( |
| "chromadb not installed. " |
| "Install with: pip install chromadb" |
| ) |
| |
| self.persist_directory = Path(persist_directory) |
| self.persist_directory.mkdir(parents=True, exist_ok=True) |
| |
| logger.info(f"Initializing ChromaDB at {self.persist_directory}") |
| |
| |
| self.client = chromadb.PersistentClient( |
| path=str(self.persist_directory) |
| ) |
| |
| |
| self.collection_name = "nepal_legal_docs" |
| self.collection = self.client.get_or_create_collection( |
| name=self.collection_name, |
| metadata={"description": "Nepal legal documents for RAG-based law explanation"} |
| ) |
| |
| current_count = self.collection.count() |
| logger.info(f"Collection '{self.collection_name}' ready. Current document count: {current_count}") |
| |
| def add_chunks( |
| self, |
| chunks: List[Dict[str, Any]], |
| embeddings: List[List[float]] |
| ) -> None: |
| """ |
| Add chunks with embeddings to the database |
| |
| Args: |
| chunks: List of chunk dictionaries with 'chunk_id', 'text', and 'metadata' |
| embeddings: List of embedding vectors (as lists) |
| """ |
| if len(chunks) != len(embeddings): |
| raise ValueError(f"Number of chunks ({len(chunks)}) must match number of embeddings ({len(embeddings)})") |
| |
| |
| ids = [chunk['chunk_id'] for chunk in chunks] |
| documents = [chunk['text'] for chunk in chunks] |
| |
| |
| |
| metadatas = [] |
| for chunk in chunks: |
| cleaned_metadata = {} |
| for key, value in chunk['metadata'].items(): |
| if value is None: |
| |
| continue |
| elif isinstance(value, (str, int, float, bool)): |
| |
| cleaned_metadata[key] = value |
| elif isinstance(value, list): |
| |
| if value: |
| cleaned_metadata[key] = ', '.join(str(item) for item in value) |
| else: |
| |
| cleaned_metadata[key] = str(value) |
| metadatas.append(cleaned_metadata) |
| |
| logger.info(f"Adding {len(chunks)} chunks to vector database") |
| |
| |
| self.collection.add( |
| ids=ids, |
| documents=documents, |
| embeddings=embeddings, |
| metadatas=metadatas |
| ) |
| |
| total_count = self.collection.count() |
| logger.info(f"Successfully added chunks. Total documents in database: {total_count}") |
| |
| def query( |
| self, |
| query_text: str, |
| n_results: int = DEFAULT_RETRIEVAL_K, |
| where: Optional[Dict] = None |
| ) -> Dict[str, Any]: |
| """ |
| Query the database with a text query |
| |
| Args: |
| query_text: Query string |
| n_results: Number of results to return |
| where: Optional metadata filter |
| |
| Returns: |
| Dictionary with 'ids', 'documents', 'metadatas', and 'distances' |
| """ |
| logger.info(f"Querying database with: '{query_text[:50]}...' (n_results={n_results})") |
| |
| results = self.collection.query( |
| query_texts=[query_text], |
| n_results=n_results, |
| where=where |
| ) |
| |
| return results |
| |
| def query_with_embedding( |
| self, |
| query_embedding: List[float], |
| n_results: int = DEFAULT_RETRIEVAL_K, |
| where: Optional[Dict] = None |
| ) -> Dict[str, Any]: |
| """ |
| Query with pre-computed embedding |
| |
| Args: |
| query_embedding: Query embedding vector |
| n_results: Number of results to return |
| where: Optional metadata filter |
| |
| Returns: |
| Dictionary with 'ids', 'documents', 'metadatas', and 'distances' |
| """ |
| logger.info(f"Querying database with embedding (n_results={n_results})") |
| |
| results = self.collection.query( |
| query_embeddings=[query_embedding], |
| n_results=n_results, |
| where=where |
| ) |
| |
| return results |
| |
| def get_count(self) -> int: |
| """Get the number of documents in the database""" |
| return self.collection.count() |
| |
| def delete_collection(self) -> None: |
| """Delete the entire collection (use with caution!)""" |
| logger.warning(f"Deleting collection '{self.collection_name}'") |
| self.client.delete_collection(name=self.collection_name) |
| logger.info("Collection deleted") |
| |
| def peek(self, limit: int = 5) -> Dict[str, Any]: |
| """ |
| Peek at some documents in the database |
| |
| Args: |
| limit: Number of documents to return |
| |
| Returns: |
| Dictionary with sample documents |
| """ |
| return self.collection.peek(limit=limit) |
|
|