# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Any, Optional import numpy as np import pytest import torch from verl.protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto def _get_data_proto( tensors: Optional[dict[str, list[Any]]] = None, non_tensors: Optional[dict[str, list[Any]]] = None, meta_info: Optional[dict[str, Any]] = None, ) -> DataProto: if tensors is None and non_tensors is None: tensors = {"obs": [1, 2, 3, 4, 5, 6]} non_tensors = {"labels": ["a", "b", "c", "d", "e", "f"]} if tensors is not None: tensors = {k: torch.tensor(v) if not isinstance(v, torch.Tensor) else v for k, v in tensors.items()} if non_tensors is not None: non_tensors = { k: np.array(v, dtype=object) if not isinstance(v, np.ndarray) else v for k, v in non_tensors.items() } meta_info = meta_info or {"info": "test_info"} return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) def _assert_equal(data1: DataProto, data2: Optional[DataProto] = None): data2 = data2 or _get_data_proto() if data1.batch is not None: assert data1.batch.keys() == data2.batch.keys() for key in data1.batch.keys(): assert torch.all(data1.batch[key] == data2.batch[key]) else: assert data2.batch is None if data1.non_tensor_batch is not None: assert data1.non_tensor_batch.keys() == data2.non_tensor_batch.keys() for key in data1.non_tensor_batch.keys(): assert np.all(data1.non_tensor_batch[key] == data2.non_tensor_batch[key]) else: assert data2.non_tensor_batch is None assert data1.meta_info == data2.meta_info def test_tensor_dict_constructor(): obs = torch.randn(100, 10) act = torch.randn(100, 10, 3) data = DataProto.from_dict(tensors={"obs": obs, "act": act}) assert len(data) == 100 with pytest.raises(AssertionError): data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=2) with pytest.raises(AssertionError): data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=3) labels = np.array(["a", "b", "c"], dtype=object) data = DataProto.from_dict(non_tensors={"labels": labels}) assert len(data) == 3 def test_getitem(): data = _get_data_proto() assert data[0].batch["obs"] == torch.tensor(1) assert data[0].non_tensor_batch["labels"] == "a" _assert_equal(data[1:3], _get_data_proto({"obs": [2, 3]}, {"labels": ["b", "c"]})) _assert_equal(data[[0, 2]], _get_data_proto({"obs": [1, 3]}, {"labels": ["a", "c"]})) _assert_equal(data[torch.tensor([1])], _get_data_proto({"obs": [2]}, {"labels": ["b"]})) def test_select_pop(): obs = torch.randn(100, 10) act = torch.randn(100, 3) dataset = _get_data_proto(tensors={"obs": obs, "act": act}, meta_info={"p": 1, "q": 2}) selected_dataset = dataset.select(batch_keys=["obs"], meta_info_keys=["p"]) assert selected_dataset.batch.keys() == {"obs"} assert selected_dataset.meta_info.keys() == {"p"} assert dataset.batch.keys() == {"obs", "act"} assert dataset.meta_info.keys() == {"p", "q"} popped_dataset = dataset.pop(batch_keys=["obs"], meta_info_keys=["p"]) assert popped_dataset.batch.keys() == {"obs"} assert popped_dataset.meta_info.keys() == {"p"} assert dataset.batch.keys() == {"act"} assert dataset.meta_info.keys() == {"q"} def test_chunk_concat_split(): data = _get_data_proto() with pytest.raises(AssertionError): data.chunk(5) chunked_data = data.chunk(2) assert len(chunked_data) == 2 expected_data = _get_data_proto({"obs": [1, 2, 3]}, {"labels": ["a", "b", "c"]}) _assert_equal(chunked_data[0], expected_data) concat_data = DataProto.concat(chunked_data) _assert_equal(concat_data, data) splitted_data = data.split(2) assert len(splitted_data) == 3 expected_data = _get_data_proto({"obs": [1, 2]}, {"labels": ["a", "b"]}) _assert_equal(splitted_data[0], expected_data) def test_reorder(): data = _get_data_proto() data.reorder(torch.tensor([3, 4, 2, 0, 1, 5])) expected_data = _get_data_proto({"obs": [4, 5, 3, 1, 2, 6]}, {"labels": ["d", "e", "c", "a", "b", "f"]}) _assert_equal(data, expected_data) @pytest.mark.parametrize("interleave", [True, False]) def test_repeat(interleave: bool): data = _get_data_proto({"obs": [1, 2]}, {"labels": ["a", "b"]}) repeated_data = data.repeat(repeat_times=2, interleave=interleave) expected_tensors = {"obs": [1, 1, 2, 2] if interleave else [1, 2, 1, 2]} expected_non_tensors = {"labels": ["a", "a", "b", "b"] if interleave else ["a", "b", "a", "b"]} _assert_equal(repeated_data, _get_data_proto(expected_tensors, expected_non_tensors)) @pytest.mark.parametrize("size_divisor", [2, 3]) def test_dataproto_pad_unpad(size_divisor: int): data = _get_data_proto({"obs": [1, 2, 3]}, {"labels": ["a", "b", "c"]}) # test size_divisor=2 padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=size_divisor) unpadded_data = unpad_dataproto(padded_data, pad_size=pad_size) if size_divisor == 2: assert pad_size == 1 expected_tensors = {"obs": [1, 2, 3, 1]} expected_non_tensors = {"labels": ["a", "b", "c", "a"]} expected_data = _get_data_proto(expected_tensors, expected_non_tensors) else: assert pad_size == 0 expected_data = data _assert_equal(padded_data, expected_data) _assert_equal(unpadded_data, data) def test_data_proto_save_load(): data = _get_data_proto() data.save_to_disk("test_data.pt") loaded_data = DataProto.load_from_disk("test_data.pt") os.remove("test_data.pt") _assert_equal(data, loaded_data) def test_union_tensor_dict(): obs = torch.randn(100, 10) data1 = _get_data_proto({"obs": obs, "act": torch.randn(100, 3)}) data2 = _get_data_proto({"obs": obs, "rew": torch.randn(100)}) data1.union(data2) data1 = _get_data_proto({"obs": obs, "act": torch.randn(100, 3)}) data2 = _get_data_proto({"obs": obs + 1, "rew": torch.randn(100)}) with pytest.raises(ValueError): data1.union(data2)