493 lines
20 KiB
Python
493 lines
20 KiB
Python
# Copyright 2019 gRPC authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Invocation-side implementation of gRPC Asyncio Python."""
|
|
|
|
import asyncio
|
|
import sys
|
|
from typing import Any, Iterable, List, Optional, Sequence
|
|
|
|
import grpc
|
|
from grpc import _common
|
|
from grpc import _compression
|
|
from grpc import _grpcio_metadata
|
|
from grpc._cython import cygrpc
|
|
|
|
from . import _base_call
|
|
from . import _base_channel
|
|
from ._call import StreamStreamCall
|
|
from ._call import StreamUnaryCall
|
|
from ._call import UnaryStreamCall
|
|
from ._call import UnaryUnaryCall
|
|
from ._interceptor import ClientInterceptor
|
|
from ._interceptor import InterceptedStreamStreamCall
|
|
from ._interceptor import InterceptedStreamUnaryCall
|
|
from ._interceptor import InterceptedUnaryStreamCall
|
|
from ._interceptor import InterceptedUnaryUnaryCall
|
|
from ._interceptor import StreamStreamClientInterceptor
|
|
from ._interceptor import StreamUnaryClientInterceptor
|
|
from ._interceptor import UnaryStreamClientInterceptor
|
|
from ._interceptor import UnaryUnaryClientInterceptor
|
|
from ._metadata import Metadata
|
|
from ._typing import ChannelArgumentType
|
|
from ._typing import DeserializingFunction
|
|
from ._typing import RequestIterableType
|
|
from ._typing import SerializingFunction
|
|
from ._utils import _timeout_to_deadline
|
|
|
|
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
|
|
|
|
if sys.version_info[1] < 7:
|
|
|
|
def _all_tasks() -> Iterable[asyncio.Task]:
|
|
return asyncio.Task.all_tasks()
|
|
else:
|
|
|
|
def _all_tasks() -> Iterable[asyncio.Task]:
|
|
return asyncio.all_tasks()
|
|
|
|
|
|
def _augment_channel_arguments(base_options: ChannelArgumentType,
|
|
compression: Optional[grpc.Compression]):
|
|
compression_channel_argument = _compression.create_channel_option(
|
|
compression)
|
|
user_agent_channel_argument = ((
|
|
cygrpc.ChannelArgKey.primary_user_agent_string,
|
|
_USER_AGENT,
|
|
),)
|
|
return tuple(base_options
|
|
) + compression_channel_argument + user_agent_channel_argument
|
|
|
|
|
|
class _BaseMultiCallable:
|
|
"""Base class of all multi callable objects.
|
|
|
|
Handles the initialization logic and stores common attributes.
|
|
"""
|
|
_loop: asyncio.AbstractEventLoop
|
|
_channel: cygrpc.AioChannel
|
|
_method: bytes
|
|
_request_serializer: SerializingFunction
|
|
_response_deserializer: DeserializingFunction
|
|
_interceptors: Optional[Sequence[ClientInterceptor]]
|
|
_references: List[Any]
|
|
_loop: asyncio.AbstractEventLoop
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def __init__(
|
|
self,
|
|
channel: cygrpc.AioChannel,
|
|
method: bytes,
|
|
request_serializer: SerializingFunction,
|
|
response_deserializer: DeserializingFunction,
|
|
interceptors: Optional[Sequence[ClientInterceptor]],
|
|
references: List[Any],
|
|
loop: asyncio.AbstractEventLoop,
|
|
) -> None:
|
|
self._loop = loop
|
|
self._channel = channel
|
|
self._method = method
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
self._interceptors = interceptors
|
|
self._references = references
|
|
|
|
@staticmethod
|
|
def _init_metadata(
|
|
metadata: Optional[Metadata] = None,
|
|
compression: Optional[grpc.Compression] = None) -> Metadata:
|
|
"""Based on the provided values for <metadata> or <compression> initialise the final
|
|
metadata, as it should be used for the current call.
|
|
"""
|
|
metadata = metadata or Metadata()
|
|
if compression:
|
|
metadata = Metadata(
|
|
*_compression.augment_metadata(metadata, compression))
|
|
return metadata
|
|
|
|
|
|
class UnaryUnaryMultiCallable(_BaseMultiCallable,
|
|
_base_channel.UnaryUnaryMultiCallable):
|
|
|
|
def __call__(
|
|
self,
|
|
request: Any,
|
|
*,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[Metadata] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None
|
|
) -> _base_call.UnaryUnaryCall:
|
|
|
|
metadata = self._init_metadata(metadata, compression)
|
|
if not self._interceptors:
|
|
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
|
|
metadata, credentials, wait_for_ready,
|
|
self._channel, self._method,
|
|
self._request_serializer,
|
|
self._response_deserializer, self._loop)
|
|
else:
|
|
call = InterceptedUnaryUnaryCall(
|
|
self._interceptors, request, timeout, metadata, credentials,
|
|
wait_for_ready, self._channel, self._method,
|
|
self._request_serializer, self._response_deserializer,
|
|
self._loop)
|
|
|
|
return call
|
|
|
|
|
|
class UnaryStreamMultiCallable(_BaseMultiCallable,
|
|
_base_channel.UnaryStreamMultiCallable):
|
|
|
|
def __call__(
|
|
self,
|
|
request: Any,
|
|
*,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[Metadata] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None
|
|
) -> _base_call.UnaryStreamCall:
|
|
|
|
metadata = self._init_metadata(metadata, compression)
|
|
deadline = _timeout_to_deadline(timeout)
|
|
|
|
if not self._interceptors:
|
|
call = UnaryStreamCall(request, deadline, metadata, credentials,
|
|
wait_for_ready, self._channel, self._method,
|
|
self._request_serializer,
|
|
self._response_deserializer, self._loop)
|
|
else:
|
|
call = InterceptedUnaryStreamCall(
|
|
self._interceptors, request, deadline, metadata, credentials,
|
|
wait_for_ready, self._channel, self._method,
|
|
self._request_serializer, self._response_deserializer,
|
|
self._loop)
|
|
|
|
return call
|
|
|
|
|
|
class StreamUnaryMultiCallable(_BaseMultiCallable,
|
|
_base_channel.StreamUnaryMultiCallable):
|
|
|
|
def __call__(
|
|
self,
|
|
request_iterator: Optional[RequestIterableType] = None,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[Metadata] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None
|
|
) -> _base_call.StreamUnaryCall:
|
|
|
|
metadata = self._init_metadata(metadata, compression)
|
|
deadline = _timeout_to_deadline(timeout)
|
|
|
|
if not self._interceptors:
|
|
call = StreamUnaryCall(request_iterator, deadline, metadata,
|
|
credentials, wait_for_ready, self._channel,
|
|
self._method, self._request_serializer,
|
|
self._response_deserializer, self._loop)
|
|
else:
|
|
call = InterceptedStreamUnaryCall(
|
|
self._interceptors, request_iterator, deadline, metadata,
|
|
credentials, wait_for_ready, self._channel, self._method,
|
|
self._request_serializer, self._response_deserializer,
|
|
self._loop)
|
|
|
|
return call
|
|
|
|
|
|
class StreamStreamMultiCallable(_BaseMultiCallable,
|
|
_base_channel.StreamStreamMultiCallable):
|
|
|
|
def __call__(
|
|
self,
|
|
request_iterator: Optional[RequestIterableType] = None,
|
|
timeout: Optional[float] = None,
|
|
metadata: Optional[Metadata] = None,
|
|
credentials: Optional[grpc.CallCredentials] = None,
|
|
wait_for_ready: Optional[bool] = None,
|
|
compression: Optional[grpc.Compression] = None
|
|
) -> _base_call.StreamStreamCall:
|
|
|
|
metadata = self._init_metadata(metadata, compression)
|
|
deadline = _timeout_to_deadline(timeout)
|
|
|
|
if not self._interceptors:
|
|
call = StreamStreamCall(request_iterator, deadline, metadata,
|
|
credentials, wait_for_ready, self._channel,
|
|
self._method, self._request_serializer,
|
|
self._response_deserializer, self._loop)
|
|
else:
|
|
call = InterceptedStreamStreamCall(
|
|
self._interceptors, request_iterator, deadline, metadata,
|
|
credentials, wait_for_ready, self._channel, self._method,
|
|
self._request_serializer, self._response_deserializer,
|
|
self._loop)
|
|
|
|
return call
|
|
|
|
|
|
class Channel(_base_channel.Channel):
|
|
_loop: asyncio.AbstractEventLoop
|
|
_channel: cygrpc.AioChannel
|
|
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
|
|
_unary_stream_interceptors: List[UnaryStreamClientInterceptor]
|
|
_stream_unary_interceptors: List[StreamUnaryClientInterceptor]
|
|
_stream_stream_interceptors: List[StreamStreamClientInterceptor]
|
|
|
|
def __init__(self, target: str, options: ChannelArgumentType,
|
|
credentials: Optional[grpc.ChannelCredentials],
|
|
compression: Optional[grpc.Compression],
|
|
interceptors: Optional[Sequence[ClientInterceptor]]):
|
|
"""Constructor.
|
|
|
|
Args:
|
|
target: The target to which to connect.
|
|
options: Configuration options for the channel.
|
|
credentials: A cygrpc.ChannelCredentials or None.
|
|
compression: An optional value indicating the compression method to be
|
|
used over the lifetime of the channel.
|
|
interceptors: An optional list of interceptors that would be used for
|
|
intercepting any RPC executed with that channel.
|
|
"""
|
|
self._unary_unary_interceptors = []
|
|
self._unary_stream_interceptors = []
|
|
self._stream_unary_interceptors = []
|
|
self._stream_stream_interceptors = []
|
|
|
|
if interceptors is not None:
|
|
for interceptor in interceptors:
|
|
if isinstance(interceptor, UnaryUnaryClientInterceptor):
|
|
self._unary_unary_interceptors.append(interceptor)
|
|
elif isinstance(interceptor, UnaryStreamClientInterceptor):
|
|
self._unary_stream_interceptors.append(interceptor)
|
|
elif isinstance(interceptor, StreamUnaryClientInterceptor):
|
|
self._stream_unary_interceptors.append(interceptor)
|
|
elif isinstance(interceptor, StreamStreamClientInterceptor):
|
|
self._stream_stream_interceptors.append(interceptor)
|
|
else:
|
|
raise ValueError(
|
|
"Interceptor {} must be ".format(interceptor) +
|
|
"{} or ".format(UnaryUnaryClientInterceptor.__name__) +
|
|
"{} or ".format(UnaryStreamClientInterceptor.__name__) +
|
|
"{} or ".format(StreamUnaryClientInterceptor.__name__) +
|
|
"{}. ".format(StreamStreamClientInterceptor.__name__))
|
|
|
|
self._loop = cygrpc.get_working_loop()
|
|
self._channel = cygrpc.AioChannel(
|
|
_common.encode(target),
|
|
_augment_channel_arguments(options, compression), credentials,
|
|
self._loop)
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
await self._close(None)
|
|
|
|
async def _close(self, grace): # pylint: disable=too-many-branches
|
|
if self._channel.closed():
|
|
return
|
|
|
|
# No new calls will be accepted by the Cython channel.
|
|
self._channel.closing()
|
|
|
|
# Iterate through running tasks
|
|
tasks = _all_tasks()
|
|
calls = []
|
|
call_tasks = []
|
|
for task in tasks:
|
|
try:
|
|
stack = task.get_stack(limit=1)
|
|
except AttributeError as attribute_error:
|
|
# NOTE(lidiz) tl;dr: If the Task is created with a CPython
|
|
# object, it will trigger AttributeError.
|
|
#
|
|
# In the global finalizer, the event loop schedules
|
|
# a CPython PyAsyncGenAThrow object.
|
|
# https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
|
|
#
|
|
# However, the PyAsyncGenAThrow object is written in C and
|
|
# failed to include the normal Python frame objects. Hence,
|
|
# this exception is a false negative, and it is safe to ignore
|
|
# the failure. It is fixed by https://github.com/python/cpython/pull/18669,
|
|
# but not available until 3.9 or 3.8.3. So, we have to keep it
|
|
# for a while.
|
|
# TODO(lidiz) drop this hack after 3.8 deprecation
|
|
if 'frame' in str(attribute_error):
|
|
continue
|
|
else:
|
|
raise
|
|
|
|
# If the Task is created by a C-extension, the stack will be empty.
|
|
if not stack:
|
|
continue
|
|
|
|
# Locate ones created by `aio.Call`.
|
|
frame = stack[0]
|
|
candidate = frame.f_locals.get('self')
|
|
if candidate:
|
|
if isinstance(candidate, _base_call.Call):
|
|
if hasattr(candidate, '_channel'):
|
|
# For intercepted Call object
|
|
if candidate._channel is not self._channel:
|
|
continue
|
|
elif hasattr(candidate, '_cython_call'):
|
|
# For normal Call object
|
|
if candidate._cython_call._channel is not self._channel:
|
|
continue
|
|
else:
|
|
# Unidentified Call object
|
|
raise cygrpc.InternalError(
|
|
f'Unrecognized call object: {candidate}')
|
|
|
|
calls.append(candidate)
|
|
call_tasks.append(task)
|
|
|
|
# If needed, try to wait for them to finish.
|
|
# Call objects are not always awaitables.
|
|
if grace and call_tasks:
|
|
await asyncio.wait(call_tasks, timeout=grace)
|
|
|
|
# Time to cancel existing calls.
|
|
for call in calls:
|
|
call.cancel()
|
|
|
|
# Destroy the channel
|
|
self._channel.close()
|
|
|
|
async def close(self, grace: Optional[float] = None):
|
|
await self._close(grace)
|
|
|
|
def __del__(self):
|
|
if hasattr(self, '_channel'):
|
|
if not self._channel.closed():
|
|
self._channel.close()
|
|
|
|
def get_state(self,
|
|
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
|
|
result = self._channel.check_connectivity_state(try_to_connect)
|
|
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
|
|
|
|
async def wait_for_state_change(
|
|
self,
|
|
last_observed_state: grpc.ChannelConnectivity,
|
|
) -> None:
|
|
assert await self._channel.watch_connectivity_state(
|
|
last_observed_state.value[0], None)
|
|
|
|
async def channel_ready(self) -> None:
|
|
state = self.get_state(try_to_connect=True)
|
|
while state != grpc.ChannelConnectivity.READY:
|
|
await self.wait_for_state_change(state)
|
|
state = self.get_state(try_to_connect=True)
|
|
|
|
def unary_unary(
|
|
self,
|
|
method: str,
|
|
request_serializer: Optional[SerializingFunction] = None,
|
|
response_deserializer: Optional[DeserializingFunction] = None
|
|
) -> UnaryUnaryMultiCallable:
|
|
return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
|
|
request_serializer,
|
|
response_deserializer,
|
|
self._unary_unary_interceptors, [self],
|
|
self._loop)
|
|
|
|
def unary_stream(
|
|
self,
|
|
method: str,
|
|
request_serializer: Optional[SerializingFunction] = None,
|
|
response_deserializer: Optional[DeserializingFunction] = None
|
|
) -> UnaryStreamMultiCallable:
|
|
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
|
|
request_serializer,
|
|
response_deserializer,
|
|
self._unary_stream_interceptors, [self],
|
|
self._loop)
|
|
|
|
def stream_unary(
|
|
self,
|
|
method: str,
|
|
request_serializer: Optional[SerializingFunction] = None,
|
|
response_deserializer: Optional[DeserializingFunction] = None
|
|
) -> StreamUnaryMultiCallable:
|
|
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
|
|
request_serializer,
|
|
response_deserializer,
|
|
self._stream_unary_interceptors, [self],
|
|
self._loop)
|
|
|
|
def stream_stream(
|
|
self,
|
|
method: str,
|
|
request_serializer: Optional[SerializingFunction] = None,
|
|
response_deserializer: Optional[DeserializingFunction] = None
|
|
) -> StreamStreamMultiCallable:
|
|
return StreamStreamMultiCallable(self._channel, _common.encode(method),
|
|
request_serializer,
|
|
response_deserializer,
|
|
self._stream_stream_interceptors,
|
|
[self], self._loop)
|
|
|
|
|
|
def insecure_channel(
|
|
target: str,
|
|
options: Optional[ChannelArgumentType] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
interceptors: Optional[Sequence[ClientInterceptor]] = None):
|
|
"""Creates an insecure asynchronous Channel to a server.
|
|
|
|
Args:
|
|
target: The server address
|
|
options: An optional list of key-value pairs (:term:`channel_arguments`
|
|
in gRPC Core runtime) to configure the channel.
|
|
compression: An optional value indicating the compression method to be
|
|
used over the lifetime of the channel.
|
|
interceptors: An optional sequence of interceptors that will be executed for
|
|
any call executed with this channel.
|
|
|
|
Returns:
|
|
A Channel.
|
|
"""
|
|
return Channel(target, () if options is None else options, None,
|
|
compression, interceptors)
|
|
|
|
|
|
def secure_channel(target: str,
|
|
credentials: grpc.ChannelCredentials,
|
|
options: Optional[ChannelArgumentType] = None,
|
|
compression: Optional[grpc.Compression] = None,
|
|
interceptors: Optional[Sequence[ClientInterceptor]] = None):
|
|
"""Creates a secure asynchronous Channel to a server.
|
|
|
|
Args:
|
|
target: The server address.
|
|
credentials: A ChannelCredentials instance.
|
|
options: An optional list of key-value pairs (:term:`channel_arguments`
|
|
in gRPC Core runtime) to configure the channel.
|
|
compression: An optional value indicating the compression method to be
|
|
used over the lifetime of the channel.
|
|
interceptors: An optional sequence of interceptors that will be executed for
|
|
any call executed with this channel.
|
|
|
|
Returns:
|
|
An aio.Channel.
|
|
"""
|
|
return Channel(target, () if options is None else options,
|
|
credentials._credentials, compression, interceptors)
|