from typing import Any, Dict, List import ray from ray.dag.dag_node import DAGNode from ray.dag.format_utils import get_dag_node_str from ray.util.annotations import DeveloperAPI @DeveloperAPI class FunctionNode(DAGNode): """Represents a bound task node in a Ray task DAG.""" def __init__( self, func_body, func_args, func_kwargs, func_options, other_args_to_resolve=None, ): self._body = func_body super().__init__( func_args, func_kwargs, func_options, other_args_to_resolve=other_args_to_resolve, ) def _copy_impl( self, new_args: List[Any], new_kwargs: Dict[str, Any], new_options: Dict[str, Any], new_other_args_to_resolve: Dict[str, Any], ): return FunctionNode( self._body, new_args, new_kwargs, new_options, other_args_to_resolve=new_other_args_to_resolve, ) def _execute_impl(self, *args, **kwargs): """Executor of FunctionNode by ray.remote(). Args and kwargs are to match base class signature, but not in the implementation. All args and kwargs should be resolved and replaced with value in bound_args and bound_kwargs via bottom-up recursion when current node is executed. """ return ( ray.remote(self._body) .options(**self._bound_options) .remote(*self._bound_args, **self._bound_kwargs) ) def __str__(self) -> str: return get_dag_node_str(self, str(self._body))