| import asyncio |
| import os |
| from flask import Blueprint, request, Response |
| import json |
| import datetime |
| import logging |
| import traceback |
|
|
| from pymongo import MongoClient |
| from bson.objectid import ObjectId |
| from transformers import GPT2TokenizerFast |
|
|
|
|
|
|
| from application.core.settings import settings |
| from application.vectorstore.vector_creator import VectorCreator |
| from application.llm.llm_creator import LLMCreator |
| from application.error import bad_request |
|
|
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| mongo = MongoClient(settings.MONGO_URI) |
| db = mongo["docsgpt"] |
| conversations_collection = db["conversations"] |
| vectors_collection = db["vectors"] |
| prompts_collection = db["prompts"] |
| answer = Blueprint('answer', __name__) |
|
|
| if settings.LLM_NAME == "gpt4": |
| gpt_model = 'gpt-4' |
| elif settings.LLM_NAME == "anthropic": |
| gpt_model = 'claude-2' |
| else: |
| gpt_model = 'gpt-3.5-turbo' |
|
|
| |
| current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: |
| chat_combine_template = f.read() |
|
|
| with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f: |
| chat_reduce_template = f.read() |
|
|
| with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f: |
| chat_combine_creative = f.read() |
|
|
| with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: |
| chat_combine_strict = f.read() |
|
|
| api_key_set = settings.API_KEY is not None |
| embeddings_key_set = settings.EMBEDDINGS_KEY is not None |
|
|
|
|
| async def async_generate(chain, question, chat_history): |
| result = await chain.arun({"question": question, "chat_history": chat_history}) |
| return result |
|
|
|
|
| def count_tokens(string): |
| tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') |
| return len(tokenizer(string)['input_ids']) |
|
|
|
|
| def run_async_chain(chain, question, chat_history): |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| result = {} |
| try: |
| answer = loop.run_until_complete(async_generate(chain, question, chat_history)) |
| finally: |
| loop.close() |
| result["answer"] = answer |
| return result |
|
|
|
|
| def get_vectorstore(data): |
| if "active_docs" in data: |
| if data["active_docs"].split("/")[0] == "default": |
| vectorstore = "" |
| elif data["active_docs"].split("/")[0] == "local": |
| vectorstore = "indexes/" + data["active_docs"] |
| else: |
| vectorstore = "vectors/" + data["active_docs"] |
| if data["active_docs"] == "default": |
| vectorstore = "" |
| else: |
| vectorstore = "" |
| vectorstore = os.path.join("application", vectorstore) |
| return vectorstore |
|
|
|
|
| def is_azure_configured(): |
| return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME |
|
|
|
|
| def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id): |
| llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) |
|
|
| if prompt_id == 'default': |
| prompt = chat_combine_template |
| elif prompt_id == 'creative': |
| prompt = chat_combine_creative |
| elif prompt_id == 'strict': |
| prompt = chat_combine_strict |
| else: |
| prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] |
|
|
| docs = docsearch.search(question, k=2) |
| if settings.LLM_NAME == "llama.cpp": |
| docs = [docs[0]] |
| |
| docs_together = "\n".join([doc.page_content for doc in docs]) |
| p_chat_combine = prompt.replace("{summaries}", docs_together) |
| messages_combine = [{"role": "system", "content": p_chat_combine}] |
| source_log_docs = [] |
| for doc in docs: |
| if doc.metadata: |
| source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) |
| else: |
| source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) |
|
|
| if len(chat_history) > 1: |
| tokens_current_history = 0 |
| |
| chat_history.reverse() |
| for i in chat_history: |
| if "prompt" in i and "response" in i: |
| tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) |
| if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: |
| tokens_current_history += tokens_batch |
| messages_combine.append({"role": "user", "content": i["prompt"]}) |
| messages_combine.append({"role": "system", "content": i["response"]}) |
| messages_combine.append({"role": "user", "content": question}) |
|
|
| response_full = "" |
| completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, |
| messages=messages_combine) |
| for line in completion: |
| data = json.dumps({"answer": str(line)}) |
| response_full += str(line) |
| yield f"data: {data}\n\n" |
|
|
| |
| if conversation_id is not None: |
| conversations_collection.update_one( |
| {"_id": ObjectId(conversation_id)}, |
| {"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}}, |
| ) |
|
|
| else: |
| |
| |
| messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 " |
| "words, respond ONLY with the summary, use the same " |
| "language as the system \n\nUser: " + question + "\n\n" + |
| "AI: " + |
| response_full}, |
| {"role": "user", "content": "Summarise following conversation in no more than 3 words, " |
| "respond ONLY with the summary, use the same language as the " |
| "system"}] |
|
|
| completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, |
| messages=messages_summary, max_tokens=30) |
| conversation_id = conversations_collection.insert_one( |
| {"user": "local", |
| "date": datetime.datetime.utcnow(), |
| "name": completion, |
| "queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]} |
| ).inserted_id |
|
|
| |
| data = json.dumps({"type": "id", "id": str(conversation_id)}) |
| yield f"data: {data}\n\n" |
| data = json.dumps({"type": "end"}) |
| yield f"data: {data}\n\n" |
|
|
|
|
| @answer.route("/stream", methods=["POST"]) |
| def stream(): |
| data = request.get_json() |
| |
| question = data["question"] |
| history = data["history"] |
| |
| history = json.loads(history) |
| conversation_id = data["conversation_id"] |
| if 'prompt_id' in data: |
| prompt_id = data["prompt_id"] |
| else: |
| prompt_id = 'default' |
|
|
| |
|
|
| if not api_key_set: |
| api_key = data["api_key"] |
| else: |
| api_key = settings.API_KEY |
| if not embeddings_key_set: |
| embeddings_key = data["embeddings_key"] |
| else: |
| embeddings_key = settings.EMBEDDINGS_KEY |
| if "active_docs" in data: |
| vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) |
| else: |
| vectorstore = "" |
| docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) |
|
|
| return Response( |
| complete_stream(question, docsearch, |
| chat_history=history, api_key=api_key, |
| prompt_id=prompt_id, |
| conversation_id=conversation_id), mimetype="text/event-stream" |
| ) |
|
|
|
|
| @answer.route("/api/answer", methods=["POST"]) |
| def api_answer(): |
| data = request.get_json() |
| question = data["question"] |
| history = data["history"] |
| if "conversation_id" not in data: |
| conversation_id = None |
| else: |
| conversation_id = data["conversation_id"] |
| print("-" * 5) |
| if not api_key_set: |
| api_key = data["api_key"] |
| else: |
| api_key = settings.API_KEY |
| if not embeddings_key_set: |
| embeddings_key = data["embeddings_key"] |
| else: |
| embeddings_key = settings.EMBEDDINGS_KEY |
| if 'prompt_id' in data: |
| prompt_id = data["prompt_id"] |
| else: |
| prompt_id = 'default' |
|
|
| if prompt_id == 'default': |
| prompt = chat_combine_template |
| elif prompt_id == 'creative': |
| prompt = chat_combine_creative |
| elif prompt_id == 'strict': |
| prompt = chat_combine_strict |
| else: |
| prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] |
|
|
| |
| try: |
| |
| vectorstore = get_vectorstore(data) |
| |
| |
| docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) |
|
|
|
|
| llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) |
|
|
|
|
|
|
| docs = docsearch.search(question, k=2) |
| |
| docs_together = "\n".join([doc.page_content for doc in docs]) |
| p_chat_combine = prompt.replace("{summaries}", docs_together) |
| messages_combine = [{"role": "system", "content": p_chat_combine}] |
| source_log_docs = [] |
| for doc in docs: |
| if doc.metadata: |
| source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) |
| else: |
| source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) |
| |
|
|
|
|
| if len(history) > 1: |
| tokens_current_history = 0 |
| |
| history.reverse() |
| for i in history: |
| if "prompt" in i and "response" in i: |
| tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) |
| if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: |
| tokens_current_history += tokens_batch |
| messages_combine.append({"role": "user", "content": i["prompt"]}) |
| messages_combine.append({"role": "system", "content": i["response"]}) |
| messages_combine.append({"role": "user", "content": question}) |
|
|
|
|
| completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, |
| messages=messages_combine) |
|
|
|
|
| result = {"answer": completion, "sources": source_log_docs} |
| logger.debug(result) |
|
|
| |
| if conversation_id is not None: |
| conversations_collection.update_one( |
| {"_id": ObjectId(conversation_id)}, |
| {"$push": {"queries": {"prompt": question, |
| "response": result["answer"], "sources": result['sources']}}}, |
| ) |
|
|
| else: |
| |
| |
| messages_summary = [ |
| {"role": "assistant", "content": "Summarise following conversation in no more than 3 words, " |
| "respond ONLY with the summary, use the same language as the system \n\n" |
| "User: " + question + "\n\n" + "AI: " + result["answer"]}, |
| {"role": "user", "content": "Summarise following conversation in no more than 3 words, " |
| "respond ONLY with the summary, use the same language as the system"} |
| ] |
|
|
| completion = llm.gen( |
| model=gpt_model, |
| engine=settings.AZURE_DEPLOYMENT_NAME, |
| messages=messages_summary, |
| max_tokens=30 |
| ) |
| conversation_id = conversations_collection.insert_one( |
| {"user": "local", |
| "date": datetime.datetime.utcnow(), |
| "name": completion, |
| "queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]} |
| ).inserted_id |
|
|
| result["conversation_id"] = str(conversation_id) |
|
|
| |
| |
| |
| |
| |
| return result |
| except Exception as e: |
| |
| traceback.print_exc() |
| print(str(e)) |
| return bad_request(500, str(e)) |
|
|
|
|
| @answer.route("/api/search", methods=["POST"]) |
| def api_search(): |
| data = request.get_json() |
| |
| question = data["question"] |
|
|
| if not embeddings_key_set: |
| embeddings_key = data["embeddings_key"] |
| else: |
| embeddings_key = settings.EMBEDDINGS_KEY |
| if "active_docs" in data: |
| vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) |
| else: |
| vectorstore = "" |
| docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) |
|
|
| docs = docsearch.search(question, k=2) |
|
|
| source_log_docs = [] |
| for doc in docs: |
| if doc.metadata: |
| source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) |
| else: |
| source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) |
| |
| return source_log_docs |
|
|
|
|