# Copyright 2016 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Service-side implementation of gRPC Python.""" from __future__ import annotations import abc import collections from concurrent import futures import contextvars import enum import logging import threading import time import traceback from typing import ( Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Union, ) import grpc # pytype: disable=pyi-error from grpc import _common # pytype: disable=pyi-error from grpc import _compression # pytype: disable=pyi-error from grpc import _interceptor # pytype: disable=pyi-error from grpc._cython import cygrpc from grpc._typing import ArityAgnosticMethodHandler from grpc._typing import ChannelArgumentType from grpc._typing import DeserializingFunction from grpc._typing import MetadataType from grpc._typing import NullaryCallbackType from grpc._typing import ResponseType from grpc._typing import SerializingFunction from grpc._typing import ServerCallbackTag from grpc._typing import ServerTagCallbackType _LOGGER = logging.getLogger(__name__) _SHUTDOWN_TAG = "shutdown" _REQUEST_CALL_TAG = "request_call" _RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server" _SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata" _RECEIVE_MESSAGE_TOKEN = "receive_message" _SEND_MESSAGE_TOKEN = "send_message" _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( "send_initial_metadata * send_message" ) _SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server" _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( "send_initial_metadata * send_status_from_server" ) _OPEN = "open" _CLOSED = "closed" _CANCELLED = "cancelled" _EMPTY_FLAGS = 0 _DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 _INF_TIMEOUT = 1e9 def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes: return request_event.batch_operations[0].message() def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode: cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code def _completion_code(state: _RPCState) -> cygrpc.StatusCode: if state.code is None: return cygrpc.StatusCode.ok else: return _application_code(state.code) def _abortion_code( state: _RPCState, code: cygrpc.StatusCode ) -> cygrpc.StatusCode: if state.code is None: return code else: return _application_code(state.code) def _details(state: _RPCState) -> bytes: return b"" if state.details is None else state.details class _HandlerCallDetails( collections.namedtuple( "_HandlerCallDetails", ( "method", "invocation_metadata", ), ), grpc.HandlerCallDetails, ): pass class _Method(abc.ABC): @abc.abstractmethod def name(self) -> Optional[str]: raise NotImplementedError() @abc.abstractmethod def handler( self, handler_call_details: _HandlerCallDetails ) -> Optional[grpc.RpcMethodHandler]: raise NotImplementedError() class _RegisteredMethod(_Method): def __init__( self, name: str, registered_handler: Optional[grpc.RpcMethodHandler], ): self._name = name self._registered_handler = registered_handler def name(self) -> Optional[str]: return self._name def handler( self, handler_call_details: _HandlerCallDetails ) -> Optional[grpc.RpcMethodHandler]: return self._registered_handler class _GenericMethod(_Method): def __init__( self, generic_handlers: List[grpc.GenericRpcHandler], ): self._generic_handlers = generic_handlers def name(self) -> Optional[str]: return None def handler( self, handler_call_details: _HandlerCallDetails ) -> Optional[grpc.RpcMethodHandler]: # If the same method have both generic and registered handler, # registered handler will take precedence. for generic_handler in self._generic_handlers: method_handler = generic_handler.service(handler_call_details) if method_handler is not None: return method_handler return None class _RPCState(object): context: contextvars.Context condition: threading.Condition due = Set[str] request: Any client: str initial_metadata_allowed: bool compression_algorithm: Optional[grpc.Compression] disable_next_compression: bool trailing_metadata: Optional[MetadataType] code: Optional[grpc.StatusCode] details: Optional[bytes] statused: bool rpc_errors: List[Exception] callbacks: Optional[List[NullaryCallbackType]] aborted: bool def __init__(self): self.context = contextvars.Context() self.condition = threading.Condition() self.due = set() self.request = None self.client = _OPEN self.initial_metadata_allowed = True self.compression_algorithm = None self.disable_next_compression = False self.trailing_metadata = None self.code = None self.details = None self.statused = False self.rpc_errors = [] self.callbacks = [] self.aborted = False def _raise_rpc_error(state: _RPCState) -> None: rpc_error = grpc.RpcError() state.rpc_errors.append(rpc_error) raise rpc_error def _possibly_finish_call( state: _RPCState, token: str ) -> ServerTagCallbackType: state.due.remove(token) if not _is_rpc_state_active(state) and not state.due: callbacks = state.callbacks state.callbacks = None return state, callbacks else: return None, () def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag: def send_status_from_server(unused_send_status_from_server_event): with state.condition: return _possibly_finish_call(state, token) return send_status_from_server def _get_initial_metadata( state: _RPCState, metadata: Optional[MetadataType] ) -> Optional[MetadataType]: with state.condition: if state.compression_algorithm: compression_metadata = ( _compression.compression_algorithm_to_metadata( state.compression_algorithm ), ) if metadata is None: return compression_metadata else: return compression_metadata + tuple(metadata) else: return metadata def _get_initial_metadata_operation( state: _RPCState, metadata: Optional[MetadataType] ) -> cygrpc.Operation: operation = cygrpc.SendInitialMetadataOperation( _get_initial_metadata(state, metadata), _EMPTY_FLAGS ) return operation def _abort( state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes ) -> None: if state.client is not _CANCELLED: effective_code = _abortion_code(state, code) effective_details = details if state.details is None else state.details if state.initial_metadata_allowed: operations = ( _get_initial_metadata_operation(state, None), cygrpc.SendStatusFromServerOperation( state.trailing_metadata, effective_code, effective_details, _EMPTY_FLAGS, ), ) token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN else: operations = ( cygrpc.SendStatusFromServerOperation( state.trailing_metadata, effective_code, effective_details, _EMPTY_FLAGS, ), ) token = _SEND_STATUS_FROM_SERVER_TOKEN call.start_server_batch( operations, _send_status_from_server(state, token) ) state.statused = True state.due.add(token) def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag: def receive_close_on_server(receive_close_on_server_event): with state.condition: if receive_close_on_server_event.batch_operations[0].cancelled(): state.client = _CANCELLED elif state.client is _OPEN: state.client = _CLOSED state.condition.notify_all() return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) return receive_close_on_server def _receive_message( state: _RPCState, call: cygrpc.Call, request_deserializer: Optional[DeserializingFunction], ) -> ServerCallbackTag: def receive_message(receive_message_event): serialized_request = _serialized_request(receive_message_event) if serialized_request is None: with state.condition: if state.client is _OPEN: state.client = _CLOSED state.condition.notify_all() return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) else: request = _common.deserialize( serialized_request, request_deserializer ) with state.condition: if request is None: _abort( state, call, cygrpc.StatusCode.internal, b"Exception deserializing request!", ) else: state.request = request state.condition.notify_all() return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) return receive_message def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag: def send_initial_metadata(unused_send_initial_metadata_event): with state.condition: return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) return send_initial_metadata def _send_message(state: _RPCState, token: str) -> ServerCallbackTag: def send_message(unused_send_message_event): with state.condition: state.condition.notify_all() return _possibly_finish_call(state, token) return send_message class _Context(grpc.ServicerContext): _rpc_event: cygrpc.BaseEvent _state: _RPCState request_deserializer: Optional[DeserializingFunction] def __init__( self, rpc_event: cygrpc.BaseEvent, state: _RPCState, request_deserializer: Optional[DeserializingFunction], ): self._rpc_event = rpc_event self._state = state self._request_deserializer = request_deserializer def is_active(self) -> bool: with self._state.condition: return _is_rpc_state_active(self._state) def time_remaining(self) -> float: return max(self._rpc_event.call_details.deadline - time.time(), 0) def cancel(self) -> None: self._rpc_event.call.cancel() def add_callback(self, callback: NullaryCallbackType) -> bool: with self._state.condition: if self._state.callbacks is None: return False else: self._state.callbacks.append(callback) return True def disable_next_message_compression(self) -> None: with self._state.condition: self._state.disable_next_compression = True def invocation_metadata(self) -> Optional[MetadataType]: return self._rpc_event.invocation_metadata def peer(self) -> str: return _common.decode(self._rpc_event.call.peer()) def peer_identities(self) -> Optional[Sequence[bytes]]: return cygrpc.peer_identities(self._rpc_event.call) def peer_identity_key(self) -> Optional[str]: id_key = cygrpc.peer_identity_key(self._rpc_event.call) return id_key if id_key is None else _common.decode(id_key) def auth_context(self) -> Mapping[str, Sequence[bytes]]: auth_context = cygrpc.auth_context(self._rpc_event.call) auth_context_dict = {} if auth_context is None else auth_context return { _common.decode(key): value for key, value in auth_context_dict.items() } def set_compression(self, compression: grpc.Compression) -> None: with self._state.condition: self._state.compression_algorithm = compression def send_initial_metadata(self, initial_metadata: MetadataType) -> None: with self._state.condition: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) else: if self._state.initial_metadata_allowed: operation = _get_initial_metadata_operation( self._state, initial_metadata ) self._rpc_event.call.start_server_batch( (operation,), _send_initial_metadata(self._state) ) self._state.initial_metadata_allowed = False self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) else: raise ValueError("Initial metadata no longer allowed!") def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: with self._state.condition: self._state.trailing_metadata = trailing_metadata def trailing_metadata(self) -> Optional[MetadataType]: return self._state.trailing_metadata def abort(self, code: grpc.StatusCode, details: str) -> None: # treat OK like other invalid arguments: fail the RPC if code == grpc.StatusCode.OK: _LOGGER.error( "abort() called with StatusCode.OK; returning UNKNOWN" ) code = grpc.StatusCode.UNKNOWN details = "" with self._state.condition: self._state.code = code self._state.details = _common.encode(details) self._state.aborted = True raise Exception() def abort_with_status(self, status: grpc.Status) -> None: self._state.trailing_metadata = status.trailing_metadata self.abort(status.code, status.details) def set_code(self, code: grpc.StatusCode) -> None: with self._state.condition: self._state.code = code def code(self) -> grpc.StatusCode: return self._state.code def set_details(self, details: str) -> None: with self._state.condition: self._state.details = _common.encode(details) def details(self) -> bytes: return self._state.details def _finalize_state(self) -> None: pass class _RequestIterator(object): _state: _RPCState _call: cygrpc.Call _request_deserializer: Optional[DeserializingFunction] def __init__( self, state: _RPCState, call: cygrpc.Call, request_deserializer: Optional[DeserializingFunction], ): self._state = state self._call = call self._request_deserializer = request_deserializer def _raise_or_start_receive_message(self) -> None: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) elif not _is_rpc_state_active(self._state): raise StopIteration() else: self._call.start_server_batch( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), _receive_message( self._state, self._call, self._request_deserializer ), ) self._state.due.add(_RECEIVE_MESSAGE_TOKEN) def _look_for_request(self) -> Any: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) elif ( self._state.request is None and _RECEIVE_MESSAGE_TOKEN not in self._state.due ): raise StopIteration() else: request = self._state.request self._state.request = None return request raise AssertionError() # should never run def _next(self) -> Any: with self._state.condition: self._raise_or_start_receive_message() while True: self._state.condition.wait() request = self._look_for_request() if request is not None: return request def __iter__(self) -> _RequestIterator: return self def __next__(self) -> Any: return self._next() def next(self) -> Any: return self._next() def _unary_request( rpc_event: cygrpc.BaseEvent, state: _RPCState, request_deserializer: Optional[DeserializingFunction], ) -> Callable[[], Any]: def unary_request(): with state.condition: if not _is_rpc_state_active(state): return None else: rpc_event.call.start_server_batch( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), _receive_message( state, rpc_event.call, request_deserializer ), ) state.due.add(_RECEIVE_MESSAGE_TOKEN) while True: state.condition.wait() if state.request is None: if state.client is _CLOSED: details = '"{}" requires exactly one request message.'.format( rpc_event.call_details.method ) _abort( state, rpc_event.call, cygrpc.StatusCode.unimplemented, _common.encode(details), ) return None elif state.client is _CANCELLED: return None else: request = state.request state.request = None return request return unary_request def _call_behavior( rpc_event: cygrpc.BaseEvent, state: _RPCState, behavior: ArityAgnosticMethodHandler, argument: Any, request_deserializer: Optional[DeserializingFunction], send_response_callback: Optional[Callable[[ResponseType], None]] = None, ) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]: from grpc import _create_servicer_context # pytype: disable=pyi-error with _create_servicer_context( rpc_event, state, request_deserializer ) as context: try: response_or_iterator = None if send_response_callback is not None: response_or_iterator = behavior( argument, context, send_response_callback ) else: response_or_iterator = behavior(argument, context) return response_or_iterator, True except Exception as exception: # pylint: disable=broad-except with state.condition: if state.aborted: _abort( state, rpc_event.call, cygrpc.StatusCode.unknown, b"RPC Aborted", ) elif exception not in state.rpc_errors: try: details = "Exception calling application: {}".format( exception ) except Exception: # pylint: disable=broad-except details = ( "Calling application raised unprintable Exception!" ) _LOGGER.exception( traceback.format_exception( type(exception), exception, exception.__traceback__, ) ) traceback.print_exc() _LOGGER.exception(details) _abort( state, rpc_event.call, cygrpc.StatusCode.unknown, _common.encode(details), ) return None, False def _take_response_from_response_iterator( rpc_event: cygrpc.BaseEvent, state: _RPCState, response_iterator: Iterator[ResponseType], ) -> Tuple[ResponseType, bool]: try: return next(response_iterator), True except StopIteration: return None, True except Exception as exception: # pylint: disable=broad-except with state.condition: if state.aborted: _abort( state, rpc_event.call, cygrpc.StatusCode.unknown, b"RPC Aborted", ) elif exception not in state.rpc_errors: details = "Exception iterating responses: {}".format(exception) _LOGGER.exception(details) _abort( state, rpc_event.call, cygrpc.StatusCode.unknown, _common.encode(details), ) return None, False def _serialize_response( rpc_event: cygrpc.BaseEvent, state: _RPCState, response: Any, response_serializer: Optional[SerializingFunction], ) -> Optional[bytes]: serialized_response = _common.serialize(response, response_serializer) if serialized_response is None: with state.condition: _abort( state, rpc_event.call, cygrpc.StatusCode.internal, b"Failed to serialize response!", ) return None else: return serialized_response def _get_send_message_op_flags_from_state( state: _RPCState, ) -> Union[int, cygrpc.WriteFlag]: if state.disable_next_compression: return cygrpc.WriteFlag.no_compress else: return _EMPTY_FLAGS def _reset_per_message_state(state: _RPCState) -> None: with state.condition: state.disable_next_compression = False def _send_response( rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes ) -> bool: with state.condition: if not _is_rpc_state_active(state): return False else: if state.initial_metadata_allowed: operations = ( _get_initial_metadata_operation(state, None), cygrpc.SendMessageOperation( serialized_response, _get_send_message_op_flags_from_state(state), ), ) state.initial_metadata_allowed = False token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN else: operations = ( cygrpc.SendMessageOperation( serialized_response, _get_send_message_op_flags_from_state(state), ), ) token = _SEND_MESSAGE_TOKEN rpc_event.call.start_server_batch( operations, _send_message(state, token) ) state.due.add(token) _reset_per_message_state(state) while True: state.condition.wait() if token not in state.due: return _is_rpc_state_active(state) def _status( rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: Optional[bytes], ) -> None: with state.condition: if state.client is not _CANCELLED: code = _completion_code(state) details = _details(state) operations = [ cygrpc.SendStatusFromServerOperation( state.trailing_metadata, code, details, _EMPTY_FLAGS ), ] if state.initial_metadata_allowed: operations.append(_get_initial_metadata_operation(state, None)) if serialized_response is not None: operations.append( cygrpc.SendMessageOperation( serialized_response, _get_send_message_op_flags_from_state(state), ) ) rpc_event.call.start_server_batch( operations, _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN), ) state.statused = True _reset_per_message_state(state) state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) def _unary_response_in_pool( rpc_event: cygrpc.BaseEvent, state: _RPCState, behavior: ArityAgnosticMethodHandler, argument_thunk: Callable[[], Any], request_deserializer: Optional[SerializingFunction], response_serializer: Optional[SerializingFunction], ) -> None: cygrpc.install_context_from_request_call_event(rpc_event) try: argument = argument_thunk() if argument is not None: response, proceed = _call_behavior( rpc_event, state, behavior, argument, request_deserializer ) if proceed: serialized_response = _serialize_response( rpc_event, state, response, response_serializer ) if serialized_response is not None: _status(rpc_event, state, serialized_response) except Exception: # pylint: disable=broad-except traceback.print_exc() finally: cygrpc.uninstall_context() def _stream_response_in_pool( rpc_event: cygrpc.BaseEvent, state: _RPCState, behavior: ArityAgnosticMethodHandler, argument_thunk: Callable[[], Any], request_deserializer: Optional[DeserializingFunction], response_serializer: Optional[SerializingFunction], ) -> None: cygrpc.install_context_from_request_call_event(rpc_event) def send_response(response: Any) -> None: if response is None: _status(rpc_event, state, None) else: serialized_response = _serialize_response( rpc_event, state, response, response_serializer ) if serialized_response is not None: _send_response(rpc_event, state, serialized_response) try: argument = argument_thunk() if argument is not None: if ( hasattr(behavior, "experimental_non_blocking") and behavior.experimental_non_blocking ): _call_behavior( rpc_event, state, behavior, argument, request_deserializer, send_response_callback=send_response, ) else: response_iterator, proceed = _call_behavior( rpc_event, state, behavior, argument, request_deserializer ) if proceed: _send_message_callback_to_blocking_iterator_adapter( rpc_event, state, send_response, response_iterator ) except Exception: # pylint: disable=broad-except traceback.print_exc() finally: cygrpc.uninstall_context() def _is_rpc_state_active(state: _RPCState) -> bool: return state.client is not _CANCELLED and not state.statused def _send_message_callback_to_blocking_iterator_adapter( rpc_event: cygrpc.BaseEvent, state: _RPCState, send_response_callback: Callable[[ResponseType], None], response_iterator: Iterator[ResponseType], ) -> None: while True: response, proceed = _take_response_from_response_iterator( rpc_event, state, response_iterator ) if proceed: send_response_callback(response) if not _is_rpc_state_active(state): break else: break def _select_thread_pool_for_behavior( behavior: ArityAgnosticMethodHandler, default_thread_pool: futures.ThreadPoolExecutor, ) -> futures.ThreadPoolExecutor: if hasattr(behavior, "experimental_thread_pool") and isinstance( behavior.experimental_thread_pool, futures.ThreadPoolExecutor ): return behavior.experimental_thread_pool else: return default_thread_pool def _handle_unary_unary( rpc_event: cygrpc.BaseEvent, state: _RPCState, method_handler: grpc.RpcMethodHandler, default_thread_pool: futures.ThreadPoolExecutor, ) -> futures.Future: unary_request = _unary_request( rpc_event, state, method_handler.request_deserializer ) thread_pool = _select_thread_pool_for_behavior( method_handler.unary_unary, default_thread_pool ) return thread_pool.submit( state.context.run, _unary_response_in_pool, rpc_event, state, method_handler.unary_unary, unary_request, method_handler.request_deserializer, method_handler.response_serializer, ) def _handle_unary_stream( rpc_event: cygrpc.BaseEvent, state: _RPCState, method_handler: grpc.RpcMethodHandler, default_thread_pool: futures.ThreadPoolExecutor, ) -> futures.Future: unary_request = _unary_request( rpc_event, state, method_handler.request_deserializer ) thread_pool = _select_thread_pool_for_behavior( method_handler.unary_stream, default_thread_pool ) return thread_pool.submit( state.context.run, _stream_response_in_pool, rpc_event, state, method_handler.unary_stream, unary_request, method_handler.request_deserializer, method_handler.response_serializer, ) def _handle_stream_unary( rpc_event: cygrpc.BaseEvent, state: _RPCState, method_handler: grpc.RpcMethodHandler, default_thread_pool: futures.ThreadPoolExecutor, ) -> futures.Future: request_iterator = _RequestIterator( state, rpc_event.call, method_handler.request_deserializer ) thread_pool = _select_thread_pool_for_behavior( method_handler.stream_unary, default_thread_pool ) return thread_pool.submit( state.context.run, _unary_response_in_pool, rpc_event, state, method_handler.stream_unary, lambda: request_iterator, method_handler.request_deserializer, method_handler.response_serializer, ) def _handle_stream_stream( rpc_event: cygrpc.BaseEvent, state: _RPCState, method_handler: grpc.RpcMethodHandler, default_thread_pool: futures.ThreadPoolExecutor, ) -> futures.Future: request_iterator = _RequestIterator( state, rpc_event.call, method_handler.request_deserializer ) thread_pool = _select_thread_pool_for_behavior( method_handler.stream_stream, default_thread_pool ) return thread_pool.submit( state.context.run, _stream_response_in_pool, rpc_event, state, method_handler.stream_stream, lambda: request_iterator, method_handler.request_deserializer, method_handler.response_serializer, ) def _find_method_handler( rpc_event: cygrpc.BaseEvent, state: _RPCState, method_with_handler: _Method, interceptor_pipeline: Optional[_interceptor._ServicePipeline], ) -> Optional[grpc.RpcMethodHandler]: def query_handlers( handler_call_details: _HandlerCallDetails, ) -> Optional[grpc.RpcMethodHandler]: return method_with_handler.handler(handler_call_details) method_name = method_with_handler.name() if not method_name: method_name = _common.decode(rpc_event.call_details.method) handler_call_details = _HandlerCallDetails( method_name, rpc_event.invocation_metadata, ) if interceptor_pipeline is not None: return state.context.run( interceptor_pipeline.execute, query_handlers, handler_call_details ) else: return state.context.run(query_handlers, handler_call_details) def _reject_rpc( rpc_event: cygrpc.BaseEvent, rpc_state: _RPCState, status: cygrpc.StatusCode, details: bytes, ): operations = ( _get_initial_metadata_operation(rpc_state, None), cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), cygrpc.SendStatusFromServerOperation( None, status, details, _EMPTY_FLAGS ), ) rpc_event.call.start_server_batch( operations, lambda ignored_event: ( rpc_state, (), ), ) def _handle_with_method_handler( rpc_event: cygrpc.BaseEvent, state: _RPCState, method_handler: grpc.RpcMethodHandler, thread_pool: futures.ThreadPoolExecutor, ) -> futures.Future: with state.condition: rpc_event.call.start_server_batch( (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), _receive_close_on_server(state), ) state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) if method_handler.request_streaming: if method_handler.response_streaming: return _handle_stream_stream( rpc_event, state, method_handler, thread_pool ) else: return _handle_stream_unary( rpc_event, state, method_handler, thread_pool ) else: if method_handler.response_streaming: return _handle_unary_stream( rpc_event, state, method_handler, thread_pool ) else: return _handle_unary_unary( rpc_event, state, method_handler, thread_pool ) def _handle_call( rpc_event: cygrpc.BaseEvent, method_with_handler: _Method, interceptor_pipeline: Optional[_interceptor._ServicePipeline], thread_pool: futures.ThreadPoolExecutor, concurrency_exceeded: bool, ) -> Tuple[Optional[_RPCState], Optional[futures.Future]]: """Handles RPC based on provided handlers. When receiving a call event from Core, registered method will have it's name as tag, we pass the tag as registered_method_name to this method, then we can find the handler in registered_method_handlers based on the method name. For call event with unregistered method, the method name will be included in rpc_event.call_details.method and we need to query the generics handlers to find the actual handler. """ if not rpc_event.success: return None, None if rpc_event.call_details.method or method_with_handler.name(): rpc_state = _RPCState() try: method_handler = _find_method_handler( rpc_event, rpc_state, method_with_handler, interceptor_pipeline, ) except Exception as exception: # pylint: disable=broad-except details = "Exception servicing handler: {}".format(exception) _LOGGER.exception(details) _reject_rpc( rpc_event, rpc_state, cygrpc.StatusCode.unknown, b"Error in service handler!", ) return rpc_state, None if method_handler is None: _reject_rpc( rpc_event, rpc_state, cygrpc.StatusCode.unimplemented, b"Method not found!", ) return rpc_state, None elif concurrency_exceeded: _reject_rpc( rpc_event, rpc_state, cygrpc.StatusCode.resource_exhausted, b"Concurrent RPC limit exceeded!", ) return rpc_state, None else: return ( rpc_state, _handle_with_method_handler( rpc_event, rpc_state, method_handler, thread_pool ), ) else: return None, None @enum.unique class _ServerStage(enum.Enum): STOPPED = "stopped" STARTED = "started" GRACE = "grace" class _ServerState(object): lock: threading.RLock completion_queue: cygrpc.CompletionQueue server: cygrpc.Server generic_handlers: List[grpc.GenericRpcHandler] registered_method_handlers: Dict[str, grpc.RpcMethodHandler] interceptor_pipeline: Optional[_interceptor._ServicePipeline] thread_pool: futures.ThreadPoolExecutor stage: _ServerStage termination_event: threading.Event shutdown_events: List[threading.Event] maximum_concurrent_rpcs: Optional[int] active_rpc_count: int rpc_states: Set[_RPCState] due: Set[str] server_deallocated: bool # pylint: disable=too-many-arguments def __init__( self, completion_queue: cygrpc.CompletionQueue, server: cygrpc.Server, generic_handlers: Sequence[grpc.GenericRpcHandler], interceptor_pipeline: Optional[_interceptor._ServicePipeline], thread_pool: futures.ThreadPoolExecutor, maximum_concurrent_rpcs: Optional[int], ): self.lock = threading.RLock() self.completion_queue = completion_queue self.server = server self.generic_handlers = list(generic_handlers) self.interceptor_pipeline = interceptor_pipeline self.thread_pool = thread_pool self.stage = _ServerStage.STOPPED self.termination_event = threading.Event() self.shutdown_events = [self.termination_event] self.maximum_concurrent_rpcs = maximum_concurrent_rpcs self.active_rpc_count = 0 self.registered_method_handlers = {} # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. self.rpc_states = set() self.due = set() # A "volatile" flag to interrupt the daemon serving thread self.server_deallocated = False def _add_generic_handlers( state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler] ) -> None: with state.lock: state.generic_handlers.extend(generic_handlers) def _add_registered_method_handlers( state: _ServerState, method_handlers: Dict[str, grpc.RpcMethodHandler] ) -> None: with state.lock: state.registered_method_handlers.update(method_handlers) def _add_insecure_port(state: _ServerState, address: bytes) -> int: with state.lock: return state.server.add_http2_port(address) def _add_secure_port( state: _ServerState, address: bytes, server_credentials: grpc.ServerCredentials, ) -> int: with state.lock: return state.server.add_http2_port( address, server_credentials._credentials ) def _request_call(state: _ServerState) -> None: state.server.request_call( state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG ) state.due.add(_REQUEST_CALL_TAG) def _request_registered_call(state: _ServerState, method: str) -> None: registered_call_tag = method state.server.request_registered_call( state.completion_queue, state.completion_queue, method, registered_call_tag, ) state.due.add(registered_call_tag) # TODO(https://github.com/grpc/grpc/issues/6597): delete this function. def _stop_serving(state: _ServerState) -> bool: if not state.rpc_states and not state.due: state.server.destroy() for shutdown_event in state.shutdown_events: shutdown_event.set() state.stage = _ServerStage.STOPPED return True else: return False def _on_call_completed(state: _ServerState) -> None: with state.lock: state.active_rpc_count -= 1 # pylint: disable=too-many-branches def _process_event_and_continue( state: _ServerState, event: cygrpc.BaseEvent ) -> bool: should_continue = True if event.tag is _SHUTDOWN_TAG: with state.lock: state.due.remove(_SHUTDOWN_TAG) if _stop_serving(state): should_continue = False elif ( event.tag is _REQUEST_CALL_TAG or event.tag in state.registered_method_handlers.keys() ): registered_method_name = None if event.tag in state.registered_method_handlers.keys(): registered_method_name = event.tag method_with_handler = _RegisteredMethod( registered_method_name, state.registered_method_handlers.get( registered_method_name, None ), ) else: method_with_handler = _GenericMethod( state.generic_handlers, ) with state.lock: state.due.remove(event.tag) concurrency_exceeded = ( state.maximum_concurrent_rpcs is not None and state.active_rpc_count >= state.maximum_concurrent_rpcs ) rpc_state, rpc_future = _handle_call( event, method_with_handler, state.interceptor_pipeline, state.thread_pool, concurrency_exceeded, ) if rpc_state is not None: state.rpc_states.add(rpc_state) if rpc_future is not None: state.active_rpc_count += 1 rpc_future.add_done_callback( lambda unused_future: _on_call_completed(state) ) if state.stage is _ServerStage.STARTED: if ( registered_method_name in state.registered_method_handlers.keys() ): _request_registered_call(state, registered_method_name) else: _request_call(state) elif _stop_serving(state): should_continue = False else: rpc_state, callbacks = event.tag(event) for callback in callbacks: try: callback() except Exception: # pylint: disable=broad-except _LOGGER.exception("Exception calling callback!") if rpc_state is not None: with state.lock: state.rpc_states.remove(rpc_state) if _stop_serving(state): should_continue = False return should_continue def _serve(state: _ServerState) -> None: while True: timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S event = state.completion_queue.poll(timeout) if state.server_deallocated: _begin_shutdown_once(state) if event.completion_type != cygrpc.CompletionType.queue_timeout: if not _process_event_and_continue(state, event): return # We want to force the deletion of the previous event # ~before~ we poll again; if the event has a reference # to a shutdown Call object, this can induce spinlock. event = None def _begin_shutdown_once(state: _ServerState) -> None: with state.lock: if state.stage is _ServerStage.STARTED: state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) state.stage = _ServerStage.GRACE state.due.add(_SHUTDOWN_TAG) def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event: with state.lock: if state.stage is _ServerStage.STOPPED: shutdown_event = threading.Event() shutdown_event.set() return shutdown_event else: _begin_shutdown_once(state) shutdown_event = threading.Event() state.shutdown_events.append(shutdown_event) if grace is None: state.server.cancel_all_calls() else: def cancel_all_calls_after_grace(): shutdown_event.wait(timeout=grace) with state.lock: state.server.cancel_all_calls() thread = threading.Thread(target=cancel_all_calls_after_grace) thread.start() return shutdown_event shutdown_event.wait() return shutdown_event def _start(state: _ServerState) -> None: with state.lock: if state.stage is not _ServerStage.STOPPED: raise ValueError("Cannot start already-started server!") state.server.start() state.stage = _ServerStage.STARTED # Request a call for each registered method so we can handle any of them. for method in state.registered_method_handlers.keys(): _request_registered_call(state, method) # Also request a call for non-registered method. _request_call(state) thread = threading.Thread(target=_serve, args=(state,)) thread.daemon = True thread.start() def _validate_generic_rpc_handlers( generic_rpc_handlers: Iterable[grpc.GenericRpcHandler], ) -> None: for generic_rpc_handler in generic_rpc_handlers: service_attribute = getattr(generic_rpc_handler, "service", None) if service_attribute is None: raise AttributeError( '"{}" must conform to grpc.GenericRpcHandler type but does ' 'not have "service" method!'.format(generic_rpc_handler) ) def _augment_options( base_options: Sequence[ChannelArgumentType], compression: Optional[grpc.Compression], ) -> Sequence[ChannelArgumentType]: compression_option = _compression.create_channel_option(compression) return tuple(base_options) + compression_option class _Server(grpc.Server): _state: _ServerState # pylint: disable=too-many-arguments def __init__( self, thread_pool: futures.ThreadPoolExecutor, generic_handlers: Sequence[grpc.GenericRpcHandler], interceptors: Sequence[grpc.ServerInterceptor], options: Sequence[ChannelArgumentType], maximum_concurrent_rpcs: Optional[int], compression: Optional[grpc.Compression], xds: bool, ): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server(_augment_options(options, compression), xds) server.register_completion_queue(completion_queue) self._state = _ServerState( completion_queue, server, generic_handlers, _interceptor.service_pipeline(interceptors), thread_pool, maximum_concurrent_rpcs, ) self._cy_server = server def add_generic_rpc_handlers( self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler] ) -> None: _validate_generic_rpc_handlers(generic_rpc_handlers) _add_generic_handlers(self._state, generic_rpc_handlers) def add_registered_method_handlers( self, service_name: str, method_handlers: Dict[str, grpc.RpcMethodHandler], ) -> None: # Can't register method once server started. with self._state.lock: if self._state.stage is _ServerStage.STARTED: return # TODO(xuanwn): We should validate method_handlers first. method_to_handlers = { _common.fully_qualified_method(service_name, method): method_handler for method, method_handler in method_handlers.items() } for fully_qualified_method in method_to_handlers.keys(): self._cy_server.register_method(fully_qualified_method) _add_registered_method_handlers(self._state, method_to_handlers) def add_insecure_port(self, address: str) -> int: return _common.validate_port_binding_result( address, _add_insecure_port(self._state, _common.encode(address)) ) def add_secure_port( self, address: str, server_credentials: grpc.ServerCredentials ) -> int: return _common.validate_port_binding_result( address, _add_secure_port( self._state, _common.encode(address), server_credentials ), ) def start(self) -> None: _start(self._state) def wait_for_termination(self, timeout: Optional[float] = None) -> bool: # NOTE(https://bugs.python.org/issue35935) # Remove this workaround once threading.Event.wait() is working with # CTRL+C across platforms. return _common.wait( self._state.termination_event.wait, self._state.termination_event.is_set, timeout=timeout, ) def stop(self, grace: Optional[float]) -> threading.Event: return _stop(self._state, grace) def __del__(self): if hasattr(self, "_state"): # We can not grab a lock in __del__(), so set a flag to signal the # serving daemon thread (if it exists) to initiate shutdown. self._state.server_deallocated = True def create_server( thread_pool: futures.ThreadPoolExecutor, generic_rpc_handlers: Sequence[grpc.GenericRpcHandler], interceptors: Sequence[grpc.ServerInterceptor], options: Sequence[ChannelArgumentType], maximum_concurrent_rpcs: Optional[int], compression: Optional[grpc.Compression], xds: bool, ) -> _Server: _validate_generic_rpc_handlers(generic_rpc_handlers) return _Server( thread_pool, generic_rpc_handlers, interceptors, options, maximum_concurrent_rpcs, compression, xds, )