| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import pytest |
| import torch |
| from PIL.Image import Image |
|
|
| from verl.utils.dataset import RLHFDataset |
| from verl.utils.tokenizer import get_processor, get_tokenizer |
|
|
|
|
| @pytest.mark.parametrize("use_fast", [True, False]) |
| def test_image_dataset(use_fast: bool): |
| tokenizer = get_tokenizer("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=use_fast) |
| processor = get_processor("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=use_fast) |
| dataset = RLHFDataset( |
| data_path="hiyouga/geometry3k@test", |
| tokenizer=tokenizer, |
| processor=processor, |
| prompt_key="problem", |
| answer_key="answer", |
| image_key="images", |
| max_prompt_length=16, |
| truncation="right", |
| filter_overlong_prompts=False, |
| ) |
| token_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655] |
| assert set(dataset[0].keys()) == { |
| "input_ids", |
| "attention_mask", |
| "position_ids", |
| "raw_prompt_ids", |
| "ground_truth", |
| "multi_modal_data", |
| } |
| assert torch.all(dataset[0]["input_ids"] == torch.tensor(token_ids)) |
| assert torch.all(dataset[0]["attention_mask"] == torch.ones(16)) |
| assert torch.all(dataset[0]["position_ids"] == torch.arange(16).unsqueeze(0).expand(4, -1)) |
| assert list(dataset[0]["position_ids"].size()) == [4, 16] |
| assert dataset[0]["raw_prompt_ids"] == token_ids |
| assert dataset[0]["ground_truth"] == "48" |
| assert isinstance(dataset[0]["multi_modal_data"]["images"][0], Image) |
|
|
|
|
| if __name__ == "__main__": |
| test_image_dataset() |
|
|