2268 lines
79 KiB
Python
2268 lines
79 KiB
Python
# Copyright 2016 gRPC authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Invocation-side implementation of gRPC Python."""
|
|
|
|
import copy
|
|
import functools
|
|
import logging
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
import types
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
import grpc # pytype: disable=pyi-error
|
|
from grpc import _common # pytype: disable=pyi-error
|
|
from grpc import _compression # pytype: disable=pyi-error
|
|
from grpc import _grpcio_metadata # pytype: disable=pyi-error
|
|
from grpc import _observability # pytype: disable=pyi-error
|
|
from grpc._cython import cygrpc
|
|
from grpc._typing import ChannelArgumentType
|
|
from grpc._typing import DeserializingFunction
|
|
from grpc._typing import IntegratedCallFactory
|
|
from grpc._typing import MetadataType
|
|
from grpc._typing import NullaryCallbackType
|
|
from grpc._typing import ResponseType
|
|
from grpc._typing import SerializingFunction
|
|
from grpc._typing import UserTag
|
|
import grpc.experimental # pytype: disable=pyi-error
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
_USER_AGENT = "grpc-python/{}".format(_grpcio_metadata.__version__)
|
|
|
|
_EMPTY_FLAGS = 0
|
|
|
|
# NOTE(rbellevi): No guarantees are given about the maintenance of this
|
|
# environment variable.
|
|
_DEFAULT_SINGLE_THREADED_UNARY_STREAM = (
|
|
os.getenv("GRPC_SINGLE_THREADED_UNARY_STREAM") is not None
|
|
)
|
|
|
|
_UNARY_UNARY_INITIAL_DUE = (
|
|
cygrpc.OperationType.send_initial_metadata,
|
|
cygrpc.OperationType.send_message,
|
|
cygrpc.OperationType.send_close_from_client,
|
|
cygrpc.OperationType.receive_initial_metadata,
|
|
cygrpc.OperationType.receive_message,
|
|
cygrpc.OperationType.receive_status_on_client,
|
|
)
|
|
_UNARY_STREAM_INITIAL_DUE = (
|
|
cygrpc.OperationType.send_initial_metadata,
|
|
cygrpc.OperationType.send_message,
|
|
cygrpc.OperationType.send_close_from_client,
|
|
cygrpc.OperationType.receive_initial_metadata,
|
|
cygrpc.OperationType.receive_status_on_client,
|
|
)
|
|
_STREAM_UNARY_INITIAL_DUE = (
|
|
cygrpc.OperationType.send_initial_metadata,
|
|
cygrpc.OperationType.receive_initial_metadata,
|
|
cygrpc.OperationType.receive_message,
|
|
cygrpc.OperationType.receive_status_on_client,
|
|
)
|
|
_STREAM_STREAM_INITIAL_DUE = (
|
|
cygrpc.OperationType.send_initial_metadata,
|
|
cygrpc.OperationType.receive_initial_metadata,
|
|
cygrpc.OperationType.receive_status_on_client,
|
|
)
|
|
|
|
_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
|
|
"Exception calling channel subscription callback!"
|
|
)
|
|
|
|
_OK_RENDEZVOUS_REPR_FORMAT = (
|
|
'<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>'
|
|
)
|
|
|
|
_NON_OK_RENDEZVOUS_REPR_FORMAT = (
|
|
"<{} of RPC that terminated with:\n"
|
|
"\tstatus = {}\n"
|
|
'\tdetails = "{}"\n'
|
|
'\tdebug_error_string = "{}"\n'
|
|
">"
|
|
)
|
|
|
|
|
|
def _deadline(timeout: Optional[float]) -> Optional[float]:
|
|
return None if timeout is None else time.time() + timeout
|
|
|
|
|
|
def _unknown_code_details(
|
|
unknown_cygrpc_code: Optional[grpc.StatusCode], details: Optional[str]
|
|
) -> str:
|
|
return 'Server sent unknown code {} and details "{}"'.format(
|
|
unknown_cygrpc_code, details
|
|
)
|
|
|
|
|
|
class _RPCState(object):
|
|
condition: threading.Condition
|
|
due: Set[cygrpc.OperationType]
|
|
initial_metadata: Optional[MetadataType]
|
|
response: Any
|
|
trailing_metadata: Optional[MetadataType]
|
|
code: Optional[grpc.StatusCode]
|
|
details: Optional[str]
|
|
debug_error_string: Optional[str]
|
|
cancelled: bool
|
|
callbacks: List[NullaryCallbackType]
|
|
fork_epoch: Optional[int]
|
|
rpc_start_time: Optional[float] # In relative seconds
|
|
rpc_end_time: Optional[float] # In relative seconds
|
|
method: Optional[str]
|
|
target: Optional[str]
|
|
|
|
def __init__(
|
|
self,
|
|
due: Sequence[cygrpc.OperationType],
|
|
initial_metadata: Optional[MetadataType],
|
|
trailing_metadata: Optional[MetadataType],
|
|
code: Optional[grpc.StatusCode],
|
|
details: Optional[str],
|
|
):
|
|
# `condition` guards all members of _RPCState. `notify_all` is called on
|
|
# `condition` when the state of the RPC has changed.
|
|
self.condition = threading.Condition()
|
|
|
|
# The cygrpc.OperationType objects representing events due from the RPC's
|
|
# completion queue. If an operation is in `due`, it is guaranteed that
|
|
# `operate()` has been called on a corresponding operation. But the
|
|
# converse is not true. That is, in the case of failed `operate()`
|
|
# calls, there may briefly be events in `due` that do not correspond to
|
|
# operations submitted to Core.
|
|
self.due = set(due)
|
|
self.initial_metadata = initial_metadata
|
|
self.response = None
|
|
self.trailing_metadata = trailing_metadata
|
|
self.code = code
|
|
self.details = details
|
|
self.debug_error_string = None
|
|
# The following three fields are used for observability.
|
|
# Updates to those fields do not trigger self.condition.
|
|
self.rpc_start_time = None
|
|
self.rpc_end_time = None
|
|
self.method = None
|
|
self.target = None
|
|
|
|
# The semantics of grpc.Future.cancel and grpc.Future.cancelled are
|
|
# slightly wonky, so they have to be tracked separately from the rest of the
|
|
# result of the RPC. This field tracks whether cancellation was requested
|
|
# prior to termination of the RPC.
|
|
self.cancelled = False
|
|
self.callbacks = []
|
|
self.fork_epoch = cygrpc.get_fork_epoch()
|
|
|
|
def reset_postfork_child(self):
|
|
self.condition = threading.Condition()
|
|
|
|
|
|
def _abort(state: _RPCState, code: grpc.StatusCode, details: str) -> None:
|
|
if state.code is None:
|
|
state.code = code
|
|
state.details = details
|
|
if state.initial_metadata is None:
|
|
state.initial_metadata = ()
|
|
state.trailing_metadata = ()
|
|
|
|
|
|
def _handle_event(
|
|
event: cygrpc.BaseEvent,
|
|
state: _RPCState,
|
|
response_deserializer: Optional[DeserializingFunction],
|
|
) -> List[NullaryCallbackType]:
|
|
callbacks = []
|
|
for batch_operation in event.batch_operations:
|
|
operation_type = batch_operation.type()
|
|
state.due.remove(operation_type)
|
|
if operation_type == cygrpc.OperationType.receive_initial_metadata:
|
|
state.initial_metadata = batch_operation.initial_metadata()
|
|
elif operation_type == cygrpc.OperationType.receive_message:
|
|
serialized_response = batch_operation.message()
|
|
if serialized_response is not None:
|
|
response = _common.deserialize(
|
|
serialized_response, response_deserializer
|
|
)
|
|
if response is None:
|
|
details = "Exception deserializing response!"
|
|
_abort(state, grpc.StatusCode.INTERNAL, details)
|
|
else:
|
|
state.response = response
|
|
elif operation_type == cygrpc.OperationType.receive_status_on_client:
|
|
state.trailing_metadata = batch_operation.trailing_metadata()
|
|
if state.code is None:
|
|
code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE.get(
|
|
batch_operation.code()
|
|
)
|
|
if code is None:
|
|
state.code = grpc.StatusCode.UNKNOWN
|
|
state.details = _unknown_code_details(
|
|
code, batch_operation.details()
|
|
)
|
|
else:
|
|
state.code = code
|
|
state.details = batch_operation.details()
|
|
state.debug_error_string = batch_operation.error_string()
|
|
state.rpc_end_time = time.perf_counter()
|
|
_observability.maybe_record_rpc_latency(state)
|
|
callbacks.extend(state.callbacks)
|
|
state.callbacks = None
|
|
return callbacks
|
|
|
|
|
|
def _event_handler(
|
|
state: _RPCState, response_deserializer: Optional[DeserializingFunction]
|
|
) -> UserTag:
|
|
def handle_event(event):
|
|
with state.condition:
|
|
callbacks = _handle_event(event, state, response_deserializer)
|
|
state.condition.notify_all()
|
|
done = not state.due
|
|
for callback in callbacks:
|
|
try:
|
|
callback()
|
|
except Exception as e: # pylint: disable=broad-except
|
|
# NOTE(rbellevi): We suppress but log errors here so as not to
|
|
# kill the channel spin thread.
|
|
logging.error(
|
|
"Exception in callback %s: %s", repr(callback.func), repr(e)
|
|
)
|
|
return done and state.fork_epoch >= cygrpc.get_fork_epoch()
|
|
|
|
return handle_event
|
|
|
|
|
|
# TODO(xuanwn): Create a base class for IntegratedCall and SegregatedCall.
|
|
# pylint: disable=too-many-statements
|
|
def _consume_request_iterator(
|
|
request_iterator: Iterator,
|
|
state: _RPCState,
|
|
call: Union[cygrpc.IntegratedCall, cygrpc.SegregatedCall],
|
|
request_serializer: SerializingFunction,
|
|
event_handler: Optional[UserTag],
|
|
) -> None:
|
|
"""Consume a request supplied by the user."""
|
|
|
|
def consume_request_iterator(): # pylint: disable=too-many-branches
|
|
# Iterate over the request iterator until it is exhausted or an error
|
|
# condition is encountered.
|
|
while True:
|
|
return_from_user_request_generator_invoked = False
|
|
try:
|
|
# The thread may die in user-code. Do not block fork for this.
|
|
cygrpc.enter_user_request_generator()
|
|
request = next(request_iterator)
|
|
except StopIteration:
|
|
break
|
|
except Exception: # pylint: disable=broad-except
|
|
cygrpc.return_from_user_request_generator()
|
|
return_from_user_request_generator_invoked = True
|
|
code = grpc.StatusCode.UNKNOWN
|
|
details = "Exception iterating requests!"
|
|
_LOGGER.exception(details)
|
|
call.cancel(
|
|
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details
|
|
)
|
|
_abort(state, code, details)
|
|
return
|
|
finally:
|
|
if not return_from_user_request_generator_invoked:
|
|
cygrpc.return_from_user_request_generator()
|
|
serialized_request = _common.serialize(request, request_serializer)
|
|
with state.condition:
|
|
if state.code is None and not state.cancelled:
|
|
if serialized_request is None:
|
|
code = grpc.StatusCode.INTERNAL
|
|
details = "Exception serializing request!"
|
|
call.cancel(
|
|
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code],
|
|
details,
|
|
)
|
|
_abort(state, code, details)
|
|
return
|
|
else:
|
|
state.due.add(cygrpc.OperationType.send_message)
|
|
operations = (
|
|
cygrpc.SendMessageOperation(
|
|
serialized_request, _EMPTY_FLAGS
|
|
),
|
|
)
|
|
operating = call.operate(operations, event_handler)
|
|
if not operating:
|
|
state.due.remove(cygrpc.OperationType.send_message)
|
|
return
|
|
|
|
def _done():
|
|
return (
|
|
state.code is not None
|
|
or cygrpc.OperationType.send_message
|
|
not in state.due
|
|
)
|
|
|
|
_common.wait(
|
|
state.condition.wait,
|
|
_done,
|
|
spin_cb=functools.partial(
|
|
cygrpc.block_if_fork_in_progress, state
|
|
),
|
|
)
|
|
if state.code is not None:
|
|
return
|
|
else:
|
|
return
|
|
with state.condition:
|
|
if state.code is None:
|
|
state.due.add(cygrpc.OperationType.send_close_from_client)
|
|
operations = (
|
|
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
|
|
)
|
|
operating = call.operate(operations, event_handler)
|
|
if not operating:
|
|
state.due.remove(
|
|
cygrpc.OperationType.send_close_from_client
|
|
)
|
|
|
|
consumption_thread = cygrpc.ForkManagedThread(
|
|
target=consume_request_iterator
|
|
)
|
|
consumption_thread.setDaemon(True)
|
|
consumption_thread.start()
|
|
|
|
|
|
def _rpc_state_string(class_name: str, rpc_state: _RPCState) -> str:
|
|
"""Calculates error string for RPC."""
|
|
with rpc_state.condition:
|
|
if rpc_state.code is None:
|
|
return "<{} object>".format(class_name)
|
|
elif rpc_state.code is grpc.StatusCode.OK:
|
|
return _OK_RENDEZVOUS_REPR_FORMAT.format(
|
|
class_name, rpc_state.code, rpc_state.details
|
|
)
|
|
else:
|
|
return _NON_OK_RENDEZVOUS_REPR_FORMAT.format(
|
|
class_name,
|
|
rpc_state.code,
|
|
rpc_state.details,
|
|
rpc_state.debug_error_string,
|
|
)
|
|
|
|
|
|
class _InactiveRpcError(grpc.RpcError, grpc.Call, grpc.Future):
|
|
"""An RPC error not tied to the execution of a particular RPC.
|
|
|
|
The RPC represented by the state object must not be in-progress or
|
|
cancelled.
|
|
|
|
Attributes:
|
|
_state: An instance of _RPCState.
|
|
"""
|
|
|
|
_state: _RPCState
|
|
|
|
def __init__(self, state: _RPCState):
|
|
with state.condition:
|
|
self._state = _RPCState(
|
|
(),
|
|
copy.deepcopy(state.initial_metadata),
|
|
copy.deepcopy(state.trailing_metadata),
|
|
state.code,
|
|
copy.deepcopy(state.details),
|
|
)
|
|
self._state.response = copy.copy(state.response)
|
|
self._state.debug_error_string = copy.copy(state.debug_error_string)
|
|
|
|
def initial_metadata(self) -> Optional[MetadataType]:
|
|
return self._state.initial_metadata
|
|
|
|
def trailing_metadata(self) -> Optional[MetadataType]:
|
|
return self._state.trailing_metadata
|
|
|
|
def code(self) -> Optional[grpc.StatusCode]:
|
|
return self._state.code
|
|
|
|
def details(self) -> Optional[str]:
|
|
return _common.decode(self._state.details)
|
|
|
|
def debug_error_string(self) -> Optional[str]:
|
|
return _common.decode(self._state.debug_error_string)
|
|
|
|
def _repr(self) -> str:
|
|
return _rpc_state_string(self.__class__.__name__, self._state)
|
|
|
|
def __repr__(self) -> str:
|
|
return self._repr()
|
|
|
|
def __str__(self) -> str:
|
|
return self._repr()
|
|
|
|
def cancel(self) -> bool:
|
|
"""See grpc.Future.cancel."""
|
|
return False
|
|
|
|
def cancelled(self) -> bool:
|
|
"""See grpc.Future.cancelled."""
|
|
return False
|
|
|
|
def running(self) -> bool:
|
|
"""See grpc.Future.running."""
|
|
return False
|
|
|
|
def done(self) -> bool:
|
|
"""See grpc.Future.done."""
|
|
return True
|
|
|
|
def result(
|
|
self, timeout: Optional[float] = None
|
|
) -> Any: # pylint: disable=unused-argument
|
|
"""See grpc.Future.result."""
|
|
raise self
|
|
|
|
def exception(
|
|
self, timeout: Optional[float] = None # pylint: disable=unused-argument
|
|
) -> Optional[Exception]:
|
|
"""See grpc.Future.exception."""
|
|
return self
|
|
|
|
def traceback(
|
|
self, timeout: Optional[float] = None # pylint: disable=unused-argument
|
|
) -> Optional[types.TracebackType]:
|
|
"""See grpc.Future.traceback."""
|
|
try:
|
|
raise self
|
|
except grpc.RpcError:
|
|
return sys.exc_info()[2]
|
|
|
|
def add_done_callback(
|
|
self,
|
|
fn: Callable[[grpc.Future], None],
|
|
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
) -> None:
|
|
"""See grpc.Future.add_done_callback."""
|
|
fn(self)
|
|
|
|
|
|
class _Rendezvous(grpc.RpcError, grpc.RpcContext):
|
|
"""An RPC iterator.
|
|
|
|
Attributes:
|
|
_state: An instance of _RPCState.
|
|
_call: An instance of SegregatedCall or IntegratedCall.
|
|
In either case, the _call object is expected to have operate, cancel,
|
|
and next_event methods.
|
|
_response_deserializer: A callable taking bytes and return a Python
|
|
object.
|
|
_deadline: A float representing the deadline of the RPC in seconds. Or
|
|
possibly None, to represent an RPC with no deadline at all.
|
|
"""
|
|
|
|
_state: _RPCState
|
|
_call: Union[cygrpc.SegregatedCall, cygrpc.IntegratedCall]
|
|
_response_deserializer: Optional[DeserializingFunction]
|
|
_deadline: Optional[float]
|
|
|
|
def __init__(
|
|
self,
|
|
state: _RPCState,
|
|
call: Union[cygrpc.SegregatedCall, cygrpc.IntegratedCall],
|
|
response_deserializer: Optional[DeserializingFunction],
|
|
deadline: Optional[float],
|
|
):
|
|
super(_Rendezvous, self).__init__()
|
|
self._state = state
|
|
self._call = call
|
|
self._response_deserializer = response_deserializer
|
|
self._deadline = deadline
|
|
|
|
def is_active(self) -> bool:
|
|
"""See grpc.RpcContext.is_active"""
|
|
with self._state.condition:
|
|
return self._state.code is None
|
|
|
|
def time_remaining(self) -> Optional[float]:
|
|
"""See grpc.RpcContext.time_remaining"""
|
|
with self._state.condition:
|
|
if self._deadline is None:
|
|
return None
|
|
else:
|
|
return max(self._deadline - time.time(), 0)
|
|
|
|
def cancel(self) -> bool:
|
|
"""See grpc.RpcContext.cancel"""
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
code = grpc.StatusCode.CANCELLED
|
|
details = "Locally cancelled by application!"
|
|
self._call.cancel(
|
|
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details
|
|
)
|
|
self._state.cancelled = True
|
|
_abort(self._state, code, details)
|
|
self._state.condition.notify_all()
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def add_callback(self, callback: NullaryCallbackType) -> bool:
|
|
"""See grpc.RpcContext.add_callback"""
|
|
with self._state.condition:
|
|
if self._state.callbacks is None:
|
|
return False
|
|
else:
|
|
self._state.callbacks.append(callback)
|
|
return True
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
return self._next()
|
|
|
|
def __next__(self):
|
|
return self._next()
|
|
|
|
def _next(self):
|
|
raise NotImplementedError()
|
|
|
|
def debug_error_string(self) -> Optional[str]:
|
|
raise NotImplementedError()
|
|
|
|
def _repr(self) -> str:
|
|
return _rpc_state_string(self.__class__.__name__, self._state)
|
|
|
|
def __repr__(self) -> str:
|
|
return self._repr()
|
|
|
|
def __str__(self) -> str:
|
|
return self._repr()
|
|
|
|
def __del__(self) -> None:
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
self._state.code = grpc.StatusCode.CANCELLED
|
|
self._state.details = "Cancelled upon garbage collection!"
|
|
self._state.cancelled = True
|
|
self._call.cancel(
|
|
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code],
|
|
self._state.details,
|
|
)
|
|
self._state.condition.notify_all()
|
|
|
|
|
|
class _SingleThreadedRendezvous(
|
|
_Rendezvous, grpc.Call, grpc.Future
|
|
): # pylint: disable=too-many-ancestors
|
|
"""An RPC iterator operating entirely on a single thread.
|
|
|
|
The __next__ method of _SingleThreadedRendezvous does not depend on the
|
|
existence of any other thread, including the "channel spin thread".
|
|
However, this means that its interface is entirely synchronous. So this
|
|
class cannot completely fulfill the grpc.Future interface. The result,
|
|
exception, and traceback methods will never block and will instead raise
|
|
an exception if calling the method would result in blocking.
|
|
|
|
This means that these methods are safe to call from add_done_callback
|
|
handlers.
|
|
"""
|
|
|
|
_state: _RPCState
|
|
|
|
def _is_complete(self) -> bool:
|
|
return self._state.code is not None
|
|
|
|
def cancelled(self) -> bool:
|
|
with self._state.condition:
|
|
return self._state.cancelled
|
|
|
|
def running(self) -> bool:
|
|
with self._state.condition:
|
|
return self._state.code is None
|
|
|
|
def done(self) -> bool:
|
|
with self._state.condition:
|
|
return self._state.code is not None
|
|
|
|
def result(self, timeout: Optional[float] = None) -> Any:
|
|
"""Returns the result of the computation or raises its exception.
|
|
|
|
This method will never block. Instead, it will raise an exception
|
|
if calling this method would otherwise result in blocking.
|
|
|
|
Since this method will never block, any `timeout` argument passed will
|
|
be ignored.
|
|
"""
|
|
del timeout
|
|
with self._state.condition:
|
|
if not self._is_complete():
|
|
raise grpc.experimental.UsageError(
|
|
"_SingleThreadedRendezvous only supports result() when the"
|
|
" RPC is complete."
|
|
)
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
return self._state.response
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
raise self
|
|
|
|
def exception(self, timeout: Optional[float] = None) -> Optional[Exception]:
|
|
"""Return the exception raised by the computation.
|
|
|
|
This method will never block. Instead, it will raise an exception
|
|
if calling this method would otherwise result in blocking.
|
|
|
|
Since this method will never block, any `timeout` argument passed will
|
|
be ignored.
|
|
"""
|
|
del timeout
|
|
with self._state.condition:
|
|
if not self._is_complete():
|
|
raise grpc.experimental.UsageError(
|
|
"_SingleThreadedRendezvous only supports exception() when"
|
|
" the RPC is complete."
|
|
)
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
return None
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
return self
|
|
|
|
def traceback(
|
|
self, timeout: Optional[float] = None
|
|
) -> Optional[types.TracebackType]:
|
|
"""Access the traceback of the exception raised by the computation.
|
|
|
|
This method will never block. Instead, it will raise an exception
|
|
if calling this method would otherwise result in blocking.
|
|
|
|
Since this method will never block, any `timeout` argument passed will
|
|
be ignored.
|
|
"""
|
|
del timeout
|
|
with self._state.condition:
|
|
if not self._is_complete():
|
|
raise grpc.experimental.UsageError(
|
|
"_SingleThreadedRendezvous only supports traceback() when"
|
|
" the RPC is complete."
|
|
)
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
return None
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
try:
|
|
raise self
|
|
except grpc.RpcError:
|
|
return sys.exc_info()[2]
|
|
|
|
def add_done_callback(self, fn: Callable[[grpc.Future], None]) -> None:
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
self._state.callbacks.append(functools.partial(fn, self))
|
|
return
|
|
|
|
fn(self)
|
|
|
|
def initial_metadata(self) -> Optional[MetadataType]:
|
|
"""See grpc.Call.initial_metadata"""
|
|
with self._state.condition:
|
|
# NOTE(gnossen): Based on our initial call batch, we are guaranteed
|
|
# to receive initial metadata before any messages.
|
|
while self._state.initial_metadata is None:
|
|
self._consume_next_event()
|
|
return self._state.initial_metadata
|
|
|
|
def trailing_metadata(self) -> Optional[MetadataType]:
|
|
"""See grpc.Call.trailing_metadata"""
|
|
with self._state.condition:
|
|
if self._state.trailing_metadata is None:
|
|
raise grpc.experimental.UsageError(
|
|
"Cannot get trailing metadata until RPC is completed."
|
|
)
|
|
return self._state.trailing_metadata
|
|
|
|
def code(self) -> Optional[grpc.StatusCode]:
|
|
"""See grpc.Call.code"""
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
raise grpc.experimental.UsageError(
|
|
"Cannot get code until RPC is completed."
|
|
)
|
|
return self._state.code
|
|
|
|
def details(self) -> Optional[str]:
|
|
"""See grpc.Call.details"""
|
|
with self._state.condition:
|
|
if self._state.details is None:
|
|
raise grpc.experimental.UsageError(
|
|
"Cannot get details until RPC is completed."
|
|
)
|
|
return _common.decode(self._state.details)
|
|
|
|
def _consume_next_event(self) -> Optional[cygrpc.BaseEvent]:
|
|
event = self._call.next_event()
|
|
with self._state.condition:
|
|
callbacks = _handle_event(
|
|
event, self._state, self._response_deserializer
|
|
)
|
|
for callback in callbacks:
|
|
# NOTE(gnossen): We intentionally allow exceptions to bubble up
|
|
# to the user when running on a single thread.
|
|
callback()
|
|
return event
|
|
|
|
def _next_response(self) -> Any:
|
|
while True:
|
|
self._consume_next_event()
|
|
with self._state.condition:
|
|
if self._state.response is not None:
|
|
response = self._state.response
|
|
self._state.response = None
|
|
return response
|
|
elif (
|
|
cygrpc.OperationType.receive_message not in self._state.due
|
|
):
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
raise StopIteration()
|
|
elif self._state.code is not None:
|
|
raise self
|
|
|
|
def _next(self) -> Any:
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
# We tentatively add the operation as expected and remove
|
|
# it if the enqueue operation fails. This allows us to guarantee that
|
|
# if an event has been submitted to the core completion queue,
|
|
# it is in `due`. If we waited until after a successful
|
|
# enqueue operation then a signal could interrupt this
|
|
# thread between the enqueue operation and the addition of the
|
|
# operation to `due`. This would cause an exception on the
|
|
# channel spin thread when the operation completes and no
|
|
# corresponding operation would be present in state.due.
|
|
# Note that, since `condition` is held through this block, there is
|
|
# no data race on `due`.
|
|
self._state.due.add(cygrpc.OperationType.receive_message)
|
|
operating = self._call.operate(
|
|
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), None
|
|
)
|
|
if not operating:
|
|
self._state.due.remove(cygrpc.OperationType.receive_message)
|
|
elif self._state.code is grpc.StatusCode.OK:
|
|
raise StopIteration()
|
|
else:
|
|
raise self
|
|
return self._next_response()
|
|
|
|
def debug_error_string(self) -> Optional[str]:
|
|
with self._state.condition:
|
|
if self._state.debug_error_string is None:
|
|
raise grpc.experimental.UsageError(
|
|
"Cannot get debug error string until RPC is completed."
|
|
)
|
|
return _common.decode(self._state.debug_error_string)
|
|
|
|
|
|
class _MultiThreadedRendezvous(
|
|
_Rendezvous, grpc.Call, grpc.Future
|
|
): # pylint: disable=too-many-ancestors
|
|
"""An RPC iterator that depends on a channel spin thread.
|
|
|
|
This iterator relies upon a per-channel thread running in the background,
|
|
dequeueing events from the completion queue, and notifying threads waiting
|
|
on the threading.Condition object in the _RPCState object.
|
|
|
|
This extra thread allows _MultiThreadedRendezvous to fulfill the grpc.Future interface
|
|
and to mediate a bidirection streaming RPC.
|
|
"""
|
|
|
|
_state: _RPCState
|
|
|
|
def initial_metadata(self) -> Optional[MetadataType]:
|
|
"""See grpc.Call.initial_metadata"""
|
|
with self._state.condition:
|
|
|
|
def _done():
|
|
return self._state.initial_metadata is not None
|
|
|
|
_common.wait(self._state.condition.wait, _done)
|
|
return self._state.initial_metadata
|
|
|
|
def trailing_metadata(self) -> Optional[MetadataType]:
|
|
"""See grpc.Call.trailing_metadata"""
|
|
with self._state.condition:
|
|
|
|
def _done():
|
|
return self._state.trailing_metadata is not None
|
|
|
|
_common.wait(self._state.condition.wait, _done)
|
|
return self._state.trailing_metadata
|
|
|
|
def code(self) -> Optional[grpc.StatusCode]:
|
|
"""See grpc.Call.code"""
|
|
with self._state.condition:
|
|
|
|
def _done():
|
|
return self._state.code is not None
|
|
|
|
_common.wait(self._state.condition.wait, _done)
|
|
return self._state.code
|
|
|
|
def details(self) -> Optional[str]:
|
|
"""See grpc.Call.details"""
|
|
with self._state.condition:
|
|
|
|
def _done():
|
|
return self._state.details is not None
|
|
|
|
_common.wait(self._state.condition.wait, _done)
|
|
return _common.decode(self._state.details)
|
|
|
|
def debug_error_string(self) -> Optional[str]:
|
|
with self._state.condition:
|
|
|
|
def _done():
|
|
return self._state.debug_error_string is not None
|
|
|
|
_common.wait(self._state.condition.wait, _done)
|
|
return _common.decode(self._state.debug_error_string)
|
|
|
|
def cancelled(self) -> bool:
|
|
with self._state.condition:
|
|
return self._state.cancelled
|
|
|
|
def running(self) -> bool:
|
|
with self._state.condition:
|
|
return self._state.code is None
|
|
|
|
def done(self) -> bool:
|
|
with self._state.condition:
|
|
return self._state.code is not None
|
|
|
|
def _is_complete(self) -> bool:
|
|
return self._state.code is not None
|
|
|
|
def result(self, timeout: Optional[float] = None) -> Any:
|
|
"""Returns the result of the computation or raises its exception.
|
|
|
|
See grpc.Future.result for the full API contract.
|
|
"""
|
|
with self._state.condition:
|
|
timed_out = _common.wait(
|
|
self._state.condition.wait, self._is_complete, timeout=timeout
|
|
)
|
|
if timed_out:
|
|
raise grpc.FutureTimeoutError()
|
|
else:
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
return self._state.response
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
raise self
|
|
|
|
def exception(self, timeout: Optional[float] = None) -> Optional[Exception]:
|
|
"""Return the exception raised by the computation.
|
|
|
|
See grpc.Future.exception for the full API contract.
|
|
"""
|
|
with self._state.condition:
|
|
timed_out = _common.wait(
|
|
self._state.condition.wait, self._is_complete, timeout=timeout
|
|
)
|
|
if timed_out:
|
|
raise grpc.FutureTimeoutError()
|
|
else:
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
return None
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
return self
|
|
|
|
def traceback(
|
|
self, timeout: Optional[float] = None
|
|
) -> Optional[types.TracebackType]:
|
|
"""Access the traceback of the exception raised by the computation.
|
|
|
|
See grpc.future.traceback for the full API contract.
|
|
"""
|
|
with self._state.condition:
|
|
timed_out = _common.wait(
|
|
self._state.condition.wait, self._is_complete, timeout=timeout
|
|
)
|
|
if timed_out:
|
|
raise grpc.FutureTimeoutError()
|
|
else:
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
return None
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
try:
|
|
raise self
|
|
except grpc.RpcError:
|
|
return sys.exc_info()[2]
|
|
|
|
def add_done_callback(self, fn: Callable[[grpc.Future], None]) -> None:
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
self._state.callbacks.append(functools.partial(fn, self))
|
|
return
|
|
|
|
fn(self)
|
|
|
|
def _next(self) -> Any:
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
event_handler = _event_handler(
|
|
self._state, self._response_deserializer
|
|
)
|
|
self._state.due.add(cygrpc.OperationType.receive_message)
|
|
operating = self._call.operate(
|
|
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
|
|
event_handler,
|
|
)
|
|
if not operating:
|
|
self._state.due.remove(cygrpc.OperationType.receive_message)
|
|
elif self._state.code is grpc.StatusCode.OK:
|
|
raise StopIteration()
|
|
else:
|
|
raise self
|
|
|
|
def _response_ready():
|
|
return self._state.response is not None or (
|
|
cygrpc.OperationType.receive_message not in self._state.due
|
|
and self._state.code is not None
|
|
)
|
|
|
|
_common.wait(self._state.condition.wait, _response_ready)
|
|
if self._state.response is not None:
|
|
response = self._state.response
|
|
self._state.response = None
|
|
return response
|
|
elif cygrpc.OperationType.receive_message not in self._state.due:
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
raise StopIteration()
|
|
elif self._state.code is not None:
|
|
raise self
|
|
|
|
|
|
def _start_unary_request(
|
|
request: Any,
|
|
timeout: Optional[float],
|
|
request_serializer: SerializingFunction,
|
|
) -> Tuple[Optional[float], Optional[bytes], Optional[grpc.RpcError]]:
|
|
deadline = _deadline(timeout)
|
|
serialized_request = _common.serialize(request, request_serializer)
|
|
if serialized_request is None:
|
|
state = _RPCState(
|
|
(),
|
|
(),
|
|
(),
|
|
grpc.StatusCode.INTERNAL,
|
|
"Exception serializing request!",
|
|
)
|
|
error = _InactiveRpcError(state)
|
|
return deadline, None, error
|
|
else:
|
|
return deadline, serialized_request, None
|
|
|
|
|
|
def _end_unary_response_blocking(
|
|
state: _RPCState,
|
|
call: cygrpc.SegregatedCall,
|
|
with_call: bool,
|
|
deadline: Optional[float],
|
|
) -> Union[ResponseType, Tuple[ResponseType, grpc.Call]]:
|
|
if state.code is grpc.StatusCode.OK:
|
|
if with_call:
|
|
rendezvous = _MultiThreadedRendezvous(state, call, None, deadline)
|
|
return state.response, rendezvous
|
|
else:
|
|
return state.response
|
|
else:
|
|
raise _InactiveRpcError(state) # pytype: disable=not-instantiable
|
|
|
|
|
|
def _stream_unary_invocation_operations(
|
|
metadata: Optional[MetadataType], initial_metadata_flags: int
|
|
) -> Sequence[Sequence[cygrpc.Operation]]:
|
|
return (
|
|
(
|
|
cygrpc.SendInitialMetadataOperation(
|
|
metadata, initial_metadata_flags
|
|
),
|
|
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
|
|
),
|
|
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
|
|
)
|
|
|
|
|
|
def _stream_unary_invocation_operations_and_tags(
|
|
metadata: Optional[MetadataType], initial_metadata_flags: int
|
|
) -> Sequence[Tuple[Sequence[cygrpc.Operation], Optional[UserTag]]]:
|
|
return tuple(
|
|
(
|
|
operations,
|
|
None,
|
|
)
|
|
for operations in _stream_unary_invocation_operations(
|
|
metadata, initial_metadata_flags
|
|
)
|
|
)
|
|
|
|
|
|
def _determine_deadline(user_deadline: Optional[float]) -> Optional[float]:
|
|
parent_deadline = cygrpc.get_deadline_from_context()
|
|
if parent_deadline is None and user_deadline is None:
|
|
return None
|
|
elif parent_deadline is not None and user_deadline is None:
|
|
return parent_deadline
|
|
elif user_deadline is not None and parent_deadline is None:
|
|
return user_deadline
|
|
else:
|
|
return min(parent_deadline, user_deadline)
|
|
|
|
|
|
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
|
|
_channel: cygrpc.Channel
|
|
_managed_call: IntegratedCallFactory
|
|
_method: bytes
|
|
_target: bytes
|
|
_request_serializer: Optional[SerializingFunction]
|
|
_response_deserializer: Optional[DeserializingFunction]
|
|
_context: Any
|
|
_registered_call_handle: Optional[int]
|
|
|
|
__slots__ = [
|
|
"_channel",
|
|
"_managed_call",
|
|
"_method",
|
|
"_target",
|
|
"_request_serializer",
|
|
"_response_deserializer",
|
|
"_context",
|
|
]
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def __init__(
|
|
self,
|
|
channel: cygrpc.Channel,
|
|
managed_call: IntegratedCallFactory,
|
|
method: bytes,
|
|
target: bytes,
|
|
request_serializer: Optional[SerializingFunction],
|
|
response_deserializer: Optional[DeserializingFunction],
|
|
_registered_call_handle: Optional[int],
|
|
):
|
|
self._channel = channel
|
|
self._managed_call = managed_call
|
|
self._method = method
|
|
self._target = target
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
self._context = cygrpc.build_census_context()
|
|
self._registered_call_handle = _registered_call_handle
|
|
|
|
def _prepare(
|
|
self,
|
|
request: Any,
|
|
timeout: Optional[float],
|
|
metadata: Optional[MetadataType],
|
|
wait_for_ready: Optional[bool],
|
|
compression: Optional[grpc.Compression],
|
|
) -> Tuple[
|
|
Optional[_RPCState],
|
|
Optional[Sequence[cygrpc.Operation]],
|
|
Optional[float],
|
|
Optional[grpc.RpcError],
|
|
]:
|
|
deadline, serialized_request, rendezvous = _start_unary_request(
|
|
request, timeout, self._request_serializer
|
|
)
|
|
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
|
|
wait_for_ready
|
|
)
|
|
augmented_metadata = _compression.augment_metadata(
|
|
metadata, compression
|
|
)
|
|
if serialized_request is None:
|
|
return None, None, None, rendezvous
|
|
else:
|
|
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
|
|
operations = (
|
|
cygrpc.SendInitialMetadataOperation(
|
|
augmented_metadata, initial_metadata_flags
|
|
),
|
|
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS),
|
|
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
|
|
)
|
|
return state, operations, deadline, None
|
|
|
|
def _blocking(
|
|
self,
|
|
request: Any,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> Tuple[_RPCState, cygrpc.SegregatedCall]:
|
|
state, operations, deadline, rendezvous = self._prepare(
|
|
request, timeout, metadata, wait_for_ready, compression
|
|
)
|
|
if state is None:
|
|
raise rendezvous # pylint: disable-msg=raising-bad-type
|
|
else:
|
|
state.rpc_start_time = time.perf_counter()
|
|
state.method = _common.decode(self._method)
|
|
state.target = _common.decode(self._target)
|
|
call = self._channel.segregated_call(
|
|
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
|
|
self._method,
|
|
None,
|
|
_determine_deadline(deadline),
|
|
metadata,
|
|
None if credentials is None else credentials._credentials,
|
|
(
|
|
(
|
|
operations,
|
|
None,
|
|
),
|
|
),
|
|
self._context,
|
|
self._registered_call_handle,
|
|
)
|
|
event = call.next_event()
|
|
_handle_event(event, state, self._response_deserializer)
|
|
return state, call
|
|
|
|
def __call__(
|
|
self,
|
|
request: Any,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> Any:
|
|
(
|
|
state,
|
|
call,
|
|
) = self._blocking(
|
|
request, timeout, metadata, credentials, wait_for_ready, compression
|
|
)
|
|
return _end_unary_response_blocking(state, call, False, None)
|
|
|
|
def with_call(
|
|
self,
|
|
request: Any,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> Tuple[Any, grpc.Call]:
|
|
(
|
|
state,
|
|
call,
|
|
) = self._blocking(
|
|
request, timeout, metadata, credentials, wait_for_ready, compression
|
|
)
|
|
return _end_unary_response_blocking(state, call, True, None)
|
|
|
|
def future(
|
|
self,
|
|
request: Any,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> _MultiThreadedRendezvous:
|
|
state, operations, deadline, rendezvous = self._prepare(
|
|
request, timeout, metadata, wait_for_ready, compression
|
|
)
|
|
if state is None:
|
|
raise rendezvous # pylint: disable-msg=raising-bad-type
|
|
else:
|
|
event_handler = _event_handler(state, self._response_deserializer)
|
|
state.rpc_start_time = time.perf_counter()
|
|
state.method = _common.decode(self._method)
|
|
state.target = _common.decode(self._target)
|
|
call = self._managed_call(
|
|
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
|
|
self._method,
|
|
None,
|
|
deadline,
|
|
metadata,
|
|
None if credentials is None else credentials._credentials,
|
|
(operations,),
|
|
event_handler,
|
|
self._context,
|
|
self._registered_call_handle,
|
|
)
|
|
return _MultiThreadedRendezvous(
|
|
state, call, self._response_deserializer, deadline
|
|
)
|
|
|
|
|
|
class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
|
|
_channel: cygrpc.Channel
|
|
_method: bytes
|
|
_target: bytes
|
|
_request_serializer: Optional[SerializingFunction]
|
|
_response_deserializer: Optional[DeserializingFunction]
|
|
_context: Any
|
|
_registered_call_handle: Optional[int]
|
|
|
|
__slots__ = [
|
|
"_channel",
|
|
"_method",
|
|
"_target",
|
|
"_request_serializer",
|
|
"_response_deserializer",
|
|
"_context",
|
|
]
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def __init__(
|
|
self,
|
|
channel: cygrpc.Channel,
|
|
method: bytes,
|
|
target: bytes,
|
|
request_serializer: SerializingFunction,
|
|
response_deserializer: DeserializingFunction,
|
|
_registered_call_handle: Optional[int],
|
|
):
|
|
self._channel = channel
|
|
self._method = method
|
|
self._target = target
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
self._context = cygrpc.build_census_context()
|
|
self._registered_call_handle = _registered_call_handle
|
|
|
|
def __call__( # pylint: disable=too-many-locals
|
|
self,
|
|
request: Any,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> _SingleThreadedRendezvous:
|
|
deadline = _deadline(timeout)
|
|
serialized_request = _common.serialize(
|
|
request, self._request_serializer
|
|
)
|
|
if serialized_request is None:
|
|
state = _RPCState(
|
|
(),
|
|
(),
|
|
(),
|
|
grpc.StatusCode.INTERNAL,
|
|
"Exception serializing request!",
|
|
)
|
|
raise _InactiveRpcError(state)
|
|
|
|
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
|
|
call_credentials = (
|
|
None if credentials is None else credentials._credentials
|
|
)
|
|
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
|
|
wait_for_ready
|
|
)
|
|
augmented_metadata = _compression.augment_metadata(
|
|
metadata, compression
|
|
)
|
|
operations = (
|
|
(
|
|
cygrpc.SendInitialMetadataOperation(
|
|
augmented_metadata, initial_metadata_flags
|
|
),
|
|
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS),
|
|
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
|
|
),
|
|
(cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),),
|
|
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
|
|
)
|
|
operations_and_tags = tuple((ops, None) for ops in operations)
|
|
state.rpc_start_time = time.perf_counter()
|
|
state.method = _common.decode(self._method)
|
|
state.target = _common.decode(self._target)
|
|
call = self._channel.segregated_call(
|
|
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
|
|
self._method,
|
|
None,
|
|
_determine_deadline(deadline),
|
|
metadata,
|
|
call_credentials,
|
|
operations_and_tags,
|
|
self._context,
|
|
self._registered_call_handle,
|
|
)
|
|
return _SingleThreadedRendezvous(
|
|
state, call, self._response_deserializer, deadline
|
|
)
|
|
|
|
|
|
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
|
|
_channel: cygrpc.Channel
|
|
_managed_call: IntegratedCallFactory
|
|
_method: bytes
|
|
_target: bytes
|
|
_request_serializer: Optional[SerializingFunction]
|
|
_response_deserializer: Optional[DeserializingFunction]
|
|
_context: Any
|
|
_registered_call_handle: Optional[int]
|
|
|
|
__slots__ = [
|
|
"_channel",
|
|
"_managed_call",
|
|
"_method",
|
|
"_target",
|
|
"_request_serializer",
|
|
"_response_deserializer",
|
|
"_context",
|
|
]
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def __init__(
|
|
self,
|
|
channel: cygrpc.Channel,
|
|
managed_call: IntegratedCallFactory,
|
|
method: bytes,
|
|
target: bytes,
|
|
request_serializer: SerializingFunction,
|
|
response_deserializer: DeserializingFunction,
|
|
_registered_call_handle: Optional[int],
|
|
):
|
|
self._channel = channel
|
|
self._managed_call = managed_call
|
|
self._method = method
|
|
self._target = target
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
self._context = cygrpc.build_census_context()
|
|
self._registered_call_handle = _registered_call_handle
|
|
|
|
def __call__( # pylint: disable=too-many-locals
|
|
self,
|
|
request: Any,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> _MultiThreadedRendezvous:
|
|
deadline, serialized_request, rendezvous = _start_unary_request(
|
|
request, timeout, self._request_serializer
|
|
)
|
|
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
|
|
wait_for_ready
|
|
)
|
|
if serialized_request is None:
|
|
raise rendezvous # pylint: disable-msg=raising-bad-type
|
|
else:
|
|
augmented_metadata = _compression.augment_metadata(
|
|
metadata, compression
|
|
)
|
|
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
|
|
operations = (
|
|
(
|
|
cygrpc.SendInitialMetadataOperation(
|
|
augmented_metadata, initial_metadata_flags
|
|
),
|
|
cygrpc.SendMessageOperation(
|
|
serialized_request, _EMPTY_FLAGS
|
|
),
|
|
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
|
|
),
|
|
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
|
|
)
|
|
state.rpc_start_time = time.perf_counter()
|
|
state.method = _common.decode(self._method)
|
|
state.target = _common.decode(self._target)
|
|
call = self._managed_call(
|
|
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
|
|
self._method,
|
|
None,
|
|
_determine_deadline(deadline),
|
|
metadata,
|
|
None if credentials is None else credentials._credentials,
|
|
operations,
|
|
_event_handler(state, self._response_deserializer),
|
|
self._context,
|
|
self._registered_call_handle,
|
|
)
|
|
return _MultiThreadedRendezvous(
|
|
state, call, self._response_deserializer, deadline
|
|
)
|
|
|
|
|
|
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
|
|
_channel: cygrpc.Channel
|
|
_managed_call: IntegratedCallFactory
|
|
_method: bytes
|
|
_target: bytes
|
|
_request_serializer: Optional[SerializingFunction]
|
|
_response_deserializer: Optional[DeserializingFunction]
|
|
_context: Any
|
|
_registered_call_handle: Optional[int]
|
|
|
|
__slots__ = [
|
|
"_channel",
|
|
"_managed_call",
|
|
"_method",
|
|
"_target",
|
|
"_request_serializer",
|
|
"_response_deserializer",
|
|
"_context",
|
|
]
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def __init__(
|
|
self,
|
|
channel: cygrpc.Channel,
|
|
managed_call: IntegratedCallFactory,
|
|
method: bytes,
|
|
target: bytes,
|
|
request_serializer: Optional[SerializingFunction],
|
|
response_deserializer: Optional[DeserializingFunction],
|
|
_registered_call_handle: Optional[int],
|
|
):
|
|
self._channel = channel
|
|
self._managed_call = managed_call
|
|
self._method = method
|
|
self._target = target
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
self._context = cygrpc.build_census_context()
|
|
self._registered_call_handle = _registered_call_handle
|
|
|
|
def _blocking(
|
|
self,
|
|
request_iterator: Iterator,
|
|
timeout: Optional[float],
|
|
metadata: Optional[MetadataType],
|
|
credentials: Optional[grpc.CallCredentials],
|
|
wait_for_ready: Optional[bool],
|
|
compression: Optional[grpc.Compression],
|
|
) -> Tuple[_RPCState, cygrpc.SegregatedCall]:
|
|
deadline = _deadline(timeout)
|
|
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
|
|
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
|
|
wait_for_ready
|
|
)
|
|
augmented_metadata = _compression.augment_metadata(
|
|
metadata, compression
|
|
)
|
|
state.rpc_start_time = time.perf_counter()
|
|
state.method = _common.decode(self._method)
|
|
state.target = _common.decode(self._target)
|
|
call = self._channel.segregated_call(
|
|
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
|
|
self._method,
|
|
None,
|
|
_determine_deadline(deadline),
|
|
augmented_metadata,
|
|
None if credentials is None else credentials._credentials,
|
|
_stream_unary_invocation_operations_and_tags(
|
|
augmented_metadata, initial_metadata_flags
|
|
),
|
|
self._context,
|
|
self._registered_call_handle,
|
|
)
|
|
_consume_request_iterator(
|
|
request_iterator, state, call, self._request_serializer, None
|
|
)
|
|
while True:
|
|
event = call.next_event()
|
|
with state.condition:
|
|
_handle_event(event, state, self._response_deserializer)
|
|
state.condition.notify_all()
|
|
if not state.due:
|
|
break
|
|
return state, call
|
|
|
|
def __call__(
|
|
self,
|
|
request_iterator: Iterator,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> Any:
|
|
(
|
|
state,
|
|
call,
|
|
) = self._blocking(
|
|
request_iterator,
|
|
timeout,
|
|
metadata,
|
|
credentials,
|
|
wait_for_ready,
|
|
compression,
|
|
)
|
|
return _end_unary_response_blocking(state, call, False, None)
|
|
|
|
def with_call(
|
|
self,
|
|
request_iterator: Iterator,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> Tuple[Any, grpc.Call]:
|
|
(
|
|
state,
|
|
call,
|
|
) = self._blocking(
|
|
request_iterator,
|
|
timeout,
|
|
metadata,
|
|
credentials,
|
|
wait_for_ready,
|
|
compression,
|
|
)
|
|
return _end_unary_response_blocking(state, call, True, None)
|
|
|
|
def future(
|
|
self,
|
|
request_iterator: Iterator,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> _MultiThreadedRendezvous:
|
|
deadline = _deadline(timeout)
|
|
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
|
|
event_handler = _event_handler(state, self._response_deserializer)
|
|
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
|
|
wait_for_ready
|
|
)
|
|
augmented_metadata = _compression.augment_metadata(
|
|
metadata, compression
|
|
)
|
|
state.rpc_start_time = time.perf_counter()
|
|
state.method = _common.decode(self._method)
|
|
state.target = _common.decode(self._target)
|
|
call = self._managed_call(
|
|
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
|
|
self._method,
|
|
None,
|
|
deadline,
|
|
augmented_metadata,
|
|
None if credentials is None else credentials._credentials,
|
|
_stream_unary_invocation_operations(
|
|
metadata, initial_metadata_flags
|
|
),
|
|
event_handler,
|
|
self._context,
|
|
self._registered_call_handle,
|
|
)
|
|
_consume_request_iterator(
|
|
request_iterator,
|
|
state,
|
|
call,
|
|
self._request_serializer,
|
|
event_handler,
|
|
)
|
|
return _MultiThreadedRendezvous(
|
|
state, call, self._response_deserializer, deadline
|
|
)
|
|
|
|
|
|
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
|
|
_channel: cygrpc.Channel
|
|
_managed_call: IntegratedCallFactory
|
|
_method: bytes
|
|
_target: bytes
|
|
_request_serializer: Optional[SerializingFunction]
|
|
_response_deserializer: Optional[DeserializingFunction]
|
|
_context: Any
|
|
_registered_call_handle: Optional[int]
|
|
|
|
__slots__ = [
|
|
"_channel",
|
|
"_managed_call",
|
|
"_method",
|
|
"_target",
|
|
"_request_serializer",
|
|
"_response_deserializer",
|
|
"_context",
|
|
]
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def __init__(
|
|
self,
|
|
channel: cygrpc.Channel,
|
|
managed_call: IntegratedCallFactory,
|
|
method: bytes,
|
|
target: bytes,
|
|
request_serializer: Optional[SerializingFunction],
|
|
response_deserializer: Optional[DeserializingFunction],
|
|
_registered_call_handle: Optional[int],
|
|
):
|
|
self._channel = channel
|
|
self._managed_call = managed_call
|
|
self._method = method
|
|
self._target = target
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
self._context = cygrpc.build_census_context()
|
|
self._registered_call_handle = _registered_call_handle
|
|
|
|
def __call__(
|
|
self,
|
|
request_iterator: Iterator,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[MetadataType] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
) -> _MultiThreadedRendezvous:
|
|
deadline = _deadline(timeout)
|
|
state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None)
|
|
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
|
|
wait_for_ready
|
|
)
|
|
augmented_metadata = _compression.augment_metadata(
|
|
metadata, compression
|
|
)
|
|
operations = (
|
|
(
|
|
cygrpc.SendInitialMetadataOperation(
|
|
augmented_metadata, initial_metadata_flags
|
|
),
|
|
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
|
|
),
|
|
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
|
|
)
|
|
event_handler = _event_handler(state, self._response_deserializer)
|
|
state.rpc_start_time = time.perf_counter()
|
|
state.method = _common.decode(self._method)
|
|
state.target = _common.decode(self._target)
|
|
call = self._managed_call(
|
|
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
|
|
self._method,
|
|
None,
|
|
_determine_deadline(deadline),
|
|
augmented_metadata,
|
|
None if credentials is None else credentials._credentials,
|
|
operations,
|
|
event_handler,
|
|
self._context,
|
|
self._registered_call_handle,
|
|
)
|
|
_consume_request_iterator(
|
|
request_iterator,
|
|
state,
|
|
call,
|
|
self._request_serializer,
|
|
event_handler,
|
|
)
|
|
return _MultiThreadedRendezvous(
|
|
state, call, self._response_deserializer, deadline
|
|
)
|
|
|
|
|
|
class _InitialMetadataFlags(int):
|
|
"""Stores immutable initial metadata flags"""
|
|
|
|
def __new__(cls, value: int = _EMPTY_FLAGS):
|
|
value &= cygrpc.InitialMetadataFlags.used_mask
|
|
return super(_InitialMetadataFlags, cls).__new__(cls, value)
|
|
|
|
def with_wait_for_ready(self, wait_for_ready: Optional[bool]) -> int:
|
|
if wait_for_ready is not None:
|
|
if wait_for_ready:
|
|
return self.__class__(
|
|
self
|
|
| cygrpc.InitialMetadataFlags.wait_for_ready
|
|
| cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set
|
|
)
|
|
elif not wait_for_ready:
|
|
return self.__class__(
|
|
self & ~cygrpc.InitialMetadataFlags.wait_for_ready
|
|
| cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set
|
|
)
|
|
return self
|
|
|
|
|
|
class _ChannelCallState(object):
|
|
channel: cygrpc.Channel
|
|
managed_calls: int
|
|
threading: bool
|
|
|
|
def __init__(self, channel: cygrpc.Channel):
|
|
self.lock = threading.Lock()
|
|
self.channel = channel
|
|
self.managed_calls = 0
|
|
self.threading = False
|
|
|
|
def reset_postfork_child(self) -> None:
|
|
self.managed_calls = 0
|
|
|
|
def __del__(self):
|
|
try:
|
|
self.channel.close(
|
|
cygrpc.StatusCode.cancelled, "Channel deallocated!"
|
|
)
|
|
except (TypeError, AttributeError):
|
|
pass
|
|
|
|
|
|
def _run_channel_spin_thread(state: _ChannelCallState) -> None:
|
|
def channel_spin():
|
|
while True:
|
|
cygrpc.block_if_fork_in_progress(state)
|
|
event = state.channel.next_call_event()
|
|
if event.completion_type == cygrpc.CompletionType.queue_timeout:
|
|
continue
|
|
call_completed = event.tag(event)
|
|
if call_completed:
|
|
with state.lock:
|
|
state.managed_calls -= 1
|
|
if state.managed_calls == 0:
|
|
return
|
|
|
|
channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin)
|
|
channel_spin_thread.setDaemon(True)
|
|
channel_spin_thread.start()
|
|
|
|
|
|
def _channel_managed_call_management(state: _ChannelCallState):
|
|
# pylint: disable=too-many-arguments
|
|
def create(
|
|
flags: int,
|
|
method: bytes,
|
|
host: Optional[str],
|
|
deadline: Optional[float],
|
|
metadata: Optional[MetadataType],
|
|
credentials: Optional[cygrpc.CallCredentials],
|
|
operations: Sequence[Sequence[cygrpc.Operation]],
|
|
event_handler: UserTag,
|
|
context: Any,
|
|
_registered_call_handle: Optional[int],
|
|
) -> cygrpc.IntegratedCall:
|
|
"""Creates a cygrpc.IntegratedCall.
|
|
|
|
Args:
|
|
flags: An integer bitfield of call flags.
|
|
method: The RPC method.
|
|
host: A host string for the created call.
|
|
deadline: A float to be the deadline of the created call or None if
|
|
the call is to have an infinite deadline.
|
|
metadata: The metadata for the call or None.
|
|
credentials: A cygrpc.CallCredentials or None.
|
|
operations: A sequence of sequences of cygrpc.Operations to be
|
|
started on the call.
|
|
event_handler: A behavior to call to handle the events resultant from
|
|
the operations on the call.
|
|
context: Context object for distributed tracing.
|
|
_registered_call_handle: An int representing the call handle of the
|
|
method, or None if the method is not registered.
|
|
Returns:
|
|
A cygrpc.IntegratedCall with which to conduct an RPC.
|
|
"""
|
|
operations_and_tags = tuple(
|
|
(
|
|
operation,
|
|
event_handler,
|
|
)
|
|
for operation in operations
|
|
)
|
|
with state.lock:
|
|
call = state.channel.integrated_call(
|
|
flags,
|
|
method,
|
|
host,
|
|
deadline,
|
|
metadata,
|
|
credentials,
|
|
operations_and_tags,
|
|
context,
|
|
_registered_call_handle,
|
|
)
|
|
if state.managed_calls == 0:
|
|
state.managed_calls = 1
|
|
_run_channel_spin_thread(state)
|
|
else:
|
|
state.managed_calls += 1
|
|
return call
|
|
|
|
return create
|
|
|
|
|
|
class _ChannelConnectivityState(object):
|
|
lock: threading.RLock
|
|
channel: grpc.Channel
|
|
polling: bool
|
|
connectivity: grpc.ChannelConnectivity
|
|
try_to_connect: bool
|
|
# TODO(xuanwn): Refactor this: https://github.com/grpc/grpc/issues/31704
|
|
callbacks_and_connectivities: List[
|
|
Sequence[
|
|
Union[
|
|
Callable[[grpc.ChannelConnectivity], None],
|
|
Optional[grpc.ChannelConnectivity],
|
|
]
|
|
]
|
|
]
|
|
delivering: bool
|
|
|
|
def __init__(self, channel: grpc.Channel):
|
|
self.lock = threading.RLock()
|
|
self.channel = channel
|
|
self.polling = False
|
|
self.connectivity = None
|
|
self.try_to_connect = False
|
|
self.callbacks_and_connectivities = []
|
|
self.delivering = False
|
|
|
|
def reset_postfork_child(self) -> None:
|
|
self.polling = False
|
|
self.connectivity = None
|
|
self.try_to_connect = False
|
|
self.callbacks_and_connectivities = []
|
|
self.delivering = False
|
|
|
|
|
|
def _deliveries(
|
|
state: _ChannelConnectivityState,
|
|
) -> List[Callable[[grpc.ChannelConnectivity], None]]:
|
|
callbacks_needing_update = []
|
|
for callback_and_connectivity in state.callbacks_and_connectivities:
|
|
(
|
|
callback,
|
|
callback_connectivity,
|
|
) = callback_and_connectivity
|
|
if callback_connectivity is not state.connectivity:
|
|
callbacks_needing_update.append(callback)
|
|
callback_and_connectivity[1] = state.connectivity
|
|
return callbacks_needing_update
|
|
|
|
|
|
def _deliver(
|
|
state: _ChannelConnectivityState,
|
|
initial_connectivity: grpc.ChannelConnectivity,
|
|
initial_callbacks: Sequence[Callable[[grpc.ChannelConnectivity], None]],
|
|
) -> None:
|
|
connectivity = initial_connectivity
|
|
callbacks = initial_callbacks
|
|
while True:
|
|
for callback in callbacks:
|
|
cygrpc.block_if_fork_in_progress(state)
|
|
try:
|
|
callback(connectivity)
|
|
except Exception: # pylint: disable=broad-except
|
|
_LOGGER.exception(
|
|
_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE
|
|
)
|
|
with state.lock:
|
|
callbacks = _deliveries(state)
|
|
if callbacks:
|
|
connectivity = state.connectivity
|
|
else:
|
|
state.delivering = False
|
|
return
|
|
|
|
|
|
def _spawn_delivery(
|
|
state: _ChannelConnectivityState,
|
|
callbacks: Sequence[Callable[[grpc.ChannelConnectivity], None]],
|
|
) -> None:
|
|
delivering_thread = cygrpc.ForkManagedThread(
|
|
target=_deliver,
|
|
args=(
|
|
state,
|
|
state.connectivity,
|
|
callbacks,
|
|
),
|
|
)
|
|
delivering_thread.setDaemon(True)
|
|
delivering_thread.start()
|
|
state.delivering = True
|
|
|
|
|
|
# NOTE(https://github.com/grpc/grpc/issues/3064): We'd rather not poll.
|
|
def _poll_connectivity(
|
|
state: _ChannelConnectivityState,
|
|
channel: grpc.Channel,
|
|
initial_try_to_connect: bool,
|
|
) -> None:
|
|
try_to_connect = initial_try_to_connect
|
|
connectivity = channel.check_connectivity_state(try_to_connect)
|
|
with state.lock:
|
|
state.connectivity = (
|
|
_common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
|
|
connectivity
|
|
]
|
|
)
|
|
callbacks = tuple(
|
|
callback for callback, _ in state.callbacks_and_connectivities
|
|
)
|
|
for callback_and_connectivity in state.callbacks_and_connectivities:
|
|
callback_and_connectivity[1] = state.connectivity
|
|
if callbacks:
|
|
_spawn_delivery(state, callbacks)
|
|
while True:
|
|
event = channel.watch_connectivity_state(
|
|
connectivity, time.time() + 0.2
|
|
)
|
|
cygrpc.block_if_fork_in_progress(state)
|
|
with state.lock:
|
|
if (
|
|
not state.callbacks_and_connectivities
|
|
and not state.try_to_connect
|
|
):
|
|
state.polling = False
|
|
state.connectivity = None
|
|
break
|
|
try_to_connect = state.try_to_connect
|
|
state.try_to_connect = False
|
|
if event.success or try_to_connect:
|
|
connectivity = channel.check_connectivity_state(try_to_connect)
|
|
with state.lock:
|
|
state.connectivity = (
|
|
_common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
|
|
connectivity
|
|
]
|
|
)
|
|
if not state.delivering:
|
|
callbacks = _deliveries(state)
|
|
if callbacks:
|
|
_spawn_delivery(state, callbacks)
|
|
|
|
|
|
def _subscribe(
|
|
state: _ChannelConnectivityState,
|
|
callback: Callable[[grpc.ChannelConnectivity], None],
|
|
try_to_connect: bool,
|
|
) -> None:
|
|
with state.lock:
|
|
if not state.callbacks_and_connectivities and not state.polling:
|
|
polling_thread = cygrpc.ForkManagedThread(
|
|
target=_poll_connectivity,
|
|
args=(state, state.channel, bool(try_to_connect)),
|
|
)
|
|
polling_thread.setDaemon(True)
|
|
polling_thread.start()
|
|
state.polling = True
|
|
state.callbacks_and_connectivities.append([callback, None])
|
|
elif not state.delivering and state.connectivity is not None:
|
|
_spawn_delivery(state, (callback,))
|
|
state.try_to_connect |= bool(try_to_connect)
|
|
state.callbacks_and_connectivities.append(
|
|
[callback, state.connectivity]
|
|
)
|
|
else:
|
|
state.try_to_connect |= bool(try_to_connect)
|
|
state.callbacks_and_connectivities.append([callback, None])
|
|
|
|
|
|
def _unsubscribe(
|
|
state: _ChannelConnectivityState,
|
|
callback: Callable[[grpc.ChannelConnectivity], None],
|
|
) -> None:
|
|
with state.lock:
|
|
for index, (subscribed_callback, unused_connectivity) in enumerate(
|
|
state.callbacks_and_connectivities
|
|
):
|
|
if callback == subscribed_callback:
|
|
state.callbacks_and_connectivities.pop(index)
|
|
break
|
|
|
|
|
|
def _augment_options(
|
|
base_options: Sequence[ChannelArgumentType],
|
|
compression: Optional[grpc.Compression],
|
|
) -> Sequence[ChannelArgumentType]:
|
|
compression_option = _compression.create_channel_option(compression)
|
|
return (
|
|
tuple(base_options)
|
|
+ compression_option
|
|
+ (
|
|
(
|
|
cygrpc.ChannelArgKey.primary_user_agent_string,
|
|
_USER_AGENT,
|
|
),
|
|
)
|
|
)
|
|
|
|
|
|
def _separate_channel_options(
|
|
options: Sequence[ChannelArgumentType],
|
|
) -> Tuple[Sequence[ChannelArgumentType], Sequence[ChannelArgumentType]]:
|
|
"""Separates core channel options from Python channel options."""
|
|
core_options = []
|
|
python_options = []
|
|
for pair in options:
|
|
if (
|
|
pair[0]
|
|
== grpc.experimental.ChannelOptions.SingleThreadedUnaryStream
|
|
):
|
|
python_options.append(pair)
|
|
else:
|
|
core_options.append(pair)
|
|
return python_options, core_options
|
|
|
|
|
|
class Channel(grpc.Channel):
|
|
"""A cygrpc.Channel-backed implementation of grpc.Channel."""
|
|
|
|
_single_threaded_unary_stream: bool
|
|
_channel: cygrpc.Channel
|
|
_call_state: _ChannelCallState
|
|
_connectivity_state: _ChannelConnectivityState
|
|
_target: str
|
|
_registered_call_handles: Dict[str, int]
|
|
|
|
def __init__(
|
|
self,
|
|
target: str,
|
|
options: Sequence[ChannelArgumentType],
|
|
credentials: Optional[grpc.ChannelCredentials],
|
|
compression: Optional[grpc.Compression],
|
|
):
|
|
"""Constructor.
|
|
|
|
Args:
|
|
target: The target to which to connect.
|
|
options: Configuration options for the channel.
|
|
credentials: A cygrpc.ChannelCredentials or None.
|
|
compression: An optional value indicating the compression method to be
|
|
used over the lifetime of the channel.
|
|
"""
|
|
python_options, core_options = _separate_channel_options(options)
|
|
self._single_threaded_unary_stream = (
|
|
_DEFAULT_SINGLE_THREADED_UNARY_STREAM
|
|
)
|
|
self._process_python_options(python_options)
|
|
self._channel = cygrpc.Channel(
|
|
_common.encode(target),
|
|
_augment_options(core_options, compression),
|
|
credentials,
|
|
)
|
|
self._target = target
|
|
self._call_state = _ChannelCallState(self._channel)
|
|
self._connectivity_state = _ChannelConnectivityState(self._channel)
|
|
cygrpc.fork_register_channel(self)
|
|
if cygrpc.g_gevent_activated:
|
|
cygrpc.gevent_increment_channel_count()
|
|
|
|
def _get_registered_call_handle(self, method: str) -> int:
|
|
"""
|
|
Get the registered call handle for a method.
|
|
|
|
This is a semi-private method. It is intended for use only by gRPC generated code.
|
|
|
|
This method is not thread-safe.
|
|
|
|
Args:
|
|
method: Required, the method name for the RPC.
|
|
|
|
Returns:
|
|
The registered call handle pointer in the form of a Python Long.
|
|
"""
|
|
return self._channel.get_registered_call_handle(_common.encode(method))
|
|
|
|
def _process_python_options(
|
|
self, python_options: Sequence[ChannelArgumentType]
|
|
) -> None:
|
|
"""Sets channel attributes according to python-only channel options."""
|
|
for pair in python_options:
|
|
if (
|
|
pair[0]
|
|
== grpc.experimental.ChannelOptions.SingleThreadedUnaryStream
|
|
):
|
|
self._single_threaded_unary_stream = True
|
|
|
|
def subscribe(
|
|
self,
|
|
callback: Callable[[grpc.ChannelConnectivity], None],
|
|
try_to_connect: Optional[bool] = None,
|
|
) -> None:
|
|
_subscribe(self._connectivity_state, callback, try_to_connect)
|
|
|
|
def unsubscribe(
|
|
self, callback: Callable[[grpc.ChannelConnectivity], None]
|
|
) -> None:
|
|
_unsubscribe(self._connectivity_state, callback)
|
|
|
|
# pylint: disable=arguments-differ
|
|
def unary_unary(
|
|
self,
|
|
method: str,
|
|
request_serializer: Optional[SerializingFunction] = None,
|
|
response_deserializer: Optional[DeserializingFunction] = None,
|
|
_registered_method: Optional[bool] = False,
|
|
) -> grpc.UnaryUnaryMultiCallable:
|
|
_registered_call_handle = None
|
|
if _registered_method:
|
|
_registered_call_handle = self._get_registered_call_handle(method)
|
|
return _UnaryUnaryMultiCallable(
|
|
self._channel,
|
|
_channel_managed_call_management(self._call_state),
|
|
_common.encode(method),
|
|
_common.encode(self._target),
|
|
request_serializer,
|
|
response_deserializer,
|
|
_registered_call_handle,
|
|
)
|
|
|
|
# pylint: disable=arguments-differ
|
|
def unary_stream(
|
|
self,
|
|
method: str,
|
|
request_serializer: Optional[SerializingFunction] = None,
|
|
response_deserializer: Optional[DeserializingFunction] = None,
|
|
_registered_method: Optional[bool] = False,
|
|
) -> grpc.UnaryStreamMultiCallable:
|
|
_registered_call_handle = None
|
|
if _registered_method:
|
|
_registered_call_handle = self._get_registered_call_handle(method)
|
|
# NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC
|
|
# on a single Python thread results in an appreciable speed-up. However,
|
|
# due to slight differences in capability, the multi-threaded variant
|
|
# remains the default.
|
|
if self._single_threaded_unary_stream:
|
|
return _SingleThreadedUnaryStreamMultiCallable(
|
|
self._channel,
|
|
_common.encode(method),
|
|
_common.encode(self._target),
|
|
request_serializer,
|
|
response_deserializer,
|
|
_registered_call_handle,
|
|
)
|
|
else:
|
|
return _UnaryStreamMultiCallable(
|
|
self._channel,
|
|
_channel_managed_call_management(self._call_state),
|
|
_common.encode(method),
|
|
_common.encode(self._target),
|
|
request_serializer,
|
|
response_deserializer,
|
|
_registered_call_handle,
|
|
)
|
|
|
|
# pylint: disable=arguments-differ
|
|
def stream_unary(
|
|
self,
|
|
method: str,
|
|
request_serializer: Optional[SerializingFunction] = None,
|
|
response_deserializer: Optional[DeserializingFunction] = None,
|
|
_registered_method: Optional[bool] = False,
|
|
) -> grpc.StreamUnaryMultiCallable:
|
|
_registered_call_handle = None
|
|
if _registered_method:
|
|
_registered_call_handle = self._get_registered_call_handle(method)
|
|
return _StreamUnaryMultiCallable(
|
|
self._channel,
|
|
_channel_managed_call_management(self._call_state),
|
|
_common.encode(method),
|
|
_common.encode(self._target),
|
|
request_serializer,
|
|
response_deserializer,
|
|
_registered_call_handle,
|
|
)
|
|
|
|
# pylint: disable=arguments-differ
|
|
def stream_stream(
|
|
self,
|
|
method: str,
|
|
request_serializer: Optional[SerializingFunction] = None,
|
|
response_deserializer: Optional[DeserializingFunction] = None,
|
|
_registered_method: Optional[bool] = False,
|
|
) -> grpc.StreamStreamMultiCallable:
|
|
_registered_call_handle = None
|
|
if _registered_method:
|
|
_registered_call_handle = self._get_registered_call_handle(method)
|
|
return _StreamStreamMultiCallable(
|
|
self._channel,
|
|
_channel_managed_call_management(self._call_state),
|
|
_common.encode(method),
|
|
_common.encode(self._target),
|
|
request_serializer,
|
|
response_deserializer,
|
|
_registered_call_handle,
|
|
)
|
|
|
|
def _unsubscribe_all(self) -> None:
|
|
state = self._connectivity_state
|
|
if state:
|
|
with state.lock:
|
|
del state.callbacks_and_connectivities[:]
|
|
|
|
def _close(self) -> None:
|
|
self._unsubscribe_all()
|
|
self._channel.close(cygrpc.StatusCode.cancelled, "Channel closed!")
|
|
cygrpc.fork_unregister_channel(self)
|
|
if cygrpc.g_gevent_activated:
|
|
cygrpc.gevent_decrement_channel_count()
|
|
|
|
def _close_on_fork(self) -> None:
|
|
self._unsubscribe_all()
|
|
self._channel.close_on_fork(
|
|
cygrpc.StatusCode.cancelled, "Channel closed due to fork"
|
|
)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self._close()
|
|
return False
|
|
|
|
def close(self) -> None:
|
|
self._close()
|
|
|
|
def __del__(self):
|
|
# TODO(https://github.com/grpc/grpc/issues/12531): Several releases
|
|
# after 1.12 (1.16 or thereabouts?) add a "self._channel.close" call
|
|
# here (or more likely, call self._close() here). We don't do this today
|
|
# because many valid use cases today allow the channel to be deleted
|
|
# immediately after stubs are created. After a sufficient period of time
|
|
# has passed for all users to be trusted to freeze out to their channels
|
|
# for as long as they are in use and to close them after using them,
|
|
# then deletion of this grpc._channel.Channel instance can be made to
|
|
# effect closure of the underlying cygrpc.Channel instance.
|
|
try:
|
|
self._unsubscribe_all()
|
|
except: # pylint: disable=bare-except
|
|
# Exceptions in __del__ are ignored by Python anyway, but they can
|
|
# keep spamming logs. Just silence them.
|
|
pass
|