VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_rgb_backward_compat.py
lsnu's picture
Add files using upload-large-folder tool
b14c4b7 verified
from train.trainer import build_policy
def test_rgb_backward_compat(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
interaction_policy = build_policy(config, tiny_trainer_config(policy_type="interaction_state"))
interaction_output = interaction_policy(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=True,
)
assert interaction_output["action_mean"].shape[-1] == 14
assert interaction_output["candidate_chunks"].shape[1] == config.decoder.num_candidates
elastic_policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
elastic_output = elastic_policy(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=True,
use_depth=False,
)
assert elastic_output["action_mean"].shape[-1] == 14
assert elastic_output["planned_chunk"].shape == elastic_output["action_mean"].shape