| import asyncio |
| import concurrent.futures |
| import logging |
| import multiprocessing |
| import threading |
| from dataclasses import dataclass |
| from typing import Awaitable, Optional |
|
|
| from ray.dashboard.optional_deps import aiohttp |
|
|
| from ray.dashboard.subprocesses.message import ( |
| ChildBoundMessage, |
| ErrorMessage, |
| ParentBoundMessage, |
| RequestMessage, |
| UnaryResponseMessage, |
| StreamResponseDataMessage, |
| StreamResponseEndMessage, |
| StreamResponseStartMessage, |
| ) |
| from ray.dashboard.subprocesses.module import ( |
| SubprocessModule, |
| SubprocessModuleConfig, |
| run_module, |
| ) |
| from ray.dashboard.subprocesses.utils import ( |
| assert_not_in_asyncio_loop, |
| ThreadSafeDict, |
| module_logging_filename, |
| ) |
|
|
| """ |
| This file contains code run in the parent process. It can start a subprocess and send |
| messages to it. Requires non-minimal Ray. |
| """ |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class SubprocessModuleHandle: |
| """ |
| A handle to a module created as a subprocess. Can send messages to the module and |
| receive responses. On destruction, the subprocess is terminated. |
| |
| Lifecycle: |
| 1. In SubprocessModuleHandle creation, the subprocess is started, and 2 queues are |
| created. |
| 2. User must call SubprocessModuleHandle.start_module() before it can handle parent |
| bound messages. |
| 3. SubprocessRouteTable.bind(handle) |
| 4. app.add_routes(routes=SubprocessRouteTable.bound_routes()) |
| 5. Run the app. |
| |
| Health check (_do_periodic_health_check): |
| Every 1s, do a health check by _do_once_health_check. If the module is |
| unhealthy: |
| 1. log the exception |
| 2. log the last N lines of the log file |
| 3. fail all active requests |
| 4. restart the module |
| |
| TODO(ryw): define policy for health check: |
| - check period (Now: 1s) |
| - define unhealthy. (Now: process exits. TODO: check_health() for event loop hang) |
| - check number of failures in a row before we deem it unhealthy (Now: N/A) |
| - "max number of restarts"? (Now: infinite) |
| """ |
|
|
| @dataclass |
| class ActiveRequest: |
| request: aiohttp.web.Request |
| |
| |
| response_fut: Awaitable[aiohttp.web.StreamResponse] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| stream_response: Optional[ |
| concurrent.futures.Future[aiohttp.web.StreamResponse] |
| ] = None |
|
|
| def __init__( |
| self, |
| loop: asyncio.AbstractEventLoop, |
| module_cls: type[SubprocessModule], |
| config: SubprocessModuleConfig, |
| ): |
| self.loop = loop |
| self.module_cls = module_cls |
| self.config = config |
|
|
| |
| self.incarnation = 0 |
| |
| self.next_request_id = None |
| self.child_bound_queue = None |
| self.parent_bound_queue = None |
| self.active_requests = ThreadSafeDict[ |
| int, SubprocessModuleHandle.ActiveRequest |
| ]() |
| self.process = None |
| self.dispatch_parent_bound_messages_thread = None |
| self.health_check_task = None |
|
|
| def str_for_state(self, incarnation: int, pid: Optional[int]): |
| return f"SubprocessModuleHandle(module_cls={self.module_cls.__name__}, incarnation={incarnation}, pid={pid})" |
|
|
| def __str__(self): |
| return self.str_for_state( |
| self.incarnation, self.process.pid if self.process else None |
| ) |
|
|
| def start_module(self, start_dispatch_parent_bound_messages_thread: bool = True): |
| """ |
| Params: |
| - start_dispatch_parent_bound_messages_thread: used for testing. |
| """ |
| self.next_request_id = 0 |
| self.child_bound_queue = multiprocessing.Queue() |
| self.parent_bound_queue = multiprocessing.Queue() |
| self.active_requests.pop_all() |
| self.process = multiprocessing.Process( |
| target=run_module, |
| args=( |
| self.child_bound_queue, |
| self.parent_bound_queue, |
| self.module_cls, |
| self.config, |
| ), |
| daemon=True, |
| ) |
| self.process.start() |
|
|
| if start_dispatch_parent_bound_messages_thread: |
| self.dispatch_parent_bound_messages_thread = threading.Thread( |
| name=f"{self.module_cls.__name__}-dispatch_parent_bound_messages_thread", |
| target=self.dispatch_parent_bound_messages, |
| daemon=True, |
| ) |
| self.dispatch_parent_bound_messages_thread.start() |
|
|
| self.health_check_task = self.loop.create_task(self._do_periodic_health_check()) |
|
|
| async def destroy_module(self, reason: Exception): |
| """ |
| Destroy the module. This is called when the module is unhealthy. |
| |
| async because we need to set exceptions to the futures. |
| |
| Params: |
| - reason: the exception that caused the module to be destroyed. Propagated to |
| active requests so they can be failed. |
| """ |
| self.incarnation += 1 |
| self.next_request_id = 0 |
| self.process.terminate() |
| self.process = None |
|
|
| for active_request in self.active_requests.pop_all().values(): |
| active_request.response_fut.set_exception(reason) |
| self.parent_bound_queue.close() |
| self.parent_bound_queue = None |
|
|
| self.child_bound_queue.close() |
| self.child_bound_queue = None |
|
|
| |
| self.dispatch_parent_bound_messages_thread = None |
|
|
| self.health_check_task.cancel() |
| self.health_check_task = None |
|
|
| async def health_check(self) -> aiohttp.web.Response: |
| """ |
| Do internal health check. The module should respond immediately with a 200 OK. |
| This can be used to measure module responsiveness in RTT, it also indicates |
| subprocess event loop lag. |
| |
| Currently you get a 200 OK with body = b'ok!'. Later if we want we can add more |
| observability payloads. |
| """ |
| return await self.send_request("_internal_health_check", request=None) |
|
|
| async def _do_once_health_check(self): |
| """ |
| Do a health check once. We check for: |
| 1. if the process exits, it's considered died. |
| |
| # TODO(ryw): also do `await self.health_check()` and define a policy to |
| # determine if the process is dead. |
| """ |
| if self.process.exitcode is not None: |
| raise RuntimeError(f"Process exited with code {self.process.exitcode}") |
|
|
| async def _do_periodic_health_check(self): |
| """ |
| Every 1s, do a health check. If the module is unhealthy: |
| 1. log the exception |
| 2. log the last N lines of the log file |
| 3. fail all active requests |
| 4. restart the module |
| """ |
| while True: |
| try: |
| await self._do_once_health_check() |
| except Exception as e: |
| filename = module_logging_filename( |
| self.module_cls.__name__, self.config.logging_filename |
| ) |
| logger.exception( |
| f"Module {self.module_cls.__name__} is unhealthy. Please refer to" |
| f"{self.config.log_dir}/{filename} " |
| "for more details. Failing all active requests." |
| ) |
| await self.destroy_module(e) |
| self.start_module() |
| return |
| await asyncio.sleep(1) |
|
|
| async def send_request( |
| self, method_name: str, request: Optional[aiohttp.web.Request] |
| ) -> Awaitable[aiohttp.web.StreamResponse]: |
| """ |
| Sends a new request. Bookkeeps it in self.active_requests and sends the |
| request to the module. Returns a Future that will be resolved with the response |
| from the module. |
| """ |
| request_id = self.next_request_id |
| self.next_request_id += 1 |
|
|
| new_active_request = SubprocessModuleHandle.ActiveRequest( |
| request=request, response_fut=self.loop.create_future() |
| ) |
| self.active_requests.put_new(request_id, new_active_request) |
| if request is None: |
| body = b"" |
| else: |
| body = await request.read() |
| self._send_message( |
| RequestMessage(request_id=request_id, method_name=method_name, body=body) |
| ) |
| return await new_active_request.response_fut |
|
|
| def _send_message(self, message: ChildBoundMessage): |
| self.child_bound_queue.put(message) |
|
|
| @staticmethod |
| async def handle_stream_response_start( |
| request: aiohttp.web.Request, first_data: bytes |
| ) -> aiohttp.web.StreamResponse: |
| |
| response = aiohttp.web.StreamResponse() |
| response.content_type = "text/plain" |
| await response.prepare(request) |
| await response.write(first_data) |
| return response |
|
|
| @staticmethod |
| async def handle_stream_response_data( |
| prev_fut: Awaitable[aiohttp.web.StreamResponse], data: bytes |
| ) -> aiohttp.web.StreamResponse: |
| |
| response = await asyncio.wrap_future(prev_fut) |
| await response.write(data) |
| return response |
|
|
| @staticmethod |
| async def handle_stream_response_end( |
| prev_fut: Awaitable[aiohttp.web.StreamResponse], |
| response_fut: Awaitable[aiohttp.web.StreamResponse], |
| ) -> None: |
| try: |
| response = await asyncio.wrap_future(prev_fut) |
| await response.write_eof() |
| response_fut.set_result(response) |
| except Exception as e: |
| response_fut.set_exception(e) |
|
|
| @staticmethod |
| async def handle_stream_response_error( |
| prev_fut: Awaitable[aiohttp.web.StreamResponse], |
| exception: Exception, |
| response_fut: Awaitable[aiohttp.web.StreamResponse], |
| ) -> None: |
| """ |
| When the async iterator in the module raises an error, we need to propagate it |
| to the client and close the stream. However, we already sent a 200 OK to the |
| client and can't change that to a 500. We can't just raise an exception here to |
| aiohttp because that causes it to abruptly close the connection and the client |
| will raise a ClientPayloadError(TransferEncodingError). |
| |
| Instead, we write exception to the stream and close the stream. |
| """ |
| try: |
| response = await asyncio.wrap_future(prev_fut) |
| await response.write(str(exception).encode()) |
| await response.write_eof() |
| response_fut.set_result(response) |
| except Exception as e: |
| response_fut.set_exception(e) |
|
|
| def handle_parent_bound_message(self, message: ParentBoundMessage): |
| """Handles a message from the parent bound queue. This function must run on a |
| dedicated thread, called by dispatch_parent_bound_messages.""" |
| loop = self.loop |
| if isinstance(message, UnaryResponseMessage): |
| active_request = self.active_requests.pop_or_raise(message.request_id) |
| |
| loop.call_soon_threadsafe( |
| active_request.response_fut.set_result, |
| aiohttp.web.Response( |
| status=message.status, |
| body=message.body, |
| ), |
| ) |
| elif isinstance(message, StreamResponseStartMessage): |
| active_request = self.active_requests.get_or_raise(message.request_id) |
| assert active_request.stream_response is None |
| |
| |
| |
| active_request.stream_response = asyncio.run_coroutine_threadsafe( |
| SubprocessModuleHandle.handle_stream_response_start( |
| active_request.request, message.body |
| ), |
| loop, |
| ) |
| elif isinstance(message, StreamResponseDataMessage): |
| active_request = self.active_requests.get_or_raise(message.request_id) |
| assert active_request.stream_response is not None |
| active_request.stream_response = asyncio.run_coroutine_threadsafe( |
| SubprocessModuleHandle.handle_stream_response_data( |
| active_request.stream_response, message.body |
| ), |
| loop, |
| ) |
| elif isinstance(message, StreamResponseEndMessage): |
| active_request = self.active_requests.pop_or_raise(message.request_id) |
| assert active_request.stream_response is not None |
| asyncio.run_coroutine_threadsafe( |
| SubprocessModuleHandle.handle_stream_response_end( |
| active_request.stream_response, |
| active_request.response_fut, |
| ), |
| loop, |
| ) |
| elif isinstance(message, ErrorMessage): |
| |
| active_request = self.active_requests.pop_or_raise(message.request_id) |
| if active_request.stream_response is not None: |
| asyncio.run_coroutine_threadsafe( |
| SubprocessModuleHandle.handle_stream_response_error( |
| active_request.stream_response, |
| message.error, |
| active_request.response_fut, |
| ), |
| loop, |
| ) |
| else: |
| loop.call_soon_threadsafe( |
| active_request.response_fut.set_exception, message.error |
| ) |
| else: |
| raise ValueError(f"Unknown message type: {type(message)}") |
|
|
| def dispatch_parent_bound_messages(self): |
| """ |
| Dispatch Messages from the module. This function should be run in a separate thread |
| from the asyncio loop of the parent process. |
| """ |
| assert_not_in_asyncio_loop() |
| incarnation = self.incarnation |
| pid = self.process.pid if self.process else None |
| self_str = self.str_for_state(incarnation, pid) |
|
|
| queue = self.parent_bound_queue |
| |
| while incarnation == self.incarnation: |
| message = None |
| try: |
| message = queue.get(timeout=1) |
| except multiprocessing.queues.Empty: |
| |
| continue |
| except ValueError: |
| |
| break |
| except Exception: |
| logger.exception( |
| f"Error unpickling parent bound message from {self_str}." |
| " This may result in a http request never being responded to." |
| ) |
| continue |
| try: |
| self.handle_parent_bound_message(message) |
| except Exception: |
| logger.exception( |
| f"Error handling parent bound message from {self_str}." |
| " This may result in a http request never being responded to." |
| ) |
|
|
| logger.info(f"dispatch_parent_bound_messages thread for {self_str} is exiting") |
|
|