| |
| """ |
| VLAC Integration Test |
| |
| Simple test to verify VLAC integration with SimpleVLA-RL rollout works correctly. |
| This test creates a minimal rollout configuration and runs a few steps. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import subprocess |
| import requests |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from verl.utils.vlac_client import VLACClient |
| import numpy as np |
|
|
|
|
| class MockConfig: |
| """Mock configuration for testing VLAC integration.""" |
| def __init__(self, use_vlac=True, vlac_service_url="http://localhost:8111"): |
| self.use_vlac = use_vlac |
| self.vlac_service_url = vlac_service_url |
| self.val_only = False |
| self.task_suite_name = "libero_10" |
| self.max_steps = {"libero_10": 512} |
|
|
|
|
| def test_vlac_client(): |
| """Test VLAC client functionality.""" |
| print("Testing VLAC client...") |
| |
| try: |
| |
| client = VLACClient(service_url="http://localhost:8111", timeout=30) |
| |
| |
| height, width = 256, 256 |
| first_frame = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) |
| prev_frame = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) |
| curr_frame = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) |
| |
| |
| print("Testing done detection...") |
| done, prob = client.check_done( |
| task="Pick up the red bowl and place it in the white box", |
| first_frame=first_frame, |
| prev_frame=prev_frame, |
| curr_frame=curr_frame |
| ) |
| print(f"Done check result: done={done}, prob={prob:.3f}") |
| |
| |
| print("Testing trajectory value computation...") |
| frames = [first_frame, prev_frame, curr_frame] |
| value_list, critic_list = client.compute_trajectory_values( |
| task="Pick up the red bowl and place it in the white box", |
| frames=frames, |
| skip=1 |
| ) |
| print(f"Trajectory values: {len(value_list)} values, {len(critic_list)} critics") |
| print(f"Final value: {value_list[-1]:.3f}") |
| |
| |
| print("Testing pairwise critic...") |
| critic_score = client.pairwise_critic( |
| task="Pick up the red bowl and place it in the white box", |
| image_a=prev_frame, |
| image_b=curr_frame |
| ) |
| print(f"Pairwise critic score: {critic_score:.3f}") |
| |
| print("β VLAC client tests passed!") |
| return True |
| |
| except Exception as e: |
| print(f"β VLAC client test failed: {e}") |
| return False |
|
|
|
|
| def test_config_integration(): |
| """Test configuration integration.""" |
| print("Testing configuration integration...") |
| |
| |
| train_config = MockConfig(use_vlac=True) |
| assert train_config.use_vlac == True |
| assert train_config.val_only == False |
| print("β Training configuration correct") |
| |
| |
| eval_config = MockConfig(use_vlac=False) |
| eval_config.val_only = True |
| assert eval_config.use_vlac == False |
| assert eval_config.val_only == True |
| print("β Evaluation configuration correct") |
| |
| return True |
|
|
|
|
| def check_service_health(): |
| """Check if VLAC service is running and healthy.""" |
| try: |
| response = requests.post( |
| "http://localhost:8111/healthcheck", |
| timeout=10 |
| ) |
| if response.status_code == 200: |
| result = response.json() |
| print(f"β VLAC service healthy: {result}") |
| return True |
| else: |
| print(f"β VLAC service unhealthy: {response.status_code}") |
| return False |
| except Exception as e: |
| print(f"β VLAC service not accessible: {e}") |
| return False |
|
|
|
|
| def main(): |
| print("VLAC Integration Test Suite") |
| print("=" * 50) |
| |
| |
| print("1. Checking VLAC service health...") |
| if not check_service_health(): |
| print("\nPlease start VLAC service first:") |
| print("python vlac_service.py --port 8111 --gpu-ids 0") |
| return 1 |
| |
| |
| print("\n2. Testing configuration integration...") |
| if not test_config_integration(): |
| return 1 |
| |
| |
| print("\n3. Testing VLAC client...") |
| if not test_vlac_client(): |
| return 1 |
| |
| print("\n" + "=" * 50) |
| print("π All VLAC integration tests passed!") |
| print("\nYou can now run training with VLAC:") |
| print("bash examples/run_openvla_oft_rl_vlac.sh") |
| |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|