koichi12's picture
Add files using upload-large-folder tool
6f8c8ab verified
raw
history blame
16.2 kB
import asyncio
import concurrent.futures
import logging
import multiprocessing
import threading
from dataclasses import dataclass
from typing import Awaitable, Optional
from ray.dashboard.optional_deps import aiohttp
from ray.dashboard.subprocesses.message import (
ChildBoundMessage,
ErrorMessage,
ParentBoundMessage,
RequestMessage,
UnaryResponseMessage,
StreamResponseDataMessage,
StreamResponseEndMessage,
StreamResponseStartMessage,
)
from ray.dashboard.subprocesses.module import (
SubprocessModule,
SubprocessModuleConfig,
run_module,
)
from ray.dashboard.subprocesses.utils import (
assert_not_in_asyncio_loop,
ThreadSafeDict,
module_logging_filename,
)
"""
This file contains code run in the parent process. It can start a subprocess and send
messages to it. Requires non-minimal Ray.
"""
logger = logging.getLogger(__name__)
class SubprocessModuleHandle:
"""
A handle to a module created as a subprocess. Can send messages to the module and
receive responses. On destruction, the subprocess is terminated.
Lifecycle:
1. In SubprocessModuleHandle creation, the subprocess is started, and 2 queues are
created.
2. User must call SubprocessModuleHandle.start_module() before it can handle parent
bound messages.
3. SubprocessRouteTable.bind(handle)
4. app.add_routes(routes=SubprocessRouteTable.bound_routes())
5. Run the app.
Health check (_do_periodic_health_check):
Every 1s, do a health check by _do_once_health_check. If the module is
unhealthy:
1. log the exception
2. log the last N lines of the log file
3. fail all active requests
4. restart the module
TODO(ryw): define policy for health check:
- check period (Now: 1s)
- define unhealthy. (Now: process exits. TODO: check_health() for event loop hang)
- check number of failures in a row before we deem it unhealthy (Now: N/A)
- "max number of restarts"? (Now: infinite)
"""
@dataclass
class ActiveRequest:
request: aiohttp.web.Request
# Future to a Response as the result of a aiohttp handler. It's can be a
# Response for a unary request, or a StreamResponse for a streaming request.
response_fut: Awaitable[aiohttp.web.StreamResponse]
# Only exists when the module decides this is a streaming response.
# To keep the data sent in order, we use future to synchronize. This assumes
# the Messages received from the Queue are in order.
# StreamResponseStartMessage expects this to be None. It creates the future,
# and in async, prepares a StreamResponse and resolves the future.
# StreamResponseDataMessage expects a future. It *replaces* the future with a
# new future by a coroutine that awaits the previous future, writes the data and
# resolves the new future.
# StreamResponseEndMessage expects a future. It resolves the future and sets
# the stream_response to None.
stream_response: Optional[
concurrent.futures.Future[aiohttp.web.StreamResponse]
] = None
def __init__(
self,
loop: asyncio.AbstractEventLoop,
module_cls: type[SubprocessModule],
config: SubprocessModuleConfig,
):
self.loop = loop
self.module_cls = module_cls
self.config = config
# Increment this when the module is restarted.
self.incarnation = 0
# Runtime states, set by start_module(), reset by destroy_module().
self.next_request_id = None
self.child_bound_queue = None
self.parent_bound_queue = None
self.active_requests = ThreadSafeDict[
int, SubprocessModuleHandle.ActiveRequest
]()
self.process = None
self.dispatch_parent_bound_messages_thread = None
self.health_check_task = None
def str_for_state(self, incarnation: int, pid: Optional[int]):
return f"SubprocessModuleHandle(module_cls={self.module_cls.__name__}, incarnation={incarnation}, pid={pid})"
def __str__(self):
return self.str_for_state(
self.incarnation, self.process.pid if self.process else None
)
def start_module(self, start_dispatch_parent_bound_messages_thread: bool = True):
"""
Params:
- start_dispatch_parent_bound_messages_thread: used for testing.
"""
self.next_request_id = 0
self.child_bound_queue = multiprocessing.Queue()
self.parent_bound_queue = multiprocessing.Queue()
self.active_requests.pop_all()
self.process = multiprocessing.Process(
target=run_module,
args=(
self.child_bound_queue,
self.parent_bound_queue,
self.module_cls,
self.config,
),
daemon=True,
)
self.process.start()
if start_dispatch_parent_bound_messages_thread:
self.dispatch_parent_bound_messages_thread = threading.Thread(
name=f"{self.module_cls.__name__}-dispatch_parent_bound_messages_thread",
target=self.dispatch_parent_bound_messages,
daemon=True,
)
self.dispatch_parent_bound_messages_thread.start()
self.health_check_task = self.loop.create_task(self._do_periodic_health_check())
async def destroy_module(self, reason: Exception):
"""
Destroy the module. This is called when the module is unhealthy.
async because we need to set exceptions to the futures.
Params:
- reason: the exception that caused the module to be destroyed. Propagated to
active requests so they can be failed.
"""
self.incarnation += 1
self.next_request_id = 0
self.process.terminate()
self.process = None
for active_request in self.active_requests.pop_all().values():
active_request.response_fut.set_exception(reason)
self.parent_bound_queue.close()
self.parent_bound_queue = None
self.child_bound_queue.close()
self.child_bound_queue = None
# dispatch_parent_bound_messages_thread is daemon so we don't need to join it.
self.dispatch_parent_bound_messages_thread = None
self.health_check_task.cancel()
self.health_check_task = None
async def health_check(self) -> aiohttp.web.Response:
"""
Do internal health check. The module should respond immediately with a 200 OK.
This can be used to measure module responsiveness in RTT, it also indicates
subprocess event loop lag.
Currently you get a 200 OK with body = b'ok!'. Later if we want we can add more
observability payloads.
"""
return await self.send_request("_internal_health_check", request=None)
async def _do_once_health_check(self):
"""
Do a health check once. We check for:
1. if the process exits, it's considered died.
# TODO(ryw): also do `await self.health_check()` and define a policy to
# determine if the process is dead.
"""
if self.process.exitcode is not None:
raise RuntimeError(f"Process exited with code {self.process.exitcode}")
async def _do_periodic_health_check(self):
"""
Every 1s, do a health check. If the module is unhealthy:
1. log the exception
2. log the last N lines of the log file
3. fail all active requests
4. restart the module
"""
while True:
try:
await self._do_once_health_check()
except Exception as e:
filename = module_logging_filename(
self.module_cls.__name__, self.config.logging_filename
)
logger.exception(
f"Module {self.module_cls.__name__} is unhealthy. Please refer to"
f"{self.config.log_dir}/{filename} "
"for more details. Failing all active requests."
)
await self.destroy_module(e)
self.start_module()
return
await asyncio.sleep(1)
async def send_request(
self, method_name: str, request: Optional[aiohttp.web.Request]
) -> Awaitable[aiohttp.web.StreamResponse]:
"""
Sends a new request. Bookkeeps it in self.active_requests and sends the
request to the module. Returns a Future that will be resolved with the response
from the module.
"""
request_id = self.next_request_id
self.next_request_id += 1
new_active_request = SubprocessModuleHandle.ActiveRequest(
request=request, response_fut=self.loop.create_future()
)
self.active_requests.put_new(request_id, new_active_request)
if request is None:
body = b""
else:
body = await request.read()
self._send_message(
RequestMessage(request_id=request_id, method_name=method_name, body=body)
)
return await new_active_request.response_fut
def _send_message(self, message: ChildBoundMessage):
self.child_bound_queue.put(message)
@staticmethod
async def handle_stream_response_start(
request: aiohttp.web.Request, first_data: bytes
) -> aiohttp.web.StreamResponse:
# TODO: error handling
response = aiohttp.web.StreamResponse()
response.content_type = "text/plain"
await response.prepare(request)
await response.write(first_data)
return response
@staticmethod
async def handle_stream_response_data(
prev_fut: Awaitable[aiohttp.web.StreamResponse], data: bytes
) -> aiohttp.web.StreamResponse:
# TODO: error handling
response = await asyncio.wrap_future(prev_fut)
await response.write(data)
return response
@staticmethod
async def handle_stream_response_end(
prev_fut: Awaitable[aiohttp.web.StreamResponse],
response_fut: Awaitable[aiohttp.web.StreamResponse],
) -> None:
try:
response = await asyncio.wrap_future(prev_fut)
await response.write_eof()
response_fut.set_result(response)
except Exception as e:
response_fut.set_exception(e)
@staticmethod
async def handle_stream_response_error(
prev_fut: Awaitable[aiohttp.web.StreamResponse],
exception: Exception,
response_fut: Awaitable[aiohttp.web.StreamResponse],
) -> None:
"""
When the async iterator in the module raises an error, we need to propagate it
to the client and close the stream. However, we already sent a 200 OK to the
client and can't change that to a 500. We can't just raise an exception here to
aiohttp because that causes it to abruptly close the connection and the client
will raise a ClientPayloadError(TransferEncodingError).
Instead, we write exception to the stream and close the stream.
"""
try:
response = await asyncio.wrap_future(prev_fut)
await response.write(str(exception).encode())
await response.write_eof()
response_fut.set_result(response)
except Exception as e:
response_fut.set_exception(e)
def handle_parent_bound_message(self, message: ParentBoundMessage):
"""Handles a message from the parent bound queue. This function must run on a
dedicated thread, called by dispatch_parent_bound_messages."""
loop = self.loop
if isinstance(message, UnaryResponseMessage):
active_request = self.active_requests.pop_or_raise(message.request_id)
# set_result is not thread safe.
loop.call_soon_threadsafe(
active_request.response_fut.set_result,
aiohttp.web.Response(
status=message.status,
body=message.body,
),
)
elif isinstance(message, StreamResponseStartMessage):
active_request = self.active_requests.get_or_raise(message.request_id)
assert active_request.stream_response is None
# This assignment is thread safe, because a next read will come from another
# handle_parent_bound_message call for a Stream.*Message, which will run on
# the same thread and hence will happen-after this assignment.
active_request.stream_response = asyncio.run_coroutine_threadsafe(
SubprocessModuleHandle.handle_stream_response_start(
active_request.request, message.body
),
loop,
)
elif isinstance(message, StreamResponseDataMessage):
active_request = self.active_requests.get_or_raise(message.request_id)
assert active_request.stream_response is not None
active_request.stream_response = asyncio.run_coroutine_threadsafe(
SubprocessModuleHandle.handle_stream_response_data(
active_request.stream_response, message.body
),
loop,
)
elif isinstance(message, StreamResponseEndMessage):
active_request = self.active_requests.pop_or_raise(message.request_id)
assert active_request.stream_response is not None
asyncio.run_coroutine_threadsafe(
SubprocessModuleHandle.handle_stream_response_end(
active_request.stream_response,
active_request.response_fut,
),
loop,
)
elif isinstance(message, ErrorMessage):
# Propagate the error to aiohttp.
active_request = self.active_requests.pop_or_raise(message.request_id)
if active_request.stream_response is not None:
asyncio.run_coroutine_threadsafe(
SubprocessModuleHandle.handle_stream_response_error(
active_request.stream_response,
message.error,
active_request.response_fut,
),
loop,
)
else:
loop.call_soon_threadsafe(
active_request.response_fut.set_exception, message.error
)
else:
raise ValueError(f"Unknown message type: {type(message)}")
def dispatch_parent_bound_messages(self):
"""
Dispatch Messages from the module. This function should be run in a separate thread
from the asyncio loop of the parent process.
"""
assert_not_in_asyncio_loop()
incarnation = self.incarnation
pid = self.process.pid if self.process else None
self_str = self.str_for_state(incarnation, pid)
queue = self.parent_bound_queue
# Exit if the module has restarted.
while incarnation == self.incarnation:
message = None
try:
message = queue.get(timeout=1)
except multiprocessing.queues.Empty:
# Empty is normal.
continue
except ValueError:
# queue is closed.
break
except Exception:
logger.exception(
f"Error unpickling parent bound message from {self_str}."
" This may result in a http request never being responded to."
)
continue
try:
self.handle_parent_bound_message(message)
except Exception:
logger.exception(
f"Error handling parent bound message from {self_str}."
" This may result in a http request never being responded to."
)
logger.info(f"dispatch_parent_bound_messages thread for {self_str} is exiting")