| import torch |
|
|
| from train.trainer import build_policy |
|
|
|
|
| def test_policy_topk_cascade(tiny_policy_config, tiny_trainer_config, tiny_batch): |
| config = tiny_policy_config(num_candidates=4, top_k=2) |
| batch = tiny_batch(chunk_size=config.decoder.chunk_size) |
| policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) |
| output = policy( |
| images=batch["images"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=batch["camera_extrinsics"], |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| history_images=batch["history_images"], |
| history_depths=batch["history_depths"], |
| history_depth_valid=batch["history_depth_valid"], |
| history_proprio=batch["history_proprio"], |
| history_actions=batch["history_actions"], |
| plan=True, |
| ) |
| assert output["planner_topk_indices"].shape[1] == config.planner.top_k |
| assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k |
| assert (output["best_candidate_indices"] < config.decoder.num_candidates).all() |
|
|
|
|
| def test_policy_null_rollout_ablation_keeps_planner_interface( |
| tiny_policy_config, |
| tiny_trainer_config, |
| tiny_batch, |
| ): |
| config = tiny_policy_config(num_candidates=4, top_k=2) |
| batch = tiny_batch(chunk_size=config.decoder.chunk_size) |
| policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) |
| output = policy( |
| images=batch["images"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=batch["camera_extrinsics"], |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| history_images=batch["history_images"], |
| history_depths=batch["history_depths"], |
| history_depth_valid=batch["history_depth_valid"], |
| history_proprio=batch["history_proprio"], |
| history_actions=batch["history_actions"], |
| plan=True, |
| use_world_model=False, |
| use_planner=True, |
| ) |
| rollout = output["planned_rollout"] |
| current_state = output["interaction_state"] |
| assert output["rollout_source"] == "identity" |
| assert output["planner_topk_indices"].shape[1] == config.planner.top_k |
| assert rollout["target_belief_field"].shape[1] == config.planner.top_k |
| repeated_belief = current_state["target_belief_field"].detach().unsqueeze(1).unsqueeze(2).expand_as( |
| rollout["target_belief_field"] |
| ) |
| repeated_phase = current_state["phase_logits"].detach().unsqueeze(1).unsqueeze(2).expand_as( |
| rollout["phase_logits"] |
| ) |
| assert output["utility_total"].shape == (batch["images"].shape[0], config.planner.top_k) |
| assert torch.allclose(rollout["target_belief_field"], repeated_belief) |
| assert torch.allclose(rollout["phase_logits"], repeated_phase) |
|
|