| |
| """ |
| VLAC Service (Multi-Worker) - HTTP API for Vision-Language-Action-Critic model |
| |
| This variant of the service keeps the single-worker code path untouched so that |
| we can debug multi-process behaviour separately. Each uvicorn worker lazily |
| loads the VLAC model, pins itself to a GPU from the provided list, and limits |
| CUDA visibility to that device before model initialisation. |
| |
| Usage: |
| python vlac_service_multi_worker.py --port 8111 --gpu-ids 0,1,2,3 \ |
| --ckpt-path /path/to/VLAC/checkpoint --workers 4 |
| """ |
|
|
| import argparse |
| import asyncio |
| import atexit |
| import base64 |
| import json |
| import hashlib |
| import os |
| import subprocess |
| import sys |
| from io import BytesIO |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import torch |
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import JSONResponse |
| from PIL import Image |
| from pydantic import BaseModel, Field |
| import uvicorn |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent / "evo_vlac")) |
| from evo_vlac import GAC_model |
|
|
| |
| model: Optional[GAC_model] = None |
| model_tag = "1.0.0" |
| checkpoint_sha: Optional[str] = None |
| load_lock: Optional[asyncio.Lock] = None |
|
|
| try: |
| import fcntl |
| except ImportError: |
| fcntl = None |
|
|
| GPU_LOCK_FILE = Path(os.environ.get("VLAC_GPU_LOCK_FILE", "/tmp/vlac_gpu_lock")) |
| GPU_ASSIGNMENT_FILE = Path( |
| os.environ.get("VLAC_GPU_ASSIGNMENT_FILE", "/tmp/vlac_gpu_assignments.json") |
| ) |
|
|
| lock_release_registered = False |
| assigned_gpu_id: Optional[int] = None |
|
|
|
|
| def _acquire_assignment_lock() -> Optional[int]: |
| if fcntl is None: |
| return None |
|
|
| GPU_LOCK_FILE.parent.mkdir(parents=True, exist_ok=True) |
| fd = os.open(str(GPU_LOCK_FILE), os.O_RDWR | os.O_CREAT, 0o600) |
| fcntl.flock(fd, fcntl.LOCK_EX) |
| return fd |
|
|
|
|
| def _release_assignment_lock(fd: Optional[int]): |
| if fd is None or fcntl is None: |
| return |
| fcntl.flock(fd, fcntl.LOCK_UN) |
| os.close(fd) |
|
|
|
|
| def _read_assignment_state() -> dict[int, dict[str, int]]: |
| if not GPU_ASSIGNMENT_FILE.exists(): |
| return {} |
|
|
| try: |
| with GPU_ASSIGNMENT_FILE.open("r", encoding="utf-8") as file: |
| raw = json.load(file) |
| except (json.JSONDecodeError, OSError): |
| return {} |
|
|
| result: dict[int, dict[str, int]] = {} |
| for key, value in raw.items(): |
| try: |
| pid = int(key) |
| if isinstance(value, dict) and "gpu_id" in value: |
| result[pid] = {"gpu_id": int(value["gpu_id"])} |
| except (ValueError, TypeError): |
| continue |
| return result |
|
|
|
|
| def _write_assignment_state(state: dict[int, dict[str, int]]): |
| GPU_ASSIGNMENT_FILE.parent.mkdir(parents=True, exist_ok=True) |
| tmp_file = GPU_ASSIGNMENT_FILE.with_suffix(".tmp") |
| with tmp_file.open("w", encoding="utf-8") as file: |
| json.dump({str(pid): data for pid, data in state.items()}, file) |
| tmp_file.replace(GPU_ASSIGNMENT_FILE) |
|
|
|
|
| def _is_process_alive(pid: int) -> bool: |
| try: |
| os.kill(pid, 0) |
| except OSError: |
| return False |
| return True |
|
|
|
|
| def _cleanup_assignment_state(state: dict[int, dict[str, int]], gpu_ids: List[int]) -> dict[int, dict[str, int]]: |
| cleaned: dict[int, dict[str, int]] = {} |
| for pid, data in state.items(): |
| gpu_id = data.get("gpu_id") |
| if gpu_id in gpu_ids and _is_process_alive(pid): |
| cleaned[pid] = {"gpu_id": gpu_id} |
| return cleaned |
|
|
|
|
| def _register_release_handler(): |
| global lock_release_registered |
| if lock_release_registered: |
| return |
|
|
| def _release_on_exit(): |
| release_gpu_assignment() |
|
|
| atexit.register(_release_on_exit) |
| lock_release_registered = True |
|
|
|
|
| def release_gpu_assignment(): |
| global assigned_gpu_id |
| if assigned_gpu_id is None or fcntl is None: |
| return |
|
|
| fd = _acquire_assignment_lock() |
| try: |
| state = _read_assignment_state() |
| state.pop(os.getpid(), None) |
| _write_assignment_state(state) |
| finally: |
| _release_assignment_lock(fd) |
|
|
| assigned_gpu_id = None |
|
|
|
|
| |
| class PairwiseCriticRequest(BaseModel): |
| task: str |
| image_a: str = Field(..., description="Base-64 encoded RGB image") |
| image_b: str = Field(..., description="Base-64 encoded RGB image") |
| rich: Optional[bool] = False |
|
|
|
|
| class PairwiseCriticResponse(BaseModel): |
| critic: float |
| raw: str |
|
|
|
|
| class DoneRequest(BaseModel): |
| task: str |
| first_frame: str = Field(..., description="Base-64 encoded RGB image") |
| prev_frame: str = Field(..., description="Base-64 encoded RGB image") |
| curr_frame: str = Field(..., description="Base-64 encoded RGB image") |
| reference: Optional[List[str]] = Field(None, description="List of base-64 encoded RGB images") |
|
|
|
|
| class DoneResponse(BaseModel): |
| done: bool |
| prob: float |
|
|
|
|
| class TrajectoryCriticRequest(BaseModel): |
| task: str |
| frames: List[str] = Field(..., description="List of base-64 encoded RGB images") |
| reference: Optional[List[str]] = Field(None, description="List of base-64 encoded RGB images") |
| skip: int = 5 |
| ref_num: int = 6 |
| batch_size: int = 10 |
| think: bool = False |
| return_video: bool = False |
| demo_id: str = "" |
|
|
|
|
| class TrajectoryCriticResponse(BaseModel): |
| value_list: List[float] |
| critic_list: List[float] |
| done_list: Optional[List[float]] = None |
| video: Optional[str] = None |
|
|
|
|
| class HealthResponse(BaseModel): |
| status: str |
| model_tag: str |
|
|
|
|
| |
| def get_gpu_memory_usage() -> dict: |
| """Get GPU memory usage for all GPUs.""" |
| try: |
| result = subprocess.run( |
| ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"], |
| capture_output=True, |
| text=True, |
| check=True, |
| ) |
| gpu_usage: dict[int, int] = {} |
| for line in result.stdout.strip().split("\n"): |
| if line.strip(): |
| gpu_id, memory_used = line.strip().split(", ") |
| gpu_usage[int(gpu_id)] = int(memory_used) |
| return gpu_usage |
| except (subprocess.CalledProcessError, FileNotFoundError): |
| return {} |
|
|
|
|
| def select_best_gpu(gpu_ids: List[int]) -> int: |
| """Select GPU with lowest memory usage from the given list.""" |
| gpu_usage = get_gpu_memory_usage() |
| if not gpu_usage: |
| |
| return gpu_ids[0] |
|
|
| |
| best_gpu = min(gpu_ids, key=lambda x: gpu_usage.get(x, float("inf"))) |
| return best_gpu |
|
|
|
|
| def b64_to_pil(data: str) -> Image.Image: |
| """Convert base64 string to PIL Image.""" |
| try: |
| img_data = base64.b64decode(data) |
| img = Image.open(BytesIO(img_data)).convert("RGB") |
| |
| img = img.resize((448, 448), Image.Resampling.LANCZOS) |
| return img |
| except Exception as exc: |
| raise HTTPException(status_code=400, detail=f"Invalid image data: {exc}") from exc |
|
|
|
|
| def save_debug_image(img: Image.Image, prefix: str): |
| """Save image for debugging if VLAC_SAVE_INPUTS=1.""" |
| if os.getenv("VLAC_SAVE_INPUTS") == "1": |
| debug_dir = Path("./vlac_debug") |
| debug_dir.mkdir(exist_ok=True) |
| timestamp = str(hash(img.tobytes()))[:8] |
| img.save(debug_dir / f"{prefix}_{timestamp}.png") |
|
|
|
|
| def chunk_list(lst: List, chunk_size: int) -> List[List]: |
| """Split list into chunks of specified size.""" |
| return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] |
|
|
|
|
| def get_checkpoint_sha(ckpt_path: str) -> str: |
| """Compute SHA256 of checkpoint directory (simplified).""" |
| try: |
| |
| files_to_hash: list[Path] = [] |
| ckpt_dir = Path(ckpt_path) |
| for pattern in ["*.bin", "*.safetensors", "config.json"]: |
| files_to_hash.extend(ckpt_dir.glob(pattern)) |
|
|
| if not files_to_hash: |
| return "unknown" |
|
|
| hasher = hashlib.sha256() |
| for file_path in sorted(files_to_hash): |
| hasher.update(file_path.name.encode()) |
| if file_path.stat().st_size < 1024 * 1024: |
| hasher.update(file_path.read_bytes()) |
| else: |
| hasher.update(str(file_path.stat().st_size).encode()) |
|
|
| return hasher.hexdigest()[:16] |
| except Exception: |
| return "unknown" |
|
|
|
|
| |
| app = FastAPI(title="VLAC Service (Multi-Worker)", version="1.0.0") |
|
|
|
|
| def add_version_headers(response: JSONResponse): |
| """Add versioning headers to response.""" |
| response.headers["X-VLAC-Model-Tag"] = model_tag |
| response.headers["X-VLAC-Checkpoint-SHA"] = checkpoint_sha or "unknown" |
| return response |
|
|
|
|
| @app.post("/healthcheck", response_model=HealthResponse) |
| async def healthcheck(): |
| """Health check endpoint.""" |
| return HealthResponse(status="ok", model_tag=model_tag) |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Lazy-load VLAC model when the server process starts.""" |
| global model, load_lock, assigned_gpu_id |
| if model is not None: |
| return |
|
|
| if load_lock is None: |
| load_lock = asyncio.Lock() |
|
|
| async with load_lock: |
| if model is not None: |
| return |
|
|
| ckpt_path = os.environ.get("VLAC_CKPT_PATH", "/home/zechen/SimpleVLA-RL/CKPT/VLAC") |
| gpu_ids_env = os.environ.get("VLAC_GPU_IDS", "0") |
| try: |
| gpu_ids = [int(x.strip()) for x in gpu_ids_env.split(",") if x.strip()] |
| except ValueError as exc: |
| raise RuntimeError(f"Invalid VLAC_GPU_IDS value: {gpu_ids_env}") from exc |
|
|
| if not gpu_ids: |
| gpu_ids = [0] |
|
|
| fd = _acquire_assignment_lock() |
| try: |
| state = _cleanup_assignment_state(_read_assignment_state(), gpu_ids) |
|
|
| worker_id_str = os.environ.get("UVICORN_WORKER_ID") |
| preferred_gpu: Optional[int] = None |
| if worker_id_str is not None and worker_id_str.isdigit(): |
| worker_idx = int(worker_id_str) |
| preferred_gpu = gpu_ids[worker_idx % len(gpu_ids)] |
| else: |
| selected_gpu_env = os.environ.get("VLAC_SELECTED_GPU") |
| preferred_gpu = int(selected_gpu_env) if selected_gpu_env and selected_gpu_env.isdigit() else None |
|
|
| assigned_gpu = None |
| if preferred_gpu is not None and preferred_gpu in gpu_ids: |
| |
| if all(entry.get("gpu_id") != preferred_gpu for entry in state.values()): |
| assigned_gpu = preferred_gpu |
|
|
| if assigned_gpu is None: |
| |
| usage_count: dict[int, int] = {gpu: 0 for gpu in gpu_ids} |
| for info in state.values(): |
| gpu = info.get("gpu_id") |
| if gpu in usage_count: |
| usage_count[gpu] += 1 |
|
|
| min_count = min(usage_count.values()) |
| candidates = [gpu for gpu, count in usage_count.items() if count == min_count] |
| if len(candidates) == 1: |
| assigned_gpu = candidates[0] |
| else: |
| |
| best_gpu = select_best_gpu(candidates) |
| assigned_gpu = best_gpu |
|
|
| state[os.getpid()] = {"gpu_id": assigned_gpu} |
| _write_assignment_state(state) |
| assigned_gpu_id = assigned_gpu |
| finally: |
| _release_assignment_lock(fd) |
|
|
| _register_release_handler() |
|
|
| os.environ["VLAC_SELECTED_GPU"] = str(assigned_gpu_id) |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(assigned_gpu_id) |
| print(f"Worker {worker_id_str or 'main'} using GPU {assigned_gpu_id}") |
|
|
| init_model(ckpt_path, assigned_gpu_id) |
|
|
|
|
| @app.post("/pairwise-critic", response_model=PairwiseCriticResponse) |
| async def pairwise_critic(req: PairwiseCriticRequest): |
| """Compare two images and return critic score.""" |
| if model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| try: |
| try: |
| img_a = b64_to_pil(req.image_a) |
| except Exception as exc: |
| raise HTTPException(status_code=400, detail=f"Invalid image_a: {exc}") from exc |
|
|
| try: |
| img_b = b64_to_pil(req.image_b) |
| except Exception as exc: |
| raise HTTPException(status_code=400, detail=f"Invalid image_b: {exc}") from exc |
|
|
| save_debug_image(img_a, "pairwise_a") |
| save_debug_image(img_b, "pairwise_b") |
|
|
| try: |
| critic_list, _ = model.get_trajectory_critic( |
| task=req.task, |
| image_list=[img_a, img_b], |
| ref_image_list=None, |
| batch_num=1, |
| ref_num=0, |
| rich=req.rich, |
| skip=1, |
| frame_skip=True, |
| ) |
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=f"Model processing error: {exc}") from exc |
|
|
| critic_val = float(critic_list[-1]) if critic_list else 0.0 |
| response = JSONResponse( |
| content=PairwiseCriticResponse(critic=critic_val, raw=str(critic_val)).dict() |
| ) |
| return add_version_headers(response) |
|
|
| except HTTPException: |
| raise |
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=f"Unexpected error: {exc}") from exc |
|
|
|
|
| @app.post("/done", response_model=DoneResponse) |
| async def done(req: DoneRequest): |
| """Determine if trajectory should terminate.""" |
| if model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| if req.reference and len(req.reference) > 0 and len(req.reference) < 2: |
| raise HTTPException( |
| status_code=422, |
| detail=( |
| "When reference images are provided for done detection, at least 2 are " |
| f"required (got {len(req.reference)}). Use reference=null for simple done detection." |
| ), |
| ) |
|
|
| try: |
| first_img = b64_to_pil(req.first_frame) |
| prev_img = b64_to_pil(req.prev_frame) |
| curr_img = b64_to_pil(req.curr_frame) |
|
|
| ref_imgs = None |
| if req.reference: |
| ref_imgs = [b64_to_pil(ref) for ref in req.reference[:8]] |
|
|
| save_debug_image(curr_img, "done_curr") |
|
|
| if ref_imgs: |
| done_list, _ = model.get_in_context_done( |
| task=req.task, |
| first_image=[first_img], |
| n_pre_image=[prev_img], |
| now_image=[curr_img], |
| ref_image_list=ref_imgs, |
| ref_num=len(ref_imgs), |
| rich=True, |
| ) |
| prob = float(done_list[0]) if done_list else 0.0 |
| else: |
| done_list = model.get_trajectory_done( |
| task=req.task, |
| image_list=[curr_img], |
| batch_num=1, |
| rich=True, |
| ) |
| prob = float(done_list[0]) if done_list else 0.0 |
|
|
| is_done = prob > 0.75 |
| response = JSONResponse(content=DoneResponse(done=is_done, prob=prob).dict()) |
| return add_version_headers(response) |
|
|
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=f"Processing error: {exc}") from exc |
|
|
|
|
| @app.post("/trajectory-critic", response_model=TrajectoryCriticResponse) |
| async def trajectory_critic(req: TrajectoryCriticRequest): |
| """Compute critic and value curves for full trajectory.""" |
| if model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| print(f"Processing trajectory-critic request: {len(req.frames)} frames, task: {req.task[:50]}...") |
|
|
| if req.reference and len(req.reference) > 0 and req.ref_num < 2: |
| raise HTTPException( |
| status_code=422, |
| detail=( |
| "When reference images are provided, ref_num must be >= 2 (got " |
| f"{req.ref_num}). Use ref_num=0 and reference=null for no reference images." |
| ), |
| ) |
|
|
| try: |
| print("Converting base64 images to PIL...") |
| frames: List[Image.Image] = [] |
| for idx, frame in enumerate(req.frames): |
| try: |
| frames.append(b64_to_pil(frame)) |
| except Exception as exc: |
| print(f"Error converting frame {idx}: {exc}") |
| raise HTTPException(status_code=400, detail=f"Invalid image data in frame {idx}: {exc}") from exc |
|
|
| ref_imgs = None |
| if req.reference: |
| print(f"Converting {len(req.reference)} reference images...") |
| ref_imgs = [] |
| for idx, ref in enumerate(req.reference): |
| try: |
| ref_imgs.append(b64_to_pil(ref)) |
| except Exception as exc: |
| print(f"Error converting reference {idx}: {exc}") |
| raise HTTPException(status_code=400, detail=f"Invalid reference image {idx}: {exc}") from exc |
|
|
| if frames: |
| save_debug_image(frames[0], f"traj_first_{req.demo_id}") |
| save_debug_image(frames[-1], f"traj_last_{req.demo_id}") |
|
|
| effective_batch_size = min(req.batch_size, 8) |
| print(f"Using effective batch size: {effective_batch_size}") |
|
|
| try: |
| critic_list, value_list = model.get_trajectory_critic( |
| task=req.task, |
| image_list=frames, |
| ref_image_list=ref_imgs, |
| batch_num=effective_batch_size, |
| ref_num=req.ref_num, |
| think=req.think, |
| skip=req.skip, |
| rich=True, |
| frame_skip=True, |
| ) |
| print(f"Got results: {len(critic_list)} critics, {len(value_list)} values") |
| except Exception as exc: |
| print(f"Error in model.get_trajectory_critic: {exc}") |
| import traceback |
|
|
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"Model processing error: {exc}") from exc |
|
|
| try: |
| critic_floats = [float(c) for c in critic_list] |
| value_floats = [float(v) for v in value_list] |
| print(f"Converted to floats: {len(critic_floats)} critics, {len(value_floats)} values") |
| except Exception as exc: |
| print(f"Error converting to floats: {exc}") |
| print(f"critic_list types: {[type(c) for c in critic_list[:3]]}") |
| print(f"value_list types: {[type(v) for v in value_list[:3]]}") |
| raise HTTPException(status_code=500, detail=f"Type conversion error: {exc}") from exc |
|
|
| video_b64 = None |
|
|
| print("Creating response...") |
| response = JSONResponse( |
| content=TrajectoryCriticResponse( |
| value_list=value_floats, |
| critic_list=critic_floats, |
| done_list=None, |
| video=video_b64, |
| ).dict() |
| ) |
| return add_version_headers(response) |
|
|
| except HTTPException: |
| raise |
| except Exception as exc: |
| print(f"Unexpected error in trajectory_critic: {exc}") |
| import traceback |
|
|
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"Processing error: {exc}") from exc |
|
|
|
|
| def init_model(ckpt_path: str, gpu_id: int): |
| """Initialize the VLAC model.""" |
| global model, checkpoint_sha |
|
|
| print(f"Loading VLAC model from {ckpt_path} on GPU {gpu_id}") |
|
|
| try: |
| device_target: str |
| if torch.cuda.is_available(): |
| if os.environ.get("CUDA_VISIBLE_DEVICES"): |
| device_target = "cuda:0" |
| else: |
| device_target = f"cuda:{gpu_id}" |
| else: |
| device_target = "cpu" |
|
|
| model = GAC_model(tag="critic") |
| model.init_model( |
| model_path=ckpt_path, |
| model_type="internvl2", |
| device_map=device_target, |
| ) |
| model.temperature = 0.5 |
| model.top_k = 1 |
| model.set_config() |
| model.set_system_prompt() |
|
|
| checkpoint_sha = get_checkpoint_sha(ckpt_path) |
| print(f"Model loaded successfully. Checkpoint SHA: {checkpoint_sha}") |
|
|
| except Exception as exc: |
| print(f"Failed to load model: {exc}") |
| raise |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="VLAC Service (multi-worker)") |
| parser.add_argument("--port", type=int, default=8111, help="Port to run on") |
| parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") |
| parser.add_argument( |
| "--ckpt-path", |
| default="/home/zechen/SimpleVLA-RL/CKPT/VLAC", |
| help="Path to VLAC checkpoint", |
| ) |
| parser.add_argument("--gpu-ids", default="0,1,2,3", help="Comma-separated GPU IDs") |
| parser.add_argument("--workers", type=int, default=4, help="Number of uvicorn worker processes") |
|
|
| args = parser.parse_args() |
|
|
| try: |
| gpu_ids = [int(x.strip()) for x in args.gpu_ids.split(",") if x.strip()] |
| except ValueError as exc: |
| raise RuntimeError(f"Invalid --gpu-ids value: {args.gpu_ids}") from exc |
|
|
| if not gpu_ids: |
| if torch.cuda.is_available(): |
| gpu_ids = list(range(torch.cuda.device_count())) |
| else: |
| gpu_ids = [0] |
|
|
| |
| os.environ["VLAC_CKPT_PATH"] = args.ckpt_path |
| os.environ["VLAC_GPU_IDS"] = ",".join(str(g) for g in gpu_ids) |
|
|
| if args.workers == 1 or len(gpu_ids) == 1: |
| selected_gpu = select_best_gpu(gpu_ids) |
| os.environ["VLAC_SELECTED_GPU"] = str(selected_gpu) |
| print(f"Single-worker mode: selected GPU {selected_gpu} from {gpu_ids}") |
| else: |
| print(f"Multi-worker mode: available GPUs {gpu_ids}, workers {args.workers}") |
|
|
| print(f"Starting VLAC multi-worker service on {args.host}:{args.port}") |
| uvicorn.run( |
| "vlac_service_multi_worker:app", |
| host=args.host, |
| port=args.port, |
| workers=args.workers, |
| log_level="debug", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|