injected_thinking / Inference /inference_demo_drivelmm.py
BechusRantus's picture
Upload folder using huggingface_hub
7134ce7 verified
# Copyright (c) Kangan Qian. All rights reserved.
# Authors: Kangan Qian (Tsinghua University, Xiaomi Corporation)
# Description: Tool integration with Qwen2.5-VL model for autonomous driving inference
import json
import os
import time
import base64
import io
import sys
from typing import Callable, Any
from PIL import Image
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from scripts.tools.tool_libraries import FuncAgent
from scripts.tools.gen_tool_result_data import AgentThink
from qwen_vl_utils import process_vision_info
def pil_to_base64(pil_image: Image.Image) -> str:
"""
Convert a PIL image to a base64-encoded string
Args:
pil_image (Image.Image): PIL image object to convert
Returns:
str: Base64-encoded string representation of the image
Raises:
RuntimeError: If conversion fails
"""
try:
binary_stream = io.BytesIO()
pil_image.save(binary_stream, format="PNG")
binary_data = binary_stream.getvalue()
return base64.b64encode(binary_data).decode('utf-8')
except Exception as e:
raise RuntimeError(f"Image to base64 conversion failed: {e}")
def inference_with_retry(
inference_func: Callable,
*args: Any,
max_retries: int = 3,
retry_delay: int = 3,
**kwargs: Any
) -> str:
"""
Execute an inference function with automatic retries on failure
Args:
inference_func (Callable): Inference function to call
*args: Positional arguments for the inference function
max_retries (int): Maximum number of retry attempts
retry_delay (int): Delay between retry attempts in seconds
**kwargs: Keyword arguments for the inference function
Returns:
str: Output from the inference function
Raises:
RuntimeError: If maximum retries are exceeded without success
"""
retries = 0
while retries < max_retries:
try:
return inference_func(*args, **kwargs)
except Exception as e:
print(f"Inference error: {e}. Retry {retries+1}/{max_retries}...")
retries += 1
time.sleep(retry_delay)
raise RuntimeError(f"Inference failed after {max_retries} retries")
class Qwen2_5VLInterface:
def __init__(self, model_path: str) -> None:
"""
Initialize Qwen2.5-VL model interface
Args:
model_path (str): Path to pretrained model
"""
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
self.processor = AutoProcessor.from_pretrained(model_path)
def inference(self, pil_image: Image.Image, prompt: str, max_tokens: int = 4096) -> str:
"""
Perform inference using the Qwen2.5-VL model
Args:
pil_image (Image.Image): Input image
prompt (str): Text prompt for the model
max_tokens (int): Maximum number of tokens to generate
Returns:
str: Model output text
"""
# Convert image to base64 for model input
image_base64 = pil_to_base64(pil_image)
image_url = f"data:image;base64,{image_base64}"
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_url},
{"type": "text", "text": prompt},
],
}
]
# Process inputs
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Generate model output
generated_ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
def execute_tool_call(
func_agent: FuncAgent,
tool_name: str,
tool_args: dict,
verbose: bool = True
) -> dict:
"""
Execute a tool call using the function agent
Args:
func_agent (FuncAgent): Function agent instance
tool_name (str): Name of the tool to execute
tool_args (dict): Arguments for the tool
verbose (bool): Whether to print tool execution details
Returns:
dict: Tool response containing name, arguments, and prompt
"""
try:
tool_function = getattr(func_agent, tool_name)
except AttributeError:
print(f"Error: Tool '{tool_name}' not found")
return None
if not callable(tool_function):
print(f"Error: '{tool_name}' is not a callable function")
return None
try:
tool_prompt, tool_result_data = tool_function(**tool_args)
except Exception as e:
print(f"Error executing tool '{tool_name}': {e}")
return None
if tool_prompt is None:
tool_prompt = ""
tool_response = {
"name": tool_name,
"args": tool_args,
"prompt": tool_prompt,
}
if verbose:
print(f"Tool: {tool_name}")
print(f"Arguments: {tool_args}")
print(f"Prompt: {tool_prompt}")
return tool_response
def run_model_inference(
image_path: str,
prompt: str,
model_path: str = "./pretrained_model"
) -> str:
"""
Run inference using the chat model
Args:
image_path (str): Path to input image file
prompt (str): Text prompt for the model
model_path (str): Path to model checkpoint
Returns:
str: Model output text
"""
image = Image.open(image_path)
model_interface = Qwen2_5VLInterface(model_path)
return inference_with_retry(
model_interface.inference,
image,
prompt,
max_retries=3,
retry_delay=3
)
def main():
"""Main function to process JSON data and run model inference"""
# Path configuration
json_file = "./Inference/inference_demo_data_drivelmm.json"
tool_data_path = "./data/tool_results"
image_base_path = "./data/image2concat"
model_path = "./pretrained_model"
AgentThink_model = os.path.join(model_path, 'AgentThink')
# AgentThink_model = os.path.join(model_path, 'checkpoint-700-merged')
# Load JSON data
with open(json_file, "r", encoding="utf-8") as file:
json_data = json.load(file)
# Process each sample in the JSON data
for sample in json_data:
sample_idx = sample['idx']
scene_token, frame_token, question_id = sample_idx.split('_', 2)
# breakpoint()
# Initialize agent for tool execution
agent = AgentThink(
token=frame_token,
split='val',
data_path=tool_data_path,
model_name='Qwen2.5-VL'
)
# Prepare image path
filename = sample_idx.rsplit('_', 1)[0] + '.png'
image_path = os.path.join(image_base_path, filename)
# Process tool chain
tool_chain = sample['tool_result']
system_prompt = sample['system_prompts']
question = sample['question']
tool_prompt = ""
for tool_node in tool_chain:
if tool_node is None:
continue
tool_name = tool_node['name']
tool_args = tool_node['args']
tool_response = execute_tool_call(
agent.func_agent,
tool_name,
tool_args
)
if tool_response:
tool_prompt += tool_response['prompt']
# Construct full prompt for model inference
full_prompt = f"{system_prompt}\n{question}\nTool results:{tool_prompt}"
# Run model inference
model_output = run_model_inference(image_path, full_prompt, AgentThink_model)
print(f"Sample {sample_idx} output: {model_output}")
if __name__ == "__main__":
main()