Spaces:
Runtime error
Runtime error
| from haystack import Document, Pipeline | |
| from haystack.document_stores.in_memory import InMemoryDocumentStore | |
| from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder | |
| from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever | |
| from haystack.components.builders import PromptBuilder | |
| from datasets import load_dataset | |
| from haystack.dataclasses import ChatMessage | |
| from typing import Optional, List, Union, Dict | |
| from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG | |
| class RAGPipeline: | |
| def __init__( | |
| self, | |
| dataset_config: Union[str, DatasetConfig], | |
| documents: Optional[List[Union[str, Document]]] = None, | |
| embedding_model: Optional[str] = None | |
| ): | |
| """ | |
| Initialize the RAG Pipeline. | |
| Args: | |
| dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object | |
| documents: Optional list of documents to use instead of loading from a dataset | |
| embedding_model: Optional override for embedding model | |
| """ | |
| # Load configuration | |
| if isinstance(dataset_config, str): | |
| if dataset_config not in DATASET_CONFIGS: | |
| raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}") | |
| self.config = DATASET_CONFIGS[dataset_config] | |
| else: | |
| self.config = dataset_config | |
| # Load documents either from provided list or dataset | |
| if documents is not None: | |
| self.documents = documents | |
| else: | |
| dataset = load_dataset(self.config.name, split=self.config.split) | |
| # Create documents with metadata based on configuration | |
| self.documents = [] | |
| for doc in dataset: | |
| # Create metadata dictionary from configured fields | |
| meta = {} | |
| if self.config.fields: | |
| for meta_key, dataset_field in self.config.fields.items(): | |
| if dataset_field in doc: | |
| meta[meta_key] = doc[dataset_field] | |
| # Create document with content and metadata | |
| document = Document( | |
| content=doc[self.config.content_field], | |
| meta=meta | |
| ) | |
| self.documents.append(document) | |
| # Documents loaded silently - remove verbose output | |
| # Initialize components | |
| self.document_store = InMemoryDocumentStore() | |
| self.doc_embedder = SentenceTransformersDocumentEmbedder( | |
| model=embedding_model or MODEL_CONFIG["embedding_model"], | |
| progress_bar=False | |
| ) | |
| self.text_embedder = SentenceTransformersTextEmbedder( | |
| model=embedding_model or MODEL_CONFIG["embedding_model"], | |
| progress_bar=False | |
| ) | |
| self.text_embedder = SentenceTransformersTextEmbedder( | |
| model=embedding_model or MODEL_CONFIG["embedding_model"], | |
| progress_bar=False | |
| ) | |
| self.retriever = InMemoryEmbeddingRetriever(self.document_store) | |
| # Warm up the embedders | |
| self.doc_embedder.warm_up() | |
| self.text_embedder.warm_up() | |
| # Initialize prompt template | |
| self.prompt_builder = PromptBuilder(template=self.config.prompt_template or """ | |
| Given the following context, please answer the question. | |
| Context: | |
| {% for document in documents %} | |
| {{ document.content }} | |
| {% endfor %} | |
| Question: {{question}} | |
| Answer: | |
| """) | |
| # Index documents | |
| self._index_documents(self.documents) | |
| # Build pipeline | |
| self.pipeline = self._build_pipeline() | |
| def from_preset(cls, preset_name: str): | |
| """ | |
| Create a pipeline from a preset configuration. | |
| Args: | |
| preset_name: Name of the preset configuration to use | |
| """ | |
| return cls(dataset_config=preset_name) | |
| def _index_documents(self, documents): | |
| # Embed and index documents | |
| docs_with_embeddings = self.doc_embedder.run(documents) | |
| self.document_store.write_documents(docs_with_embeddings["documents"]) | |
| def _build_pipeline(self): | |
| pipeline = Pipeline() | |
| pipeline.add_component("text_embedder", self.text_embedder) | |
| pipeline.add_component("retriever", self.retriever) | |
| pipeline.add_component("prompt_builder", self.prompt_builder) | |
| # Connect components | |
| pipeline.connect("text_embedder.embedding", "retriever.query_embedding") | |
| pipeline.connect("retriever", "prompt_builder") | |
| return pipeline | |
| def answer_question(self, question: str) -> str: | |
| """Run the RAG pipeline to answer a question""" | |
| # First, embed the question and retrieve relevant documents | |
| embedded_question = self.text_embedder.run(text=question) | |
| retrieved_docs = self.retriever.run(query_embedding=embedded_question["embedding"]) | |
| # Then, build the prompt with retrieved documents | |
| prompt_result = self.prompt_builder.run( | |
| question=question, | |
| documents=retrieved_docs["documents"] | |
| ) | |
| # Return the formatted prompt (this will be processed by the main AI) | |
| return prompt_result["prompt"] |