File size: 11,243 Bytes
de7cd93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
from typing import Any, Dict, List, Union, Optional

from ray.dag import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.experimental.gradio_utils import type_to_string
from ray.util.annotations import DeveloperAPI

IN_CONTEXT_MANAGER = "__in_context_manager__"


@DeveloperAPI
class InputNode(DAGNode):
    r"""Ray dag node used in DAG building API to mark entrypoints of a DAG.

    Should only be function or class method. A DAG can have multiple
    entrypoints, but only one instance of InputNode exists per DAG, shared
    among all DAGNodes.

    Example:

    .. code-block::

                m1.forward
                /       \
        dag_input     ensemble -> dag_output
                \       /
                m2.forward

    In this pipeline, each user input is broadcasted to both m1.forward and
    m2.forward as first stop of the DAG, and authored like

    .. code-block:: python

        import ray

        @ray.remote
        class Model:
            def __init__(self, val):
                self.val = val
            def forward(self, input):
                return self.val * input

        @ray.remote
        def combine(a, b):
            return a + b

        with InputNode() as dag_input:
            m1 = Model.bind(1)
            m2 = Model.bind(2)
            m1_output = m1.forward.bind(dag_input[0])
            m2_output = m2.forward.bind(dag_input.x)
            ray_dag = combine.bind(m1_output, m2_output)

        # Pass mix of args and kwargs as input.
        ray_dag.execute(1, x=2) # 1 sent to m1, 2 sent to m2

        # Alternatively user can also pass single data object, list or dict
        # and access them via list index, object attribute or dict key str.
        ray_dag.execute(UserDataObject(m1=1, m2=2))
            # dag_input.m1, dag_input.m2
        ray_dag.execute([1, 2])
            # dag_input[0], dag_input[1]
        ray_dag.execute({"m1": 1, "m2": 2})
            # dag_input["m1"], dag_input["m2"]
    """

    def __init__(
        self,
        *args,
        input_type: Optional[Union[type, Dict[Union[int, str], type]]] = None,
        _other_args_to_resolve=None,
        **kwargs,
    ):
        """InputNode should only take attributes of validating and converting
        input data rather than the input data itself. User input should be
        provided via `ray_dag.execute(user_input)`.

        Args:
            input_type: Describes the data type of inputs user will be giving.
                - if given through singular InputNode: type of InputNode
                - if given through InputAttributeNodes: map of key -> type
                Used when deciding what Gradio block to represent the input nodes with.
            _other_args_to_resolve: Internal only to keep InputNode's execution
                context throughput pickling, replacement and serialization.
                User should not use or pass this field.
        """
        if len(args) != 0 or len(kwargs) != 0:
            raise ValueError("InputNode should not take any args or kwargs.")

        self.input_attribute_nodes = {}

        self.input_type = input_type
        if input_type is not None and isinstance(input_type, type):
            if _other_args_to_resolve is None:
                _other_args_to_resolve = {}
            _other_args_to_resolve["result_type_string"] = type_to_string(input_type)

        super().__init__([], {}, {}, other_args_to_resolve=_other_args_to_resolve)

    def _copy_impl(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ):
        return InputNode(_other_args_to_resolve=new_other_args_to_resolve)

    def _execute_impl(self, *args, **kwargs):
        """Executor of InputNode."""
        # Catch and assert singleton context at dag execution time.
        assert self._in_context_manager(), (
            "InputNode is a singleton instance that should be only used in "
            "context manager for dag building and execution. See the docstring "
            "of class InputNode for examples."
        )
        # If user only passed in one value, for simplicity we just return it.
        if len(args) == 1 and len(kwargs) == 0:
            return args[0]

        return DAGInputData(*args, **kwargs)

    def _in_context_manager(self) -> bool:
        """Return if InputNode is created in context manager."""
        if (
            not self._bound_other_args_to_resolve
            or IN_CONTEXT_MANAGER not in self._bound_other_args_to_resolve
        ):
            return False
        else:
            return self._bound_other_args_to_resolve[IN_CONTEXT_MANAGER]

    def set_context(self, key: str, val: Any):
        """Set field in parent DAGNode attribute that can be resolved in both
        pickle and JSON serialization
        """
        self._bound_other_args_to_resolve[key] = val

    def __str__(self) -> str:
        return get_dag_node_str(self, "__InputNode__")

    def __getattr__(self, key: str):
        assert isinstance(
            key, str
        ), "Please only access dag input attributes with str key."
        if key not in self.input_attribute_nodes:
            self.input_attribute_nodes[key] = InputAttributeNode(
                self, key, "__getattr__"
            )
        return self.input_attribute_nodes[key]

    def __getitem__(self, key: Union[int, str]) -> Any:
        assert isinstance(key, (str, int)), (
            "Please only use int index or str as first-level key to "
            "access fields of dag input."
        )

        input_type = None
        if self.input_type is not None and key in self.input_type:
            input_type = type_to_string(self.input_type[key])

        if key not in self.input_attribute_nodes:
            self.input_attribute_nodes[key] = InputAttributeNode(
                self, key, "__getitem__", input_type
            )
        return self.input_attribute_nodes[key]

    def __enter__(self):
        self.set_context(IN_CONTEXT_MANAGER, True)
        return self

    def __exit__(self, *args):
        pass

    def get_result_type(self) -> str:
        """Get type of the output of this DAGNode.

        Generated by ray.experimental.gradio_utils.type_to_string().
        """
        if "result_type_string" in self._bound_other_args_to_resolve:
            return self._bound_other_args_to_resolve["result_type_string"]


@DeveloperAPI
class InputAttributeNode(DAGNode):
    """Represents partial access of user input based on an index (int),
     object attribute or dict key (str).

    Examples:

        .. code-block:: python

            with InputNode() as dag_input:
                a = dag_input[0]
                b = dag_input.x
                ray_dag = add.bind(a, b)

            # This makes a = 1 and b = 2
            ray_dag.execute(1, x=2)

            with InputNode() as dag_input:
                a = dag_input[0]
                b = dag_input[1]
                ray_dag = add.bind(a, b)

            # This makes a = 2 and b = 3
            ray_dag.execute(2, 3)

            # Alternatively, you can input a single object
            # and the inputs are automatically indexed from the object:
            # This makes a = 2 and b = 3
            ray_dag.execute([2, 3])
    """

    def __init__(
        self,
        dag_input_node: InputNode,
        key: Union[int, str],
        accessor_method: str,
        input_type: str = None,
    ):
        self._dag_input_node = dag_input_node
        self._key = key
        self._accessor_method = accessor_method
        super().__init__(
            [],
            {},
            {},
            {
                "dag_input_node": dag_input_node,
                "key": key,
                "accessor_method": accessor_method,
                # Type of the input tied to this node. Used by
                # gradio_visualize_graph.GraphVisualizer to determine which Gradio
                # component should be used for this node.
                "result_type_string": input_type,
            },
        )

    def _copy_impl(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ):
        return InputAttributeNode(
            new_other_args_to_resolve["dag_input_node"],
            new_other_args_to_resolve["key"],
            new_other_args_to_resolve["accessor_method"],
            new_other_args_to_resolve["result_type_string"],
        )

    def _execute_impl(self, *args, **kwargs):
        """Executor of InputAttributeNode.

        Args and kwargs are to match base class signature, but not in the
        implementation. All args and kwargs should be resolved and replaced
        with value in bound_args and bound_kwargs via bottom-up recursion when
        current node is executed.
        """

        if isinstance(self._dag_input_node, DAGInputData):
            return self._dag_input_node[self._key]
        else:
            # dag.execute() is called with only one arg, thus when an
            # InputAttributeNode is executed, its dependent InputNode is
            # resolved with original user input python object.
            user_input_python_object = self._dag_input_node
            if isinstance(self._key, str):
                if self._accessor_method == "__getitem__":
                    return user_input_python_object[self._key]
                elif self._accessor_method == "__getattr__":
                    return getattr(user_input_python_object, self._key)
            elif isinstance(self._key, int):
                return user_input_python_object[self._key]
            else:
                raise ValueError(
                    "Please only use int index or str as first-level key to "
                    "access fields of dag input."
                )

    def __str__(self) -> str:
        return get_dag_node_str(self, f'["{self._key}"]')

    def get_result_type(self) -> str:
        """Get type of the output of this DAGNode.

        Generated by ray.experimental.gradio_utils.type_to_string().
        """
        if "result_type_string" in self._bound_other_args_to_resolve:
            return self._bound_other_args_to_resolve["result_type_string"]

    @property
    def key(self) -> Union[int, str]:
        return self._key


@DeveloperAPI
class DAGInputData:
    """If user passed multiple args and kwargs directly to dag.execute(), we
    generate this wrapper for all user inputs as one object, accessible via
    list index or object attribute key.
    """

    def __init__(self, *args, **kwargs):
        self._args = list(args)
        self._kwargs = kwargs

    def __getitem__(self, key: Union[int, str]) -> Any:
        if isinstance(key, int):
            # Access list args by index.
            return self._args[key]
        elif isinstance(key, str):
            # Access kwarg by key.
            return self._kwargs[key]
        else:
            raise ValueError(
                "Please only use int index or str as first-level key to "
                "access fields of dag input."
            )