diff --git a/.coveralls.yml b/.coveralls.yml deleted file mode 100644 index 22c66215..00000000 --- a/.coveralls.yml +++ /dev/null @@ -1 +0,0 @@ -repo_token: 3yI8EwDqrGZaPCnfih1fSDizXbjwwL623 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3ff623d0..b5b8cff8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,10 +19,13 @@ concurrency: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] # , 'pypy3.10' + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', '3.14', 'pypy3.11'] + env: + # Configure a constant location for the uv cache + UV_CACHE_DIR: /tmp/.uv-cache steps: - uses: actions/checkout@v4 @@ -30,10 +33,15 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - pyproject.toml - - uses: hoverkraft-tech/compose-action@v2.0.0 + - name: Restore uv cache + uses: actions/cache@v4 + with: + path: /tmp/.uv-cache + key: uv-${{ runner.os }}-${{ hashFiles('uv.lock') }} + restore-keys: | + uv-${{ runner.os }}-${{ hashFiles('uv.lock') }} + uv-${{ runner.os }} + - uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: "./docker-compose.yml" down-flags: "--remove-orphans" @@ -44,9 +52,14 @@ jobs: make services-up - name: Test run: | - make test - make services-down - - name: Push Coveralls - run: | - pip install -q coveralls coveralls[yaml] - coveralls + if [[ "${{ matrix.python-version }}" == pypy* ]]; then + SKIP_MYPY=1 make test + else + make test + fi + - name: Minimize uv cache + run: uv cache prune --ci + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 564b8ce6..9bacc469 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ shippable .vscode/ Pipfile.lock requirements.txt +coverage.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eae15d12..b2cd5de5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: forbid-crlf - id: remove-crlf - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -15,8 +15,12 @@ repos: exclude: helm/ args: [ --unsafe ] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.4.4" + rev: "v0.14.0" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - id: ruff-format + - repo: https://github.com/rstcheck/rstcheck + rev: v6.2.5 + hooks: + - id: rstcheck diff --git a/LICENSE b/LICENSE index 2788c4b4..45cf27c5 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2017-2024 Giorgio Salluzzo and individual contributors. All rights reserved. +Copyright (c) 2017-2026 Giorgio Salluzzo and individual contributors. All rights reserved. Copyright (c) 2013-2017 Andrea de Marco, Giorgio Salluzzo and individual contributors. All rights reserved. diff --git a/Makefile b/Makefile index 675a0fa0..e35b0a57 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,7 @@ #!/usr/bin/make -f +VENV_PATH = .venv/bin + install-dev-requirements: curl -LsSf https://astral.sh/uv/install.sh | sh uv venv && uv pip install hatch @@ -24,25 +26,31 @@ setup: develop develop: install-dev-requirements install-test-requirements types: - @echo "Type checking Python files" - .venv/bin/mypy --pretty + @if [ -n "$$SKIP_MYPY" ]; then \ + echo "Skipping mypy types check because SKIP_MYPY is set"; \ + else \ + echo "Type checking Python files"; \ + $(VENV_PATH)/mypy --pretty; \ + fi @echo "" test: types @echo "Running Python tests" - export VIRTUAL_ENV=.venv; .venv/bin/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- .venv/bin/pytest + uv pip uninstall pook || true + $(VENV_PATH)/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- $(VENV_PATH)/pytest + uv pip install pook && $(VENV_PATH)/pytest tests/test_pook.py && uv pip uninstall pook @echo "" safetest: - export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; make test + SKIP_TRUE_REDIS=1 SKIP_TRUE_HTTP=1 $(VENV_PATH)/pytest -publish: install-test-requirements - python -m build --sdist . - twine upload --repository mocket dist/mocket-$(shell python -c 'import mocket; print(mocket.__version__)').tar.gz +publish: clean install-test-requirements + uv build --package mocket --sdist --wheel + uv publish clean: - rm -rf *.egg-info dist/ requirements.txt Pipfile.lock - find . -type d -name __pycache__ -exec rm -rf {} \; + rm -rf *.egg-info dist/ requirements.txt uv.lock coverage.xml || true + find . -type d -name __pycache__ -exec rm -rf {} \; || true .PHONY: clean publish safetest test setup develop lint-python test-python _services-up .PHONY: prepare-hosts services-up services-down install-test-requirements install-dev-requirements diff --git a/README.rst b/README.rst index 17a7801c..a6c662d4 100644 --- a/README.rst +++ b/README.rst @@ -1,12 +1,12 @@ -=============== -mocket /mɔˈkɛt/ -=============== +================ +mocket /ˈmɔ.kɛt/ +================ .. image:: https://github.com/mindflayer/python-mocket/actions/workflows/main.yml/badge.svg?branch=main :target: https://github.com/mindflayer/python-mocket/actions?query=workflow%3A%22Mocket%27s+CI%22 -.. image:: https://coveralls.io/repos/github/mindflayer/python-mocket/badge.svg?branch=main - :target: https://coveralls.io/github/mindflayer/python-mocket?branch=main +.. image:: https://codecov.io/github/mindflayer/python-mocket/graph/badge.svg?token=htRySebRBt + :target: https://codecov.io/github/mindflayer/python-mocket .. image:: https://app.codacy.com/project/badge/Grade/6327640518ce42adaf59368217028f14 :target: https://www.codacy.com/gh/mindflayer/python-mocket/dashboard @@ -14,6 +14,15 @@ mocket /mɔˈkɛt/ .. image:: https://img.shields.io/pypi/dm/mocket :target: https://pypistats.org/packages/mocket +.. image:: https://deepwiki.com/badge.svg + :target: https://deepwiki.com/mindflayer/python-mocket + +.. image:: https://raw.githubusercontent.com/mindflayer/python-mocket/main/mocket.png + :height: 256px + :width: 256px + :alt: Mocket logo + :align: center + A socket mock framework ------------------------- @@ -21,16 +30,30 @@ A socket mock framework ...and then MicroPython's *urequests* (*mocket >= 3.9.1*) +What is it about? +================= + +In a nutshell, **Mocket** is *monkey-patching on steroids* for the ``socket`` and ``ssl`` modules. + +It’s designed to serve two main purposes: + +- As a **low-level framework** — for example, if you're building a client for a new database or protocol. +- As a **ready-to-use mock** — perfect for testing HTTP or HTTPS calls from any client library. + +To demonstrate that Mocket is more than just a web client mocking tool, it even includes a simple Redis mock. + +The main goal of Mocket is to make it easier to test Python clients that communicate using the ``socket`` protocol. + Outside GitHub ============== -Mocket packages are available for `Arch Linux`_, `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, and of course you can **pip install** it from `PyPI`_. +Mocket packages are available for `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, `AUR Arch Linux`_, and of course from `PyPI`_. -.. _`Arch Linux`: https://archlinux.org/packages/extra/any/python-mocket/ .. _`openSUSE`: https://software.opensuse.org/search?baseproject=ALL&q=mocket .. _`NixOS`: https://search.nixos.org/packages?query=mocket .. _`ALT Linux`: https://packages.altlinux.org/en/sisyphus/srpms/python3-module-mocket/ .. _`NetBSD`: https://cdn.netbsd.org/pub/pkgsrc/current/pkgsrc/devel/py-mocket/index.html +.. _`AUR Arch Linux`: https://aur.archlinux.org/packages/python-mocket .. _`PyPI`: https://pypi.org/project/mocket/ @@ -64,12 +87,12 @@ The starting point to understand how to use *Mocket* to write a custom mock is t As next step, you are invited to have a look at the implementation of both the mocks it provides: -- HTTP mock (similar to HTTPretty) - https://github.com/mindflayer/python-mocket/blob/master/mocket/mockhttp.py -- Redis mock (basic implementation) - https://github.com/mindflayer/python-mocket/blob/master/mocket/mockredis.py +- HTTP mock (similar to HTTPretty) - https://github.com/mindflayer/python-mocket/blob/main/mocket/mocks/mockhttp.py +- Redis mock (basic implementation) - https://github.com/mindflayer/python-mocket/blob/main/mocket/mocks/mockredis.py Please also have a look at the huge test suite: -- Tests module at https://github.com/mindflayer/python-mocket/tree/master/tests +- Tests module at https://github.com/mindflayer/python-mocket/tree/main/tests Installation ============ @@ -102,7 +125,7 @@ As second step, we create an `example.py` file as the following one: import json from mocket import mocketize - from mocket.mockhttp import Entry + from mocket.mocks.mockhttp import Entry import requests import pytest @@ -206,6 +229,37 @@ It's very important that we test non-happy paths. with self.assertRaises(requests.exceptions.ConnectionError): requests.get(url) +Example of how to mock a call with a custom request matching logic +================================================================== +.. code-block:: python + + import json + + from mocket import mocketize + from mocket.mocks.mockhttp import Entry + import requests + + @mocketize + def test_can_handle(): + Entry.single_register( + Entry.GET, + url, + body=json.dumps({"message": "Nope... not this time!"}), + headers={"content-type": "application/json"}, + can_handle_fun=lambda path, qs_dict: path == "/ip" and qs_dict, + ) + Entry.single_register( + Entry.GET, + url, + body=json.dumps({"message": "There you go!"}), + headers={"content-type": "application/json"}, + can_handle_fun=lambda path, qs_dict: path == "/ip" and not qs_dict, + ) + resp = requests.get("https://httpbin.org/ip") + assert resp.status_code == 200 + assert resp.json() == {"message": "There you go!"} + + Example of how to record real socket traffic ============================================ @@ -232,10 +286,12 @@ You probably know what *VCRpy* is capable of, that's the *mocket*'s way of achie HTTPretty compatibility layer ============================= -Mocket HTTP mock can work as *HTTPretty* replacement for many different use cases. Two main features are missing: +Mocket HTTP mock can work as *HTTPretty* replacement for many different use cases. Two main features are missing, or better said, are implemented differently: -- URL entries containing regular expressions; -- response body from functions (used mostly to fake errors, *mocket* doesn't need to do it this way). +- URL entries containing regular expressions, *Mocket* implements `can_handle_fun` which is way simpler to use and more powerful; +- response body from functions (used mostly to fake errors, *Mocket* accepts an `exception` instead). + +Both features are documented above. Two features which are against the Zen of Python, at least imho (*mindflayer*), but of course I am open to call it into question. @@ -284,52 +340,44 @@ Example: .. code-block:: python - class AioHttpEntryTestCase(TestCase): - @mocketize - def test_http_session(self): - url = 'http://httpbin.org/ip' - body = "asd" * 100 - Entry.single_register(Entry.GET, url, body=body, status=404) - Entry.single_register(Entry.POST, url, body=body*2, status=201) + # `aiohttp` creates SSLContext instances at import-time + # that's why Mocket would get stuck when dealing with HTTPS + # Importing the module while Mocket is in control (inside a + # decorated test function or using its context manager would + # be enough for making it work), the alternative is using a + # custom TCPConnector which always returns a FakeSSLContext + # from Mocket like this example is showing. + import aiohttp + import pytest - async def main(l): - async with aiohttp.ClientSession( - loop=l, timeout=aiohttp.ClientTimeout(total=3) - ) as session: - async with session.get(url) as get_response: - assert get_response.status == 404 - assert await get_response.text() == body + from mocket import async_mocketize + from mocket.mocks.mockhttp import Entry + from mocket.plugins.aiohttp_connector import MocketTCPConnector - async with session.post(url, data=body * 6) as post_response: - assert post_response.status == 201 - assert await post_response.text() == body * 2 - loop = asyncio.new_event_loop() - loop.run_until_complete(main(loop)) + @pytest.mark.asyncio + @async_mocketize + async def test_aiohttp(): + """ + The alternative to using the custom `connector` would be importing + `aiohttp` when Mocket is already in control (inside the decorated test). + """ + + url = "https://bar.foo/" + data = {"message": "Hello"} + + Entry.single_register( + Entry.GET, + url, + body=json.dumps(data), + headers={"content-type": "application/json"}, + ) - # or again with a unittest.IsolatedAsyncioTestCase - from mocket.async_mocket import async_mocketize - - class AioHttpEntryTestCase(IsolatedAsyncioTestCase): - @async_mocketize - async def test_http_session(self): - url = 'http://httpbin.org/ip' - body = "asd" * 100 - Entry.single_register(Entry.GET, url, body=body, status=404) - Entry.single_register(Entry.POST, url, body=body * 2, status=201) - - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=3) - ) as session: - async with session.get(url) as get_response: - assert get_response.status == 404 - assert await get_response.text() == body - - async with session.post(url, data=body * 6) as post_response: - assert post_response.status == 201 - assert await post_response.text() == body * 2 - assert Mocket.last_request().method == 'POST' - assert Mocket.last_request().body == body * 6 + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=3), connector=MocketTCPConnector() + ) as session, session.get(url) as response: + response = await response.json() + assert response == data Works well with others diff --git a/mocket.png b/mocket.png new file mode 100644 index 00000000..08498b4c Binary files /dev/null and b/mocket.png differ diff --git a/mocket/__init__.py b/mocket/__init__.py index 2a5c3f40..27ffad16 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,6 +1,36 @@ -from .async_mocket import async_mocketize -from .mocket import Mocket, MocketEntry, Mocketizer, mocketize +"""Mocket - socket mocking library for Python.""" -__all__ = ("async_mocketize", "mocketize", "Mocket", "MocketEntry", "Mocketizer") +import importlib +import sys -__version__ = "3.12.8" +from mocket.decorators.async_mocket import async_mocketize +from mocket.decorators.mocketizer import Mocketizer, mocketize +from mocket.entry import MocketEntry +from mocket.mocket import Mocket +from mocket.ssl.context import MocketSSLContext + +# NOTE the following lines are here for backwards-compatibility, +# to keep old import-paths working +from mocket.ssl.context import MocketSSLContext as FakeSSLContext + +sys.modules["mocket.mockhttp"] = importlib.import_module("mocket.mocks.mockhttp") +sys.modules["mocket.mockredis"] = importlib.import_module("mocket.mocks.mockredis") +sys.modules["mocket.async_mocket"] = importlib.import_module( + "mocket.decorators.async_mocket" +) +sys.modules["mocket.mocketizer"] = importlib.import_module( + "mocket.decorators.mocketizer" +) + + +__all__ = ( + "async_mocketize", + "mocketize", + "Mocket", + "MocketEntry", + "Mocketizer", + "MocketSSLContext", + "FakeSSLContext", +) + +__version__ = "3.14.1" diff --git a/mocket/async_mocket.py b/mocket/async_mocket.py deleted file mode 100644 index 2970e0f4..00000000 --- a/mocket/async_mocket.py +++ /dev/null @@ -1,22 +0,0 @@ -from .mocket import Mocketizer -from .utils import get_mocketize - - -async def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): - async with Mocketizer.factory( - test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args - ): - return await test(*args, **kwargs) - - -async_mocketize = get_mocketize(wrapper_=wrapper) - - -__all__ = ("async_mocketize",) diff --git a/mocket/compat.py b/mocket/compat.py index 8457c274..a8e726f6 100644 --- a/mocket/compat.py +++ b/mocket/compat.py @@ -5,36 +5,65 @@ import shlex from typing import Final -ENCODING: Final[str] = os.getenv("MOCKET_ENCODING", "utf-8") +import puremagic -text_type = str -byte_type = bytes -basestring = (str,) +ENCODING: Final[str] = os.getenv("MOCKET_ENCODING", "utf-8") def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes: - if isinstance(s, text_type): + """Encode a string or bytes to bytes. + + Args: + s: String or bytes to encode + encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var) + + Returns: + Encoded bytes + """ + if isinstance(s, str): s = s.encode(encoding) - return byte_type(s) + return bytes(s) def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str: - if isinstance(s, byte_type): + """Decode bytes or string to string. + + Args: + s: String or bytes to decode + encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var) + + Returns: + Decoded string + """ + if isinstance(s, bytes): s = codecs.decode(s, encoding, "ignore") - return text_type(s) + return str(s) def shsplit(s: str | bytes) -> list[str]: + """Split a shell command string into arguments. + + Args: + s: Shell command string or bytes + + Returns: + List of shell command arguments + """ s = decode_from_bytes(s) return shlex.split(s) -def do_the_magic(lib_magic, body): # pragma: no cover - if hasattr(lib_magic, "from_buffer"): - # PyPI python-magic - return lib_magic.from_buffer(body, mime=True) - # file's builtin python wrapper - # used by https://www.archlinux.org/packages/community/any/python-mocket/ - _magic = lib_magic.open(lib_magic.MAGIC_MIME_TYPE) - _magic.load() - return _magic.buffer(body) +def do_the_magic(body: bytes) -> str: + """Detect MIME type of binary data using puremagic. + + Args: + body: Binary data to analyze + + Returns: + MIME type string + """ + try: + magic = puremagic.magic_string(body) + except puremagic.PureError: + magic = [] + return magic[0].mime_type if len(magic) else "application/octet-stream" diff --git a/tests/main/__init__.py b/mocket/decorators/__init__.py similarity index 100% rename from tests/main/__init__.py rename to mocket/decorators/__init__.py diff --git a/mocket/decorators/async_mocket.py b/mocket/decorators/async_mocket.py new file mode 100644 index 00000000..53b966c0 --- /dev/null +++ b/mocket/decorators/async_mocket.py @@ -0,0 +1,41 @@ +"""Async version of Mocket decorator.""" + +from __future__ import annotations + +from typing import Any, Callable + +from mocket.decorators.mocketizer import Mocketizer +from mocket.utils import get_mocketize + + +async def wrapper( + test: Callable, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + *args: Any, + **kwargs: Any, +) -> Any: + """Async wrapper function for @async_mocketize decorator. + + Args: + test: Async test function to wrap + truesocket_recording_dir: Directory for recording true socket calls + strict_mode: Enable STRICT mode to forbid real socket calls + strict_mode_allowed: List of allowed hosts in STRICT mode + *args: Test arguments + **kwargs: Test keyword arguments + + Returns: + Result of the test function + """ + async with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): + return await test(*args, **kwargs) + + +async_mocketize = get_mocketize(wrapper) + + +__all__ = ("async_mocketize",) diff --git a/mocket/decorators/mocketizer.py b/mocket/decorators/mocketizer.py new file mode 100644 index 00000000..b067ffdf --- /dev/null +++ b/mocket/decorators/mocketizer.py @@ -0,0 +1,173 @@ +"""Mocketizer decorator for managing Mocket lifecycle in tests.""" + +from __future__ import annotations + +from typing import Any, Callable + +from mocket.mocket import Mocket +from mocket.mode import MocketMode +from mocket.utils import get_mocketize + + +class Mocketizer: + """Context manager and decorator for managing Mocket lifecycle in tests.""" + + def __init__( + self, + instance: Any | None = None, + namespace: str | None = None, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + ) -> None: + """Initialize the Mocketizer. + + Args: + instance: Test instance (optional) + namespace: Namespace for recordings + truesocket_recording_dir: Directory for recording true socket calls + strict_mode: Enable STRICT mode to forbid real socket calls + strict_mode_allowed: List of allowed hosts in STRICT mode + """ + self.instance = instance + self.truesocket_recording_dir = truesocket_recording_dir + self.namespace = namespace or str(id(self)) + MocketMode.STRICT = strict_mode + if strict_mode: + MocketMode.STRICT_ALLOWED = strict_mode_allowed or [] + elif strict_mode_allowed: + raise ValueError( + "Allowed locations are only accepted when STRICT mode is active." + ) + + def enter(self) -> None: + """Enter the Mocketizer context (enable Mocket).""" + Mocket.enable( + namespace=self.namespace, + truesocket_recording_dir=self.truesocket_recording_dir, + ) + if self.instance: + self.check_and_call("mocketize_setup") + + def __enter__(self) -> Mocketizer: + """Enter context manager. + + Returns: + Self for use in `with` statements + """ + self.enter() + return self + + def exit(self) -> None: + """Exit the Mocketizer context (disable Mocket).""" + if self.instance: + self.check_and_call("mocketize_teardown") + + Mocket.disable() + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + """Exit context manager. + + Args: + type: Exception type + value: Exception value + tb: Traceback + """ + self.exit() + + async def __aenter__(self, *args: Any, **kwargs: Any) -> Mocketizer: + """Enter async context manager. + + Returns: + Self for use in `async with` statements + """ + self.enter() + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit async context manager. + + Args: + *args: Exception arguments + **kwargs: Exception keyword arguments + """ + self.exit() + + def check_and_call(self, method_name: str) -> None: + """Check if instance has a method and call it. + + Args: + method_name: Name of method to check and call + """ + method = getattr(self.instance, method_name, None) + if callable(method): + method() + + @staticmethod + def factory( + test: Callable, + truesocket_recording_dir: str | None, + strict_mode: bool, + strict_mode_allowed: list | None, + args: tuple, + ) -> Mocketizer: + """Create a Mocketizer instance for a test function. + + Args: + test: Test function being decorated + truesocket_recording_dir: Recording directory + strict_mode: Enable STRICT mode + strict_mode_allowed: Allowed hosts in STRICT mode + args: Positional arguments to test + + Returns: + Configured Mocketizer instance + """ + instance = args[0] if args else None + namespace = None + if truesocket_recording_dir: + namespace = ".".join( + ( + instance.__class__.__module__, + instance.__class__.__name__, + test.__name__, + ) + ) + + return Mocketizer( + instance, + namespace=namespace, + truesocket_recording_dir=truesocket_recording_dir, + strict_mode=strict_mode, + strict_mode_allowed=strict_mode_allowed, + ) + + +def wrapper( + test: Callable, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + *args: Any, + **kwargs: Any, +) -> Any: + """Wrapper function for @mocketize decorator. + + Args: + test: Test function to wrap + truesocket_recording_dir: Recording directory + strict_mode: Enable STRICT mode + strict_mode_allowed: Allowed hosts in STRICT mode + *args: Test arguments + **kwargs: Test keyword arguments + + Returns: + Result of the test function + """ + with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): + return test(*args, **kwargs) + + +mocketize = get_mocketize(wrapper) diff --git a/mocket/entry.py b/mocket/entry.py new file mode 100644 index 00000000..2d618472 --- /dev/null +++ b/mocket/entry.py @@ -0,0 +1,96 @@ +"""Mocket entry base class for registering mock responses.""" + +from __future__ import annotations + +import collections.abc +from typing import Any + +from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket + + +class MocketEntry: + """Base class for Mocket entries that match requests and return responses.""" + + class Response(bytes): + """Response wrapper class that extends bytes.""" + + @property + def data(self) -> bytes: + """Get the response data.""" + return self + + response_index: int = 0 + request_cls: type = bytes + response_cls: type = Response + responses: list | None = None + _served: bool | None = None + + def __init__(self, location: tuple, responses: Any) -> None: + """Initialize a Mocket entry. + + Args: + location: Tuple of (host, port) + responses: Single response or list of responses to cycle through + """ + self._served = False + self.location = location + + if not isinstance(responses, collections.abc.Iterable): + responses = [responses] + + if not responses: + self.responses = [self.response_cls(encode_to_bytes(""))] + else: + self.responses = [] + for r in responses: + if not isinstance(r, BaseException) and not getattr(r, "data", False): + if isinstance(r, str): + r = encode_to_bytes(r) + r = self.response_cls(r) + self.responses.append(r) + + def __repr__(self) -> str: + """Return a string representation of the entry.""" + return f"{self.__class__.__name__}(location={self.location})" + + @staticmethod + def can_handle(data: bytes) -> bool: + """Check if this entry can handle the given request data. + + Args: + data: Request data to check + + Returns: + True if this entry can handle the request, False otherwise + """ + return True + + def collect(self, data: bytes) -> None: + """Collect the request data in the Mocket singleton. + + Args: + data: Request data to collect + """ + req = self.request_cls(data) + Mocket.collect(req) + + def get_response(self) -> bytes: + """Get the next response to send. + + Returns: + Response bytes to send to the client + + Raises: + BaseException: If a response is an exception, it will be raised + """ + response = self.responses[self.response_index] + if self.response_index < len(self.responses) - 1: + self.response_index += 1 + + self._served = True + + if isinstance(response, BaseException): + raise response + + return response.data diff --git a/mocket/exceptions.py b/mocket/exceptions.py index f5537568..db78dbf5 100644 --- a/mocket/exceptions.py +++ b/mocket/exceptions.py @@ -1,6 +1,13 @@ +"""Mocket exception classes.""" + + class MocketException(Exception): + """Base exception class for Mocket errors.""" + pass class StrictMocketException(MocketException): + """Exception raised when a socket operation is not allowed in STRICT mode.""" + pass diff --git a/mocket/inject.py b/mocket/inject.py new file mode 100644 index 00000000..e788a929 --- /dev/null +++ b/mocket/inject.py @@ -0,0 +1,97 @@ +"""Socket patching and restoration for Mocket injection.""" + +from __future__ import annotations + +import contextlib +import socket +import ssl +from types import ModuleType +from typing import Any + +import urllib3 + +_patches_restore: dict[tuple[ModuleType, str], Any] = {} + + +def _patch(module: ModuleType, name: str, patched_value: Any) -> None: + """Patch a module with a new value and store the original. + + Args: + module: Module to patch + name: Attribute name to patch + patched_value: New value to set + """ + with contextlib.suppress(KeyError): + original_value, module.__dict__[name] = module.__dict__[name], patched_value + _patches_restore[(module, name)] = original_value + + +def _restore(module: ModuleType, name: str) -> None: + """Restore a module's original attribute value. + + Args: + module: Module to restore + name: Attribute name to restore + """ + if original_value := _patches_restore.pop((module, name)): + module.__dict__[name] = original_value + + +def enable() -> None: + """Enable Mocket by patching socket, ssl, and urllib3 modules.""" + from mocket.socket import ( + MocketSocket, + mock_create_connection, + mock_getaddrinfo, + mock_gethostbyname, + mock_gethostname, + mock_inet_pton, + mock_socketpair, + ) + from mocket.ssl.context import MocketSSLContext, mock_wrap_socket + from mocket.urllib3 import ( + mock_match_hostname as mock_urllib3_match_hostname, + ) + from mocket.urllib3 import ( + mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket, + ) + + patches = { + # stdlib: socket + (socket, "socket"): MocketSocket, + (socket, "create_connection"): mock_create_connection, + (socket, "getaddrinfo"): mock_getaddrinfo, + (socket, "gethostbyname"): mock_gethostbyname, + (socket, "gethostname"): mock_gethostname, + (socket, "inet_pton"): mock_inet_pton, + (socket, "SocketType"): MocketSocket, + (socket, "socketpair"): mock_socketpair, + # stdlib: ssl + (ssl, "SSLContext"): MocketSSLContext, + (ssl, "wrap_socket"): mock_wrap_socket, # python < 3.12.0 + # urllib3 + (urllib3.connection, "match_hostname"): mock_urllib3_match_hostname, + (urllib3.connection, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util.ssl_, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util.ssl_, "wrap_socket"): mock_urllib3_ssl_wrap_socket, # urllib3 < 2 + } + + for (module, name), new_value in patches.items(): + _patch(module, name, new_value) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import extract_from_urllib3 + + extract_from_urllib3() + + +def disable() -> None: + """Disable Mocket by restoring all patched modules.""" + for module, name in list(_patches_restore.keys()): + _restore(module, name) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import inject_into_urllib3 + + inject_into_urllib3() diff --git a/mocket/io.py b/mocket/io.py new file mode 100644 index 00000000..e815e0ec --- /dev/null +++ b/mocket/io.py @@ -0,0 +1,37 @@ +"""Mocket socket I/O implementation.""" + +from __future__ import annotations + +import io +import os + +from mocket.mocket import Mocket + + +class MocketSocketIO(io.BytesIO): + """A BytesIO wrapper that integrates with Mocket's pipe-based I/O.""" + + def __init__(self, address: tuple) -> None: + """Initialize the socket I/O with a socket address. + + Args: + address: Tuple of (host, port) + """ + self._address = address + super().__init__() + + def write(self, content: bytes) -> int: + """Write content to the buffer and the pipe if available. + + Args: + content: Bytes to write + + Returns: + Number of bytes written + """ + super().write(content) + + _, w_fd = Mocket.get_pair(self._address) + if w_fd: + os.write(w_fd, content) + return len(content) diff --git a/mocket/mocket.py b/mocket/mocket.py index cca0a4cd..75ae6285 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,768 +1,203 @@ +"""Core Mocket singleton for socket mocking management.""" + +from __future__ import annotations + import collections -import collections.abc as collections_abc -import contextlib -import errno -import hashlib import itertools -import json import os -import select -import socket -import ssl -from datetime import datetime, timedelta -from json.decoder import JSONDecodeError -from typing import Optional, Tuple - -import urllib3 -from urllib3.connection import match_hostname as urllib3_match_hostname -from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket - -try: - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket -except ImportError: - urllib3_wrap_socket = None - - -from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type -from .utils import ( - SSL_PROTOCOL, - MocketMode, - MocketSocketCore, - get_mocketize, - hexdump, - hexload, -) - -xxh32 = None -try: - from xxhash import xxh32 -except ImportError: # pragma: no cover - with contextlib.suppress(ImportError): - from xxhash_cffi import xxh32 -hasher = xxh32 or hashlib.md5 - -try: # pragma: no cover - from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 - - pyopenssl_override = True -except ImportError: - pyopenssl_override = False - -try: # pragma: no cover - from aiohttp import TCPConnector - - aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear -except (ImportError, AttributeError): - aiohttp_make_ssl_context_cache_clear = None - - -true_socket = socket.socket -true_create_connection = socket.create_connection -true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_getaddrinfo = socket.getaddrinfo -true_socketpair = socket.socketpair -true_ssl_wrap_socket = getattr( - ssl, "wrap_socket", None -) # in Py3.12 it's only under SSLContext -true_ssl_socket = ssl.SSLSocket -true_ssl_context = ssl.SSLContext -true_inet_pton = socket.inet_pton -true_urllib3_wrap_socket = urllib3_wrap_socket -true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket -true_urllib3_match_hostname = urllib3_match_hostname - - -class SuperFakeSSLContext: - """For Python 3.6 and newer.""" - - class FakeSetter(int): - def __set__(self, *args): - pass - - minimum_version = FakeSetter() - options = FakeSetter() - verify_mode = FakeSetter() - - -class FakeSSLContext(SuperFakeSSLContext): - DUMMY_METHODS = ( - "load_default_certs", - "load_verify_locations", - "set_alpn_protocols", - "set_ciphers", - "set_default_verify_paths", - ) - sock = None - post_handshake_auth = None - _check_hostname = False - - @property - def check_hostname(self): - return self._check_hostname - - @check_hostname.setter - def check_hostname(self, *args): - self._check_hostname = False - - def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs): - self._set_dummy_methods() - - if isinstance(sock, MocketSocket): - self.sock = sock - self.sock._host = server_hostname - self.sock.true_socket = true_ssl_socket( - sock=self.sock.true_socket, - server_hostname=server_hostname, - _context=true_ssl_context(protocol=SSL_PROTOCOL), - ) - elif isinstance(sock, int) and true_ssl_context: - self.context = true_ssl_context(sock) - - def _set_dummy_methods(self): - def dummy_method(*args, **kwargs): - pass - - for m in self.DUMMY_METHODS: - setattr(self, m, dummy_method) - - @staticmethod - def wrap_socket(sock=sock, *args, **kwargs): - sock.kwargs = kwargs - sock._secure_socket = True - return sock - - @staticmethod - def wrap_bio(incoming, outcoming, *args, **kwargs): - ssl_obj = MocketSocket() - ssl_obj._host = kwargs["server_hostname"] - return ssl_obj - - def __getattr__(self, name): - if self.sock is not None: - return getattr(self.sock, name) - - -def create_connection(address, timeout=None, source_address=None): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) - if timeout: - s.settimeout(timeout) - s.connect(address) - return s - - -def socketpair(*args, **kwargs): - """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" - import _socket - - return _socket.socketpair(*args, **kwargs) - - -def _hash_request(h, req): - return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() - - -class MocketSocket: - timeout = None - _fd = None - family = None - type = None - proto = None - _host = None - _port = None - _address = None - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION - _mode = None - _bufsize = None - _secure_socket = False - _did_handshake = False - _sent_non_empty_bytes = False - _io = None - - def __init__( - self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs - ): - self.true_socket = true_socket(family, type, proto) - self._buflen = 65536 - self._entry = None - self.family = int(family) - self.type = int(type) - self.proto = int(proto) - self._truesocket_recording_dir = None - self.kwargs = kwargs - - def __str__(self): - return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - @property - def io(self): - if self._io is None: - self._io = MocketSocketCore((self._host, self._port)) - return self._io - - def fileno(self): - address = (self._host, self._port) - r_fd, _ = Mocket.get_pair(address) - if not r_fd: - r_fd, w_fd = os.pipe() - Mocket.set_pair(address, (r_fd, w_fd)) - return r_fd - - def gettimeout(self): - return self.timeout - - def setsockopt(self, family, type, proto): - self.family = family - self.type = type - self.proto = proto - - if self.true_socket: - self.true_socket.setsockopt(family, type, proto) - - def settimeout(self, timeout): - self.timeout = timeout - - @staticmethod - def getsockopt(level, optname, buflen=None): - return socket.SOCK_STREAM - - def do_handshake(self): - self._did_handshake = True - - def getpeername(self): - return self._address - - def setblocking(self, block): - self.settimeout(None) if block else self.settimeout(0.0) - - def getblocking(self): - return self.gettimeout() is None - - def getsockname(self): - return socket.gethostbyname(self._address[0]), self._address[1] - - def getpeercert(self, *args, **kwargs): - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", f"*.{self._host}"), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", f"*.{self._host}"),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", f"*.{self._host}"),), - ), - } - - def unwrap(self): - return self - - def write(self, data): - return self.send(encode_to_bytes(data)) - - def connect(self, address): - self._address = self._host, self._port = address - Mocket._address = address - - def makefile(self, mode="r", bufsize=-1): - self._mode = mode - self._bufsize = bufsize - return self.io - - def get_entry(self, data): - return Mocket.get_entry(self._host, self._port, data) - - def sendall(self, data, entry=None, *args, **kwargs): - if entry is None: - entry = self.get_entry(data) - - if entry: - consume_response = entry.collect(data) - response = entry.get_response() if consume_response is not False else None - else: - response = self.true_sendall(data, *args, **kwargs) - - if response is not None: - self.io.seek(0) - self.io.write(response) - self.io.truncate() - self.io.seek(0) - - def read(self, buffersize): - rv = self.io.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv - - def recv_into(self, buffer, buffersize=None, flags=None): - if hasattr(buffer, "write"): - return buffer.write(self.read(buffersize)) - # buffer is a memoryview - data = self.read(buffersize) - if data: - buffer[: len(data)] = data - return len(data) - - def recv(self, buffersize, flags=None): - r_fd, _ = Mocket.get_pair((self._host, self._port)) - if r_fd: - return os.read(r_fd, buffersize) - data = self.read(buffersize) - if data: - return data - # used by Redis mock - exc = BlockingIOError() - exc.errno = errno.EWOULDBLOCK - exc.args = (0,) - raise exc - - def true_sendall(self, data, *args, **kwargs): - if not MocketMode().is_allowed((self._host, self._port)): - MocketMode.raise_not_allowed() - - req = decode_from_bytes(data) - # make request unique again - req_signature = _hash_request(hasher, req) - # port should be always a string - port = text_type(self._port) - - # prepare responses dictionary - responses = {} - - if Mocket.get_truesocket_recording_dir(): - path = os.path.join( - Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" - ) - # check if there's already a recorded session dumped to a JSON file - try: - with open(path) as f: - responses = json.load(f) - # if not, create a new dictionary - except (FileNotFoundError, JSONDecodeError): - pass - - try: - try: - response_dict = responses[self._host][port][req_signature] - except KeyError: - if hasher is not hashlib.md5: - # Fallback for backwards compatibility - req_signature = _hash_request(hashlib.md5, req) - response_dict = responses[self._host][port][req_signature] - else: - raise - except KeyError: - # preventing next KeyError exceptions - responses.setdefault(self._host, {}) - responses[self._host].setdefault(port, {}) - responses[self._host][port].setdefault(req_signature, {}) - response_dict = responses[self._host][port][req_signature] - - # try to get the response from the dictionary - try: - encoded_response = hexload(response_dict["response"]) - # if not available, call the real sendall - except KeyError: - host, port = self._host, self._port - host = true_gethostbyname(host) - - if isinstance(self.true_socket, true_socket) and self._secure_socket: - self.true_socket = true_urllib3_ssl_wrap_socket( - self.true_socket, - **self.kwargs, - ) - - with contextlib.suppress(OSError, ValueError): - # already connected - self.true_socket.connect((host, port)) - self.true_socket.sendall(data, *args, **kwargs) - encoded_response = b"" - # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L13 - while True: - if ( - not select.select([self.true_socket], [], [], 0.1)[0] - and encoded_response - ): - break - recv = self.true_socket.recv(self._buflen) - - if not recv and encoded_response: - break - encoded_response += recv - - # dump the resulting dictionary to a JSON file - if Mocket.get_truesocket_recording_dir(): - # update the dictionary with request and response lines - response_dict["request"] = req - response_dict["response"] = hexdump(encoded_response) - - with open(path, mode="w") as f: - f.write( - decode_from_bytes( - json.dumps(responses, indent=4, sort_keys=True) - ) - ) - - # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO - return encoded_response - - def send(self, data, *args, **kwargs): # pragma: no cover - entry = self.get_entry(data) - if not entry or (entry and self._entry != entry): - kwargs["entry"] = entry - self.sendall(data, *args, **kwargs) - else: - req = Mocket.last_request() - if hasattr(req, "add_data"): - req.add_data(data) - self._entry = entry - return len(data) - - def close(self): - if self.true_socket and not self.true_socket._closed: - self.true_socket.close() - self._fd = None - - def __getattr__(self, name): - """Do nothing catchall function, for methods like shutdown()""" - - def do_nothing(*args, **kwargs): - pass - - return do_nothing +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar + +import mocket.inject +from mocket.recording import MocketRecordStorage + +# NOTE this is here for backwards-compat to keep old import-paths working +# from mocket.socket import MocketSocket as MocketSocket + +if TYPE_CHECKING: + from mocket.entry import MocketEntry + from mocket.types import Address class Mocket: - _socket_pairs = {} - _address = (None, None) - _entries = collections.defaultdict(list) - _requests = [] - _namespace = text_type(id(_entries)) - _truesocket_recording_dir = None + """Singleton class managing all mock socket operations and entries.""" + + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} + _address: ClassVar[Address | tuple[None, None]] = (None, None) + _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) + _requests: ClassVar[list] = [] + _record_storage: ClassVar[MocketRecordStorage | None] = None @classmethod - def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: + def enable( + cls, + namespace: str | None = None, + truesocket_recording_dir: str | None = None, + ) -> None: + """Enable Mocket socket mocking. + + Args: + namespace: Namespace for recording storage (defaults to id of _entries) + truesocket_recording_dir: Directory to store recorded requests/responses """ + if namespace is None: + namespace = str(id(cls._entries)) + + if truesocket_recording_dir is not None: + recording_dir = Path(truesocket_recording_dir) + + assert recording_dir.is_dir(), f"Not a directory: {recording_dir}" + + cls._record_storage = MocketRecordStorage( + directory=recording_dir, + namespace=namespace, + ) + + mocket.inject.enable() + + @classmethod + def disable(cls) -> None: + """Disable Mocket socket mocking and clean up resources.""" + cls.reset() + + mocket.inject.disable() + + @classmethod + def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: + """Get the file descriptor pair for a socket address. + Given the id() of the caller, return a pair of file descriptors as a tuple of two integers: (, ) + + Args: + address: (host, port) tuple + + Returns: + Tuple of (read_fd, write_fd) or (None, None) if not found """ return cls._socket_pairs.get(address, (None, None)) @classmethod - def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: - """ - Store a pair of file descriptors under the key `id_` + def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: + """Store a file descriptor pair for a socket address. + + Store a pair of file descriptors under the key `address` as a tuple of two integers: (, ) + + Args: + address: (host, port) tuple + pair: Tuple of (read_fd, write_fd) """ cls._socket_pairs[address] = pair @classmethod - def register(cls, *entries): + def register(cls, *entries: MocketEntry) -> None: + """Register mock entries with Mocket. + + Args: + *entries: Variable number of MocketEntry instances to register + """ for entry in entries: cls._entries[entry.location].append(entry) @classmethod - def get_entry(cls, host, port, data): - host = host or Mocket._address[0] - port = port or Mocket._address[1] + def get_entry(cls, host: str, port: int, data: Any) -> MocketEntry | None: + """Get a matching entry for the given request data. + + Args: + host: Hostname + port: Port number + data: Request data + + Returns: + Matching MocketEntry or None + """ + host = host or cls._address[0] + port = port or cls._address[1] entries = cls._entries.get((host, port), []) for entry in entries: if entry.can_handle(data): return entry + return None @classmethod - def collect(cls, data): - cls.request_list().append(data) + def collect(cls, data: Any) -> None: + """Collect a request in the list of all requests. + + Args: + data: Request data to collect + """ + cls._requests.append(data) @classmethod - def reset(cls): + def reset(cls) -> None: + """Reset all Mocket state and clean up file descriptors.""" for r_fd, w_fd in cls._socket_pairs.values(): os.close(r_fd) os.close(w_fd) cls._socket_pairs = {} cls._entries = collections.defaultdict(list) cls._requests = [] + cls._record_storage = None @classmethod - def last_request(cls): + def last_request(cls) -> Any: + """Get the last request made. + + Returns: + Last request data or None if no requests + """ if cls.has_requests(): - return cls.request_list()[-1] + return cls._requests[-1] @classmethod - def request_list(cls): + def request_list(cls) -> list[Any]: + """Get the list of all requests. + + Returns: + List of all collected requests + """ return cls._requests @classmethod - def remove_last_request(cls): + def remove_last_request(cls) -> None: + """Remove the last request from the request list.""" if cls.has_requests(): del cls._requests[-1] @classmethod - def has_requests(cls): - return bool(cls.request_list()) + def has_requests(cls) -> bool: + """Check if any requests have been made. - @staticmethod - def enable(namespace=None, truesocket_recording_dir=None): - Mocket._namespace = namespace - Mocket._truesocket_recording_dir = truesocket_recording_dir - - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): - # JSON dumps will be saved here - raise AssertionError - - socket.socket = socket.__dict__["socket"] = MocketSocket - socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket - socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = ( - create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" - socket.gethostbyname = socket.__dict__["gethostbyname"] = ( - lambda host: "127.0.0.1" - ) - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( - lambda host, port, family=None, socktype=None, proto=None, flags=None: [ - (2, 1, 6, "", (host, port)) - ] - ) - socket.socketpair = socket.__dict__["socketpair"] = socketpair - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type( - "\x7f\x00\x00\x01", "utf-8" - ) - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - FakeSSLContext.wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - FakeSSLContext.wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = lambda *args: None - if pyopenssl_override: # pragma: no cover - # Take out the pyopenssl version - use the default implementation - extract_from_urllib3() - if aiohttp_make_ssl_context_cache_clear: # pragma: no cover - aiohttp_make_ssl_context_cache_clear() - - @staticmethod - def disable(): - socket.socket = socket.__dict__["socket"] = true_socket - socket._socketobject = socket.__dict__["_socketobject"] = true_socket - socket.SocketType = socket.__dict__["SocketType"] = true_socket - socket.create_connection = socket.__dict__["create_connection"] = ( - true_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = true_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = true_socketpair - if true_ssl_wrap_socket: - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context - socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - true_urllib3_wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - true_urllib3_ssl_wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = true_urllib3_match_hostname - Mocket.reset() - if pyopenssl_override: # pragma: no cover - # Put the pyopenssl version back in place - inject_into_urllib3() - if aiohttp_make_ssl_context_cache_clear: # pragma: no cover - aiohttp_make_ssl_context_cache_clear() + Returns: + True if there are requests, False otherwise + """ + return bool(cls.request_list()) @classmethod - def get_namespace(cls): - return cls._namespace + def get_namespace(cls) -> str | None: + """Get the recording namespace. - @classmethod - def get_truesocket_recording_dir(cls): - return cls._truesocket_recording_dir + Returns: + Namespace string or None if recording is not enabled + """ + return cls._record_storage.namespace if cls._record_storage else None @classmethod - def assert_fail_if_entries_not_served(cls): - """Mocket checks that all entries have been served at least once.""" - if not all(entry._served for entry in itertools.chain(*cls._entries.values())): - raise AssertionError("Some Mocket entries have not been served") - + def get_truesocket_recording_dir(cls) -> str | None: + """Get the true socket recording directory. -class MocketEntry: - class Response(byte_type): - @property - def data(self): - return self - - response_index = 0 - request_cls = byte_type - response_cls = Response - responses = None - _served = None - - def __init__(self, location, responses): - self._served = False - self.location = location - - if not isinstance(responses, collections_abc.Iterable) or isinstance( - responses, basestring - ): - responses = [responses] - - if not responses: - self.responses = [self.response_cls(encode_to_bytes(""))] - else: - self.responses = [] - for r in responses: - if not isinstance(r, BaseException) and not getattr(r, "data", False): - if isinstance(r, text_type): - r = encode_to_bytes(r) - r = self.response_cls(r) - self.responses.append(r) - - def __repr__(self): - return f"{self.__class__.__name__}(location={self.location})" - - @staticmethod - def can_handle(data): - return True - - def collect(self, data): - req = self.request_cls(data) - Mocket.collect(req) - - def get_response(self): - response = self.responses[self.response_index] - if self.response_index < len(self.responses) - 1: - self.response_index += 1 - - self._served = True - - if isinstance(response, BaseException): - raise response - - return response.data - - -class Mocketizer: - def __init__( - self, - instance=None, - namespace=None, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - ): - self.instance = instance - self.truesocket_recording_dir = truesocket_recording_dir - self.namespace = namespace or text_type(id(self)) - MocketMode().STRICT = strict_mode - if strict_mode: - MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] - elif strict_mode_allowed: - raise ValueError( - "Allowed locations are only accepted when STRICT mode is active." - ) + Returns: + Directory path as string or None if recording is not enabled + """ + return str(cls._record_storage.directory) if cls._record_storage else None - def enter(self): - Mocket.enable( - namespace=self.namespace, - truesocket_recording_dir=self.truesocket_recording_dir, - ) - if self.instance: - self.check_and_call("mocketize_setup") - - def __enter__(self): - self.enter() - return self - - def exit(self): - if self.instance: - self.check_and_call("mocketize_teardown") - Mocket.disable() - - def __exit__(self, type, value, tb): - self.exit() - - async def __aenter__(self, *args, **kwargs): - self.enter() - return self - - async def __aexit__(self, *args, **kwargs): - self.exit() - - def check_and_call(self, method_name): - method = getattr(self.instance, method_name, None) - if callable(method): - method() - - @staticmethod - def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): - instance = args[0] if args else None - namespace = None - if truesocket_recording_dir: - namespace = ".".join( - ( - instance.__class__.__module__, - instance.__class__.__name__, - test.__name__, - ) - ) + @classmethod + def assert_fail_if_entries_not_served(cls) -> None: + """Assert that all registered entries have been served at least once. - return Mocketizer( - instance, - namespace=namespace, - truesocket_recording_dir=truesocket_recording_dir, - strict_mode=strict_mode, - strict_mode_allowed=strict_mode_allowed, - ) - - -def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): - with Mocketizer.factory( - test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args - ): - return test(*args, **kwargs) - - -mocketize = get_mocketize(wrapper_=wrapper) + Raises: + AssertionError: If any entries have not been served + """ + if not all(entry._served for entry in itertools.chain(*cls._entries.values())): + raise AssertionError("Some Mocket entries have not been served") diff --git a/mocket/mockhttp.py b/mocket/mockhttp.py deleted file mode 100644 index 25540915..00000000 --- a/mocket/mockhttp.py +++ /dev/null @@ -1,278 +0,0 @@ -import re -import time -from http.server import BaseHTTPRequestHandler -from urllib.parse import parse_qs, unquote, urlsplit - -from httptools.parser import HttpRequestParser - -from .compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes -from .mocket import Mocket, MocketEntry - -try: - import magic -except ImportError: - magic = None - - -STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} -CRLF = "\r\n" -ASCII = "ascii" - - -class Protocol: - def __init__(self): - self.url = None - self.body = None - self.headers = {} - - def on_header(self, name: bytes, value: bytes): - self.headers[name.decode(ASCII)] = value.decode(ASCII) - - def on_body(self, body: bytes): - try: - self.body = body.decode(ENCODING) - except UnicodeDecodeError: - self.body = body - - def on_url(self, url: bytes): - self.url = url.decode(ASCII) - - -class Request: - _protocol = None - _parser = None - - def __init__(self, data): - self._protocol = Protocol() - self._parser = HttpRequestParser(self._protocol) - self.add_data(data) - - def add_data(self, data): - self._parser.feed_data(data) - - @property - def method(self): - return self._parser.get_method().decode(ASCII) - - @property - def path(self): - return self._protocol.url - - @property - def headers(self): - return self._protocol.headers - - @property - def querystring(self): - parts = self._protocol.url.split("?", 1) - return ( - parse_qs(unquote(parts[1]), keep_blank_values=True) - if len(parts) == 2 - else {} - ) - - @property - def body(self): - return self._protocol.body - - def __str__(self): - return f"{self.method} - {self.path} - {self.headers}" - - -class Response: - headers = None - is_file_object = False - - def __init__(self, body="", status=200, headers=None, lib_magic=magic): - # needed for testing libmagic import failure - self.magic = lib_magic - - headers = headers or {} - try: - # File Objects - self.body = body.read() - self.is_file_object = True - except AttributeError: - self.body = encode_to_bytes(body) - self.status = status - - self.set_base_headers() - - if headers is not None: - self.set_extra_headers(headers) - - self.data = self.get_protocol_data() + self.body - - def get_protocol_data(self, str_format_fun_name="capitalize"): - status_line = f"HTTP/1.1 {self.status} {STATUS[self.status]}" - header_lines = CRLF.join( - ( - f"{getattr(k, str_format_fun_name)()}: {v}" - for k, v in self.headers.items() - ) - ) - return f"{status_line}\r\n{header_lines}\r\n\r\n".encode(ENCODING) - - def set_base_headers(self): - self.headers = { - "Status": str(self.status), - "Date": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()), - "Server": "Python/Mocket", - "Connection": "close", - "Content-Length": str(len(self.body)), - } - if not self.is_file_object: - self.headers["Content-Type"] = f"text/plain; charset={ENCODING}" - elif self.magic: - self.headers["Content-Type"] = do_the_magic(self.magic, self.body) - - def set_extra_headers(self, headers): - r""" - >>> r = Response(body="") - >>> len(r.headers.keys()) - 6 - >>> r.set_extra_headers({"foo-bar": "Foobar"}) - >>> len(r.headers.keys()) - 7 - >>> encode_to_bytes(r.headers.get("Foo-Bar")) == encode_to_bytes("Foobar") - True - """ - for k, v in headers.items(): - self.headers["-".join(token.capitalize() for token in k.split("-"))] = v - - -class Entry(MocketEntry): - CONNECT = "CONNECT" - DELETE = "DELETE" - GET = "GET" - HEAD = "HEAD" - OPTIONS = "OPTIONS" - PATCH = "PATCH" - POST = "POST" - PUT = "PUT" - TRACE = "TRACE" - - METHODS = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE) - - request_cls = Request - response_cls = Response - - def __init__(self, uri, method, responses, match_querystring=True): - uri = urlsplit(uri) - - port = uri.port - if not port: - port = 443 if uri.scheme == "https" else 80 - - super().__init__((uri.hostname, port), responses) - self.schema = uri.scheme - self.path = uri.path - self.query = uri.query - self.method = method.upper() - self._sent_data = b"" - self._match_querystring = match_querystring - - def __repr__(self): - return f"{self.__class__.__name__}(method={self.method!r}, schema={self.schema!r}, location={self.location!r}, path={self.path!r}, query={self.query!r})" - - def collect(self, data): - consume_response = True - - decoded_data = decode_from_bytes(data) - if not decoded_data.startswith(Entry.METHODS): - Mocket.remove_last_request() - self._sent_data += data - consume_response = False - else: - self._sent_data = data - - super().collect(self._sent_data) - - return consume_response - - def can_handle(self, data): - r""" - >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),)) - >>> e.can_handle(b'GET /?bar=foo HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n') - False - >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),)) - >>> e.can_handle(b'GET /?bar=foo&foobar HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n') - True - """ - try: - requestline, _ = decode_from_bytes(data).split(CRLF, 1) - method, path, _ = self._parse_requestline(requestline) - except ValueError: - return self is getattr(Mocket, "_last_entry", None) - - uri = urlsplit(path) - can_handle = uri.path == self.path and method == self.method - if self._match_querystring: - kw = dict(keep_blank_values=True) - can_handle = can_handle and parse_qs(uri.query, **kw) == parse_qs( - self.query, **kw - ) - if can_handle: - Mocket._last_entry = self - return can_handle - - @staticmethod - def _parse_requestline(line): - """ - http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5 - - >>> Entry._parse_requestline('GET / HTTP/1.0') == ('GET', '/', '1.0') - True - >>> Entry._parse_requestline('post /testurl htTP/1.1') == ('POST', '/testurl', '1.1') - True - >>> Entry._parse_requestline('Im not a RequestLine') - Traceback (most recent call last): - ... - ValueError: Not a Request-Line - """ - m = re.match( - r"({})\s+(.*)\s+HTTP/(1.[0|1])".format("|".join(Entry.METHODS)), line, re.I - ) - if m: - return m.group(1).upper(), m.group(2), m.group(3) - raise ValueError("Not a Request-Line") - - @classmethod - def register(cls, method, uri, *responses, **config): - if "body" in config or "status" in config: - raise AttributeError("Did you mean `Entry.single_register(...)`?") - - default_config = dict(match_querystring=True, add_trailing_slash=True) - default_config.update(config) - config = default_config - - if config["add_trailing_slash"] and not urlsplit(uri).path: - uri += "/" - - Mocket.register( - cls(uri, method, responses, match_querystring=config["match_querystring"]) - ) - - @classmethod - def single_register( - cls, - method, - uri, - body="", - status=200, - headers=None, - match_querystring=True, - exception=None, - ): - response = ( - exception - if exception - else cls.response_cls(body=body, status=status, headers=headers) - ) - - cls.register( - method, - uri, - response, - match_querystring=match_querystring, - ) diff --git a/mocket/mockredis.py b/mocket/mockredis.py deleted file mode 100644 index 1a0c51e2..00000000 --- a/mocket/mockredis.py +++ /dev/null @@ -1,86 +0,0 @@ -from itertools import chain - -from .compat import byte_type, decode_from_bytes, encode_to_bytes, shsplit, text_type -from .mocket import Mocket, MocketEntry - - -class Request: - def __init__(self, data): - self.data = data - - -class Response: - def __init__(self, data=None): - self.data = Redisizer.redisize(data or OK) - - -class Redisizer(byte_type): - @staticmethod - def tokens(iterable): - iterable = [encode_to_bytes(x) for x in iterable] - return [f"*{len(iterable)}".encode()] + list( - chain(*zip([f"${len(x)}".encode() for x in iterable], iterable)) - ) - - @staticmethod - def redisize(data): - def get_conversion(t): - return { - dict: lambda x: b"\r\n".join( - Redisizer.tokens(list(chain(*tuple(x.items())))) - ), - int: lambda x: f":{x}".encode(), - text_type: lambda x: "${}\r\n{}".format( - len(x.encode("utf-8")), x - ).encode("utf-8"), - list: lambda x: b"\r\n".join(Redisizer.tokens(x)), - }[t] - - if isinstance(data, Redisizer): - return data - if isinstance(data, byte_type): - data = decode_from_bytes(data) - return Redisizer(get_conversion(data.__class__)(data) + b"\r\n") - - @staticmethod - def command(description, _type="+"): - return Redisizer("{}{}{}".format(_type, description, "\r\n").encode("utf-8")) - - @staticmethod - def error(description): - return Redisizer.command(description, _type="-") - - -OK = Redisizer.command("OK") -QUEUED = Redisizer.command("QUEUED") -ERROR = Redisizer.error - - -class Entry(MocketEntry): - request_cls = Request - response_cls = Response - - def __init__(self, addr, command, responses): - super().__init__(addr or ("localhost", 6379), responses) - d = shsplit(command) - d[0] = d[0].upper() - self.command = Redisizer.tokens(d) - - def can_handle(self, data): - return data.splitlines() == self.command - - @classmethod - def register(cls, addr, command, *responses): - responses = [ - r if isinstance(r, BaseException) else cls.response_cls(r) - for r in responses - ] - Mocket.register(cls(addr, command, responses)) - - @classmethod - def register_response(cls, command, response, addr=None): - cls.register(addr, command, response) - - @classmethod - def register_responses(cls, command, responses, addr=None): - cls.register(addr, command, *responses) diff --git a/tests/tests38/__init__.py b/mocket/mocks/__init__.py similarity index 100% rename from tests/tests38/__init__.py rename to mocket/mocks/__init__.py diff --git a/mocket/mocks/mockhttp.py b/mocket/mocks/mockhttp.py new file mode 100644 index 00000000..5ec14a62 --- /dev/null +++ b/mocket/mocks/mockhttp.py @@ -0,0 +1,437 @@ +"""HTTP mocking implementation for Mocket.""" + +from __future__ import annotations + +import re +import time +from functools import cached_property +from http.server import BaseHTTPRequestHandler +from typing import Any, Callable +from urllib.parse import parse_qs, unquote, urlsplit + +from h11 import SERVER, Connection, Data +from h11 import Request as H11Request + +from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes +from mocket.entry import MocketEntry +from mocket.mocket import Mocket + +STATUS: dict = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} +CRLF: str = "\r\n" +ASCII: str = "ascii" + + +class Request: + """HTTP request parser using h11.""" + + _parser: Connection | None = None + _event: Any | None = None + + def __init__(self, data: bytes) -> None: + """Initialize the request parser. + + Args: + data: Raw HTTP request data + """ + self._parser = Connection(SERVER) + self.add_data(data) + + def add_data(self, data: bytes) -> None: + """Add more data to the request. + + Args: + data: Additional raw request data + """ + self._parser.receive_data(data) + + @property + def event(self) -> Any: + """Get the parsed request event. + + Returns: + The h11 request event + """ + if not self._event: + self._event = self._parser.next_event() + return self._event + + @cached_property + def method(self) -> str: + """Get the HTTP method. + + Returns: + HTTP method (GET, POST, etc.) + """ + return self.event.method.decode(ASCII) + + @cached_property + def path(self) -> str: + """Get the request path. + + Returns: + Request path with query string + """ + return self.event.target.decode(ASCII) + + @cached_property + def headers(self) -> dict: + """Get the request headers. + + Returns: + Dictionary of header names to values + """ + return {k.decode(ASCII): v.decode(ASCII) for k, v in self.event.headers} + + @cached_property + def querystring(self) -> dict: + """Get the parsed query string. + + Returns: + Dictionary of query parameter names to lists of values + """ + parts = self.path.split("?", 1) + return ( + parse_qs(unquote(parts[1]), keep_blank_values=True) + if len(parts) == 2 + else {} + ) + + @cached_property + def body(self) -> str: + """Get the request body. + + Returns: + Decoded request body string + """ + while True: + event = self._parser.next_event() + if isinstance(event, H11Request): + self._event = event + elif isinstance(event, Data): + return event.data.decode(ENCODING) + + def __str__(self) -> str: + """Get string representation of request. + + Returns: + Formatted request string + """ + return f"{self.method} - {self.path} - {self.headers}" + + +class Response: + """HTTP response builder.""" + + headers: dict | None = None + is_file_object: bool = False + + def __init__( + self, body: Any = "", status: int = 200, headers: dict | None = None + ) -> None: + """Initialize an HTTP response. + + Args: + body: Response body (string, bytes, or file-like object) + status: HTTP status code + headers: Dictionary of response headers + """ + headers = headers or {} + try: + # File Objects + self.body = body.read() + self.is_file_object = True + except AttributeError: + self.body = encode_to_bytes(body) + self.status = status + + self.set_base_headers() + self.set_extra_headers(headers) + + self.data = self.get_protocol_data() + self.body + + def get_protocol_data(self, str_format_fun_name: str = "capitalize") -> bytes: + """Get the HTTP protocol headers and status line. + + Args: + str_format_fun_name: Name of string formatting method to use + + Returns: + Bytes of protocol headers (status line and headers) + """ + status_line = f"HTTP/1.1 {self.status} {STATUS[self.status]}" + header_lines = CRLF.join( + ( + f"{getattr(k, str_format_fun_name)()}: {v}" + for k, v in self.headers.items() + ) + ) + return f"{status_line}\r\n{header_lines}\r\n\r\n".encode(ENCODING) + + def set_base_headers(self) -> None: + """Set the base response headers.""" + self.headers = { + "Status": str(self.status), + "Date": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()), + "Server": "Python/Mocket", + "Connection": "close", + "Content-Length": str(len(self.body)), + } + if not self.is_file_object: + self.headers["Content-Type"] = f"text/plain; charset={ENCODING}" + else: + self.headers["Content-Type"] = do_the_magic(self.body) + + def set_extra_headers(self, headers: dict) -> None: + r"""Add extra headers to the response. + + Args: + headers: Dictionary of additional headers + + >>> r = Response(body="") + >>> len(r.headers.keys()) + 6 + >>> r.set_extra_headers({"foo-bar": "Foobar"}) + >>> len(r.headers.keys()) + 7 + >>> encode_to_bytes(r.headers.get("Foo-Bar")) == encode_to_bytes("Foobar") + True + """ + for k, v in headers.items(): + self.headers["-".join(token.capitalize() for token in k.split("-"))] = v + + +class Entry(MocketEntry): + """HTTP entry for matching and responding to HTTP requests.""" + + CONNECT = "CONNECT" + DELETE = "DELETE" + GET = "GET" + HEAD = "HEAD" + OPTIONS = "OPTIONS" + PATCH = "PATCH" + POST = "POST" + PUT = "PUT" + TRACE = "TRACE" + + METHODS: tuple = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE) + + request_cls: type = Request + response_cls: type = Response + + default_config: dict = {"match_querystring": True, "can_handle_fun": None} + _can_handle_fun: Callable | None = None + + def __init__( + self, + uri: str, + method: str, + responses: Any, + match_querystring: bool = True, + can_handle_fun: Callable | None = None, + ) -> None: + """Initialize an HTTP entry. + + Args: + uri: URI to match (http://host:port/path?query) + method: HTTP method (GET, POST, etc.) + responses: Response(s) to return + match_querystring: Whether to match query strings + can_handle_fun: Custom matching function + """ + self._can_handle_fun = can_handle_fun if can_handle_fun else self._can_handle + + uri = urlsplit(uri) + + port = uri.port + if not port: + port = 443 if uri.scheme == "https" else 80 + + super().__init__((uri.hostname, port), responses) + self.schema = uri.scheme + self.path = uri.path or "/" + self.query = uri.query + self.method = method.upper() + self._sent_data = b"" + self._match_querystring = match_querystring + + def __repr__(self) -> str: + """Get string representation of the entry. + + Returns: + String representation + """ + return f"{self.__class__.__name__}(method={self.method!r}, schema={self.schema!r}, location={self.location!r}, path={self.path!r}, query={self.query!r})" + + def collect(self, data: bytes) -> bool: + """Collect the request data. + + Args: + data: Request data + + Returns: + Whether to consume the response + """ + consume_response = True + + decoded_data = decode_from_bytes(data) + if not decoded_data.startswith(Entry.METHODS): + Mocket.remove_last_request() + self._sent_data += data + consume_response = False + else: + self._sent_data = data + + super().collect(self._sent_data) + + return consume_response + + def _can_handle(self, path: str, qs_dict: dict) -> bool: + """Default can_handle function checking path and query string. + + Args: + path: Request path + qs_dict: Parsed query string parameters + + Returns: + True if this entry can handle the request + """ + can_handle = path == self.path + if self._match_querystring: + can_handle = can_handle and qs_dict == parse_qs( + self.query, keep_blank_values=True + ) + return can_handle + + def can_handle(self, data: bytes) -> bool: + r"""Check if this entry can handle the given request data. + + Args: + data: Request data + + Returns: + True if this entry can handle the request + + >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),)) + >>> e.can_handle(b'GET /?bar=foo HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n') + False + >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),)) + >>> e.can_handle(b'GET /?bar=foo&foobar HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n') + True + """ + try: + requestline, _ = decode_from_bytes(data).split(CRLF, 1) + method, path, _ = self._parse_requestline(requestline) + except ValueError: + return self is getattr(Mocket, "_last_entry", None) + + _request = urlsplit(path) + + can_handle = method == self.method and self._can_handle_fun( + _request.path, parse_qs(_request.query, keep_blank_values=True) + ) + + if can_handle: + Mocket._last_entry = self + return can_handle + + @staticmethod + def _parse_requestline(line: str) -> tuple: + """Parse an HTTP request line. + + http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5 + + Args: + line: HTTP request line string + + Returns: + Tuple of (method, path, version) + + Raises: + ValueError: If line is not a valid request line + + >>> Entry._parse_requestline('GET / HTTP/1.0') == ('GET', '/', '1.0') + True + >>> Entry._parse_requestline('post /testurl htTP/1.1') == ('POST', '/testurl', '1.1') + True + >>> Entry._parse_requestline('Im not a RequestLine') + Traceback (most recent call last): + ... + ValueError: Not a Request-Line + """ + m = re.match( + r"({})\s+(.*)\s+HTTP/(1.[0|1])".format("|".join(Entry.METHODS)), line, re.I + ) + if m: + return m.group(1).upper(), m.group(2), m.group(3) + raise ValueError("Not a Request-Line") + + @classmethod + def register(cls, method: str, uri: str, *responses: Any, **config: Any) -> None: + """Register an HTTP entry for multiple responses. + + Args: + method: HTTP method (GET, POST, etc.) + uri: URI to match + *responses: Response(s) to cycle through + **config: Configuration options (match_querystring, can_handle_fun) + + Raises: + AttributeError: If using body/status params (use single_register instead) + KeyError: If invalid config keys provided + """ + if "body" in config or "status" in config: + raise AttributeError("Did you mean `Entry.single_register(...)`?") + + if config.keys() - cls.default_config.keys(): + raise KeyError( + f"Invalid config keys: {config.keys() - cls.default_config.keys()}" + ) + + _config = cls.default_config.copy() + _config.update({k: v for k, v in config.items() if k in _config}) + + Mocket.register(cls(uri, method, responses, **_config)) + + @classmethod + def single_register( + cls, + method: str, + uri: str, + body: Any = "", + status: int = 200, + headers: dict | None = None, + exception: Exception | None = None, + match_querystring: bool = True, + can_handle_fun: Callable | None = None, + **config: Any, + ) -> None: + """Register a single HTTP response for a URI and method. + + This is a convenience method that creates a single Response object + instead of requiring a list. + + Args: + method: HTTP method (GET, POST, etc.) + uri: URI to match + body: Response body content + status: HTTP status code + headers: Dictionary of response headers + exception: Exception to raise instead of returning response + match_querystring: Whether to match query strings + can_handle_fun: Custom matching function + **config: Additional configuration options + """ + response = ( + exception + if exception + else cls.response_cls(body=body, status=status, headers=headers) + ) + + cls.register( + method, + uri, + response, + match_querystring=match_querystring, + can_handle_fun=can_handle_fun, + **config, + ) diff --git a/mocket/mocks/mockredis.py b/mocket/mocks/mockredis.py new file mode 100644 index 00000000..eee2d6c8 --- /dev/null +++ b/mocket/mocks/mockredis.py @@ -0,0 +1,191 @@ +"""Redis mocking implementation for Mocket.""" + +from __future__ import annotations + +from itertools import chain +from typing import Any + +from mocket.compat import ( + decode_from_bytes, + encode_to_bytes, + shsplit, +) +from mocket.entry import MocketEntry +from mocket.mocket import Mocket +from mocket.types import Address + + +class Request: + """Redis request wrapper.""" + + def __init__(self, data: bytes) -> None: + """Initialize a Redis request. + + Args: + data: Raw Redis command data + """ + self.data = data + + +class Response: + """Redis response wrapper.""" + + def __init__(self, data: Any = None) -> None: + """Initialize a Redis response. + + Args: + data: Response data (will be "redisize"d) + """ + self.data = Redisizer.redisize(data or OK) + + +class Redisizer(bytes): + """Convert Python types to Redis protocol format.""" + + @staticmethod + def tokens(iterable: list[Any]) -> list[bytes]: + """Convert an iterable to Redis tokens. + + Args: + iterable: List of items to convert + + Returns: + List of Redis protocol bytes + """ + iterable = [encode_to_bytes(x) for x in iterable] + return [f"*{len(iterable)}".encode()] + list( + chain(*zip([f"${len(x)}".encode() for x in iterable], iterable)) + ) + + @staticmethod + def redisize(data: Any) -> Redisizer: + """Convert Python data to Redis protocol format. + + Args: + data: Python data to convert + + Returns: + Redisizer bytes + """ + + def get_conversion(t: type) -> Any: + return { + dict: lambda x: b"\r\n".join( + Redisizer.tokens(list(chain(*tuple(x.items())))) + ), + int: lambda x: f":{x}".encode(), + str: lambda x: "${}\r\n{}".format(len(x.encode("utf-8")), x).encode( + "utf-8" + ), + list: lambda x: b"\r\n".join(Redisizer.tokens(x)), + }[t] + + if isinstance(data, Redisizer): + return data + if isinstance(data, bytes): + data = decode_from_bytes(data) + return Redisizer(get_conversion(data.__class__)(data) + b"\r\n") + + @staticmethod + def command(description: str, _type: str = "+") -> Redisizer: + """Create a Redis command response. + + Args: + description: Response description + _type: Response type prefix (+, -, :, $, *) + + Returns: + Formatted Redis response + """ + return Redisizer("{}{}{}".format(_type, description, "\r\n").encode("utf-8")) + + @staticmethod + def error(description: str) -> Redisizer: + """Create a Redis error response. + + Args: + description: Error description + + Returns: + Formatted Redis error response + """ + return Redisizer.command(description, _type="-") + + +OK = Redisizer.command("OK") +QUEUED = Redisizer.command("QUEUED") +ERROR = Redisizer.error + + +class Entry(MocketEntry): + """Redis entry for matching and responding to Redis commands.""" + + request_cls = Request + response_cls = Response + + def __init__( + self, addr: Address | None, command: str, responses: list[Any] + ) -> None: + """Initialize a Redis entry. + + Args: + addr: (host, port) tuple or None for default + command: Redis command string to match + responses: List of responses to cycle through + """ + super().__init__(addr or ("localhost", 6379), responses) + d = shsplit(command) + d[0] = d[0].upper() + self.command = Redisizer.tokens(d) + + def can_handle(self, data: bytes) -> bool: + """Check if this entry can handle the given command. + + Args: + data: Raw Redis command data + + Returns: + True if this entry matches the command + """ + return data.splitlines() == self.command + + @classmethod + def register(cls, addr: Address | None, command: str, *responses: Any) -> None: + """Register a Redis entry. + + Args: + addr: (host, port) tuple or None for default + command: Redis command to match + *responses: Responses to cycle through + """ + responses = [ + r if isinstance(r, BaseException) else cls.response_cls(r) + for r in responses + ] + Mocket.register(cls(addr, command, responses)) + + @classmethod + def register_response( + cls, command: str, response: Any, addr: Address | None = None + ) -> None: + """Register a single response for a command. + + Args: + command: Redis command to match + response: Response to return + addr: (host, port) tuple or None for default + """ + cls.register(addr, command, response) + + @classmethod + def register_responses( + cls, command: str, responses: list[Any], addr: Address | None = None + ) -> None: + """Register multiple responses for a command. + + Args: + command: Redis command to match + responses: List of responses to cycle through + addr: (host, port) tuple or None for default + """ + cls.register(addr, command, *responses) diff --git a/mocket/mode.py b/mocket/mode.py new file mode 100644 index 00000000..ffb23a44 --- /dev/null +++ b/mocket/mode.py @@ -0,0 +1,82 @@ +"""Mocket mode management for strict socket enforcement.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +from mocket.exceptions import StrictMocketException +from mocket.mocket import Mocket + +if TYPE_CHECKING: # pragma: no cover + from typing import NoReturn + + +class _MocketMode: + """Singleton class for managing Mocket's strict mode enforcement.""" + + __shared_state: ClassVar[dict[str, Any]] = {} + STRICT: ClassVar = None + STRICT_ALLOWED: ClassVar = None + + def __init__(self) -> None: + """Initialize the MocketMode singleton with shared state.""" + self.__dict__ = self.__shared_state + + def is_allowed(self, location: str | tuple[str, int]) -> bool: + """Check if a location is allowed to perform real socket calls. + + Checks if (`host`, `port`) or at least `host` + are allowed locations to perform real `socket` calls + + Args: + location: Hostname string or (host, port) tuple + + Returns: + True if the location is allowed, False if in STRICT mode and not allowed + """ + if not self.STRICT: + return True + + host_allowed = False + if isinstance(location, tuple): + host_allowed = location[0] in self.STRICT_ALLOWED + return host_allowed or location in self.STRICT_ALLOWED + + @staticmethod + def raise_not_allowed( + address: tuple[str, int] | None = None, + data: bytes | None = None, + ) -> NoReturn: + """Raise an exception when a socket operation is not allowed in STRICT mode. + + Args: + address: The (host, port) tuple that was attempted + data: The request data that was sent + + Raises: + StrictMocketException: Always raised with detailed context + """ + current_entries = [ + (location, "\n ".join(map(str, entries))) + for location, entries in Mocket._entries.items() + ] + formatted_entries = "\n".join( + [f" {location}:\n {entries}" for location, entries in current_entries] + ) + msg = ( + "Mocket tried to use the real `socket` module while STRICT mode was active." + ) + if address: + host, port = address + msg += f"\nAttempted address: {host}:{port}" + if data: + from mocket.compat import decode_from_bytes + + preview = decode_from_bytes(data).split("\r\n", 1)[0][:200] + msg += f"\nSent data: {preview}" + + msg += f"\nRegistered entries:\n{formatted_entries}" + raise StrictMocketException(msg) + + +MocketMode = _MocketMode() diff --git a/mocket/plugins/aiohttp_connector.py b/mocket/plugins/aiohttp_connector.py new file mode 100644 index 00000000..cde5019a --- /dev/null +++ b/mocket/plugins/aiohttp_connector.py @@ -0,0 +1,18 @@ +import contextlib + +from mocket import MocketSSLContext + +with contextlib.suppress(ModuleNotFoundError): + from aiohttp import ClientRequest + from aiohttp.connector import TCPConnector + + class MocketTCPConnector(TCPConnector): + """ + `aiohttp` reuses SSLContext instances created at import-time, + making it more difficult for Mocket to do its job. + This is an attempt to make things smoother, at the cost of + slightly patching the `ClientSession` while testing. + """ + + def _get_ssl_context(self, req: ClientRequest) -> MocketSSLContext: + return MocketSSLContext() diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index 9d61ae2e..fb40c0c5 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -1,38 +1,43 @@ -from mocket import Mocket, mocketize +from typing import Any, Dict, Optional + +from mocket import mocketize from mocket.async_mocket import async_mocketize -from mocket.compat import ENCODING, byte_type, text_type +from mocket.compat import ENCODING +from mocket.mocket import Mocket from mocket.mockhttp import Entry as MocketHttpEntry from mocket.mockhttp import Request as MocketHttpRequest from mocket.mockhttp import Response as MocketHttpResponse -def httprettifier_headers(headers): +def httprettifier_headers(headers: Dict[str, str]) -> Dict[str, str]: return {k.lower().replace("_", "-"): v for k, v in headers.items()} class Request(MocketHttpRequest): @property - def body(self): - return super().body.encode(ENCODING) + def body(self) -> bytes: + return super().body.encode(ENCODING) # type: ignore[no-any-return] @property - def headers(self): + def headers(self) -> Dict[str, str]: return httprettifier_headers(super().headers) class Response(MocketHttpResponse): - def get_protocol_data(self, str_format_fun_name="lower"): + headers: Dict[str, str] + + def get_protocol_data(self, str_format_fun_name: str = "lower") -> bytes: if "server" in self.headers and self.headers["server"] == "Python/Mocket": self.headers["server"] = "Python/HTTPretty" - return super().get_protocol_data(str_format_fun_name=str_format_fun_name) + return super().get_protocol_data(str_format_fun_name=str_format_fun_name) # type: ignore[no-any-return] - def set_base_headers(self): + def set_base_headers(self) -> None: super().set_base_headers() self.headers = httprettifier_headers(self.headers) original_set_base_headers = set_base_headers - def set_extra_headers(self, headers): + def set_extra_headers(self, headers: Dict[str, str]) -> None: self.headers.update(headers) @@ -59,17 +64,17 @@ class Entry(MocketHttpEntry): def register_uri( - method, - uri, - body="HTTPretty :)", - adding_headers=None, - forcing_headers=None, - status=200, - responses=None, - match_querystring=False, - priority=0, - **headers, -): + method: str, + uri: str, + body: str = "HTTPretty :)", + adding_headers: Optional[Dict[str, str]] = None, + forcing_headers: Optional[Dict[str, str]] = None, + status: int = 200, + responses: Any = None, + match_querystring: bool = False, + priority: int = 0, + **headers: str, +) -> None: headers = httprettifier_headers(headers) if adding_headers is not None: @@ -80,12 +85,17 @@ def register_uri( def force_headers(self): self.headers = httprettifier_headers(forcing_headers) - Response.set_base_headers = force_headers + Response.set_base_headers = force_headers # type: ignore[method-assign] else: - Response.set_base_headers = Response.original_set_base_headers + Response.set_base_headers = Response.original_set_base_headers # type: ignore[method-assign] if responses: - Entry.register(method, uri, *responses) + Entry.register( + method, + uri, + *responses, + match_querystring=match_querystring, + ) else: Entry.single_register( method, @@ -109,7 +119,7 @@ def __getattr__(self, name): HTTPretty = MocketHTTPretty() -HTTPretty.register_uri = register_uri +HTTPretty.register_uri = register_uri # type: ignore[attr-defined] httpretty = HTTPretty __all__ = ( @@ -129,6 +139,4 @@ def __getattr__(self, name): "HEAD", "PATCH", "register_uri", - "text_type", - "byte_type", ) diff --git a/mocket/plugins/pook_mock_engine.py b/mocket/plugins/pook_mock_engine.py index 99cb07ec..549f5509 100644 --- a/mocket/plugins/pook_mock_engine.py +++ b/mocket/plugins/pook_mock_engine.py @@ -1,5 +1,7 @@ -from pook.engine import MockEngine -from pook.interceptors.base import BaseInterceptor +try: + from pook.engine import MockEngine +except ModuleNotFoundError: + MockEngine = object from mocket.mocket import Mocket from mocket.mockhttp import Entry, Response @@ -37,17 +39,6 @@ def single_register( return entry -class MocketInterceptor(BaseInterceptor): - @staticmethod - def activate(): - Mocket.disable() - Mocket.enable() - - @staticmethod - def disable(): - Mocket.disable() - - class MocketEngine(MockEngine): def __init__(self, engine): def mocket_mock_fun(*args, **kwargs): @@ -68,6 +59,18 @@ def mocket_mock_fun(*args, **kwargs): return mock + from pook.interceptors.base import BaseInterceptor + + class MocketInterceptor(BaseInterceptor): + @staticmethod + def activate(): + Mocket.disable() + Mocket.enable() + + @staticmethod + def disable(): + Mocket.disable() + # Store plugins engine self.engine = engine # Store HTTP client interceptors diff --git a/mocket/recording.py b/mocket/recording.py new file mode 100644 index 00000000..95faf126 --- /dev/null +++ b/mocket/recording.py @@ -0,0 +1,225 @@ +"""Request/response recording for playback during tests.""" + +from __future__ import annotations + +import contextlib +import hashlib +import json +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.types import Address +from mocket.utils import hexdump, hexload + +hash_function: Any = hashlib.md5 + +with contextlib.suppress(ImportError): + from xxhash_cffi import xxh32 as xxhash_cffi_xxh32 + + hash_function = xxhash_cffi_xxh32 + +with contextlib.suppress(ImportError): + from xxhash import xxh32 as xxhash_xxh32 + + hash_function = xxhash_xxh32 + + +def _hash_prepare_request(data: bytes) -> bytes: + """Prepare request data for hashing by sorting headers. + + Args: + data: Raw request data + + Returns: + Prepared bytes for hashing + """ + _data = decode_from_bytes(data) + return encode_to_bytes("".join(sorted(_data.split("\r\n")))) + + +def _hash_request(data: bytes) -> str: + """Hash a request using the best available hash function. + + Args: + data: Raw request data + + Returns: + Hex digest of the hash + """ + _data = _hash_prepare_request(data) + return hash_function(_data).hexdigest() + + +def _hash_request_fallback(data: bytes) -> str: + """Hash a request using MD5 as fallback. + + Args: + data: Raw request data + + Returns: + Hex digest of the MD5 hash + """ + _data = _hash_prepare_request(data) + return hashlib.md5(_data).hexdigest() + + +@dataclass +class MocketRecord: + """A record of a request and its corresponding response.""" + + host: str + port: int + request: bytes + response: bytes + + +class MocketRecordStorage: + """Storage for recording and retrieving request/response pairs.""" + + def __init__(self, directory: Path, namespace: str) -> None: + """Initialize the record storage. + + Args: + directory: Path to directory for storing recordings + namespace: Namespace for grouping records + """ + self._directory = directory + self._namespace = namespace + self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = ( + defaultdict(defaultdict) + ) + + self._load() + + @property + def directory(self) -> Path: + """Get the recording directory. + + Returns: + Path to recording directory + """ + return self._directory + + @property + def namespace(self) -> str: + """Get the recording namespace. + + Returns: + Namespace string + """ + return self._namespace + + @property + def file(self) -> Path: + """Get the path to the namespace's JSON file. + + Returns: + Path to JSON recording file + """ + return self._directory / f"{self._namespace}.json" + + def _load(self) -> None: + """Load recordings from disk.""" + if not self.file.exists(): + return + + json_data = self.file.read_text() + records = json.loads(json_data) + for host, port_signature_record in records.items(): + for port, signature_record in port_signature_record.items(): + for signature, record in signature_record.items(): + # NOTE backward-compat + try: + request_data = hexload(record["request"]) + except ValueError: + request_data = record["request"] + + self._records[(host, int(port))][signature] = MocketRecord( + host=host, + port=port, + request=request_data, + response=hexload(record["response"]), + ) + + def _save(self) -> None: + """Save recordings to disk.""" + data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict( + lambda: defaultdict(defaultdict) + ) + for address, signature_record in self._records.items(): + host, port = address + for signature, record in signature_record.items(): + data[host][str(port)][signature] = dict( + request=decode_from_bytes(record.request), + response=hexdump(record.response), + ) + + json_data = json.dumps(data, indent=4, sort_keys=True) + self.file.parent.mkdir(exist_ok=True) + self.file.write_text(json_data) + + def get_records(self, address: Address) -> list[MocketRecord]: + """Get all records for an address. + + Args: + address: (host, port) tuple + + Returns: + List of MocketRecord instances + """ + return list(self._records[address].values()) + + def get_record(self, address: Address, request: bytes) -> MocketRecord | None: + """Get a specific record matching the request. + + Args: + address: (host, port) tuple + request: Request bytes + + Returns: + Matching MocketRecord or None + """ + # NOTE for backward-compat + request_signature_fallback = _hash_request_fallback(request) + if request_signature_fallback in self._records[address]: + return self._records[address].get(request_signature_fallback) + + request_signature = _hash_request(request) + if request_signature in self._records[address]: + return self._records[address][request_signature] + + return None + + def put_record( + self, + address: Address, + request: bytes, + response: bytes, + ) -> None: + """Store a new record. + + Args: + address: (host, port) tuple + request: Request bytes + response: Response bytes + """ + host, port = address + record = MocketRecord( + host=host, + port=port, + request=request, + response=response, + ) + + # NOTE for backward-compat + request_signature_fallback = _hash_request_fallback(request) + if request_signature_fallback in self._records[address]: + self._records[address][request_signature_fallback] = record + return + + request_signature = _hash_request(request) + self._records[address][request_signature] = record + self._save() diff --git a/mocket/socket.py b/mocket/socket.py new file mode 100644 index 00000000..bd79528c --- /dev/null +++ b/mocket/socket.py @@ -0,0 +1,668 @@ +"""Mock socket implementation for Mocket.""" + +from __future__ import annotations + +import contextlib +import errno +import os +import select +import socket +from types import TracebackType +from typing import Any, Type + +from typing_extensions import Self + +from mocket.entry import MocketEntry +from mocket.io import MocketSocketIO +from mocket.mocket import Mocket +from mocket.mode import MocketMode +from mocket.types import ( + Address, + ReadableBuffer, + WriteableBuffer, + _RetAddress, +) + +true_gethostbyname = socket.gethostbyname +true_socket = socket.socket + + +def mock_create_connection( + address: Address, + timeout: float | None = None, + source_address: Address | None = None, +) -> socket.socket: + """Create a mock socket connection. + + Args: + address: (host, port) tuple + timeout: Connection timeout in seconds + source_address: Source address for binding (unused) + + Returns: + MocketSocket instance + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) + if timeout: + s.settimeout(timeout) + s.connect(address) + return s + + +def mock_getaddrinfo( + host: str, + port: int, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[tuple[int, int, int, str, tuple[str, int]]]: + """Mock socket.getaddrinfo function. + + Args: + host: Hostname + port: Port number + family: Address family (ignored) + type: Socket type (ignored) + proto: Protocol (ignored) + flags: Flags (ignored) + + Returns: + List of address info tuples + """ + return [(2, 1, 6, "", (host, port))] + + +def mock_gethostbyname(hostname: str) -> str: + """Mock socket.gethostbyname function. + + Args: + hostname: Hostname to resolve (unused) + + Returns: + Localhost IP address + """ + return "127.0.0.1" + + +def mock_gethostname() -> str: + """Mock socket.gethostname function. + + Returns: + Localhost hostname + """ + return "localhost" + + +def mock_inet_pton(address_family: int, ip_string: str) -> bytes: + """Mock socket.inet_pton function. + + Args: + address_family: Address family (unused) + ip_string: IP string (unused) + + Returns: + Localhost as bytes + """ + return bytes("\x7f\x00\x00\x01", "utf-8") + + +def mock_socketpair( + *args: Any, + **kwargs: Any, +) -> tuple[socket.socket, socket.socket]: + """Mock socket.socketpair function. + + Returns a real socketpair() used by asyncio loop for supporting + calls made by fastapi and similar services. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Tuple of two connected sockets + """ + import _socket + + return _socket.socketpair(*args, **kwargs) + + +class MocketSocket: + """Mock socket implementation for Mocket.""" + + def __init__( + self, + family: socket.AddressFamily | int = socket.AF_INET, + type: socket.SocketKind | int = socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, + **kwargs: Any, + ) -> None: + """Initialize a Mocket socket. + + Args: + family: Address family + type: Socket type + proto: Protocol number + fileno: File descriptor (unused) + **kwargs: Additional keyword arguments + """ + self._family = family + self._type = type + self._proto = proto + + self._kwargs = kwargs + self._true_socket = true_socket(family, type, proto) + + self._buflen = 65536 + self._timeout: float | None = None + + self._host = None + self._port = None + self._address = None + + self._io = None + self._entry = None + + def __str__(self) -> str: + """Return a string representation of the socket.""" + return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" + + def __enter__(self) -> Self: + """Enter context manager.""" + return self + + def __exit__( + self, + type_: Type[BaseException] | None, # noqa: UP006 + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """Exit context manager and close socket.""" + self.close() + + @property + def family(self) -> int: + """Get the address family.""" + return self._family + + @property + def type(self) -> int: + """Get the socket type.""" + return self._type + + @property + def proto(self) -> int: + """Get the protocol number.""" + return self._proto + + @property + def io(self) -> MocketSocketIO: + """Get or create the socket I/O object.""" + if self._io is None: + self._io = MocketSocketIO((self._host, self._port)) + return self._io + + def fileno(self) -> int: + """Get the file descriptor for reading. + + Returns: + File descriptor number + """ + address = (self._host, self._port) + r_fd, _ = Mocket.get_pair(address) + if not r_fd: + r_fd, w_fd = os.pipe() + Mocket.set_pair(address, (r_fd, w_fd)) + return r_fd + + def gettimeout(self) -> float | None: + """Get the socket timeout. + + Returns: + Timeout in seconds or None + """ + return self._timeout + + def setsockopt( + self, + level: int, + optname: int, + value: int | bytes | None, + optlen: int | None = None, + ) -> None: + """Set socket option. + + Args: + level: Socket option level (e.g., socket.SOL_SOCKET) + optname: Socket option name (e.g., socket.SO_REUSEADDR) + value: Option value as an integer or bytes, or None when optlen is provided + optlen: Option length (used when value is None) + """ + if self._true_socket: + if optlen is not None: + self._true_socket.setsockopt(level, optname, value, optlen) + else: + self._true_socket.setsockopt(level, optname, value) + + def settimeout(self, timeout: float | None) -> None: + """Set the socket timeout. + + Args: + timeout: Timeout in seconds or None + """ + self._timeout = timeout + + @staticmethod + def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: + """Get socket option (mock implementation). + + Args: + level: Socket option level + optname: Socket option name + buflen: Buffer length (unused) + + Returns: + SOCK_STREAM constant + """ + return socket.SOCK_STREAM + + def getpeername(self) -> _RetAddress: + """Get the remote socket address. + + Returns: + Address of the remote socket + """ + return self._address + + def setblocking(self, block: bool) -> None: + """Set the socket to blocking or non-blocking mode. + + Args: + block: True for blocking, False for non-blocking + """ + self.settimeout(None) if block else self.settimeout(0.0) + + def getblocking(self) -> bool: + """Check if the socket is in blocking mode. + + Returns: + True if blocking, False otherwise + """ + return self.gettimeout() is None + + def getsockname(self) -> _RetAddress: + """Get the local socket address. + + Returns: + Local socket address + """ + return socket.gethostbyname(self._address[0]), self._address[1] + + def connect(self, address: Address) -> None: + """Connect the socket to a remote address. + + Args: + address: (host, port) tuple + """ + self._address = self._host, self._port = address + Mocket._address = address + + def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO: + """Create a file object for the socket. + + Args: + mode: Mode string (unused) + bufsize: Buffer size (unused) + + Returns: + MocketSocketIO object + """ + return self.io + + def get_entry(self, data: bytes) -> MocketEntry | None: + """Get a matching entry for the given data. + + Args: + data: Request data + + Returns: + Matching MocketEntry or None + """ + return Mocket.get_entry(self._host, self._port, data) + + def sendto( + self, + data: ReadableBuffer, + address: Address | None = None, + ) -> int: + """Send data to a specific address (UDP-like). + + Args: + data: Data to send + address: Destination address + + Returns: + Number of bytes sent + """ + self.connect(address) + self.sendall(data) + return len(data) + + def sendall( + self, + data: ReadableBuffer, + entry: MocketEntry | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """Send all data through the socket. + + Args: + data: Data to send + entry: Pre-matched entry (optional) + *args: Additional arguments + **kwargs: Additional keyword arguments + """ + if entry is None: + entry = self.get_entry(data) + + if entry: + consume_response = entry.collect(data) + response = entry.get_response() if consume_response is not False else None + else: + response = self.true_sendall(data, *args, **kwargs) + + if response is not None: + self.io.seek(0) + self.io.write(response) + self.io.truncate() + self.io.seek(0) + + def sendmsg( + self, + buffers: list[ReadableBuffer], + ancdata: list[tuple[int, bytes]] | None = None, + flags: int = 0, + address: Address | None = None, + ) -> int: + """Send a message through multiple buffers. + + Args: + buffers: List of buffers to send + ancdata: Ancillary data (unused) + flags: Flags (unused) + address: Destination address (unused) + + Returns: + Number of bytes sent + """ + if not buffers: + return 0 + + data = b"".join(bytes(b) for b in buffers) + self.sendall(data) + return len(data) + + def recvmsg( + self, + buffersize: int | None = None, + ancbufsize: int | None = None, + flags: int = 0, + ) -> tuple[bytes, list[tuple[int, bytes]]]: + """Receive a message from the socket. + + This is a mock implementation that reads from the MocketSocketIO. + + Args: + buffersize: Size of buffer to receive + ancbufsize: Ancillary buffer size (unused) + flags: Flags (unused) + + Returns: + Tuple of (data, ancillary_data) + """ + try: + data = self.recv(buffersize) + except BlockingIOError: + return b"", [] + + return data, [] + + def recvmsg_into( + self, + buffers: list[ReadableBuffer], + ancbufsize: int | None = None, + flags: int = 0, + address: Address | None = None, + ) -> int: + """Receive a message into multiple buffers. + + This is a mock implementation that reads from the MocketSocketIO. + + Args: + buffers: List of buffers to receive into + ancbufsize: Ancillary buffer size (unused) + flags: Flags (unused) + address: Address (unused) + + Returns: + Number of bytes received + """ + if not buffers: + return 0 + + try: + data = self.recv(len(buffers[0])) + except BlockingIOError: + return 0 + + for i, buffer in enumerate(buffers): + if i < len(data): + buffer[: len(data)] = data + else: + buffer[:] = b"" + return len(data) + + def recvfrom_into( + self, + buffer: WriteableBuffer, + buffersize: int | None = None, + flags: int | None = None, + ) -> tuple[int, _RetAddress]: + """Receive data into a buffer and return the source address. + + Args: + buffer: Buffer to receive into + buffersize: Size to receive + flags: Flags (unused) + + Returns: + Tuple of (bytes_received, source_address) + """ + return self.recv_into(buffer, buffersize, flags), self._address + + def recv_into( + self, + buffer: WriteableBuffer, + buffersize: int | None = None, + flags: int | None = None, + ) -> int: + """Receive data into a buffer. + + Args: + buffer: Buffer to receive into + buffersize: Number of bytes to receive + flags: Flags (unused) + + Returns: + Number of bytes received + """ + if hasattr(buffer, "write"): + return buffer.write(self.recv(buffersize)) + + if buffersize is None: + buffersize = len(buffer) + + data = self.recv(buffersize) + if data: + buffer[: len(data)] = data + return len(data) + + def recvfrom( + self, buffersize: int, flags: int | None = None + ) -> tuple[bytes, _RetAddress]: + """Receive data and the source address. + + Args: + buffersize: Number of bytes to receive + flags: Flags (unused) + + Returns: + Tuple of (data, source_address) + """ + return self.recv(buffersize, flags), self._address + + def recv(self, buffersize: int, flags: int | None = None) -> bytes: + """Receive data from the socket. + + Args: + buffersize: Maximum number of bytes to receive + flags: Flags (unused) + + Returns: + Received bytes + + Raises: + BlockingIOError: If socket is non-blocking and no data available + """ + r_fd, _ = Mocket.get_pair((self._host, self._port)) + if r_fd: + return os.read(r_fd, buffersize) + data = self.io.read(buffersize) + if data: + return data + # used by Redis mock + exc = BlockingIOError() + exc.errno = errno.EWOULDBLOCK + exc.args = (0,) + raise exc + + def true_sendall(self, data: bytes, *args: Any, **kwargs: Any) -> bytes: + """Send data through the real socket and receive response. + + Args: + data: Data to send + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + Response bytes from the real socket + + Raises: + StrictMocketException: If operation not allowed in STRICT mode + """ + if not MocketMode.is_allowed(self._address): + MocketMode.raise_not_allowed(self._address, data) + + # try to get the response from recordings + if Mocket._record_storage: + record = Mocket._record_storage.get_record( + address=self._address, + request=data, + ) + if record is not None: + return record.response + + host, port = self._address + host = true_gethostbyname(host) + + with contextlib.suppress(OSError, ValueError): + # already connected + self._true_socket.connect((host, port)) + + self._true_socket.sendall(data, *args, **kwargs) + response = b"" + # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 + while True: + more_to_read = select.select([self._true_socket], [], [], 0.1)[0] + if not more_to_read and response: + break + new_content = self._true_socket.recv(self._buflen) + if not new_content: + break + response += new_content + + # store request+response in recordings + if Mocket._record_storage: + Mocket._record_storage.put_record( + address=self._address, + request=data, + response=response, + ) + + return response + + def send( + self, + data: ReadableBuffer, + *args: Any, + **kwargs: Any, + ) -> int: + """Send data through the socket. + + Args: + data: Data to send + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + Number of bytes sent + """ + entry = self.get_entry(data) + if not entry or (entry and self._entry != entry): + kwargs["entry"] = entry + self.sendall(data, *args, **kwargs) + else: + req = Mocket.last_request() + if hasattr(req, "add_data"): + req.add_data(data) + self._entry = entry + return len(data) + + def accept(self) -> tuple[MocketSocket, _RetAddress]: + """Accept a connection and return a new MocketSocket object. + + Returns: + Tuple of (new_socket, client_address) + """ + new_socket = MocketSocket( + family=self._family, + type=self._type, + proto=self._proto, + ) + new_socket._address = (self._host, self._port) + new_socket._host = self._host + new_socket._port = self._port + return new_socket, (self._host, self._port) + + def close(self) -> None: + """Close the socket and underlying true socket.""" + if self._true_socket and not self._true_socket._closed: + self._true_socket.close() + + def __getattr__(self, name: str) -> Any: + """Do-nothing catchall function for methods like shutdown(). + + Args: + name: Method name + + Returns: + A callable that does nothing + """ + + def do_nothing(*args: Any, **kwargs: Any) -> Any: + pass + + return do_nothing diff --git a/mocket/ssl/__init__.py b/mocket/ssl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py new file mode 100644 index 00000000..aeaab6b5 --- /dev/null +++ b/mocket/ssl/context.py @@ -0,0 +1,134 @@ +"""Mocket SSL context implementation.""" + +from __future__ import annotations + +from typing import Any + +from mocket.socket import MocketSocket +from mocket.ssl.socket import MocketSSLSocket + + +class _MocketSSLContext: + """Mock SSL context for Python 3.6 and newer.""" + + class FakeSetter(int): + """Descriptor that ignores assignment.""" + + def __set__(self, *args: Any) -> None: + """Ignore any assignment attempts.""" + pass + + minimum_version = FakeSetter() + options = FakeSetter() + verify_mode = FakeSetter() + verify_flags = FakeSetter() + + +class MocketSSLContext(_MocketSSLContext): + """Mock SSL context that wraps sockets in MocketSSLSocket.""" + + DUMMY_METHODS: tuple = ( + "load_default_certs", + "load_verify_locations", + "set_alpn_protocols", + "set_ciphers", + "set_default_verify_paths", + ) + sock: MocketSocket | None = None + post_handshake_auth: bool | None = None + _check_hostname: bool = False + + @property + def check_hostname(self) -> bool: + """Get the check_hostname setting. + + Returns: + Always False (mock implementation) + """ + return self._check_hostname + + @check_hostname.setter + def check_hostname(self, _: bool) -> None: + """Set the check_hostname setting (mocked). + + Args: + _: Value (ignored, always set to False) + """ + self._check_hostname = False + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the SSL context. + + Args: + *args: Positional arguments (ignored) + **kwargs: Keyword arguments (ignored) + """ + self._set_dummy_methods() + + def _set_dummy_methods(self) -> None: + """Set all dummy methods that do nothing.""" + + def dummy_method(*args: Any, **kwargs: Any) -> Any: + pass + + for m in self.DUMMY_METHODS: + setattr(self, m, dummy_method) + + def wrap_socket( + self, + sock: MocketSocket, + *args: Any, + **kwargs: Any, + ) -> MocketSSLSocket: + """Wrap a socket in an SSL socket. + + Args: + sock: Socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ + return MocketSSLSocket._create(sock, *args, **kwargs) + + def wrap_bio( + self, + incoming: Any, + outgoing: Any, + server_side: bool = False, + server_hostname: str | bytes | None = None, + ) -> MocketSSLSocket: + """Wrap BIO objects in an SSL socket (mock implementation). + + Args: + incoming: Incoming BIO (_ssl.MemoryBIO) + outgoing: Outgoing BIO (_ssl.MemoryBIO) + server_side: Whether this is server side + server_hostname: Server hostname + + Returns: + MocketSSLSocket instance + """ + ssl_obj = MocketSSLSocket() + ssl_obj._host = server_hostname + return ssl_obj + + +def mock_wrap_socket( + sock: MocketSocket, + *args: Any, + **kwargs: Any, +) -> MocketSSLSocket: + """Mock ssl.wrap_socket function. + + Args: + sock: Socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ + context = MocketSSLContext() + return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py new file mode 100644 index 00000000..94984fce --- /dev/null +++ b/mocket/ssl/socket.py @@ -0,0 +1,160 @@ +"""Mocket SSL socket implementation.""" + +from __future__ import annotations + +import ssl +from datetime import datetime, timedelta +from ssl import Options +from typing import Any + +from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket +from mocket.socket import MocketSocket +from mocket.types import _PeerCertRetDictType + + +class MocketSSLSocket(MocketSocket): + """Mock SSL socket that extends MocketSocket with SSL-specific behavior.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize an SSL socket. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + """ + super().__init__(*args, **kwargs) + + self._did_handshake: bool = False + self._sent_non_empty_bytes: bool = False + self._original_socket: MocketSocket = self + + def read(self, buffersize: int | None = None) -> bytes: + """Read data from the SSL socket. + + Args: + buffersize: Maximum bytes to read + + Returns: + Bytes read from the socket + + Raises: + ssl.SSLWantReadError: If handshake not completed and no data + """ + rv = self.io.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def write(self, data: bytes) -> int | None: + """Write data to the SSL socket. + + Args: + data: Bytes to write + + Returns: + Number of bytes written + """ + return self.send(encode_to_bytes(data)) + + def do_handshake(self) -> None: + """Perform SSL handshake (mock implementation).""" + self._did_handshake = True + + def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: + """Get the peer certificate (mock implementation). + + Args: + binary_form: Whether to return binary form (unused) + + Returns: + Mock certificate dictionary + """ + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", f"*.{self._host}"), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", f"*.{self._host}"),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", f"*.{self._host}"),), + ), + } + + def ciper(self) -> tuple[str, str, str]: + """Get cipher information (mock implementation). + + Returns: + Tuple of (cipher_name, protocol, key_exchange_algorithm) + """ + return "ADH", "AES256", "SHA" + + def compression(self) -> Options: + """Get compression options (mock implementation). + + Returns: + SSL options constant + """ + return ssl.OP_NO_COMPRESSION + + def unwrap(self) -> MocketSocket: + """Unwrap the SSL socket and return the underlying socket. + + Returns: + The original MocketSocket + """ + return self._original_socket + + @classmethod + def _create( + cls, + sock: MocketSocket, + ssl_context: ssl.SSLContext | None = None, + server_hostname: str | None = None, + *args: Any, + **kwargs: Any, + ) -> MocketSSLSocket: + """Create an SSL socket from a regular socket. + + Args: + sock: Socket to wrap + ssl_context: SSL context (optional) + server_hostname: Server hostname + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + New MocketSSLSocket instance + """ + ssl_socket = MocketSSLSocket() + ssl_socket._original_socket = sock + ssl_socket._true_socket = sock._true_socket + + if ssl_context: + ssl_socket._true_socket = ssl_context.wrap_socket( + sock=ssl_socket._true_socket, + server_hostname=server_hostname, + ) + + ssl_socket._kwargs = kwargs + + ssl_socket._timeout = sock._timeout + + ssl_socket._host = sock._host + ssl_socket._port = sock._port + ssl_socket._address = sock._address + + ssl_socket._io = sock._io + ssl_socket._entry = sock._entry + + return ssl_socket diff --git a/mocket/types.py b/mocket/types.py new file mode 100644 index 00000000..fedfd37f --- /dev/null +++ b/mocket/types.py @@ -0,0 +1,22 @@ +"""Type aliases and definitions for Mocket.""" + +from __future__ import annotations + +from typing import Any, Dict, Tuple, Union + +from typing_extensions import Buffer, TypeAlias + +Address = Tuple[str, int] + +# adapted from typeshed/stdlib/_typeshed/__init__.pyi +WriteableBuffer: TypeAlias = Buffer +ReadableBuffer: TypeAlias = Buffer + +# from typeshed/stdlib/_socket.pyi +_Address: TypeAlias = Union[Tuple[Any, ...], str, ReadableBuffer] +_RetAddress: TypeAlias = Any + +# from typeshed/stdlib/ssl.pyi +_PCTRTT: TypeAlias = Tuple[Tuple[str, str], ...] +_PCTRTTT: TypeAlias = Tuple[_PCTRTT, ...] +_PeerCertRetDictType: TypeAlias = Dict[str, Union[str, _PCTRTTT, _PCTRTT]] diff --git a/mocket/urllib3.py b/mocket/urllib3.py new file mode 100644 index 00000000..872efc5f --- /dev/null +++ b/mocket/urllib3.py @@ -0,0 +1,40 @@ +"""Urllib3 specific socket mocking.""" + +from __future__ import annotations + +from typing import Any + +from mocket.socket import MocketSocket +from mocket.ssl.context import MocketSSLContext +from mocket.ssl.socket import MocketSSLSocket + + +def mock_match_hostname(*args: Any) -> None: + """Mock urllib3's match_hostname function. + + Args: + *args: Ignored arguments + + Returns: + None + """ + return None + + +def mock_ssl_wrap_socket( + sock: MocketSocket, + *args: Any, + **kwargs: Any, +) -> MocketSSLSocket: + """Mock urllib3's ssl_wrap_socket function. + + Args: + sock: The socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ + context = MocketSSLContext() + return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/utils.py b/mocket/utils.py index 9efd6ad9..749b2b70 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,98 +1,100 @@ +"""Utility functions for Mocket.""" + from __future__ import annotations import binascii -import io -import os -import ssl -from typing import TYPE_CHECKING, Any, Callable, ClassVar +import contextlib +from typing import Any, Callable, Protocol, TypeVar, overload -from .compat import decode_from_bytes, encode_to_bytes -from .exceptions import StrictMocketException +import decorator +from typing_extensions import ParamSpec -if TYPE_CHECKING: # pragma: no cover - from typing import NoReturn +from mocket.compat import decode_from_bytes, encode_to_bytes +_P = ParamSpec("_P") +_R = TypeVar("_R") -SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 +class MocketizeDecorator(Protocol): + """Protocol for a flexible decorator that can be used in multiple ways. -class MocketSocketCore(io.BytesIO): - def __init__(self, address) -> None: - self._address = address - super().__init__() + This is a generic decorator signature, currently applicable to get_mocketize. - def write(self, content): - from mocket import Mocket + Decorators implementing this protocol can be used as: + 1. A function that transforms func (the parameter) into func1 (the returned object). + 2. A function that takes keyword arguments and returns a decorator. + """ - super().write(content) + @overload + def __call__(self, func: Callable[_P, _R], /) -> Callable[_P, _R]: ... - _, w_fd = Mocket.get_pair(self._address) - if w_fd: - os.write(w_fd, content) + @overload + def __call__( + self, **kwargs: Any + ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ... def hexdump(binary_string: bytes) -> str: - r""" - >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) - True + """Convert binary data to space-separated hex string. + + Args: + binary_string: Binary data to convert + + Returns: + Space-separated hexadecimal representation + + Example: + >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) + True """ bs = decode_from_bytes(binascii.hexlify(binary_string).upper()) return " ".join(a + b for a, b in zip(bs[::2], bs[1::2])) def hexload(string: str) -> bytes: - r""" - >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo") - True + """Convert space-separated hex string to binary data. + + Args: + string: Space-separated hexadecimal string + + Returns: + Binary data + + Raises: + ValueError: If the hex string is invalid + + Example: + >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo") + True """ string_no_spaces = "".join(string.split()) - return encode_to_bytes(binascii.unhexlify(string_no_spaces)) - - -def get_mocketize(wrapper_: Callable) -> Callable: - import decorator - - if decorator.__version__ < "5": # type: ignore[attr-defined] # pragma: no cover - return decorator.decorator(wrapper_) - return decorator.decorator( # type: ignore[call-arg] # kwsyntax - wrapper_, - kwsyntax=True, - ) - - -class MocketMode: - __shared_state: ClassVar[dict[str, Any]] = {} - STRICT: ClassVar = None - STRICT_ALLOWED: ClassVar = None - - def __init__(self) -> None: - self.__dict__ = self.__shared_state - - def is_allowed(self, location: str | tuple[str, int]) -> bool: - """ - Checks if (`host`, `port`) or at least `host` - are allowed locations to perform real `socket` calls - """ - if not self.STRICT: - return True - - host_allowed = False - if isinstance(location, tuple): - host_allowed = location[0] in self.STRICT_ALLOWED - return host_allowed or location in self.STRICT_ALLOWED - - @staticmethod - def raise_not_allowed() -> NoReturn: - from .mocket import Mocket - - current_entries = [ - (location, "\n ".join(map(str, entries))) - for location, entries in Mocket._entries.items() - ] - formatted_entries = "\n".join( - [f" {location}:\n {entries}" for location, entries in current_entries] - ) - raise StrictMocketException( - "Mocket tried to use the real `socket` module while STRICT mode was active.\n" - f"Registered entries:\n{formatted_entries}" - ) + try: + return encode_to_bytes(binascii.unhexlify(string_no_spaces)) + except binascii.Error as e: + raise ValueError from e + + +def get_mocketize(wrapper_: Callable) -> MocketizeDecorator: + """Get a mocketize decorator from a wrapper function. + + Decorators can be used as: + 1. A function that transforms func (the parameter) into func1 (the returned object). + 2. A function that takes keyword arguments and returns 1. + + Args: + wrapper_: The wrapper function to convert to a decorator + + Returns: + A MocketizeDecorator instance that can be used as a flexible decorator + """ + # trying to support different versions of `decorator` + with contextlib.suppress(TypeError): + return decorator.decorator(wrapper_, kwsyntax=True) # type: ignore[return-value, call-arg, unused-ignore] + return decorator.decorator(wrapper_) # type: ignore[return-value] + + +__all__ = ( + "get_mocketize", + "hexdump", + "hexload", +) diff --git a/pyproject.toml b/pyproject.toml index 203184cf..f20dbb93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,10 +6,12 @@ build-backend = "hatchling.build" requires-python = ">=3.8" name = "mocket" description = "Socket Mock Framework - for all kinds of socket animals, web-clients included - with gevent/asyncio/SSL support" -readme = { file = "README.rst", content-type = "text/x-rst" } -license = { file = "LICENSE" } +readme = "README.rst" +license = "BSD-3-Clause" +license-files = [ + "LICENSE", +] authors = [{ name = "Giorgio Salluzzo", email = "giorgio.salluzzo@gmail.com" }] -urls = { github = "https://github.com/mindflayer/python-mocket" } classifiers = [ "Development Status :: 6 - Mature", "Intended Audience :: Developers", @@ -19,6 +21,8 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Software Development", @@ -26,13 +30,18 @@ classifiers = [ "License :: OSI Approved :: BSD License", ] dependencies = [ - "python-magic>=0.4.5", + "puremagic", "decorator>=4.0.0", "urllib3>=1.25.3", - "httptools", + "h11", + "typing-extensions", ] dynamic = ["version"] +[project.urls] +Homepage = "https://pypi.org/project/mocket" +Repository = "https://github.com/mindflayer/python-mocket" + [project.optional-dependencies] test = [ "pre-commit", @@ -45,18 +54,18 @@ test = [ "redis", "gevent", "sure", - "pook", "flake8>5", "xxhash", "httpx", "pipfile", "build", - "twine", "fastapi", "aiohttp", "wait-for-it", - "mypy", + "mypy; platform_python_implementation!='PyPy'", "types-decorator", + "types-requests", + "trio", ] speedups = [ "xxhash;platform_python_implementation=='CPython'", @@ -86,9 +95,11 @@ exclude = [ [tool.pytest.ini_options] testpaths = [ - "tests", + "tests", "mocket", ] -addopts = "--doctest-modules --cov=mocket --cov-report=term-missing -v -x" +addopts = "--doctest-modules --cov=mocket --cov-report=xml --cov-report=term-missing --cov-append -v -x" +asyncio_default_fixture_loop_scope = "function" +asyncio_mode = "auto" [tool.ruff] src = ["mocket", "tests"] @@ -116,11 +127,14 @@ select = [ max-complexity = 8 [tool.mypy] -python_version = "3.8" +python_version = "3.13" files = [ "mocket/exceptions.py", "mocket/compat.py", "mocket/utils.py", + "mocket/plugins/httpretty/__init__.py", + "tests/test_httpretty.py", + "tests/test_utils.py", # "tests/" ] strict = true @@ -138,3 +152,11 @@ disable_error_code = ["no-untyped-def"] # enable this once full type-coverage is [[tool.mypy.overrides]] module = "tests.*" disable_error_code = ['type-arg', 'no-untyped-def'] + +[[tool.mypy.overrides]] +module = "mocket.plugins.*" +disallow_subclassing_any = false # mypy doesn't support dynamic imports + +[[tool.mypy.overrides]] +module = "tests.test_httpretty" +disallow_untyped_decorators = true diff --git a/scripts/patch_hosts.sh b/scripts/patch_hosts.sh index af7e453d..ec527d2d 100644 --- a/scripts/patch_hosts.sh +++ b/scripts/patch_hosts.sh @@ -1,5 +1,9 @@ -sudo grep -v httpbin.local /etc/hosts | sudo tee /etc/hosts.mocket -export CONTAINER_ID=$(docker compose ps -q proxy) -export CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' $CONTAINER_ID) -echo "$CONTAINER_IP httpbin.local" | sudo tee -a /etc/hosts.mocket -sudo mv /etc/hosts.mocket /etc/hosts +HOSTS=/etc/hosts +MOCKET_HOSTS=/etc/hosts.mocket +HTTPBIN_HOST=httpbin.local + +sudo grep -v ${HTTPBIN_HOST} ${HOSTS} | sudo tee ${MOCKET_HOSTS} +CONTAINER_ID=$(docker compose ps -q proxy) +CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' ${CONTAINER_ID}) +echo "${CONTAINER_IP} ${HTTPBIN_HOST}" | sudo tee -a ${MOCKET_HOSTS} +sudo mv ${MOCKET_HOSTS} ${HOSTS} diff --git a/tests/main/test_pook.py b/tests/main/test_pook.py deleted file mode 100644 index f398672e..00000000 --- a/tests/main/test_pook.py +++ /dev/null @@ -1,29 +0,0 @@ -import pook -import requests - -from mocket.plugins.pook_mock_engine import MocketEngine - -pook.set_mock_engine(MocketEngine) - - -@pook.on -def test_pook_engine(): - url = "http://twitter.com/api/1/foobar" - status = 404 - response_json = {"error": "foo"} - - mock = pook.get( - url, - headers={"content-type": "application/json"}, - reply=status, - response_json=response_json, - ) - mock.persist() - - requests.get(url) - assert mock.calls == 1 - - resp = requests.get(url) - assert resp.status_code == status - assert resp.json() == response_json - assert mock.calls == 2 diff --git a/tests/main/test_socket.py b/tests/main/test_socket.py deleted file mode 100644 index 8a6e65ad..00000000 --- a/tests/main/test_socket.py +++ /dev/null @@ -1,13 +0,0 @@ -import socket - -import pytest - -from mocket.mocket import MocketSocket - - -@pytest.mark.parametrize("blocking", (False, True)) -def test_blocking_socket(blocking): - sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(("locahost", 1234)) - sock.setblocking(blocking) - assert sock.getblocking() is blocking diff --git a/tests/main/test_asyncio.py b/tests/test_asyncio.py similarity index 78% rename from tests/main/test_asyncio.py rename to tests/test_asyncio.py index 0f9a7d17..a1eae240 100644 --- a/tests/main/test_asyncio.py +++ b/tests/test_asyncio.py @@ -9,9 +9,10 @@ from mocket import Mocketizer, async_mocketize from mocket.mockhttp import Entry +from mocket.plugins.aiohttp_connector import MocketTCPConnector -def test_asyncio_record_replay(event_loop): +def test_asyncio_record_replay(): async def test_asyncio_connection(): reader, writer = await asyncio.open_connection( host="google.com", @@ -32,7 +33,7 @@ async def test_asyncio_connection(): with tempfile.TemporaryDirectory() as temp_dir: with Mocketizer(truesocket_recording_dir=temp_dir): - event_loop.run_until_complete(test_asyncio_connection()) + asyncio.run(test_asyncio_connection()) files = glob.glob(f"{temp_dir}/*.json") assert len(files) == 1 @@ -46,6 +47,11 @@ async def test_asyncio_connection(): @pytest.mark.asyncio @async_mocketize async def test_aiohttp(): + """ + The alternative to using the custom `connector` would be importing + `aiohttp` when Mocket is already in control (inside the decorated test). + """ + url = "https://bar.foo/" data = {"message": "Hello"} @@ -57,7 +63,7 @@ async def test_aiohttp(): ) async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=3) + timeout=aiohttp.ClientTimeout(total=3), connector=MocketTCPConnector() ) as session, session.get(url) as response: response = await response.json() assert response == data diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 00000000..49b62ec7 --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,5 @@ +from mocket.compat import do_the_magic + + +def test_unknown_binary(): + assert do_the_magic(b"foobar-binary") == "application/octet-stream" diff --git a/tests/main/test_http.py b/tests/test_http.py similarity index 86% rename from tests/main/test_http.py rename to tests/test_http.py index 21fe448d..3d3e5b8e 100644 --- a/tests/main/test_http.py +++ b/tests/test_http.py @@ -12,7 +12,7 @@ import requests from mocket import Mocket, Mocketizer, mocketize -from mocket.mockhttp import Entry, Response +from mocket.mocks.mockhttp import Entry, Response class HttpTestCase(TestCase): @@ -258,22 +258,6 @@ def test_file_object(self): self.assertEqual(int(r.headers["Content-Length"]), len(local_content)) self.assertEqual(r.headers["Content-Type"], "image/png") - @mocketize - def test_file_object_with_no_lib_magic(self): - url = "http://github.com/fluidicon.png" - filename = "tests/fluidicon.png" - with open(filename, "rb") as file_obj: - Entry.register(Entry.GET, url, Response(body=file_obj, lib_magic=None)) - r = requests.get(url) - remote_content = r.content - with open(filename, "rb") as local_file_obj: - local_content = local_file_obj.read() - self.assertEqual(remote_content, local_content) - self.assertEqual(len(remote_content), len(local_content)) - self.assertEqual(int(r.headers["Content-Length"]), len(local_content)) - with self.assertRaises(KeyError): - self.assertEqual(r.headers["Content-Type"], "image/png") - @mocketize def test_same_url_different_methods(self): url = "http://bit.ly/fakeurl" @@ -308,10 +292,18 @@ def test_request_bodies(self): @mocketize(truesocket_recording_dir=os.path.dirname(__file__)) def test_truesendall_with_dump_from_recording(self): requests.get( - "http://httpbin.local/ip", headers={"user-agent": "Fake-User-Agent"} + "http://httpbin.local/ip", + headers={ + "user-agent": "Fake-User-Agent", + "Accept-Encoding": "gzip, deflate, zstd", + }, ) requests.get( - "http://httpbin.local/gzip", headers={"user-agent": "Fake-User-Agent"} + "http://httpbin.local/gzip", + headers={ + "user-agent": "Fake-User-Agent", + "Accept-Encoding": "gzip, deflate, zstd", + }, ) dump_filename = os.path.join( @@ -367,12 +359,12 @@ def test_sockets(self): sock = socket.socket(address[0], address[1], address[2]) sock.connect(address[-1]) - sock.write(f"{method} {path} HTTP/1.0\r\n") - sock.write(f"Host: {host}\r\n") - sock.write("Content-Type: application/json\r\n") - sock.write("Content-Length: %d\r\n" % len(data)) - sock.write("Connection: close\r\n\r\n") - sock.write(data) + sock.send(f"{method} {path} HTTP/1.0\r\n".encode()) + sock.send(f"Host: {host}\r\n".encode()) + sock.send(b"Content-Type: application/json\r\n") + sock.send(b"Content-Length: %d\r\n" % len(data)) + sock.send(b"Connection: close\r\n\r\n") + sock.send(data.encode()) sock.close() # Proof that worked. @@ -441,3 +433,52 @@ def test_suggestion_for_register_and_status(self): url, status=201, ) + + def test_invalid_config_key(self): + url = "http://foobar.com/path" + with self.assertRaises(KeyError): + Entry.register( + Entry.POST, + url, + Response(body='{"foo":"bar0"}', status=200), + invalid_key=True, + ) + + def test_add_trailing_slash(self): + url = "http://testme.org" + entry = Entry(url, "GET", [Response(body='{"foo":"bar0"}', status=200)]) + self.assertEqual(entry.path, "/") + + @mocketize + def test_mocket_with_no_path(self): + Entry.register(Entry.GET, "http://httpbin.local", Response(status=202)) + response = urlopen("http://httpbin.local/") + self.assertEqual(response.code, 202) + self.assertEqual(Mocket._entries[("httpbin.local", 80)][0].path, "/") + + @mocketize + def test_can_handle(self): + Entry.single_register( + Entry.POST, + "http://testme.org/foobar", + body=json.dumps({"message": "Spooky!"}), + match_querystring=False, + ) + Entry.single_register( + Entry.GET, + "http://testme.org/", + body=json.dumps({"message": "Gotcha!"}), + can_handle_fun=lambda p, q: p.endswith("/foobar") and "a" in q, + ) + Entry.single_register( + Entry.GET, + "http://testme.org/foobar", + body=json.dumps({"message": "Missed!"}), + match_querystring=False, + ) + response = requests.get("http://testme.org/foobar?a=1") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"message": "Gotcha!"}) + response = requests.get("http://testme.org/foobar?b=2") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"message": "Missed!"}) diff --git a/tests/main/test_http_gevent.py b/tests/test_http_gevent.py similarity index 68% rename from tests/main/test_http_gevent.py rename to tests/test_http_gevent.py index 88233071..4a3a9ff1 100644 --- a/tests/main/test_http_gevent.py +++ b/tests/test_http_gevent.py @@ -1,6 +1,6 @@ from gevent import monkey -from tests.main.test_http import HttpEntryTestCase +from tests.test_http import HttpEntryTestCase monkey.patch_socket() diff --git a/tests/tests38/test_http_httpx.py b/tests/test_http_httpx.py similarity index 100% rename from tests/tests38/test_http_httpx.py rename to tests/test_http_httpx.py diff --git a/tests/main/test_http_with_xxhash.py b/tests/test_http_with_xxhash.py similarity index 59% rename from tests/main/test_http_with_xxhash.py rename to tests/test_http_with_xxhash.py index 4600bf37..76d85343 100644 --- a/tests/main/test_http_with_xxhash.py +++ b/tests/test_http_with_xxhash.py @@ -4,17 +4,25 @@ import requests from mocket import Mocket, mocketize -from tests.main.test_http import HttpTestCase +from tests.test_http import HttpTestCase class HttpEntryTestCase(HttpTestCase): @mocketize(truesocket_recording_dir=os.path.dirname(__file__)) def test_truesendall_with_dump_from_recording(self): requests.get( - "http://httpbin.local/ip", headers={"user-agent": "Fake-User-Agent"} + "http://httpbin.local/ip", + headers={ + "user-agent": "Fake-User-Agent", + "Accept-Encoding": "gzip, deflate, zstd", + }, ) requests.get( - "http://httpbin.local/gzip", headers={"user-agent": "Fake-User-Agent"} + "http://httpbin.local/gzip", + headers={ + "user-agent": "Fake-User-Agent", + "Accept-Encoding": "gzip, deflate, zstd", + }, ) dump_filename = os.path.join( diff --git a/tests/main/test_httpretty.py b/tests/test_httpretty.py similarity index 100% rename from tests/main/test_httpretty.py rename to tests/test_httpretty.py diff --git a/tests/main/test_https.py b/tests/test_https.py similarity index 69% rename from tests/main/test_https.py rename to tests/test_https.py index f8c8549e..4685f4eb 100644 --- a/tests/main/test_https.py +++ b/tests/test_https.py @@ -7,7 +7,7 @@ import requests from mocket import Mocket, Mocketizer, mocketize -from mocket.mockhttp import Entry +from mocket.mockhttp import Entry # noqa - test retrocompatibility @pytest.fixture @@ -43,6 +43,7 @@ def test_json(response): @pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') +@pytest.mark.xfail(reason="Service down or blocking GitHub actions IPs") def test_truesendall_with_recording_https(url_to_mock): with tempfile.TemporaryDirectory() as temp_dir, Mocketizer( truesocket_recording_dir=temp_dir @@ -62,6 +63,7 @@ def test_truesendall_with_recording_https(url_to_mock): @pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') +@pytest.mark.xfail(reason="Service down or blocking GitHub actions IPs") def test_truesendall_after_mocket_session(url_to_mock): Mocket.enable() Mocket.disable() @@ -71,6 +73,7 @@ def test_truesendall_after_mocket_session(url_to_mock): @pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') +@pytest.mark.xfail(reason="Service down or blocking GitHub actions IPs") def test_real_request_session(url_to_mock): session = requests.Session() @@ -88,3 +91,24 @@ def test_raise_exception_from_single_register(): Entry.single_register(Entry.GET, url, exception=OSError()) with pytest.raises(requests.exceptions.ConnectionError): requests.get(url) + + +@mocketize +def test_can_handle(): + Entry.single_register( + Entry.GET, + "https://httpbin.org", + body=json.dumps({"message": "Nope... not this time!"}), + headers={"content-type": "application/json"}, + can_handle_fun=lambda path, qs_dict: path == "/ip" and qs_dict, + ) + Entry.single_register( + Entry.GET, + "https://httpbin.org", + body=json.dumps({"message": "There you go!"}), + headers={"content-type": "application/json"}, + can_handle_fun=lambda path, qs_dict: path == "/ip" and not qs_dict, + ) + resp = requests.get("https://httpbin.org/ip") + assert resp.status_code == 200 + assert resp.json() == {"message": "There you go!"} diff --git a/tests/main/test_httpx.py b/tests/test_httpx.py similarity index 91% rename from tests/main/test_httpx.py rename to tests/test_httpx.py index 889a7df8..add53de8 100644 --- a/tests/main/test_httpx.py +++ b/tests/test_httpx.py @@ -194,3 +194,22 @@ async def test_httpx_fixture(httpx_client): response = await client.get(url) assert response.json() == data + + +@pytest.mark.asyncio +async def test_httpx_fixture_with_can_handle_fun(httpx_client): + url = "https://foo.bar/barfoo" + data = {"message": "Gotcha!"} + + Entry.single_register( + Entry.GET, + "https://foo.bar", + body=json.dumps(data), + headers={"content-type": "application/json"}, + can_handle_fun=lambda p, q: p.endswith("foo"), + ) + + async with httpx_client as client: + response = await client.get(url) + + assert response.json() == data diff --git a/tests/main/test_mocket.py b/tests/test_mocket.py similarity index 98% rename from tests/main/test_mocket.py rename to tests/test_mocket.py index e6116dd1..8810a5b9 100644 --- a/tests/main/test_mocket.py +++ b/tests/test_mocket.py @@ -222,6 +222,7 @@ def test_patch( @pytest.mark.skipif(not psutil.POSIX, reason="Uses a POSIX-only API to test") +@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') @pytest.mark.asyncio async def test_no_dangling_fds(): url = "http://httpbin.local/ip" @@ -233,4 +234,4 @@ async def test_no_dangling_fds(): async with Mocketizer(strict_mode=False), httpx.AsyncClient() as client: await client.get(url) - assert proc.num_fds() == prev_num_fds + assert proc.num_fds() <= prev_num_fds diff --git a/tests/main/test_mode.py b/tests/test_mode.py similarity index 88% rename from tests/main/test_mode.py rename to tests/test_mode.py index 2a764949..bfdb2a79 100644 --- a/tests/main/test_mode.py +++ b/tests/test_mode.py @@ -4,7 +4,7 @@ from mocket import Mocketizer, mocketize from mocket.exceptions import StrictMocketException from mocket.mockhttp import Entry, Response -from mocket.utils import MocketMode +from mocket.mode import MocketMode @mocketize(strict_mode=True) @@ -52,6 +52,8 @@ def test_strict_mode_error_message(): str(exc_info.value) == """ Mocket tried to use the real `socket` module while STRICT mode was active. +Attempted address: httpbin.local:80 +Sent data: GET /ip HTTP/1.1 Registered entries: ('httpbin.local', 80): Entry(method='GET', schema='http', location=('httpbin.local', 80), path='/user.agent', query='') @@ -67,5 +69,5 @@ def test_strict_mode_false_with_allowed_hosts(): @pytest.mark.parametrize("strict_mode_on", (False, True)) def test_strict_mode_allowed_or_not(strict_mode_on): with Mocketizer(strict_mode=strict_mode_on): - assert MocketMode().is_allowed("foobar.com") is not strict_mode_on - assert MocketMode().is_allowed(("foobar.com", 443)) is not strict_mode_on + assert MocketMode.is_allowed("foobar.com") is not strict_mode_on + assert MocketMode.is_allowed(("foobar.com", 443)) is not strict_mode_on diff --git a/tests/test_pook.py b/tests/test_pook.py new file mode 100644 index 00000000..012fcdfb --- /dev/null +++ b/tests/test_pook.py @@ -0,0 +1,30 @@ +import contextlib + +with contextlib.suppress(ModuleNotFoundError): + import pook + import requests + from mocket.plugins.pook_mock_engine import MocketEngine + + pook.set_mock_engine(MocketEngine) + + @pook.on + def test_pook_engine(): + url = "http://twitter.com/api/1/foobar" + status = 404 + response_json = {"error": "foo"} + + mock = pook.get( + url, + headers={"content-type": "application/json"}, + reply=status, + response_json=response_json, + ) + mock.persist() + + requests.get(url) + assert mock.calls == 1 + + resp = requests.get(url) + assert resp.status_code == status + assert resp.json() == response_json + assert mock.calls == 2 diff --git a/tests/main/test_redis.py b/tests/test_redis.py similarity index 97% rename from tests/main/test_redis.py rename to tests/test_redis.py index 50b9beac..fb6ec355 100644 --- a/tests/main/test_redis.py +++ b/tests/test_redis.py @@ -158,9 +158,11 @@ def setUp(self): self.rclient = redis.StrictRedis() def mocketize_setup(self): + Entry.register_response("CLIENT SETINFO LIB-NAME redis-py", OK) + Entry.register_response(f"CLIENT SETINFO LIB-VER {redis.__version__}", OK) Entry.register_response("FLUSHDB", OK) self.rclient.flushdb() - self.assertEqual(len(Mocket.request_list()), 1) + self.assertEqual(len(Mocket.request_list()), 3) Mocket.reset() @mocketize diff --git a/tests/test_socket.py b/tests/test_socket.py new file mode 100644 index 00000000..68e71aee --- /dev/null +++ b/tests/test_socket.py @@ -0,0 +1,149 @@ +import socket +import struct +from unittest.mock import MagicMock + +import pytest + +from mocket import Mocket, MocketEntry, mocketize +from mocket.socket import MocketSocket + + +@pytest.mark.parametrize("blocking", (False, True)) +def test_blocking_socket(blocking): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(("locahost", 1234)) + sock.setblocking(blocking) + assert sock.getblocking() is blocking + + +@mocketize +def test_udp_socket(): + host = "127.0.0.1" + port = 9999 + request_data = b"ping" + response_data = b"pong" + + Mocket.register(MocketEntry((host, port), [response_data])) + + # Your UDP client code + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.sendto(request_data, (host, port)) + data, address = sock.recvfrom(1024) + + assert data == response_data + assert address == (host, port) + + +def test_recvmsg(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + test_data = b"hello world" + sock._io = type("MockIO", (), {"read": lambda self, n: test_data})() + data, ancdata = sock.recvmsg(1024) + assert data == test_data + assert ancdata == [] + + +def test_recvmsg_into(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + test_data = b"foobar" + sock._io = type("MockIO", (), {"read": lambda self, n: test_data})() + buf = bytearray(10) + buf2 = bytearray(10) + buffers = [buf, buf2] + nbytes = sock.recvmsg_into(buffers) + assert nbytes == len(test_data) + assert buf[: len(test_data)] == test_data + + +def test_recvmsg_into_empty_buffers(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.recvmsg_into([]) + assert result == 0 + + +def test_accept(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._host = "127.0.0.1" + sock._port = 8080 + new_sock, addr = sock.accept() + assert isinstance(new_sock, MocketSocket) + assert new_sock is not sock + assert addr == ("127.0.0.1", 8080) + assert new_sock._host == "127.0.0.1" + assert new_sock._port == 8080 + + +@mocketize +def test_sendmsg(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._host = "127.0.0.1" + sock._port = 8080 + response_data = b"pong" + + Mocket.register(MocketEntry((sock._host, sock._port), [response_data])) + + msg = [b"foo", b"bar", b"foobaz"] + total_sent = sock.sendmsg(msg) + assert total_sent == sum(len(m) for m in msg) + assert Mocket.last_request() == b"".join(msg) + + +def test_sendmsg_empty_buffers(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.sendmsg([]) + assert result == 0 + + +def test_recvmsg_no_data(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + # Mock _io.read to return empty bytes + sock._io = type("MockIO", (), {"read": lambda self, n: b""})() + data, ancdata = sock.recvmsg(1024) + assert data == b"" + assert ancdata == [] + + +def test_recvmsg_into_no_data(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + # Mock _io.read to return empty bytes + sock._io = type("MockIO", (), {"read": lambda self, n: b""})() + buf = bytearray(10) + nbytes = sock.recvmsg_into([buf]) + assert nbytes == 0 + assert buf == bytearray(10) + + +def test_getsockopt(): + # getsockopt is a static method, so we can call it directly + result = MocketSocket.getsockopt(0, 0) + assert result == socket.SOCK_STREAM + + +def test_recvfrom_into(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + test_data = b"abc123" + sock._io = type("MockIO", (), {"read": lambda self, n: test_data})() + buf = bytearray(10) + nbytes, addr = sock.recvfrom_into(buf) + assert nbytes == len(test_data) + assert buf[:nbytes] == test_data + assert addr == sock._address + + +def test_setsockopt_without_optlen(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._true_socket = MagicMock() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock._true_socket.setsockopt.assert_called_once_with( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 + ) + + +def test_setsockopt_with_optlen(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._true_socket = MagicMock() + linger_value = struct.pack("ii", 1, 5) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value)) + sock._true_socket.setsockopt.assert_called_once_with( + socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value) + ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..a791d136 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,55 @@ +from typing import Callable +from unittest import TestCase +from unittest.mock import NonCallableMock, patch + +import decorator + +from mocket.utils import get_mocketize, hexdump, hexload + + +def mock_decorator(func: Callable[[], None]) -> None: + return func() + + +class GetMocketizeTestCase(TestCase): + @patch.object(decorator, "decorator") + def test_get_mocketize_with_kwsyntax(self, dec: NonCallableMock) -> None: + get_mocketize(mock_decorator) + dec.assert_called_once_with(mock_decorator, kwsyntax=True) + + @patch.object(decorator, "decorator") + def test_get_mocketize_without_kwsyntax(self, dec: NonCallableMock) -> None: + dec.side_effect = [ + TypeError("kwsyntax is not supported in this version of decorator"), + mock_decorator, + ] + + get_mocketize(mock_decorator) + # First time called with kwsyntax=True, which failed with TypeError + dec.call_args_list[0].assert_compare_to((mock_decorator,), {"kwsyntax": True}) + # Second time without kwsyntax, which succeeds + dec.call_args_list[1].assert_compare_to((mock_decorator,)) + + +class HexdumpTestCase(TestCase): + def test_hexdump_converts_bytes_to_spaced_hex(self) -> None: + assert hexdump(b"Hi") == "48 69" + + def test_hexdump_empty_bytes(self) -> None: + assert hexdump(b"") == "" + + def test_hexdump_roundtrip_with_hexload(self) -> None: + data = b"bar foobar foo" + assert hexload(hexdump(data)) == data + + +class HexloadTestCase(TestCase): + def test_hexload_converts_spaced_hex_to_bytes(self) -> None: + assert hexload("48 69") == b"Hi" + + def test_hexload_empty_string(self) -> None: + assert hexload("") == b"" + + def test_hexload_invalid_hex_raises_value_error(self) -> None: + with self.assertRaises(ValueError): + hexload("ZZ ZZ") diff --git a/tests/main/tests.main.test_http.HttpEntryTestCase.test_truesendall_with_dump_from_recording.json b/tests/tests.test_http.HttpEntryTestCase.test_truesendall_with_dump_from_recording.json similarity index 100% rename from tests/main/tests.main.test_http.HttpEntryTestCase.test_truesendall_with_dump_from_recording.json rename to tests/tests.test_http.HttpEntryTestCase.test_truesendall_with_dump_from_recording.json diff --git a/tests/main/tests.main.test_http_gevent.GeventHttpEntryTestCase.test_truesendall_with_dump_from_recording.json b/tests/tests.test_http_gevent.GeventHttpEntryTestCase.test_truesendall_with_dump_from_recording.json similarity index 100% rename from tests/main/tests.main.test_http_gevent.GeventHttpEntryTestCase.test_truesendall_with_dump_from_recording.json rename to tests/tests.test_http_gevent.GeventHttpEntryTestCase.test_truesendall_with_dump_from_recording.json diff --git a/tests/main/tests.main.test_http_with_xxhash.HttpEntryTestCase.test_truesendall_with_dump_from_recording.json b/tests/tests.test_http_with_xxhash.HttpEntryTestCase.test_truesendall_with_dump_from_recording.json similarity index 100% rename from tests/main/tests.main.test_http_with_xxhash.HttpEntryTestCase.test_truesendall_with_dump_from_recording.json rename to tests/tests.test_http_with_xxhash.HttpEntryTestCase.test_truesendall_with_dump_from_recording.json diff --git a/tests/tests38/README.txt b/tests/tests38/README.txt deleted file mode 100644 index 9d9332be..00000000 --- a/tests/tests38/README.txt +++ /dev/null @@ -1 +0,0 @@ -Since IsolatedAsyncioTestCase is only available on Python >= 3.8, these tests won't be available to builds using previous versions.