| import logging |
| logger = logging.getLogger() |
|
|
| import openai |
| from pydantic import BaseSettings |
|
|
| from langchain.chat_models import ChatOpenAI |
| from langchain.chains import RetrievalQAWithSourcesChain |
| from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
| from langchain.chains import SequentialChain |
| from langchain.llms import OpenAI |
| from langchain.chains import LLMCheckerChain |
| from langchain.chains.query_constructor.base import AttributeInfo |
| from langchain.vectorstores import Pinecone |
| from langchain.embeddings.openai import OpenAIEmbeddings |
|
|
| from magic.prompts import PROMPT, EXAMPLE_PROMPT |
| from magic.self_query_retriever import SelfQueryRetriever |
|
|
| from utils import get_courses |
|
|
| import pinecone |
|
|
|
|
| class Settings(BaseSettings): |
| OPENAI_API_KEY: str = 'OPENAI_API_KEY' |
| OPENAI_CHAT_MODEL: str = 'gpt-3.5-turbo' |
| PINECONE_API_KEY: str = 'PINECONE_API_KEY' |
| PINECONE_INDEX_NAME: str = 'kth-qa' |
| PINECONE_ENV: str = 'us-west1-gcp-free' |
| class Config: |
| env_file = '.env' |
|
|
| def set_openai_key(key): |
| """Sets OpenAI key.""" |
| openai.api_key = key |
|
|
| class State: |
| settings: Settings |
| store: Pinecone |
| chain: RetrievalQAWithSourcesChain |
| courses: list |
|
|
| def __init__(self): |
| self.settings = Settings() |
|
|
| self.courses = get_courses() |
|
|
| |
| set_openai_key(self.settings.OPENAI_API_KEY) |
| |
| |
| embeddings = OpenAIEmbeddings() |
| pinecone.init(api_key=self.settings.PINECONE_API_KEY, environment=self.settings.PINECONE_ENV) |
| self.store: Pinecone = Pinecone.from_existing_index(self.settings.PINECONE_INDEX_NAME, embeddings, "text") |
| logger.info(f"Pinecone store initialized") |
|
|
| |
| doc_chain = self._load_doc_chain() |
| qa_chain = self._load_qa_chain(doc_chain, self_query=True) |
| |
| |
| self.chain = qa_chain |
| |
| |
| |
| |
|
|
| def _load_seq_chain(self, chains): |
| sequential_chain = SequentialChain( |
| chains=chains, |
| input_variables=["question"], |
| output_variables=["answer"], |
| verbose=True) |
| return sequential_chain |
|
|
| def _load_checker_chain(self): |
| llm = OpenAI(temperature=0) |
| checker_chain = LLMCheckerChain(llm=llm, verbose=True, input_key="answer", output_key="result") |
| return checker_chain |
| |
| def _load_doc_chain(self): |
| doc_chain = load_qa_with_sources_chain( |
| ChatOpenAI(temperature=0, max_tokens=256, model=self.settings.OPENAI_CHAT_MODEL, request_timeout=120), |
| chain_type="stuff", |
| document_variable_name="context", |
| prompt=PROMPT, |
| document_prompt=EXAMPLE_PROMPT |
| ) |
| return doc_chain |
| |
| def _load_qa_chain(self, doc_chain, self_query=False): |
| """Load QA chain with retriever. |
| If self_query is True, the retriever will be a SelfQueryRetriever, |
| which will extract a metadata filter from question, and add to the vectorstore query. |
| """ |
| if self_query: |
| metadata_field_info=[ |
| AttributeInfo( |
| name="course", |
| description="A course code for a course", |
| type="string" |
| )] |
| document_content_description = "Brief description of a course" |
| llm = OpenAI(temperature=0, model_name='text-davinci-002') |
| retriever = SelfQueryRetriever.from_llm(llm, self.store, document_content_description, |
| metadata_field_info, verbose=True) |
| qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain, |
| retriever=retriever, |
| return_source_documents=False) |
| else: |
| qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain, |
| retriever=self.store.as_retriever(), |
| return_source_documents=False) |
| return qa_chain |
| |
| def course_exists(self, course: str): |
| course = course.upper() |
| exists = course in self.courses |
| if exists: |
| logger.info(f'Course {course} exists') |
| return True |
| else: |
| logger.info(f'Course {course} does not exist') |
| return False |
| |
| if __name__ == '__main__': |
| state = State() |
| print(state.settings) |
|
|