| |
|
|
| import ast |
| import json |
| import re |
| from collections.abc import Sequence |
| from typing import Union |
|
|
| import partial_json_parser |
| from partial_json_parser.core.options import Allow |
|
|
| from vllm.entrypoints.openai.protocol import ( |
| ChatCompletionRequest, |
| DeltaFunctionCall, DeltaMessage, |
| DeltaToolCall, |
| ExtractedToolCallInformation, |
| FunctionCall, |
| ToolCall, |
| ) |
| from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| ToolParser, |
| ToolParserManager, |
| ) |
| from vllm.logger import init_logger |
| from vllm.transformers_utils.tokenizer import AnyTokenizer |
| from vllm.utils import random_uuid |
|
|
| logger = init_logger(__name__) |
|
|
|
|
| @ToolParserManager.register_module("llama_nemotron_xml") |
| class LlamaNemotronXMLToolParser(ToolParser): |
|
|
| def __init__(self, tokenizer: AnyTokenizer): |
| super().__init__(tokenizer) |
|
|
| self.current_tool_name_sent: bool = False |
| self.prev_tool_call_arr: list[dict] = [] |
| self.current_tool_id: int = -1 |
| self.streamed_args_for_tool: list[str] = [] |
|
|
| self.tool_call_start_token: str = "<tool_call>" |
| self.tool_call_end_token: str = "</tool_call>" |
|
|
| |
| self.tool_call_block_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL) |
| |
| self.name_regex = re.compile(r"<tool>(.*?)</tool>", re.DOTALL) |
| |
| self.param_regex = re.compile(r"<([^/>\s]+)>(.*?)</\1>", re.DOTALL) |
|
|
| def extract_tool_calls( |
| self, |
| model_output: str, |
| request: ChatCompletionRequest, |
| ) -> ExtractedToolCallInformation: |
|
|
| tool_call_start_index = model_output.find(self.tool_call_start_token) |
|
|
| if tool_call_start_index == -1: |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
| |
| content = model_output[:tool_call_start_index].strip() |
| tool_calls_str_content = model_output[tool_call_start_index:] |
|
|
| parsed_tool_calls = [] |
| |
| try: |
| |
| xml_tool_call_contents = self.tool_call_block_regex.findall(tool_calls_str_content) |
|
|
| for tool_content_str in xml_tool_call_contents: |
| name_match = self.name_regex.search(tool_content_str) |
| if not name_match: |
| logger.warning(f"Could not find tool name in XML block: {tool_content_str}") |
| continue |
| tool_name = name_match.group(1).strip() |
|
|
| parsed_arguments = {} |
| |
| |
| param_matches = self.param_regex.finditer(tool_content_str) |
| |
| for match in param_matches: |
| param_name = match.group(1).strip() |
| param_value_str = match.group(2).strip() |
| |
| |
| if param_name == "tool": |
| continue |
| |
| target_type = None |
| |
| if request.tools: |
| for tool_def in request.tools: |
| if tool_def.function.name == tool_name: |
| if tool_def.function.parameters and \ |
| isinstance(tool_def.function.parameters, dict) and \ |
| "properties" in tool_def.function.parameters and \ |
| isinstance(tool_def.function.parameters["properties"], dict) and \ |
| param_name in tool_def.function.parameters["properties"] and \ |
| isinstance(tool_def.function.parameters["properties"][param_name], dict): |
| target_type = tool_def.function.parameters["properties"][param_name].get("type") |
| break |
| |
| typed_param_value = param_value_str |
| if target_type: |
| try: |
| if target_type == "string": |
| typed_param_value = param_value_str |
| elif target_type == "integer": |
| typed_param_value = int(param_value_str) |
| elif target_type == "number": |
| typed_param_value = float(param_value_str) |
| elif target_type == "boolean": |
| typed_param_value = param_value_str.lower() == 'true' |
| elif target_type in ["object", "array"]: |
| try: |
| typed_param_value = json.loads(param_value_str) |
| except json.JSONDecodeError: |
| |
| typed_param_value = ast.literal_eval(param_value_str) |
| else: |
| typed_param_value = param_value_str |
| except (ValueError, SyntaxError, json.JSONDecodeError) as e: |
| logger.warning( |
| f"Could not convert param '{param_name}' with value '{param_value_str}' " |
| f"to type '{target_type}'. Error: {e}. Using string value." |
| ) |
| typed_param_value = param_value_str |
| else: |
| try: |
| |
| |
| if (param_value_str.startswith("'") and param_value_str.endswith("'")) or \ |
| (param_value_str.startswith('"') and param_value_str.endswith('"')) or \ |
| (param_value_str.startswith('[') and param_value_str.endswith(']')) or \ |
| (param_value_str.startswith('{') and param_value_str.endswith('}')) or \ |
| param_value_str.lower() in ['true', 'false', 'none'] or \ |
| param_value_str.replace('.', '', 1).isdigit() or \ |
| (param_value_str.startswith('-') and param_value_str[1:].replace('.', '', 1).isdigit()): |
| typed_param_value = ast.literal_eval(param_value_str) |
| else: |
| typed_param_value = param_value_str |
| except (ValueError, SyntaxError): |
| typed_param_value = param_value_str |
|
|
| parsed_arguments[param_name] = typed_param_value |
| |
| parsed_tool_calls.append(ToolCall( |
| id=f"call_{random_uuid()}", |
| type="function", |
| function=FunctionCall( |
| name=tool_name, |
| arguments=json.dumps(parsed_arguments, ensure_ascii=False), |
| ), |
| )) |
|
|
| return ExtractedToolCallInformation( |
| tools_called=len(parsed_tool_calls) > 0, |
| tool_calls=parsed_tool_calls, |
| content=content if content else None, |
| ) |
|
|
| except Exception: |
| logger.exception(f"Error in extracting XML tool call from response. Response: {model_output}") |
| |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
|
|
| def extract_tool_calls_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| request: ChatCompletionRequest, |
| ) -> Union[DeltaMessage, None]: |
|
|
| raise NotImplementedError("Tool calling is not supported in streaming mode!") |
|
|
|
|
| @ToolParserManager.register_module("llama_nemotron_json") |
| class LlamaNemotronJSONToolParser(ToolParser): |
|
|
| def __init__(self, tokenizer: AnyTokenizer): |
| super().__init__(tokenizer) |
|
|
| self.current_tool_name_sent: bool = False |
| self.prev_tool_call_arr: list[dict] = [] |
| self.current_tool_id: int = -1 |
| self.streamed_args_for_tool: list[str] = [] |
|
|
| self.tool_call_start_token: str = "<TOOLCALL>" |
| self.tool_call_end_token: str = "</TOOLCALL>" |
|
|
| self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL) |
|
|
| def extract_tool_calls( |
| self, |
| model_output: str, |
| request: ChatCompletionRequest, |
| ) -> ExtractedToolCallInformation: |
|
|
| if self.tool_call_start_token not in model_output: |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
|
|
| else: |
|
|
| try: |
| str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip() |
| if not str_tool_calls.startswith("["): |
| str_tool_calls = "[" + str_tool_calls |
| if not str_tool_calls.endswith("]"): |
| str_tool_calls = str_tool_calls + "]" |
| json_tool_calls = json.loads(str_tool_calls) |
| tool_calls = [] |
| for tool_call in json_tool_calls: |
| try: |
| tool_calls.append(ToolCall( |
| type="function", |
| function=FunctionCall( |
| name=tool_call["name"], |
| arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \ |
| if isinstance(tool_call["arguments"], dict) else tool_call["arguments"], |
| ), |
| )) |
| except: |
| continue |
|
|
| content = model_output[:model_output.rfind(self.tool_call_start_token)] |
|
|
| return ExtractedToolCallInformation( |
| tools_called=True, |
| tool_calls=tool_calls, |
| content=content if content else None, |
| ) |
|
|
| except Exception: |
| logger.exception(f"Error in extracting tool call from response. Response: {model_output}") |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
|
|
| def extract_tool_calls_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| request: ChatCompletionRequest, |
| ) -> Union[DeltaMessage, None]: |
|
|
| raise NotImplementedError("Tool calling is not supported in streaming mode!") |
|
|
|
|
| @ToolParserManager.register_module("llama_nemotron_pythonic") |
| class LlamaNemotronPythonicToolParser(ToolParser): |
|
|
| def __init__(self, tokenizer: AnyTokenizer): |
| super().__init__(tokenizer) |
|
|
| self.current_tool_name_sent: bool = False |
| self.prev_tool_call_arr: list[dict] = [] |
| self.current_tool_id: int = -1 |
| self.streamed_args_for_tool: list[str] = [] |
|
|
| self.tool_call_start_token: str = "<TOOLCALL>" |
| self.tool_call_end_token: str = "</TOOLCALL>" |
|
|
| self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL) |
| |
| self.function_call_regex = re.compile(r"(\w+)\((.*?)\)$", re.DOTALL) |
|
|
| def parse_function_arguments(self, args_str: str) -> dict: |
| """Parse pythonic function arguments string into a dictionary""" |
| if not args_str.strip(): |
| return {} |
| |
| |
| |
| try: |
| |
| dummy_code = f"dummy_func({args_str})" |
| parsed = ast.parse(dummy_code, mode='eval') |
| |
| |
| call_node = parsed.body |
| if not isinstance(call_node, ast.Call): |
| return {} |
| |
| arguments = {} |
| |
| |
| for keyword in call_node.keywords: |
| if keyword.arg is None: |
| continue |
| |
| |
| try: |
| value = ast.literal_eval(keyword.value) |
| arguments[keyword.arg] = value |
| except (ValueError, TypeError): |
| |
| if isinstance(keyword.value, ast.Name): |
| arguments[keyword.arg] = keyword.value.id |
| elif isinstance(keyword.value, ast.Constant): |
| arguments[keyword.arg] = keyword.value.value |
| else: |
| |
| arguments[keyword.arg] = ast.unparse(keyword.value) |
| |
| |
| for i, arg in enumerate(call_node.args): |
| try: |
| value = ast.literal_eval(arg) |
| arguments[f"arg_{i}"] = value |
| except (ValueError, TypeError): |
| if isinstance(arg, ast.Name): |
| arguments[f"arg_{i}"] = arg.id |
| elif isinstance(arg, ast.Constant): |
| arguments[f"arg_{i}"] = arg.value |
| else: |
| arguments[f"arg_{i}"] = ast.unparse(arg) |
| |
| return arguments |
| |
| except (SyntaxError, ValueError) as e: |
| logger.warning(f"Failed to parse function arguments '{args_str}': {e}") |
| return {} |
|
|
| def extract_tool_calls( |
| self, |
| model_output: str, |
| request: ChatCompletionRequest, |
| ) -> ExtractedToolCallInformation: |
|
|
| if self.tool_call_start_token not in model_output: |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
|
|
| tool_call_start_index = model_output.find(self.tool_call_start_token) |
| content = model_output[:tool_call_start_index].strip() |
| |
| try: |
| |
| tool_call_matches = self.tool_call_regex.findall(model_output) |
| if not tool_call_matches: |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
| |
| tool_calls_content = tool_call_matches[0].strip() |
| |
| |
| function_lines = [line.strip() for line in tool_calls_content.split('\n') if line.strip()] |
| |
| parsed_tool_calls = [] |
| |
| for func_line in function_lines: |
| |
| match = self.function_call_regex.match(func_line) |
| if not match: |
| logger.warning(f"Could not parse function call: {func_line}") |
| continue |
| |
| function_name = match.group(1) |
| args_str = match.group(2) |
| |
| |
| parsed_arguments = self.parse_function_arguments(args_str) |
| |
| |
| if request.tools: |
| for tool_def in request.tools: |
| if tool_def.function.name == function_name: |
| schema_properties = {} |
| if (tool_def.function.parameters and |
| isinstance(tool_def.function.parameters, dict) and |
| "properties" in tool_def.function.parameters and |
| isinstance(tool_def.function.parameters["properties"], dict)): |
| schema_properties = tool_def.function.parameters["properties"] |
| |
| |
| for arg_name, arg_value in parsed_arguments.items(): |
| if arg_name in schema_properties: |
| param_info = schema_properties[arg_name] |
| target_type = param_info.get("type") |
| |
| try: |
| if target_type == "string" and not isinstance(arg_value, str): |
| parsed_arguments[arg_name] = str(arg_value) |
| elif target_type == "integer" and not isinstance(arg_value, int): |
| parsed_arguments[arg_name] = int(arg_value) |
| elif target_type == "number" and not isinstance(arg_value, (int, float)): |
| parsed_arguments[arg_name] = float(arg_value) |
| elif target_type == "boolean" and not isinstance(arg_value, bool): |
| if isinstance(arg_value, str): |
| parsed_arguments[arg_name] = arg_value.lower() in ['true', '1', 'yes'] |
| else: |
| parsed_arguments[arg_name] = bool(arg_value) |
| elif target_type in ["object", "array"]: |
| if isinstance(arg_value, str): |
| try: |
| parsed_arguments[arg_name] = json.loads(arg_value) |
| except json.JSONDecodeError: |
| |
| pass |
| except (ValueError, TypeError) as e: |
| logger.warning(f"Type conversion failed for {arg_name}: {e}") |
| |
| break |
| |
| parsed_tool_calls.append(ToolCall( |
| id=f"call_{random_uuid()}", |
| type="function", |
| function=FunctionCall( |
| name=function_name, |
| arguments=json.dumps(parsed_arguments, ensure_ascii=False), |
| ), |
| )) |
|
|
| return ExtractedToolCallInformation( |
| tools_called=len(parsed_tool_calls) > 0, |
| tool_calls=parsed_tool_calls, |
| content=content if content else None, |
| ) |
|
|
| except Exception: |
| logger.exception(f"Error in extracting pythonic tool call from response. Response: {model_output}") |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
|
|
| def extract_tool_calls_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| request: ChatCompletionRequest, |
| ) -> Union[DeltaMessage, None]: |
|
|
| raise NotImplementedError("Tool calling is not supported in streaming mode!") |
|
|