| import argparse |
|
|
| from rag_pipelines.embeddings.dense import DenseEmbeddings |
| from rag_pipelines.embeddings.sparse import SparseEmbeddings |
| from rag_pipelines.llms.groq import ChatGroqGenerator |
| from rag_pipelines.pipelines.crag import CorrectiveRAGPipeline |
| from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader |
| from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator |
| from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB |
| from rag_pipelines.vectordb.pinecone_hybrid_retriever import PineconeHybridRetriever |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run the Corrective RAG pipeline.") |
|
|
| |
| parser.add_argument( |
| "--dense_model_name", |
| type=str, |
| default="sentence-transformers/all-MiniLM-L6-v2", |
| help="Dense embedding model name.", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cpu", |
| help="Device to run the dense embedding model.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--sparse_max_seq_length", |
| type=int, |
| default=512, |
| help="Maximum sequence length for sparse embeddings.", |
| ) |
|
|
| |
| parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.") |
| parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.") |
| parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.") |
| parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.") |
| parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.") |
| parser.add_argument( |
| "--namespace", |
| type=str, |
| default="edgar-all", |
| help="Namespace for Pinecone retriever.", |
| ) |
|
|
| |
| parser.add_argument("--alpha", type=float, default=0.5, help="Alpha parameter for hybrid retriever.") |
| parser.add_argument("--top_k", type=int, default=5, help="Number of top documents to retrieve.") |
|
|
| |
| parser.add_argument( |
| "--llm_model", |
| type=str, |
| default="llama-3.2-90b-vision-preview", |
| help="Language model name.", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0, |
| help="Temperature for the language model.", |
| ) |
| parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.") |
|
|
| |
| parser.add_argument( |
| "--relevance_threshold", |
| type=float, |
| default=0.7, |
| help="Relevance threshold for document grading.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--query", |
| type=str, |
| required=True, |
| help="Query to run through the Corrective RAG pipeline.", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| dense_embeddings = DenseEmbeddings( |
| model_name=args.dense_model_name, |
| model_kwargs={"device": args.device}, |
| encode_kwargs={"normalize_embeddings": True}, |
| show_progress=True, |
| ) |
| sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length}) |
|
|
| dense_embedding_model = dense_embeddings.embedding_model |
| sparse_embedding_model = sparse_embeddings.sparse_embedding_model |
|
|
| |
| pinecone_vector_db = PineconeHybridVectorDB( |
| api_key=args.pinecone_api_key, |
| index_name=args.index_name, |
| dimension=args.dimension, |
| metric=args.metric, |
| region=args.region, |
| ) |
|
|
| |
| pinecone_retriever = PineconeHybridRetriever( |
| index=pinecone_vector_db.index, |
| dense_embedding_model=dense_embedding_model, |
| sparse_embedding_model=sparse_embedding_model, |
| alpha=args.alpha, |
| top_k=args.top_k, |
| namespace=args.namespace, |
| ) |
|
|
| |
| retrieval_evaluator = RetrievalEvaluator( |
| llm_model=args.llm_model, |
| llm_api_key=args.llm_api_key, |
| temperature=args.temperature, |
| ) |
| document_grader = DocumentGrader( |
| evaluator=retrieval_evaluator, |
| threshold=args.relevance_threshold, |
| ) |
|
|
| |
| generator = ChatGroqGenerator( |
| model=args.llm_model, |
| api_key=args.llm_api_key, |
| llm_params={"temperature": args.temperature}, |
| ) |
| llm = generator.llm |
|
|
| |
| corrective_rag = CorrectiveRAGPipeline( |
| retriever=pinecone_retriever.hybrid_retriever, |
| prompt=retrieval_evaluator.prompt_template, |
| llm=llm, |
| document_grader=document_grader, |
| tracing_project_name="sec_corrective_rag", |
| ) |
|
|
| |
| output = corrective_rag.run(args.query) |
| print(output) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|