210 lines
6.9 KiB
Python
210 lines
6.9 KiB
Python
"""Stubs for aiohttp HTTP clients"""
|
|
from __future__ import absolute_import
|
|
|
|
import asyncio
|
|
import functools
|
|
import logging
|
|
import json
|
|
|
|
from aiohttp import ClientConnectionError, ClientResponse, RequestInfo, streams
|
|
from multidict import CIMultiDict, CIMultiDictProxy
|
|
from yarl import URL
|
|
|
|
from vcr.request import Request
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin):
|
|
pass
|
|
|
|
|
|
class MockClientResponse(ClientResponse):
|
|
def __init__(self, method, url, request_info=None):
|
|
super().__init__(
|
|
method=method,
|
|
url=url,
|
|
writer=None,
|
|
continue100=None,
|
|
timer=None,
|
|
request_info=request_info,
|
|
traces=None,
|
|
loop=asyncio.get_event_loop(),
|
|
session=None,
|
|
)
|
|
|
|
async def json(self, *, encoding="utf-8", loads=json.loads, **kwargs): # NOQA: E999
|
|
stripped = self._body.strip()
|
|
if not stripped:
|
|
return None
|
|
|
|
return loads(stripped.decode(encoding))
|
|
|
|
async def text(self, encoding="utf-8", errors="strict"):
|
|
return self._body.decode(encoding, errors=errors)
|
|
|
|
async def read(self):
|
|
return self._body
|
|
|
|
def release(self):
|
|
pass
|
|
|
|
@property
|
|
def content(self):
|
|
s = MockStream()
|
|
s.feed_data(self._body)
|
|
s.feed_eof()
|
|
return s
|
|
|
|
|
|
def build_response(vcr_request, vcr_response, history):
|
|
request_info = RequestInfo(
|
|
url=URL(vcr_request.url),
|
|
method=vcr_request.method,
|
|
headers=CIMultiDictProxy(CIMultiDict(vcr_request.headers)),
|
|
real_url=URL(vcr_request.url),
|
|
)
|
|
response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")), request_info=request_info)
|
|
response.status = vcr_response["status"]["code"]
|
|
response._body = vcr_response["body"].get("string", b"")
|
|
response.reason = vcr_response["status"]["message"]
|
|
response._headers = CIMultiDictProxy(CIMultiDict(vcr_response["headers"]))
|
|
response._history = tuple(history)
|
|
|
|
response.close()
|
|
return response
|
|
|
|
|
|
def _serialize_headers(headers):
|
|
"""Serialize CIMultiDictProxy to a pickle-able dict because proxy
|
|
objects forbid pickling:
|
|
|
|
https://github.com/aio-libs/multidict/issues/340
|
|
"""
|
|
# Mark strings as keys so 'istr' types don't show up in
|
|
# the cassettes as comments.
|
|
return {str(k): v for k, v in headers.items()}
|
|
|
|
|
|
def play_responses(cassette, vcr_request):
|
|
history = []
|
|
vcr_response = cassette.play_response(vcr_request)
|
|
response = build_response(vcr_request, vcr_response, history)
|
|
|
|
# If we're following redirects, continue playing until we reach
|
|
# our final destination.
|
|
while 300 <= response.status <= 399:
|
|
next_url = URL(response.url).with_path(response.headers["location"])
|
|
|
|
# Make a stub VCR request that we can then use to look up the recorded
|
|
# VCR request saved to the cassette. This feels a little hacky and
|
|
# may have edge cases based on the headers we're providing (e.g. if
|
|
# there's a matcher that is used to filter by headers).
|
|
vcr_request = Request("GET", str(next_url), None, _serialize_headers(response.request_info.headers))
|
|
vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]
|
|
|
|
# Tack on the response we saw from the redirect into the history
|
|
# list that is added on to the final response.
|
|
history.append(response)
|
|
vcr_response = cassette.play_response(vcr_request)
|
|
response = build_response(vcr_request, vcr_response, history)
|
|
|
|
return response
|
|
|
|
|
|
async def record_response(cassette, vcr_request, response):
|
|
"""Record a VCR request-response chain to the cassette."""
|
|
|
|
try:
|
|
body = {"string": (await response.read())}
|
|
# aiohttp raises a ClientConnectionError on reads when
|
|
# there is no body. We can use this to know to not write one.
|
|
except ClientConnectionError:
|
|
body = {}
|
|
|
|
vcr_response = {
|
|
"status": {"code": response.status, "message": response.reason},
|
|
"headers": _serialize_headers(response.headers),
|
|
"body": body, # NOQA: E999
|
|
"url": str(response.url),
|
|
}
|
|
|
|
cassette.append(vcr_request, vcr_response)
|
|
|
|
|
|
async def record_responses(cassette, vcr_request, response):
|
|
"""Because aiohttp follows redirects by default, we must support
|
|
them by default. This method is used to write individual
|
|
request-response chains that were implicitly followed to get
|
|
to the final destination.
|
|
"""
|
|
|
|
for past_response in response.history:
|
|
aiohttp_request = past_response.request_info
|
|
|
|
# No data because it's following a redirect.
|
|
past_request = Request(
|
|
aiohttp_request.method,
|
|
str(aiohttp_request.url),
|
|
None,
|
|
_serialize_headers(aiohttp_request.headers),
|
|
)
|
|
await record_response(cassette, past_request, past_response)
|
|
|
|
# If we're following redirects, then the last request-response
|
|
# we record is the one attached to the `response`.
|
|
if response.history:
|
|
aiohttp_request = response.request_info
|
|
vcr_request = Request(
|
|
aiohttp_request.method,
|
|
str(aiohttp_request.url),
|
|
None,
|
|
_serialize_headers(aiohttp_request.headers),
|
|
)
|
|
|
|
await record_response(cassette, vcr_request, response)
|
|
|
|
|
|
def vcr_request(cassette, real_request):
|
|
@functools.wraps(real_request)
|
|
async def new_request(self, method, url, **kwargs):
|
|
headers = kwargs.get("headers")
|
|
auth = kwargs.get("auth")
|
|
headers = self._prepare_headers(headers)
|
|
data = kwargs.get("data", kwargs.get("json"))
|
|
params = kwargs.get("params")
|
|
|
|
if auth is not None:
|
|
headers["AUTHORIZATION"] = auth.encode()
|
|
|
|
request_url = URL(url)
|
|
if params:
|
|
for k, v in params.items():
|
|
params[k] = str(v)
|
|
request_url = URL(url).with_query(params)
|
|
|
|
vcr_request = Request(method, str(request_url), data, headers)
|
|
|
|
if cassette.can_play_response_for(vcr_request):
|
|
return play_responses(cassette, vcr_request)
|
|
|
|
if cassette.write_protected and cassette.filter_request(vcr_request):
|
|
response = MockClientResponse(method, URL(url))
|
|
response.status = 599
|
|
msg = (
|
|
"No match for the request {!r} was found. Can't overwrite "
|
|
"existing cassette {!r} in your current record mode {!r}."
|
|
)
|
|
msg = msg.format(vcr_request, cassette._path, cassette.record_mode)
|
|
response._body = msg.encode()
|
|
response.close()
|
|
return response
|
|
|
|
log.info("%s not in cassette, sending to real server", vcr_request)
|
|
|
|
response = await real_request(self, method, url, **kwargs) # NOQA: E999
|
|
await record_responses(cassette, vcr_request, response)
|
|
return response
|
|
|
|
return new_request
|