| import gradio as gr |
| from transformers import pipeline |
| from huggingface_hub import InferenceClient, login, snapshot_download |
| from langchain_community.vectorstores import FAISS, DistanceStrategy |
| from langchain_huggingface import HuggingFaceEmbeddings |
| import os |
| import pandas as pd |
| from datetime import datetime |
|
|
| from smolagents import Tool, HfApiModel, ToolCallingAgent |
| from langchain_core.vectorstores import VectorStore |
|
|
|
|
| class RetrieverTool(Tool): |
| name = "retriever" |
| description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query." |
| inputs = { |
| "query": { |
| "type": "string", |
| "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", |
| } |
| } |
| output_type = "string" |
|
|
| def __init__(self, vectordb: VectorStore, **kwargs): |
| super().__init__(**kwargs) |
| self.vectordb = vectordb |
|
|
| def forward(self, query: str) -> str: |
| assert isinstance(query, str), "Your search query must be a string" |
|
|
| docs = self.vectordb.similarity_search( |
| query, |
| k=7, |
| ) |
|
|
| spacer = " \n" |
| context = "" |
| nb_char = 100 |
| |
| for doc in docs: |
| case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0] |
| index = case_text.find(doc.page_content) |
| start = max(0, index - nb_char) |
| end = min(len(case_text), index + len(doc.page_content) + nb_char) |
| case_text_summary = case_text[start:end] |
| |
| context += "#######" + spacer |
| context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer |
| context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer |
| context += "# Case date: " + doc.metadata["case_date"] + spacer |
| context += "# Case url: " + doc.metadata["case_url"] + spacer |
| |
| context += "# Case extract: " + case_text_summary + spacer |
|
|
|
|
| return "\nRetrieved documents:\n" + context |
|
|
|
|
| """ |
| For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
| """ |
| HF_TOKEN=os.getenv('TOKEN') |
| login(HF_TOKEN) |
|
|
| model = "meta-llama/Meta-Llama-3-8B-Instruct" |
| |
|
|
| client = InferenceClient(model) |
|
|
| folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd()) |
|
|
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") |
|
|
| vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE) |
|
|
| df = pd.read_csv("bger_cedh_db 1954-2024.csv") |
|
|
| retriever_tool = RetrieverTool(vector_db) |
| agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model)) |
|
|
| def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,): |
|
|
| print(datetime.now()) |
| context = retriever_tool(message) |
| |
| print(message) |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| if True: |
| prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question. |
| Respond only to the question asked, response should be relevant to the question and in the same language as the question. |
| Provide the number of the source document when relevant, as well as the link to the document. |
| If you cannot find information, do not give up and try calling your retriever again with different arguments! |
| Always give url of the sources at the end and only answer in the language the question is asked. |
| |
| Question: |
| {message} |
| |
| {context} |
| """ |
| else: |
| prompt = f"""A user wrote the following message, please answer him to best of your knowledge in the language of his message: |
| {message}""" |
| |
| messages = [{"role": "system", "content": system_message}] |
|
|
| for val in history: |
| if val[0]: |
| messages.append({"role": "user", "content": val[0]}) |
| if val[1]: |
| messages.append({"role": "assistant", "content": val[1]}) |
|
|
| messages.append({"role": "user", "content": prompt}) |
|
|
| response = "" |
|
|
|
|
| for message in client.chat_completion( |
| messages, |
| max_tokens=max_tokens, |
| stream=True, |
| temperature=temperature, |
| top_p=top_p, |
| ): |
| token = message.choices[0].delta.content |
| |
| response += token |
| yield response |
|
|
|
|
| """ |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
| """ |
| demo = gr.ChatInterface( |
| respond, |
| additional_inputs=[ |
| gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"), |
| gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"), |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.95, |
| step=0.05, |
| label="Top-p (nucleus sampling)", |
| ), |
| gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"), |
| ], |
| description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| print("Ready!") |
| demo.launch(debug=True) |