|
|
|
|
|
|
|
|
|
|
| import os
|
| import re
|
| import json
|
| import uuid
|
| import httpx
|
| import secrets
|
| import string
|
| import traceback
|
| import time
|
| import random
|
| import threading
|
| import logging
|
| from typing import List, Dict, Any, Optional, Literal, Union
|
| from collections import OrderedDict
|
|
|
| from fastapi import FastAPI, Request, Header, HTTPException, Depends
|
| from fastapi.responses import JSONResponse, StreamingResponse
|
| from pydantic import BaseModel, ValidationError
|
|
|
| from config_loader import config_loader
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| def generate_random_trigger_signal() -> str:
|
| """Generate a random, self-closing trigger signal like <Function_AB1c_Start/>."""
|
| chars = string.ascii_letters + string.digits
|
| random_str = ''.join(secrets.choice(chars) for _ in range(4))
|
| return f"<Function_{random_str}_Start/>"
|
|
|
| try:
|
| app_config = config_loader.load_config()
|
|
|
| log_level_str = app_config.features.log_level
|
| if log_level_str == "DISABLED":
|
| log_level = logging.CRITICAL + 1
|
| else:
|
| log_level = getattr(logging, log_level_str, logging.INFO)
|
|
|
| logging.basicConfig(
|
| level=log_level,
|
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| datefmt='%Y-%m-%d %H:%M:%S'
|
| )
|
|
|
| logger.info(f"β
Configuration loaded successfully: {config_loader.config_path}")
|
| logger.info(f"π Configured {len(app_config.upstream_services)} upstream services")
|
| logger.info(f"π Configured {len(app_config.client_authentication.allowed_keys)} client keys")
|
|
|
| MODEL_TO_SERVICE_MAPPING, ALIAS_MAPPING = config_loader.get_model_to_service_mapping()
|
| DEFAULT_SERVICE = config_loader.get_default_service()
|
| ALLOWED_CLIENT_KEYS = config_loader.get_allowed_client_keys()
|
| GLOBAL_TRIGGER_SIGNAL = generate_random_trigger_signal()
|
|
|
| logger.info(f"π― Configured {len(MODEL_TO_SERVICE_MAPPING)} model mappings")
|
| if ALIAS_MAPPING:
|
| logger.info(f"π Configured {len(ALIAS_MAPPING)} model aliases: {list(ALIAS_MAPPING.keys())}")
|
| logger.info(f"π Default service: {DEFAULT_SERVICE['name']}")
|
|
|
| except Exception as e:
|
| logger.error(f"β Configuration loading failed: {type(e).__name__}")
|
| logger.error(f"β Error details: {str(e)}")
|
| logger.error("π‘ Please ensure config.yaml file exists and is properly formatted")
|
| exit(1)
|
| class ToolCallMappingManager:
|
| """
|
| Tool call mapping manager with TTL (Time To Live) and size limit
|
|
|
| Features:
|
| 1. Automatic expiration cleanup - entries are automatically deleted after specified time
|
| 2. Size limit - prevents unlimited memory growth
|
| 3. LRU eviction - removes least recently used entries when size limit is reached
|
| 4. Thread safe - supports concurrent access
|
| 5. Periodic cleanup - background thread regularly cleans up expired entries
|
| """
|
|
|
| def __init__(self, max_size: int = 1000, ttl_seconds: int = 3600, cleanup_interval: int = 300):
|
| """
|
| Initialize mapping manager
|
|
|
| Args:
|
| max_size: Maximum number of stored entries
|
| ttl_seconds: Entry time to live (seconds)
|
| cleanup_interval: Cleanup interval (seconds)
|
| """
|
| self.max_size = max_size
|
| self.ttl_seconds = ttl_seconds
|
| self.cleanup_interval = cleanup_interval
|
|
|
| self._data: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
| self._timestamps: Dict[str, float] = {}
|
| self._lock = threading.RLock()
|
|
|
| self._cleanup_thread = threading.Thread(target=self._periodic_cleanup, daemon=True)
|
| self._cleanup_thread.start()
|
|
|
| logger.debug(f"π§ [INIT] Tool call mapping manager started - Max entries: {max_size}, TTL: {ttl_seconds}s, Cleanup interval: {cleanup_interval}s")
|
|
|
| def store(self, tool_call_id: str, name: str, args: dict, description: str = "") -> None:
|
| """Store tool call mapping"""
|
| with self._lock:
|
| current_time = time.time()
|
|
|
| if tool_call_id in self._data:
|
| del self._data[tool_call_id]
|
| del self._timestamps[tool_call_id]
|
|
|
| while len(self._data) >= self.max_size:
|
| oldest_key = next(iter(self._data))
|
| del self._data[oldest_key]
|
| del self._timestamps[oldest_key]
|
| logger.debug(f"π§ [CLEANUP] Removed oldest entry due to size limit: {oldest_key}")
|
|
|
| self._data[tool_call_id] = {
|
| "name": name,
|
| "args": args,
|
| "description": description,
|
| "created_at": current_time
|
| }
|
| self._timestamps[tool_call_id] = current_time
|
|
|
| logger.debug(f"π§ Stored tool call mapping: {tool_call_id} -> {name}")
|
| logger.debug(f"π§ Current mapping table size: {len(self._data)}")
|
|
|
| def get(self, tool_call_id: str) -> Optional[Dict[str, Any]]:
|
| """Get tool call mapping (updates LRU order)"""
|
| with self._lock:
|
| current_time = time.time()
|
|
|
| if tool_call_id not in self._data:
|
| logger.debug(f"π§ Tool call mapping not found: {tool_call_id}")
|
| logger.debug(f"π§ All IDs in current mapping table: {list(self._data.keys())}")
|
| return None
|
|
|
| if current_time - self._timestamps[tool_call_id] > self.ttl_seconds:
|
| logger.debug(f"π§ Tool call mapping expired: {tool_call_id}")
|
| del self._data[tool_call_id]
|
| del self._timestamps[tool_call_id]
|
| return None
|
|
|
| result = self._data[tool_call_id]
|
| self._data.move_to_end(tool_call_id)
|
|
|
| logger.debug(f"π§ Found tool call mapping: {tool_call_id} -> {result['name']}")
|
| return result
|
|
|
| def cleanup_expired(self) -> int:
|
| """Clean up expired entries, return the number of cleaned entries"""
|
| with self._lock:
|
| current_time = time.time()
|
| expired_keys = []
|
|
|
| for key, timestamp in self._timestamps.items():
|
| if current_time - timestamp > self.ttl_seconds:
|
| expired_keys.append(key)
|
|
|
| for key in expired_keys:
|
| del self._data[key]
|
| del self._timestamps[key]
|
|
|
| if expired_keys:
|
| logger.debug(f"π§ [CLEANUP] Cleaned up {len(expired_keys)} expired entries")
|
|
|
| return len(expired_keys)
|
|
|
| def get_stats(self) -> Dict[str, Any]:
|
| """Get statistics"""
|
| with self._lock:
|
| current_time = time.time()
|
| expired_count = sum(1 for ts in self._timestamps.values()
|
| if current_time - ts > self.ttl_seconds)
|
|
|
| return {
|
| "total_entries": len(self._data),
|
| "expired_entries": expired_count,
|
| "active_entries": len(self._data) - expired_count,
|
| "max_size": self.max_size,
|
| "ttl_seconds": self.ttl_seconds,
|
| "memory_usage_ratio": len(self._data) / self.max_size
|
| }
|
|
|
| def _periodic_cleanup(self) -> None:
|
| """Background periodic cleanup thread"""
|
| while True:
|
| try:
|
| time.sleep(self.cleanup_interval)
|
| cleaned = self.cleanup_expired()
|
|
|
| stats = self.get_stats()
|
| if stats["total_entries"] > 0:
|
| logger.debug(f"π§ [STATS] Mapping table status: Total={stats['total_entries']}, "
|
| f"Active={stats['active_entries']}, Memory usage={stats['memory_usage_ratio']:.1%}")
|
|
|
| except Exception as e:
|
| logger.error(f"β Background cleanup thread exception: {e}")
|
|
|
| TOOL_CALL_MAPPING_MANAGER = ToolCallMappingManager(
|
| max_size=1000,
|
| ttl_seconds=3600,
|
| cleanup_interval=300
|
| )
|
|
|
| def store_tool_call_mapping(tool_call_id: str, name: str, args: dict, description: str = ""):
|
| """Store mapping between tool call ID and call content"""
|
| TOOL_CALL_MAPPING_MANAGER.store(tool_call_id, name, args, description)
|
|
|
| def get_tool_call_mapping(tool_call_id: str) -> Optional[Dict[str, Any]]:
|
| """Get call content corresponding to tool call ID"""
|
| return TOOL_CALL_MAPPING_MANAGER.get(tool_call_id)
|
|
|
| def format_tool_result_for_ai(tool_call_id: str, result_content: str) -> str:
|
| """Format tool call results for AI understanding with English prompts and XML structure"""
|
| logger.debug(f"π§ Formatting tool call result: tool_call_id={tool_call_id}")
|
| tool_info = get_tool_call_mapping(tool_call_id)
|
| if not tool_info:
|
| logger.debug(f"π§ Tool call mapping not found, using default format")
|
| return f"Tool execution result:\n<tool_result>\n{result_content}\n</tool_result>"
|
|
|
| formatted_text = f"""Tool execution result:
|
| - Tool name: {tool_info['name']}
|
| - Execution result:
|
| <tool_result>
|
| {result_content}
|
| </tool_result>"""
|
|
|
| logger.debug(f"π§ Formatting completed, tool name: {tool_info['name']}")
|
| return formatted_text
|
|
|
| def format_assistant_tool_calls_for_ai(tool_calls: List[Dict[str, Any]], trigger_signal: str) -> str:
|
| """Format assistant tool calls into AI-readable string format."""
|
| logger.debug(f"π§ Formatting assistant tool calls. Count: {len(tool_calls)}")
|
|
|
| xml_calls_parts = []
|
| for tool_call in tool_calls:
|
| function_info = tool_call.get("function", {})
|
| name = function_info.get("name", "")
|
| arguments_json = function_info.get("arguments", "{}")
|
|
|
| try:
|
|
|
| args_dict = json.loads(arguments_json)
|
| except (json.JSONDecodeError, TypeError):
|
|
|
| args_dict = {"raw_arguments": arguments_json}
|
|
|
| args_parts = []
|
| for key, value in args_dict.items():
|
|
|
| json_value = json.dumps(value, ensure_ascii=False)
|
| args_parts.append(f"<{key}>{json_value}</{key}>")
|
|
|
| args_content = "\n".join(args_parts)
|
|
|
| xml_call = f"<function_call>\n<tool>{name}</tool>\n<args>\n{args_content}\n</args>\n</function_call>"
|
| xml_calls_parts.append(xml_call)
|
|
|
| all_calls = "\n".join(xml_calls_parts)
|
| final_str = f"{trigger_signal}\n<function_calls>\n{all_calls}\n</function_calls>"
|
|
|
| logger.debug("π§ Assistant tool calls formatted successfully.")
|
| return final_str
|
|
|
| def get_function_call_prompt_template(trigger_signal: str) -> str:
|
| """
|
| Generate prompt template based on dynamic trigger signal
|
| """
|
| custom_template = app_config.features.prompt_template
|
| if custom_template:
|
| logger.info("π§ Using custom prompt template from configuration")
|
| return custom_template.format(
|
| trigger_signal=trigger_signal,
|
| tools_list="{tools_list}"
|
| )
|
|
|
| return f"""
|
| You have access to the following available tools to help solve problems:
|
|
|
| {{tools_list}}
|
|
|
| **IMPORTANT CONTEXT NOTES:**
|
| 1. You can call MULTIPLE tools in a single response if needed.
|
| 2. The conversation context may already contain tool execution results from previous function calls. Review the conversation history carefully to avoid unnecessary duplicate tool calls.
|
| 3. When tool execution results are present in the context, they will be formatted with XML tags like <tool_result>...</tool_result> for easy identification.
|
| 4. This is the ONLY format you can use for tool calls, and any deviation will result in failure.
|
|
|
| When you need to use tools, you **MUST** strictly follow this format. Do NOT include any extra text, explanations, or dialogue on the first and second lines of the tool call syntax:
|
|
|
| 1. When starting tool calls, begin on a new line with exactly:
|
| {trigger_signal}
|
| No leading or trailing spaces, output exactly as shown above. The trigger signal MUST be on its own line and appear only once.
|
|
|
| 2. Starting from the second line, **immediately** follow with the complete <function_calls> XML block.
|
|
|
| 3. For multiple tool calls, include multiple <function_call> blocks within the same <function_calls> wrapper.
|
|
|
| 4. Do not add any text or explanation after the closing </function_calls> tag.
|
|
|
| STRICT ARGUMENT KEY RULES:
|
| - You MUST use parameter keys EXACTLY as defined (case- and punctuation-sensitive). Do NOT rename, add, or remove characters.
|
| - If a key starts with a hyphen (e.g., -i, -C), you MUST keep the hyphen in the tag name. Example: <-i>true</-i>, <-C>2</-C>.
|
| - Never convert "-i" to "i" or "-C" to "C". Do not pluralize, translate, or alias parameter keys.
|
| - The <tool> tag must contain the exact name of a tool from the list. Any other tool name is invalid.
|
| - The <args> must contain all required arguments for that tool.
|
|
|
| CORRECT Example (multiple tool calls, including hyphenated keys):
|
| ...response content (optional)...
|
| {trigger_signal}
|
| <function_calls>
|
| <function_call>
|
| <tool>Grep</tool>
|
| <args>
|
| <-i>true</-i>
|
| <-C>2</-C>
|
| <path>.</path>
|
| </args>
|
| </function_call>
|
| <function_call>
|
| <tool>search</tool>
|
| <args>
|
| <keywords>["Python Document", "how to use python"]</keywords>
|
| </args>
|
| </function_call>
|
| </function_calls>
|
|
|
| INCORRECT Example (extra text + wrong key names β DO NOT DO THIS):
|
| ...response content (optional)...
|
| {trigger_signal}
|
| I will call the tools for you.
|
| <function_calls>
|
| <function_call>
|
| <tool>Grep</tool>
|
| <args>
|
| <i>true</i>
|
| <C>2</C>
|
| <path>.</path>
|
| </args>
|
| </function_call>
|
| </function_calls>
|
|
|
| Now please be ready to strictly follow the above specifications.
|
| """
|
|
|
| class ToolFunction(BaseModel):
|
| name: str
|
| description: Optional[str] = None
|
| parameters: Dict[str, Any]
|
|
|
| class Tool(BaseModel):
|
| type: Literal["function"]
|
| function: ToolFunction
|
|
|
| class Message(BaseModel):
|
| role: str
|
| content: Optional[str] = None
|
| tool_calls: Optional[List[Dict[str, Any]]] = None
|
| tool_call_id: Optional[str] = None
|
| name: Optional[str] = None
|
|
|
| class Config:
|
| extra = "allow"
|
|
|
| class ToolChoice(BaseModel):
|
| type: Literal["function"]
|
| function: Dict[str, str]
|
|
|
| class ChatCompletionRequest(BaseModel):
|
| model: str
|
| messages: List[Dict[str, Any]]
|
| tools: Optional[List[Tool]] = None
|
| tool_choice: Optional[Union[str, ToolChoice]] = None
|
| stream: Optional[bool] = False
|
| stream_options: Optional[Dict[str, Any]] = None
|
| temperature: Optional[float] = None
|
| max_tokens: Optional[int] = None
|
| top_p: Optional[float] = None
|
| frequency_penalty: Optional[float] = None
|
| presence_penalty: Optional[float] = None
|
| n: Optional[int] = None
|
| stop: Optional[Union[str, List[str]]] = None
|
|
|
| class Config:
|
| extra = "allow"
|
|
|
|
|
| def generate_function_prompt(tools: List[Tool], trigger_signal: str) -> tuple[str, str]:
|
| """
|
| Generate injected system prompt based on tools definition in client request.
|
| Returns: (prompt_content, trigger_signal)
|
| """
|
| tools_list_str = []
|
| for i, tool in enumerate(tools):
|
| func = tool.function
|
| name = func.name
|
| description = func.description or ""
|
|
|
|
|
| schema: Dict[str, Any] = func.parameters or {}
|
| props: Dict[str, Any] = schema.get("properties", {}) or {}
|
| required_list: List[str] = schema.get("required", []) or []
|
|
|
|
|
| params_summary = ", ".join([
|
| f"{p_name} ({(p_info or {}).get('type', 'any')})" for p_name, p_info in props.items()
|
| ]) or "None"
|
|
|
|
|
| detail_lines: List[str] = []
|
| for p_name, p_info in props.items():
|
| p_info = p_info or {}
|
| p_type = p_info.get("type", "any")
|
| is_required = "Yes" if p_name in required_list else "No"
|
| p_desc = p_info.get("description")
|
| enum_vals = p_info.get("enum")
|
| default_val = p_info.get("default")
|
| examples_val = p_info.get("examples") or p_info.get("example")
|
|
|
|
|
| constraints: Dict[str, Any] = {}
|
| for key in [
|
| "minimum", "maximum", "exclusiveMinimum", "exclusiveMaximum",
|
| "minLength", "maxLength", "pattern", "format",
|
| "minItems", "maxItems", "uniqueItems"
|
| ]:
|
| if key in p_info:
|
| constraints[key] = p_info.get(key)
|
|
|
|
|
| if p_type == "array":
|
| items = p_info.get("items") or {}
|
| if isinstance(items, dict):
|
| itype = items.get("type")
|
| if itype:
|
| constraints["items.type"] = itype
|
|
|
|
|
| detail_lines.append(f"- {p_name}:")
|
| detail_lines.append(f" - type: {p_type}")
|
| detail_lines.append(f" - required: {is_required}")
|
| if p_desc:
|
| detail_lines.append(f" - description: {p_desc}")
|
| if enum_vals is not None:
|
| try:
|
| detail_lines.append(f" - enum: {json.dumps(enum_vals, ensure_ascii=False)}")
|
| except Exception:
|
| detail_lines.append(f" - enum: {enum_vals}")
|
| if default_val is not None:
|
| try:
|
| detail_lines.append(f" - default: {json.dumps(default_val, ensure_ascii=False)}")
|
| except Exception:
|
| detail_lines.append(f" - default: {default_val}")
|
| if examples_val is not None:
|
| try:
|
| detail_lines.append(f" - examples: {json.dumps(examples_val, ensure_ascii=False)}")
|
| except Exception:
|
| detail_lines.append(f" - examples: {examples_val}")
|
| if constraints:
|
| try:
|
| detail_lines.append(f" - constraints: {json.dumps(constraints, ensure_ascii=False)}")
|
| except Exception:
|
| detail_lines.append(f" - constraints: {constraints}")
|
|
|
| detail_block = "\n".join(detail_lines) if detail_lines else "(no parameter details)"
|
|
|
| desc_block = f"```\n{description}\n```" if description else "None"
|
|
|
| tools_list_str.append(
|
| f"{i + 1}. <tool name=\"{name}\">\n"
|
| f" Description:\n{desc_block}\n"
|
| f" Parameters summary: {params_summary}\n"
|
| f" Required parameters: {', '.join(required_list) if required_list else 'None'}\n"
|
| f" Parameter details:\n{detail_block}"
|
| )
|
|
|
| prompt_template = get_function_call_prompt_template(trigger_signal)
|
| prompt_content = prompt_template.replace("{tools_list}", "\n\n".join(tools_list_str))
|
|
|
| return prompt_content, trigger_signal
|
|
|
| def remove_think_blocks(text: str) -> str:
|
| """
|
| Temporarily remove all <think>...</think> blocks for XML parsing
|
| Supports nested think tags
|
| Note: This function is only used for temporary parsing and does not affect the original content returned to the user
|
| """
|
| while '<think>' in text and '</think>' in text:
|
| start_pos = text.find('<think>')
|
| if start_pos == -1:
|
| break
|
|
|
| pos = start_pos + 7
|
| depth = 1
|
|
|
| while pos < len(text) and depth > 0:
|
| if text[pos:pos+7] == '<think>':
|
| depth += 1
|
| pos += 7
|
| elif text[pos:pos+8] == '</think>':
|
| depth -= 1
|
| pos += 8
|
| else:
|
| pos += 1
|
|
|
| if depth == 0:
|
| text = text[:start_pos] + text[pos:]
|
| else:
|
| break
|
|
|
| return text
|
|
|
| class StreamingFunctionCallDetector:
|
| """Enhanced streaming function call detector, supports dynamic trigger signals, avoids misjudgment within <think> tags
|
|
|
| Core features:
|
| 1. Avoid triggering tool call detection within <think> blocks
|
| 2. Normally output <think> block content to the user
|
| 3. Supports nested think tags
|
| """
|
|
|
| def __init__(self, trigger_signal: str):
|
| self.trigger_signal = trigger_signal
|
| self.reset()
|
|
|
| def reset(self):
|
| self.content_buffer = ""
|
| self.state = "detecting"
|
| self.in_think_block = False
|
| self.think_depth = 0
|
| self.signal = self.trigger_signal
|
| self.signal_len = len(self.signal)
|
|
|
| def process_chunk(self, delta_content: str) -> tuple[bool, str]:
|
| """
|
| Process streaming content chunk
|
| Returns: (is_tool_call_detected, content_to_yield)
|
| """
|
| if not delta_content:
|
| return False, ""
|
|
|
| self.content_buffer += delta_content
|
| content_to_yield = ""
|
|
|
| if self.state == "tool_parsing":
|
| return False, ""
|
|
|
| if delta_content:
|
| logger.debug(f"π§ Processing chunk: {repr(delta_content[:50])}{'...' if len(delta_content) > 50 else ''}, buffer length: {len(self.content_buffer)}, think state: {self.in_think_block}")
|
|
|
| i = 0
|
| while i < len(self.content_buffer):
|
| skip_chars = self._update_think_state(i)
|
| if skip_chars > 0:
|
| for j in range(skip_chars):
|
| if i + j < len(self.content_buffer):
|
| content_to_yield += self.content_buffer[i + j]
|
| i += skip_chars
|
| continue
|
|
|
| if not self.in_think_block and self._can_detect_signal_at(i):
|
| if self.content_buffer[i:i+self.signal_len] == self.signal:
|
| logger.debug(f"π§ Improved detector: detected trigger signal in non-think block! Signal: {self.signal[:20]}...")
|
| logger.debug(f"π§ Trigger signal position: {i}, think state: {self.in_think_block}, think depth: {self.think_depth}")
|
| self.state = "tool_parsing"
|
| self.content_buffer = self.content_buffer[i:]
|
| return True, content_to_yield
|
|
|
| remaining_len = len(self.content_buffer) - i
|
| if remaining_len < self.signal_len or remaining_len < 8:
|
| break
|
|
|
| content_to_yield += self.content_buffer[i]
|
| i += 1
|
|
|
| self.content_buffer = self.content_buffer[i:]
|
| return False, content_to_yield
|
|
|
| def _update_think_state(self, pos: int):
|
| """Update think tag state, supports nesting"""
|
| remaining = self.content_buffer[pos:]
|
|
|
| if remaining.startswith('<think>'):
|
| self.think_depth += 1
|
| self.in_think_block = True
|
| logger.debug(f"π§ Entering think block, depth: {self.think_depth}")
|
| return 7
|
|
|
| elif remaining.startswith('</think>'):
|
| self.think_depth = max(0, self.think_depth - 1)
|
| self.in_think_block = self.think_depth > 0
|
| logger.debug(f"π§ Exiting think block, depth: {self.think_depth}")
|
| return 8
|
|
|
| return 0
|
|
|
| def _can_detect_signal_at(self, pos: int) -> bool:
|
| """Check if signal can be detected at the specified position"""
|
| return (pos + self.signal_len <= len(self.content_buffer) and
|
| not self.in_think_block)
|
|
|
| def finalize(self) -> Optional[List[Dict[str, Any]]]:
|
| """Final processing when stream ends"""
|
| if self.state == "tool_parsing":
|
| return parse_function_calls_xml(self.content_buffer, self.trigger_signal)
|
| return None
|
|
|
| def parse_function_calls_xml(xml_string: str, trigger_signal: str) -> Optional[List[Dict[str, Any]]]:
|
| """
|
| Enhanced XML parsing function, supports dynamic trigger signals
|
| 1. Retain <think>...</think> blocks (they should be returned normally to the user)
|
| 2. Temporarily remove think blocks only when parsing function_calls to prevent think content from interfering with XML parsing
|
| 3. Find the last occurrence of the trigger signal
|
| 4. Start parsing function_calls from the last trigger signal
|
| """
|
| logger.debug(f"π§ Improved parser starting processing, input length: {len(xml_string) if xml_string else 0}")
|
| logger.debug(f"π§ Using trigger signal: {trigger_signal[:20]}...")
|
|
|
| if not xml_string or trigger_signal not in xml_string:
|
| logger.debug(f"π§ Input is empty or doesn't contain trigger signal")
|
| return None
|
|
|
| cleaned_content = remove_think_blocks(xml_string)
|
| logger.debug(f"π§ Content length after temporarily removing think blocks: {len(cleaned_content)}")
|
|
|
| signal_positions = []
|
| start_pos = 0
|
| while True:
|
| pos = cleaned_content.find(trigger_signal, start_pos)
|
| if pos == -1:
|
| break
|
| signal_positions.append(pos)
|
| start_pos = pos + 1
|
|
|
| if not signal_positions:
|
| logger.debug(f"π§ No trigger signal found in cleaned content")
|
| return None
|
|
|
| logger.debug(f"π§ Found {len(signal_positions)} trigger signal positions: {signal_positions}")
|
|
|
| last_signal_pos = signal_positions[-1]
|
| content_after_signal = cleaned_content[last_signal_pos:]
|
| logger.debug(f"π§ Content starting from last trigger signal: {repr(content_after_signal[:100])}")
|
|
|
| calls_content_match = re.search(r"<function_calls>([\s\S]*?)</function_calls>", content_after_signal)
|
| if not calls_content_match:
|
| logger.debug(f"π§ No function_calls tag found")
|
| return None
|
|
|
| calls_content = calls_content_match.group(1)
|
| logger.debug(f"π§ function_calls content: {repr(calls_content)}")
|
|
|
| results = []
|
| call_blocks = re.findall(r"<function_call>([\s\S]*?)</function_call>", calls_content)
|
| logger.debug(f"π§ Found {len(call_blocks)} function_call blocks")
|
|
|
| for i, block in enumerate(call_blocks):
|
| logger.debug(f"π§ Processing function_call #{i+1}: {repr(block)}")
|
|
|
| tool_match = re.search(r"<tool>(.*?)</tool>", block)
|
| if not tool_match:
|
| logger.debug(f"π§ No tool tag found in block #{i+1}")
|
| continue
|
|
|
| name = tool_match.group(1).strip()
|
| args = {}
|
|
|
| args_block_match = re.search(r"<args>([\s\S]*?)</args>", block)
|
| if args_block_match:
|
| args_content = args_block_match.group(1)
|
|
|
| arg_matches = re.findall(r"<([^\s>/]+)>([\s\S]*?)</\1>", args_content)
|
|
|
| def _coerce_value(v: str):
|
| try:
|
| return json.loads(v)
|
| except Exception:
|
| pass
|
| return v
|
|
|
| for k, v in arg_matches:
|
| args[k] = _coerce_value(v)
|
|
|
| result = {"name": name, "args": args}
|
| results.append(result)
|
| logger.debug(f"π§ Added tool call: {result}")
|
|
|
| logger.debug(f"π§ Final parsing result: {results}")
|
| return results if results else None
|
|
|
| def find_upstream(model_name: str) -> tuple[Dict[str, Any], str]:
|
| """Find upstream configuration by model name, handling aliases and passthrough mode."""
|
|
|
|
|
| if app_config.features.model_passthrough:
|
| logger.info("π Model passthrough mode is active. Forwarding to 'openai' service.")
|
| openai_service = None
|
| for service in app_config.upstream_services:
|
| if service.name == "openai":
|
| openai_service = service.model_dump()
|
| break
|
|
|
| if openai_service:
|
| if not openai_service.get("api_key"):
|
| raise HTTPException(status_code=500, detail="Configuration error: API key not found for the 'openai' service in model passthrough mode.")
|
|
|
| return openai_service, model_name
|
| else:
|
| raise HTTPException(status_code=500, detail="Configuration error: 'model_passthrough' is enabled, but no upstream service named 'openai' was found.")
|
|
|
|
|
| chosen_model_entry = model_name
|
|
|
| if model_name in ALIAS_MAPPING:
|
| chosen_model_entry = random.choice(ALIAS_MAPPING[model_name])
|
| logger.info(f"π Model alias '{model_name}' detected. Randomly selected '{chosen_model_entry}' for this request.")
|
|
|
| service = MODEL_TO_SERVICE_MAPPING.get(chosen_model_entry)
|
|
|
| if service:
|
| if not service.get("api_key"):
|
| raise HTTPException(status_code=500, detail=f"Model configuration error: API key not found for service '{service.get('name')}'.")
|
| else:
|
| logger.warning(f"β οΈ Model '{model_name}' not found in configuration, using default service")
|
| service = DEFAULT_SERVICE
|
| if not service.get("api_key"):
|
| raise HTTPException(status_code=500, detail="Service configuration error: Default API key not found.")
|
|
|
| actual_model_name = chosen_model_entry
|
| if ':' in chosen_model_entry:
|
| parts = chosen_model_entry.split(':', 1)
|
| if len(parts) == 2:
|
| _, actual_model_name = parts
|
|
|
| return service, actual_model_name
|
|
|
| app = FastAPI()
|
| http_client = httpx.AsyncClient()
|
|
|
| @app.middleware("http")
|
| async def debug_middleware(request: Request, call_next):
|
| """Middleware for debugging validation errors, does not log conversation content."""
|
| response = await call_next(request)
|
|
|
| if response.status_code == 422:
|
| logger.debug(f"π Validation error detected for {request.method} {request.url.path}")
|
| logger.debug(f"π Response status code: 422 (Pydantic validation failure)")
|
|
|
| return response
|
|
|
| @app.exception_handler(ValidationError)
|
| async def validation_exception_handler(request: Request, exc: ValidationError):
|
| """Handle Pydantic validation errors with detailed error information"""
|
| logger.error(f"β Pydantic validation error: {exc}")
|
| logger.error(f"β Request URL: {request.url}")
|
| logger.error(f"β Error details: {exc.errors()}")
|
|
|
| for error in exc.errors():
|
| logger.error(f"β Validation error location: {error.get('loc')}")
|
| logger.error(f"β Validation error message: {error.get('msg')}")
|
| logger.error(f"β Validation error type: {error.get('type')}")
|
|
|
| return JSONResponse(
|
| status_code=422,
|
| content={
|
| "error": {
|
| "message": "Invalid request format",
|
| "type": "invalid_request_error",
|
| "code": "invalid_request"
|
| }
|
| }
|
| )
|
|
|
| @app.exception_handler(Exception)
|
| async def general_exception_handler(request: Request, exc: Exception):
|
| """Handle all uncaught exceptions"""
|
| logger.error(f"β Unhandled exception: {exc}")
|
| logger.error(f"β Request URL: {request.url}")
|
| logger.error(f"β Exception type: {type(exc).__name__}")
|
| logger.error(f"β Error stack: {traceback.format_exc()}")
|
|
|
| return JSONResponse(
|
| status_code=500,
|
| content={
|
| "error": {
|
| "message": "Internal server error",
|
| "type": "server_error",
|
| "code": "internal_error"
|
| }
|
| }
|
| )
|
|
|
| async def verify_api_key(authorization: str = Header(...)):
|
| """Dependency: verify client API key"""
|
| client_key = authorization.replace("Bearer ", "")
|
| if app_config.features.key_passthrough:
|
|
|
| return client_key
|
| if client_key not in ALLOWED_CLIENT_KEYS:
|
| raise HTTPException(status_code=401, detail="Unauthorized")
|
| return client_key
|
|
|
| def preprocess_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| """Preprocess messages, convert tool-type messages to AI-understandable format, return dictionary list to avoid Pydantic validation issues"""
|
| processed_messages = []
|
|
|
| for message in messages:
|
| if isinstance(message, dict):
|
| if message.get("role") == "tool":
|
| tool_call_id = message.get("tool_call_id")
|
| content = message.get("content")
|
|
|
| if tool_call_id and content:
|
| formatted_content = format_tool_result_for_ai(tool_call_id, content)
|
| processed_message = {
|
| "role": "user",
|
| "content": formatted_content
|
| }
|
| processed_messages.append(processed_message)
|
| logger.debug(f"π§ Converted tool message to user message: tool_call_id={tool_call_id}")
|
| else:
|
| logger.debug(f"π§ Skipped invalid tool message: tool_call_id={tool_call_id}, content={bool(content)}")
|
| elif message.get("role") == "assistant" and "tool_calls" in message and message["tool_calls"]:
|
| tool_calls = message.get("tool_calls", [])
|
| formatted_tool_calls_str = format_assistant_tool_calls_for_ai(tool_calls, GLOBAL_TRIGGER_SIGNAL)
|
|
|
|
|
| original_content = message.get("content") or ""
|
| final_content = f"{original_content}\n{formatted_tool_calls_str}".strip()
|
|
|
| processed_message = {
|
| "role": "assistant",
|
| "content": final_content
|
| }
|
|
|
| for key, value in message.items():
|
| if key not in ["role", "content", "tool_calls"]:
|
| processed_message[key] = value
|
|
|
| processed_messages.append(processed_message)
|
| logger.debug(f"π§ Converted assistant tool_calls to content.")
|
|
|
| elif message.get("role") == "developer":
|
| if app_config.features.convert_developer_to_system:
|
| processed_message = message.copy()
|
| processed_message["role"] = "system"
|
| processed_messages.append(processed_message)
|
| logger.debug(f"π§ Converted developer message to system message for better upstream compatibility")
|
| else:
|
| processed_messages.append(message)
|
| logger.debug(f"π§ Keeping developer role unchanged (based on configuration)")
|
| else:
|
| processed_messages.append(message)
|
| else:
|
| processed_messages.append(message)
|
|
|
| return processed_messages
|
|
|
| @app.post("/v1/chat/completions")
|
| async def chat_completions(
|
| request: Request,
|
| body: ChatCompletionRequest,
|
| _api_key: str = Depends(verify_api_key)
|
| ):
|
| """Main chat completion endpoint, proxy and inject function calling capabilities."""
|
| try:
|
| logger.debug(f"π§ Received request, model: {body.model}")
|
| logger.debug(f"π§ Number of messages: {len(body.messages)}")
|
| logger.debug(f"π§ Number of tools: {len(body.tools) if body.tools else 0}")
|
| logger.debug(f"π§ Streaming: {body.stream}")
|
|
|
| upstream, actual_model = find_upstream(body.model)
|
| upstream_url = f"{upstream['base_url']}/chat/completions"
|
|
|
| logger.debug(f"π§ Starting message preprocessing, original message count: {len(body.messages)}")
|
| processed_messages = preprocess_messages(body.messages)
|
| logger.debug(f"π§ Preprocessing completed, processed message count: {len(processed_messages)}")
|
|
|
| if not validate_message_structure(processed_messages):
|
| logger.error(f"β Message structure validation failed, but continuing processing")
|
|
|
| request_body_dict = body.model_dump(exclude_unset=True)
|
| request_body_dict["model"] = actual_model
|
| request_body_dict["messages"] = processed_messages
|
| is_fc_enabled = app_config.features.enable_function_calling
|
| has_tools_in_request = bool(body.tools)
|
| has_function_call = is_fc_enabled and has_tools_in_request
|
|
|
| logger.debug(f"π§ Request body constructed, message count: {len(processed_messages)}")
|
|
|
| except Exception as e:
|
| logger.error(f"β Request preprocessing failed: {str(e)}")
|
| logger.error(f"β Error type: {type(e).__name__}")
|
| if hasattr(app_config, 'debug') and app_config.debug:
|
| logger.error(f"β Error stack: {traceback.format_exc()}")
|
|
|
| return JSONResponse(
|
| status_code=422,
|
| content={
|
| "error": {
|
| "message": "Invalid request format",
|
| "type": "invalid_request_error",
|
| "code": "invalid_request"
|
| }
|
| }
|
| )
|
|
|
| if has_function_call:
|
| logger.debug(f"π§ Using global trigger signal for this request: {GLOBAL_TRIGGER_SIGNAL}")
|
|
|
| function_prompt, _ = generate_function_prompt(body.tools, GLOBAL_TRIGGER_SIGNAL)
|
|
|
| tool_choice_prompt = safe_process_tool_choice(body.tool_choice)
|
| if tool_choice_prompt:
|
| function_prompt += tool_choice_prompt
|
|
|
| system_message = {"role": "system", "content": function_prompt}
|
| request_body_dict["messages"].insert(0, system_message)
|
|
|
| if "tools" in request_body_dict:
|
| del request_body_dict["tools"]
|
| if "tool_choice" in request_body_dict:
|
| del request_body_dict["tool_choice"]
|
|
|
| elif has_tools_in_request and not is_fc_enabled:
|
| logger.info(f"π§ Function calling is disabled by configuration, ignoring 'tools' and 'tool_choice' in request.")
|
| if "tools" in request_body_dict:
|
| del request_body_dict["tools"]
|
| if "tool_choice" in request_body_dict:
|
| del request_body_dict["tool_choice"]
|
|
|
| headers = {
|
| "Content-Type": "application/json",
|
| "Authorization": f"Bearer {_api_key}" if app_config.features.key_passthrough else f"Bearer {upstream['api_key']}",
|
| "Accept": "application/json" if not body.stream else "text/event-stream"
|
| }
|
|
|
| logger.info(f"π Forwarding request to upstream: {upstream['name']}")
|
| logger.info(f"π Model: {request_body_dict.get('model', 'unknown')}, Messages: {len(request_body_dict.get('messages', []))}")
|
|
|
| if not body.stream:
|
| try:
|
| logger.debug(f"π§ Sending upstream request to: {upstream_url}")
|
| logger.debug(f"π§ has_function_call: {has_function_call}")
|
| logger.debug(f"π§ Request body contains tools: {bool(body.tools)}")
|
|
|
| upstream_response = await http_client.post(
|
| upstream_url, json=request_body_dict, headers=headers, timeout=app_config.server.timeout
|
| )
|
| upstream_response.raise_for_status()
|
|
|
| response_json = upstream_response.json()
|
| logger.debug(f"π§ Upstream response status code: {upstream_response.status_code}")
|
|
|
| if has_function_call:
|
| content = response_json["choices"][0]["message"]["content"]
|
| logger.debug(f"π§ Complete response content: {repr(content)}")
|
|
|
| parsed_tools = parse_function_calls_xml(content, GLOBAL_TRIGGER_SIGNAL)
|
| logger.debug(f"π§ XML parsing result: {parsed_tools}")
|
|
|
| if parsed_tools:
|
| logger.debug(f"π§ Successfully parsed {len(parsed_tools)} tool calls")
|
| tool_calls = []
|
| for tool in parsed_tools:
|
| tool_call_id = f"call_{uuid.uuid4().hex}"
|
| store_tool_call_mapping(
|
| tool_call_id,
|
| tool["name"],
|
| tool["args"],
|
| f"Calling tool {tool['name']}"
|
| )
|
| tool_calls.append({
|
| "id": tool_call_id,
|
| "type": "function",
|
| "function": {
|
| "name": tool["name"],
|
| "arguments": json.dumps(tool["args"])
|
| }
|
| })
|
| logger.debug(f"π§ Converted tool_calls: {tool_calls}")
|
|
|
| response_json["choices"][0]["message"] = {
|
| "role": "assistant",
|
| "content": None,
|
| "tool_calls": tool_calls,
|
| }
|
| response_json["choices"][0]["finish_reason"] = "tool_calls"
|
| logger.debug(f"π§ Function call conversion completed")
|
| else:
|
| logger.debug(f"π§ No tool calls detected, returning original content (including think blocks)")
|
| else:
|
| logger.debug(f"π§ No function calls detected or conversion conditions not met")
|
|
|
| return JSONResponse(content=response_json)
|
|
|
| except httpx.HTTPStatusError as e:
|
| logger.error(f"β Upstream service response error: status_code={e.response.status_code}")
|
| logger.error(f"β Upstream error details: {e.response.text}")
|
|
|
| if e.response.status_code == 400:
|
| error_response = {
|
| "error": {
|
| "message": "Invalid request parameters",
|
| "type": "invalid_request_error",
|
| "code": "bad_request"
|
| }
|
| }
|
| elif e.response.status_code == 401:
|
| error_response = {
|
| "error": {
|
| "message": "Authentication failed",
|
| "type": "authentication_error",
|
| "code": "unauthorized"
|
| }
|
| }
|
| elif e.response.status_code == 403:
|
| error_response = {
|
| "error": {
|
| "message": "Access forbidden",
|
| "type": "permission_error",
|
| "code": "forbidden"
|
| }
|
| }
|
| elif e.response.status_code == 429:
|
| error_response = {
|
| "error": {
|
| "message": "Rate limit exceeded",
|
| "type": "rate_limit_error",
|
| "code": "rate_limit_exceeded"
|
| }
|
| }
|
| elif e.response.status_code >= 500:
|
| error_response = {
|
| "error": {
|
| "message": "Upstream service temporarily unavailable",
|
| "type": "service_error",
|
| "code": "upstream_error"
|
| }
|
| }
|
| else:
|
| error_response = {
|
| "error": {
|
| "message": "Request processing failed",
|
| "type": "api_error",
|
| "code": "unknown_error"
|
| }
|
| }
|
|
|
| return JSONResponse(content=error_response, status_code=e.response.status_code)
|
|
|
| else:
|
| return StreamingResponse(
|
| stream_proxy_with_fc_transform(upstream_url, request_body_dict, headers, body.model, has_function_call, GLOBAL_TRIGGER_SIGNAL),
|
| media_type="text/event-stream"
|
| )
|
|
|
| async def stream_proxy_with_fc_transform(url: str, body: dict, headers: dict, model: str, has_fc: bool, trigger_signal: str):
|
| """
|
| Enhanced streaming proxy, supports dynamic trigger signals, avoids misjudgment within think tags
|
| """
|
| logger.info(f"π Starting streaming response from: {url}")
|
| logger.info(f"π Function calling enabled: {has_fc}")
|
|
|
| if not has_fc or not trigger_signal:
|
| try:
|
| async with http_client.stream("POST", url, json=body, headers=headers, timeout=app_config.server.timeout) as response:
|
| async for chunk in response.aiter_bytes():
|
| yield chunk
|
| except httpx.RemoteProtocolError:
|
| logger.debug("π§ Upstream closed connection prematurely, ending stream response")
|
| return
|
| return
|
|
|
| detector = StreamingFunctionCallDetector(trigger_signal)
|
|
|
| def _prepare_tool_calls(parsed_tools: List[Dict[str, Any]]):
|
| tool_calls = []
|
| for i, tool in enumerate(parsed_tools):
|
| tool_call_id = f"call_{uuid.uuid4().hex}"
|
| store_tool_call_mapping(
|
| tool_call_id,
|
| tool["name"],
|
| tool["args"],
|
| f"Calling tool {tool['name']}"
|
| )
|
| tool_calls.append({
|
| "index": i, "id": tool_call_id, "type": "function",
|
| "function": { "name": tool["name"], "arguments": json.dumps(tool["args"]) }
|
| })
|
| return tool_calls
|
|
|
| def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: str) -> List[str]:
|
| tool_calls = _prepare_tool_calls(parsed_tools)
|
| chunks: List[str] = []
|
|
|
| initial_chunk = {
|
| "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk",
|
| "created": int(os.path.getmtime(__file__)), "model": model_id,
|
| "choices": [{"index": 0, "delta": {"role": "assistant", "content": None, "tool_calls": tool_calls}, "finish_reason": None}],
|
| }
|
| chunks.append(f"data: {json.dumps(initial_chunk)}\n\n")
|
|
|
|
|
| final_chunk = {
|
| "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk",
|
| "created": int(os.path.getmtime(__file__)), "model": model_id,
|
| "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}],
|
| }
|
| chunks.append(f"data: {json.dumps(final_chunk)}\n\n")
|
| chunks.append("data: [DONE]\n\n")
|
| return chunks
|
|
|
| try:
|
| async with http_client.stream("POST", url, json=body, headers=headers, timeout=app_config.server.timeout) as response:
|
| if response.status_code != 200:
|
| error_content = await response.aread()
|
| logger.error(f"β Upstream service stream response error: status_code={response.status_code}")
|
| logger.error(f"β Upstream error details: {error_content.decode('utf-8', errors='ignore')}")
|
|
|
| if response.status_code == 401:
|
| error_message = "Authentication failed"
|
| elif response.status_code == 403:
|
| error_message = "Access forbidden"
|
| elif response.status_code == 429:
|
| error_message = "Rate limit exceeded"
|
| elif response.status_code >= 500:
|
| error_message = "Upstream service temporarily unavailable"
|
| else:
|
| error_message = "Request processing failed"
|
|
|
| error_chunk = {"error": {"message": error_message, "type": "upstream_error"}}
|
| yield f"data: {json.dumps(error_chunk)}\n\n"
|
| yield "data: [DONE]\n\n"
|
| return
|
|
|
| async for line in response.aiter_lines():
|
| if detector.state == "tool_parsing":
|
| if line.startswith("data:"):
|
| line_data = line[len("data: "):].strip()
|
| if line_data and line_data != "[DONE]":
|
| try:
|
| chunk_json = json.loads(line_data)
|
| delta_content = chunk_json.get("choices", [{}])[0].get("delta", {}).get("content", "") or ""
|
| detector.content_buffer += delta_content
|
|
|
| if "</function_calls>" in detector.content_buffer:
|
| logger.debug("π§ Detected </function_calls> in stream, finalizing early...")
|
| parsed_tools = detector.finalize()
|
| if parsed_tools:
|
| logger.debug(f"π§ Early finalize: parsed {len(parsed_tools)} tool calls")
|
| for sse in _build_tool_call_sse_chunks(parsed_tools, model):
|
| yield sse
|
| return
|
| else:
|
| logger.error("β Early finalize failed to parse tool calls")
|
| error_content = "Error: Detected tool use signal but failed to parse function call format"
|
| error_chunk = { "id": "error-chunk", "choices": [{"delta": {"content": error_content}}]}
|
| yield f"data: {json.dumps(error_chunk)}\n\n"
|
| yield "data: [DONE]\n\n"
|
| return
|
| except (json.JSONDecodeError, IndexError):
|
| pass
|
| continue
|
|
|
| if line.startswith("data:"):
|
| line_data = line[len("data: "):].strip()
|
| if not line_data or line_data == "[DONE]":
|
| continue
|
|
|
| try:
|
| chunk_json = json.loads(line_data)
|
| delta_content = chunk_json.get("choices", [{}])[0].get("delta", {}).get("content", "") or ""
|
|
|
| if delta_content:
|
| is_detected, content_to_yield = detector.process_chunk(delta_content)
|
|
|
| if content_to_yield:
|
| yield_chunk = {
|
| "id": f"chatcmpl-passthrough-{uuid.uuid4().hex}",
|
| "object": "chat.completion.chunk",
|
| "created": int(os.path.getmtime(__file__)),
|
| "model": model,
|
| "choices": [{"index": 0, "delta": {"content": content_to_yield}}]
|
| }
|
| yield f"data: {json.dumps(yield_chunk)}\n\n"
|
|
|
| if is_detected:
|
|
|
| continue
|
|
|
| except (json.JSONDecodeError, IndexError):
|
| yield line + "\n\n"
|
|
|
| except httpx.RequestError as e:
|
| logger.error(f"β Failed to connect to upstream service: {e}")
|
| logger.error(f"β Error type: {type(e).__name__}")
|
|
|
| error_message = "Failed to connect to upstream service"
|
| error_chunk = {"error": {"message": error_message, "type": "connection_error"}}
|
| yield f"data: {json.dumps(error_chunk)}\n\n"
|
| yield "data: [DONE]\n\n"
|
| return
|
|
|
| if detector.state == "tool_parsing":
|
| logger.debug(f"π§ Stream ended, starting to parse tool call XML...")
|
| parsed_tools = detector.finalize()
|
| if parsed_tools:
|
| logger.debug(f"π§ Streaming processing: Successfully parsed {len(parsed_tools)} tool calls")
|
| for sse in _build_tool_call_sse_chunks(parsed_tools, model):
|
| yield sse
|
| return
|
| else:
|
| logger.error(f"β Detected tool call signal but XML parsing failed, buffer content: {detector.content_buffer}")
|
| error_content = "Error: Detected tool use signal but failed to parse function call format"
|
| error_chunk = { "id": "error-chunk", "choices": [{"delta": {"content": error_content}}]}
|
| yield f"data: {json.dumps(error_chunk)}\n\n"
|
|
|
| elif detector.state == "detecting" and detector.content_buffer:
|
|
|
| final_yield_chunk = {
|
| "id": f"chatcmpl-finalflush-{uuid.uuid4().hex}", "object": "chat.completion.chunk",
|
| "created": int(os.path.getmtime(__file__)), "model": model,
|
| "choices": [{"index": 0, "delta": {"content": detector.content_buffer}}]
|
| }
|
| yield f"data: {json.dumps(final_yield_chunk)}\n\n"
|
|
|
| yield "data: [DONE]\n\n"
|
|
|
|
|
| @app.get("/")
|
| def read_root():
|
| return {
|
| "status": "OpenAI Function Call Middleware is running",
|
| "config": {
|
| "upstream_services_count": len(app_config.upstream_services),
|
| "client_keys_count": len(app_config.client_authentication.allowed_keys),
|
| "models_count": len(MODEL_TO_SERVICE_MAPPING),
|
| "features": {
|
| "function_calling": app_config.features.enable_function_calling,
|
| "log_level": app_config.features.log_level,
|
| "convert_developer_to_system": app_config.features.convert_developer_to_system,
|
| "random_trigger": True
|
| }
|
| }
|
| }
|
|
|
| @app.get("/v1/models")
|
| async def list_models(_api_key: str = Depends(verify_api_key)):
|
| """List all available models"""
|
| visible_models = set()
|
| for model_name in MODEL_TO_SERVICE_MAPPING.keys():
|
| if ':' in model_name:
|
| parts = model_name.split(':', 1)
|
| if len(parts) == 2:
|
| alias, _ = parts
|
| visible_models.add(alias)
|
| else:
|
| visible_models.add(model_name)
|
| else:
|
| visible_models.add(model_name)
|
|
|
| models = []
|
| for model_id in sorted(visible_models):
|
| models.append({
|
| "id": model_id,
|
| "object": "model",
|
| "created": 1677610602,
|
| "owned_by": "openai",
|
| "permission": [],
|
| "root": model_id,
|
| "parent": None
|
| })
|
|
|
| return {
|
| "object": "list",
|
| "data": models
|
| }
|
|
|
|
|
| def validate_message_structure(messages: List[Dict[str, Any]]) -> bool:
|
| """Validate if message structure meets requirements"""
|
| try:
|
| valid_roles = ["system", "user", "assistant", "tool"]
|
| if not app_config.features.convert_developer_to_system:
|
| valid_roles.append("developer")
|
|
|
| for i, msg in enumerate(messages):
|
| if "role" not in msg:
|
| logger.error(f"β Message {i} missing role field")
|
| return False
|
|
|
| if msg["role"] not in valid_roles:
|
| logger.error(f"β Invalid role value for message {i}: {msg['role']}")
|
| return False
|
|
|
| if msg["role"] == "tool":
|
| if "tool_call_id" not in msg:
|
| logger.error(f"β Tool message {i} missing tool_call_id field")
|
| return False
|
|
|
| content = msg.get("content")
|
| content_info = ""
|
| if content:
|
| if isinstance(content, str):
|
| content_info = f", content=text({len(content)} chars)"
|
| elif isinstance(content, list):
|
| text_parts = [item for item in content if isinstance(item, dict) and item.get('type') == 'text']
|
| image_parts = [item for item in content if isinstance(item, dict) and item.get('type') == 'image_url']
|
| content_info = f", content=multimodal(text={len(text_parts)}, images={len(image_parts)})"
|
| else:
|
| content_info = f", content={type(content).__name__}"
|
| else:
|
| content_info = ", content=empty"
|
|
|
| logger.debug(f"β
Message {i} validation passed: role={msg['role']}{content_info}")
|
|
|
| logger.debug(f"β
All messages validated successfully, total {len(messages)} messages")
|
| return True
|
| except Exception as e:
|
| logger.error(f"β Message validation exception: {e}")
|
| return False
|
|
|
| def safe_process_tool_choice(tool_choice) -> str:
|
| """Safely process tool_choice field to avoid type errors"""
|
| try:
|
| if tool_choice is None:
|
| return ""
|
|
|
| if isinstance(tool_choice, str):
|
| if tool_choice == "none":
|
| return "\n\n**IMPORTANT:** You are prohibited from using any tools in this round. Please respond like a normal chat assistant and answer the user's question directly."
|
| else:
|
| logger.debug(f"π§ Unknown tool_choice string value: {tool_choice}")
|
| return ""
|
|
|
| elif hasattr(tool_choice, 'function') and hasattr(tool_choice.function, 'name'):
|
| required_tool_name = tool_choice.function.name
|
| return f"\n\n**IMPORTANT:** In this round, you must use ONLY the tool named `{required_tool_name}`. Generate the necessary parameters and output in the specified XML format."
|
|
|
| else:
|
| logger.debug(f"π§ Unsupported tool_choice type: {type(tool_choice)}")
|
| return ""
|
|
|
| except Exception as e:
|
| logger.error(f"β Error processing tool_choice: {e}")
|
| return ""
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
| logger.info(f"π Starting server on {app_config.server.host}:{app_config.server.port}")
|
| logger.info(f"β±οΈ Request timeout: {app_config.server.timeout} seconds")
|
|
|
| uvicorn.run(
|
| app,
|
| host=app_config.server.host,
|
| port=app_config.server.port,
|
| log_level=app_config.features.log_level.lower() if app_config.features.log_level != "DISABLED" else "critical"
|
| ) |