#!/usr/bin/env python3 """ VLAC Service (Multi-Worker) - HTTP API for Vision-Language-Action-Critic model This variant of the service keeps the single-worker code path untouched so that we can debug multi-process behaviour separately. Each uvicorn worker lazily loads the VLAC model, pins itself to a GPU from the provided list, and limits CUDA visibility to that device before model initialisation. Usage: python vlac_service_multi_worker.py --port 8111 --gpu-ids 0,1,2,3 \ --ckpt-path /path/to/VLAC/checkpoint --workers 4 """ import argparse import asyncio import atexit import base64 import json import hashlib import os import subprocess import sys from io import BytesIO from pathlib import Path from typing import List, Optional import torch from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from PIL import Image from pydantic import BaseModel, Field import uvicorn # Add evo_vlac to path sys.path.insert(0, str(Path(__file__).parent / "evo_vlac")) from evo_vlac import GAC_model # Global model instance model: Optional[GAC_model] = None model_tag = "1.0.0" checkpoint_sha: Optional[str] = None load_lock: Optional[asyncio.Lock] = None try: import fcntl # type: ignore[attr-defined] except ImportError: # pragma: no cover - platform without fcntl (e.g. Windows) fcntl = None # type: ignore[assignment] GPU_LOCK_FILE = Path(os.environ.get("VLAC_GPU_LOCK_FILE", "/tmp/vlac_gpu_lock")) GPU_ASSIGNMENT_FILE = Path( os.environ.get("VLAC_GPU_ASSIGNMENT_FILE", "/tmp/vlac_gpu_assignments.json") ) lock_release_registered = False assigned_gpu_id: Optional[int] = None def _acquire_assignment_lock() -> Optional[int]: if fcntl is None: return None GPU_LOCK_FILE.parent.mkdir(parents=True, exist_ok=True) fd = os.open(str(GPU_LOCK_FILE), os.O_RDWR | os.O_CREAT, 0o600) fcntl.flock(fd, fcntl.LOCK_EX) return fd def _release_assignment_lock(fd: Optional[int]): if fd is None or fcntl is None: return fcntl.flock(fd, fcntl.LOCK_UN) os.close(fd) def _read_assignment_state() -> dict[int, dict[str, int]]: if not GPU_ASSIGNMENT_FILE.exists(): return {} try: with GPU_ASSIGNMENT_FILE.open("r", encoding="utf-8") as file: raw = json.load(file) except (json.JSONDecodeError, OSError): return {} result: dict[int, dict[str, int]] = {} for key, value in raw.items(): try: pid = int(key) if isinstance(value, dict) and "gpu_id" in value: result[pid] = {"gpu_id": int(value["gpu_id"])} except (ValueError, TypeError): continue return result def _write_assignment_state(state: dict[int, dict[str, int]]): GPU_ASSIGNMENT_FILE.parent.mkdir(parents=True, exist_ok=True) tmp_file = GPU_ASSIGNMENT_FILE.with_suffix(".tmp") with tmp_file.open("w", encoding="utf-8") as file: json.dump({str(pid): data for pid, data in state.items()}, file) tmp_file.replace(GPU_ASSIGNMENT_FILE) def _is_process_alive(pid: int) -> bool: try: os.kill(pid, 0) except OSError: return False return True def _cleanup_assignment_state(state: dict[int, dict[str, int]], gpu_ids: List[int]) -> dict[int, dict[str, int]]: cleaned: dict[int, dict[str, int]] = {} for pid, data in state.items(): gpu_id = data.get("gpu_id") if gpu_id in gpu_ids and _is_process_alive(pid): cleaned[pid] = {"gpu_id": gpu_id} return cleaned def _register_release_handler(): global lock_release_registered if lock_release_registered: return def _release_on_exit(): release_gpu_assignment() atexit.register(_release_on_exit) lock_release_registered = True def release_gpu_assignment(): global assigned_gpu_id if assigned_gpu_id is None or fcntl is None: return fd = _acquire_assignment_lock() try: state = _read_assignment_state() state.pop(os.getpid(), None) _write_assignment_state(state) finally: _release_assignment_lock(fd) assigned_gpu_id = None # Request/Response Models class PairwiseCriticRequest(BaseModel): task: str image_a: str = Field(..., description="Base-64 encoded RGB image") image_b: str = Field(..., description="Base-64 encoded RGB image") rich: Optional[bool] = False class PairwiseCriticResponse(BaseModel): critic: float raw: str class DoneRequest(BaseModel): task: str first_frame: str = Field(..., description="Base-64 encoded RGB image") prev_frame: str = Field(..., description="Base-64 encoded RGB image") curr_frame: str = Field(..., description="Base-64 encoded RGB image") reference: Optional[List[str]] = Field(None, description="List of base-64 encoded RGB images") class DoneResponse(BaseModel): done: bool prob: float class TrajectoryCriticRequest(BaseModel): task: str frames: List[str] = Field(..., description="List of base-64 encoded RGB images") reference: Optional[List[str]] = Field(None, description="List of base-64 encoded RGB images") skip: int = 5 ref_num: int = 6 batch_size: int = 10 think: bool = False return_video: bool = False demo_id: str = "" class TrajectoryCriticResponse(BaseModel): value_list: List[float] critic_list: List[float] done_list: Optional[List[float]] = None video: Optional[str] = None class HealthResponse(BaseModel): status: str model_tag: str # Utility functions def get_gpu_memory_usage() -> dict: """Get GPU memory usage for all GPUs.""" try: result = subprocess.run( ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"], capture_output=True, text=True, check=True, ) gpu_usage: dict[int, int] = {} for line in result.stdout.strip().split("\n"): if line.strip(): gpu_id, memory_used = line.strip().split(", ") gpu_usage[int(gpu_id)] = int(memory_used) return gpu_usage except (subprocess.CalledProcessError, FileNotFoundError): return {} def select_best_gpu(gpu_ids: List[int]) -> int: """Select GPU with lowest memory usage from the given list.""" gpu_usage = get_gpu_memory_usage() if not gpu_usage: # Fallback to first GPU if nvidia-smi not available return gpu_ids[0] # Find GPU with minimum usage best_gpu = min(gpu_ids, key=lambda x: gpu_usage.get(x, float("inf"))) return best_gpu def b64_to_pil(data: str) -> Image.Image: """Convert base64 string to PIL Image.""" try: img_data = base64.b64decode(data) img = Image.open(BytesIO(img_data)).convert("RGB") # Resize to 448x448 as specified in contract img = img.resize((448, 448), Image.Resampling.LANCZOS) return img except Exception as exc: raise HTTPException(status_code=400, detail=f"Invalid image data: {exc}") from exc def save_debug_image(img: Image.Image, prefix: str): """Save image for debugging if VLAC_SAVE_INPUTS=1.""" if os.getenv("VLAC_SAVE_INPUTS") == "1": debug_dir = Path("./vlac_debug") debug_dir.mkdir(exist_ok=True) timestamp = str(hash(img.tobytes()))[:8] img.save(debug_dir / f"{prefix}_{timestamp}.png") def chunk_list(lst: List, chunk_size: int) -> List[List]: """Split list into chunks of specified size.""" return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] def get_checkpoint_sha(ckpt_path: str) -> str: """Compute SHA256 of checkpoint directory (simplified).""" try: # Simple hash of main model files files_to_hash: list[Path] = [] ckpt_dir = Path(ckpt_path) for pattern in ["*.bin", "*.safetensors", "config.json"]: files_to_hash.extend(ckpt_dir.glob(pattern)) if not files_to_hash: return "unknown" hasher = hashlib.sha256() for file_path in sorted(files_to_hash): hasher.update(file_path.name.encode()) if file_path.stat().st_size < 1024 * 1024: # Only hash small files fully hasher.update(file_path.read_bytes()) else: hasher.update(str(file_path.stat().st_size).encode()) return hasher.hexdigest()[:16] except Exception: return "unknown" # FastAPI app app = FastAPI(title="VLAC Service (Multi-Worker)", version="1.0.0") def add_version_headers(response: JSONResponse): """Add versioning headers to response.""" response.headers["X-VLAC-Model-Tag"] = model_tag response.headers["X-VLAC-Checkpoint-SHA"] = checkpoint_sha or "unknown" return response @app.post("/healthcheck", response_model=HealthResponse) async def healthcheck(): """Health check endpoint.""" return HealthResponse(status="ok", model_tag=model_tag) @app.on_event("startup") async def startup_event(): """Lazy-load VLAC model when the server process starts.""" global model, load_lock, assigned_gpu_id if model is not None: return if load_lock is None: load_lock = asyncio.Lock() async with load_lock: if model is not None: return ckpt_path = os.environ.get("VLAC_CKPT_PATH", "/home/zechen/SimpleVLA-RL/CKPT/VLAC") gpu_ids_env = os.environ.get("VLAC_GPU_IDS", "0") try: gpu_ids = [int(x.strip()) for x in gpu_ids_env.split(",") if x.strip()] except ValueError as exc: raise RuntimeError(f"Invalid VLAC_GPU_IDS value: {gpu_ids_env}") from exc if not gpu_ids: gpu_ids = [0] fd = _acquire_assignment_lock() try: state = _cleanup_assignment_state(_read_assignment_state(), gpu_ids) worker_id_str = os.environ.get("UVICORN_WORKER_ID") preferred_gpu: Optional[int] = None if worker_id_str is not None and worker_id_str.isdigit(): worker_idx = int(worker_id_str) preferred_gpu = gpu_ids[worker_idx % len(gpu_ids)] else: selected_gpu_env = os.environ.get("VLAC_SELECTED_GPU") preferred_gpu = int(selected_gpu_env) if selected_gpu_env and selected_gpu_env.isdigit() else None assigned_gpu = None if preferred_gpu is not None and preferred_gpu in gpu_ids: # Use preferred GPU if not already taken by another process if all(entry.get("gpu_id") != preferred_gpu for entry in state.values()): assigned_gpu = preferred_gpu if assigned_gpu is None: # Find least used GPU (by count of processes assigned); tie-break using memory usage usage_count: dict[int, int] = {gpu: 0 for gpu in gpu_ids} for info in state.values(): gpu = info.get("gpu_id") if gpu in usage_count: usage_count[gpu] += 1 min_count = min(usage_count.values()) candidates = [gpu for gpu, count in usage_count.items() if count == min_count] if len(candidates) == 1: assigned_gpu = candidates[0] else: # Use GPU memory heuristic among candidates best_gpu = select_best_gpu(candidates) assigned_gpu = best_gpu state[os.getpid()] = {"gpu_id": assigned_gpu} _write_assignment_state(state) assigned_gpu_id = assigned_gpu finally: _release_assignment_lock(fd) _register_release_handler() os.environ["VLAC_SELECTED_GPU"] = str(assigned_gpu_id) os.environ["CUDA_VISIBLE_DEVICES"] = str(assigned_gpu_id) print(f"Worker {worker_id_str or 'main'} using GPU {assigned_gpu_id}") init_model(ckpt_path, assigned_gpu_id) @app.post("/pairwise-critic", response_model=PairwiseCriticResponse) async def pairwise_critic(req: PairwiseCriticRequest): """Compare two images and return critic score.""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: try: img_a = b64_to_pil(req.image_a) except Exception as exc: raise HTTPException(status_code=400, detail=f"Invalid image_a: {exc}") from exc try: img_b = b64_to_pil(req.image_b) except Exception as exc: raise HTTPException(status_code=400, detail=f"Invalid image_b: {exc}") from exc save_debug_image(img_a, "pairwise_a") save_debug_image(img_b, "pairwise_b") try: critic_list, _ = model.get_trajectory_critic( task=req.task, image_list=[img_a, img_b], ref_image_list=None, batch_num=1, ref_num=0, rich=req.rich, skip=1, frame_skip=True, ) except Exception as exc: raise HTTPException(status_code=500, detail=f"Model processing error: {exc}") from exc critic_val = float(critic_list[-1]) if critic_list else 0.0 response = JSONResponse( content=PairwiseCriticResponse(critic=critic_val, raw=str(critic_val)).dict() ) return add_version_headers(response) except HTTPException: raise except Exception as exc: raise HTTPException(status_code=500, detail=f"Unexpected error: {exc}") from exc @app.post("/done", response_model=DoneResponse) async def done(req: DoneRequest): """Determine if trajectory should terminate.""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") if req.reference and len(req.reference) > 0 and len(req.reference) < 2: raise HTTPException( status_code=422, detail=( "When reference images are provided for done detection, at least 2 are " f"required (got {len(req.reference)}). Use reference=null for simple done detection." ), ) try: first_img = b64_to_pil(req.first_frame) prev_img = b64_to_pil(req.prev_frame) curr_img = b64_to_pil(req.curr_frame) ref_imgs = None if req.reference: ref_imgs = [b64_to_pil(ref) for ref in req.reference[:8]] save_debug_image(curr_img, "done_curr") if ref_imgs: done_list, _ = model.get_in_context_done( task=req.task, first_image=[first_img], n_pre_image=[prev_img], now_image=[curr_img], ref_image_list=ref_imgs, ref_num=len(ref_imgs), rich=True, ) prob = float(done_list[0]) if done_list else 0.0 else: done_list = model.get_trajectory_done( task=req.task, image_list=[curr_img], batch_num=1, rich=True, ) prob = float(done_list[0]) if done_list else 0.0 is_done = prob > 0.75 response = JSONResponse(content=DoneResponse(done=is_done, prob=prob).dict()) return add_version_headers(response) except Exception as exc: raise HTTPException(status_code=500, detail=f"Processing error: {exc}") from exc @app.post("/trajectory-critic", response_model=TrajectoryCriticResponse) async def trajectory_critic(req: TrajectoryCriticRequest): """Compute critic and value curves for full trajectory.""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") print(f"Processing trajectory-critic request: {len(req.frames)} frames, task: {req.task[:50]}...") if req.reference and len(req.reference) > 0 and req.ref_num < 2: raise HTTPException( status_code=422, detail=( "When reference images are provided, ref_num must be >= 2 (got " f"{req.ref_num}). Use ref_num=0 and reference=null for no reference images." ), ) try: print("Converting base64 images to PIL...") frames: List[Image.Image] = [] for idx, frame in enumerate(req.frames): try: frames.append(b64_to_pil(frame)) except Exception as exc: print(f"Error converting frame {idx}: {exc}") raise HTTPException(status_code=400, detail=f"Invalid image data in frame {idx}: {exc}") from exc ref_imgs = None if req.reference: print(f"Converting {len(req.reference)} reference images...") ref_imgs = [] for idx, ref in enumerate(req.reference): try: ref_imgs.append(b64_to_pil(ref)) except Exception as exc: print(f"Error converting reference {idx}: {exc}") raise HTTPException(status_code=400, detail=f"Invalid reference image {idx}: {exc}") from exc if frames: save_debug_image(frames[0], f"traj_first_{req.demo_id}") save_debug_image(frames[-1], f"traj_last_{req.demo_id}") effective_batch_size = min(req.batch_size, 8) print(f"Using effective batch size: {effective_batch_size}") try: critic_list, value_list = model.get_trajectory_critic( task=req.task, image_list=frames, ref_image_list=ref_imgs, batch_num=effective_batch_size, ref_num=req.ref_num, think=req.think, skip=req.skip, rich=True, frame_skip=True, ) print(f"Got results: {len(critic_list)} critics, {len(value_list)} values") except Exception as exc: print(f"Error in model.get_trajectory_critic: {exc}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Model processing error: {exc}") from exc try: critic_floats = [float(c) for c in critic_list] value_floats = [float(v) for v in value_list] print(f"Converted to floats: {len(critic_floats)} critics, {len(value_floats)} values") except Exception as exc: print(f"Error converting to floats: {exc}") print(f"critic_list types: {[type(c) for c in critic_list[:3]]}") print(f"value_list types: {[type(v) for v in value_list[:3]]}") raise HTTPException(status_code=500, detail=f"Type conversion error: {exc}") from exc video_b64 = None print("Creating response...") response = JSONResponse( content=TrajectoryCriticResponse( value_list=value_floats, critic_list=critic_floats, done_list=None, video=video_b64, ).dict() ) return add_version_headers(response) except HTTPException: raise except Exception as exc: print(f"Unexpected error in trajectory_critic: {exc}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Processing error: {exc}") from exc def init_model(ckpt_path: str, gpu_id: int): """Initialize the VLAC model.""" global model, checkpoint_sha print(f"Loading VLAC model from {ckpt_path} on GPU {gpu_id}") try: device_target: str if torch.cuda.is_available(): if os.environ.get("CUDA_VISIBLE_DEVICES"): device_target = "cuda:0" else: device_target = f"cuda:{gpu_id}" else: device_target = "cpu" model = GAC_model(tag="critic") model.init_model( model_path=ckpt_path, model_type="internvl2", device_map=device_target, ) model.temperature = 0.5 model.top_k = 1 model.set_config() model.set_system_prompt() checkpoint_sha = get_checkpoint_sha(ckpt_path) print(f"Model loaded successfully. Checkpoint SHA: {checkpoint_sha}") except Exception as exc: print(f"Failed to load model: {exc}") raise def main(): parser = argparse.ArgumentParser(description="VLAC Service (multi-worker)") parser.add_argument("--port", type=int, default=8111, help="Port to run on") parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") parser.add_argument( "--ckpt-path", default="/home/zechen/SimpleVLA-RL/CKPT/VLAC", help="Path to VLAC checkpoint", ) parser.add_argument("--gpu-ids", default="0,1,2,3", help="Comma-separated GPU IDs") parser.add_argument("--workers", type=int, default=4, help="Number of uvicorn worker processes") args = parser.parse_args() try: gpu_ids = [int(x.strip()) for x in args.gpu_ids.split(",") if x.strip()] except ValueError as exc: raise RuntimeError(f"Invalid --gpu-ids value: {args.gpu_ids}") from exc if not gpu_ids: if torch.cuda.is_available(): gpu_ids = list(range(torch.cuda.device_count())) else: gpu_ids = [0] # Expose configuration for worker processes os.environ["VLAC_CKPT_PATH"] = args.ckpt_path os.environ["VLAC_GPU_IDS"] = ",".join(str(g) for g in gpu_ids) if args.workers == 1 or len(gpu_ids) == 1: selected_gpu = select_best_gpu(gpu_ids) os.environ["VLAC_SELECTED_GPU"] = str(selected_gpu) print(f"Single-worker mode: selected GPU {selected_gpu} from {gpu_ids}") else: print(f"Multi-worker mode: available GPUs {gpu_ids}, workers {args.workers}") print(f"Starting VLAC multi-worker service on {args.host}:{args.port}") uvicorn.run( "vlac_service_multi_worker:app", host=args.host, port=args.port, workers=args.workers, log_level="debug", ) if __name__ == "__main__": main()