File size: 14,719 Bytes
80c179b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import importlib
import inspect
import json
import logging
import os
import sys
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import ray.util.client_connect
from ray._private.ray_constants import (
    RAY_ADDRESS_ENVIRONMENT_VARIABLE,
    RAY_NAMESPACE_ENVIRONMENT_VARIABLE,
    RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE,
)
from ray._private.utils import check_ray_client_dependencies_installed, split_address
from ray._private.worker import BaseContext
from ray._private.worker import init as ray_driver_init
from ray.job_config import JobConfig
from ray.util.annotations import Deprecated, PublicAPI

logger = logging.getLogger(__name__)

CLIENT_DOCS_URL = (
    "https://docs.ray.io/en/latest/cluster/running-applications/"
    "job-submission/ray-client.html"
)


@dataclass
@PublicAPI
class ClientContext(BaseContext):
    """
    Basic context manager for a ClientBuilder connection.

    `protocol_version` is no longer used.
    """

    dashboard_url: Optional[str]
    python_version: str
    ray_version: str
    ray_commit: str
    _num_clients: int
    _context_to_restore: Optional[ray.util.client.RayAPIStub]
    protocol_version: Optional[str] = None  # Deprecated

    def __enter__(self) -> "ClientContext":
        self._swap_context()
        return self

    def __exit__(self, *exc) -> None:
        self._disconnect_with_context(False)
        self._swap_context()

    def disconnect(self) -> None:
        self._swap_context()
        self._disconnect_with_context(True)
        self._swap_context()

    def _swap_context(self):
        if self._context_to_restore is not None:
            self._context_to_restore = ray.util.client.ray.set_context(
                self._context_to_restore
            )

    def _disconnect_with_context(self, force_disconnect: bool) -> None:
        """
        Disconnect Ray. If it's a ray client and created with `allow_multiple`,
        it will do nothing. For other cases this either disconnects from the
        remote Client Server or shuts the current driver down.
        """
        if ray.util.client.ray.is_connected():
            if ray.util.client.ray.is_default() or force_disconnect:
                # This is the only client connection
                ray.util.client_connect.disconnect()
        elif ray._private.worker.global_worker.node is None:
            # Already disconnected.
            return
        elif ray._private.worker.global_worker.node.is_head():
            logger.debug(
                "The current Ray Cluster is scoped to this process. "
                "Disconnecting is not possible as it will shutdown the "
                "cluster."
            )
        else:
            # This is only a driver connected to an existing cluster.
            ray.shutdown()


@Deprecated
class ClientBuilder:
    """
    Builder for a Ray Client connection. This class can be subclassed by
    custom builder classes to modify connection behavior to include additional
    features or altered semantics. One example is the ``_LocalClientBuilder``.
    """

    def __init__(self, address: Optional[str]) -> None:
        if not check_ray_client_dependencies_installed():
            raise ValueError(
                "Ray Client requires pip package `ray[client]`. "
                "If you installed the minimal Ray (e.g. `pip install ray`), "
                "please reinstall by executing `pip install ray[client]`."
            )
        self.address = address
        self._job_config = JobConfig()
        self._remote_init_kwargs = {}
        # Whether to allow connections to multiple clusters"
        # " (allow_multiple=True).
        self._allow_multiple_connections = False
        self._credentials = None
        self._metadata = None
        # Set to False if ClientBuilder is being constructed by internal
        # methods
        self._deprecation_warn_enabled = True

    def env(self, env: Dict[str, Any]) -> "ClientBuilder":
        """
        Set an environment for the session.
        Args:
            env (Dict[st, Any]): A runtime environment to use for this
            connection. See :ref:`runtime-environments` for what values are
            accepted in this dict.
        """
        self._job_config.set_runtime_env(env)
        return self

    def namespace(self, namespace: str) -> "ClientBuilder":
        """
        Sets the namespace for the session.
        Args:
            namespace: Namespace to use.
        """
        self._job_config.set_ray_namespace(namespace)
        return self

    def connect(self) -> ClientContext:
        """
        Begin a connection to the address passed in via ray.client(...).

        Returns:
            ClientInfo: Dataclass with information about the setting. This
                includes the server's version of Python & Ray as well as the
                dashboard_url.
        """
        if self._deprecation_warn_enabled:
            self._client_deprecation_warn()
        # Fill runtime env/namespace from environment if not already set.
        # Should be done *after* the deprecation warning, since warning will
        # check if those values are already set.
        self._fill_defaults_from_env()

        # If it has already connected to the cluster with allow_multiple=True,
        # connect to the default one is not allowed.
        # But if it has connected to the default one, connect to other clients
        # with allow_multiple=True is allowed
        default_cli_connected = ray.util.client.ray.is_connected()
        has_cli_connected = ray.util.client.num_connected_contexts() > 0
        if (
            not self._allow_multiple_connections
            and not default_cli_connected
            and has_cli_connected
        ):
            raise ValueError(
                "The client has already connected to the cluster "
                "with allow_multiple=True. Please set allow_multiple=True"
                " to proceed"
            )

        old_ray_cxt = None
        if self._allow_multiple_connections:
            old_ray_cxt = ray.util.client.ray.set_context(None)

        client_info_dict = ray.util.client_connect.connect(
            self.address,
            job_config=self._job_config,
            _credentials=self._credentials,
            ray_init_kwargs=self._remote_init_kwargs,
            metadata=self._metadata,
        )

        dashboard_url = ray.util.client.ray._get_dashboard_url()

        cxt = ClientContext(
            dashboard_url=dashboard_url,
            python_version=client_info_dict["python_version"],
            ray_version=client_info_dict["ray_version"],
            ray_commit=client_info_dict["ray_commit"],
            _num_clients=client_info_dict["num_clients"],
            _context_to_restore=ray.util.client.ray.get_context(),
        )
        if self._allow_multiple_connections:
            ray.util.client.ray.set_context(old_ray_cxt)
        return cxt

    def _fill_defaults_from_env(self):
        # Check environment variables for default values
        namespace_env_var = os.environ.get(RAY_NAMESPACE_ENVIRONMENT_VARIABLE)
        if namespace_env_var and self._job_config.ray_namespace is None:
            self.namespace(namespace_env_var)

        runtime_env_var = os.environ.get(RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE)
        if runtime_env_var and self._job_config.runtime_env is None:
            self.env(json.loads(runtime_env_var))

    def _init_args(self, **kwargs) -> "ClientBuilder":
        """
        When a client builder is constructed through ray.init, for example
        `ray.init(ray://..., namespace=...)`, all of the
        arguments passed into ray.init with non-default values are passed
        again into this method. Custom client builders can override this method
        to do their own handling/validation of arguments.
        """
        # Use namespace and runtime_env from ray.init call
        if kwargs.get("namespace") is not None:
            self.namespace(kwargs["namespace"])
            del kwargs["namespace"]
        if kwargs.get("runtime_env") is not None:
            self.env(kwargs["runtime_env"])
            del kwargs["runtime_env"]

        if kwargs.get("allow_multiple") is True:
            self._allow_multiple_connections = True
            del kwargs["allow_multiple"]

        if "_credentials" in kwargs.keys():
            self._credentials = kwargs["_credentials"]
            del kwargs["_credentials"]

        if "_metadata" in kwargs.keys():
            self._metadata = kwargs["_metadata"]
            del kwargs["_metadata"]

        if kwargs:
            expected_sig = inspect.signature(ray_driver_init)
            extra_args = set(kwargs.keys()).difference(expected_sig.parameters.keys())
            if len(extra_args) > 0:
                raise RuntimeError(
                    "Got unexpected kwargs: {}".format(", ".join(extra_args))
                )
            self._remote_init_kwargs = kwargs
            unknown = ", ".join(kwargs)
            logger.info(
                "Passing the following kwargs to ray.init() "
                f"on the server: {unknown}"
            )
        return self

    def _client_deprecation_warn(self) -> None:
        """
        Generates a warning for user's if this ClientBuilder instance was
        created directly or through ray.client, instead of relying on
        internal methods (ray.init, or auto init)
        """
        namespace = self._job_config.ray_namespace
        runtime_env = self._job_config.runtime_env
        replacement_args = []
        if self.address:
            if isinstance(self, _LocalClientBuilder):
                # Address might be set for LocalClientBuilder if ray.client()
                # is called while ray_current_cluster is set
                # (see _get_builder_from_address). In this case,
                # leave off the ray:// so the user attaches the driver directly
                replacement_args.append(f'"{self.address}"')
            else:
                replacement_args.append(f'"ray://{self.address}"')
        if namespace:
            replacement_args.append(f'namespace="{namespace}"')
        if runtime_env:
            # Use a placeholder here, since the real runtime_env would be
            # difficult to read if formatted in directly
            replacement_args.append("runtime_env=<your_runtime_env>")
        args_str = ", ".join(replacement_args)
        replacement_call = f"ray.init({args_str})"

        # Note: stack level is set to 3 since we want the warning to reach the
        # call to ray.client(...).connect(). The intervening frames are
        # connect() -> client_deprecation_warn() -> warnings.warn()
        # https://docs.python.org/3/library/warnings.html#available-functions
        warnings.warn(
            "Starting a connection through `ray.client` will be deprecated "
            "in future ray versions in favor of `ray.init`. See the docs for "
            f"more details: {CLIENT_DOCS_URL}. You can replace your call to "
            "`ray.client().connect()` with the following:\n"
            f"      {replacement_call}\n",
            DeprecationWarning,
            stacklevel=3,
        )


class _LocalClientBuilder(ClientBuilder):
    def connect(self) -> ClientContext:
        """
        Begin a connection to the address passed in via ray.client(...)
        """
        if self._deprecation_warn_enabled:
            self._client_deprecation_warn()
        # Fill runtime env/namespace from environment if not already set.
        # Should be done *after* the deprecation warning, since warning will
        # check if those values are already set.
        self._fill_defaults_from_env()

        connection_dict = ray.init(address=self.address, job_config=self._job_config)
        return ClientContext(
            dashboard_url=connection_dict["webui_url"],
            python_version="{}.{}.{}".format(
                sys.version_info[0], sys.version_info[1], sys.version_info[2]
            ),
            ray_version=ray.__version__,
            ray_commit=ray.__commit__,
            _num_clients=1,
            _context_to_restore=None,
        )


def _split_address(address: str) -> Tuple[str, str]:
    """
    Splits address into a module string (scheme) and an inner_address.

    If the scheme is not present, then "ray://" is prepended to the address.
    """
    if "://" not in address:
        address = "ray://" + address
    return split_address(address)


def _get_builder_from_address(address: Optional[str]) -> ClientBuilder:
    if address == "local":
        return _LocalClientBuilder("local")
    if address is None:
        # NOTE: This is not placed in `Node::get_temp_dir_path`, because
        # this file is accessed before the `Node` object is created.
        address = ray._private.services.canonicalize_bootstrap_address(address)
        return _LocalClientBuilder(address)
    module_string, inner_address = _split_address(address)
    try:
        module = importlib.import_module(module_string)
    except Exception as e:
        raise RuntimeError(
            f"Module: {module_string} does not exist.\n"
            f"This module was parsed from Address: {address}"
        ) from e
    assert "ClientBuilder" in dir(
        module
    ), f"Module: {module_string} does not have ClientBuilder."
    return module.ClientBuilder(inner_address)


@Deprecated
def client(
    address: Optional[str] = None, _deprecation_warn_enabled: bool = True
) -> ClientBuilder:
    """
    Creates a ClientBuilder based on the provided address. The address can be
    of the following forms:

        * None: Connects to or creates a local cluster and connects to it.
        * ``"local"``: Creates a new cluster locally and connects to it.
        * ``"IP:Port"``: Connects to a Ray Client Server at the given address.
        * ``"module://inner_address"``: load module.ClientBuilder & pass
            inner_address

    The _deprecation_warn_enabled flag enables deprecation warnings, and is
    for internal use only. Set it to False to suppress client deprecation
    warnings.
    """
    env_address = os.environ.get(RAY_ADDRESS_ENVIRONMENT_VARIABLE)
    if env_address and address is None:
        logger.debug(
            f"Using address ({env_address}) instead of auto-detection "
            f"because {RAY_ADDRESS_ENVIRONMENT_VARIABLE} is set."
        )
        address = env_address

    builder = _get_builder_from_address(address)
    # Disable client deprecation warn when ray.client is used internally
    builder._deprecation_warn_enabled = _deprecation_warn_enabled
    return builder