| import os |
| import json |
| import base64 |
| import argparse |
| import time |
| import re |
| import traceback |
| import uuid |
| import multiprocessing |
| import concurrent.futures |
| from datetime import datetime |
| from functools import partial |
|
|
| import requests |
| import torch |
| from PIL import Image |
| from tqdm import tqdm |
| from openai import AzureOpenAI, OpenAI |
| from volcenginesdkarkruntime import Ark |
| from transformers import AutoModel, AutoProcessor |
| from torch.nn.functional import cosine_similarity |
|
|
| |
|
|
| |
| SIGLIP_MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" |
| |
| TOP_K_FRAMES = 8 |
|
|
| |
|
|
| |
| |
| STEP_1_PLANNING_PROMPT = """ |
| You are a professional video analyst. Your task is to analyze a question and a few initial video sample frames, then plan what keyframes you need to see to answer the question. |
| |
| Do not answer the question directly. Your output must be a JSON array, where each object represents a keyframe you wish to generate. |
| Each object must contain the following two keys: |
| 1. `reference_image_id`: An integer representing the ID of a frame already provided to you that you wish to use as a generation reference. This ID must be one of the IDs provided by the user. |
| 2. `prompt`: A detailed text description to tell the image generation model what kind of scene to draw. |
| |
| For example, if the question is "Where did the man in the red shirt eventually go?", you might generate the following JSON: |
| ```json |
| [ |
| { |
| "reference_image_id": 120, |
| "prompt": "A man in a red shirt is walking towards an open door, with a background similar to the reference image." |
| }, |
| { |
| "reference_image_id": 120, |
| "prompt": "A man in a red shirt has already walked out the door, and the door is closing, with a background similar to the reference image." |
| } |
| ] |
| ``` |
| Your output must strictly adhere to this JSON format. |
| """ |
|
|
| |
| STEP_3_FINAL_ANSWER_PROMPT = """ |
| You are an AI video question-answering assistant. |
| The user will provide you with a series of keyframes retrieved from a video and a question. |
| |
| First, please provide a step-by-step reasoning process, analyzing these keyframes and deriving your conclusion. |
| After your reasoning, provide the final answer. The answer must be in a JSON code block, and the JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'. |
| |
| Your output format must be strictly as follows: |
| <Your step-by-step reasoning process> |
| ```json |
| {"answer": "A"} |
| ``` |
| Do not include any other text after the JSON block. |
| """ |
|
|
|
|
| def parse_arguments(): |
| """Parse command-line arguments""" |
| parser = argparse.ArgumentParser( |
| description="Image Retrieval-based Video QA Workflow" |
| ) |
| |
| parser.add_argument( |
| "--target-model", "-tm", type=str, required=True, help="VLM model for inference (e.g., gpt-4o)" |
| ) |
| |
| parser.add_argument( |
| "--frames-path", "-fp", type=str, required=True, help="Root directory containing video frame folders" |
| ) |
| parser.add_argument( |
| "--data-file", "-df", type=str, required=True, help="JSON data file containing evaluation questions" |
| ) |
| parser.add_argument( |
| "--embeddings-path", "-ep", type=str, required=True, help="Directory containing pre-computed embeddings for all video frames" |
| ) |
| parser.add_argument( |
| "--output-path", "-op", type=str, default="./results_image_retrieval", help="Directory to store all outputs and generated images" |
| ) |
| |
| parser.add_argument( |
| "--initial-frames-num", "-ifn", type=int, default=8, help="Number of initial uniformly sampled frames for Step 1" |
| ) |
| |
| parser.add_argument( |
| "--max-retry-times", "-mr", type=int, default=10, help="Maximum number of retries for API calls" |
| ) |
| parser.add_argument( |
| "--pool-processes", "-pp", type=int, default=10, help="Number of parallel processes" |
| ) |
| |
| parser.add_argument( |
| "--base_url", type=str, required=True, help="API Endpoint URL for the VLM model" |
| ) |
| parser.add_argument( |
| "--api_key", type=str, required=True, help="API Key for the VLM model" |
| ) |
| return parser.parse_args() |
|
|
|
|
| def save_json_file(data, output_file): |
| """Save data to a JSON file""" |
| os.makedirs(os.path.dirname(output_file), exist_ok=True) |
| with open(output_file, "w", encoding="utf-8") as f: |
| json.dump(data, f, indent=4, ensure_ascii=False) |
|
|
|
|
| def extract_json_from_response(response, is_list=False): |
| """Extract a JSON object or list from the model's response text""" |
| if not response: |
| return None |
| |
| pattern = r"```json\s*([\{\[].*?[\]\}])\s*```" |
| match = re.search(pattern, response, re.DOTALL) |
| if match: |
| json_str = match.group(1) |
| try: |
| return json.loads(json_str) |
| except json.JSONDecodeError: |
| print(f"JSON parsing failed: {json_str}") |
| return None |
| return None |
|
|
|
|
| def calculate_metrics(results): |
| """Calculate accuracy and other metrics from evaluation results""" |
| valid_results = [r for r in results if "error" not in r] |
| total_samples = len(valid_results) |
| if total_samples == 0: return {"accuracy": 0.0} |
| |
| answered = sum(1 for x in valid_results if x.get("model_answer") is not None) |
| correct = sum(1 for x in valid_results if x.get("is_correct")) |
| accuracy = correct / answered if answered > 0 else 0.0 |
| |
| return { |
| "total_samples": total_samples, |
| "answered_samples": answered, |
| "correct_answers": correct, |
| "accuracy": accuracy, |
| } |
|
|
|
|
| def call_vlm_api(client, messages, model, item_id, max_retry_times, json_schema=None): |
| """Call VLM API, with support for retries and structured output""" |
| params = {"model": model, "messages": messages, "max_tokens": 4096} |
| if json_schema: |
| params["response_format"] = {"type": "json_object", "schema": json_schema} |
|
|
| for retry in range(max_retry_times): |
| try: |
| completion = client.chat.completions.create(**params) |
| return completion.choices[0].message.content |
| except Exception as e: |
| print(f"API Error (item {item_id}): {e}. Retrying ({retry + 1}/{max_retry_times})...") |
| if retry == max_retry_times - 1: |
| raise e |
| time.sleep(5) |
|
|
|
|
| def generate_image(reference_image_id, prompt, all_frame_paths, output_dir, generation_idx): |
| """Call the image generation API to create a new frame""" |
| print(f"\n[Image Generation] Using Prompt: '{prompt}'") |
| ark_api_key = os.environ.get("ARK_API_KEY") |
| if not ark_api_key: |
| raise ValueError("Environment variable ARK_API_KEY is not set.") |
|
|
| client = Ark(base_url="https://ark.cn-beijing.volces.com/api/v3", api_key=ark_api_key) |
| |
| ref_image_path = all_frame_paths.get(reference_image_id) |
| if not ref_image_path or not os.path.exists(ref_image_path): |
| raise FileNotFoundError(f"Reference image ID {reference_image_id} not found.") |
|
|
| try: |
| ref_image_b64 = encode_image(ref_image_path) |
| ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}" |
| |
| response = client.images.generate( |
| model="doubao-seedream-4-0-250828", |
| prompt=prompt, |
| image=ref_image_data_uri, |
| size="1024x1024", |
| response_format="url", |
| watermark=False, |
| ) |
| image_url = response.data[0].url |
| |
| image_content = requests.get(image_url, timeout=60).content |
| |
| new_frame_filename = f"generated_frame_{generation_idx}_ref_{reference_image_id}.jpg" |
| new_frame_path = os.path.join(output_dir, new_frame_filename) |
| |
| with open(new_frame_path, "wb") as f: |
| f.write(image_content) |
| |
| print(f"[Image Generation Success] Image saved to: {new_frame_path}") |
| return new_frame_path |
| except Exception as e: |
| print(f"Image generation or download failed: {e}") |
| traceback.print_exc() |
| return None |
|
|
| def retrieve_frames_by_image_embedding( |
| image_path, video_embeddings_data, request_queue, results_dict, k |
| ): |
| """Retrieve Top-K similar frames from the video using an image embedding""" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| frame_filenames = video_embeddings_data["filenames"] |
| frame_embeddings = video_embeddings_data["embeddings"].to(device) |
|
|
| |
| request_id = str(uuid.uuid4()) |
| request_queue.put((request_id, image_path)) |
|
|
| |
| while request_id not in results_dict: |
| time.sleep(0.05) |
| query_embedding = results_dict.pop(request_id).to(device) |
|
|
| |
| with torch.no_grad(): |
| similarities = cosine_similarity(query_embedding, frame_embeddings) |
| top_k_indices = torch.topk(similarities, k=min(k, len(frame_filenames)), dim=-1).indices.cpu() |
| |
| |
| video_frame_dir = os.path.dirname(frame_filenames[0]) |
| top_k_paths = [os.path.join(video_frame_dir, video_embeddings_data['filenames'][i]) for i in top_k_indices] |
|
|
| return top_k_paths |
|
|
| def embedding_server_process(model_id, device, request_queue, results_dict): |
| """ |
| An independent server process that loads the SigLIP model and handles image embedding requests from worker processes. |
| """ |
| print(f"Embedding server started (PID: {os.getpid()})...") |
| model = AutoModel.from_pretrained(model_id).to(device).eval() |
| processor = AutoProcessor.from_pretrained(model_id) |
| print("SigLIP model loaded in the embedding server.") |
|
|
| while True: |
| try: |
| request_id, image_path = request_queue.get() |
| if image_path == "STOP": |
| print("Embedding server received stop signal, shutting down.") |
| break |
| |
| with torch.no_grad(): |
| image = Image.open(image_path).convert("RGB") |
| inputs = processor(images=[image], return_tensors="pt").to(device) |
| image_features = model.get_image_features(**inputs) |
| results_dict[request_id] = image_features.cpu() |
|
|
| except Exception as e: |
| print(f"Error in embedding server: {e}") |
| traceback.print_exc() |
|
|
|
|
| def encode_image(image_path): |
| """Encode an image file to a Base64 string""" |
| with open(image_path, "rb") as f: |
| return base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
| def uniformly_sample_frames_and_encode(frames_dir, num_frames): |
| """Uniformly sample frames and encode them, while also returning a mapping of frame IDs to paths""" |
| if not os.path.isdir(frames_dir): return [], {} |
|
|
| frame_files = sorted( |
| [f for f in os.listdir(frames_dir) if f.endswith(".jpg")], |
| key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)), |
| ) |
| if not frame_files: return [], {} |
|
|
| indices = [int(i * len(frame_files) / num_frames) for i in range(num_frames)] |
| sampled_files = [frame_files[i] for i in indices] |
|
|
| frame_path_map, encoded_frames = {}, [] |
| for f in sampled_files: |
| path = os.path.join(frames_dir, f) |
| frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1)) |
| |
| encoded_frames.extend([ |
| {"type": "text", "text": f"This is Frame ID: {frame_id}"}, |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}} |
| ]) |
| frame_path_map[frame_id] = path |
| return encoded_frames, frame_path_map |
|
|
|
|
| def run_workflow_for_item( |
| data_item, args, request_queue, results_dict |
| ): |
| """Execute the complete three-step workflow for a single data item""" |
| item_key = data_item["key"] |
| print(f"\n--- Starting processing for video: {item_key} ---") |
| |
| |
| generated_images_dir = os.path.join(args.output_path, "generated_images", item_key) |
| os.makedirs(generated_images_dir, exist_ok=True) |
| |
| |
| if "ark" in args.base_url: |
| client = Ark(base_url=args.base_url, api_key=args.api_key) |
| elif "aliyun" in args.base_url or "127.0.0.1" in args.base_url: |
| client = OpenAI(api_key=args.api_key, base_url=args.base_url) |
| else: |
| client = AzureOpenAI(api_version="2023-05-15", api_key=args.api_key, azure_endpoint=args.base_url) |
| |
| |
| print(f"[{item_key}] Step 1: Uniformly sampling and generating keyframe creation requests...") |
| video_frames_path = os.path.join(args.frames_path, item_key) |
| initial_frames_encoded, initial_frame_paths = uniformly_sample_frames_and_encode( |
| video_frames_path, args.initial_frames_num |
| ) |
| if not initial_frames_encoded: |
| raise FileNotFoundError(f"Initial frames not found for video {item_key}.") |
|
|
| planning_messages = [ |
| {"role": "system", "content": STEP_1_PLANNING_PROMPT}, |
| {"role": "user", "content": [ |
| {"type": "text", "text": "Here are the initial sample frames and the question:"}, |
| *initial_frames_encoded, |
| {"type": "text", "text": f"Question: {data_item['question']}"} |
| ]} |
| ] |
| |
| |
| planning_schema = { |
| "type": "array", |
| "items": { |
| "type": "object", |
| "properties": { |
| "reference_image_id": {"type": "integer"}, |
| "prompt": {"type": "string"} |
| }, |
| "required": ["reference_image_id", "prompt"] |
| } |
| } |
| |
| raw_planning_response = call_vlm_api(client, planning_messages, args.target_model, item_key, args.max_retry_times) |
| image_generation_requests = extract_json_from_response(raw_planning_response, is_list=True) |
| |
| if not image_generation_requests or not isinstance(image_generation_requests, list): |
| raise ValueError(f"Step 1 failed to generate valid JSON-formatted image generation requests. Response: {raw_planning_response}") |
| |
| print(f"[{item_key}] Successfully generated {len(image_generation_requests)} keyframe generation requests.") |
|
|
| |
| valid_ids = list(initial_frame_paths.keys()) |
| if not valid_ids: |
| raise ValueError(f"No valid initial frame IDs found for video {item_key}.") |
| |
| for req in image_generation_requests: |
| original_id = req.get("reference_image_id") |
| if original_id not in valid_ids: |
| closest_id = min(valid_ids, key=lambda valid_id: abs(valid_id - original_id)) |
| print(f"Warning: Model generated a non-existent reference_image_id: {original_id}. Substituting with the closest valid ID: {closest_id}.") |
| req["reference_image_id"] = closest_id |
|
|
| |
| print(f"[{item_key}] Step 2: Generating images and retrieving similar frames...") |
| all_retrieved_frame_paths = set() |
| generated_image_paths = [] |
| video_embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt") |
| if not os.path.exists(video_embedding_file): |
| raise FileNotFoundError(f"Embedding file for video {item_key} not found: {video_embedding_file}") |
| video_embeddings_data = torch.load(video_embedding_file, map_location="cpu") |
| |
| |
| video_frame_dir_for_embeddings = os.path.join(args.frames_path, item_key) |
| video_embeddings_data['filenames'] = [os.path.join(video_frame_dir_for_embeddings, os.path.basename(f)) for f in video_embeddings_data['filenames']] |
|
|
|
|
| for i, req in enumerate(image_generation_requests): |
| |
| generated_path = generate_image( |
| reference_image_id=req["reference_image_id"], |
| prompt=req["prompt"], |
| all_frame_paths=initial_frame_paths, |
| output_dir=generated_images_dir, |
| generation_idx=i + 1, |
| ) |
|
|
| path_for_retrieval = None |
| if generated_path: |
| generated_image_paths.append(generated_path) |
| path_for_retrieval = generated_path |
| else: |
| print(f"Warning: Generation failed for image {i+1}. Using its reference image (ID: {req['reference_image_id']}) for retrieval instead.") |
| path_for_retrieval = initial_frame_paths.get(req["reference_image_id"]) |
|
|
| if not path_for_retrieval: |
| print(f"Error: Could not find a path for retrieval for request {i+1}. Skipping.") |
| continue |
| |
| |
| retrieved_paths = retrieve_frames_by_image_embedding( |
| path_for_retrieval, video_embeddings_data, request_queue, results_dict, k=TOP_K_FRAMES |
| ) |
| all_retrieved_frame_paths.update(retrieved_paths) |
| print(f"[{item_key}] Retrieval {i+1}/{len(image_generation_requests)} complete, found {len(retrieved_paths)} frames.") |
|
|
| if not all_retrieved_frame_paths: |
| raise ValueError(f"Failed to retrieve any frames for video {item_key}.") |
| |
| print(f"[{item_key}] Step 2 complete. Retrieved a total of {len(all_retrieved_frame_paths)} unique keyframes.") |
| |
| |
| print(f"[{item_key}] Step 3: Consolidating keyframes for final reasoning...") |
| final_frames_encoded = [] |
| for path in sorted(list(all_retrieved_frame_paths)): |
| final_frames_encoded.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}}) |
|
|
| final_messages = [ |
| {"role": "system", "content": STEP_3_FINAL_ANSWER_PROMPT}, |
| {"role": "user", "content": [ |
| {"type": "text", "text": "Here are all the keyframes retrieved for you. Please answer the question based on them."}, |
| *final_frames_encoded, |
| {"type": "text", "text": f"Question: {data_item['question']}"} |
| ]} |
| ] |
|
|
| final_response_text = call_vlm_api(client, final_messages, args.target_model, item_key, args.max_retry_times) |
| |
| |
| parsed_answer = extract_json_from_response(final_response_text) |
| model_answer = parsed_answer.get("answer", "").strip().upper() if parsed_answer else None |
| is_correct = (model_answer == data_item["answer"].strip().upper()) if model_answer else False |
| |
| result = { |
| **data_item, |
| "workflow_steps": { |
| "step1_planning_requests": image_generation_requests, |
| "step2_generated_images": generated_image_paths, |
| "step2_retrieved_frame_paths": sorted(list(all_retrieved_frame_paths)), |
| "step3_final_reasoning_and_answer": final_response_text, |
| }, |
| "model_answer": model_answer, |
| "is_correct": is_correct, |
| } |
| return result |
|
|
|
|
| def process_single_data_wrapper(data_item, args, request_queue, results_dict): |
| """Wrapper function to process a single data item, used for exception handling""" |
| try: |
| return run_workflow_for_item(data_item, args, request_queue, results_dict) |
| except Exception as e: |
| print(f"\nA critical error occurred while processing video {data_item['key']}: {e}") |
| traceback.print_exc() |
| return { |
| "key": data_item['key'], |
| "uid": data_item.get('uid'), |
| "error": str(e), |
| "traceback": traceback.format_exc(), |
| } |
|
|
| def main(): |
| """Main function to orchestrate the entire evaluation workflow""" |
| args = parse_arguments() |
| print("--- Image Retrieval-based Video QA Workflow Starting ---") |
| print(f"Evaluating Model: {args.target_model}, Dataset: {args.data_file}") |
| |
| try: |
| multiprocessing.set_start_method("spawn", force=True) |
| except RuntimeError: |
| pass |
|
|
| os.makedirs(args.output_path, exist_ok=True) |
| |
| |
| model_safe_name = args.target_model.replace("/", "_") |
| data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
| output_prefix = f"{model_safe_name}_{data_filename_base}_image_retrieval_{args.initial_frames_num}frames" |
| |
| results_file = os.path.join(args.output_path, f"{output_prefix}_results.json") |
| metrics_file = os.path.join(args.output_path, f"{output_prefix}_metrics.json") |
| |
| test_data = load_test_data(args.data_file) |
| all_results = [] |
| |
| with multiprocessing.Manager() as manager: |
| request_queue = manager.Queue() |
| results_dict = manager.dict() |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| embedding_server = multiprocessing.Process( |
| target=embedding_server_process, |
| args=(SIGLIP_MODEL_ID, device, request_queue, results_dict), |
| ) |
| embedding_server.start() |
| |
| |
| time.sleep(15) |
|
|
| with concurrent.futures.ProcessPoolExecutor(max_workers=args.pool_processes) as executor: |
| func = partial( |
| process_single_data_wrapper, |
| args=args, |
| request_queue=request_queue, |
| results_dict=results_dict |
| ) |
| |
| results_iterator = executor.map(func, test_data) |
| |
| for result in tqdm(results_iterator, total=len(test_data), desc="Processing Videos"): |
| if result: |
| all_results.append(result) |
| |
| if len(all_results) % 10 == 0: |
| save_json_file(all_results, results_file) |
| |
| |
| print("All tasks completed. Shutting down the embedding server...") |
| request_queue.put((None, "STOP")) |
| embedding_server.join() |
|
|
| print("\n--- All Videos Processed ---") |
| save_json_file(all_results, results_file) |
| print(f"Detailed results saved to: {results_file}") |
|
|
| final_metrics = calculate_metrics(all_results) |
| save_json_file(final_metrics, metrics_file) |
| print(f"Final evaluation metrics saved to: {metrics_file}") |
| print(json.dumps(final_metrics, indent=4)) |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| main() |
|
|
|
|