VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_lightweight_transition_contract.py
| import torch | |
| from models.world_model import LightweightRevealStateTransitionModel | |
| def test_lightweight_transition_contract(tiny_policy_config, tiny_state): | |
| config = tiny_policy_config(num_candidates=4, chunk_size=2) | |
| model = LightweightRevealStateTransitionModel(config.world_model) | |
| state = tiny_state(batch_size=2, field_size=config.reveal_head.field_size) | |
| action_chunk = torch.rand(2, 4, config.decoder.chunk_size, config.decoder.action_dim) | |
| proposal_mode_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long).expand(2, -1) | |
| rollout = model( | |
| interaction_state=state, | |
| action_chunk=action_chunk, | |
| proposal_mode_ids=proposal_mode_ids, | |
| ) | |
| assert rollout["visibility_summary"].shape == (2, 4, config.decoder.chunk_size) | |
| assert rollout["access_field"].shape[:4] == (2, 4, config.decoder.chunk_size, config.world_model.num_support_modes) | |
| assert rollout["clearance_field"].shape == (2, 4, config.decoder.chunk_size, 2, 1, 1) | |