File size: 24,310 Bytes
de7cd93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 | import copy
from ray.experimental.channel.auto_transport_type import AutoTransportType
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import ray
from ray.dag.base import DAGNodeBase
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.util.annotations import DeveloperAPI
from itertools import chain
from typing import (
Optional,
Union,
List,
Tuple,
Dict,
Any,
TypeVar,
Callable,
)
import uuid
import asyncio
from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag
from ray.experimental.channel import ChannelOutputType
from ray.experimental.channel.communicator import Communicator
T = TypeVar("T")
@DeveloperAPI
class DAGNode(DAGNodeBase):
"""Abstract class for a node in a Ray task graph.
A node has a type (e.g., FunctionNode), data (e.g., function options and
body), arguments (Python values, DAGNodes, and DAGNodes nested within Python
argument values) and options (Ray API .options() used for function, class
or class method)
"""
def __init__(
self,
args: Tuple[Any],
kwargs: Dict[str, Any],
options: Dict[str, Any],
other_args_to_resolve: Dict[str, Any],
):
"""
args:
args (Tuple[Any]): Bound node arguments.
ex: func_or_class.bind(1)
kwargs (Dict[str, Any]): Bound node keyword arguments.
ex: func_or_class.bind(a=1)
options (Dict[str, Any]): Bound node options arguments.
ex: func_or_class.options(num_cpus=2)
other_args_to_resolve (Dict[str, Any]): Bound kwargs to resolve
that's specific to subclass implementation without exposing
as args in base class, example: ClassMethodNode
"""
self._bound_args: Tuple[Any] = args or []
self._bound_kwargs: Dict[str, Any] = kwargs or {}
self._bound_options: Dict[str, Any] = options or {}
self._bound_other_args_to_resolve: Optional[Dict[str, Any]] = (
other_args_to_resolve or {}
)
# The list of nodes that use this DAG node as an argument.
self._downstream_nodes: List["DAGNode"] = []
# UUID that is not changed over copies of this node.
self._stable_uuid = uuid.uuid4().hex
# Indicates whether this DAG node contains nested DAG nodes.
# Nested DAG nodes are allowed in traditional DAGs but not
# in Ray Compiled Graphs, except for MultiOutputNode.
self._args_contain_nested_dag_node = False
# The list of nodes that this DAG node uses as an argument.
self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()
# Cached values from last call to execute()
self.cache_from_last_execute = {}
self._type_hint: ChannelOutputType = ChannelOutputType()
# If the original type hint is an AutoTransportType, we make a copy
# here when it is resolved to the actual type, as additional debugging
# information. Otherwise, it is None.
self._original_type_hint: Optional[ChannelOutputType] = None
# Whether this node calls `experimental_compile`.
self.is_cgraph_output_node = False
def _collect_upstream_nodes(self) -> List["DAGNode"]:
"""
Retrieve upstream nodes and update their downstream dependencies.
Currently, the DAG assumes that all DAGNodes in `args`, `kwargs`, and
`other_args_to_resolve` are upstream nodes. However, Ray Compiled Graphs
builds the upstream/downstream relationship based only on args. Be cautious
when persisting DAGNodes in `other_args_to_resolve` and kwargs in the future.
TODO (kevin85421): Currently, the upstream nodes and downstream nodes have
circular references. Therefore, it relies on the garbage collector to clean
them up instead of reference counting. We should consider using weak references
to avoid circular references.
"""
upstream_nodes: List["DAGNode"] = []
# Ray Compiled Graphs do not allow nested DAG nodes in arguments.
# Specifically, a DAGNode should not be placed inside any type of
# container. However, we only know if this is a compiled graph
# when calling `experimental_compile`. Therefore, we need to check
# in advance if the arguments contain nested DAG nodes and raise
# an error after compilation.
assert hasattr(self._bound_args, "__iter__")
for arg in self._bound_args:
if isinstance(arg, DAGNode):
upstream_nodes.append(arg)
else:
scanner = _PyObjScanner()
dag_nodes = scanner.find_nodes(arg)
upstream_nodes.extend(dag_nodes)
scanner.clear()
self._args_contain_nested_dag_node = len(dag_nodes) > 0
scanner = _PyObjScanner()
other_upstream_nodes: List["DAGNode"] = scanner.find_nodes(
[
self._bound_kwargs,
self._bound_other_args_to_resolve,
]
)
upstream_nodes.extend(other_upstream_nodes)
scanner.clear()
# Update dependencies.
for upstream_node in upstream_nodes:
upstream_node._downstream_nodes.append(self)
return upstream_nodes
def with_tensor_transport(
self,
transport: Optional[Union[str, Communicator]] = "auto",
_static_shape: bool = False,
_direct_return: bool = False,
):
if transport == "auto":
self._type_hint = AutoTransportType(
_static_shape=_static_shape,
_direct_return=_direct_return,
)
elif transport == "nccl":
self._type_hint = TorchTensorType(
transport=transport,
_static_shape=_static_shape,
_direct_return=_direct_return,
)
else:
if not isinstance(transport, Communicator):
raise ValueError(
"transport must be 'auto', 'nccl' or a Communicator type"
)
self._type_hint = TorchTensorType(
transport=transport,
_static_shape=_static_shape,
_direct_return=_direct_return,
)
return self
@property
def type_hint(self) -> ChannelOutputType:
return self._type_hint
@type_hint.setter
def type_hint(self, type_hint: ChannelOutputType) -> None:
if isinstance(self._type_hint, AutoTransportType):
self._original_type_hint = self._type_hint
self._type_hint = type_hint
def get_args(self) -> Tuple[Any]:
"""Return the tuple of arguments for this node."""
return self._bound_args
def get_kwargs(self) -> Dict[str, Any]:
"""Return the dict of keyword arguments for this node."""
return self._bound_kwargs.copy()
def get_options(self) -> Dict[str, Any]:
"""Return the dict of options arguments for this node."""
return self._bound_options.copy()
def get_other_args_to_resolve(self) -> Dict[str, Any]:
"""Return the dict of other args to resolve arguments for this node."""
return self._bound_other_args_to_resolve.copy()
def get_stable_uuid(self) -> str:
"""Return stable uuid for this node.
1) Generated only once at first instance creation
2) Stable across pickling, replacement and JSON serialization.
"""
return self._stable_uuid
async def get_object_refs_from_last_execute(self) -> Dict[str, Any]:
"""Gets cached object refs from the last call to execute().
After this DAG is executed through execute(), retrieves a map between node
UUID to a reference to the return value of the default executor on that node.
"""
cache = {}
for node_uuid, value in self.cache_from_last_execute.items():
if isinstance(value, asyncio.Task):
cache[node_uuid] = await value
else:
cache[node_uuid] = value
return cache
def clear_cache(self):
self.cache_from_last_execute = {}
def experimental_compile(
self,
_submit_timeout: Optional[float] = None,
_buffer_size_bytes: Optional[int] = None,
enable_asyncio: bool = False,
_max_inflight_executions: Optional[int] = None,
_overlap_gpu_communication: Optional[bool] = None,
) -> "ray.dag.CompiledDAG":
"""Compile an accelerated execution path for this DAG.
Args:
_submit_timeout: The maximum time in seconds to wait for execute() calls.
None means using default timeout, 0 means immediate timeout
(immediate success or timeout without blocking), -1 means
infinite timeout (block indefinitely).
_buffer_size_bytes: The initial buffer size in bytes for messages
that can be passed between tasks in the DAG. The buffers will
be automatically resized if larger messages are written to the
channel.
enable_asyncio: Whether to enable asyncio for this DAG.
_max_inflight_executions: The maximum number of in-flight executions that
can be submitted via `execute` or `execute_async` before consuming
the output using `ray.get()`. If the caller submits more executions,
`RayCgraphCapacityExceeded` is raised.
_overlap_gpu_communication: (experimental) Whether to overlap GPU
communication with computation during DAG execution. If True, the
communication and computation can be overlapped, which can improve
the performance of the DAG execution. If None, the default value
will be used.
Returns:
A compiled DAG.
"""
from ray.dag import DAGContext
ctx = DAGContext.get_current()
if _buffer_size_bytes is None:
_buffer_size_bytes = ctx.buffer_size_bytes
# Validate whether this DAG node has already been compiled.
if self.is_cgraph_output_node:
raise ValueError(
"It is not allowed to call `experimental_compile` on the same DAG "
"object multiple times no matter whether `teardown` is called or not. "
"Please reuse the existing compiled DAG or create a new one."
)
# Whether this node is an output node in the DAG. We cannot determine
# this in the constructor because the output node is determined when
# `experimental_compile` is called.
self.is_cgraph_output_node = True
return build_compiled_dag_from_ray_dag(
self,
_submit_timeout,
_buffer_size_bytes,
enable_asyncio,
_max_inflight_executions,
_overlap_gpu_communication,
)
def execute(
self, *args, _ray_cache_refs: bool = False, **kwargs
) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
"""Execute this DAG using the Ray default executor _execute_impl().
Args:
_ray_cache_refs: If true, stores the the default executor's return values
on each node in this DAG in a cache. These should be a mix of:
- ray.ObjectRefs pointing to the outputs of method and function nodes
- Serve handles for class nodes
- resolved values representing user input at runtime
"""
def executor(node):
return node._execute_impl(*args, **kwargs)
result = self.apply_recursive(executor)
if _ray_cache_refs:
self.cache_from_last_execute = executor.cache
return result
def _get_toplevel_child_nodes(self) -> List["DAGNode"]:
"""Return the list of nodes specified as top-level args.
For example, in `f.remote(a, [b])`, only `a` is a top-level arg.
This list of nodes are those that are typically resolved prior to
task execution in Ray. This does not include nodes nested within args.
For that, use ``_get_all_child_nodes()``.
"""
# we use List instead of Set here because the hash key of the node
# object changes each time we create it. So if using Set here, the
# order of returned children can be different if we create the same
# nodes and dag one more time.
children = []
for a in self.get_args():
if isinstance(a, DAGNode):
if a not in children:
children.append(a)
for a in self.get_kwargs().values():
if isinstance(a, DAGNode):
if a not in children:
children.append(a)
for a in self.get_other_args_to_resolve().values():
if isinstance(a, DAGNode):
if a not in children:
children.append(a)
return children
def _get_all_child_nodes(self) -> List["DAGNode"]:
"""Return the list of nodes referenced by the args, kwargs, and
args_to_resolve in current node, even they're deeply nested.
Examples:
f.remote(a, [b]) -> [a, b]
f.remote(a, [b], key={"nested": [c]}) -> [a, b, c]
"""
scanner = _PyObjScanner()
# we use List instead of Set here, reason explained
# in `_get_toplevel_child_nodes`.
children = []
for n in scanner.find_nodes(
[
self._bound_args,
self._bound_kwargs,
self._bound_other_args_to_resolve,
]
):
if n not in children:
children.append(n)
scanner.clear()
return children
def _apply_and_replace_all_child_nodes(
self, fn: "Callable[[DAGNode], T]"
) -> "DAGNode":
"""Apply and replace all immediate child nodes using a given function.
This is a shallow replacement only. To recursively transform nodes in
the DAG, use ``apply_recursive()``.
Args:
fn: Callable that will be applied once to each child of this node.
Returns:
New DAGNode after replacing all child nodes.
"""
replace_table = {}
# CloudPickler scanner object for current layer of DAGNode. Same
# scanner should be use for a full find & replace cycle.
scanner = _PyObjScanner()
# Find all first-level nested DAGNode children in args.
# Update replacement table and execute the replace.
for node in scanner.find_nodes(
[
self._bound_args,
self._bound_kwargs,
self._bound_other_args_to_resolve,
]
):
if node not in replace_table:
replace_table[node] = fn(node)
new_args, new_kwargs, new_other_args_to_resolve = scanner.replace_nodes(
replace_table
)
scanner.clear()
# Return updated copy of self.
return self._copy(
new_args, new_kwargs, self.get_options(), new_other_args_to_resolve
)
def apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
"""Apply callable on each node in this DAG in a bottom-up tree walk.
Args:
fn: Callable that will be applied once to each node in the
DAG. It will be applied recursively bottom-up, so nodes can
assume the fn has been applied to their args already.
Returns:
Return type of the fn after application to the tree.
"""
if not type(fn).__name__ == "_CachingFn":
class _CachingFn:
def __init__(self, fn):
self.cache = {}
self.fn = fn
self.fn.cache = self.cache
self.input_node_uuid = None
def __call__(self, node: "DAGNode"):
from ray.dag.input_node import InputNode
if node._stable_uuid not in self.cache:
self.cache[node._stable_uuid] = self.fn(node)
if isinstance(node, InputNode):
if not self.input_node_uuid:
self.input_node_uuid = node._stable_uuid
elif self.input_node_uuid != node._stable_uuid:
raise AssertionError(
"Each DAG should only have one unique InputNode."
)
return self.cache[node._stable_uuid]
fn = _CachingFn(fn)
else:
if self._stable_uuid in fn.cache:
return fn.cache[self._stable_uuid]
return fn(
self._apply_and_replace_all_child_nodes(
lambda node: node.apply_recursive(fn)
)
)
def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"):
"""
Traverse all nodes in the connected component of the DAG that contains
the `self` node, and apply the given function to each node.
"""
visited = set()
queue = [self]
cgraph_output_node: Optional[DAGNode] = None
while queue:
node = queue.pop(0)
if node._args_contain_nested_dag_node:
self._raise_nested_dag_node_error(node._bound_args)
if node not in visited:
if node.is_cgraph_output_node:
# Validate whether there are multiple nodes that call
# `experimental_compile`.
if cgraph_output_node is not None:
raise ValueError(
"The DAG was compiled more than once. The following two "
"nodes call `experimental_compile`: "
f"(1) {cgraph_output_node}, (2) {node}"
)
cgraph_output_node = node
fn(node)
visited.add(node)
"""
Add all unseen downstream and upstream nodes to the queue.
This function should be called by the root of the DAG. However,
in some invalid cases, some nodes may not be descendants of the
root. Therefore, we also add upstream nodes to the queue so that
a meaningful error message can be raised when the DAG is compiled.
```
with InputNode() as inp:
dag = MultiOutputNode([a1.inc.bind(inp), a2.inc.bind(1)])
```
In the above example, `a2.inc` is not a descendant of inp. If we only
add downstream nodes to the queue, the `a2.inc` node will not be visited
, and the error message will be hard to understand, such as a key error
in the compiled DAG.
"""
for neighbor in chain.from_iterable(
[node._downstream_nodes, node._upstream_nodes]
):
if neighbor not in visited:
queue.append(neighbor)
def _raise_nested_dag_node_error(self, args):
"""
Raise an error for nested DAGNodes in Ray Compiled Graphs.
Args:
args: The arguments of the DAGNode.
"""
for arg in args:
if isinstance(arg, DAGNode):
continue
else:
scanner = _PyObjScanner()
dag_nodes = scanner.find_nodes([arg])
scanner.clear()
if len(dag_nodes) > 0:
raise ValueError(
f"Found {len(dag_nodes)} DAGNodes from the arg {arg} "
f"in {self}. Please ensure that the argument is a "
"single DAGNode and that a DAGNode is not allowed to "
"be placed inside any type of container."
)
raise AssertionError(
"A DAGNode's args should contain nested DAGNodes as args, "
"but none were found during the compilation process. This is a "
"Ray internal error. Please report this issue to the Ray team."
)
def _find_root(self) -> "DAGNode":
"""
Return the root node of the DAG. The root node must be an InputNode.
"""
from ray.dag.input_node import InputNode
node = self
while not isinstance(node, InputNode):
if len(node._upstream_nodes) == 0:
raise ValueError(
"No InputNode found in the DAG: when traversing upwards, "
f"no upstream node was found for {node}."
)
node = node._upstream_nodes[0]
return node
def apply_functional(
self,
source_input_list: Any,
predictate_fn: Callable,
apply_fn: Callable,
):
"""
Apply a given function to DAGNodes in source_input_list, and return
the replaced inputs without mutating or coping any DAGNode.
Args:
source_input_list: Source inputs to extract and apply function on
all children DAGNode instances.
predictate_fn: Applied on each DAGNode instance found and determine
if we should apply function to it. Can be used to filter node
types.
apply_fn: Function to appy on the node on bound attributes. Example:
apply_fn = lambda node: node._get_serve_deployment_handle(
node._deployment, node._bound_other_args_to_resolve
)
Returns:
replaced_inputs: Outputs of apply_fn on DAGNodes in
source_input_list that passes predictate_fn.
"""
replace_table = {}
scanner = _PyObjScanner()
for node in scanner.find_nodes(source_input_list):
if predictate_fn(node) and node not in replace_table:
replace_table[node] = apply_fn(node)
replaced_inputs = scanner.replace_nodes(replace_table)
scanner.clear()
return replaced_inputs
def _execute_impl(
self, *args, **kwargs
) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
"""Execute this node, assuming args have been transformed already."""
raise NotImplementedError
def _copy_impl(
self,
new_args: List[Any],
new_kwargs: Dict[str, Any],
new_options: Dict[str, Any],
new_other_args_to_resolve: Dict[str, Any],
) -> "DAGNode":
"""Return a copy of this node with the given new args."""
raise NotImplementedError
def _copy(
self,
new_args: List[Any],
new_kwargs: Dict[str, Any],
new_options: Dict[str, Any],
new_other_args_to_resolve: Dict[str, Any],
) -> "DAGNode":
"""Return a copy of this node with the given new args."""
instance = self._copy_impl(
new_args, new_kwargs, new_options, new_other_args_to_resolve
)
instance._stable_uuid = self._stable_uuid
instance._type_hint = copy.deepcopy(self._type_hint)
instance._original_type_hint = copy.deepcopy(self._original_type_hint)
return instance
def __getstate__(self):
"""Required due to overriding `__getattr__` else pickling fails."""
return self.__dict__
def __setstate__(self, d: Dict[str, Any]):
"""Required due to overriding `__getattr__` else pickling fails."""
self.__dict__.update(d)
def __getattr__(self, attr: str):
if attr == "bind":
raise AttributeError(f".bind() cannot be used again on {type(self)} ")
elif attr == "remote":
raise AttributeError(
f".remote() cannot be used on {type(self)}. To execute the task "
"graph for this node, use .execute()."
)
else:
return self.__getattribute__(attr)
|