| import os |
| import json |
| import base64 |
| import argparse |
| import time |
| import re |
| import traceback |
| from datetime import datetime |
| from functools import partial |
| import requests |
| from openai import AzureOpenAI, OpenAI |
| from volcenginesdkarkruntime import Ark |
| import concurrent.futures |
| from tqdm import tqdm |
|
|
| |
| |
| IMAGINE_AGENT_SYSTEM_PROMPT = """ |
| You are an intelligent AI assistant specializing in answering video question-answering problems through reasoning and imagination. |
| Your task is to answer a multiple-choice question based on an initial, limited set of video frames. |
| |
| You will receive a few uniformly sampled frames to get a basic understanding of the video. |
| These frames may not contain all the visual evidence needed to directly answer the question. |
| |
| If the provided frame information is insufficient, you must use the `imagine_frame` tool to generate new, imagined frames to fill in the visual gaps and aid your reasoning. |
| You can call this tool multiple times to construct a sequence of imagined events. |
| |
| Your strategy should be: |
| 1. Analyze the initial frames and the user's question. |
| 2. Form a hypothesis about the missing content. |
| 3. If you need more visual information, call the `imagine_frame` tool. Provide a text `prompt` describing the scene you want to imagine, and select a `reference_image_id` from existing frames. The `reference_image_id` MUST be one of the IDs explicitly provided to you in the conversation history (e.g., "Frame ID: X" or "New Frame ID: Y"). Do not invent or assume frame IDs. |
| 4. Analyze the newly generated frame in conjunction with the existing ones. |
| 5. Continue this process of reasoning and imagination until you are confident in your answer. Please ensure you have found or created the relevant visual cues before answering the question. |
| 6. Each tool call can only generate one frame. |
| |
| IMPORTANT: Your text `prompt` for image generation must be safe and general. Avoid descriptions that could be interpreted as sensitive, harmful, or explicit to prevent generation failures. |
| |
| After your reasoning, provide the final answer in a JSON code block. The JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'. |
| |
| Your output must strictly follow this format: |
| <Your step-by-step reasoning process here, including why you chose to imagine a certain frame> |
| ```json |
| {"answer": "X"} |
| ``` |
| Do not include any other text after the JSON code block. |
| """ |
|
|
| |
| |
| IMAGINE_FRAME_TOOL_SCHEMA = { |
| "type": "function", |
| "function": { |
| "name": "imagine_frame", |
| "description": "When visual evidence is insufficient, generates a new image based on a text prompt and a reference image to help answer the question. Use it to imagine what might have happened between the provided frames.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "reference_image_id": { |
| "type": "integer", |
| "description": "The ID of an existing frame to use as a style and content reference. It can be one of the original frames or a previously generated one.", |
| }, |
| "prompt": { |
| "type": "string", |
| "description": "A detailed text description of the frame you want to imagine and generate.", |
| }, |
| }, |
| "required": ["reference_image_id", "prompt"], |
| }, |
| }, |
| } |
|
|
|
|
| |
| def imagine_frame( |
| reference_image_id: int, |
| prompt: str, |
| all_frame_paths: dict, |
| output_dir: str, |
| generation_count: int, |
| ): |
| """ |
| Tool implementation: Calls an image generation model to create a new frame. |
| |
| Args: |
| reference_image_id (int): The ID of the reference frame. |
| prompt (str): The text prompt for image generation. |
| all_frame_paths (dict): A dictionary containing IDs and paths of all currently available frames (original + generated). |
| output_dir (str): The directory to save the generated image. |
| generation_count (int): The current generation count, used for naming the file. |
| |
| Returns: |
| str or None: The path of the newly generated image on success, otherwise None. |
| """ |
| print(f"\n[Tool Call] Imagining new frame with prompt: '{prompt}'") |
| ark_api_key = os.environ.get("ARK_API_KEY") |
| if not ark_api_key: |
| raise ValueError("Error: Environment variable ARK_API_KEY is not set.") |
|
|
| client = Ark( |
| base_url="https://ark.cn-beijing.volces.com/api/v3", |
| api_key=ark_api_key, |
| ) |
|
|
| ref_image_path = all_frame_paths.get(reference_image_id) |
| if not ref_image_path or not os.path.exists(ref_image_path): |
| raise FileNotFoundError(f"Reference image ID not found: {reference_image_id}") |
|
|
| try: |
| |
| ref_image_b64 = encode_image(ref_image_path) |
| ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}" |
|
|
| imagesResponse = client.images.generate( |
| model="doubao-seedream-4-0-250828", |
| prompt=prompt, |
| image=ref_image_data_uri, |
| size="1024x1024", |
| response_format="url", |
| watermark=False, |
| ) |
|
|
| image_url = imagesResponse.data[0].url |
|
|
| |
| response = requests.get(image_url) |
| response.raise_for_status() |
|
|
| |
| new_frame_filename = ( |
| f"generated_frame_{generation_count}_ref_{reference_image_id}.jpg" |
| ) |
| new_frame_path = os.path.join(output_dir, new_frame_filename) |
|
|
| with open(new_frame_path, "wb") as f: |
| f.write(response.content) |
|
|
| print(f"[Tool Success] Generated frame saved to: {new_frame_path}") |
| return new_frame_path |
|
|
| except Exception as e: |
| print(f"An error occurred during image generation or download: {e}") |
| traceback.print_exc() |
| return None |
|
|
|
|
| def parse_arguments(): |
| """Parse command-line arguments""" |
| parser = argparse.ArgumentParser( |
| description="Video QA Evaluation Framework with Imagine-and-Reason Agent" |
| ) |
| parser.add_argument( |
| "--target-model", |
| "-tm", |
| type=str, |
| required=True, |
| help="The model to be evaluated (e.g., gpt-4o)", |
| ) |
| parser.add_argument( |
| "--frames-path", |
| "-fp", |
| type=str, |
| required=True, |
| help="Absolute path to the root directory containing video frames.", |
| ) |
| parser.add_argument( |
| "--output-path", |
| "-op", |
| type=str, |
| default="./generated_outputs", |
| help="Path to store generated images and results.", |
| ) |
| parser.add_argument( |
| "--data-file", |
| "-df", |
| type=str, |
| required=True, |
| help="Absolute path to the evaluation dataset JSON file.", |
| ) |
| parser.add_argument( |
| "--initial-frames-num", |
| "-ifn", |
| type=int, |
| default=8, |
| help="Number of initial uniformly sampled frames.", |
| ) |
| parser.add_argument( |
| "--max-retry-times", |
| "-mr", |
| type=int, |
| default=10, |
| help="Maximum number of retries for failed API calls.", |
| ) |
| parser.add_argument( |
| "--pool-processes", |
| "-pp", |
| type=int, |
| default=10, |
| help="Number of parallel processes.", |
| ) |
| parser.add_argument( |
| "--base_url", |
| type=str, |
| required=True, |
| help="API Endpoint URL for the target model service.", |
| ) |
| parser.add_argument( |
| "--api_key", |
| type=str, |
| required=True, |
| help="API Key for the target model service.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def save_json_file(data, output_file): |
| """Save data to a JSON file""" |
| with open(output_file, "w", encoding="utf-8") as f: |
| json.dump(data, f, indent=4, ensure_ascii=False) |
|
|
|
|
| def extract_json_from_response(response): |
| """Extract JSON answer from the model's text response""" |
| if not response: |
| return None |
| match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL) |
| if match: |
| try: |
| return json.loads(match.group(1)) |
| except (json.JSONDecodeError, IndexError): |
| return None |
| return None |
|
|
|
|
| def calculate_metrics(results): |
| """Calculate various metrics from the evaluation results""" |
| valid_results = [r for r in results if "error" not in r] |
| total_samples = len(valid_results) |
| if total_samples == 0: |
| return { |
| "total_samples": 0, |
| "answered_samples": 0, |
| "correct_answers": 0, |
| "accuracy": 0.0, |
| } |
| answered_samples = sum( |
| 1 for x in valid_results if x.get("model_answer") is not None |
| ) |
| correct_answers = sum(1 for x in valid_results if x.get("is_correct")) |
| accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0 |
| return { |
| "total_samples": total_samples, |
| "answered_samples": answered_samples, |
| "correct_answers": correct_answers, |
| "accuracy": accuracy, |
| } |
|
|
|
|
| def call_single_model(client, messages, model, item_id, max_retry_times, tools=None): |
| """A single model API call with retry logic""" |
| params = {"model": model, "messages": messages, "max_tokens": 4096} |
| if tools: |
| params["tools"] = tools |
| params["tool_choice"] = "auto" |
|
|
| retry_times = 0 |
| while retry_times < max_retry_times: |
| try: |
| completion = client.chat.completions.create(**params) |
| return completion.choices[0].message |
| except Exception as e: |
| retry_times += 1 |
| print( |
| f"API call error (Item {item_id}): {str(e)}. Retrying ({retry_times}/{max_retry_times})..." |
| ) |
| if retry_times == max_retry_times: |
| raise e |
| time.sleep(5) |
|
|
|
|
| def uniformly_sample_frames_and_encode(frames_dir, num_frames): |
| """Uniformly sample a specified number of frames from a directory and encode them""" |
| if not os.path.isdir(frames_dir): |
| return [], {} |
|
|
| frame_files = sorted( |
| [f for f in os.listdir(frames_dir) if f.endswith(".jpg")], |
| key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)), |
| ) |
|
|
| total_frames = len(frame_files) |
| if total_frames == 0: |
| return [], {} |
|
|
| if total_frames > num_frames: |
| indices = [int(i * total_frames / num_frames) for i in range(num_frames)] |
| sampled_files = [frame_files[i] for i in indices] |
| else: |
| sampled_files = frame_files |
|
|
| frame_path_map = {} |
| encoded_frames = [] |
| for f in sampled_files: |
| path = os.path.join(frames_dir, f) |
| frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1)) |
| b64_image = encode_image(path) |
| |
| encoded_frames.append({"type": "text", "text": f"This is Frame ID: {frame_id}"}) |
| encoded_frames.append( |
| { |
| "type": "image_url", |
| "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"}, |
| } |
| ) |
| frame_path_map[frame_id] = path |
|
|
| return encoded_frames, frame_path_map |
|
|
|
|
| def evaluate_single_item_agentic_imagination( |
| data_item, |
| initial_frames, |
| initial_frame_paths, |
| generated_images_dir, |
| target_model, |
| api_key, |
| base_url, |
| max_retry_times, |
| ): |
| """ |
| Core logic for evaluating a single data item using the Imagine-and-Reason Agent. |
| """ |
| |
| if "ark" in base_url: |
| client = Ark(base_url=base_url, api_key=api_key) |
| elif "aliyun" in base_url or "127.0.0.1" in base_url: |
| client = OpenAI(api_key=api_key, base_url=base_url) |
| else: |
| client = AzureOpenAI( |
| api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url |
| ) |
|
|
| tools = [IMAGINE_FRAME_TOOL_SCHEMA] |
|
|
| |
| available_frame_paths = initial_frame_paths.copy() |
|
|
| initial_prompt_content = [ |
| { |
| "type": "text", |
| "text": "Here are the initial sampled video frames provided to you:", |
| }, |
| *initial_frames, |
| { |
| "type": "text", |
| "text": f"Please answer the following question:\n{data_item['question']}", |
| }, |
| ] |
|
|
| messages = [ |
| {"role": "system", "content": IMAGINE_AGENT_SYSTEM_PROMPT}, |
| {"role": "user", "content": initial_prompt_content}, |
| ] |
|
|
| response_content = None |
| max_tool_calls = ( |
| 5 |
| ) |
| generation_count = 0 |
|
|
| for i in range(max_tool_calls): |
| response_message = call_single_model( |
| client, |
| messages, |
| target_model, |
| data_item["key"], |
| max_retry_times, |
| tools=tools, |
| ) |
| if response_message is None: |
| return None |
|
|
| messages.append(response_message.model_dump(exclude_none=True)) |
|
|
| if response_message.tool_calls: |
| tool_call = response_message.tool_calls[ |
| 0 |
| ] |
| function_name = tool_call.function.name |
|
|
| if function_name == "imagine_frame": |
| generation_count += 1 |
| function_args = json.loads(tool_call.function.arguments) |
| new_frame_path = imagine_frame( |
| **function_args, |
| all_frame_paths=available_frame_paths, |
| output_dir=generated_images_dir, |
| generation_count=generation_count, |
| ) |
|
|
| if new_frame_path: |
| |
| new_frame_id = ( |
| max(available_frame_paths.keys()) |
| if available_frame_paths |
| else 0 |
| ) + 1 |
| available_frame_paths[new_frame_id] = new_frame_path |
|
|
| b64_image = encode_image(new_frame_path) |
| tool_response_content = [ |
| { |
| "type": "text", |
| "text": f"Here is the frame you requested to imagine (New Frame ID: {new_frame_id}). Please use it to continue your reasoning.", |
| }, |
| { |
| "type": "image_url", |
| "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"}, |
| }, |
| ] |
|
|
| messages.append( |
| { |
| "tool_call_id": tool_call.id, |
| "role": "tool", |
| "name": function_name, |
| "content": json.dumps( |
| {"status": "success", "new_frame_id": new_frame_id} |
| ), |
| } |
| ) |
| messages.append({"role": "user", "content": tool_response_content}) |
| else: |
| messages.append( |
| { |
| "tool_call_id": tool_call.id, |
| "role": "tool", |
| "name": function_name, |
| "content": json.dumps( |
| { |
| "status": "error", |
| "message": "Failed to generate image.", |
| } |
| ), |
| } |
| ) |
| else: |
| response_content = response_message.content |
| break |
|
|
| |
| if response_content is None and response_message: |
| final_prompt = "You have reached the maximum number of tool calls. Please provide a final answer in the specified JSON format based on the information you have gathered so far." |
| messages.append({"role": "user", "content": final_prompt}) |
| final_response_message = call_single_model( |
| client, messages, target_model, data_item["key"], max_retry_times |
| ) |
| if final_response_message: |
| messages.append(final_response_message.model_dump(exclude_none=True)) |
| response_content = final_response_message.content |
|
|
| is_correct = False |
| model_answer_cleaned = None |
| parsed_json = extract_json_from_response(response_content) |
| if parsed_json and "answer" in parsed_json: |
| model_answer_cleaned = str(parsed_json["answer"]).strip().upper() |
| gold_answer = data_item["answer"].strip().upper() |
| if model_answer_cleaned == gold_answer: |
| is_correct = True |
|
|
| return { |
| **data_item, |
| "agent_conversation": messages, |
| "model_reasoning_and_answer": response_content, |
| "model_answer": model_answer_cleaned, |
| "is_correct": is_correct, |
| "generated_images_path": generated_images_dir, |
| } |
|
|
|
|
| def encode_image(image_path): |
| """Encode an image file to a Base64 string""" |
| with open(image_path, "rb") as image_file: |
| return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
| def process_single_data(data_item, args): |
| """Worker function to process a single data item in parallel""" |
| item_key = data_item["key"] |
| try: |
| |
| generated_images_dir = os.path.join( |
| args.output_path, "generated_images", item_key |
| ) |
| os.makedirs(generated_images_dir, exist_ok=True) |
|
|
| specific_frames_path = os.path.join(args.frames_path, item_key) |
| initial_frames, initial_frame_paths = uniformly_sample_frames_and_encode( |
| specific_frames_path, args.initial_frames_num |
| ) |
|
|
| if not initial_frames: |
| raise FileNotFoundError(f"Initial frames not found for item '{item_key}'") |
|
|
| result = evaluate_single_item_agentic_imagination( |
| data_item, |
| initial_frames, |
| initial_frame_paths, |
| generated_images_dir, |
| args.target_model, |
| args.api_key, |
| args.base_url, |
| args.max_retry_times, |
| ) |
| return result |
|
|
| except Exception as e: |
| print(f"\nA critical error occurred while processing item {item_key}: {str(e)}") |
| traceback.print_exc() |
| return { |
| "key": item_key, |
| "uid": data_item.get("uid"), |
| "error": str(e), |
| "traceback": traceback.format_exc(), |
| } |
|
|
|
|
| def load_test_data(json_file): |
| """Load test data from a JSON file""" |
| try: |
| with open(json_file, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except FileNotFoundError: |
| print(f"Error: Data file not found: {json_file}") |
| exit(1) |
| except json.JSONDecodeError: |
| print(f"Error: JSON file is malformed: {json_file}") |
| exit(1) |
|
|
|
|
| def main(): |
| """Main function to orchestrate the entire evaluation flow""" |
| args = parse_arguments() |
|
|
| print("--- Video QA Imagine-and-Reason Agent Framework ---") |
| print(f"Evaluating Model: {args.target_model}") |
| print(f"Output Path: {args.output_path}") |
| print(f"Dataset: {args.data_file}") |
| print("---------------------------------") |
|
|
| |
| os.makedirs(args.output_path, exist_ok=True) |
|
|
| model_name_safe = args.target_model.replace("/", "_") |
| data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
|
|
| output_prefix = f"{model_name_safe}_{data_filename_base}_imagine_agent" |
| results_output_file = os.path.join( |
| args.output_path, f"{output_prefix}_results.json" |
| ) |
| metrics_output_file = os.path.join( |
| args.output_path, f"{output_prefix}_metrics.json" |
| ) |
| error_log_file = os.path.join(args.output_path, f"{output_prefix}_errors.log") |
|
|
| |
|
|
| all_test_data = load_test_data(args.data_file) |
| tasks_to_process = all_test_data |
|
|
| all_results = [] |
| |
| with concurrent.futures.ProcessPoolExecutor( |
| max_workers=args.pool_processes |
| ) as executor: |
| func = partial(process_single_data, args=args) |
| results_iterator = executor.map(func, tasks_to_process) |
|
|
| for result in tqdm( |
| results_iterator, total=len(tasks_to_process), desc="Processing Videos" |
| ): |
| if result: |
| if "error" in result: |
| with open(error_log_file, "a", encoding="utf-8") as f: |
| f.write( |
| f"Error on item {result.get('key', 'N/A')}:\n Error: {result['error']}\n---\n" |
| ) |
| all_results.append(result) |
|
|
| |
| if len(all_results) % 10 == 0: |
| save_json_file(all_results, results_output_file) |
|
|
| print("\n\nProcessing complete.") |
| |
| save_json_file(all_results, results_output_file) |
| print(f"Detailed results saved to: {results_output_file}") |
|
|
| |
| final_metrics = calculate_metrics(all_results) |
| save_json_file(final_metrics, metrics_output_file) |
| print(f"\nEvaluation metrics saved to: {metrics_output_file}") |
| print(json.dumps(final_metrics, indent=4)) |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| main() |
|
|