Spaces:
Build error
Build error
1st
Browse files- Document_QA.py +57 -46
- app.py +9 -17
- requirements.txt +1 -0
Document_QA.py
CHANGED
|
@@ -6,40 +6,43 @@ import pickle
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
import argparse
|
| 8 |
import os
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""Create embeddings for the provided input."""
|
| 12 |
# input = ['ddd','aaa','ccccccccccccccc','ddddd']
|
| 13 |
result = []
|
| 14 |
-
# limit about 1000 tokens per request
|
| 15 |
-
# 记录文章每行的长度
|
| 16 |
-
# 0 [100]
|
| 17 |
-
# 1 [200]
|
| 18 |
-
# 2 [4100]
|
| 19 |
-
# 3 [999]
|
| 20 |
-
lens = [len(text) for text in input]
|
| 21 |
-
query_len = 0
|
| 22 |
-
start_index = 0
|
| 23 |
tokens = 0
|
| 24 |
|
| 25 |
def get_embedding(input_slice):
|
|
|
|
| 26 |
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
|
| 27 |
-
#返回了(文字,embedding)和文字的token
|
| 28 |
return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens
|
| 29 |
-
|
| 30 |
-
for
|
| 31 |
-
|
| 32 |
-
if query_len > 4096:
|
| 33 |
-
ebd, tk = get_embedding(input[start_index:index + 1])
|
| 34 |
-
query_len = 0
|
| 35 |
-
start_index = index + 1
|
| 36 |
-
tokens += tk
|
| 37 |
-
result.extend(ebd)
|
| 38 |
-
|
| 39 |
-
if query_len > 0:
|
| 40 |
-
ebd, tk = get_embedding(input[start_index:])
|
| 41 |
tokens += tk
|
| 42 |
result.extend(ebd)
|
|
|
|
| 43 |
return result, tokens
|
| 44 |
|
| 45 |
def create_embedding(text):
|
|
@@ -58,33 +61,35 @@ class QA():
|
|
| 58 |
self.index = index
|
| 59 |
#所有文字
|
| 60 |
self.data = data
|
|
|
|
|
|
|
| 61 |
def __call__(self, query):
|
| 62 |
embedding = create_embedding(query)
|
| 63 |
#输出与用户的问题相关的文字
|
| 64 |
-
context = self.get_texts(embedding[1]
|
| 65 |
#将用户的问题和涉及的文字告诉gpt,并将答案返回
|
| 66 |
answer = self.completion(query,context)
|
| 67 |
return answer,context
|
| 68 |
-
def get_texts(self,embeding,limit):
|
| 69 |
_,text_index = self.index.search(np.array([embeding]),limit)
|
| 70 |
context = []
|
| 71 |
for i in list(text_index[0]):
|
| 72 |
-
context.extend(self.data[i:i+
|
| 73 |
# context = [self.data[i] for i in list(text_index[0])]
|
| 74 |
#输出与用户的问题相关的文字
|
| 75 |
return context
|
| 76 |
|
| 77 |
def completion(self,query, context):
|
| 78 |
"""Create a completion."""
|
| 79 |
-
lens = [len(text) for text in context]
|
| 80 |
|
| 81 |
-
maximum = 3000
|
| 82 |
-
for index, l in enumerate(lens):
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
|
| 89 |
text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
|
| 90 |
response = openai.ChatCompletion.create(
|
|
@@ -100,24 +105,30 @@ class QA():
|
|
| 100 |
|
| 101 |
if __name__ == '__main__':
|
| 102 |
parser = argparse.ArgumentParser(description="Document QA")
|
| 103 |
-
parser.add_argument("--input_file", default="
|
| 104 |
-
parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径")
|
| 105 |
parser.add_argument("--print_context", action='store_true',help="是否打印上下文")
|
| 106 |
|
| 107 |
|
| 108 |
args = parser.parse_args()
|
| 109 |
|
| 110 |
-
if os.path.isfile(args.file_embeding):
|
| 111 |
-
|
| 112 |
-
else:
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
|
|
|
|
|
|
| 121 |
qa =QA(data_embe)
|
| 122 |
|
| 123 |
limit = 10
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
import argparse
|
| 8 |
import os
|
| 9 |
+
from PyPDF2 import PdfReader
|
| 10 |
|
| 11 |
+
class Paper(object):
|
| 12 |
+
|
| 13 |
+
def __init__(self, pdf_path) -> None:
|
| 14 |
+
self._pdf_obj = PdfReader(pdf_path)
|
| 15 |
+
self._paper_meta = self._pdf_obj.metadata
|
| 16 |
+
self.texts = []
|
| 17 |
+
|
| 18 |
+
def iter_pages(self, iter_text_len: int = 1000):
|
| 19 |
+
page_idx = 0
|
| 20 |
+
for page in self._pdf_obj.pages:
|
| 21 |
+
txt = page.extract_text()
|
| 22 |
+
for i in range((len(txt) // iter_text_len) + 1):
|
| 23 |
+
yield page_idx, i, txt[i * iter_text_len:(i + 1) * iter_text_len]
|
| 24 |
+
page_idx += 1
|
| 25 |
+
def get_texts(self):
|
| 26 |
+
for (page_idx, part_idx, text) in self.iter_pages():
|
| 27 |
+
self.texts.append(text.strip())
|
| 28 |
+
return self.texts
|
| 29 |
+
|
| 30 |
+
def create_embeddings(inputs):
|
| 31 |
"""Create embeddings for the provided input."""
|
| 32 |
# input = ['ddd','aaa','ccccccccccccccc','ddddd']
|
| 33 |
result = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
tokens = 0
|
| 35 |
|
| 36 |
def get_embedding(input_slice):
|
| 37 |
+
input_slice = [input_slice]
|
| 38 |
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
|
|
|
|
| 39 |
return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens
|
| 40 |
+
|
| 41 |
+
for i in range(0,len(inputs)):
|
| 42 |
+
ebd, tk = get_embedding(inputs[i])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
tokens += tk
|
| 44 |
result.extend(ebd)
|
| 45 |
+
|
| 46 |
return result, tokens
|
| 47 |
|
| 48 |
def create_embedding(text):
|
|
|
|
| 61 |
self.index = index
|
| 62 |
#所有文字
|
| 63 |
self.data = data
|
| 64 |
+
print("now all data is:\n",self.data)
|
| 65 |
+
|
| 66 |
def __call__(self, query):
|
| 67 |
embedding = create_embedding(query)
|
| 68 |
#输出与用户的问题相关的文字
|
| 69 |
+
context = self.get_texts(embedding[1])
|
| 70 |
#将用户的问题和涉及的文字告诉gpt,并将答案返回
|
| 71 |
answer = self.completion(query,context)
|
| 72 |
return answer,context
|
| 73 |
+
def get_texts(self,embeding,limit=5):
|
| 74 |
_,text_index = self.index.search(np.array([embeding]),limit)
|
| 75 |
context = []
|
| 76 |
for i in list(text_index[0]):
|
| 77 |
+
context.extend(self.data[i:i+2])
|
| 78 |
# context = [self.data[i] for i in list(text_index[0])]
|
| 79 |
#输出与用户的问题相关的文字
|
| 80 |
return context
|
| 81 |
|
| 82 |
def completion(self,query, context):
|
| 83 |
"""Create a completion."""
|
| 84 |
+
# lens = [len(text) for text in context]
|
| 85 |
|
| 86 |
+
# maximum = 3000
|
| 87 |
+
# for index, l in enumerate(lens):
|
| 88 |
+
# maximum -= l
|
| 89 |
+
# if maximum < 0:
|
| 90 |
+
# context = context[:index + 1]
|
| 91 |
+
# print("超过最大长度,截断到前", index + 1, "个片段")
|
| 92 |
+
# break
|
| 93 |
|
| 94 |
text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
|
| 95 |
response = openai.ChatCompletion.create(
|
|
|
|
| 105 |
|
| 106 |
if __name__ == '__main__':
|
| 107 |
parser = argparse.ArgumentParser(description="Document QA")
|
| 108 |
+
parser.add_argument("--input_file", default="slimming-pages-1.pdf", dest="input_file", type=str,help="输入文件路径")
|
| 109 |
+
# parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径")
|
| 110 |
parser.add_argument("--print_context", action='store_true',help="是否打印上下文")
|
| 111 |
|
| 112 |
|
| 113 |
args = parser.parse_args()
|
| 114 |
|
| 115 |
+
# if os.path.isfile(args.file_embeding):
|
| 116 |
+
# data_embe = pickle.load(open(args.file_embeding,'rb'))
|
| 117 |
+
# else:
|
| 118 |
+
# with open(args.input_file,'r',encoding='utf-8') as f:
|
| 119 |
+
# texts = f.readlines()
|
| 120 |
+
# #按照行对文章进行切割
|
| 121 |
+
# texts = [text.strip() for text in texts if text.strip()]
|
| 122 |
+
# data_embe,tokens = create_embeddings(texts)
|
| 123 |
+
# pickle.dump(data_embe,open(args.file_embeding,'wb'))
|
| 124 |
+
# print("文本消耗 {} tokens".format(tokens))
|
| 125 |
+
|
| 126 |
+
paper = Paper(args.input_file)
|
| 127 |
+
all_texts = paper.get_texts()
|
| 128 |
+
|
| 129 |
|
| 130 |
+
data_embe, tokens = create_embeddings(all_texts)
|
| 131 |
+
print("全部文本消耗 {} tokens".format(tokens))
|
| 132 |
qa =QA(data_embe)
|
| 133 |
|
| 134 |
limit = 10
|
app.py
CHANGED
|
@@ -4,6 +4,8 @@ import openai
|
|
| 4 |
# from gpt_reader.prompt import BASE_POINTS
|
| 5 |
from Document_QA import QA
|
| 6 |
from Document_QA import create_embeddings
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class GUI:
|
| 9 |
def __init__(self):
|
|
@@ -14,27 +16,17 @@ class GUI:
|
|
| 14 |
#load pdf and create all embedings
|
| 15 |
def pdf_init(self, api_key, pdf_path):
|
| 16 |
openai.api_key = api_key
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
def get_answer(self, question):
|
| 24 |
qa = QA(self.all_embedding)
|
| 25 |
answer,context = qa(question)
|
| 26 |
return answer.strip()
|
| 27 |
|
| 28 |
-
# def analyse(self, api_key, pdf_file):
|
| 29 |
-
# self.session = PaperReader(api_key, points_to_focus=BASE_POINTS)
|
| 30 |
-
# return self.session.read_pdf_and_summarize(pdf_file)
|
| 31 |
-
|
| 32 |
-
# def ask_question(self, question):
|
| 33 |
-
# if self.session == "":
|
| 34 |
-
# return "Please upload PDF file first!"
|
| 35 |
-
# return self.session.question(question)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
with gr.Blocks() as demo:
|
| 39 |
gr.Markdown(
|
| 40 |
"""
|
|
@@ -57,4 +49,4 @@ with gr.Blocks() as demo:
|
|
| 57 |
|
| 58 |
if __name__ == "__main__":
|
| 59 |
demo.title = "CHATGPT-PAPER-READER"
|
| 60 |
-
demo.launch(
|
|
|
|
| 4 |
# from gpt_reader.prompt import BASE_POINTS
|
| 5 |
from Document_QA import QA
|
| 6 |
from Document_QA import create_embeddings
|
| 7 |
+
from Document_QA import Paper
|
| 8 |
+
from PyPDF2 import PdfReader
|
| 9 |
|
| 10 |
class GUI:
|
| 11 |
def __init__(self):
|
|
|
|
| 16 |
#load pdf and create all embedings
|
| 17 |
def pdf_init(self, api_key, pdf_path):
|
| 18 |
openai.api_key = api_key
|
| 19 |
+
pdf_reader = PdfReader(pdf_path)
|
| 20 |
+
paper = Paper(pdf_reader)
|
| 21 |
+
all_texts = paper.get_texts()
|
| 22 |
+
self.all_embedding, self.tokens = create_embeddings(all_texts)
|
| 23 |
+
print("全部文本消耗 {} tokens".format(self.tokens))
|
| 24 |
+
|
| 25 |
def get_answer(self, question):
|
| 26 |
qa = QA(self.all_embedding)
|
| 27 |
answer,context = qa(question)
|
| 28 |
return answer.strip()
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
with gr.Blocks() as demo:
|
| 31 |
gr.Markdown(
|
| 32 |
"""
|
|
|
|
| 49 |
|
| 50 |
if __name__ == "__main__":
|
| 51 |
demo.title = "CHATGPT-PAPER-READER"
|
| 52 |
+
demo.launch() # add "share=True" to share CHATGPT-PAPER-READER app on Internet.
|
requirements.txt
CHANGED
|
@@ -2,3 +2,4 @@ numpy
|
|
| 2 |
faiss-cpu
|
| 3 |
tqdm
|
| 4 |
openai
|
|
|
|
|
|
| 2 |
faiss-cpu
|
| 3 |
tqdm
|
| 4 |
openai
|
| 5 |
+
PyPDF2
|