| import gradio as gr |
| from langchain.document_loaders import PDFMinerLoader, PyMuPDFLoader |
| from langchain.text_splitter import CharacterTextSplitter |
| import chromadb |
| import chromadb.config |
| from chromadb.config import Settings |
| from transformers import T5ForConditionalGeneration, AutoTokenizer |
| import torch |
| import uuid |
| from sentence_transformers import SentenceTransformer |
| import os |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import transformers |
| import torch |
|
|
| |
| model_name = 'google/flan-t5-base' |
| model = T5ForConditionalGeneration.from_pretrained(model_name) |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
| |
| ST_name = 'sentence-transformers/sentence-t5-base' |
| st_model = SentenceTransformer(ST_name) |
|
|
| |
| client = chromadb.Client() |
| collection = client.create_collection("my_db") |
|
|
|
|
| def get_context(query_text): |
| ''' |
| Given query in tokenized format, find its embeddings |
| Search in Chroma DB |
| and return results |
| ''' |
| |
| query_emb = st_model.encode(query_text) |
| query_response = collection.query(query_embeddings=query_emb.tolist(), n_results=4) |
| context = query_response['documents'][0][0] |
| context = context.replace('\n', ' ').replace(' ', ' ') |
| return context |
|
|
|
|
|
|
| def local_query(query, context): |
| ''' |
| Given query (user response) |
| Construct LLM query adding context to it |
| Return response of LLM |
| ''' |
|
|
| |
| t5query = """Please answer the question based on the given context. |
| If you are not sure about your response, say I am not sure. |
| Context: {} |
| Question: {} |
| """.format(context, query) |
|
|
| |
| inputs = tokenizer(t5query, return_tensors="pt") |
|
|
| outputs = model.generate(**inputs, max_new_tokens=20) |
| |
| return tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
|
| |
| |
|
|
| def run_query(history, query): |
| ''' |
| Run Gradio ChatInterface |
| Given user response (query), find the most similar/related part to the question from the uploaded document |
| Using Chroma search |
| Update the query with context, and ask the question to LLM |
| ''' |
|
|
| context = get_context(query) |
| result = local_query(query, context) |
| |
|
|
| history.append((query, str(result[0]))) |
| |
|
|
| return history, "" |
|
|
|
|
|
|
| def upload_pdf(file): |
| ''' |
| Upload a PDF |
| Split into chunks |
| Encode each chunk into embeddings |
| Assign a unique ID for each chunk embedding |
| Construct Chroma DB |
| Update your global Chroma DB collection |
| ''' |
| try: |
| if file is not None: |
|
|
| global collection |
| |
| file_name = file.name |
|
|
| |
| loader = PDFMinerLoader(file_name) |
| doc = loader.load() |
|
|
| |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=10) |
| texts = text_splitter.split_documents(doc) |
| |
| texts = [i.page_content for i in texts] |
|
|
| |
| doc_emb = st_model.encode(texts) |
| doc_emb = doc_emb.tolist() |
|
|
| |
| ids = [str(uuid.uuid1()) for _ in doc_emb] |
| |
| |
| collection.add( |
| embeddings=doc_emb, |
| documents=texts, |
| ids=ids |
| ) |
|
|
| |
| return 'Successfully uploaded!' |
| else: |
| return "No file uploaded." |
|
|
| except Exception as e: |
| return f"An error occurred: {e}" |
|
|
|
|
|
|
| |
|
|
| with gr.Blocks() as demo: |
| ''' |
| Frontend for our tool |
| ''' |
|
|
| |
| btn = gr.UploadButton("Upload a PDF", file_types=[".pdf"]) |
| output = gr.Textbox(label="Output Box") |
| chatbot = gr.Chatbot(height=240) |
| |
| with gr.Row(): |
| with gr.Column(scale=0.70): |
| txt = gr.Textbox( |
| show_label=False, |
| placeholder="Type a question", |
| ) |
|
|
|
|
| |
| |
| btn.upload(fn=upload_pdf, inputs=[btn], outputs=[output]) |
| txt.submit(run_query, [chatbot, txt], [chatbot, txt]) |
|
|
|
|
| gr.close_all() |
| demo.queue().launch() |
|
|