TTI / Dev /vlac_service_multi_worker.py
JosephBai's picture
Upload folder using huggingface_hub
857c2e9 verified
#!/usr/bin/env python3
"""
VLAC Service (Multi-Worker) - HTTP API for Vision-Language-Action-Critic model
This variant of the service keeps the single-worker code path untouched so that
we can debug multi-process behaviour separately. Each uvicorn worker lazily
loads the VLAC model, pins itself to a GPU from the provided list, and limits
CUDA visibility to that device before model initialisation.
Usage:
python vlac_service_multi_worker.py --port 8111 --gpu-ids 0,1,2,3 \
--ckpt-path /path/to/VLAC/checkpoint --workers 4
"""
import argparse
import asyncio
import atexit
import base64
import json
import hashlib
import os
import subprocess
import sys
from io import BytesIO
from pathlib import Path
from typing import List, Optional
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
from pydantic import BaseModel, Field
import uvicorn
# Add evo_vlac to path
sys.path.insert(0, str(Path(__file__).parent / "evo_vlac"))
from evo_vlac import GAC_model
# Global model instance
model: Optional[GAC_model] = None
model_tag = "1.0.0"
checkpoint_sha: Optional[str] = None
load_lock: Optional[asyncio.Lock] = None
try:
import fcntl # type: ignore[attr-defined]
except ImportError: # pragma: no cover - platform without fcntl (e.g. Windows)
fcntl = None # type: ignore[assignment]
GPU_LOCK_FILE = Path(os.environ.get("VLAC_GPU_LOCK_FILE", "/tmp/vlac_gpu_lock"))
GPU_ASSIGNMENT_FILE = Path(
os.environ.get("VLAC_GPU_ASSIGNMENT_FILE", "/tmp/vlac_gpu_assignments.json")
)
lock_release_registered = False
assigned_gpu_id: Optional[int] = None
def _acquire_assignment_lock() -> Optional[int]:
if fcntl is None:
return None
GPU_LOCK_FILE.parent.mkdir(parents=True, exist_ok=True)
fd = os.open(str(GPU_LOCK_FILE), os.O_RDWR | os.O_CREAT, 0o600)
fcntl.flock(fd, fcntl.LOCK_EX)
return fd
def _release_assignment_lock(fd: Optional[int]):
if fd is None or fcntl is None:
return
fcntl.flock(fd, fcntl.LOCK_UN)
os.close(fd)
def _read_assignment_state() -> dict[int, dict[str, int]]:
if not GPU_ASSIGNMENT_FILE.exists():
return {}
try:
with GPU_ASSIGNMENT_FILE.open("r", encoding="utf-8") as file:
raw = json.load(file)
except (json.JSONDecodeError, OSError):
return {}
result: dict[int, dict[str, int]] = {}
for key, value in raw.items():
try:
pid = int(key)
if isinstance(value, dict) and "gpu_id" in value:
result[pid] = {"gpu_id": int(value["gpu_id"])}
except (ValueError, TypeError):
continue
return result
def _write_assignment_state(state: dict[int, dict[str, int]]):
GPU_ASSIGNMENT_FILE.parent.mkdir(parents=True, exist_ok=True)
tmp_file = GPU_ASSIGNMENT_FILE.with_suffix(".tmp")
with tmp_file.open("w", encoding="utf-8") as file:
json.dump({str(pid): data for pid, data in state.items()}, file)
tmp_file.replace(GPU_ASSIGNMENT_FILE)
def _is_process_alive(pid: int) -> bool:
try:
os.kill(pid, 0)
except OSError:
return False
return True
def _cleanup_assignment_state(state: dict[int, dict[str, int]], gpu_ids: List[int]) -> dict[int, dict[str, int]]:
cleaned: dict[int, dict[str, int]] = {}
for pid, data in state.items():
gpu_id = data.get("gpu_id")
if gpu_id in gpu_ids and _is_process_alive(pid):
cleaned[pid] = {"gpu_id": gpu_id}
return cleaned
def _register_release_handler():
global lock_release_registered
if lock_release_registered:
return
def _release_on_exit():
release_gpu_assignment()
atexit.register(_release_on_exit)
lock_release_registered = True
def release_gpu_assignment():
global assigned_gpu_id
if assigned_gpu_id is None or fcntl is None:
return
fd = _acquire_assignment_lock()
try:
state = _read_assignment_state()
state.pop(os.getpid(), None)
_write_assignment_state(state)
finally:
_release_assignment_lock(fd)
assigned_gpu_id = None
# Request/Response Models
class PairwiseCriticRequest(BaseModel):
task: str
image_a: str = Field(..., description="Base-64 encoded RGB image")
image_b: str = Field(..., description="Base-64 encoded RGB image")
rich: Optional[bool] = False
class PairwiseCriticResponse(BaseModel):
critic: float
raw: str
class DoneRequest(BaseModel):
task: str
first_frame: str = Field(..., description="Base-64 encoded RGB image")
prev_frame: str = Field(..., description="Base-64 encoded RGB image")
curr_frame: str = Field(..., description="Base-64 encoded RGB image")
reference: Optional[List[str]] = Field(None, description="List of base-64 encoded RGB images")
class DoneResponse(BaseModel):
done: bool
prob: float
class TrajectoryCriticRequest(BaseModel):
task: str
frames: List[str] = Field(..., description="List of base-64 encoded RGB images")
reference: Optional[List[str]] = Field(None, description="List of base-64 encoded RGB images")
skip: int = 5
ref_num: int = 6
batch_size: int = 10
think: bool = False
return_video: bool = False
demo_id: str = ""
class TrajectoryCriticResponse(BaseModel):
value_list: List[float]
critic_list: List[float]
done_list: Optional[List[float]] = None
video: Optional[str] = None
class HealthResponse(BaseModel):
status: str
model_tag: str
# Utility functions
def get_gpu_memory_usage() -> dict:
"""Get GPU memory usage for all GPUs."""
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"],
capture_output=True,
text=True,
check=True,
)
gpu_usage: dict[int, int] = {}
for line in result.stdout.strip().split("\n"):
if line.strip():
gpu_id, memory_used = line.strip().split(", ")
gpu_usage[int(gpu_id)] = int(memory_used)
return gpu_usage
except (subprocess.CalledProcessError, FileNotFoundError):
return {}
def select_best_gpu(gpu_ids: List[int]) -> int:
"""Select GPU with lowest memory usage from the given list."""
gpu_usage = get_gpu_memory_usage()
if not gpu_usage:
# Fallback to first GPU if nvidia-smi not available
return gpu_ids[0]
# Find GPU with minimum usage
best_gpu = min(gpu_ids, key=lambda x: gpu_usage.get(x, float("inf")))
return best_gpu
def b64_to_pil(data: str) -> Image.Image:
"""Convert base64 string to PIL Image."""
try:
img_data = base64.b64decode(data)
img = Image.open(BytesIO(img_data)).convert("RGB")
# Resize to 448x448 as specified in contract
img = img.resize((448, 448), Image.Resampling.LANCZOS)
return img
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid image data: {exc}") from exc
def save_debug_image(img: Image.Image, prefix: str):
"""Save image for debugging if VLAC_SAVE_INPUTS=1."""
if os.getenv("VLAC_SAVE_INPUTS") == "1":
debug_dir = Path("./vlac_debug")
debug_dir.mkdir(exist_ok=True)
timestamp = str(hash(img.tobytes()))[:8]
img.save(debug_dir / f"{prefix}_{timestamp}.png")
def chunk_list(lst: List, chunk_size: int) -> List[List]:
"""Split list into chunks of specified size."""
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
def get_checkpoint_sha(ckpt_path: str) -> str:
"""Compute SHA256 of checkpoint directory (simplified)."""
try:
# Simple hash of main model files
files_to_hash: list[Path] = []
ckpt_dir = Path(ckpt_path)
for pattern in ["*.bin", "*.safetensors", "config.json"]:
files_to_hash.extend(ckpt_dir.glob(pattern))
if not files_to_hash:
return "unknown"
hasher = hashlib.sha256()
for file_path in sorted(files_to_hash):
hasher.update(file_path.name.encode())
if file_path.stat().st_size < 1024 * 1024: # Only hash small files fully
hasher.update(file_path.read_bytes())
else:
hasher.update(str(file_path.stat().st_size).encode())
return hasher.hexdigest()[:16]
except Exception:
return "unknown"
# FastAPI app
app = FastAPI(title="VLAC Service (Multi-Worker)", version="1.0.0")
def add_version_headers(response: JSONResponse):
"""Add versioning headers to response."""
response.headers["X-VLAC-Model-Tag"] = model_tag
response.headers["X-VLAC-Checkpoint-SHA"] = checkpoint_sha or "unknown"
return response
@app.post("/healthcheck", response_model=HealthResponse)
async def healthcheck():
"""Health check endpoint."""
return HealthResponse(status="ok", model_tag=model_tag)
@app.on_event("startup")
async def startup_event():
"""Lazy-load VLAC model when the server process starts."""
global model, load_lock, assigned_gpu_id
if model is not None:
return
if load_lock is None:
load_lock = asyncio.Lock()
async with load_lock:
if model is not None:
return
ckpt_path = os.environ.get("VLAC_CKPT_PATH", "/home/zechen/SimpleVLA-RL/CKPT/VLAC")
gpu_ids_env = os.environ.get("VLAC_GPU_IDS", "0")
try:
gpu_ids = [int(x.strip()) for x in gpu_ids_env.split(",") if x.strip()]
except ValueError as exc:
raise RuntimeError(f"Invalid VLAC_GPU_IDS value: {gpu_ids_env}") from exc
if not gpu_ids:
gpu_ids = [0]
fd = _acquire_assignment_lock()
try:
state = _cleanup_assignment_state(_read_assignment_state(), gpu_ids)
worker_id_str = os.environ.get("UVICORN_WORKER_ID")
preferred_gpu: Optional[int] = None
if worker_id_str is not None and worker_id_str.isdigit():
worker_idx = int(worker_id_str)
preferred_gpu = gpu_ids[worker_idx % len(gpu_ids)]
else:
selected_gpu_env = os.environ.get("VLAC_SELECTED_GPU")
preferred_gpu = int(selected_gpu_env) if selected_gpu_env and selected_gpu_env.isdigit() else None
assigned_gpu = None
if preferred_gpu is not None and preferred_gpu in gpu_ids:
# Use preferred GPU if not already taken by another process
if all(entry.get("gpu_id") != preferred_gpu for entry in state.values()):
assigned_gpu = preferred_gpu
if assigned_gpu is None:
# Find least used GPU (by count of processes assigned); tie-break using memory usage
usage_count: dict[int, int] = {gpu: 0 for gpu in gpu_ids}
for info in state.values():
gpu = info.get("gpu_id")
if gpu in usage_count:
usage_count[gpu] += 1
min_count = min(usage_count.values())
candidates = [gpu for gpu, count in usage_count.items() if count == min_count]
if len(candidates) == 1:
assigned_gpu = candidates[0]
else:
# Use GPU memory heuristic among candidates
best_gpu = select_best_gpu(candidates)
assigned_gpu = best_gpu
state[os.getpid()] = {"gpu_id": assigned_gpu}
_write_assignment_state(state)
assigned_gpu_id = assigned_gpu
finally:
_release_assignment_lock(fd)
_register_release_handler()
os.environ["VLAC_SELECTED_GPU"] = str(assigned_gpu_id)
os.environ["CUDA_VISIBLE_DEVICES"] = str(assigned_gpu_id)
print(f"Worker {worker_id_str or 'main'} using GPU {assigned_gpu_id}")
init_model(ckpt_path, assigned_gpu_id)
@app.post("/pairwise-critic", response_model=PairwiseCriticResponse)
async def pairwise_critic(req: PairwiseCriticRequest):
"""Compare two images and return critic score."""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
try:
img_a = b64_to_pil(req.image_a)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid image_a: {exc}") from exc
try:
img_b = b64_to_pil(req.image_b)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid image_b: {exc}") from exc
save_debug_image(img_a, "pairwise_a")
save_debug_image(img_b, "pairwise_b")
try:
critic_list, _ = model.get_trajectory_critic(
task=req.task,
image_list=[img_a, img_b],
ref_image_list=None,
batch_num=1,
ref_num=0,
rich=req.rich,
skip=1,
frame_skip=True,
)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Model processing error: {exc}") from exc
critic_val = float(critic_list[-1]) if critic_list else 0.0
response = JSONResponse(
content=PairwiseCriticResponse(critic=critic_val, raw=str(critic_val)).dict()
)
return add_version_headers(response)
except HTTPException:
raise
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Unexpected error: {exc}") from exc
@app.post("/done", response_model=DoneResponse)
async def done(req: DoneRequest):
"""Determine if trajectory should terminate."""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if req.reference and len(req.reference) > 0 and len(req.reference) < 2:
raise HTTPException(
status_code=422,
detail=(
"When reference images are provided for done detection, at least 2 are "
f"required (got {len(req.reference)}). Use reference=null for simple done detection."
),
)
try:
first_img = b64_to_pil(req.first_frame)
prev_img = b64_to_pil(req.prev_frame)
curr_img = b64_to_pil(req.curr_frame)
ref_imgs = None
if req.reference:
ref_imgs = [b64_to_pil(ref) for ref in req.reference[:8]]
save_debug_image(curr_img, "done_curr")
if ref_imgs:
done_list, _ = model.get_in_context_done(
task=req.task,
first_image=[first_img],
n_pre_image=[prev_img],
now_image=[curr_img],
ref_image_list=ref_imgs,
ref_num=len(ref_imgs),
rich=True,
)
prob = float(done_list[0]) if done_list else 0.0
else:
done_list = model.get_trajectory_done(
task=req.task,
image_list=[curr_img],
batch_num=1,
rich=True,
)
prob = float(done_list[0]) if done_list else 0.0
is_done = prob > 0.75
response = JSONResponse(content=DoneResponse(done=is_done, prob=prob).dict())
return add_version_headers(response)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Processing error: {exc}") from exc
@app.post("/trajectory-critic", response_model=TrajectoryCriticResponse)
async def trajectory_critic(req: TrajectoryCriticRequest):
"""Compute critic and value curves for full trajectory."""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
print(f"Processing trajectory-critic request: {len(req.frames)} frames, task: {req.task[:50]}...")
if req.reference and len(req.reference) > 0 and req.ref_num < 2:
raise HTTPException(
status_code=422,
detail=(
"When reference images are provided, ref_num must be >= 2 (got "
f"{req.ref_num}). Use ref_num=0 and reference=null for no reference images."
),
)
try:
print("Converting base64 images to PIL...")
frames: List[Image.Image] = []
for idx, frame in enumerate(req.frames):
try:
frames.append(b64_to_pil(frame))
except Exception as exc:
print(f"Error converting frame {idx}: {exc}")
raise HTTPException(status_code=400, detail=f"Invalid image data in frame {idx}: {exc}") from exc
ref_imgs = None
if req.reference:
print(f"Converting {len(req.reference)} reference images...")
ref_imgs = []
for idx, ref in enumerate(req.reference):
try:
ref_imgs.append(b64_to_pil(ref))
except Exception as exc:
print(f"Error converting reference {idx}: {exc}")
raise HTTPException(status_code=400, detail=f"Invalid reference image {idx}: {exc}") from exc
if frames:
save_debug_image(frames[0], f"traj_first_{req.demo_id}")
save_debug_image(frames[-1], f"traj_last_{req.demo_id}")
effective_batch_size = min(req.batch_size, 8)
print(f"Using effective batch size: {effective_batch_size}")
try:
critic_list, value_list = model.get_trajectory_critic(
task=req.task,
image_list=frames,
ref_image_list=ref_imgs,
batch_num=effective_batch_size,
ref_num=req.ref_num,
think=req.think,
skip=req.skip,
rich=True,
frame_skip=True,
)
print(f"Got results: {len(critic_list)} critics, {len(value_list)} values")
except Exception as exc:
print(f"Error in model.get_trajectory_critic: {exc}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Model processing error: {exc}") from exc
try:
critic_floats = [float(c) for c in critic_list]
value_floats = [float(v) for v in value_list]
print(f"Converted to floats: {len(critic_floats)} critics, {len(value_floats)} values")
except Exception as exc:
print(f"Error converting to floats: {exc}")
print(f"critic_list types: {[type(c) for c in critic_list[:3]]}")
print(f"value_list types: {[type(v) for v in value_list[:3]]}")
raise HTTPException(status_code=500, detail=f"Type conversion error: {exc}") from exc
video_b64 = None
print("Creating response...")
response = JSONResponse(
content=TrajectoryCriticResponse(
value_list=value_floats,
critic_list=critic_floats,
done_list=None,
video=video_b64,
).dict()
)
return add_version_headers(response)
except HTTPException:
raise
except Exception as exc:
print(f"Unexpected error in trajectory_critic: {exc}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Processing error: {exc}") from exc
def init_model(ckpt_path: str, gpu_id: int):
"""Initialize the VLAC model."""
global model, checkpoint_sha
print(f"Loading VLAC model from {ckpt_path} on GPU {gpu_id}")
try:
device_target: str
if torch.cuda.is_available():
if os.environ.get("CUDA_VISIBLE_DEVICES"):
device_target = "cuda:0"
else:
device_target = f"cuda:{gpu_id}"
else:
device_target = "cpu"
model = GAC_model(tag="critic")
model.init_model(
model_path=ckpt_path,
model_type="internvl2",
device_map=device_target,
)
model.temperature = 0.5
model.top_k = 1
model.set_config()
model.set_system_prompt()
checkpoint_sha = get_checkpoint_sha(ckpt_path)
print(f"Model loaded successfully. Checkpoint SHA: {checkpoint_sha}")
except Exception as exc:
print(f"Failed to load model: {exc}")
raise
def main():
parser = argparse.ArgumentParser(description="VLAC Service (multi-worker)")
parser.add_argument("--port", type=int, default=8111, help="Port to run on")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
parser.add_argument(
"--ckpt-path",
default="/home/zechen/SimpleVLA-RL/CKPT/VLAC",
help="Path to VLAC checkpoint",
)
parser.add_argument("--gpu-ids", default="0,1,2,3", help="Comma-separated GPU IDs")
parser.add_argument("--workers", type=int, default=4, help="Number of uvicorn worker processes")
args = parser.parse_args()
try:
gpu_ids = [int(x.strip()) for x in args.gpu_ids.split(",") if x.strip()]
except ValueError as exc:
raise RuntimeError(f"Invalid --gpu-ids value: {args.gpu_ids}") from exc
if not gpu_ids:
if torch.cuda.is_available():
gpu_ids = list(range(torch.cuda.device_count()))
else:
gpu_ids = [0]
# Expose configuration for worker processes
os.environ["VLAC_CKPT_PATH"] = args.ckpt_path
os.environ["VLAC_GPU_IDS"] = ",".join(str(g) for g in gpu_ids)
if args.workers == 1 or len(gpu_ids) == 1:
selected_gpu = select_best_gpu(gpu_ids)
os.environ["VLAC_SELECTED_GPU"] = str(selected_gpu)
print(f"Single-worker mode: selected GPU {selected_gpu} from {gpu_ids}")
else:
print(f"Multi-worker mode: available GPUs {gpu_ids}, workers {args.workers}")
print(f"Starting VLAC multi-worker service on {args.host}:{args.port}")
uvicorn.run(
"vlac_service_multi_worker:app",
host=args.host,
port=args.port,
workers=args.workers,
log_level="debug",
)
if __name__ == "__main__":
main()