756 lines
25 KiB
Python
756 lines
25 KiB
Python
|
# Copyright 2019 gRPC authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Invocation-side implementation of gRPC Asyncio Python."""
|
||
|
|
||
|
import asyncio
|
||
|
import enum
|
||
|
from functools import partial
|
||
|
import inspect
|
||
|
import logging
|
||
|
import traceback
|
||
|
from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple
|
||
|
|
||
|
import grpc
|
||
|
from grpc import _common
|
||
|
from grpc._cython import cygrpc
|
||
|
|
||
|
from . import _base_call
|
||
|
from ._metadata import Metadata
|
||
|
from ._typing import DeserializingFunction
|
||
|
from ._typing import DoneCallbackType
|
||
|
from ._typing import MetadatumType
|
||
|
from ._typing import RequestIterableType
|
||
|
from ._typing import RequestType
|
||
|
from ._typing import ResponseType
|
||
|
from ._typing import SerializingFunction
|
||
|
|
||
|
__all__ = "AioRpcError", "Call", "UnaryUnaryCall", "UnaryStreamCall"
|
||
|
|
||
|
_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!"
|
||
|
_GC_CANCELLATION_DETAILS = "Cancelled upon garbage collection!"
|
||
|
_RPC_ALREADY_FINISHED_DETAILS = "RPC already finished."
|
||
|
_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
|
||
|
_API_STYLE_ERROR = (
|
||
|
"The iterator and read/write APIs may not be mixed on a single RPC."
|
||
|
)
|
||
|
|
||
|
_OK_CALL_REPRESENTATION = (
|
||
|
'<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>'
|
||
|
)
|
||
|
|
||
|
_NON_OK_CALL_REPRESENTATION = (
|
||
|
"<{} of RPC that terminated with:\n"
|
||
|
"\tstatus = {}\n"
|
||
|
'\tdetails = "{}"\n'
|
||
|
'\tdebug_error_string = "{}"\n'
|
||
|
">"
|
||
|
)
|
||
|
|
||
|
_LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class AioRpcError(grpc.RpcError):
|
||
|
"""An implementation of RpcError to be used by the asynchronous API.
|
||
|
|
||
|
Raised RpcError is a snapshot of the final status of the RPC, values are
|
||
|
determined. Hence, its methods no longer needs to be coroutines.
|
||
|
"""
|
||
|
|
||
|
_code: grpc.StatusCode
|
||
|
_details: Optional[str]
|
||
|
_initial_metadata: Optional[Metadata]
|
||
|
_trailing_metadata: Optional[Metadata]
|
||
|
_debug_error_string: Optional[str]
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
code: grpc.StatusCode,
|
||
|
initial_metadata: Metadata,
|
||
|
trailing_metadata: Metadata,
|
||
|
details: Optional[str] = None,
|
||
|
debug_error_string: Optional[str] = None,
|
||
|
) -> None:
|
||
|
"""Constructor.
|
||
|
|
||
|
Args:
|
||
|
code: The status code with which the RPC has been finalized.
|
||
|
details: Optional details explaining the reason of the error.
|
||
|
initial_metadata: Optional initial metadata that could be sent by the
|
||
|
Server.
|
||
|
trailing_metadata: Optional metadata that could be sent by the Server.
|
||
|
"""
|
||
|
|
||
|
super().__init__()
|
||
|
self._code = code
|
||
|
self._details = details
|
||
|
self._initial_metadata = initial_metadata
|
||
|
self._trailing_metadata = trailing_metadata
|
||
|
self._debug_error_string = debug_error_string
|
||
|
|
||
|
def code(self) -> grpc.StatusCode:
|
||
|
"""Accesses the status code sent by the server.
|
||
|
|
||
|
Returns:
|
||
|
The `grpc.StatusCode` status code.
|
||
|
"""
|
||
|
return self._code
|
||
|
|
||
|
def details(self) -> Optional[str]:
|
||
|
"""Accesses the details sent by the server.
|
||
|
|
||
|
Returns:
|
||
|
The description of the error.
|
||
|
"""
|
||
|
return self._details
|
||
|
|
||
|
def initial_metadata(self) -> Metadata:
|
||
|
"""Accesses the initial metadata sent by the server.
|
||
|
|
||
|
Returns:
|
||
|
The initial metadata received.
|
||
|
"""
|
||
|
return self._initial_metadata
|
||
|
|
||
|
def trailing_metadata(self) -> Metadata:
|
||
|
"""Accesses the trailing metadata sent by the server.
|
||
|
|
||
|
Returns:
|
||
|
The trailing metadata received.
|
||
|
"""
|
||
|
return self._trailing_metadata
|
||
|
|
||
|
def debug_error_string(self) -> str:
|
||
|
"""Accesses the debug error string sent by the server.
|
||
|
|
||
|
Returns:
|
||
|
The debug error string received.
|
||
|
"""
|
||
|
return self._debug_error_string
|
||
|
|
||
|
def _repr(self) -> str:
|
||
|
"""Assembles the error string for the RPC error."""
|
||
|
return _NON_OK_CALL_REPRESENTATION.format(
|
||
|
self.__class__.__name__,
|
||
|
self._code,
|
||
|
self._details,
|
||
|
self._debug_error_string,
|
||
|
)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return self._repr()
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return self._repr()
|
||
|
|
||
|
def __reduce__(self):
|
||
|
return (
|
||
|
type(self),
|
||
|
(
|
||
|
self._code,
|
||
|
self._initial_metadata,
|
||
|
self._trailing_metadata,
|
||
|
self._details,
|
||
|
self._debug_error_string,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_rpc_error(
|
||
|
initial_metadata: Metadata, status: cygrpc.AioRpcStatus
|
||
|
) -> AioRpcError:
|
||
|
return AioRpcError(
|
||
|
_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
|
||
|
Metadata.from_tuple(initial_metadata),
|
||
|
Metadata.from_tuple(status.trailing_metadata()),
|
||
|
details=status.details(),
|
||
|
debug_error_string=status.debug_error_string(),
|
||
|
)
|
||
|
|
||
|
|
||
|
class Call:
|
||
|
"""Base implementation of client RPC Call object.
|
||
|
|
||
|
Implements logic around final status, metadata and cancellation.
|
||
|
"""
|
||
|
|
||
|
_loop: asyncio.AbstractEventLoop
|
||
|
_code: grpc.StatusCode
|
||
|
_cython_call: cygrpc._AioCall
|
||
|
_metadata: Tuple[MetadatumType, ...]
|
||
|
_request_serializer: SerializingFunction
|
||
|
_response_deserializer: DeserializingFunction
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
cython_call: cygrpc._AioCall,
|
||
|
metadata: Metadata,
|
||
|
request_serializer: SerializingFunction,
|
||
|
response_deserializer: DeserializingFunction,
|
||
|
loop: asyncio.AbstractEventLoop,
|
||
|
) -> None:
|
||
|
self._loop = loop
|
||
|
self._cython_call = cython_call
|
||
|
self._metadata = tuple(metadata)
|
||
|
self._request_serializer = request_serializer
|
||
|
self._response_deserializer = response_deserializer
|
||
|
|
||
|
def __del__(self) -> None:
|
||
|
# The '_cython_call' object might be destructed before Call object
|
||
|
if hasattr(self, "_cython_call"):
|
||
|
if not self._cython_call.done():
|
||
|
self._cancel(_GC_CANCELLATION_DETAILS)
|
||
|
|
||
|
def cancelled(self) -> bool:
|
||
|
return self._cython_call.cancelled()
|
||
|
|
||
|
def _cancel(self, details: str) -> bool:
|
||
|
"""Forwards the application cancellation reasoning."""
|
||
|
if not self._cython_call.done():
|
||
|
self._cython_call.cancel(details)
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def cancel(self) -> bool:
|
||
|
return self._cancel(_LOCAL_CANCELLATION_DETAILS)
|
||
|
|
||
|
def done(self) -> bool:
|
||
|
return self._cython_call.done()
|
||
|
|
||
|
def add_done_callback(self, callback: DoneCallbackType) -> None:
|
||
|
cb = partial(callback, self)
|
||
|
self._cython_call.add_done_callback(cb)
|
||
|
|
||
|
def time_remaining(self) -> Optional[float]:
|
||
|
return self._cython_call.time_remaining()
|
||
|
|
||
|
async def initial_metadata(self) -> Metadata:
|
||
|
raw_metadata_tuple = await self._cython_call.initial_metadata()
|
||
|
return Metadata.from_tuple(raw_metadata_tuple)
|
||
|
|
||
|
async def trailing_metadata(self) -> Metadata:
|
||
|
raw_metadata_tuple = (
|
||
|
await self._cython_call.status()
|
||
|
).trailing_metadata()
|
||
|
return Metadata.from_tuple(raw_metadata_tuple)
|
||
|
|
||
|
async def code(self) -> grpc.StatusCode:
|
||
|
cygrpc_code = (await self._cython_call.status()).code()
|
||
|
return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]
|
||
|
|
||
|
async def details(self) -> str:
|
||
|
return (await self._cython_call.status()).details()
|
||
|
|
||
|
async def debug_error_string(self) -> str:
|
||
|
return (await self._cython_call.status()).debug_error_string()
|
||
|
|
||
|
async def _raise_for_status(self) -> None:
|
||
|
if self._cython_call.is_locally_cancelled():
|
||
|
raise asyncio.CancelledError()
|
||
|
code = await self.code()
|
||
|
if code != grpc.StatusCode.OK:
|
||
|
raise _create_rpc_error(
|
||
|
await self.initial_metadata(), await self._cython_call.status()
|
||
|
)
|
||
|
|
||
|
def _repr(self) -> str:
|
||
|
return repr(self._cython_call)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return self._repr()
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return self._repr()
|
||
|
|
||
|
|
||
|
class _APIStyle(enum.IntEnum):
|
||
|
UNKNOWN = 0
|
||
|
ASYNC_GENERATOR = 1
|
||
|
READER_WRITER = 2
|
||
|
|
||
|
|
||
|
class _UnaryResponseMixin(Call, Generic[ResponseType]):
|
||
|
_call_response: asyncio.Task
|
||
|
|
||
|
def _init_unary_response_mixin(self, response_task: asyncio.Task):
|
||
|
self._call_response = response_task
|
||
|
|
||
|
def cancel(self) -> bool:
|
||
|
if super().cancel():
|
||
|
self._call_response.cancel()
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def __await__(self) -> Generator[Any, None, ResponseType]:
|
||
|
"""Wait till the ongoing RPC request finishes."""
|
||
|
try:
|
||
|
response = yield from self._call_response
|
||
|
except asyncio.CancelledError:
|
||
|
# Even if we caught all other CancelledError, there is still
|
||
|
# this corner case. If the application cancels immediately after
|
||
|
# the Call object is created, we will observe this
|
||
|
# `CancelledError`.
|
||
|
if not self.cancelled():
|
||
|
self.cancel()
|
||
|
raise
|
||
|
|
||
|
# NOTE(lidiz) If we raise RpcError in the task, and users doesn't
|
||
|
# 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
|
||
|
# Instead, if we move the exception raising here, the spam stops.
|
||
|
# Unfortunately, there can only be one 'yield from' in '__await__'. So,
|
||
|
# we need to access the private instance variable.
|
||
|
if response is cygrpc.EOF:
|
||
|
if self._cython_call.is_locally_cancelled():
|
||
|
raise asyncio.CancelledError()
|
||
|
else:
|
||
|
raise _create_rpc_error(
|
||
|
self._cython_call._initial_metadata,
|
||
|
self._cython_call._status,
|
||
|
)
|
||
|
else:
|
||
|
return response
|
||
|
|
||
|
|
||
|
class _StreamResponseMixin(Call):
|
||
|
_message_aiter: AsyncIterator[ResponseType]
|
||
|
_preparation: asyncio.Task
|
||
|
_response_style: _APIStyle
|
||
|
|
||
|
def _init_stream_response_mixin(self, preparation: asyncio.Task):
|
||
|
self._message_aiter = None
|
||
|
self._preparation = preparation
|
||
|
self._response_style = _APIStyle.UNKNOWN
|
||
|
|
||
|
def _update_response_style(self, style: _APIStyle):
|
||
|
if self._response_style is _APIStyle.UNKNOWN:
|
||
|
self._response_style = style
|
||
|
elif self._response_style is not style:
|
||
|
raise cygrpc.UsageError(_API_STYLE_ERROR)
|
||
|
|
||
|
def cancel(self) -> bool:
|
||
|
if super().cancel():
|
||
|
self._preparation.cancel()
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
async def _fetch_stream_responses(self) -> ResponseType:
|
||
|
message = await self._read()
|
||
|
while message is not cygrpc.EOF:
|
||
|
yield message
|
||
|
message = await self._read()
|
||
|
|
||
|
# If the read operation failed, Core should explain why.
|
||
|
await self._raise_for_status()
|
||
|
|
||
|
def __aiter__(self) -> AsyncIterator[ResponseType]:
|
||
|
self._update_response_style(_APIStyle.ASYNC_GENERATOR)
|
||
|
if self._message_aiter is None:
|
||
|
self._message_aiter = self._fetch_stream_responses()
|
||
|
return self._message_aiter
|
||
|
|
||
|
async def _read(self) -> ResponseType:
|
||
|
# Wait for the request being sent
|
||
|
await self._preparation
|
||
|
|
||
|
# Reads response message from Core
|
||
|
try:
|
||
|
raw_response = await self._cython_call.receive_serialized_message()
|
||
|
except asyncio.CancelledError:
|
||
|
if not self.cancelled():
|
||
|
self.cancel()
|
||
|
raise
|
||
|
|
||
|
if raw_response is cygrpc.EOF:
|
||
|
return cygrpc.EOF
|
||
|
else:
|
||
|
return _common.deserialize(
|
||
|
raw_response, self._response_deserializer
|
||
|
)
|
||
|
|
||
|
async def read(self) -> ResponseType:
|
||
|
if self.done():
|
||
|
await self._raise_for_status()
|
||
|
return cygrpc.EOF
|
||
|
self._update_response_style(_APIStyle.READER_WRITER)
|
||
|
|
||
|
response_message = await self._read()
|
||
|
|
||
|
if response_message is cygrpc.EOF:
|
||
|
# If the read operation failed, Core should explain why.
|
||
|
await self._raise_for_status()
|
||
|
return response_message
|
||
|
|
||
|
|
||
|
class _StreamRequestMixin(Call):
|
||
|
_metadata_sent: asyncio.Event
|
||
|
_done_writing_flag: bool
|
||
|
_async_request_poller: Optional[asyncio.Task]
|
||
|
_request_style: _APIStyle
|
||
|
|
||
|
def _init_stream_request_mixin(
|
||
|
self, request_iterator: Optional[RequestIterableType]
|
||
|
):
|
||
|
self._metadata_sent = asyncio.Event()
|
||
|
self._done_writing_flag = False
|
||
|
|
||
|
# If user passes in an async iterator, create a consumer Task.
|
||
|
if request_iterator is not None:
|
||
|
self._async_request_poller = self._loop.create_task(
|
||
|
self._consume_request_iterator(request_iterator)
|
||
|
)
|
||
|
self._request_style = _APIStyle.ASYNC_GENERATOR
|
||
|
else:
|
||
|
self._async_request_poller = None
|
||
|
self._request_style = _APIStyle.READER_WRITER
|
||
|
|
||
|
def _raise_for_different_style(self, style: _APIStyle):
|
||
|
if self._request_style is not style:
|
||
|
raise cygrpc.UsageError(_API_STYLE_ERROR)
|
||
|
|
||
|
def cancel(self) -> bool:
|
||
|
if super().cancel():
|
||
|
if self._async_request_poller is not None:
|
||
|
self._async_request_poller.cancel()
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def _metadata_sent_observer(self):
|
||
|
self._metadata_sent.set()
|
||
|
|
||
|
async def _consume_request_iterator(
|
||
|
self, request_iterator: RequestIterableType
|
||
|
) -> None:
|
||
|
try:
|
||
|
if inspect.isasyncgen(request_iterator) or hasattr(
|
||
|
request_iterator, "__aiter__"
|
||
|
):
|
||
|
async for request in request_iterator:
|
||
|
try:
|
||
|
await self._write(request)
|
||
|
except AioRpcError as rpc_error:
|
||
|
_LOGGER.debug(
|
||
|
(
|
||
|
"Exception while consuming the"
|
||
|
" request_iterator: %s"
|
||
|
),
|
||
|
rpc_error,
|
||
|
)
|
||
|
return
|
||
|
else:
|
||
|
for request in request_iterator:
|
||
|
try:
|
||
|
await self._write(request)
|
||
|
except AioRpcError as rpc_error:
|
||
|
_LOGGER.debug(
|
||
|
(
|
||
|
"Exception while consuming the"
|
||
|
" request_iterator: %s"
|
||
|
),
|
||
|
rpc_error,
|
||
|
)
|
||
|
return
|
||
|
|
||
|
await self._done_writing()
|
||
|
except: # pylint: disable=bare-except
|
||
|
# Client iterators can raise exceptions, which we should handle by
|
||
|
# cancelling the RPC and logging the client's error. No exceptions
|
||
|
# should escape this function.
|
||
|
_LOGGER.debug(
|
||
|
"Client request_iterator raised exception:\n%s",
|
||
|
traceback.format_exc(),
|
||
|
)
|
||
|
self.cancel()
|
||
|
|
||
|
async def _write(self, request: RequestType) -> None:
|
||
|
if self.done():
|
||
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
|
||
|
if self._done_writing_flag:
|
||
|
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
|
||
|
if not self._metadata_sent.is_set():
|
||
|
await self._metadata_sent.wait()
|
||
|
if self.done():
|
||
|
await self._raise_for_status()
|
||
|
|
||
|
serialized_request = _common.serialize(
|
||
|
request, self._request_serializer
|
||
|
)
|
||
|
try:
|
||
|
await self._cython_call.send_serialized_message(serialized_request)
|
||
|
except cygrpc.InternalError as err:
|
||
|
self._cython_call.set_internal_error(str(err))
|
||
|
await self._raise_for_status()
|
||
|
except asyncio.CancelledError:
|
||
|
if not self.cancelled():
|
||
|
self.cancel()
|
||
|
raise
|
||
|
|
||
|
async def _done_writing(self) -> None:
|
||
|
if self.done():
|
||
|
# If the RPC is finished, do nothing.
|
||
|
return
|
||
|
if not self._done_writing_flag:
|
||
|
# If the done writing is not sent before, try to send it.
|
||
|
self._done_writing_flag = True
|
||
|
try:
|
||
|
await self._cython_call.send_receive_close()
|
||
|
except asyncio.CancelledError:
|
||
|
if not self.cancelled():
|
||
|
self.cancel()
|
||
|
raise
|
||
|
|
||
|
async def write(self, request: RequestType) -> None:
|
||
|
self._raise_for_different_style(_APIStyle.READER_WRITER)
|
||
|
await self._write(request)
|
||
|
|
||
|
async def done_writing(self) -> None:
|
||
|
"""Signal peer that client is done writing.
|
||
|
|
||
|
This method is idempotent.
|
||
|
"""
|
||
|
self._raise_for_different_style(_APIStyle.READER_WRITER)
|
||
|
await self._done_writing()
|
||
|
|
||
|
async def wait_for_connection(self) -> None:
|
||
|
await self._metadata_sent.wait()
|
||
|
if self.done():
|
||
|
await self._raise_for_status()
|
||
|
|
||
|
|
||
|
class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
|
||
|
"""Object for managing unary-unary RPC calls.
|
||
|
|
||
|
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
|
||
|
"""
|
||
|
|
||
|
_request: RequestType
|
||
|
_invocation_task: asyncio.Task
|
||
|
|
||
|
# pylint: disable=too-many-arguments
|
||
|
def __init__(
|
||
|
self,
|
||
|
request: RequestType,
|
||
|
deadline: Optional[float],
|
||
|
metadata: Metadata,
|
||
|
credentials: Optional[grpc.CallCredentials],
|
||
|
wait_for_ready: Optional[bool],
|
||
|
channel: cygrpc.AioChannel,
|
||
|
method: bytes,
|
||
|
request_serializer: SerializingFunction,
|
||
|
response_deserializer: DeserializingFunction,
|
||
|
loop: asyncio.AbstractEventLoop,
|
||
|
) -> None:
|
||
|
super().__init__(
|
||
|
channel.call(method, deadline, credentials, wait_for_ready),
|
||
|
metadata,
|
||
|
request_serializer,
|
||
|
response_deserializer,
|
||
|
loop,
|
||
|
)
|
||
|
self._request = request
|
||
|
self._context = cygrpc.build_census_context()
|
||
|
self._invocation_task = loop.create_task(self._invoke())
|
||
|
self._init_unary_response_mixin(self._invocation_task)
|
||
|
|
||
|
async def _invoke(self) -> ResponseType:
|
||
|
serialized_request = _common.serialize(
|
||
|
self._request, self._request_serializer
|
||
|
)
|
||
|
|
||
|
# NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
|
||
|
# because the asyncio.Task class do not cache the exception object.
|
||
|
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
|
||
|
try:
|
||
|
serialized_response = await self._cython_call.unary_unary(
|
||
|
serialized_request, self._metadata, self._context
|
||
|
)
|
||
|
except asyncio.CancelledError:
|
||
|
if not self.cancelled():
|
||
|
self.cancel()
|
||
|
|
||
|
if self._cython_call.is_ok():
|
||
|
return _common.deserialize(
|
||
|
serialized_response, self._response_deserializer
|
||
|
)
|
||
|
else:
|
||
|
return cygrpc.EOF
|
||
|
|
||
|
async def wait_for_connection(self) -> None:
|
||
|
await self._invocation_task
|
||
|
if self.done():
|
||
|
await self._raise_for_status()
|
||
|
|
||
|
|
||
|
class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
|
||
|
"""Object for managing unary-stream RPC calls.
|
||
|
|
||
|
Returned when an instance of `UnaryStreamMultiCallable` object is called.
|
||
|
"""
|
||
|
|
||
|
_request: RequestType
|
||
|
_send_unary_request_task: asyncio.Task
|
||
|
|
||
|
# pylint: disable=too-many-arguments
|
||
|
def __init__(
|
||
|
self,
|
||
|
request: RequestType,
|
||
|
deadline: Optional[float],
|
||
|
metadata: Metadata,
|
||
|
credentials: Optional[grpc.CallCredentials],
|
||
|
wait_for_ready: Optional[bool],
|
||
|
channel: cygrpc.AioChannel,
|
||
|
method: bytes,
|
||
|
request_serializer: SerializingFunction,
|
||
|
response_deserializer: DeserializingFunction,
|
||
|
loop: asyncio.AbstractEventLoop,
|
||
|
) -> None:
|
||
|
super().__init__(
|
||
|
channel.call(method, deadline, credentials, wait_for_ready),
|
||
|
metadata,
|
||
|
request_serializer,
|
||
|
response_deserializer,
|
||
|
loop,
|
||
|
)
|
||
|
self._request = request
|
||
|
self._context = cygrpc.build_census_context()
|
||
|
self._send_unary_request_task = loop.create_task(
|
||
|
self._send_unary_request()
|
||
|
)
|
||
|
self._init_stream_response_mixin(self._send_unary_request_task)
|
||
|
|
||
|
async def _send_unary_request(self) -> ResponseType:
|
||
|
serialized_request = _common.serialize(
|
||
|
self._request, self._request_serializer
|
||
|
)
|
||
|
try:
|
||
|
await self._cython_call.initiate_unary_stream(
|
||
|
serialized_request, self._metadata, self._context
|
||
|
)
|
||
|
except asyncio.CancelledError:
|
||
|
if not self.cancelled():
|
||
|
self.cancel()
|
||
|
raise
|
||
|
|
||
|
async def wait_for_connection(self) -> None:
|
||
|
await self._send_unary_request_task
|
||
|
if self.done():
|
||
|
await self._raise_for_status()
|
||
|
|
||
|
|
||
|
# pylint: disable=too-many-ancestors
|
||
|
class StreamUnaryCall(
|
||
|
_StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall
|
||
|
):
|
||
|
"""Object for managing stream-unary RPC calls.
|
||
|
|
||
|
Returned when an instance of `StreamUnaryMultiCallable` object is called.
|
||
|
"""
|
||
|
|
||
|
# pylint: disable=too-many-arguments
|
||
|
def __init__(
|
||
|
self,
|
||
|
request_iterator: Optional[RequestIterableType],
|
||
|
deadline: Optional[float],
|
||
|
metadata: Metadata,
|
||
|
credentials: Optional[grpc.CallCredentials],
|
||
|
wait_for_ready: Optional[bool],
|
||
|
channel: cygrpc.AioChannel,
|
||
|
method: bytes,
|
||
|
request_serializer: SerializingFunction,
|
||
|
response_deserializer: DeserializingFunction,
|
||
|
loop: asyncio.AbstractEventLoop,
|
||
|
) -> None:
|
||
|
super().__init__(
|
||
|
channel.call(method, deadline, credentials, wait_for_ready),
|
||
|
metadata,
|
||
|
request_serializer,
|
||
|
response_deserializer,
|
||
|
loop,
|
||
|
)
|
||
|
|
||
|
self._context = cygrpc.build_census_context()
|
||
|
self._init_stream_request_mixin(request_iterator)
|
||
|
self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))
|
||
|
|
||
|
async def _conduct_rpc(self) -> ResponseType:
|
||
|
try:
|
||
|
serialized_response = await self._cython_call.stream_unary(
|
||
|
self._metadata, self._metadata_sent_observer, self._context
|
||
|
)
|
||
|
except asyncio.CancelledError:
|
||
|
if not self.cancelled():
|
||
|
self.cancel()
|
||
|
raise
|
||
|
|
||
|
if self._cython_call.is_ok():
|
||
|
return _common.deserialize(
|
||
|
serialized_response, self._response_deserializer
|
||
|
)
|
||
|
else:
|
||
|
return cygrpc.EOF
|
||
|
|
||
|
|
||
|
class StreamStreamCall(
|
||
|
_StreamRequestMixin, _StreamResponseMixin, Call, _base_call.StreamStreamCall
|
||
|
):
|
||
|
"""Object for managing stream-stream RPC calls.
|
||
|
|
||
|
Returned when an instance of `StreamStreamMultiCallable` object is called.
|
||
|
"""
|
||
|
|
||
|
_initializer: asyncio.Task
|
||
|
|
||
|
# pylint: disable=too-many-arguments
|
||
|
def __init__(
|
||
|
self,
|
||
|
request_iterator: Optional[RequestIterableType],
|
||
|
deadline: Optional[float],
|
||
|
metadata: Metadata,
|
||
|
credentials: Optional[grpc.CallCredentials],
|
||
|
wait_for_ready: Optional[bool],
|
||
|
channel: cygrpc.AioChannel,
|
||
|
method: bytes,
|
||
|
request_serializer: SerializingFunction,
|
||
|
response_deserializer: DeserializingFunction,
|
||
|
loop: asyncio.AbstractEventLoop,
|
||
|
) -> None:
|
||
|
super().__init__(
|
||
|
channel.call(method, deadline, credentials, wait_for_ready),
|
||
|
metadata,
|
||
|
request_serializer,
|
||
|
response_deserializer,
|
||
|
loop,
|
||
|
)
|
||
|
self._context = cygrpc.build_census_context()
|
||
|
self._initializer = self._loop.create_task(self._prepare_rpc())
|
||
|
self._init_stream_request_mixin(request_iterator)
|
||
|
self._init_stream_response_mixin(self._initializer)
|
||
|
|
||
|
async def _prepare_rpc(self):
|
||
|
"""This method prepares the RPC for receiving/sending messages.
|
||
|
|
||
|
All other operations around the stream should only happen after the
|
||
|
completion of this method.
|
||
|
"""
|
||
|
try:
|
||
|
await self._cython_call.initiate_stream_stream(
|
||
|
self._metadata, self._metadata_sent_observer, self._context
|
||
|
)
|
||
|
except asyncio.CancelledError:
|
||
|
if not self.cancelled():
|
||
|
self.cancel()
|
||
|
# No need to raise RpcError here, because no one will `await` this task.
|