| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import collections |
| import logging |
| import threading |
| from typing import Callable, Optional, Type |
|
|
| import grpc |
| from grpc import _common |
| from grpc._cython import cygrpc |
| from grpc._typing import MetadataType |
|
|
| _LOGGER = logging.getLogger(__name__) |
|
|
|
|
| class _AuthMetadataContext( |
| collections.namedtuple( |
| "AuthMetadataContext", |
| ( |
| "service_url", |
| "method_name", |
| ), |
| ), |
| grpc.AuthMetadataContext, |
| ): |
| pass |
|
|
|
|
| class _CallbackState(object): |
| def __init__(self): |
| self.lock = threading.Lock() |
| self.called = False |
| self.exception = None |
|
|
|
|
| class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): |
| _state: _CallbackState |
| _callback: Callable |
|
|
| def __init__(self, state: _CallbackState, callback: Callable): |
| self._state = state |
| self._callback = callback |
|
|
| def __call__( |
| self, metadata: MetadataType, error: Optional[Type[BaseException]] |
| ): |
| with self._state.lock: |
| if self._state.exception is None: |
| if self._state.called: |
| raise RuntimeError( |
| "AuthMetadataPluginCallback invoked more than once!" |
| ) |
| else: |
| self._state.called = True |
| else: |
| raise RuntimeError( |
| 'AuthMetadataPluginCallback raised exception "{}"!'.format( |
| self._state.exception |
| ) |
| ) |
| if error is None: |
| self._callback(metadata, cygrpc.StatusCode.ok, None) |
| else: |
| self._callback( |
| None, cygrpc.StatusCode.internal, _common.encode(str(error)) |
| ) |
|
|
|
|
| class _Plugin(object): |
| _metadata_plugin: grpc.AuthMetadataPlugin |
|
|
| def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin): |
| self._metadata_plugin = metadata_plugin |
| self._stored_ctx = None |
|
|
| try: |
| import contextvars |
|
|
| |
| |
| |
| self._stored_ctx = contextvars.copy_context() |
| except ImportError: |
| |
| pass |
|
|
| def __call__(self, service_url: str, method_name: str, callback: Callable): |
| context = _AuthMetadataContext( |
| _common.decode(service_url), _common.decode(method_name) |
| ) |
| callback_state = _CallbackState() |
| try: |
| self._metadata_plugin( |
| context, _AuthMetadataPluginCallback(callback_state, callback) |
| ) |
| except Exception as exception: |
| _LOGGER.exception( |
| 'AuthMetadataPluginCallback "%s" raised exception!', |
| self._metadata_plugin, |
| ) |
| with callback_state.lock: |
| callback_state.exception = exception |
| if callback_state.called: |
| return |
| callback( |
| None, cygrpc.StatusCode.internal, _common.encode(str(exception)) |
| ) |
|
|
|
|
| def metadata_plugin_call_credentials( |
| metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str] |
| ) -> grpc.CallCredentials: |
| if name is None: |
| try: |
| effective_name = metadata_plugin.__name__ |
| except AttributeError: |
| effective_name = metadata_plugin.__class__.__name__ |
| else: |
| effective_name = name |
| return grpc.CallCredentials( |
| cygrpc.MetadataPluginCallCredentials( |
| _Plugin(metadata_plugin), _common.encode(effective_name) |
| ) |
| ) |
|
|