| import argparse |
| import os |
| from dotenv import load_dotenv |
|
|
| from langchain.globals import set_debug |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain_core.output_parsers import StrOutputParser |
|
|
| from lib.repository import download_github_repo |
| from lib.loader import load_files |
| from lib.chain import create_retriever, create_qa_chain |
| from lib.utils import read_prompt, load_LLM, select_model |
| from lib.models import MODELS_MAP |
|
|
| import time |
| import gradio as gr |
|
|
| def slow_echo(message, history): |
| for i in range(len(message)): |
| time.sleep(0.05) |
| yield message[: i + 1] |
|
|
| |
|
|
| def build(): |
| with gr.Blocks() as demo: |
| repo_url = gr.Textbox(label="Repo URL", placeholder="Enter the repository URL here...") |
| submit_btn = gr.Button("Submit Repo URL") |
|
|
| user_input = gr.Textbox(label="User Input", placeholder="Enter your question here...") |
| chat_output = gr.Textbox(label="Chat Output", placeholder="The answer will appear here...") |
| |
|
|
| def update_repo_url(new_url): |
| updated_url = main(new_url) |
| return updated_url |
|
|
| def generate_answer(user_input): |
| answer = qa_chain.invoke(user_input) |
| print(f"Answer: {answer}") |
| return answer['output'] |
|
|
| submit_btn.click(update_repo_url, inputs=repo_url, outputs=repo_url) |
| user_input_submit_btn = gr.Button("Submit Question") |
| user_input_submit_btn.click(generate_answer, inputs=user_input, outputs=chat_output) |
|
|
| demo.launch() |
|
|
| def main(repo_url): |
| |
| model_name = select_model() |
| model_info = MODELS_MAP[model_name] |
| repo_name = repo_url.split("/")[-1].replace(".git", "") |
|
|
| |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| repo_dir = os.path.join(base_dir, "data", repo_name) |
| db_dir = os.path.join(base_dir, "data", "db") |
| prompt_templates_dir = os.path.join(base_dir, "prompt_templates") |
|
|
| |
| print(f"Downloading repository from {repo_url}...") |
| download_github_repo(repo_url, repo_dir) |
|
|
| |
| prompts_text = { |
| "initial_prompt": read_prompt(os.path.join(prompt_templates_dir, 'initial_prompt.txt')), |
| "evaluation_prompt": read_prompt(os.path.join(prompt_templates_dir, 'evaluation_prompt.txt')), |
| } |
|
|
| |
| print(f"Loading documents from {repo_dir}...") |
| document_chunks = load_files(repository_path=repo_dir) |
| print(f"Created chunks length is: {len(document_chunks)}") |
|
|
| |
| print(f"Creating retrieval QA chain using {model_name}...") |
| llm = load_LLM(model_name) |
| retriever = create_retriever(model_name, db_dir, document_chunks) |
| global qa_chain |
| qa_chain = create_qa_chain(llm, retriever, prompts_text) |
| print(f"Ready to chat!") |
| return repo_url |
|
|
| if __name__ == "__main__": |
| build() |