Pavol Liška commited on
Commit
0c3c7ed
·
1 Parent(s): b11dd45
Files changed (3) hide show
  1. api.py +14 -11
  2. rag.py +8 -8
  3. rag_langchain.py +2 -2
api.py CHANGED
@@ -1,4 +1,6 @@
1
- from bson import ObjectId
 
 
2
  from fastapi import FastAPI, Response, Body, Security
3
  from fastapi.security import APIKeyHeader
4
  from pydantic import BaseModel, model_validator
@@ -8,6 +10,9 @@ import json
8
  from conversation.conversation_store import ConversationStore
9
  from rag_langchain import LangChainRAG
10
 
 
 
 
11
  api = FastAPI()
12
 
13
  conversation_store = ConversationStore()
@@ -59,7 +64,6 @@ async def read_root():
59
 
60
  @api.post("/qa", response_model=AModel)
61
  async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
62
- # Verify the API key
63
  if not valid_api_key(api_key):
64
  return Response(status_code=401)
65
 
@@ -73,7 +77,7 @@ async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
73
  }
74
  )
75
 
76
- answer, check_result, sources = rag.rag_chain(data.q, data.llm)
77
 
78
  oid = conversation_store.save_content(
79
  q=data.q,
@@ -100,21 +104,20 @@ async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
100
 
101
  @api.post("/emo")
102
  async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)):
103
- # Verify the API key
104
  if not valid_api_key(api_key):
105
  return Response(status_code=401)
106
 
107
- qa = conversation_store.get(json_body.qid)
108
- new_params = qa.params
109
  new_params["user_grading"] = str(json_body.helpfulness)
110
  conversation_store.update(
111
  oid=json_body["qid"],
112
- q=qa.conversation[0].q,
113
- a=qa.conversation[0].a,
114
- sources=qa.conversation[0].sources,
115
  params=new_params
116
  )
117
 
118
 
119
- def valid_api_key(api_key: str):
120
- return api_key == "your_secret_api_key"
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
  from fastapi import FastAPI, Response, Body, Security
5
  from fastapi.security import APIKeyHeader
6
  from pydantic import BaseModel, model_validator
 
10
  from conversation.conversation_store import ConversationStore
11
  from rag_langchain import LangChainRAG
12
 
13
+ load_dotenv()
14
+ api_keys = [os.environ["API_API_KEY"]]
15
+
16
  api = FastAPI()
17
 
18
  conversation_store = ConversationStore()
 
64
 
65
  @api.post("/qa", response_model=AModel)
66
  async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
 
67
  if not valid_api_key(api_key):
68
  return Response(status_code=401)
69
 
 
77
  }
78
  )
79
 
80
+ answer, check_result, sources = await rag.rag_chain(data.q, data.llm)
81
 
82
  oid = conversation_store.save_content(
83
  q=data.q,
 
104
 
105
  @api.post("/emo")
106
  async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)):
 
107
  if not valid_api_key(api_key):
108
  return Response(status_code=401)
109
 
110
+ conversation = conversation_store.get(json_body.qid)
111
+ new_params = conversation.params
112
  new_params["user_grading"] = str(json_body.helpfulness)
113
  conversation_store.update(
114
  oid=json_body["qid"],
115
+ q=conversation.conversation[0].q,
116
+ a=conversation.conversation[0].a,
117
+ sources=conversation.conversation[0].sources,
118
  params=new_params
119
  )
120
 
121
 
122
+ def valid_api_key(api_key: str) -> bool:
123
+ return api_key in api_keys
rag.py CHANGED
@@ -1,7 +1,7 @@
1
  import datetime
2
  import os
3
  import traceback
4
- from typing import Any
5
 
6
  from dotenv import load_dotenv
7
  from langchain.chains import LLMChain
@@ -91,9 +91,9 @@ def rag_with_rerank_check_multi_query_retriever(agent: Agent, q: str, retrieve_d
91
  return answer, check_result, context_doc
92
 
93
 
94
- def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
95
- check_prompt: str):
96
- result = create_retrieval_chain(
97
  retriever=hyde_2_retrieval(agent, retrieve_document_count),
98
  combine_docs_chain=create_stuff_documents_chain(
99
  llm=agent.llm,
@@ -103,7 +103,7 @@ def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
103
  ),
104
  document_prompt=PromptTemplate(input_variables=[], template="page_content")
105
  )
106
- ).invoke(
107
  input={
108
  "question": q,
109
  "input": q,
@@ -208,7 +208,7 @@ def hyde_2_retrieval(agent, retrieve_document_count):
208
  llm=agent.llm,
209
  retriever=agent.embedding.get_vector_store().as_retriever(
210
  search_type="similarity",
211
- search_kwargs={"k": min(retrieve_document_count * 10, 500)}
212
  ),
213
  prompt=PromptTemplate(
214
  input_variables=["question"],
@@ -224,7 +224,7 @@ def hyde_2_retrieval(agent, retrieve_document_count):
224
  llm=agent.llm,
225
  retriever=agent.embedding.get_vector_store().as_retriever(
226
  search_type="similarity",
227
- search_kwargs={"k": min(retrieve_document_count * 10, 500)}
228
  ),
229
  prompt=PromptTemplate(
230
  input_variables=["question"],
@@ -240,7 +240,7 @@ def hyde_2_retrieval(agent, retrieve_document_count):
240
  llm=agent.llm,
241
  retriever=agent.embedding.get_vector_store().as_retriever(
242
  search_type="similarity",
243
- search_kwargs={"k": min(retrieve_document_count * 10, 500)}
244
  ),
245
  prompt=PromptTemplate(
246
  input_variables=["question"],
 
1
  import datetime
2
  import os
3
  import traceback
4
+ from typing import Any, Coroutine
5
 
6
  from dotenv import load_dotenv
7
  from langchain.chains import LLMChain
 
91
  return answer, check_result, context_doc
92
 
93
 
94
+ async def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
95
+ check_prompt: str):
96
+ result = await create_retrieval_chain(
97
  retriever=hyde_2_retrieval(agent, retrieve_document_count),
98
  combine_docs_chain=create_stuff_documents_chain(
99
  llm=agent.llm,
 
103
  ),
104
  document_prompt=PromptTemplate(input_variables=[], template="page_content")
105
  )
106
+ ).ainvoke(
107
  input={
108
  "question": q,
109
  "input": q,
 
208
  llm=agent.llm,
209
  retriever=agent.embedding.get_vector_store().as_retriever(
210
  search_type="similarity",
211
+ search_kwargs={"k": min(retrieve_document_count * 10, 300)}
212
  ),
213
  prompt=PromptTemplate(
214
  input_variables=["question"],
 
224
  llm=agent.llm,
225
  retriever=agent.embedding.get_vector_store().as_retriever(
226
  search_type="similarity",
227
+ search_kwargs={"k": min(retrieve_document_count * 10, 300)}
228
  ),
229
  prompt=PromptTemplate(
230
  input_variables=["question"],
 
240
  llm=agent.llm,
241
  retriever=agent.embedding.get_vector_store().as_retriever(
242
  search_type="similarity",
243
+ search_kwargs={"k": min(retrieve_document_count * 10, 300)}
244
  ),
245
  prompt=PromptTemplate(
246
  input_variables=["question"],
rag_langchain.py CHANGED
@@ -102,13 +102,13 @@ class LangChainRAG:
102
  def get_llms(self):
103
  return self.llms.keys()
104
 
105
- def rag_chain(self, query, llm_choice):
106
  print("Using " + llm_choice)
107
 
108
  # answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
109
  # answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
110
  # answer, check_result, context_doc = vanilla_rag_chain(
111
- answer, check_result, context_doc = rag_chain(
112
  Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
113
  query,
114
  self.config["retrieve_documents"],
 
102
  def get_llms(self):
103
  return self.llms.keys()
104
 
105
+ async def rag_chain(self, query, llm_choice):
106
  print("Using " + llm_choice)
107
 
108
  # answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
109
  # answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
110
  # answer, check_result, context_doc = vanilla_rag_chain(
111
+ answer, check_result, context_doc = await rag_chain(
112
  Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
113
  query,
114
  self.config["retrieve_documents"],