| import os |
| import json |
| import base64 |
| import argparse |
| import time |
| import re |
| import traceback |
| from datetime import datetime |
| from functools import partial |
| from openai import AzureOpenAI, OpenAI |
| from volcenginesdkarkruntime import Ark |
| import concurrent.futures |
| from tqdm import tqdm |
| import torch |
| from transformers import AutoModel, AutoProcessor |
| from torch.nn.functional import cosine_similarity |
| |
| import multiprocessing |
| import uuid |
|
|
| |
| |
| SIGLIP_MODEL_ID = ( |
| "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" |
| ) |
|
|
| |
| AGENT_SYSTEM_PROMPT = """ |
| You are an intelligent AI assistant specialized in video question answering. |
| Your task is to answer a multiple-choice question based on a video by strategically retrieving and analyzing its frames. |
| |
| You have two tools to retrieve frames. Both return images directly. |
| |
| 1. `get_frames_by_id(frame_ids)`: Retrieves frames using their specific numerical IDs. Use this when the question provides direct temporal clues or when you need to view specific frames identified by another tool. |
| * **Example Use Case:** For a question like "What happens at the 1 minute 30 second mark?", you can calculate the approximate frame ID and use this tool to see the visual. |
| * **Example Use Case:** For "Describe the action in frame 550.", you would call this tool with `frame_ids=[550]`. |
| |
| 2. `get_frames_by_similarity(query)`: Searches the entire video for frames that visually match a text description and returns the top 5 most relevant frames directly. Use this for content-based questions where the timing is unknown. |
| * **Example Use Case:** For a question like "What color is the main character's car?", you would use this tool with a query like "the main character's car". |
| * **Example Use Case:** For "Find the scene where a band is playing on stage", you would use the query "a band playing on stage". |
| |
| Your strategy must be efficient: |
| 1. **Analyze the Query:** First, determine if the question is temporal/logical (better for `get_frames_by_id`) or content-based (requires `get_frames_by_similarity`). |
| 2. **Retrieve & Analyze:** Call the most appropriate tool. Analyze the returned frames to form a hypothesis. |
| 3. **Iterate:** If you need more information, refine your search query for the similarity tool or calculate new frame IDs for the ID tool and call again. |
| 4. **Final Answer:** Once you have gathered enough visual evidence, provide your step-by-step reasoning and then the final answer in the specified JSON format. Do not guess. |
| |
| Your output should follow this format exactly: |
| <Your step-by-step reasoning here> |
| ```json |
| {"answer": "X"} |
| ``` |
| Do not include any other text after the JSON block. |
| """ |
|
|
| |
| GET_FRAMES_BY_ID_TOOL_SCHEMA = { |
| "type": "function", |
| "function": { |
| "name": "get_frames_by_id", |
| "description": "Retrieves specific video frames by their numerical IDs to get visual information.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "frame_ids": { |
| "type": "array", |
| "items": {"type": "integer"}, |
| "description": "A list of up to 10 frame numbers to retrieve.", |
| }, |
| }, |
| "required": ["frame_ids"], |
| }, |
| }, |
| } |
|
|
| GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA = { |
| "type": "function", |
| "function": { |
| "name": "get_frames_by_similarity", |
| "description": "Searches for and retrieves the top 5 most visually relevant frames for a given text query. Use this to locate visual content when frame numbers are unknown.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "query": { |
| "type": "string", |
| "description": "A concise text description of the visual content to search for (e.g., 'a person playing piano').", |
| }, |
| }, |
| "required": ["query"], |
| }, |
| }, |
| } |
|
|
|
|
| def parse_arguments(): |
| """Parse command line arguments.""" |
| parser = argparse.ArgumentParser( |
| description="Agentic Video QA with Hybrid Frame Retrieval" |
| ) |
| parser.add_argument( |
| "--target-model", "-tm", type=str, required=True, help="Model to evaluate." |
| ) |
| parser.add_argument( |
| "--frames-path", |
| "-fp", |
| type=str, |
| required=True, |
| help="Base directory for video frames.", |
| ) |
| parser.add_argument( |
| "--data-file", |
| "-df", |
| type=str, |
| required=True, |
| help="Path to the evaluation dataset.", |
| ) |
| parser.add_argument( |
| "--embeddings-path", |
| "-ep", |
| type=str, |
| required=True, |
| help="Directory with pre-computed frame embeddings.", |
| ) |
| parser.add_argument( |
| "--max-retry-times", |
| "-mr", |
| type=int, |
| default=10, |
| help="Max retries for API calls.", |
| ) |
| parser.add_argument( |
| "--pool-processes", |
| "-pp", |
| type=int, |
| default=20, |
| help="Number of parallel processes.", |
| ) |
| parser.add_argument("--base_url", type=str, required=True, help="API endpoint URL.") |
| parser.add_argument("--api_key", type=str, required=True, help="API key.") |
| return parser.parse_args() |
|
|
|
|
| def save_json_file(data, output_file): |
| """Saves data to a JSON file.""" |
| with open(output_file, "w", encoding="utf-8") as f: |
| json.dump(data, f, indent=4) |
|
|
|
|
| def extract_json_from_response(response): |
| """Extracts a JSON object from a model's response string.""" |
| if not response: |
| return None |
| match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL) |
| if match: |
| try: |
| return json.loads(match.group(1)) |
| except (json.JSONDecodeError, IndexError): |
| return None |
| return None |
|
|
|
|
| def calculate_metrics(results): |
| """Calculates accuracy and other metrics from evaluation results.""" |
| valid_results = [r for r in results if "error" not in r] |
| total_samples = len(valid_results) |
| if total_samples == 0: |
| return { |
| "total_samples": 0, |
| "answered_samples": 0, |
| "correct_answers": 0, |
| "accuracy": 0.0, |
| } |
| answered_samples = sum( |
| 1 for x in valid_results if x.get("model_answer") is not None |
| ) |
| correct_answers = sum(1 for x in valid_results if x.get("is_correct")) |
| accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0 |
| return { |
| "total_samples": total_samples, |
| "answered_samples": answered_samples, |
| "correct_answers": correct_answers, |
| "accuracy": accuracy, |
| } |
|
|
|
|
| def call_single_model(client, messages, model, item_id, max_retry_times, tools=None): |
| """Makes a single API call with retry logic and tool support.""" |
| params = {"model": model, "messages": messages, "max_tokens": 4096} |
| if tools: |
| params["tools"] = tools |
| params["tool_choice"] = "auto" |
|
|
| for retry in range(max_retry_times): |
| try: |
| completion = client.chat.completions.create(**params) |
| return completion.choices[0].message |
| except Exception as e: |
| print( |
| f"API Error for item {item_id}: {str(e)}. Retrying ({retry + 1}/{max_retry_times})..." |
| ) |
| if retry == max_retry_times - 1: |
| raise e |
| time.sleep(5) |
|
|
|
|
| def get_frames_by_id(frame_ids: list, all_frame_paths: list): |
| """Tool implementation: Retrieves and encodes frames from a list of IDs.""" |
| retrieved_frames = [] |
| frame_map = { |
| int(re.search(r"frame_(\d+)\.jpg", os.path.basename(p)).group(1)): p |
| for p in all_frame_paths |
| if re.search(r"frame_(\d+)\.jpg", os.path.basename(p)) |
| } |
| for fid in frame_ids: |
| path = frame_map.get(fid) |
| if path and os.path.exists(path): |
| b64_image = encode_image(path) |
| retrieved_frames.append( |
| { |
| "type": "image_url", |
| "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"}, |
| } |
| ) |
| return retrieved_frames |
|
|
|
|
| |
| def get_frames_by_similarity( |
| query: str, |
| all_frame_paths: list, |
| precomputed_data: dict, |
| request_queue: multiprocessing.Queue, |
| results_dict: dict, |
| k: int = 5, |
| ): |
| """ |
| Requests a text embedding from the server process, calculates similarity, |
| finds top-k frames, and returns them encoded. |
| """ |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| frame_filenames = precomputed_data["filenames"] |
| frame_embeddings = precomputed_data["embeddings"].to(device) |
|
|
| |
| request_id = str(uuid.uuid4()) |
| request_queue.put((request_id, query)) |
|
|
| |
| while request_id not in results_dict: |
| time.sleep(0.05) |
| query_embedding = results_dict.pop(request_id).to(device) |
|
|
| |
| with torch.no_grad(): |
| similarities = cosine_similarity(query_embedding, frame_embeddings) |
|
|
| num_frames_to_select = min(k, len(frame_filenames)) |
| top_k_indices = ( |
| torch.topk(similarities, k=num_frames_to_select, dim=-1) |
| .indices.cpu() |
| .flatten() |
| .numpy() |
| ) |
|
|
| top_k_filenames = [frame_filenames[i] for i in top_k_indices] |
| top_k_frame_ids = [ |
| int(re.search(r"frame_(\d+)\.jpg", f).group(1)) for f in top_k_filenames |
| ] |
|
|
| retrieved_frames = get_frames_by_id(top_k_frame_ids, all_frame_paths) |
| return retrieved_frames |
|
|
|
|
| def evaluate_single_item_agentic( |
| data_item, |
| all_frame_paths, |
| embeddings_data, |
| target_model, |
| api_key, |
| base_url, |
| max_retry_times, |
| request_queue, |
| results_dict, |
| ): |
| """Evaluates a single item using an agentic loop.""" |
| if "ark" in base_url: |
| client = Ark(base_url=base_url, api_key=api_key) |
| elif "aliyun" in base_url or "127.0.0.1" in base_url: |
| client = OpenAI(api_key=api_key, base_url=base_url) |
| else: |
| client = AzureOpenAI( |
| api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url |
| ) |
|
|
| tools = [GET_FRAMES_BY_ID_TOOL_SCHEMA, GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA] |
|
|
| get_frames_by_id_with_context = partial( |
| get_frames_by_id, all_frame_paths=all_frame_paths |
| ) |
| |
| get_frames_by_similarity_with_context = partial( |
| get_frames_by_similarity, |
| all_frame_paths=all_frame_paths, |
| precomputed_data=embeddings_data, |
| request_queue=request_queue, |
| results_dict=results_dict, |
| ) |
|
|
| available_functions = { |
| "get_frames_by_id": get_frames_by_id_with_context, |
| "get_frames_by_similarity": get_frames_by_similarity_with_context, |
| } |
|
|
| total_frames = len(all_frame_paths) |
| duration = data_item.get("video_info", {}).get("duration_minutes", 0) * 60 |
| initial_prompt = ( |
| f"The video has {total_frames} frames (ID 1 to {total_frames}) and is {duration:.0f} seconds long. " |
| f"Please answer this question:\n{data_item['question']}" |
| ) |
|
|
| messages = [ |
| {"role": "system", "content": AGENT_SYSTEM_PROMPT}, |
| {"role": "user", "content": initial_prompt}, |
| ] |
| response_content = None |
| max_tool_calls = 10 |
|
|
| for _ in range(max_tool_calls): |
| response_message = call_single_model( |
| client, |
| messages, |
| target_model, |
| data_item["key"], |
| max_retry_times, |
| tools=tools, |
| ) |
| if response_message is None: |
| return None |
|
|
| messages.append(response_message) |
|
|
| if response_message.tool_calls: |
| for tool_call in response_message.tool_calls: |
| function_name = tool_call.function.name |
| function_to_call = available_functions.get(function_name) |
| if function_to_call: |
| function_args = json.loads(tool_call.function.arguments) |
| function_response = function_to_call(**function_args) |
|
|
| messages.append( |
| { |
| "tool_call_id": tool_call.id, |
| "role": "tool", |
| "name": function_name, |
| "content": json.dumps( |
| { |
| "status": "success", |
| "retrieved_frame_count": len(function_response), |
| } |
| ), |
| } |
| ) |
|
|
| user_message_with_frames = [ |
| { |
| "type": "text", |
| "text": f"Here are the {len(function_response)} frames from your call to `{function_name}`.", |
| } |
| ] |
| user_message_with_frames.extend(function_response) |
| messages.append( |
| {"role": "user", "content": user_message_with_frames} |
| ) |
| else: |
| response_content = response_message.content |
| break |
|
|
| if response_content is None: |
| final_prompt = "You have reached the maximum number of tool calls. Provide a final answer based on the information gathered so far." |
| messages.append({"role": "user", "content": final_prompt}) |
| final_response = call_single_model( |
| client, messages, target_model, data_item["key"], max_retry_times |
| ) |
| response_content = ( |
| final_response.content |
| if final_response |
| else "Could not determine an answer after max tool calls." |
| ) |
|
|
| is_correct = False |
| model_answer_cleaned = None |
| parsed_json = extract_json_from_response(response_content) |
| if parsed_json and "answer" in parsed_json: |
| model_answer_cleaned = str(parsed_json["answer"]).strip().upper() |
| if model_answer_cleaned == data_item["answer"].strip().upper(): |
| is_correct = True |
|
|
| return { |
| **data_item, |
| "agent_conversation": [ |
| msg if isinstance(msg, dict) else msg.model_dump() for msg in messages |
| ], |
| "model_reasoning_and_answer": response_content, |
| "model_answer": model_answer_cleaned, |
| "is_correct": is_correct, |
| } |
|
|
|
|
| def encode_image(image_path): |
| """Encodes an image file to a base64 string.""" |
| with open(image_path, "rb") as image_file: |
| return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
| |
| def process_single_data(data_item, args, request_queue, results_dict): |
| """Main processing function for a single video, executed by a worker.""" |
| item_key = data_item["key"] |
| try: |
| specific_frames_path = os.path.join(args.frames_path, item_key) |
| embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt") |
|
|
| if not os.path.isdir(specific_frames_path): |
| raise FileNotFoundError( |
| f"Frame directory not found: {specific_frames_path}" |
| ) |
| if not os.path.exists(embedding_file): |
| raise FileNotFoundError(f"Embedding file not found: {embedding_file}") |
|
|
| all_frame_paths = sorted( |
| [ |
| os.path.join(specific_frames_path, f) |
| for f in os.listdir(specific_frames_path) |
| if f.endswith(".jpg") |
| ], |
| key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)), |
| ) |
| if not all_frame_paths: |
| raise FileNotFoundError(f"No frames found for key '{item_key}'") |
|
|
| embeddings_data = torch.load(embedding_file, map_location="cpu") |
|
|
| |
| result = evaluate_single_item_agentic( |
| data_item, |
| all_frame_paths, |
| embeddings_data, |
| args.target_model, |
| args.api_key, |
| args.base_url, |
| args.max_retry_times, |
| request_queue, |
| results_dict, |
| ) |
| return result |
|
|
| except Exception as e: |
| print(f"\nCRITICAL ERROR on key {item_key}: {str(e)}") |
| traceback.print_exc() |
| return { |
| "key": item_key, |
| "uid": data_item.get("uid"), |
| "error": str(e), |
| "traceback": traceback.format_exc(), |
| } |
|
|
|
|
| def load_test_data(json_file): |
| """Loads the evaluation data from a JSON file.""" |
| try: |
| with open(json_file, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except FileNotFoundError: |
| print(f"Error: Data file not found: {json_file}") |
| exit(1) |
| except json.JSONDecodeError: |
| print(f"Error: Malformed JSON in {json_file}") |
| exit(1) |
|
|
|
|
| |
| |
| def embedding_server_process(model_id, device, request_queue, results_dict): |
| """ |
| A server process that loads the SigLIP model and continuously fetches |
| text queries from a queue, computes their embeddings, and places the |
| results in a shared dictionary. |
| """ |
| print(f"Embedding server started on PID {os.getpid()}...") |
| print("Loading SigLIP model in the embedding server process...") |
| model = AutoModel.from_pretrained(model_id) |
| processor = AutoProcessor.from_pretrained(model_id, use_fast=True) |
| print("SigLIP model loaded in server.") |
|
|
| model.to(device) |
| model.eval() |
|
|
| while True: |
| try: |
| request_id, text_query = request_queue.get() |
| if text_query == "STOP": |
| print("Embedding server received stop signal. Shutting down.") |
| break |
|
|
| with torch.no_grad(): |
| text_inputs = processor( |
| text=[text_query], |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| ).to(device) |
| query_embedding = model.get_text_features(**text_inputs) |
| |
| results_dict[request_id] = query_embedding.cpu() |
| except Exception as e: |
| print(f"Error in embedding server: {e}") |
| traceback.print_exc() |
|
|
|
|
| |
| def main(): |
| """Main function to orchestrate the evaluation framework.""" |
| args = parse_arguments() |
| print("--- Agentic Video QA with Hybrid Retrieval ---") |
| print( |
| f"Model: {args.target_model}, Data: {args.data_file}, Embeddings: {args.embeddings_path}" |
| ) |
|
|
| |
| try: |
| multiprocessing.set_start_method("spawn", force=True) |
| print("Multiprocessing start method set to 'spawn'.") |
| except RuntimeError: |
| print("Start method already set.") |
|
|
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| model_name_safe = args.target_model.replace("/", "_") |
| data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
| output_prefix = f"{model_name_safe}_{data_filename_base}_agent_hybrid" |
| results_output_file = f"{output_prefix}_results.json" |
| metrics_output_file = f"{output_prefix}_metrics.json" |
| error_log_file = f"{output_prefix}_errors.log" |
|
|
| with open(error_log_file, "a", encoding="utf-8") as f: |
| f.write( |
| f"\n=== Log Session Started at {datetime.now()} for {args.target_model} ===\n" |
| ) |
|
|
| all_test_data = load_test_data(args.data_file) |
| existing_results = [] |
| completed_ids = set() |
| if os.path.exists(results_output_file): |
| try: |
| with open(results_output_file, "r", encoding="utf-8") as f: |
| existing_results = json.load(f) |
| if isinstance(existing_results, list): |
| completed_ids = { |
| item["uid"] for item in existing_results if "uid" in item |
| } |
| print(f"Found {len(completed_ids)} completed tasks. Resuming...") |
| else: |
| existing_results = [] |
| except (json.JSONDecodeError, IOError): |
| existing_results = [] |
|
|
| tasks_to_process = [ |
| item for item in all_test_data if item.get("uid") not in completed_ids |
| ] |
| if not tasks_to_process: |
| print("All tasks are already completed. Calculating final metrics.") |
| else: |
| print( |
| f"Total: {len(all_test_data)}. Completed: {len(completed_ids)}. To process: {len(tasks_to_process)}." |
| ) |
|
|
| all_results = list(existing_results) |
|
|
| if tasks_to_process: |
| |
| with multiprocessing.Manager() as manager: |
| request_queue = manager.Queue() |
| results_dict = manager.dict() |
|
|
| |
| embedding_server = multiprocessing.Process( |
| target=embedding_server_process, |
| args=( |
| SIGLIP_MODEL_ID, |
| device, |
| request_queue, |
| results_dict, |
| ), |
| ) |
| embedding_server.start() |
|
|
| |
| with concurrent.futures.ProcessPoolExecutor( |
| max_workers=args.pool_processes |
| ) as executor: |
| |
| func = partial( |
| process_single_data, |
| args=args, |
| request_queue=request_queue, |
| results_dict=results_dict, |
| ) |
| results_iterator = executor.map(func, tasks_to_process) |
| for result in tqdm( |
| results_iterator, |
| total=len(tasks_to_process), |
| desc="Processing Videos", |
| ): |
| if result: |
| if "error" in result: |
| with open(error_log_file, "a", encoding="utf-8") as f: |
| f.write( |
| f"Error on key {result.get('key', 'N/A')}:\n Error: {result['error']}\n Traceback: {result['traceback']}\n---\n" |
| ) |
| all_results.append(result) |
| if len(all_results) % 10 == 0: |
| save_json_file(all_results, results_output_file) |
|
|
| |
| print("All tasks processed. Sending stop signal to embedding server.") |
| request_queue.put((None, "STOP")) |
| embedding_server.join() |
|
|
| print("\n\nProcessing complete.") |
| save_json_file(all_results, results_output_file) |
| print(f"Detailed results saved to: {results_output_file}") |
|
|
| final_metrics = calculate_metrics(all_results) |
| save_json_file(final_metrics, metrics_output_file) |
| print(f"\nMetrics saved to: {metrics_output_file}") |
| print(json.dumps(final_metrics, indent=4)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|