| import argparse |
|
|
| from dataloaders.langchain import FinanceBenchDataloader |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from pymilvus import CollectionSchema, DataType, FieldSchema |
|
|
| from rag_pipelines.embeddings import SparseEmbeddingsMilvus as SparseEmbeddings |
| from rag_pipelines.unstructured import UnstructuredChunker, UnstructuredDocumentLoader |
| from rag_pipelines.utils import dict_type |
| from rag_pipelines.vectordb import MilvusVectorDB |
|
|
|
|
| def parse_arguments() -> argparse.Namespace: |
| """Parse command-line arguments. |
| |
| Returns: |
| argparse.Namespace: Parsed command-line arguments. |
| """ |
| parser = argparse.ArgumentParser( |
| description="Run the FinanceBench pipeline to load, process, chunk, embed, and index documents." |
| ) |
|
|
| |
| parser.add_argument( |
| "--dataset_name", |
| type=str, |
| default="PatronusAI/financebench", |
| help="HuggingFace dataset name.", |
| ) |
| parser.add_argument( |
| "--split", |
| type=str, |
| default="train", |
| help="Dataset split to use (e.g., 'train').", |
| ) |
|
|
| |
| parser.add_argument( |
| "--pdf_dir", |
| type=str, |
| default="pdfs/", |
| help="Directory path containing PDF files.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--strategy", |
| type=str, |
| default="fast", |
| help="Processing strategy for the unstructured document loader.", |
| ) |
| parser.add_argument( |
| "--mode", |
| type=str, |
| default="elements", |
| help="Extraction mode for the unstructured document loader.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--milvus_uri", |
| type=str, |
| help="URI for the Milvus server.", |
| ) |
| parser.add_argument( |
| "--milvus_token", |
| type=str, |
| help="Authentication token for Milvus.", |
| ) |
| parser.add_argument( |
| "--collection_name", |
| type=str, |
| default="financebench", |
| help="Name of the Milvus collection to create/use.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--dense_embedding_model", |
| type=str, |
| default="sentence-transformers/all-mpnet-base-v2", |
| help="Model name for dense embeddings.", |
| ) |
| parser.add_argument( |
| "--dense_model_kwargs", |
| type=dict_type, |
| default='{"device": "cpu", "trust_remote_code": true}', |
| help="Keyword arguments for dense embeddings model initialization.", |
| ) |
| parser.add_argument( |
| "--dense_encode_kwargs", |
| type=dict_type, |
| default='{"normalize_embeddings": true}', |
| help="Keyword arguments for dense embeddings encoding.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--sparse_embedding_model", |
| type=str, |
| default="Splade_PP_en_v1", |
| help="Model name for sparse embeddings.", |
| ) |
|
|
| |
|
|
| |
| parser.add_argument( |
| "--pk_field", |
| type=str, |
| default="doc_id", |
| help="Name of the primary key field.", |
| ) |
| parser.add_argument( |
| "--dense_field", |
| type=str, |
| default="dense_vector", |
| help="Name of the dense vector field.", |
| ) |
| parser.add_argument( |
| "--sparse_field", |
| type=str, |
| default="sparse_vector", |
| help="Name of the sparse vector field.", |
| ) |
| parser.add_argument( |
| "--text_field", |
| type=str, |
| default="text", |
| help="Name of the text field.", |
| ) |
| parser.add_argument( |
| "--metadata_field", |
| type=str, |
| default="metadata", |
| help="Name of the metadata field.", |
| ) |
|
|
| parser.add_argument( |
| "--dense_dim", |
| type=int, |
| default=768, |
| help="Dimension of dense embeddings.", |
| ) |
| parser.add_argument( |
| "--pk_max_length", |
| type=int, |
| default=100, |
| help="Max length for the primary key field.", |
| ) |
| parser.add_argument( |
| "--text_max_length", |
| type=int, |
| default=65535, |
| help="Max length for the text field.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--dense_index_params", |
| type=dict_type, |
| default='{"index_type": "FLAT", "metric_type": "IP"}', |
| help="JSON string specifying dense index parameters.", |
| ) |
| parser.add_argument( |
| "--sparse_index_params", |
| type=dict_type, |
| default='{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}', |
| help="JSON string specifying sparse index parameters.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--create_new_collection", |
| action="store_true", |
| help="Create a new collection or use existing. Defaults to False.", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| """Run the FinanceBench document processing pipeline. |
| |
| This function performs the following steps: |
| 1. Loads the FinanceBench dataset. |
| 2. Retrieves PDF documents from the specified directory. |
| 3. Processes PDFs using the UnstructuredDocumentLoader. |
| 4. Chunks documents using the UnstructuredChunker. |
| 5. Generates dense and sparse embeddings with specified parameters. |
| 6. Sets up a Milvus vector database and indexes the documents. |
| """ |
| args = parse_arguments() |
|
|
| |
| dataloader = FinanceBenchDataloader( |
| dataset_name=args.dataset_name, |
| split=args.split, |
| ) |
|
|
| |
| unstructured_document_loader = UnstructuredDocumentLoader( |
| strategy=args.strategy, |
| mode=args.mode, |
| ) |
|
|
| |
| chunker = UnstructuredChunker() |
|
|
| |
| dense_embeddings = HuggingFaceEmbeddings( |
| model_name=args.dense_embedding_model, |
| model_kwargs=args.dense_model_kwargs, |
| encode_kwargs=args.dense_encode_kwargs, |
| ) |
| sparse_embeddings = SparseEmbeddings( |
| model_name=args.sparse_embedding_model, |
| ) |
|
|
| |
| pk_field = args.pk_field |
| dense_field = args.dense_field |
| sparse_field = args.sparse_field |
| text_field = args.text_field |
| metadata_field = args.metadata_field |
|
|
| fields = [ |
| FieldSchema( |
| name=pk_field, |
| dtype=DataType.VARCHAR, |
| is_primary=True, |
| auto_id=True, |
| max_length=args.pk_max_length, |
| ), |
| FieldSchema(name=dense_field, dtype=DataType.FLOAT_VECTOR, dim=args.dense_dim), |
| FieldSchema(name=sparse_field, dtype=DataType.SPARSE_FLOAT_VECTOR), |
| FieldSchema(name=text_field, dtype=DataType.VARCHAR, max_length=args.text_max_length), |
| FieldSchema(name=metadata_field, dtype=DataType.JSON), |
| ] |
| schema = CollectionSchema(fields=fields, enable_dynamic_field=False) |
|
|
| |
| milvus_vector_db = MilvusVectorDB( |
| uri=args.milvus_uri, |
| token=args.milvus_token, |
| collection_name=args.collection_name, |
| collection_schema=schema, |
| dense_field=dense_field, |
| sparse_field=sparse_field, |
| text_field=text_field, |
| metadata_field=metadata_field, |
| dense_index_params=args.dense_index_params, |
| sparse_index_params=args.sparse_index_params, |
| create_new_collection=args.create_new_collection, |
| ) |
|
|
| |
| dataloader.get_corpus_pdfs() |
| documents = unstructured_document_loader.transform_documents(args.pdf_dir) |
| chunked_documents = chunker.transform_documents(documents) |
| milvus_vector_db.add_documents( |
| documents=chunked_documents, |
| dense_embedding_model=dense_embeddings, |
| sparse_embedding_model=sparse_embeddings, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|