#!/usr/bin/env python3 """ VLAC Service - HTTP API for Vision-Language-Action-Critic model A minimal FastAPI service that exposes the VLAC model from evo_vlac for use in SimpleVLA-RL training. Usage: python vlac_service.py --port 8111 --gpu-ids 0,1,2,3 """ import argparse import base64 import hashlib import os import subprocess import sys from io import BytesIO from pathlib import Path from typing import List, Optional, Union import torch from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from PIL import Image from pydantic import BaseModel, Field import uvicorn # Add evo_vlac to path sys.path.insert(0, str(Path(__file__).parent / "evo_vlac")) from evo_vlac import GAC_model # Global model instance model: Optional[GAC_model] = None model_tag = "1.0.0" checkpoint_sha = None # Request/Response Models class PairwiseCriticRequest(BaseModel): task: str image_a: str = Field(..., description="Base-64 encoded RGB image") image_b: str = Field(..., description="Base-64 encoded RGB image") rich: Optional[bool] = False class PairwiseCriticResponse(BaseModel): critic: float raw: str class DoneRequest(BaseModel): task: str first_frame: str = Field(..., description="Base-64 encoded RGB image") prev_frame: str = Field(..., description="Base-64 encoded RGB image") curr_frame: str = Field(..., description="Base-64 encoded RGB image") reference: Optional[List[str]] = Field(None, description="List of base-64 encoded RGB images") class DoneResponse(BaseModel): done: bool prob: float class TrajectoryCriticRequest(BaseModel): task: str frames: List[str] = Field(..., description="List of base-64 encoded RGB images") reference: Optional[List[str]] = Field(None, description="List of base-64 encoded RGB images") skip: int = 5 ref_num: int = 6 batch_size: int = 10 think: bool = False return_video: bool = False demo_id: str = "" class TrajectoryCriticResponse(BaseModel): value_list: List[float] critic_list: List[float] done_list: Optional[List[float]] = None video: Optional[str] = None class HealthResponse(BaseModel): status: str model_tag: str # Utility functions def get_gpu_memory_usage(): """Get GPU memory usage for all GPUs.""" try: result = subprocess.run( ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"], capture_output=True, text=True, check=True ) gpu_usage = {} for line in result.stdout.strip().split('\n'): if line.strip(): gpu_id, memory_used = line.strip().split(', ') gpu_usage[int(gpu_id)] = int(memory_used) return gpu_usage except (subprocess.CalledProcessError, FileNotFoundError): return {} def select_best_gpu(gpu_ids: List[int]) -> int: """Select GPU with lowest memory usage from the given list.""" gpu_usage = get_gpu_memory_usage() if not gpu_usage: # Fallback to first GPU if nvidia-smi not available return gpu_ids[0] # Find GPU with minimum usage best_gpu = min(gpu_ids, key=lambda x: gpu_usage.get(x, float('inf'))) return best_gpu def b64_to_pil(data: str) -> Image.Image: """Convert base64 string to PIL Image.""" try: img_data = base64.b64decode(data) img = Image.open(BytesIO(img_data)).convert("RGB") # Resize to 448x448 as specified in contract img = img.resize((448, 448), Image.Resampling.LANCZOS) return img except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}") def save_debug_image(img: Image.Image, prefix: str): """Save image for debugging if VLAC_SAVE_INPUTS=1.""" if os.getenv("VLAC_SAVE_INPUTS") == "1": debug_dir = Path("./vlac_debug") debug_dir.mkdir(exist_ok=True) timestamp = str(hash(img.tobytes()))[:8] img.save(debug_dir / f"{prefix}_{timestamp}.png") def chunk_list(lst: List, chunk_size: int) -> List[List]: """Split list into chunks of specified size.""" return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] def get_checkpoint_sha(ckpt_path: str) -> str: """Compute SHA256 of checkpoint directory (simplified).""" try: # Simple hash of main model files files_to_hash = [] ckpt_dir = Path(ckpt_path) for pattern in ["*.bin", "*.safetensors", "config.json"]: files_to_hash.extend(ckpt_dir.glob(pattern)) if not files_to_hash: return "unknown" hasher = hashlib.sha256() for file_path in sorted(files_to_hash): hasher.update(file_path.name.encode()) if file_path.stat().st_size < 1024 * 1024: # Only hash small files fully hasher.update(file_path.read_bytes()) else: hasher.update(str(file_path.stat().st_size).encode()) return hasher.hexdigest()[:16] except Exception: return "unknown" # FastAPI app app = FastAPI(title="VLAC Service", version="1.0.0") def add_version_headers(response: JSONResponse): """Add versioning headers to response.""" response.headers["X-VLAC-Model-Tag"] = model_tag response.headers["X-VLAC-Checkpoint-SHA"] = checkpoint_sha or "unknown" return response @app.post("/healthcheck", response_model=HealthResponse) async def healthcheck(): """Health check endpoint.""" return HealthResponse(status="ok", model_tag=model_tag) @app.post("/pairwise-critic", response_model=PairwiseCriticResponse) async def pairwise_critic(req: PairwiseCriticRequest): """Compare two images and return critic score.""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Convert images with explicit error handling try: img_a = b64_to_pil(req.image_a) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image_a: {str(e)}") try: img_b = b64_to_pil(req.image_b) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image_b: {str(e)}") # Debug save save_debug_image(img_a, "pairwise_a") save_debug_image(img_b, "pairwise_b") # Get critic score try: critic_list, _ = model.get_trajectory_critic( task=req.task, image_list=[img_a, img_b], ref_image_list=None, batch_num=1, ref_num=0, rich=req.rich, skip=1, frame_skip=True ) except Exception as e: raise HTTPException(status_code=500, detail=f"Model processing error: {str(e)}") critic_val = float(critic_list[-1]) if critic_list else 0.0 response = JSONResponse( content=PairwiseCriticResponse( critic=critic_val, raw=str(critic_val) ).dict() ) return add_version_headers(response) except HTTPException: # Re-raise HTTP exceptions as-is raise except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") # temporarily disable done detection @app.post("/done", response_model=DoneResponse) async def done(req: DoneRequest): """Determine if trajectory should terminate.""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") # Validate reference image constraints (same as trajectory-critic) if req.reference and len(req.reference) > 0: # For /done endpoint, we need at least 2 reference images when using in-context detection if len(req.reference) < 2: raise HTTPException( status_code=422, detail=f"When reference images are provided for done detection, at least 2 are required (got {len(req.reference)}). Use reference=null for simple done detection." ) try: # Convert images first_img = b64_to_pil(req.first_frame) prev_img = b64_to_pil(req.prev_frame) curr_img = b64_to_pil(req.curr_frame) # Convert reference images if provided ref_imgs = None if req.reference: ref_imgs = [b64_to_pil(ref) for ref in req.reference[:8]] # Max 11 as per contract # Debug save save_debug_image(curr_img, "done_curr") # Use in-context done detection if reference images provided if ref_imgs: done_list, _ = model.get_in_context_done( task=req.task, first_image=[first_img], n_pre_image=[prev_img], now_image=[curr_img], ref_image_list=ref_imgs, ref_num=len(ref_imgs), rich=True ) prob = float(done_list[0]) if done_list else 0.0 else: # Simple done detection on current frame done_list = model.get_trajectory_done( task=req.task, image_list=[curr_img], batch_num=1, rich=True ) prob = float(done_list[0]) if done_list else 0.0 # Convert probability to boolean is_done = prob > 0.75 response = JSONResponse( content=DoneResponse(done=is_done, prob=prob).dict() ) return add_version_headers(response) except Exception as e: raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") @app.post("/trajectory-critic", response_model=TrajectoryCriticResponse) async def trajectory_critic(req: TrajectoryCriticRequest): """Compute critic and value curves for full trajectory.""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") print(f"Processing trajectory-critic request: {len(req.frames)} frames, task: {req.task[:50]}...") # Validate reference image constraints if req.reference and len(req.reference) > 0 and req.ref_num < 2: raise HTTPException( status_code=422, detail=f"When reference images are provided, ref_num must be >= 2 (got {req.ref_num}). Use ref_num=0 and reference=null for no reference images." ) try: # Convert images print("Converting base64 images to PIL...") frames = [] for i, frame in enumerate(req.frames): try: frames.append(b64_to_pil(frame)) except Exception as e: print(f"Error converting frame {i}: {e}") raise HTTPException(status_code=400, detail=f"Invalid image data in frame {i}: {str(e)}") # Convert reference images if provided ref_imgs = None if req.reference: print(f"Converting {len(req.reference)} reference images...") ref_imgs = [] for i, ref in enumerate(req.reference): try: ref_imgs.append(b64_to_pil(ref)) except Exception as e: print(f"Error converting reference {i}: {e}") raise HTTPException(status_code=400, detail=f"Invalid reference image {i}: {str(e)}") # Debug save first and last frames, add demo_id to the filename if frames: save_debug_image(frames[0], f"traj_first_{req.demo_id}") save_debug_image(frames[-1], f"traj_last_{req.demo_id}") # Handle large batch sizes by chunking effective_batch_size = min(req.batch_size, 8) print(f"Using effective batch size: {effective_batch_size}") # Get trajectory critic print("Calling model.get_trajectory_critic...") try: critic_list, value_list = model.get_trajectory_critic( task=req.task, image_list=frames, ref_image_list=ref_imgs, batch_num=effective_batch_size, ref_num=req.ref_num, think=req.think, skip=req.skip, rich=True, frame_skip=True ) print(f"Got results: {len(critic_list)} critics, {len(value_list)} values") except Exception as e: print(f"Error in model.get_trajectory_critic: {e}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Model processing error: {str(e)}") # Convert to floats try: critic_floats = [float(c) for c in critic_list] value_floats = [float(v) for v in value_list] print(f"Converted to floats: {len(critic_floats)} critics, {len(value_floats)} values") print(f"critic_floats: {critic_floats}") print(f"value_floats: {value_floats}") except Exception as e: print(f"Error converting to floats: {e}") print(f"critic_list types: {[type(c) for c in critic_list[:3]]}") print(f"value_list types: {[type(v) for v in value_list[:3]]}") raise HTTPException(status_code=500, detail=f"Type conversion error: {str(e)}") # Generate video if requested (skip this for now to isolate the issue) video_b64 = None # if req.return_video: # try: # video_path, _, _, _ = model.web_trajectory_critic(...) # except Exception as e: # print(f"Warning: Video generation failed: {e}") print("Creating response...") response = JSONResponse( content=TrajectoryCriticResponse( value_list=value_floats, critic_list=critic_floats, done_list=None, # Not implemented in this endpoint video=video_b64 ).dict() ) return add_version_headers(response) except HTTPException: # Re-raise HTTP exceptions as-is raise except Exception as e: print(f"Unexpected error in trajectory_critic: {e}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") def init_model(ckpt_path: str, gpu_id: int): """Initialize the VLAC model.""" global model, checkpoint_sha print(f"Loading VLAC model from {ckpt_path} on GPU {gpu_id}") try: model = GAC_model(tag='critic') model.init_model( model_path=ckpt_path, model_type='internvl2', device_map=f'cuda:{gpu_id}' ) model.temperature = 0.5 model.top_k = 1 model.set_config() model.set_system_prompt() checkpoint_sha = get_checkpoint_sha(ckpt_path) print(f"Model loaded successfully. Checkpoint SHA: {checkpoint_sha}") except Exception as e: print(f"Failed to load model: {e}") sys.exit(1) def main(): parser = argparse.ArgumentParser(description="VLAC Service") parser.add_argument("--port", type=int, default=8111, help="Port to run on") parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") parser.add_argument("--ckpt-path", default="/home/zechen/SimpleVLA-RL/CKPT/VLAC", help="Path to VLAC checkpoint") parser.add_argument("--gpu-ids", default="0", help="Comma-separated GPU IDs") parser.add_argument("--workers", type=int, default=1, help="Number of workers") args = parser.parse_args() # Parse GPU IDs gpu_ids = [int(x.strip()) for x in args.gpu_ids.split(",")] # Select best GPU selected_gpu = select_best_gpu(gpu_ids) print(f"Selected GPU {selected_gpu} from available: {gpu_ids}") # Initialize model init_model(args.ckpt_path, selected_gpu) # Run server print(f"Starting VLAC service on {args.host}:{args.port}") uvicorn.run( app, host=args.host, port=args.port, workers=args.workers, log_level="info" ) if __name__ == "__main__": main() """ byte server comman: python vlac_service.py --port 8111 --gpu-ids 0,1,2,3 --ckpt-path /mnt/bn/vgfm2/test_dit/zechen/cosmos-reason1/VLAC/model/ckpt """