3RNN/Lib/site-packages/grpc/aio/_channel.py
2024-05-26 19:49:15 +02:00

625 lines
21 KiB
Python

# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
import sys
from typing import Any, Iterable, List, Optional, Sequence
import grpc
from grpc import _common
from grpc import _compression
from grpc import _grpcio_metadata
from grpc._cython import cygrpc
from . import _base_call
from . import _base_channel
from ._call import StreamStreamCall
from ._call import StreamUnaryCall
from ._call import UnaryStreamCall
from ._call import UnaryUnaryCall
from ._interceptor import ClientInterceptor
from ._interceptor import InterceptedStreamStreamCall
from ._interceptor import InterceptedStreamUnaryCall
from ._interceptor import InterceptedUnaryStreamCall
from ._interceptor import InterceptedUnaryUnaryCall
from ._interceptor import StreamStreamClientInterceptor
from ._interceptor import StreamUnaryClientInterceptor
from ._interceptor import UnaryStreamClientInterceptor
from ._interceptor import UnaryUnaryClientInterceptor
from ._metadata import Metadata
from ._typing import ChannelArgumentType
from ._typing import DeserializingFunction
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction
from ._utils import _timeout_to_deadline
_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__)
if sys.version_info[1] < 7:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.Task.all_tasks() # pylint: disable=no-member
else:
def _all_tasks() -> Iterable[asyncio.Task]:
return asyncio.all_tasks()
def _augment_channel_arguments(
base_options: ChannelArgumentType, compression: Optional[grpc.Compression]
):
compression_channel_argument = _compression.create_channel_option(
compression
)
user_agent_channel_argument = (
(
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),
)
return (
tuple(base_options)
+ compression_channel_argument
+ user_agent_channel_argument
)
class _BaseMultiCallable:
"""Base class of all multi callable objects.
Handles the initialization logic and stores common attributes.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_interceptors: Optional[Sequence[ClientInterceptor]]
_references: List[Any]
_loop: asyncio.AbstractEventLoop
# pylint: disable=too-many-arguments
def __init__(
self,
channel: cygrpc.AioChannel,
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[ClientInterceptor]],
references: List[Any],
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = loop
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._interceptors = interceptors
self._references = references
@staticmethod
def _init_metadata(
metadata: Optional[MetadataType] = None,
compression: Optional[grpc.Compression] = None,
) -> Metadata:
"""Based on the provided values for <metadata> or <compression> initialise the final
metadata, as it should be used for the current call.
"""
metadata = metadata or Metadata()
if not isinstance(metadata, Metadata) and isinstance(metadata, tuple):
metadata = Metadata.from_tuple(metadata)
if compression:
metadata = Metadata(
*_compression.augment_metadata(metadata, compression)
)
return metadata
class UnaryUnaryMultiCallable(
_BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable
):
def __call__(
self,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None,
) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
call = UnaryUnaryCall(
request,
_timeout_to_deadline(timeout),
metadata,
credentials,
wait_for_ready,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
self._loop,
)
else:
call = InterceptedUnaryUnaryCall(
self._interceptors,
request,
timeout,
metadata,
credentials,
wait_for_ready,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
self._loop,
)
return call
class UnaryStreamMultiCallable(
_BaseMultiCallable, _base_channel.UnaryStreamMultiCallable
):
def __call__(
self,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None,
) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
call = UnaryStreamCall(
request,
_timeout_to_deadline(timeout),
metadata,
credentials,
wait_for_ready,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
self._loop,
)
else:
call = InterceptedUnaryStreamCall(
self._interceptors,
request,
timeout,
metadata,
credentials,
wait_for_ready,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
self._loop,
)
return call
class StreamUnaryMultiCallable(
_BaseMultiCallable, _base_channel.StreamUnaryMultiCallable
):
def __call__(
self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None,
) -> _base_call.StreamUnaryCall:
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
call = StreamUnaryCall(
request_iterator,
_timeout_to_deadline(timeout),
metadata,
credentials,
wait_for_ready,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
self._loop,
)
else:
call = InterceptedStreamUnaryCall(
self._interceptors,
request_iterator,
timeout,
metadata,
credentials,
wait_for_ready,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
self._loop,
)
return call
class StreamStreamMultiCallable(
_BaseMultiCallable, _base_channel.StreamStreamMultiCallable
):
def __call__(
self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None,
) -> _base_call.StreamStreamCall:
metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
call = StreamStreamCall(
request_iterator,
_timeout_to_deadline(timeout),
metadata,
credentials,
wait_for_ready,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
self._loop,
)
else:
call = InterceptedStreamStreamCall(
self._interceptors,
request_iterator,
timeout,
metadata,
credentials,
wait_for_ready,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
self._loop,
)
return call
class Channel(_base_channel.Channel):
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: List[UnaryStreamClientInterceptor]
_stream_unary_interceptors: List[StreamUnaryClientInterceptor]
_stream_stream_interceptors: List[StreamStreamClientInterceptor]
def __init__(
self,
target: str,
options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[ClientInterceptor]],
):
"""Constructor.
Args:
target: The target to which to connect.
options: Configuration options for the channel.
credentials: A cygrpc.ChannelCredentials or None.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel.
"""
self._unary_unary_interceptors = []
self._unary_stream_interceptors = []
self._stream_unary_interceptors = []
self._stream_stream_interceptors = []
if interceptors is not None:
for interceptor in interceptors:
if isinstance(interceptor, UnaryUnaryClientInterceptor):
self._unary_unary_interceptors.append(interceptor)
elif isinstance(interceptor, UnaryStreamClientInterceptor):
self._unary_stream_interceptors.append(interceptor)
elif isinstance(interceptor, StreamUnaryClientInterceptor):
self._stream_unary_interceptors.append(interceptor)
elif isinstance(interceptor, StreamStreamClientInterceptor):
self._stream_stream_interceptors.append(interceptor)
else:
raise ValueError(
"Interceptor {} must be ".format(interceptor)
+ "{} or ".format(UnaryUnaryClientInterceptor.__name__)
+ "{} or ".format(UnaryStreamClientInterceptor.__name__)
+ "{} or ".format(StreamUnaryClientInterceptor.__name__)
+ "{}. ".format(StreamStreamClientInterceptor.__name__)
)
self._loop = cygrpc.get_working_loop()
self._channel = cygrpc.AioChannel(
_common.encode(target),
_augment_channel_arguments(options, compression),
credentials,
self._loop,
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._close(None)
async def _close(self, grace): # pylint: disable=too-many-branches
if self._channel.closed():
return
# No new calls will be accepted by the Cython channel.
self._channel.closing()
# Iterate through running tasks
tasks = _all_tasks()
calls = []
call_tasks = []
for task in tasks:
try:
stack = task.get_stack(limit=1)
except AttributeError as attribute_error:
# NOTE(lidiz) tl;dr: If the Task is created with a CPython
# object, it will trigger AttributeError.
#
# In the global finalizer, the event loop schedules
# a CPython PyAsyncGenAThrow object.
# https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
#
# However, the PyAsyncGenAThrow object is written in C and
# failed to include the normal Python frame objects. Hence,
# this exception is a false negative, and it is safe to ignore
# the failure. It is fixed by https://github.com/python/cpython/pull/18669,
# but not available until 3.9 or 3.8.3. So, we have to keep it
# for a while.
# TODO(lidiz) drop this hack after 3.8 deprecation
if "frame" in str(attribute_error):
continue
else:
raise
# If the Task is created by a C-extension, the stack will be empty.
if not stack:
continue
# Locate ones created by `aio.Call`.
frame = stack[0]
candidate = frame.f_locals.get("self")
if candidate:
if isinstance(candidate, _base_call.Call):
if hasattr(candidate, "_channel"):
# For intercepted Call object
if candidate._channel is not self._channel:
continue
elif hasattr(candidate, "_cython_call"):
# For normal Call object
if candidate._cython_call._channel is not self._channel:
continue
else:
# Unidentified Call object
raise cygrpc.InternalError(
f"Unrecognized call object: {candidate}"
)
calls.append(candidate)
call_tasks.append(task)
# If needed, try to wait for them to finish.
# Call objects are not always awaitables.
if grace and call_tasks:
await asyncio.wait(call_tasks, timeout=grace)
# Time to cancel existing calls.
for call in calls:
call.cancel()
# Destroy the channel
self._channel.close()
async def close(self, grace: Optional[float] = None):
await self._close(grace)
def __del__(self):
if hasattr(self, "_channel"):
if not self._channel.closed():
self._channel.close()
def get_state(
self, try_to_connect: bool = False
) -> grpc.ChannelConnectivity:
result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
async def wait_for_state_change(
self,
last_observed_state: grpc.ChannelConnectivity,
) -> None:
assert await self._channel.watch_connectivity_state(
last_observed_state.value[0], None
)
async def channel_ready(self) -> None:
state = self.get_state(try_to_connect=True)
while state != grpc.ChannelConnectivity.READY:
await self.wait_for_state_change(state)
state = self.get_state(try_to_connect=True)
# TODO(xuanwn): Implement this method after we have
# observability for Asyncio.
def _get_registered_call_handle(self, method: str) -> int:
pass
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryUnaryMultiCallable:
return UnaryUnaryMultiCallable(
self._channel,
_common.encode(method),
request_serializer,
response_deserializer,
self._unary_unary_interceptors,
[self],
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(
self._channel,
_common.encode(method),
request_serializer,
response_deserializer,
self._unary_stream_interceptors,
[self],
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(
self._channel,
_common.encode(method),
request_serializer,
response_deserializer,
self._stream_unary_interceptors,
[self],
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(
self._channel,
_common.encode(method),
request_serializer,
response_deserializer,
self._stream_stream_interceptors,
[self],
self._loop,
)
def insecure_channel(
target: str,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[ClientInterceptor]] = None,
):
"""Creates an insecure asynchronous Channel to a server.
Args:
target: The server address
options: An optional list of key-value pairs (:term:`channel_arguments`
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns:
A Channel.
"""
return Channel(
target,
() if options is None else options,
None,
compression,
interceptors,
)
def secure_channel(
target: str,
credentials: grpc.ChannelCredentials,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[ClientInterceptor]] = None,
):
"""Creates a secure asynchronous Channel to a server.
Args:
target: The server address.
credentials: A ChannelCredentials instance.
options: An optional list of key-value pairs (:term:`channel_arguments`
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns:
An aio.Channel.
"""
return Channel(
target,
() if options is None else options,
credentials._credentials,
compression,
interceptors,
)