File size: 24,310 Bytes
de7cd93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
import copy
from ray.experimental.channel.auto_transport_type import AutoTransportType
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import ray
from ray.dag.base import DAGNodeBase
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.util.annotations import DeveloperAPI

from itertools import chain

from typing import (
    Optional,
    Union,
    List,
    Tuple,
    Dict,
    Any,
    TypeVar,
    Callable,
)
import uuid
import asyncio

from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag
from ray.experimental.channel import ChannelOutputType
from ray.experimental.channel.communicator import Communicator

T = TypeVar("T")


@DeveloperAPI
class DAGNode(DAGNodeBase):
    """Abstract class for a node in a Ray task graph.

    A node has a type (e.g., FunctionNode), data (e.g., function options and
    body), arguments (Python values, DAGNodes, and DAGNodes nested within Python
    argument values) and options (Ray API .options() used for function, class
    or class method)
    """

    def __init__(
        self,
        args: Tuple[Any],
        kwargs: Dict[str, Any],
        options: Dict[str, Any],
        other_args_to_resolve: Dict[str, Any],
    ):
        """
        args:
            args (Tuple[Any]): Bound node arguments.
                ex: func_or_class.bind(1)
            kwargs (Dict[str, Any]): Bound node keyword arguments.
                ex: func_or_class.bind(a=1)
            options (Dict[str, Any]): Bound node options arguments.
                ex: func_or_class.options(num_cpus=2)
            other_args_to_resolve (Dict[str, Any]): Bound kwargs to resolve
                that's specific to subclass implementation without exposing
                as args in base class, example: ClassMethodNode
        """
        self._bound_args: Tuple[Any] = args or []
        self._bound_kwargs: Dict[str, Any] = kwargs or {}
        self._bound_options: Dict[str, Any] = options or {}
        self._bound_other_args_to_resolve: Optional[Dict[str, Any]] = (
            other_args_to_resolve or {}
        )

        # The list of nodes that use this DAG node as an argument.
        self._downstream_nodes: List["DAGNode"] = []

        # UUID that is not changed over copies of this node.
        self._stable_uuid = uuid.uuid4().hex

        # Indicates whether this DAG node contains nested DAG nodes.
        # Nested DAG nodes are allowed in traditional DAGs but not
        # in Ray Compiled Graphs, except for MultiOutputNode.
        self._args_contain_nested_dag_node = False

        # The list of nodes that this DAG node uses as an argument.
        self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()

        # Cached values from last call to execute()
        self.cache_from_last_execute = {}

        self._type_hint: ChannelOutputType = ChannelOutputType()

        # If the original type hint is an AutoTransportType, we make a copy
        # here when it is resolved to the actual type, as additional debugging
        # information. Otherwise, it is None.
        self._original_type_hint: Optional[ChannelOutputType] = None

        # Whether this node calls `experimental_compile`.
        self.is_cgraph_output_node = False

    def _collect_upstream_nodes(self) -> List["DAGNode"]:
        """
        Retrieve upstream nodes and update their downstream dependencies.

        Currently, the DAG assumes that all DAGNodes in `args`, `kwargs`, and
        `other_args_to_resolve` are upstream nodes. However, Ray Compiled Graphs
        builds the upstream/downstream relationship based only on args. Be cautious
        when persisting DAGNodes in `other_args_to_resolve` and kwargs in the future.

        TODO (kevin85421): Currently, the upstream nodes and downstream nodes have
        circular references. Therefore, it relies on the garbage collector to clean
        them up instead of reference counting. We should consider using weak references
        to avoid circular references.
        """
        upstream_nodes: List["DAGNode"] = []

        # Ray Compiled Graphs do not allow nested DAG nodes in arguments.
        # Specifically, a DAGNode should not be placed inside any type of
        # container. However, we only know if this is a compiled graph
        # when calling `experimental_compile`. Therefore, we need to check
        # in advance if the arguments contain nested DAG nodes and raise
        # an error after compilation.
        assert hasattr(self._bound_args, "__iter__")
        for arg in self._bound_args:
            if isinstance(arg, DAGNode):
                upstream_nodes.append(arg)
            else:
                scanner = _PyObjScanner()
                dag_nodes = scanner.find_nodes(arg)
                upstream_nodes.extend(dag_nodes)
                scanner.clear()
                self._args_contain_nested_dag_node = len(dag_nodes) > 0

        scanner = _PyObjScanner()
        other_upstream_nodes: List["DAGNode"] = scanner.find_nodes(
            [
                self._bound_kwargs,
                self._bound_other_args_to_resolve,
            ]
        )
        upstream_nodes.extend(other_upstream_nodes)
        scanner.clear()
        # Update dependencies.
        for upstream_node in upstream_nodes:
            upstream_node._downstream_nodes.append(self)
        return upstream_nodes

    def with_tensor_transport(
        self,
        transport: Optional[Union[str, Communicator]] = "auto",
        _static_shape: bool = False,
        _direct_return: bool = False,
    ):
        if transport == "auto":
            self._type_hint = AutoTransportType(
                _static_shape=_static_shape,
                _direct_return=_direct_return,
            )
        elif transport == "nccl":
            self._type_hint = TorchTensorType(
                transport=transport,
                _static_shape=_static_shape,
                _direct_return=_direct_return,
            )
        else:
            if not isinstance(transport, Communicator):
                raise ValueError(
                    "transport must be 'auto', 'nccl' or a Communicator type"
                )
            self._type_hint = TorchTensorType(
                transport=transport,
                _static_shape=_static_shape,
                _direct_return=_direct_return,
            )
        return self

    @property
    def type_hint(self) -> ChannelOutputType:
        return self._type_hint

    @type_hint.setter
    def type_hint(self, type_hint: ChannelOutputType) -> None:
        if isinstance(self._type_hint, AutoTransportType):
            self._original_type_hint = self._type_hint
        self._type_hint = type_hint

    def get_args(self) -> Tuple[Any]:
        """Return the tuple of arguments for this node."""

        return self._bound_args

    def get_kwargs(self) -> Dict[str, Any]:
        """Return the dict of keyword arguments for this node."""

        return self._bound_kwargs.copy()

    def get_options(self) -> Dict[str, Any]:
        """Return the dict of options arguments for this node."""

        return self._bound_options.copy()

    def get_other_args_to_resolve(self) -> Dict[str, Any]:
        """Return the dict of other args to resolve arguments for this node."""
        return self._bound_other_args_to_resolve.copy()

    def get_stable_uuid(self) -> str:
        """Return stable uuid for this node.
        1) Generated only once at first instance creation
        2) Stable across pickling, replacement and JSON serialization.
        """
        return self._stable_uuid

    async def get_object_refs_from_last_execute(self) -> Dict[str, Any]:
        """Gets cached object refs from the last call to execute().

        After this DAG is executed through execute(), retrieves a map between node
        UUID to a reference to the return value of the default executor on that node.
        """
        cache = {}
        for node_uuid, value in self.cache_from_last_execute.items():
            if isinstance(value, asyncio.Task):
                cache[node_uuid] = await value
            else:
                cache[node_uuid] = value

        return cache

    def clear_cache(self):
        self.cache_from_last_execute = {}

    def experimental_compile(
        self,
        _submit_timeout: Optional[float] = None,
        _buffer_size_bytes: Optional[int] = None,
        enable_asyncio: bool = False,
        _max_inflight_executions: Optional[int] = None,
        _overlap_gpu_communication: Optional[bool] = None,
    ) -> "ray.dag.CompiledDAG":
        """Compile an accelerated execution path for this DAG.

        Args:
            _submit_timeout: The maximum time in seconds to wait for execute() calls.
                None means using default timeout, 0 means immediate timeout
                (immediate success or timeout without blocking), -1 means
                infinite timeout (block indefinitely).
            _buffer_size_bytes: The initial buffer size in bytes for messages
                that can be passed between tasks in the DAG. The buffers will
                be automatically resized if larger messages are written to the
                channel.
            enable_asyncio: Whether to enable asyncio for this DAG.
            _max_inflight_executions: The maximum number of in-flight executions that
                can be submitted via `execute` or `execute_async` before consuming
                the output using `ray.get()`. If the caller submits more executions,
                `RayCgraphCapacityExceeded` is raised.
            _overlap_gpu_communication: (experimental) Whether to overlap GPU
                communication with computation during DAG execution. If True, the
                communication and computation can be overlapped, which can improve
                the performance of the DAG execution. If None, the default value
                will be used.

        Returns:
            A compiled DAG.
        """
        from ray.dag import DAGContext

        ctx = DAGContext.get_current()
        if _buffer_size_bytes is None:
            _buffer_size_bytes = ctx.buffer_size_bytes

        # Validate whether this DAG node has already been compiled.
        if self.is_cgraph_output_node:
            raise ValueError(
                "It is not allowed to call `experimental_compile` on the same DAG "
                "object multiple times no matter whether `teardown` is called or not. "
                "Please reuse the existing compiled DAG or create a new one."
            )
        # Whether this node is an output node in the DAG. We cannot determine
        # this in the constructor because the output node is determined when
        # `experimental_compile` is called.
        self.is_cgraph_output_node = True
        return build_compiled_dag_from_ray_dag(
            self,
            _submit_timeout,
            _buffer_size_bytes,
            enable_asyncio,
            _max_inflight_executions,
            _overlap_gpu_communication,
        )

    def execute(
        self, *args, _ray_cache_refs: bool = False, **kwargs
    ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
        """Execute this DAG using the Ray default executor _execute_impl().

        Args:
            _ray_cache_refs: If true, stores the the default executor's return values
                on each node in this DAG in a cache. These should be a mix of:
                - ray.ObjectRefs pointing to the outputs of method and function nodes
                - Serve handles for class nodes
                - resolved values representing user input at runtime
        """

        def executor(node):
            return node._execute_impl(*args, **kwargs)

        result = self.apply_recursive(executor)
        if _ray_cache_refs:
            self.cache_from_last_execute = executor.cache
        return result

    def _get_toplevel_child_nodes(self) -> List["DAGNode"]:
        """Return the list of nodes specified as top-level args.

        For example, in `f.remote(a, [b])`, only `a` is a top-level arg.

        This list of nodes are those that are typically resolved prior to
        task execution in Ray. This does not include nodes nested within args.
        For that, use ``_get_all_child_nodes()``.
        """

        # we use List instead of Set here because the hash key of the node
        # object changes each time we create it. So if using Set here, the
        # order of returned children can be different if we create the same
        # nodes and dag one more time.
        children = []
        for a in self.get_args():
            if isinstance(a, DAGNode):
                if a not in children:
                    children.append(a)
        for a in self.get_kwargs().values():
            if isinstance(a, DAGNode):
                if a not in children:
                    children.append(a)
        for a in self.get_other_args_to_resolve().values():
            if isinstance(a, DAGNode):
                if a not in children:
                    children.append(a)
        return children

    def _get_all_child_nodes(self) -> List["DAGNode"]:
        """Return the list of nodes referenced by the args, kwargs, and
        args_to_resolve in current node, even they're deeply nested.

        Examples:
            f.remote(a, [b]) -> [a, b]
            f.remote(a, [b], key={"nested": [c]}) -> [a, b, c]
        """

        scanner = _PyObjScanner()
        # we use List instead of Set here, reason explained
        # in `_get_toplevel_child_nodes`.
        children = []
        for n in scanner.find_nodes(
            [
                self._bound_args,
                self._bound_kwargs,
                self._bound_other_args_to_resolve,
            ]
        ):
            if n not in children:
                children.append(n)
        scanner.clear()
        return children

    def _apply_and_replace_all_child_nodes(
        self, fn: "Callable[[DAGNode], T]"
    ) -> "DAGNode":
        """Apply and replace all immediate child nodes using a given function.

        This is a shallow replacement only. To recursively transform nodes in
        the DAG, use ``apply_recursive()``.

        Args:
            fn: Callable that will be applied once to each child of this node.

        Returns:
            New DAGNode after replacing all child nodes.
        """

        replace_table = {}
        # CloudPickler scanner object for current layer of DAGNode. Same
        # scanner should be use for a full find & replace cycle.
        scanner = _PyObjScanner()
        # Find all first-level nested DAGNode children in args.
        # Update replacement table and execute the replace.
        for node in scanner.find_nodes(
            [
                self._bound_args,
                self._bound_kwargs,
                self._bound_other_args_to_resolve,
            ]
        ):
            if node not in replace_table:
                replace_table[node] = fn(node)
        new_args, new_kwargs, new_other_args_to_resolve = scanner.replace_nodes(
            replace_table
        )
        scanner.clear()

        # Return updated copy of self.
        return self._copy(
            new_args, new_kwargs, self.get_options(), new_other_args_to_resolve
        )

    def apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
        """Apply callable on each node in this DAG in a bottom-up tree walk.

        Args:
            fn: Callable that will be applied once to each node in the
                DAG. It will be applied recursively bottom-up, so nodes can
                assume the fn has been applied to their args already.

        Returns:
            Return type of the fn after application to the tree.
        """

        if not type(fn).__name__ == "_CachingFn":

            class _CachingFn:
                def __init__(self, fn):
                    self.cache = {}
                    self.fn = fn
                    self.fn.cache = self.cache
                    self.input_node_uuid = None

                def __call__(self, node: "DAGNode"):
                    from ray.dag.input_node import InputNode

                    if node._stable_uuid not in self.cache:
                        self.cache[node._stable_uuid] = self.fn(node)
                    if isinstance(node, InputNode):
                        if not self.input_node_uuid:
                            self.input_node_uuid = node._stable_uuid
                        elif self.input_node_uuid != node._stable_uuid:
                            raise AssertionError(
                                "Each DAG should only have one unique InputNode."
                            )
                    return self.cache[node._stable_uuid]

            fn = _CachingFn(fn)
        else:
            if self._stable_uuid in fn.cache:
                return fn.cache[self._stable_uuid]

        return fn(
            self._apply_and_replace_all_child_nodes(
                lambda node: node.apply_recursive(fn)
            )
        )

    def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"):
        """
        Traverse all nodes in the connected component of the DAG that contains
        the `self` node, and apply the given function to each node.
        """
        visited = set()
        queue = [self]
        cgraph_output_node: Optional[DAGNode] = None

        while queue:
            node = queue.pop(0)
            if node._args_contain_nested_dag_node:
                self._raise_nested_dag_node_error(node._bound_args)

            if node not in visited:
                if node.is_cgraph_output_node:
                    # Validate whether there are multiple nodes that call
                    # `experimental_compile`.
                    if cgraph_output_node is not None:
                        raise ValueError(
                            "The DAG was compiled more than once. The following two "
                            "nodes call `experimental_compile`: "
                            f"(1) {cgraph_output_node}, (2) {node}"
                        )
                    cgraph_output_node = node
                fn(node)
                visited.add(node)
                """
                Add all unseen downstream and upstream nodes to the queue.
                This function should be called by the root of the DAG. However,
                in some invalid cases, some nodes may not be descendants of the
                root. Therefore, we also add upstream nodes to the queue so that
                a meaningful error message can be raised when the DAG is compiled.

                ```
                with InputNode() as inp:
                    dag = MultiOutputNode([a1.inc.bind(inp), a2.inc.bind(1)])
                ```

                In the above example, `a2.inc` is not a descendant of inp. If we only
                add downstream nodes to the queue, the `a2.inc` node will not be visited
                , and the error message will be hard to understand, such as a key error
                in the compiled DAG.
                """
                for neighbor in chain.from_iterable(
                    [node._downstream_nodes, node._upstream_nodes]
                ):
                    if neighbor not in visited:
                        queue.append(neighbor)

    def _raise_nested_dag_node_error(self, args):
        """
        Raise an error for nested DAGNodes in Ray Compiled Graphs.

        Args:
            args: The arguments of the DAGNode.
        """
        for arg in args:
            if isinstance(arg, DAGNode):
                continue
            else:
                scanner = _PyObjScanner()
                dag_nodes = scanner.find_nodes([arg])
                scanner.clear()
                if len(dag_nodes) > 0:
                    raise ValueError(
                        f"Found {len(dag_nodes)} DAGNodes from the arg {arg} "
                        f"in {self}. Please ensure that the argument is a "
                        "single DAGNode and that a DAGNode is not allowed to "
                        "be placed inside any type of container."
                    )
        raise AssertionError(
            "A DAGNode's args should contain nested DAGNodes as args, "
            "but none were found during the compilation process. This is a "
            "Ray internal error. Please report this issue to the Ray team."
        )

    def _find_root(self) -> "DAGNode":
        """
        Return the root node of the DAG. The root node must be an InputNode.
        """
        from ray.dag.input_node import InputNode

        node = self
        while not isinstance(node, InputNode):
            if len(node._upstream_nodes) == 0:
                raise ValueError(
                    "No InputNode found in the DAG: when traversing upwards, "
                    f"no upstream node was found for {node}."
                )
            node = node._upstream_nodes[0]
        return node

    def apply_functional(
        self,
        source_input_list: Any,
        predictate_fn: Callable,
        apply_fn: Callable,
    ):
        """
        Apply a given function to DAGNodes in source_input_list, and return
        the replaced inputs without mutating or coping any DAGNode.

        Args:
            source_input_list: Source inputs to extract and apply function on
                all children DAGNode instances.
            predictate_fn: Applied on each DAGNode instance found and determine
                if we should apply function to it. Can be used to filter node
                types.
            apply_fn: Function to appy on the node on bound attributes. Example:
                apply_fn = lambda node: node._get_serve_deployment_handle(
                    node._deployment, node._bound_other_args_to_resolve
                )

        Returns:
            replaced_inputs: Outputs of apply_fn on DAGNodes in
                source_input_list that passes predictate_fn.
        """
        replace_table = {}
        scanner = _PyObjScanner()
        for node in scanner.find_nodes(source_input_list):
            if predictate_fn(node) and node not in replace_table:
                replace_table[node] = apply_fn(node)

        replaced_inputs = scanner.replace_nodes(replace_table)
        scanner.clear()

        return replaced_inputs

    def _execute_impl(
        self, *args, **kwargs
    ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
        """Execute this node, assuming args have been transformed already."""
        raise NotImplementedError

    def _copy_impl(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ) -> "DAGNode":
        """Return a copy of this node with the given new args."""
        raise NotImplementedError

    def _copy(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ) -> "DAGNode":
        """Return a copy of this node with the given new args."""
        instance = self._copy_impl(
            new_args, new_kwargs, new_options, new_other_args_to_resolve
        )
        instance._stable_uuid = self._stable_uuid
        instance._type_hint = copy.deepcopy(self._type_hint)
        instance._original_type_hint = copy.deepcopy(self._original_type_hint)
        return instance

    def __getstate__(self):
        """Required due to overriding `__getattr__` else pickling fails."""
        return self.__dict__

    def __setstate__(self, d: Dict[str, Any]):
        """Required due to overriding `__getattr__` else pickling fails."""
        self.__dict__.update(d)

    def __getattr__(self, attr: str):
        if attr == "bind":
            raise AttributeError(f".bind() cannot be used again on {type(self)} ")
        elif attr == "remote":
            raise AttributeError(
                f".remote() cannot be used on {type(self)}. To execute the task "
                "graph for this node, use .execute()."
            )
        else:
            return self.__getattribute__(attr)