| import time |
| import os |
| import gradio as gr |
| from typing import List, Optional |
|
|
| import langchain_core.callbacks |
| import markdown_it.cli.parse |
| from langchain_huggingface import HuggingFaceEndpoint |
|
|
| from langchain.schema import BaseMessage |
| from langchain_core.chat_history import BaseChatMessageHistory |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_core.runnables import ( |
| ConfigurableFieldSpec, |
| ) |
| from langchain_core.runnables.history import RunnableWithMessageHistory |
|
|
| from pydantic import BaseModel, Field |
|
|
|
|
| class InMemoryHistory(BaseChatMessageHistory, BaseModel): |
| """In memory implementation of chat message history.""" |
|
|
| messages: List[BaseMessage] = Field(default_factory=list) |
|
|
| def add_messages(self, messages: List[BaseMessage]) -> None: |
| """Add a list of messages to the store""" |
| self.messages.extend(messages) |
|
|
| def clear(self) -> None: |
| self.messages = [] |
|
|
| |
| store = {} |
| bot_llm:Optional[RunnableWithMessageHistory] = None |
|
|
| def get_session_history( |
| user_id: str, conversation_id: str |
| ) -> BaseChatMessageHistory: |
| if (user_id, conversation_id) not in store: |
| store[(user_id, conversation_id)] = InMemoryHistory() |
| return store[(user_id, conversation_id)] |
|
|
|
|
| def init_llm(k, p, t): |
| global bot_llm |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", "[INST] You're an assistant who's good at everything"), |
| MessagesPlaceholder(variable_name="history"), |
| ("human", "{question} [/INST]"), |
| ]) |
|
|
| model_id="mistralai/Mistral-7B-Instruct-v0.3" |
| callbacks = [langchain_core.callbacks.StreamingStdOutCallbackHandler()] |
|
|
| llm = HuggingFaceEndpoint( |
| repo_id=model_id, |
| max_new_tokens=4096, |
| temperature=t, |
| top_p=p, |
| top_k=k, |
| repetition_penalty=1.03, |
| callbacks=callbacks, |
| streaming=True, |
| huggingfacehub_api_token=os.getenv('HF_TOKEN'), |
| ) |
|
|
| chain = prompt | llm |
| with_message_history = RunnableWithMessageHistory( |
| chain, |
| get_session_history=get_session_history, |
| input_messages_key="question", |
| history_messages_key="history", |
| history_factory_config=[ |
| ConfigurableFieldSpec( |
| id="user_id", |
| annotation=str, |
| name="User ID", |
| description="Unique identifier for the user.", |
| default="", |
| is_shared=True, |
| ), |
| ConfigurableFieldSpec( |
| id="conversation_id", |
| annotation=str, |
| name="Conversation ID", |
| description="Unique identifier for the conversation.", |
| default="", |
| is_shared=True, |
| ), |
| ], |
| ) |
| bot_llm = with_message_history |
| return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(open=False) |
|
|
| with gr.Blocks() as demo: |
| gr.HTML("<center><h1>Chat with a Smart Assistant</h1></center>") |
| chatbot = gr.Chatbot(type="messages") |
| msg = gr.Textbox(placeholder="Enter text and press enter", interactive=False) |
| stop = gr.Button("Stop", interactive=False) |
| clear = gr.Button("Clear",interactive=False) |
|
|
| def user(user_message, history: list): |
| return "", history + [{"role": "user", "content": user_message}] |
|
|
| def bot(history: list): |
| question = history[-1]['content'] |
| answer = bot_llm.stream( |
| {"ability": "everything", "question": question}, |
| config={"configurable": {"user_id": "123", "conversation_id": "1"}} |
| ) |
| history.append({"role": "assistant", "content": ""}) |
| for character in answer: |
| history[-1]['content'] += character |
| time.sleep(0.05) |
| yield history |
|
|
| with gr.Sidebar() as s: |
| gr.HTML("<h1>Model Configuration<h1>") |
| k = gr.Slider(0.0, 100.0, label="top_k", value=50, interactive=True, |
| info="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)") |
| p = gr.Slider(0.0, 1.0, label="top_p", value=0.9, interactive=True, |
| info=" Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)") |
| t = gr.Slider(0.0, 1.0, label="temperature", value=0.4, interactive=True, |
| info="The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)") |
|
|
| bnt1 = gr.Button("Confirm") |
| bnt1.click(init_llm, inputs=[k, p, t], outputs=[msg, stop, clear, s]) |
|
|
| submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then( |
| bot, chatbot, chatbot |
| ) |
|
|
| stop.click(None, None, None, cancels=[submit_event], queue=False) |
| clear.click(lambda: None, None, chatbot, queue=True) |
|
|
| if __name__ == "__main__": |
| demo.launch() |