VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_policy_topk_cascade.py
lsnu's picture
Add files using upload-large-folder tool
31ade1f verified
import torch
from train.trainer import build_policy
def test_policy_topk_cascade(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config(num_candidates=4, top_k=2)
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
output = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=True,
)
assert output["planner_topk_indices"].shape[1] == config.planner.top_k
assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
assert (output["best_candidate_indices"] < config.decoder.num_candidates).all()
def test_policy_null_rollout_ablation_keeps_planner_interface(
tiny_policy_config,
tiny_trainer_config,
tiny_batch,
):
config = tiny_policy_config(num_candidates=4, top_k=2)
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
output = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=True,
use_world_model=False,
use_planner=True,
)
rollout = output["planned_rollout"]
current_state = output["interaction_state"]
assert output["rollout_source"] == "identity"
assert output["planner_topk_indices"].shape[1] == config.planner.top_k
assert rollout["target_belief_field"].shape[1] == config.planner.top_k
repeated_belief = current_state["target_belief_field"].detach().unsqueeze(1).unsqueeze(2).expand_as(
rollout["target_belief_field"]
)
repeated_phase = current_state["phase_logits"].detach().unsqueeze(1).unsqueeze(2).expand_as(
rollout["phase_logits"]
)
assert output["utility_total"].shape == (batch["images"].shape[0], config.planner.top_k)
assert torch.allclose(rollout["target_belief_field"], repeated_belief)
assert torch.allclose(rollout["phase_logits"], repeated_phase)