File size: 1,363 Bytes
de7cd93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import ray
from typing import Any, Dict, List, Union, Tuple
from ray.dag import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.util.annotations import DeveloperAPI
@DeveloperAPI
class MultiOutputNode(DAGNode):
"""Ray dag node used in DAG building API to mark the endpoint of DAG"""
def __init__(
self,
args: Union[List[DAGNode], Tuple[DAGNode]],
other_args_to_resolve: Dict[str, Any] = None,
):
if isinstance(args, tuple):
args = list(args)
if not isinstance(args, list):
raise ValueError(f"Invalid input type for `args`, {type(args)}.")
super().__init__(
args,
{},
{},
other_args_to_resolve=other_args_to_resolve or {},
)
def _execute_impl(
self, *args, **kwargs
) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
return self._bound_args
def _copy_impl(
self,
new_args: List[Any],
new_kwargs: Dict[str, Any],
new_options: Dict[str, Any],
new_other_args_to_resolve: Dict[str, Any],
) -> "DAGNode":
"""Return a copy of this node with the given new args."""
return MultiOutputNode(new_args, new_other_args_to_resolve)
def __str__(self) -> str:
return get_dag_node_str(self, "__MultiOutputNode__")
|