| from train.trainer import build_policy | |
| def test_rgb_backward_compat(tiny_policy_config, tiny_trainer_config, tiny_batch): | |
| config = tiny_policy_config() | |
| batch = tiny_batch(chunk_size=config.decoder.chunk_size) | |
| interaction_policy = build_policy(config, tiny_trainer_config(policy_type="interaction_state")) | |
| interaction_output = interaction_policy( | |
| images=batch["images"], | |
| proprio=batch["proprio"], | |
| texts=batch["texts"], | |
| history_images=batch["history_images"], | |
| history_proprio=batch["history_proprio"], | |
| history_actions=batch["history_actions"], | |
| plan=True, | |
| ) | |
| assert interaction_output["action_mean"].shape[-1] == 14 | |
| assert interaction_output["candidate_chunks"].shape[1] == config.decoder.num_candidates | |
| elastic_policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) | |
| elastic_output = elastic_policy( | |
| images=batch["images"], | |
| proprio=batch["proprio"], | |
| texts=batch["texts"], | |
| history_images=batch["history_images"], | |
| history_proprio=batch["history_proprio"], | |
| history_actions=batch["history_actions"], | |
| plan=True, | |
| use_depth=False, | |
| ) | |
| assert elastic_output["action_mean"].shape[-1] == 14 | |
| assert elastic_output["planned_chunk"].shape == elastic_output["action_mean"].shape | |