|
|
|
|
|
|
| """
|
| High-Performance Chat Interface for LM Studio
|
|
|
| This script creates a robust and efficient chat interface using Gradio,
|
| facilitating seamless interactions with the LM Studio API. It leverages
|
| GPU capabilities for accelerated processing and adheres to best practices
|
| in modern Python programming. Comprehensive logging and error handling
|
| ensure reliability and ease of maintenance.
|
|
|
| Author: Your Name
|
| Date: YYYY-MM-DD
|
| """
|
|
|
| import gradio as gr
|
| import httpx
|
| import logging
|
| import json
|
| import os
|
| import numpy as np
|
| import torch
|
| import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.DEBUG,
|
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| )
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| BASE_URL = os.getenv("LMSTUDIO_API_BASE_URL", "http://localhost:1234/v1")
|
|
|
|
|
| USE_GPU = torch.cuda.is_available()
|
| DEVICE = torch.device("cuda" if USE_GPU else "cpu")
|
| logger.info(f"Using device: {DEVICE}")
|
|
|
|
|
| MODEL_MAX_TOKENS = 32768
|
| AVERAGE_CHARS_PER_TOKEN = 4
|
| BUFFER_TOKENS = 2000
|
| MIN_OUTPUT_TOKENS = 1000
|
|
|
|
|
| MAX_EMBEDDINGS = 100
|
|
|
|
|
| HTTPX_TIMEOUT = 300
|
|
|
|
|
|
|
|
|
|
|
| def calculate_max_tokens(message, model_max_tokens=MODEL_MAX_TOKENS,
|
| buffer=BUFFER_TOKENS, avg_chars_per_token=AVERAGE_CHARS_PER_TOKEN,
|
| min_tokens=MIN_OUTPUT_TOKENS):
|
| """
|
| Calculate the maximum number of tokens for the output based on the input message length.
|
|
|
| Args:
|
| message (str): The input message from the user.
|
| model_max_tokens (int): The total token capacity of the model.
|
| buffer (int): Reserved tokens for system prompts and overhead.
|
| avg_chars_per_token (int): Approximate number of characters per token.
|
| min_tokens (int): Minimum number of tokens to ensure a meaningful response.
|
|
|
| Returns:
|
| int: The calculated maximum tokens for the output.
|
| """
|
| input_length = len(message)
|
| input_tokens = input_length / avg_chars_per_token
|
| max_tokens = model_max_tokens - int(input_tokens) - buffer
|
| calculated_max = max(max_tokens, min_tokens)
|
| logger.debug(f"Input length (chars): {input_length}, "
|
| f"Estimated input tokens: {input_tokens}, "
|
| f"Max tokens for output: {calculated_max}")
|
| return calculated_max
|
|
|
| async def get_embeddings(text):
|
| """
|
| Retrieve embeddings for the given text from the LM Studio API.
|
|
|
| Args:
|
| text (str): The input text to generate embeddings for.
|
|
|
| Returns:
|
| list or None: The embedding vector as a list if successful, else None.
|
| """
|
| url = f"{BASE_URL}/embeddings"
|
| payload = {"model": "nomad_embed_text_v1_5_Q8_0", "input": text}
|
| logger.info(f"Requesting embeddings for input: {text[:100]}...")
|
| async with httpx.AsyncClient(timeout=HTTPX_TIMEOUT) as client:
|
| try:
|
| response = await client.post(
|
| url,
|
| json=payload,
|
| headers={
|
| "Content-Type": "application/json"
|
| }
|
| )
|
| logger.info(f"Embeddings response status code: {response.status_code}")
|
| response.raise_for_status()
|
| data = response.json()
|
| logger.debug(f"Embeddings response data: {data}")
|
| if "data" in data and len(data["data"]) > 0:
|
| embedding = np.array(data["data"][0]["embedding"])
|
| if USE_GPU:
|
| embedding = torch.tensor(embedding, device=DEVICE).tolist()
|
| return embedding
|
| else:
|
| logger.error("Invalid response structure for embeddings.")
|
| return None
|
| except httpx.RequestError as e:
|
| logger.error(f"Failed to retrieve embeddings: {e}")
|
| return None
|
| except httpx.HTTPStatusError as e:
|
| logger.error(f"HTTP error while retrieving embeddings: {e}")
|
| return None
|
| except json.JSONDecodeError as e:
|
| logger.error(f"JSON decode error: {e}")
|
| return None
|
|
|
| def calculate_similarity(vec1, vec2):
|
| """
|
| Calculate the cosine similarity between two vectors using GPU acceleration.
|
|
|
| Args:
|
| vec1 (list or torch.Tensor): The first embedding vector.
|
| vec2 (list or torch.Tensor): The second embedding vector.
|
|
|
| Returns:
|
| float: The cosine similarity score.
|
| """
|
| if vec1 is None or vec2 is None:
|
| logger.warning("One or both vectors for similarity calculation are None.")
|
| return 0.0
|
| logger.debug("Calculating similarity between vectors.")
|
| vec1_tensor = torch.tensor(vec1, device=DEVICE) if not isinstance(vec1, torch.Tensor) else vec1.to(DEVICE)
|
| vec2_tensor = torch.tensor(vec2, device=DEVICE) if not isinstance(vec2, torch.Tensor) else vec2.to(DEVICE)
|
| similarity = torch.nn.functional.cosine_similarity(vec1_tensor.unsqueeze(0), vec2_tensor.unsqueeze(0)).item()
|
| logger.debug(f"Calculated similarity: {similarity}")
|
| return similarity
|
|
|
|
|
|
|
|
|
|
|
| async def chat_with_lmstudio(messages, max_tokens):
|
| """
|
| Handle chat completions with the LM Studio API using streaming.
|
|
|
| Args:
|
| messages (list): A list of message dictionaries following OpenAI's format.
|
| max_tokens (int): The maximum number of tokens to generate in the response.
|
|
|
| Yields:
|
| str: Chunks of the generated response.
|
| """
|
| url = f"{BASE_URL}/chat/completions"
|
| payload = {
|
| "model": "Qwen2.5-Coder-32B-Instruct",
|
| "messages": messages,
|
| "temperature": 0.7,
|
| "max_tokens": max_tokens,
|
| "stream": True,
|
| }
|
| logger.info(f"Sending request to chat/completions with max_tokens: {max_tokens}")
|
| async with httpx.AsyncClient(timeout=HTTPX_TIMEOUT) as client:
|
| try:
|
| async with client.stream("POST", url, json=payload, headers={"Content-Type": "application/json"}) as response:
|
| logger.info(f"chat/completions response status code: {response.status_code}")
|
| response.raise_for_status()
|
| async for line in response.aiter_lines():
|
| if line:
|
| try:
|
| decoded_line = line.strip()
|
| if decoded_line.startswith("data: "):
|
| data = json.loads(decoded_line[6:])
|
| logger.debug(f"Received chunk: {data}")
|
| content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
| yield content
|
| except json.JSONDecodeError as e:
|
| logger.error(f"JSON decode error: {e}")
|
| except httpx.RequestError as e:
|
| logger.error(f"LM Studio chat/completions request failed: {e}")
|
| yield "An error occurred while generating a response."
|
| except httpx.HTTPStatusError as e:
|
| logger.error(f"HTTP error during chat/completions: {e}")
|
| yield "An HTTP error occurred while generating a response."
|
|
|
|
|
|
|
|
|
|
|
| def gradio_chat_interface():
|
| """
|
| Create and launch the Gradio Blocks interface for the chat application.
|
| """
|
| with gr.Blocks() as interface:
|
| gr.Markdown("# 🚀 High-Performance Chat Interface for LM Studio")
|
|
|
|
|
| chatbot = gr.Chatbot(label="Conversation", type="messages")
|
|
|
|
|
| user_input = gr.Textbox(
|
| label="Your Message",
|
| placeholder="Type your message here...",
|
| lines=2,
|
| interactive=True
|
| )
|
|
|
|
|
| file_input = gr.File(
|
| label="Upload Context File (.txt)",
|
| type="binary",
|
| interactive=True
|
| )
|
|
|
|
|
| context_display = gr.Textbox(
|
| label="Relevant Context",
|
| interactive=False
|
| )
|
|
|
|
|
| embeddings_state = gr.State({"embeddings": [], "messages_history": []})
|
|
|
| async def chat_handler(message, file, state):
|
| """
|
| Handle user input, process embeddings, retrieve context, and generate responses.
|
|
|
| Args:
|
| message (str): The user's input message.
|
| file (UploadedFile): The uploaded context file.
|
| state (dict): The current state containing embeddings and message history.
|
|
|
| Yields:
|
| list: Updated chatbot messages, new state, and context display text.
|
| """
|
| embeddings = state.get("embeddings", [])
|
| messages_history = state.get("messages_history", [])
|
|
|
|
|
|
|
|
|
| if file:
|
| try:
|
| file_content = file.read().decode("utf-8")
|
| message += f"\n[File Content]:\n{file_content}"
|
| logger.info("Successfully processed uploaded file.")
|
| except Exception as e:
|
| error_msg = f"Error reading file: {e}"
|
| logger.error(error_msg)
|
| yield [error_msg, state, ""]
|
| return
|
|
|
|
|
|
|
|
|
| user_embedding = await get_embeddings(message)
|
| if user_embedding is not None:
|
| embeddings.append(user_embedding)
|
| messages_history.append({"role": "user", "content": message})
|
| logger.info("Embeddings generated and appended to state.")
|
| else:
|
| error_msg = "Failed to generate embeddings."
|
| logger.error(error_msg)
|
| yield [error_msg, state, ""]
|
| return
|
|
|
|
|
| if len(embeddings) > MAX_EMBEDDINGS:
|
| embeddings = embeddings[-MAX_EMBEDDINGS:]
|
| messages_history = messages_history[-MAX_EMBEDDINGS:]
|
|
|
|
|
|
|
|
|
| history = [{"role": "user", "content": message}]
|
| context_text = ""
|
| if len(embeddings) > 1:
|
| similarities = [
|
| (calculate_similarity(user_embedding, emb), idx)
|
| for idx, emb in enumerate(embeddings[:-1])
|
| ]
|
| similarities.sort(reverse=True, key=lambda x: x[0])
|
| top_context = similarities[:3]
|
| for similarity, idx in top_context:
|
| context_message = messages_history[idx]
|
| history.insert(0, {"role": "system", "content": context_message["content"]})
|
| context_text += f"Context: {context_message['content'][:100]}...\n"
|
| logger.info("Relevant context retrieved based on similarity.")
|
|
|
|
|
|
|
|
|
| max_tokens = calculate_max_tokens(message)
|
| logger.info(f"Calculated max_tokens for output: {max_tokens}")
|
|
|
|
|
|
|
|
|
| response = ""
|
| try:
|
| async for chunk in chat_with_lmstudio(history, max_tokens):
|
| response += chunk
|
|
|
| if not isinstance(response, str):
|
| response = str(response)
|
|
|
| if not response.strip():
|
| response = "Sorry, I couldn't process your request."
|
|
|
|
|
| updated_chat = chatbot.value.copy()
|
| updated_chat.append({"role": "user", "content": message})
|
| updated_chat.append({"role": "assistant", "content": response})
|
| logger.debug(f"Updated Chat: {updated_chat}")
|
| yield [
|
| updated_chat,
|
| {"embeddings": embeddings, "messages_history": messages_history},
|
| context_text
|
| ]
|
| logger.info("Response generation completed.")
|
| except Exception as e:
|
| error_msg = f"An error occurred while generating a response: {e}"
|
| logger.error(error_msg)
|
| yield [error_msg, state, ""]
|
| return
|
|
|
|
|
|
|
|
|
| messages_history.append({"role": "assistant", "content": response})
|
| new_state = {"embeddings": embeddings, "messages_history": messages_history}
|
| updated_chat = chatbot.value.copy()
|
| updated_chat.append({"role": "user", "content": message})
|
| updated_chat.append({"role": "assistant", "content": response})
|
|
|
|
|
| try:
|
| logger.debug(f"Final Updated Chat: {updated_chat}")
|
| yield [
|
| updated_chat,
|
| new_state,
|
| context_text
|
| ]
|
| except Exception as e:
|
| error_msg = f"Error updating chatbot: {e}"
|
| logger.error(error_msg)
|
| yield ["An error occurred while updating the chat.", state, ""]
|
|
|
|
|
|
|
|
|
| send_button = gr.Button("Send")
|
| send_button.click(
|
| chat_handler,
|
| inputs=[user_input, file_input, embeddings_state],
|
| outputs=[chatbot, embeddings_state, context_display],
|
| show_progress=True
|
| )
|
|
|
|
|
|
|
|
|
| interface.launch(share=True, server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| asyncio.run(gradio_chat_interface())
|
|
|