| from langchain.chains import RetrievalQA |
| from langchain.embeddings import HuggingFaceEmbeddings |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
| from langchain.vectorstores import Chroma |
| from langchain.llms import GPT4All, LlamaCpp |
| import os |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| from constants import CHROMA_SETTINGS |
|
|
| def main(): |
| embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) |
| db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) |
| retriever = db.as_retriever() |
| |
| callbacks = [StreamingStdOutCallbackHandler()] |
| match model_type: |
| case "LlamaCpp": |
| llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) |
| case "GPT4All": |
| llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) |
| case _default: |
| print(f"Model {model_type} not supported!") |
| exit; |
| qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) |
| |
| while True: |
| query = input("\nEnter a query: ") |
| if query == "exit": |
| break |
| |
| |
| res = qa(query) |
| answer, docs = res['result'], res['source_documents'] |
|
|
| |
| print("\n\n> Question:") |
| print(query) |
| print("\n> Answer:") |
| print(answer) |
| |
| |
| for document in docs: |
| print("\n> " + document.metadata["source"] + ":") |
| print(document.page_content) |
|
|
| if __name__ == "__main__": |
| main() |
|
|