| """ |
| RAG Chain Module |
| Orchestrates retrieval and generation for legal explanations |
| """ |
|
|
| import logging |
| from typing import Dict, Any, List, Optional |
|
|
| from .embeddings import EmbeddingGenerator |
| from .llm_client import MistralClient |
| from .prompts import format_rag_prompt, LEGAL_SYSTEM_PROMPT |
| from .config import DEFAULT_RETRIEVAL_K, PINECONE_API_KEY |
|
|
| |
| try: |
| from .pinecone_vector_db import PineconeLegalVectorDB |
| PINECONE_AVAILABLE = True |
| except ImportError: |
| PINECONE_AVAILABLE = False |
| PineconeLegalVectorDB = None |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| def _setup_rag_logging(): |
| """Ensure RAG chain logs are written to file""" |
| try: |
| from .logging_setup import setup_logging |
| setup_logging("module_a.rag_chain") |
| except Exception: |
| pass |
|
|
| _setup_rag_logging() |
|
|
|
|
| class LegalRAGChain: |
| """ |
| Retrieval-Augmented Generation Chain for Legal Explanations |
| Combines Vector DB retrieval with Mistral LLM generation |
| |
| NOTE: This RAG chain uses Pinecone only. ChromaDB integration has been removed. |
| Make sure PINECONE_API_KEY is set before initializing. |
| """ |
| |
| def __init__(self): |
| """Initialize the RAG chain components""" |
| logger.info("Initializing Legal RAG Chain...") |
| |
| |
| if not PINECONE_AVAILABLE: |
| raise ImportError( |
| "Pinecone client not installed. " |
| "Install with: pip install pinecone-client[grpc]>=3.0.0" |
| ) |
| |
| |
| if not PINECONE_API_KEY: |
| raise ValueError( |
| "PINECONE_API_KEY must be set to use the RAG chain. " |
| "Set it as an environment variable or in a .env file. " |
| "Get your API key from: https://app.pinecone.io/" |
| ) |
| |
| |
| self.embedder = EmbeddingGenerator() |
| |
| |
| logger.info("Initializing Pinecone vector database...") |
| try: |
| self.vector_db = PineconeLegalVectorDB() |
| logger.info("✓ Using Pinecone cloud vector database") |
| except Exception as e: |
| logger.error(f"Failed to initialize Pinecone: {e}") |
| raise RuntimeError( |
| f"Pinecone initialization failed: {e}. " |
| "Please check your API key and network connection. " |
| "See module_a/PINECONE_SETUP.md for setup instructions." |
| ) |
| |
| self.llm = MistralClient() |
| |
| logger.info("RAG Chain initialized successfully with Pinecone") |
| |
| def get_vector_db_info(self) -> Dict[str, Any]: |
| """ |
| Get information about the Pinecone vector database |
| |
| Returns: |
| Dictionary with database type, name, and other info |
| """ |
| info = { |
| "type": "Pinecone", |
| "class_name": type(self.vector_db).__name__, |
| "is_pinecone": True, |
| "index_name": getattr(self.vector_db, "index_name", "unknown"), |
| "vector_count": self.vector_db.get_count() |
| } |
| |
| return info |
| |
| def run( |
| self, |
| query: str, |
| k: int = DEFAULT_RETRIEVAL_K |
| ) -> Dict[str, Any]: |
| """ |
| Run the full RAG pipeline |
| |
| Args: |
| query: User's question |
| k: Number of chunks to retrieve |
| |
| Returns: |
| Dictionary with 'query', 'explanation', and 'sources' |
| """ |
| logger.info(f"Processing query: {query}") |
| |
| |
| logger.info("Step 1: Retrieving relevant laws...") |
| query_embedding = self.embedder.generate_embedding(query) |
| retrieval_results = self.vector_db.query_with_embedding( |
| query_embedding.tolist(), |
| n_results=k |
| ) |
| |
| |
| context_chunks = [] |
| if retrieval_results['documents'][0]: |
| for doc, metadata, distance in zip( |
| retrieval_results['documents'][0], |
| retrieval_results['metadatas'][0], |
| retrieval_results['distances'][0] |
| ): |
| context_chunks.append({ |
| 'text': doc, |
| 'metadata': metadata, |
| 'distance': distance |
| }) |
| |
| logger.info(f"Retrieved {len(context_chunks)} relevant chunks") |
| |
| |
| logger.info("Step 2: Generating explanation...") |
| |
| |
| prompt = format_rag_prompt(query, context_chunks) |
| |
| |
| try: |
| explanation = self.llm.generate_response( |
| prompt=prompt, |
| system_prompt=LEGAL_SYSTEM_PROMPT |
| ) |
| except Exception as e: |
| logger.error(f"Generation failed: {e}") |
| explanation = "I apologize, but I encountered an error while generating the explanation. Please try again later." |
| |
| |
| sources = [] |
| for i, chunk in enumerate(context_chunks): |
| source_file = chunk['metadata'].get('source_file', 'Legal Document') |
| article_section = chunk['metadata'].get('article_section') |
|
|
| |
| if not article_section and 'Article' in chunk['text'][:200]: |
| |
| import re |
| match = re.search(r'Article\s+(\d+[A-Za-z]?)', chunk['text'][:200]) |
| if match: |
| article_section = f"Article {match.group(1)}" |
|
|
| |
| source_entry = { |
| 'file': source_file, |
| 'section': article_section or f"Section {i+1}", |
| 'relevance_score': 1.0 - chunk['distance'] |
| } |
| sources.append(source_entry) |
|
|
| result = { |
| 'query': query, |
| 'explanation': explanation, |
| 'sources': sources |
| } |
|
|
| logger.info(f"Returning {len(sources)} sources") |
|
|
| return result |
|
|