| from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor |
| from qwen_vl_utils import process_vision_info |
| import torch |
| import json |
| from tqdm import tqdm |
| import re |
| from math_verify import parse, verify |
|
|
|
|
| MODEL_PATH="<MODEL_PATH>" |
| BSZ=50 |
| OUTPUT_PATH="<OUTPUT_LOG>" |
| PROMPT_PATH="./prompts/geoqa_test_prompts.jsonl" |
|
|
| |
| model = Qwen2VLForConditionalGeneration.from_pretrained( |
| MODEL_PATH, |
| torch_dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", |
| device_map="auto", |
| ) |
|
|
| |
| processor = AutoProcessor.from_pretrained(MODEL_PATH) |
|
|
| data = [] |
| with open(PROMPT_PATH, "r") as f: |
| for line in f: |
| data.append(json.loads(line)) |
|
|
|
|
| QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags." |
|
|
| messages = [] |
|
|
| data = data |
|
|
| for i in data: |
| message = [{ |
| "role": "user", |
| "content": [ |
| { |
| "type": "image", |
| "image": f"file://{i['image_path']}" |
| }, |
| { |
| "type": "text", |
| "text": QUESTION_TEMPLATE.format(Question=i['question']) |
| } |
| ] |
| }] |
| messages.append(message) |
|
|
|
|
|
|
|
|
| all_outputs = [] |
|
|
| |
| for i in tqdm(range(0, len(messages), BSZ)): |
| batch_messages = messages[i:i + BSZ] |
| |
| |
| text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] |
| |
| image_inputs, video_inputs = process_vision_info(batch_messages) |
| inputs = processor( |
| text=text, |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt", |
| ) |
| inputs = inputs.to("cuda") |
|
|
| |
| generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False) |
| |
| generated_ids_trimmed = [ |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| ] |
| batch_output_text = processor.batch_decode( |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| ) |
| |
| all_outputs.extend(batch_output_text) |
| print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}") |
|
|
|
|
|
|
|
|
|
|
| final_output = [] |
| correct_number = 0 |
|
|
| for input_example, model_output in zip(data,all_outputs): |
| original_output = model_output |
| ground_truth = input_example['ground_truth'] |
| model_answer = parse(original_output) |
|
|
| |
| if model_answer is not None and float(verify(model_answer,parse(ground_truth)))>0: |
| correct_number += 1 |
| is_correct = True |
| else: |
| is_correct = False |
| |
| try: |
| result = { |
| 'question': input_example, |
| 'ground_truth': ground_truth, |
| 'model_output': original_output, |
| 'extracted_answer':str(model_answer[0]) if model_answer is not None else None, |
| 'is_correct':is_correct |
| } |
|
|
| except Exception as e: |
| print("no answer parsed",e,model_answer) |
| result = { |
| 'question': input_example, |
| 'ground_truth': ground_truth, |
| 'model_output': original_output, |
| 'extracted_answer':None, |
| 'is_correct':is_correct |
| } |
|
|
|
|
|
|
| final_output.append(result) |
|
|
|
|
| |
| accuracy = correct_number / len(data) * 100 |
| print(f"\nAccuracy: {accuracy:.2f}%") |
|
|
| |
| output_path = OUTPUT_PATH |
| with open(output_path, "w") as f: |
| json.dump({ |
| 'accuracy': accuracy, |
| 'results': final_output |
| }, f, indent=2, ensure_ascii=False) |
|
|
| print(f"Results saved to {output_path}") |
|
|
|
|
|
|
|
|
|
|
|
|