Spaces:
Paused
Paused
Pavol Liška commited on
Commit ·
0c3c7ed
1
Parent(s): b11dd45
async
Browse files- api.py +14 -11
- rag.py +8 -8
- rag_langchain.py +2 -2
api.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
from fastapi import FastAPI, Response, Body, Security
|
| 3 |
from fastapi.security import APIKeyHeader
|
| 4 |
from pydantic import BaseModel, model_validator
|
|
@@ -8,6 +10,9 @@ import json
|
|
| 8 |
from conversation.conversation_store import ConversationStore
|
| 9 |
from rag_langchain import LangChainRAG
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
api = FastAPI()
|
| 12 |
|
| 13 |
conversation_store = ConversationStore()
|
|
@@ -59,7 +64,6 @@ async def read_root():
|
|
| 59 |
|
| 60 |
@api.post("/qa", response_model=AModel)
|
| 61 |
async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
|
| 62 |
-
# Verify the API key
|
| 63 |
if not valid_api_key(api_key):
|
| 64 |
return Response(status_code=401)
|
| 65 |
|
|
@@ -73,7 +77,7 @@ async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
|
|
| 73 |
}
|
| 74 |
)
|
| 75 |
|
| 76 |
-
answer, check_result, sources = rag.rag_chain(data.q, data.llm)
|
| 77 |
|
| 78 |
oid = conversation_store.save_content(
|
| 79 |
q=data.q,
|
|
@@ -100,21 +104,20 @@ async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
|
|
| 100 |
|
| 101 |
@api.post("/emo")
|
| 102 |
async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)):
|
| 103 |
-
# Verify the API key
|
| 104 |
if not valid_api_key(api_key):
|
| 105 |
return Response(status_code=401)
|
| 106 |
|
| 107 |
-
|
| 108 |
-
new_params =
|
| 109 |
new_params["user_grading"] = str(json_body.helpfulness)
|
| 110 |
conversation_store.update(
|
| 111 |
oid=json_body["qid"],
|
| 112 |
-
q=
|
| 113 |
-
a=
|
| 114 |
-
sources=
|
| 115 |
params=new_params
|
| 116 |
)
|
| 117 |
|
| 118 |
|
| 119 |
-
def valid_api_key(api_key: str):
|
| 120 |
-
return api_key
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
from fastapi import FastAPI, Response, Body, Security
|
| 5 |
from fastapi.security import APIKeyHeader
|
| 6 |
from pydantic import BaseModel, model_validator
|
|
|
|
| 10 |
from conversation.conversation_store import ConversationStore
|
| 11 |
from rag_langchain import LangChainRAG
|
| 12 |
|
| 13 |
+
load_dotenv()
|
| 14 |
+
api_keys = [os.environ["API_API_KEY"]]
|
| 15 |
+
|
| 16 |
api = FastAPI()
|
| 17 |
|
| 18 |
conversation_store = ConversationStore()
|
|
|
|
| 64 |
|
| 65 |
@api.post("/qa", response_model=AModel)
|
| 66 |
async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
|
|
|
|
| 67 |
if not valid_api_key(api_key):
|
| 68 |
return Response(status_code=401)
|
| 69 |
|
|
|
|
| 77 |
}
|
| 78 |
)
|
| 79 |
|
| 80 |
+
answer, check_result, sources = await rag.rag_chain(data.q, data.llm)
|
| 81 |
|
| 82 |
oid = conversation_store.save_content(
|
| 83 |
q=data.q,
|
|
|
|
| 104 |
|
| 105 |
@api.post("/emo")
|
| 106 |
async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)):
|
|
|
|
| 107 |
if not valid_api_key(api_key):
|
| 108 |
return Response(status_code=401)
|
| 109 |
|
| 110 |
+
conversation = conversation_store.get(json_body.qid)
|
| 111 |
+
new_params = conversation.params
|
| 112 |
new_params["user_grading"] = str(json_body.helpfulness)
|
| 113 |
conversation_store.update(
|
| 114 |
oid=json_body["qid"],
|
| 115 |
+
q=conversation.conversation[0].q,
|
| 116 |
+
a=conversation.conversation[0].a,
|
| 117 |
+
sources=conversation.conversation[0].sources,
|
| 118 |
params=new_params
|
| 119 |
)
|
| 120 |
|
| 121 |
|
| 122 |
+
def valid_api_key(api_key: str) -> bool:
|
| 123 |
+
return api_key in api_keys
|
rag.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import datetime
|
| 2 |
import os
|
| 3 |
import traceback
|
| 4 |
-
from typing import Any
|
| 5 |
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from langchain.chains import LLMChain
|
|
@@ -91,9 +91,9 @@ def rag_with_rerank_check_multi_query_retriever(agent: Agent, q: str, retrieve_d
|
|
| 91 |
return answer, check_result, context_doc
|
| 92 |
|
| 93 |
|
| 94 |
-
def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
| 95 |
-
|
| 96 |
-
result = create_retrieval_chain(
|
| 97 |
retriever=hyde_2_retrieval(agent, retrieve_document_count),
|
| 98 |
combine_docs_chain=create_stuff_documents_chain(
|
| 99 |
llm=agent.llm,
|
|
@@ -103,7 +103,7 @@ def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
|
| 103 |
),
|
| 104 |
document_prompt=PromptTemplate(input_variables=[], template="page_content")
|
| 105 |
)
|
| 106 |
-
).
|
| 107 |
input={
|
| 108 |
"question": q,
|
| 109 |
"input": q,
|
|
@@ -208,7 +208,7 @@ def hyde_2_retrieval(agent, retrieve_document_count):
|
|
| 208 |
llm=agent.llm,
|
| 209 |
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 210 |
search_type="similarity",
|
| 211 |
-
search_kwargs={"k": min(retrieve_document_count * 10,
|
| 212 |
),
|
| 213 |
prompt=PromptTemplate(
|
| 214 |
input_variables=["question"],
|
|
@@ -224,7 +224,7 @@ def hyde_2_retrieval(agent, retrieve_document_count):
|
|
| 224 |
llm=agent.llm,
|
| 225 |
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 226 |
search_type="similarity",
|
| 227 |
-
search_kwargs={"k": min(retrieve_document_count * 10,
|
| 228 |
),
|
| 229 |
prompt=PromptTemplate(
|
| 230 |
input_variables=["question"],
|
|
@@ -240,7 +240,7 @@ def hyde_2_retrieval(agent, retrieve_document_count):
|
|
| 240 |
llm=agent.llm,
|
| 241 |
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 242 |
search_type="similarity",
|
| 243 |
-
search_kwargs={"k": min(retrieve_document_count * 10,
|
| 244 |
),
|
| 245 |
prompt=PromptTemplate(
|
| 246 |
input_variables=["question"],
|
|
|
|
| 1 |
import datetime
|
| 2 |
import os
|
| 3 |
import traceback
|
| 4 |
+
from typing import Any, Coroutine
|
| 5 |
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from langchain.chains import LLMChain
|
|
|
|
| 91 |
return answer, check_result, context_doc
|
| 92 |
|
| 93 |
|
| 94 |
+
async def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
| 95 |
+
check_prompt: str):
|
| 96 |
+
result = await create_retrieval_chain(
|
| 97 |
retriever=hyde_2_retrieval(agent, retrieve_document_count),
|
| 98 |
combine_docs_chain=create_stuff_documents_chain(
|
| 99 |
llm=agent.llm,
|
|
|
|
| 103 |
),
|
| 104 |
document_prompt=PromptTemplate(input_variables=[], template="page_content")
|
| 105 |
)
|
| 106 |
+
).ainvoke(
|
| 107 |
input={
|
| 108 |
"question": q,
|
| 109 |
"input": q,
|
|
|
|
| 208 |
llm=agent.llm,
|
| 209 |
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 210 |
search_type="similarity",
|
| 211 |
+
search_kwargs={"k": min(retrieve_document_count * 10, 300)}
|
| 212 |
),
|
| 213 |
prompt=PromptTemplate(
|
| 214 |
input_variables=["question"],
|
|
|
|
| 224 |
llm=agent.llm,
|
| 225 |
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 226 |
search_type="similarity",
|
| 227 |
+
search_kwargs={"k": min(retrieve_document_count * 10, 300)}
|
| 228 |
),
|
| 229 |
prompt=PromptTemplate(
|
| 230 |
input_variables=["question"],
|
|
|
|
| 240 |
llm=agent.llm,
|
| 241 |
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 242 |
search_type="similarity",
|
| 243 |
+
search_kwargs={"k": min(retrieve_document_count * 10, 300)}
|
| 244 |
),
|
| 245 |
prompt=PromptTemplate(
|
| 246 |
input_variables=["question"],
|
rag_langchain.py
CHANGED
|
@@ -102,13 +102,13 @@ class LangChainRAG:
|
|
| 102 |
def get_llms(self):
|
| 103 |
return self.llms.keys()
|
| 104 |
|
| 105 |
-
def rag_chain(self, query, llm_choice):
|
| 106 |
print("Using " + llm_choice)
|
| 107 |
|
| 108 |
# answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
|
| 109 |
# answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
|
| 110 |
# answer, check_result, context_doc = vanilla_rag_chain(
|
| 111 |
-
answer, check_result, context_doc = rag_chain(
|
| 112 |
Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
|
| 113 |
query,
|
| 114 |
self.config["retrieve_documents"],
|
|
|
|
| 102 |
def get_llms(self):
|
| 103 |
return self.llms.keys()
|
| 104 |
|
| 105 |
+
async def rag_chain(self, query, llm_choice):
|
| 106 |
print("Using " + llm_choice)
|
| 107 |
|
| 108 |
# answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
|
| 109 |
# answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
|
| 110 |
# answer, check_result, context_doc = vanilla_rag_chain(
|
| 111 |
+
answer, check_result, context_doc = await rag_chain(
|
| 112 |
Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
|
| 113 |
query,
|
| 114 |
self.config["retrieve_documents"],
|