VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_lightweight_transition_contract.py
lsnu's picture
Add files using upload-large-folder tool
31ade1f verified
import torch
from models.world_model import LightweightRevealStateTransitionModel
def test_lightweight_transition_contract(tiny_policy_config, tiny_state):
config = tiny_policy_config(num_candidates=4, chunk_size=2)
model = LightweightRevealStateTransitionModel(config.world_model)
state = tiny_state(batch_size=2, field_size=config.reveal_head.field_size)
action_chunk = torch.rand(2, 4, config.decoder.chunk_size, config.decoder.action_dim)
proposal_mode_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long).expand(2, -1)
rollout = model(
interaction_state=state,
action_chunk=action_chunk,
proposal_mode_ids=proposal_mode_ids,
)
assert rollout["visibility_summary"].shape == (2, 4, config.decoder.chunk_size)
assert rollout["access_field"].shape[:4] == (2, 4, config.decoder.chunk_size, config.world_model.num_support_modes)
assert rollout["clearance_field"].shape == (2, 4, config.decoder.chunk_size, 2, 1, 1)