# Copyright 2015 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import logging import threading from typing import Callable, Optional, Type import grpc from grpc import _common from grpc._cython import cygrpc from grpc._typing import MetadataType _LOGGER = logging.getLogger(__name__) class _AuthMetadataContext( collections.namedtuple('AuthMetadataContext', ( 'service_url', 'method_name', )), grpc.AuthMetadataContext): pass class _CallbackState(object): def __init__(self): self.lock = threading.Lock() self.called = False self.exception = None class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): _state: _CallbackState _callback: Callable def __init__(self, state: _CallbackState, callback: Callable): self._state = state self._callback = callback def __call__(self, metadata: MetadataType, error: Optional[Type[BaseException]]): with self._state.lock: if self._state.exception is None: if self._state.called: raise RuntimeError( 'AuthMetadataPluginCallback invoked more than once!') else: self._state.called = True else: raise RuntimeError( 'AuthMetadataPluginCallback raised exception "{}"!'.format( self._state.exception)) if error is None: self._callback(metadata, cygrpc.StatusCode.ok, None) else: self._callback(None, cygrpc.StatusCode.internal, _common.encode(str(error))) class _Plugin(object): _metadata_plugin: grpc.AuthMetadataPlugin def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin): self._metadata_plugin = metadata_plugin self._stored_ctx = None try: import contextvars # pylint: disable=wrong-import-position # The plugin may be invoked on a thread created by Core, which will not # have the context propagated. This context is stored and installed in # the thread invoking the plugin. self._stored_ctx = contextvars.copy_context() except ImportError: # Support versions predating contextvars. pass def __call__(self, service_url: str, method_name: str, callback: Callable): context = _AuthMetadataContext(_common.decode(service_url), _common.decode(method_name)) callback_state = _CallbackState() try: self._metadata_plugin( context, _AuthMetadataPluginCallback(callback_state, callback)) except Exception as exception: # pylint: disable=broad-except _LOGGER.exception( 'AuthMetadataPluginCallback "%s" raised exception!', self._metadata_plugin) with callback_state.lock: callback_state.exception = exception if callback_state.called: return callback(None, cygrpc.StatusCode.internal, _common.encode(str(exception))) def metadata_plugin_call_credentials( metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str]) -> grpc.CallCredentials: if name is None: try: effective_name = metadata_plugin.__name__ except AttributeError: effective_name = metadata_plugin.__class__.__name__ else: effective_name = name return grpc.CallCredentials( cygrpc.MetadataPluginCallCredentials(_Plugin(metadata_plugin), _common.encode(effective_name)))