| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Replace DataProto with raw TensorDict |
| """ |
|
|
| import copy |
| import random |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
| from verl.utils import tensordict_utils as tu |
|
|
|
|
| def test_union_tensor_dict(): |
| obs = torch.randn(100, 10) |
|
|
| meta_info1 = {"top_p": 0.8} |
| meta_info2 = {"top_p": 0.9} |
| data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100} |
| data2 = {"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100), "data_sources": ["gsm8k"] * 100} |
|
|
| data_with_copied_obs = {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)} |
|
|
| data1 = tu.get_tensordict(tensor_dict=data1) |
| data2 = tu.get_tensordict(tensor_dict=data2) |
| data_with_copied_obs = tu.get_tensordict(data_with_copied_obs) |
|
|
| tu.union_tensor_dict(data1, data2) |
| with pytest.raises(AssertionError): |
| |
| tu.union_tensor_dict(data1, data_with_copied_obs) |
|
|
| data1 = tu.assign_non_tensor_dict(data1, meta_info1) |
| tu.union_tensor_dict(data1, data2) |
|
|
| data2 = tu.assign_non_tensor_dict(data2, meta_info2) |
|
|
| with pytest.raises(AssertionError): |
| |
| tu.union_tensor_dict(data1, data2) |
|
|
| data1.pop("top_p") |
| data2.pop("top_p") |
|
|
| data2["data_sources"][0] = "math" |
| with pytest.raises(AssertionError): |
| |
| tu.union_tensor_dict(data1, data2) |
|
|
|
|
| def test_tensor_dict_constructor(): |
| obs = torch.ones(100, 10) |
| act = torch.zeros(100, 10, 3) |
| data_source = ["gsm8k"] * 100 |
| non_tensor_dict = {"name": "abdce"} |
|
|
| data = tu.get_tensordict( |
| tensor_dict={"obs": obs, "act": act, "data_source": data_source}, non_tensor_dict=non_tensor_dict |
| ) |
|
|
| assert data.batch_size == torch.Size([100]) |
|
|
| |
| assert torch.all(torch.eq(data[0]["obs"], torch.ones(10))).item() |
| assert torch.all(torch.eq(data[0]["act"], torch.zeros(10, 3))).item() |
| assert data[0]["data_source"] == "gsm8k" |
|
|
| assert torch.all(torch.eq(data[0:2]["obs"], torch.ones(2, 10))).item() |
| assert torch.all(torch.eq(data[0:2]["act"], torch.zeros(2, 10, 3))).item() |
| assert data[0:2]["data_source"] == ["gsm8k"] * 2 |
|
|
| |
| assert data["name"] == "abdce" |
|
|
|
|
| def test_index_select_tensor_dict(): |
| vocab_size = 128 |
| a = torch.randint(low=0, high=vocab_size, size=(11,)) |
| b = torch.randint(low=0, high=vocab_size, size=(13,)) |
| c = torch.randint(low=0, high=vocab_size, size=(12,)) |
| d = torch.randint(low=0, high=vocab_size, size=(15,)) |
| input_ids = [a, b, c, d] |
| input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged) |
|
|
| padded_tensor = torch.randn(4, 10) |
| non_tensor_dict = {"global_batch_size": "4"} |
|
|
| data = tu.get_tensordict( |
| tensor_dict={ |
| "input_ids": input_ids, |
| "padded_tensor": padded_tensor, |
| }, |
| non_tensor_dict=non_tensor_dict, |
| ) |
|
|
| assert data.batch_size == torch.Size([4]) |
|
|
| |
| indices = torch.tensor([1, 3]) |
| selected_data = tu.index_select_tensor_dict(data, indices) |
|
|
| assert selected_data.batch_size == torch.Size([2]) |
|
|
| target_input_ids = torch.nested.as_nested_tensor([input_ids[idx] for idx in indices], layout=torch.jagged) |
| target_select_data = tu.get_tensordict( |
| tensor_dict={ |
| "input_ids": target_input_ids, |
| "padded_tensor": padded_tensor[indices], |
| }, |
| non_tensor_dict=non_tensor_dict, |
| ) |
| tu.assert_tensordict_eq(selected_data, target_select_data) |
|
|
|
|
| def test_tensordict_with_images(): |
| |
| vocab_size = 128 |
| a = torch.randint(low=0, high=vocab_size, size=(11,)) |
| b = torch.randint(low=0, high=vocab_size, size=(13,)) |
| input_ids = [a, b] |
| input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged) |
|
|
| |
| |
| a_images = [ |
| torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(), |
| torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(), |
| ] |
| b_images = [ |
| torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(), |
| torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(), |
| torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(), |
| ] |
|
|
| images = [a_images, b_images] |
|
|
| data = tu.get_tensordict({"input_ids": input_ids, "images": images}) |
|
|
| assert np.all(np.equal(data[0]["images"][0], a_images[0])) |
| assert torch.all(torch.eq(data[0]["input_ids"], a)) |
|
|
|
|
| def test_tensordict_with_packing(): |
| vocab_size = 128 |
| a = torch.randint(low=0, high=vocab_size, size=(11,)) |
| b = torch.randint(low=0, high=vocab_size, size=(13,)) |
| input_ids = [a, b] |
| input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged) |
|
|
| data = tu.get_tensordict({"input_ids": input_ids}) |
|
|
| |
| cu_seqlens = torch.tensor([0, 11, 24]) |
| assert torch.all(torch.eq(cu_seqlens, data["input_ids"].offsets())) |
|
|
| |
| assert torch.all(torch.eq(data["input_ids"][0], a)) |
| assert torch.all(torch.eq(data["input_ids"][1], b)) |
|
|
| assert torch.all(torch.eq(data[0]["input_ids"], a)) |
| assert torch.all(torch.eq(data[1]["input_ids"], b)) |
|
|
| data_lst = data.chunk(2) |
|
|
| assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a)) |
| assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b)) |
|
|
|
|
| def test_tensordict_eq(): |
| obs = torch.tensor([1, 2, 3, 4, 5, 6]) |
| data_sources = ["abc", "def", "abc", "def", "pol", "klj"] |
| non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}} |
| data = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) |
|
|
| obs = torch.tensor([1, 2, 3, 4, 5, 6]) |
| data_sources = ["abc", "def", "abc", "def", "pol", "klj"] |
| non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}} |
| data1 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) |
|
|
| tu.assert_tensordict_eq(data, data1) |
|
|
| data2 = copy.deepcopy(data1) |
| data2["obs"][0] += 1 |
|
|
| with pytest.raises(AssertionError): |
| tu.assert_tensordict_eq(data, data2) |
|
|
| data2 = copy.deepcopy(data1) |
| data2["data_sources"][0] = "math" |
|
|
| with pytest.raises(AssertionError): |
| tu.assert_tensordict_eq(data, data2) |
|
|
| data2 = copy.deepcopy(data1) |
| data2["train_sample_kwargs"]["top_p"] = 0.9 |
|
|
| with pytest.raises(AssertionError): |
| tu.assert_tensordict_eq(data, data2) |
|
|
| tensor_list = [ |
| torch.tensor([1, 2, 3, 3, 2]), |
| torch.tensor([4, 5]), |
| torch.tensor([7, 8, 10, 14]), |
| torch.tensor([10, 11, 12]), |
| torch.tensor([13, 14, 15, 18]), |
| torch.tensor([16, 17]), |
| ] |
| obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged) |
| data_sources = ["abc", "def", "abc", "def", "pol", "klj"] |
| non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}} |
| data3 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) |
|
|
| tensor_list[0] = torch.tensor([1, 2, 3, 3, 2]) |
| obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged) |
| data4 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) |
| tu.assert_tensordict_eq(data3, data4) |
|
|
| tensor_list[0] = torch.tensor([1, 2, 4]) |
| obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged) |
| data5 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) |
| with pytest.raises(AssertionError): |
| tu.assert_tensordict_eq(data3, data5) |
|
|
| tensor_list[0] = torch.tensor([4, 5]) |
| tensor_list[1] = torch.tensor([1, 2, 3, 3, 2]) |
| obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged) |
| data6 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) |
| with pytest.raises(AssertionError): |
| tu.assert_tensordict_eq(data3, data6) |
|
|
|
|
| def test_tensor_dict_make_iterator(): |
| obs = torch.tensor([1, 2, 3, 4, 5, 6]) |
| data_sources = ["abc", "def", "abc", "def", "pol", "klj"] |
| non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}} |
| dataset = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict) |
|
|
| dataloader = tu.make_iterator( |
| dataset, mini_batch_size=2, epochs=2, seed=0, dataloader_kwargs={"shuffle": False, "drop_last": False} |
| ) |
|
|
| expected_tensor_dict = [dataset[0:2], dataset[2:4], dataset[4:6], dataset[0:2], dataset[2:4], dataset[4:6]] |
|
|
| i = 0 |
|
|
| for d in dataloader: |
| tu.assert_tensordict_eq(d, expected_tensor_dict[i]) |
| i += 1 |
|
|
| data_iter_1 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True}) |
| data_list_1 = [] |
| for data in data_iter_1: |
| data_list_1.append(data) |
|
|
| data_iter_2 = tu.make_iterator(dataset, mini_batch_size=3, epochs=1, seed=1, dataloader_kwargs={"shuffle": True}) |
| data_list_2 = [] |
| for data in data_iter_2: |
| data_list_2.append(data) |
|
|
| for data1, data2 in zip(data_list_1, data_list_2, strict=True): |
| tu.assert_tensordict_eq(data1, data2) |
|
|
|
|
| def test_reorder(): |
| obs = torch.tensor([1, 2, 3, 4, 5, 6]) |
| labels = ["a", "b", "c", "d", "e", "f"] |
| non_tensor_dict = {"name": "abdce"} |
|
|
| data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict=non_tensor_dict) |
| data = data[torch.tensor([3, 4, 2, 0, 1, 5])] |
|
|
| assert torch.all(torch.eq(data["obs"], torch.tensor([4, 5, 3, 1, 2, 6]))) |
| assert np.all(data["labels"] == np.array(["d", "e", "c", "a", "b", "f"])) |
| assert data["name"] == "abdce" |
|
|
|
|
| def test_chunk_concat(): |
| obs = torch.tensor([1, 2, 3, 4, 5, 6]) |
| labels = ["a", "b", "c", "d", "e", "f"] |
| data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"name": "abcde"}) |
|
|
| data_split = data.tensor_split(indices_or_sections=5, dim=0) |
|
|
| expected_idx_lst = [[0, 1], [2], [3], [4], [5]] |
|
|
| for d, expected_idx in zip(data_split, expected_idx_lst, strict=False): |
| tu.assert_tensordict_eq(d, data[expected_idx]) |
|
|
| data_split = data.chunk(2) |
| assert len(data_split) == 2 |
| assert torch.all(torch.eq(data_split[0]["obs"], torch.tensor([1, 2, 3]))) |
| assert np.all(data_split[0]["labels"] == np.array(["a", "b", "c"])) |
| assert data_split[0]["name"] == "abcde" |
|
|
| assert torch.all(torch.eq(data_split[1]["obs"], torch.tensor([4, 5, 6]))) |
| assert np.all(data_split[1]["labels"] == np.array(["d", "e", "f"])) |
| assert data_split[1]["name"] == "abcde" |
|
|
| concat_data = torch.cat(data_split, dim=0) |
| assert torch.all(torch.eq(concat_data["obs"], data["obs"])) |
| assert np.all(concat_data["labels"] == data["labels"]) |
| assert concat_data["name"] == data["name"] |
|
|
|
|
| def test_pop(): |
| obs = torch.randn(100, 10) |
| act = torch.randn(100, 3) |
| dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1}) |
|
|
| poped_dataset = tu.pop(dataset, keys=["obs", "2"]) |
|
|
| assert poped_dataset.batch_size[0] == 100 |
|
|
| assert poped_dataset.keys() == {"obs", "2"} |
|
|
| assert dataset.keys() == {"act", "1"} |
|
|
|
|
| def test_repeat(): |
| |
| obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) |
| labels = ["a", "b", "c"] |
| data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"}) |
|
|
| |
| repeated_data_interleave = data.repeat_interleave(repeats=2) |
| expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]]) |
| expected_labels_interleave = ["a", "a", "b", "b", "c", "c"] |
|
|
| assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave)) |
| assert repeated_data_interleave["labels"] == expected_labels_interleave |
| assert repeated_data_interleave["info"] == "test_info" |
|
|
| |
| repeated_data_no_interleave = data.repeat(2) |
| expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]]) |
| expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"] |
|
|
| assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave)) |
| assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave |
| assert repeated_data_no_interleave["info"] == "test_info" |
|
|
|
|
| def test_dataproto_pad_unpad(): |
| obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) |
| labels = ["a", "b", "c"] |
| data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"}) |
|
|
| padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=2) |
|
|
| assert pad_size == 1 |
|
|
| expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]]) |
| expected_labels = ["a", "b", "c", "a"] |
|
|
| assert torch.all(torch.eq(padded_data["obs"], expected_obs)) |
| assert padded_data["labels"] == expected_labels |
| assert padded_data["info"] == "test_info" |
|
|
| unpadd_data = tu.unpad(padded_data, pad_size=pad_size) |
| assert torch.all(torch.eq(unpadd_data["obs"], obs)) |
| assert unpadd_data["labels"] == labels |
| assert unpadd_data["info"] == "test_info" |
|
|
| padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=3) |
| assert pad_size == 0 |
|
|
| expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) |
| expected_labels = ["a", "b", "c"] |
|
|
| assert torch.all(torch.eq(padded_data["obs"], expected_obs)) |
| assert padded_data["labels"] == expected_labels |
| assert padded_data["info"] == "test_info" |
|
|
| unpadd_data = tu.unpad(padded_data, pad_size=pad_size) |
| assert torch.all(torch.eq(unpadd_data["obs"], obs)) |
| assert unpadd_data["labels"] == labels |
| assert unpadd_data["info"] == "test_info" |
|
|
| padded_data, pad_size = tu.pad_to_divisor(data, size_divisor=7) |
| assert pad_size == 4 |
|
|
| expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]]) |
| expected_labels = ["a", "b", "c", "a", "b", "c", "a"] |
| assert torch.all(torch.eq(padded_data["obs"], expected_obs)) |
| assert padded_data["labels"] == expected_labels |
| assert padded_data["info"] == "test_info" |
|
|
| unpadd_data = tu.unpad(padded_data, pad_size=pad_size) |
| assert torch.all(torch.eq(unpadd_data["obs"], obs)) |
| assert unpadd_data["labels"] == labels |
| assert unpadd_data["info"] == "test_info" |
|
|
|
|
| def test_torch_save_data_proto(): |
| obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) |
| labels = ["a", "b", "c"] |
| data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"}) |
|
|
| filename = "test_data.pt" |
| torch.save(data, filename) |
| loaded_data = torch.load(filename, weights_only=False) |
|
|
| assert torch.all(torch.eq(loaded_data["obs"], data["obs"])) |
| assert loaded_data["labels"] == data["labels"] |
| assert loaded_data["info"] == data["info"] |
|
|
| import os |
|
|
| os.remove(filename) |
|
|
|
|
| def test_len(): |
| obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) |
| labels = np.array(["a", "b", "c"], dtype=object) |
|
|
| data = tu.get_tensordict({"obs": obs, "labels": labels.tolist()}, non_tensor_dict={"info": "test_info"}) |
| assert len(data) == 3 |
|
|
| data = tu.get_tensordict({"labels": labels.tolist()}, non_tensor_dict={"info": "test_info"}) |
| assert len(data) == 3 |
|
|
| data_item = data[0] |
| assert len(data_item) == 0 |
|
|
| data = tu.get_tensordict({}, non_tensor_dict={"info": "test_info"}) |
| assert len(data) == 0 |
|
|
|
|
| def test_dataproto_index(): |
| data_len = 100 |
| idx_num = 10 |
|
|
| obs = torch.randn(data_len, 10) |
| labels = [random.choice(["abc", "cde"]) for _ in range(data_len)] |
|
|
| data = tu.get_tensordict({"obs": obs, "labels": labels}) |
|
|
| labels_np = np.array(labels) |
|
|
| idx_np_int = np.random.randint(0, data_len, size=(idx_num,)) |
| result_np_int = data[idx_np_int] |
| assert result_np_int.keys() == data.keys() |
| assert result_np_int["obs"].shape[0] == idx_num |
| assert len(result_np_int["labels"]) == idx_num |
| assert np.array_equal(result_np_int["obs"].cpu().numpy(), obs[idx_np_int].numpy()) |
| assert np.array_equal(result_np_int["labels"], labels_np[idx_np_int]) |
|
|
| idx_torch_int = torch.randint(0, data_len, size=(idx_num,)) |
| result_torch_int = data[idx_torch_int] |
| assert result_torch_int.keys() == data.keys() |
| assert result_torch_int["obs"].shape[0] == idx_num |
| assert len(result_torch_int["labels"]) == idx_num |
| assert np.array_equal(result_torch_int["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy()) |
| assert np.array_equal(result_torch_int["labels"], labels_np[idx_torch_int.cpu().numpy()]) |
|
|
| idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)] |
| result_list_int = data[idx_list_int] |
| assert result_list_int.keys() == data.keys() |
| assert result_list_int["obs"].shape[0] == idx_num |
| assert len(result_list_int["labels"]) == idx_num |
| assert np.array_equal(result_list_int["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy()) |
| assert np.array_equal(result_list_int["labels"], labels_np[idx_list_int]) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool) |
| result_torch_bool = data[idx_torch_bool] |
| assert result_torch_bool.keys() == data.keys() |
| assert result_torch_bool["obs"].shape[0] == idx_torch_bool.sum().item() |
| assert len(result_torch_bool["labels"]) == idx_torch_bool.sum().item() |
| assert np.array_equal(result_torch_bool["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy()) |
| assert np.array_equal(result_torch_bool["labels"], labels_np[idx_torch_bool]) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def test_select(): |
| obs = torch.randn(100, 10) |
| act = torch.randn(100, 3) |
| dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1}) |
|
|
| subset = dataset.select("obs", "2") |
|
|
| assert torch.all(torch.eq(subset["obs"], dataset["obs"])) |
| assert subset["2"] == dataset["2"] |
| assert "act" not in subset.keys() |
| assert "1" not in subset.keys() |
|
|
|
|
| def test_dataproto_no_batch(): |
| labels = ["a", "b", "c"] |
| data = tu.get_tensordict(tensor_dict={"labels": labels}, non_tensor_dict={"info": "test_info"}) |
| selected = data.select("labels") |
|
|
| assert selected["labels"] == labels |
| pop_data = tu.pop(data, keys=["labels"]) |
| assert pop_data["labels"] == labels |
| assert "labels" not in data |
|
|
|
|
| def test_sample_level_repeat(): |
| |
| obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) |
| labels = ["a", "b", "c"] |
|
|
| data = tu.get_tensordict({"obs": obs, "labels": labels}, non_tensor_dict={"info": "test_info"}) |
|
|
| |
| repeated_data_interleave = data.repeat_interleave(repeats=torch.tensor([3, 1, 2])) |
| expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]]) |
| expected_labels_interleave = ["a", "a", "a", "b", "c", "c"] |
|
|
| assert torch.all(torch.eq(repeated_data_interleave["obs"], expected_obs_interleave)) |
| assert repeated_data_interleave["labels"] == expected_labels_interleave |
| assert repeated_data_interleave["info"] == "test_info" |
|
|
| |
| repeated_data_no_interleave = data.repeat_interleave(repeats=torch.tensor([1, 2, 3])) |
| expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]]) |
| expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"] |
|
|
| assert torch.all(torch.eq(repeated_data_no_interleave["obs"], expected_obs_no_interleave)) |
| assert repeated_data_no_interleave["labels"] == expected_labels_no_interleave |
| assert repeated_data_no_interleave["info"] == "test_info" |
|
|
|
|
| def test_dataproto_chunk_after_index(): |
| data_len = 4 |
| obs = torch.randn(data_len, 4) |
| labels = [f"label_{i}" for i in range(data_len)] |
|
|
| data = tu.get_tensordict(tensor_dict={"obs": obs, "labels": labels}, non_tensor_dict={"name": "abc"}) |
| |
| bool_mask = torch.tensor([True, False, True, False]) |
| selected = data[bool_mask] |
| assert isinstance(selected.batch_size, torch.Size) |
| assert all(isinstance(d, int) for d in selected.batch_size) |
|
|
| |
| int_mask = torch.tensor([0, 2]) |
| selected = data[int_mask] |
| assert isinstance(selected.batch_size, torch.Size) |
| assert all(isinstance(d, int) for d in selected.batch_size) |
|
|
| |
| list_mask = [True, False, True, False] |
| selected = data[list_mask] |
| assert isinstance(selected.batch_size, torch.Size) |
| assert all(isinstance(d, int) for d in selected.batch_size) |
|
|
| |
| list_mask = [0, 2] |
| selected = data[list_mask] |
| assert isinstance(selected.batch_size, torch.Size) |
| assert all(isinstance(d, int) for d in selected.batch_size) |
|
|
| |
| torch_bool_mask = torch.tensor([True, False, True, False]) |
| selected = data[torch_bool_mask] |
| assert isinstance(selected.batch_size, torch.Size) |
| assert all(isinstance(d, int) for d in selected.batch_size) |
|
|
| |
| torch_int_mask = torch.tensor([0, 2]) |
| selected = data[torch_int_mask] |
| assert isinstance(selected.batch_size, torch.Size) |
| assert all(isinstance(d, int) for d in selected.batch_size) |
|
|