| """Retriever that generates and executes structured queries over its own data source. |
| |
| NOTE: This code is adapted from the original implementation in the LangChain repo, |
| but has been modified to work with the KTH QA system. |
| |
| """ |
|
|
| from langchain.vectorstores import Pinecone, VectorStore |
| from langchain.schema import BaseRetriever, Document |
| from langchain.retrievers.self_query.pinecone import PineconeTranslator |
| from langchain.chains.query_constructor.schema import AttributeInfo |
| from langchain.chains.query_constructor.ir import StructuredQuery, Visitor |
| from langchain.chains.query_constructor.base import load_query_constructor_chain |
| from langchain.base_language import BaseLanguageModel |
| from langchain import LLMChain |
| from pydantic import BaseModel, Field, root_validator |
| import re |
| from typing import Any, Dict, List, Optional, Type, cast |
| import logging |
| logger = logging.getLogger() |
|
|
|
|
| COURSE_PATTERN = r"[a-zA-Z]{2,3}\d{3,4}\w?" |
|
|
|
|
| def make_uppercase(matchobj): |
| return matchobj.group(0).upper() |
|
|
|
|
| def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor: |
| """Get the translator class corresponding to the vector store class.""" |
| BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = { |
| Pinecone: PineconeTranslator |
| } |
| if vectorstore_cls not in BUILTIN_TRANSLATORS: |
| raise ValueError( |
| f"Self query retriever with Vector Store type {vectorstore_cls}" |
| f" not supported." |
| ) |
| return BUILTIN_TRANSLATORS[vectorstore_cls]() |
|
|
|
|
| class SelfQueryRetriever(BaseRetriever, BaseModel): |
| """Retriever that wraps around a vector store and uses an LLM to generate |
| the vector store queries.""" |
|
|
| vectorstore: VectorStore |
| """The underlying vector store from which documents will be retrieved.""" |
| llm_chain: LLMChain |
| """The LLMChain for generating the vector store queries.""" |
| search_type: str = "similarity" |
| """The search type to perform on the vector store.""" |
| search_kwargs: dict = Field(default_factory=dict) |
| """Keyword arguments to pass in to the vector store search.""" |
| structured_query_translator: Visitor |
| """Translator for turning internal query language into vectorstore search params.""" |
| verbose: bool = False |
|
|
| class Config: |
| """Configuration for this pydantic object.""" |
|
|
| arbitrary_types_allowed = True |
|
|
| @root_validator(pre=True) |
| def validate_translator(cls, values: Dict) -> Dict: |
| """Validate translator.""" |
| if "structured_query_translator" not in values: |
| vectorstore_cls = values["vectorstore"].__class__ |
| values["structured_query_translator"] = _get_builtin_translator( |
| vectorstore_cls |
| ) |
| return values |
|
|
| def get_relevant_documents(self, query: str) -> List[Document]: |
| """Get documents relevant for a query. |
| |
| Args: |
| query: string to find relevant documents for |
| |
| Returns: |
| List of relevant documents |
| """ |
| if re.findall(COURSE_PATTERN, query): |
| query = re.sub(COURSE_PATTERN, make_uppercase, query) |
| inputs = self.llm_chain.prep_inputs(query) |
| structured_query = cast( |
| StructuredQuery, self.llm_chain.predict_and_parse( |
| callbacks=None, **inputs) |
| ) |
| if self.verbose: |
| logger.info( |
| "Found course pattern in query, using structured query:") |
| logger.info(structured_query) |
| new_query, new_kwargs = self.structured_query_translator.visit_structured_query( |
| structured_query |
| ) |
| search_kwargs = {**self.search_kwargs, **new_kwargs} |
| else: |
| search_kwargs = self.search_kwargs |
| docs = self.vectorstore.search( |
| query, self.search_type, **search_kwargs) |
| return docs |
| |
| async def aget_relevant_documents(self, query: str) -> List[Document]: |
| raise NotImplementedError |
|
|
| @classmethod |
| def from_llm( |
| cls, |
| llm: BaseLanguageModel, |
| vectorstore: VectorStore, |
| document_contents: str, |
| metadata_field_info: List[AttributeInfo], |
| structured_query_translator: Optional[Visitor] = None, |
| chain_kwargs: Optional[Dict] = None, |
| **kwargs: Any, |
| ) -> "SelfQueryRetriever": |
| if structured_query_translator is None: |
| structured_query_translator = _get_builtin_translator( |
| vectorstore.__class__) |
| chain_kwargs = chain_kwargs or {} |
| if "allowed_comparators" not in chain_kwargs: |
| chain_kwargs[ |
| "allowed_comparators" |
| ] = structured_query_translator.allowed_comparators |
| if "allowed_operators" not in chain_kwargs: |
| chain_kwargs[ |
| "allowed_operators" |
| ] = structured_query_translator.allowed_operators |
| llm_chain = load_query_constructor_chain( |
| llm, document_contents, metadata_field_info, **chain_kwargs |
| ) |
| return cls( |
| llm_chain=llm_chain, |
| vectorstore=vectorstore, |
| structured_query_translator=structured_query_translator, |
| **kwargs, |
| ) |
|
|