"""This file implements a threaded stream controller to abstract a data stream back to the ray clientserver. """ import math import logging import queue import threading import warnings import grpc from collections import OrderedDict from typing import Any, Callable, Dict, TYPE_CHECKING, Optional, Union import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.util.client.common import ( INT32_MAX, OBJECT_TRANSFER_CHUNK_SIZE, OBJECT_TRANSFER_WARNING_SIZE, ) from ray.util.debug import log_once if TYPE_CHECKING: from ray.util.client.worker import Worker logger = logging.getLogger(__name__) ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]], None] # Send an acknowledge on every 32nd response received ACKNOWLEDGE_BATCH_SIZE = 32 def chunk_put(req: ray_client_pb2.DataRequest): """ Chunks a put request. Doing this lazily is important for large objects, since taking slices of bytes objects does a copy. This means if we immediately materialized every chunk of a large object and inserted them into the result_queue, we would effectively double the memory needed on the client to handle the put. """ # When accessing a protobuf field, deserialization is performed, which will # generate a copy. So we need to avoid accessing the `data` field multiple # times in the loop request_data = req.put.data total_size = len(request_data) assert total_size > 0, "Cannot chunk object with missing data" if total_size >= OBJECT_TRANSFER_WARNING_SIZE and log_once( "client_object_put_size_warning" ): size_gb = total_size / 2**30 warnings.warn( "Ray Client is attempting to send a " f"{size_gb:.2f} GiB object over the network, which may " "be slow. Consider serializing the object and using a remote " "URI to transfer via S3 or Google Cloud Storage instead. " "Documentation for doing this can be found here: " "https://docs.ray.io/en/latest/handling-dependencies.html#remote-uris", UserWarning, ) total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE) for chunk_id in range(0, total_chunks): start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE) chunk = ray_client_pb2.PutRequest( client_ref_id=req.put.client_ref_id, data=request_data[start:end], chunk_id=chunk_id, total_chunks=total_chunks, total_size=total_size, owner_id=req.put.owner_id, ) yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk) def chunk_task(req: ray_client_pb2.DataRequest): """ Chunks a client task. Doing this lazily is important with large arguments, since taking slices of bytes objects does a copy. This means if we immediately materialized every chunk of a large argument and inserted them into the result_queue, we would effectively double the memory needed on the client to handle the task. """ # When accessing a protobuf field, deserialization is performed, which will # generate a copy. So we need to avoid accessing the `data` field multiple # times in the loop request_data = req.task.data total_size = len(request_data) assert total_size > 0, "Cannot chunk object with missing data" total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE) for chunk_id in range(0, total_chunks): start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE) chunk = ray_client_pb2.ClientTask( type=req.task.type, name=req.task.name, payload_id=req.task.payload_id, client_id=req.task.client_id, options=req.task.options, baseline_options=req.task.baseline_options, namespace=req.task.namespace, data=request_data[start:end], chunk_id=chunk_id, total_chunks=total_chunks, ) yield ray_client_pb2.DataRequest(req_id=req.req_id, task=chunk) class ChunkCollector: """ This object collects chunks from async get requests via __call__, and calls the underlying callback when the object is fully received, or if an exception while retrieving the object occurs. This is not used in synchronous gets (synchronous gets interact with the raylet servicer directly, not through the datapath). __call__ returns true once the underlying call back has been called. """ def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest): # Bytearray containing data received so far self.data = bytearray() # The callback that will be called once all data is received self.callback = callback # The id of the last chunk we've received, or -1 if haven't seen any yet self.last_seen_chunk = -1 # The GetRequest that initiated the transfer. start_chunk_id will be # updated as chunks are received to avoid re-requesting chunks that # we've already received. self.request = request def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool: if isinstance(response, Exception): self.callback(response) return True get_resp = response.get if not get_resp.valid: self.callback(response) return True if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once( "client_object_transfer_size_warning" ): size_gb = get_resp.total_size / 2**30 warnings.warn( "Ray Client is attempting to retrieve a " f"{size_gb:.2f} GiB object over the network, which may " "be slow. Consider serializing the object to a file and " "using rsync or S3 instead.", UserWarning, ) chunk_data = get_resp.data chunk_id = get_resp.chunk_id if chunk_id == self.last_seen_chunk + 1: self.data.extend(chunk_data) self.last_seen_chunk = chunk_id # If we disconnect partway through, restart the get request # at the first chunk we haven't seen self.request.get.start_chunk_id = self.last_seen_chunk + 1 elif chunk_id > self.last_seen_chunk + 1: # A chunk was skipped. This shouldn't happen in practice since # grpc guarantees that chunks will arrive in order. msg = ( f"Received chunk {chunk_id} when we expected " f"{self.last_seen_chunk + 1} for request {response.req_id}" ) logger.warning(msg) self.callback(RuntimeError(msg)) return True else: # We received a chunk that've already seen before. Ignore, since # it should already be appended to self.data. logger.debug( f"Received a repeated chunk {chunk_id} " f"from request {response.req_id}." ) if get_resp.chunk_id == get_resp.total_chunks - 1: self.callback(self.data) return True else: # Not done yet return False class DataClient: def __init__(self, client_worker: "Worker", client_id: str, metadata: list): """Initializes a thread-safe datapath over a Ray Client gRPC channel. Args: client_worker: The Ray Client worker that manages this client client_id: the generated ID representing this client metadata: metadata to pass to gRPC requests """ self.client_worker = client_worker self._client_id = client_id self._metadata = metadata self.data_thread = self._start_datathread() # Track outstanding requests to resend in case of disconnection self.outstanding_requests: Dict[int, Any] = OrderedDict() # Serialize access to all mutable internal states: self.request_queue, # self.ready_data, self.asyncio_waiting_data, # self._in_shutdown, self._req_id, self.outstanding_requests and # calling self._next_id() self.lock = threading.Lock() # Waiting for response or shutdown. self.cv = threading.Condition(lock=self.lock) self.request_queue = self._create_queue() self.ready_data: Dict[int, Any] = {} # NOTE: Dictionary insertion is guaranteed to complete before lookup # and/or removal because of synchronization via the request_queue. self.asyncio_waiting_data: Dict[int, ResponseCallable] = {} self._in_shutdown = False self._req_id = 0 self._last_exception = None self._acknowledge_counter = 0 self.data_thread.start() # Must hold self.lock when calling this function. def _next_id(self) -> int: assert self.lock.locked() self._req_id += 1 if self._req_id > INT32_MAX: self._req_id = 1 # Responses that aren't tracked (like opportunistic releases) # have req_id=0, so make sure we never mint such an id. assert self._req_id != 0 return self._req_id def _start_datathread(self) -> threading.Thread: return threading.Thread( target=self._data_main, name="ray_client_streaming_rpc", args=(), daemon=True, ) # A helper that takes requests from queue. If the request wraps a PutRequest, # lazily chunks and yields the request. Otherwise, yields the request directly. def _requests(self): while True: req = self.request_queue.get() if req is None: # Stop when client signals shutdown. return req_type = req.WhichOneof("type") if req_type == "put": yield from chunk_put(req) elif req_type == "task": yield from chunk_task(req) else: yield req def _data_main(self) -> None: reconnecting = False try: while not self.client_worker._in_shutdown: stub = ray_client_pb2_grpc.RayletDataStreamerStub( self.client_worker.channel ) metadata = self._metadata + [("reconnecting", str(reconnecting))] resp_stream = stub.Datapath( self._requests(), metadata=metadata, wait_for_ready=True, ) try: for response in resp_stream: self._process_response(response) return except grpc.RpcError as e: reconnecting = self._can_reconnect(e) if not reconnecting: self._last_exception = e return self._reconnect_channel() except Exception as e: self._last_exception = e finally: logger.debug("Shutting down data channel.") self._shutdown() def _process_response(self, response: Any) -> None: """ Process responses from the data servicer. """ if response.req_id == 0: # This is not being waited for. logger.debug(f"Got unawaited response {response}") return if response.req_id in self.asyncio_waiting_data: can_remove = True try: callback = self.asyncio_waiting_data[response.req_id] if isinstance(callback, ChunkCollector): can_remove = callback(response) elif callback: callback(response) if can_remove: # NOTE: calling del self.asyncio_waiting_data results # in the destructor of ClientObjectRef running, which # calls ReleaseObject(). So self.asyncio_waiting_data # is accessed without holding self.lock. Holding the # lock shouldn't be necessary either. del self.asyncio_waiting_data[response.req_id] except Exception: logger.exception("Callback error:") with self.lock: # Update outstanding requests if response.req_id in self.outstanding_requests and can_remove: del self.outstanding_requests[response.req_id] # Acknowledge response self._acknowledge(response.req_id) else: with self.lock: self.ready_data[response.req_id] = response self.cv.notify_all() def _can_reconnect(self, e: grpc.RpcError) -> bool: """ Processes RPC errors that occur while reading from data stream. Returns True if the error can be recovered from, False otherwise. """ if not self.client_worker._can_reconnect(e): logger.error("Unrecoverable error in data channel.") logger.debug(e) return False logger.debug("Recoverable error in data channel.") logger.debug(e) return True def _shutdown(self) -> None: """ Shutdown the data channel """ with self.lock: self._in_shutdown = True self.cv.notify_all() callbacks = self.asyncio_waiting_data.values() self.asyncio_waiting_data = {} if self._last_exception: # Abort async requests with the error. err = ConnectionError( "Failed during this or a previous request. Exception that " f"broke the connection: {self._last_exception}" ) else: err = ConnectionError( "Request cannot be fulfilled because the data client has " "disconnected." ) for callback in callbacks: if callback: callback(err) # Since self._in_shutdown is set to True, no new item # will be added to self.asyncio_waiting_data def _acknowledge(self, req_id: int) -> None: """ Puts an acknowledge request on the request queue periodically. Lock should be held before calling this. Used when an async or blocking response is received. """ if not self.client_worker._reconnect_enabled: # Skip ACKs if reconnect isn't enabled return assert self.lock.locked() self._acknowledge_counter += 1 if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0: self.request_queue.put( ray_client_pb2.DataRequest( acknowledge=ray_client_pb2.AcknowledgeRequest(req_id=req_id) ) ) def _reconnect_channel(self) -> None: """ Attempts to reconnect the gRPC channel and resend outstanding requests. First, the server is pinged to see if the current channel still works. If the ping fails, then the current channel is closed and replaced with a new one. Once a working channel is available, a new request queue is made and filled with any outstanding requests to be resent to the server. """ try: # Ping the server to see if the current channel is reuseable, for # example if gRPC reconnected the channel on its own or if the # RPC error was transient and the channel is still open ping_succeeded = self.client_worker.ping_server(timeout=5) except grpc.RpcError: ping_succeeded = False if not ping_succeeded: # Ping failed, try refreshing the data channel logger.warning( "Encountered connection issues in the data channel. " "Attempting to reconnect." ) try: self.client_worker._connect_channel(reconnecting=True) except ConnectionError: logger.warning("Failed to reconnect the data channel") raise logger.debug("Reconnection succeeded!") # Recreate the request queue, and resend outstanding requests with self.lock: self.request_queue = self._create_queue() for request in self.outstanding_requests.values(): # Resend outstanding requests self.request_queue.put(request) # Use SimpleQueue to avoid deadlocks when appending to queue from __del__() @staticmethod def _create_queue(): return queue.SimpleQueue() def close(self) -> None: thread = None with self.lock: self._in_shutdown = True # Notify blocking operations to fail. self.cv.notify_all() # Add sentinel to terminate streaming RPC. if self.request_queue is not None: # Intentional shutdown, tell server it can clean up the # connection immediately and ignore the reconnect grace period. cleanup_request = ray_client_pb2.DataRequest( connection_cleanup=ray_client_pb2.ConnectionCleanupRequest() ) self.request_queue.put(cleanup_request) self.request_queue.put(None) if self.data_thread is not None: thread = self.data_thread # Wait until streaming RPCs are done. if thread is not None: thread.join() def _blocking_send( self, req: ray_client_pb2.DataRequest ) -> ray_client_pb2.DataResponse: with self.lock: self._check_shutdown() req_id = self._next_id() req.req_id = req_id self.request_queue.put(req) self.outstanding_requests[req_id] = req self.cv.wait_for(lambda: req_id in self.ready_data or self._in_shutdown) self._check_shutdown() data = self.ready_data[req_id] del self.ready_data[req_id] del self.outstanding_requests[req_id] self._acknowledge(req_id) return data def _async_send( self, req: ray_client_pb2.DataRequest, callback: Optional[ResponseCallable] = None, ) -> None: with self.lock: self._check_shutdown() req_id = self._next_id() req.req_id = req_id self.asyncio_waiting_data[req_id] = callback self.outstanding_requests[req_id] = req self.request_queue.put(req) # Must hold self.lock when calling this function. def _check_shutdown(self): assert self.lock.locked() if not self._in_shutdown: return self.lock.release() # Do not try disconnect() or throw exceptions in self.data_thread. # Otherwise deadlock can occur. if threading.current_thread().ident == self.data_thread.ident: return from ray.util import disconnect disconnect() self.lock.acquire() if self._last_exception is not None: msg = ( "Request can't be sent because the Ray client has already " "been disconnected due to an error. Last exception: " f"{self._last_exception}" ) else: msg = ( "Request can't be sent because the Ray client has already " "been disconnected." ) raise ConnectionError(msg) def Init( self, request: ray_client_pb2.InitRequest, context=None ) -> ray_client_pb2.InitResponse: datareq = ray_client_pb2.DataRequest( init=request, ) resp = self._blocking_send(datareq) return resp.init def PrepRuntimeEnv( self, request: ray_client_pb2.PrepRuntimeEnvRequest, context=None ) -> ray_client_pb2.PrepRuntimeEnvResponse: datareq = ray_client_pb2.DataRequest( prep_runtime_env=request, ) resp = self._blocking_send(datareq) return resp.prep_runtime_env def ConnectionInfo(self, context=None) -> ray_client_pb2.ConnectionInfoResponse: datareq = ray_client_pb2.DataRequest( connection_info=ray_client_pb2.ConnectionInfoRequest() ) resp = self._blocking_send(datareq) return resp.connection_info def GetObject( self, request: ray_client_pb2.GetRequest, context=None ) -> ray_client_pb2.GetResponse: datareq = ray_client_pb2.DataRequest( get=request, ) resp = self._blocking_send(datareq) return resp.get def RegisterGetCallback( self, request: ray_client_pb2.GetRequest, callback: ResponseCallable ) -> None: if len(request.ids) != 1: raise ValueError( "RegisterGetCallback() must have exactly 1 Object ID. " f"Actual: {request}" ) datareq = ray_client_pb2.DataRequest( get=request, ) collector = ChunkCollector(callback=callback, request=datareq) self._async_send(datareq, collector) # TODO: convert PutObject to async def PutObject( self, request: ray_client_pb2.PutRequest, context=None ) -> ray_client_pb2.PutResponse: datareq = ray_client_pb2.DataRequest( put=request, ) resp = self._blocking_send(datareq) return resp.put def ReleaseObject( self, request: ray_client_pb2.ReleaseRequest, context=None ) -> None: datareq = ray_client_pb2.DataRequest( release=request, ) self._async_send(datareq) def Schedule(self, request: ray_client_pb2.ClientTask, callback: ResponseCallable): datareq = ray_client_pb2.DataRequest(task=request) self._async_send(datareq, callback) def Terminate( self, request: ray_client_pb2.TerminateRequest ) -> ray_client_pb2.TerminateResponse: req = ray_client_pb2.DataRequest( terminate=request, ) resp = self._blocking_send(req) return resp.terminate def ListNamedActors( self, request: ray_client_pb2.ClientListNamedActorsRequest ) -> ray_client_pb2.ClientListNamedActorsResponse: req = ray_client_pb2.DataRequest( list_named_actors=request, ) resp = self._blocking_send(req) return resp.list_named_actors