RKP64 commited on
Commit
627f7b8
·
1 Parent(s): e3b6978

Upload question.py

Browse files
Files changed (1) hide show
  1. question.py +81 -0
question.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anthropic
2
+ import streamlit as st
3
+ from streamlit.logger import get_logger
4
+ from langchain.chains import ConversationalRetrievalChain
5
+ from langchain.memory import ConversationBufferMemory
6
+ from langchain.llms import OpenAI
7
+ from langchain.chat_models import ChatAnthropic
8
+ from langchain.vectorstores import SupabaseVectorStore
9
+ from stats import add_usage
10
+
11
+ memory = ConversationBufferMemory(
12
+ memory_key="chat_history", return_messages=True)
13
+ openai_api_key = st.secrets.openai_api_key
14
+ anthropic_api_key = st.secrets.anthropic_api_key
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ def count_tokens(question, model):
19
+ count = f'Words: {len(question.split())}'
20
+ if model.startswith("claude"):
21
+ count += f' | Tokens: {anthropic.count_tokens(question)}'
22
+ return count
23
+
24
+
25
+ def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
26
+
27
+ if 'chat_history' not in st.session_state:
28
+ st.session_state['chat_history'] = []
29
+
30
+
31
+
32
+ question = st.text_area("## Ask a question")
33
+ columns = st.columns(3)
34
+ with columns[0]:
35
+ button = st.button("Ask")
36
+ with columns[1]:
37
+ count_button = st.button("Count Tokens", type='secondary')
38
+ with columns[2]:
39
+ clear_history = st.button("Clear History", type='secondary')
40
+
41
+
42
+
43
+ if clear_history:
44
+ # Clear memory in Langchain
45
+ memory.clear()
46
+ st.session_state['chat_history'] = []
47
+ st.experimental_rerun()
48
+
49
+ if button:
50
+ qa = None
51
+ if not st.session_state["overused"]:
52
+ add_usage(stats_db, "chat", "prompt" + question, {"model": model, "temperature": st.session_state['temperature']})
53
+ if model.startswith("gpt"):
54
+ logger.info('Using OpenAI model %s', model)
55
+ qa = ConversationalRetrievalChain.from_llm(
56
+ OpenAI(
57
+ model_name=st.session_state['model'], openai_api_key=openai_api_key, temperature=st.session_state['temperature'], max_tokens=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True)
58
+ elif anthropic_api_key and model.startswith("claude"):
59
+ logger.info('Using Anthropics model %s', model)
60
+ qa = ConversationalRetrievalChain.from_llm(
61
+ ChatAnthropic(
62
+ model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400)
63
+
64
+
65
+ st.session_state['chat_history'].append(("You", question))
66
+
67
+ # Generate model's response and add it to chat history
68
+ model_response = qa({"question": question})
69
+ logger.info('Result: %s', model_response)
70
+
71
+ st.session_state['chat_history'].append(("KPMG GPT", model_response["answer"]))
72
+
73
+ # Display chat history
74
+ st.empty()
75
+ for speaker, text in st.session_state['chat_history']:
76
+ st.markdown(f"**{speaker}:** {text}")
77
+ else:
78
+ st.error("You have used all your free credits. Please try again later or self host.")
79
+
80
+ if count_button:
81
+ st.write(count_tokens(question, model))