| from sim_reveal.dataset import RGBD_PROXY_DATASET_VERSION, collect_teacher_dataset, dataset_from_bundle | |
| def test_dataset_v6_keys(): | |
| bundle = collect_teacher_dataset( | |
| episodes_per_proxy=1, | |
| resolution=16, | |
| history_steps=2, | |
| planner_candidates=3, | |
| dataset_version=RGBD_PROXY_DATASET_VERSION, | |
| ) | |
| dataset = dataset_from_bundle(bundle, resolution=16) | |
| item = dataset[0] | |
| for key in ( | |
| "images", | |
| "depths", | |
| "depth_valid", | |
| "belief_map", | |
| "visibility_map", | |
| "clearance_map", | |
| "support_stability", | |
| "reocclusion_target", | |
| "candidate_rollout_belief_map", | |
| ): | |
| assert key in item | |
| def test_phase_dataset_version_keeps_rgbd_path(): | |
| bundle = collect_teacher_dataset( | |
| proxy_names=["bag_proxy"], | |
| episodes_per_proxy=1, | |
| resolution=16, | |
| history_steps=2, | |
| planner_candidates=3, | |
| dataset_version=RGBD_PROXY_DATASET_VERSION + "_phase", | |
| ) | |
| dataset = dataset_from_bundle(bundle, resolution=16) | |
| item = dataset[0] | |
| assert float(item["depth_valid"].sum()) > 0.0 | |
| assert "phase" in item | |
| assert "rollout_phase" in item | |
| assert "candidate_rollout_phase" in item | |
| def test_dataset_proposal_target_keys_roundtrip(): | |
| bundle = collect_teacher_dataset( | |
| proxy_names=["cloth_proxy"], | |
| episodes_per_proxy=1, | |
| resolution=16, | |
| history_steps=2, | |
| planner_candidates=3, | |
| dataset_version=RGBD_PROXY_DATASET_VERSION + "_phase", | |
| proposal_target_builder=lambda env, observation, sample: { | |
| "proposal_target_action_chunks": sample["candidate_action_chunks"].copy(), | |
| "proposal_target_retrieval_success": sample["candidate_retrieval_success"].copy(), | |
| "proposal_target_risk": sample["candidate_risk"].copy(), | |
| "proposal_target_utility": sample["candidate_utility"].copy(), | |
| }, | |
| ) | |
| dataset = dataset_from_bundle(bundle, resolution=16) | |
| item = dataset[0] | |
| assert "proposal_target_action_chunks" in item | |
| assert "proposal_target_retrieval_success" in item | |
| assert "proposal_target_risk" in item | |
| assert "proposal_target_utility" in item | |