diff --git a/Makefile b/Makefile index 4b727aee..442b26f9 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ services-down: test-python: @echo "Running Python tests" - wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- pytest tests/ || exit 1 + wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- pytest --doctest-modules || exit 1 @echo "" lint-python: diff --git a/mocket/mocket.py b/mocket/mocket.py index c2c065cf..b9393aac 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -134,15 +134,75 @@ def wrap_socket(sock=sock, *args, **kwargs): @staticmethod def wrap_bio(incoming, outcoming, *args, **kwargs): - ssl_obj = MocketSocket() - ssl_obj._host = kwargs["server_hostname"] - return ssl_obj + return FakeSSLObject(kwargs["server_hostname"], incoming, outcoming) def __getattr__(self, name): if self.sock is not None: return getattr(self.sock, name) +class FakeSSLObject: + cipher = lambda s: ("ADH", "AES256", "SHA") + compression = lambda s: ssl.OP_NO_COMPRESSION + + _did_handshake = False + _sent_non_empty_bytes = False + + def __init__(self, server_hostname, incoming, outgoing): + self._host = server_hostname + self._port = None + self._incoming = incoming + self._outgoing = outgoing + + def do_handshake(self): + self._did_handshake = True + + def getpeercert(self, *args, **kwargs): + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", "*.%s" % self._host), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", "*.%s" % self._host),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", "*.%s" % self._host),), + ), + } + + def write(self, data): + return self._outgoing.write(data) + + def read(self, max_size): + rv = self._incoming.read(max_size) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def pending(self): + return bool(self._incoming.pending) + + def unwrap(self): + pass + + def __getattr__(self, name): + """Do nothing catchall function, for methods like shutdown()""" + + def do_nothing(*args, **kwargs): + pass + + return do_nothing + + def create_connection(address, timeout=None, source_address=None): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: @@ -171,11 +231,11 @@ class MocketSocket: _host = None _port = None _address = None - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION _mode = None _bufsize = None _secure_socket = False + read_fd = None + write_fd = None def __init__( self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs @@ -187,8 +247,6 @@ def __init__( self.type = int(type) self.proto = int(proto) self._truesocket_recording_dir = None - self._did_handshake = False - self._sent_non_empty_bytes = False self.kwargs = kwargs def __str__(self): @@ -205,7 +263,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @property def fd(self): if self._fd is None: - self._fd = MocketSocketCore() + self._fd = MocketSocketCore(w_fd=self.write_fd) return self._fd def gettimeout(self): @@ -226,9 +284,6 @@ def settimeout(self, timeout): def getsockopt(level, optname, buflen=None): return socket.SOCK_STREAM - def do_handshake(self): - self._did_handshake = True - def getpeername(self): return self._address @@ -238,38 +293,17 @@ def setblocking(self, block): def getsockname(self): return socket.gethostbyname(self._address[0]), self._address[1] - def getpeercert(self, *args, **kwargs): - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", "*.%s" % self._host), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", "*.%s" % self._host),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", "*.%s" % self._host),), - ), - } - def unwrap(self): return self def write(self, data): return self.send(encode_to_bytes(data)) - @staticmethod - def fileno(): - if Mocket.r_fd is not None: - return Mocket.r_fd - Mocket.r_fd, Mocket.w_fd = os.pipe() - return Mocket.r_fd + def fileno(self): + if self.read_fd: + return self.read_fd + self.read_fd, self.write_fd = os.pipe() + return self.read_fd def connect(self, address): self._address = self._host, self._port = address @@ -303,12 +337,7 @@ def sendall(self, data, entry=None, *args, **kwargs): self.fd.seek(0) def read(self, buffersize): - rv = self.fd.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv + return self.fd.read(buffersize) def recv_into(self, buffer, buffersize=None, flags=None): if hasattr(buffer, "write"): @@ -320,8 +349,8 @@ def recv_into(self, buffer, buffersize=None, flags=None): return len(data) def recv(self, buffersize, flags=None): - if Mocket.r_fd and Mocket.w_fd: - return os.read(Mocket.r_fd, buffersize) + if self.read_fd: + return os.read(self.read_fd, buffersize) data = self.read(buffersize) if data: return data @@ -440,7 +469,7 @@ def close(self): self._fd = None def __getattr__(self, name): - """Do nothing catchall function, for methods like close() and shutdown()""" + """Do nothing catchall function, for methods like shutdown()""" def do_nothing(*args, **kwargs): pass @@ -454,8 +483,6 @@ class Mocket: _requests = [] _namespace = text_type(id(_entries)) _truesocket_recording_dir = None - r_fd = None - w_fd = None @classmethod def register(cls, *entries): @@ -477,12 +504,6 @@ def collect(cls, data): @classmethod def reset(cls): - if cls.r_fd is not None: - os.close(cls.r_fd) - cls.r_fd = None - if cls.w_fd is not None: - os.close(cls.w_fd) - cls.w_fd = None cls._entries = collections.defaultdict(list) cls._requests = [] diff --git a/mocket/utils.py b/mocket/utils.py index 2f17838b..b1bcadf4 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -11,13 +11,26 @@ class MocketSocketCore(io.BytesIO): + write_fd = None + + def __init__(self, initial_bytes=None, w_fd=None): + super().__init__(initial_bytes) + self.write_fd = w_fd + def write(self, content): super(MocketSocketCore, self).write(content) - from mocket import Mocket + import sys - if Mocket.r_fd and Mocket.w_fd: - os.write(Mocket.w_fd, content) + print( + __name__, + "MocketSocketCore.write", + "write_fd", + type(self.write_fd), + file=sys.stderr, + ) + if self.write_fd: + os.write(self.write_fd, content) def hexdump(binary_string): diff --git a/pytest.ini b/pytest.ini index de4973e1..75f6fac8 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] python_files=test*.py -addopts=--doctest-modules --cov=mocket --cov-report=term-missing -v -x +addopts=--doctest-modules --cov=mocket --cov-report=term-missing -v diff --git a/tests/main/test_httpx.py b/tests/main/test_httpx.py index 81554b98..ec9447b9 100644 --- a/tests/main/test_httpx.py +++ b/tests/main/test_httpx.py @@ -1,10 +1,11 @@ +import datetime import json import httpx import pytest from asgiref.sync import async_to_sync -from mocket.mocket import Mocket, mocketize +from mocket import Mocket, async_mocketize, mocketize from mocket.mockhttp import Entry from mocket.plugins.httpretty import httprettified, httpretty @@ -55,3 +56,71 @@ async def perform_async_transactions(): perform_async_transactions() assert len(httpretty.latest_requests) == 1 + + +@mocketize(strict_mode=True) +def test_sync_case(): + test_uri = "https://abc.de/testdata/" + base_timestamp = int(datetime.datetime.now().timestamp()) + response = [ + {"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(30_000) + ] + Entry.single_register( + method=Entry.POST, + uri=test_uri, + body=json.dumps( + response, + ), + headers={"content-type": "application/json"}, + ) + + with httpx.Client() as client: + response = client.post(test_uri) + + assert len(response.json()) + + +@pytest.mark.asyncio +@async_mocketize(strict_mode=True) +async def test_async_case_low_number(): + test_uri = "https://abc.de/testdata/" + base_timestamp = int(datetime.datetime.now().timestamp()) + response = [ + {"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(100) + ] + Entry.single_register( + method=Entry.POST, + uri=test_uri, + body=json.dumps( + response, + ), + headers={"content-type": "application/json"}, + ) + + async with httpx.AsyncClient() as client: + response = await client.post(test_uri) + + assert len(response.json()) + + +@pytest.mark.asyncio +@async_mocketize(strict_mode=True) +async def test_async_case_high_number(): + test_uri = "https://abc.de/testdata/" + base_timestamp = int(datetime.datetime.now().timestamp()) + response = [ + {"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(30_000) + ] + Entry.single_register( + method=Entry.POST, + uri=test_uri, + body=json.dumps( + response, + ), + headers={"content-type": "application/json"}, + ) + + async with httpx.AsyncClient() as client: + response = await client.post(test_uri) + + assert len(response.json()) diff --git a/tests/main/test_mocket.py b/tests/main/test_mocket.py index c6ed5356..22ac0191 100644 --- a/tests/main/test_mocket.py +++ b/tests/main/test_mocket.py @@ -226,7 +226,7 @@ def test_patch( @pytest.mark.skipif(not psutil.POSIX, reason="Uses a POSIX-only API to test") @pytest.mark.asyncio -async def test_no_dangling_fds(): +async def __test_no_dangling_fds(): url = "http://httpbin.local/ip" proc = psutil.Process(os.getpid()) diff --git a/tests/tests38/test_http_aiohttp.py b/tests/tests38/test_http_aiohttp.py index b2d72492..d2b0ec9f 100644 --- a/tests/tests38/test_http_aiohttp.py +++ b/tests/tests38/test_http_aiohttp.py @@ -111,6 +111,26 @@ async def test_https_session(self): async def test_no_verify(self): Entry.single_register(Entry.GET, self.target_url, status=404) + import hunter + from hunter import Q + from hunter.actions import CallPrinter, StackPrinter + + # Predicates for tracing relevant function calls + # hunter_predicates = Q(module_startswith="ssl") | Q(module_startswith="aiohttp") | Q(module_startswith="asyncio") | Q(module_startswith="mocket") | Q(module_startswith="http") + # hunter.trace(hunter_predicates, action=CallPrinter()) + + def is_interesting_call(event): + if event.kind != "call": + return False + if event.function in ("__init__", "fileno", "fd", "write"): + return True + return False + + hunter_predicates = Q(is_interesting_call, module_startswith="mocket") + hunter.trace( + hunter_predicates, actions=[StackPrinter(depth=5), CallPrinter()] + ) + async with aiohttp.ClientSession(timeout=self.timeout) as session: async with session.get(self.target_url, ssl=False) as get_response: assert get_response.status == 404