diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4baa6f34..b5b8cff8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-24.04 strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.10'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', '3.14', 'pypy3.11'] env: # Configure a constant location for the uv cache UV_CACHE_DIR: /tmp/.uv-cache @@ -52,8 +52,11 @@ jobs: make services-up - name: Test run: | - make test - make services-down + if [[ "${{ matrix.python-version }}" == pypy* ]]; then + SKIP_MYPY=1 make test + else + make test + fi - name: Minimize uv cache run: uv cache prune --ci - name: Upload coverage reports to Codecov diff --git a/.gitignore b/.gitignore index 564b8ce6..9bacc469 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ shippable .vscode/ Pipfile.lock requirements.txt +coverage.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4edd2b69..b2cd5de5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: forbid-crlf - id: remove-crlf - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -15,8 +15,12 @@ repos: exclude: helm/ args: [ --unsafe ] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.11.11" + rev: "v0.14.0" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - id: ruff-format + - repo: https://github.com/rstcheck/rstcheck + rev: v6.2.5 + hooks: + - id: rstcheck diff --git a/LICENSE b/LICENSE index db612228..45cf27c5 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2017-2025 Giorgio Salluzzo and individual contributors. All rights reserved. +Copyright (c) 2017-2026 Giorgio Salluzzo and individual contributors. All rights reserved. Copyright (c) 2013-2017 Andrea de Marco, Giorgio Salluzzo and individual contributors. All rights reserved. diff --git a/Makefile b/Makefile index 3452a467..e35b0a57 100644 --- a/Makefile +++ b/Makefile @@ -26,8 +26,12 @@ setup: develop develop: install-dev-requirements install-test-requirements types: - @echo "Type checking Python files" - $(VENV_PATH)/mypy --pretty + @if [ -n "$$SKIP_MYPY" ]; then \ + echo "Skipping mypy types check because SKIP_MYPY is set"; \ + else \ + echo "Type checking Python files"; \ + $(VENV_PATH)/mypy --pretty; \ + fi @echo "" test: types @@ -45,7 +49,7 @@ publish: clean install-test-requirements uv publish clean: - rm -rf *.egg-info dist/ requirements.txt uv.lock || true + rm -rf *.egg-info dist/ requirements.txt uv.lock coverage.xml || true find . -type d -name __pycache__ -exec rm -rf {} \; || true .PHONY: clean publish safetest test setup develop lint-python test-python _services-up diff --git a/README.rst b/README.rst index 1c95e904..a6c662d4 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,6 @@ -=============== -mocket /mɔˈkɛt/ -=============== +================ +mocket /ˈmɔ.kɛt/ +================ .. image:: https://github.com/mindflayer/python-mocket/actions/workflows/main.yml/badge.svg?branch=main :target: https://github.com/mindflayer/python-mocket/actions?query=workflow%3A%22Mocket%27s+CI%22 @@ -14,6 +14,9 @@ mocket /mɔˈkɛt/ .. image:: https://img.shields.io/pypi/dm/mocket :target: https://pypistats.org/packages/mocket +.. image:: https://deepwiki.com/badge.svg + :target: https://deepwiki.com/mindflayer/python-mocket + .. image:: https://raw.githubusercontent.com/mindflayer/python-mocket/main/mocket.png :height: 256px :width: 256px @@ -27,15 +30,30 @@ A socket mock framework ...and then MicroPython's *urequests* (*mocket >= 3.9.1*) +What is it about? +================= + +In a nutshell, **Mocket** is *monkey-patching on steroids* for the ``socket`` and ``ssl`` modules. + +It’s designed to serve two main purposes: + +- As a **low-level framework** — for example, if you're building a client for a new database or protocol. +- As a **ready-to-use mock** — perfect for testing HTTP or HTTPS calls from any client library. + +To demonstrate that Mocket is more than just a web client mocking tool, it even includes a simple Redis mock. + +The main goal of Mocket is to make it easier to test Python clients that communicate using the ``socket`` protocol. + Outside GitHub ============== -Mocket packages are available for `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, and of course from `PyPI`_. +Mocket packages are available for `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, `AUR Arch Linux`_, and of course from `PyPI`_. .. _`openSUSE`: https://software.opensuse.org/search?baseproject=ALL&q=mocket .. _`NixOS`: https://search.nixos.org/packages?query=mocket .. _`ALT Linux`: https://packages.altlinux.org/en/sisyphus/srpms/python3-module-mocket/ .. _`NetBSD`: https://cdn.netbsd.org/pub/pkgsrc/current/pkgsrc/devel/py-mocket/index.html +.. _`AUR Arch Linux`: https://aur.archlinux.org/packages/python-mocket .. _`PyPI`: https://pypi.org/project/mocket/ @@ -69,12 +87,12 @@ The starting point to understand how to use *Mocket* to write a custom mock is t As next step, you are invited to have a look at the implementation of both the mocks it provides: -- HTTP mock (similar to HTTPretty) - https://github.com/mindflayer/python-mocket/blob/master/mocket/mocks/mockhttp.py -- Redis mock (basic implementation) - https://github.com/mindflayer/python-mocket/blob/master/mocket/mocks/mockredis.py +- HTTP mock (similar to HTTPretty) - https://github.com/mindflayer/python-mocket/blob/main/mocket/mocks/mockhttp.py +- Redis mock (basic implementation) - https://github.com/mindflayer/python-mocket/blob/main/mocket/mocks/mockredis.py Please also have a look at the huge test suite: -- Tests module at https://github.com/mindflayer/python-mocket/tree/master/tests +- Tests module at https://github.com/mindflayer/python-mocket/tree/main/tests Installation ============ @@ -211,6 +229,37 @@ It's very important that we test non-happy paths. with self.assertRaises(requests.exceptions.ConnectionError): requests.get(url) +Example of how to mock a call with a custom request matching logic +================================================================== +.. code-block:: python + + import json + + from mocket import mocketize + from mocket.mocks.mockhttp import Entry + import requests + + @mocketize + def test_can_handle(): + Entry.single_register( + Entry.GET, + url, + body=json.dumps({"message": "Nope... not this time!"}), + headers={"content-type": "application/json"}, + can_handle_fun=lambda path, qs_dict: path == "/ip" and qs_dict, + ) + Entry.single_register( + Entry.GET, + url, + body=json.dumps({"message": "There you go!"}), + headers={"content-type": "application/json"}, + can_handle_fun=lambda path, qs_dict: path == "/ip" and not qs_dict, + ) + resp = requests.get("https://httpbin.org/ip") + assert resp.status_code == 200 + assert resp.json() == {"message": "There you go!"} + + Example of how to record real socket traffic ============================================ @@ -237,10 +286,12 @@ You probably know what *VCRpy* is capable of, that's the *mocket*'s way of achie HTTPretty compatibility layer ============================= -Mocket HTTP mock can work as *HTTPretty* replacement for many different use cases. Two main features are missing: +Mocket HTTP mock can work as *HTTPretty* replacement for many different use cases. Two main features are missing, or better said, are implemented differently: + +- URL entries containing regular expressions, *Mocket* implements `can_handle_fun` which is way simpler to use and more powerful; +- response body from functions (used mostly to fake errors, *Mocket* accepts an `exception` instead). -- URL entries containing regular expressions; -- response body from functions (used mostly to fake errors, *mocket* doesn't need to do it this way). +Both features are documented above. Two features which are against the Zen of Python, at least imho (*mindflayer*), but of course I am open to call it into question. @@ -290,11 +341,11 @@ Example: .. code-block:: python # `aiohttp` creates SSLContext instances at import-time - # that's why Mocket would get stuck when dealing with HTTP + # that's why Mocket would get stuck when dealing with HTTPS # Importing the module while Mocket is in control (inside a # decorated test function or using its context manager would # be enough for making it work), the alternative is using a - # custom TCPConnector which always return a FakeSSLContext + # custom TCPConnector which always returns a FakeSSLContext # from Mocket like this example is showing. import aiohttp import pytest diff --git a/mocket/__init__.py b/mocket/__init__.py index eaf33dff..27ffad16 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,3 +1,5 @@ +"""Mocket - socket mocking library for Python.""" + import importlib import sys @@ -31,4 +33,4 @@ "FakeSSLContext", ) -__version__ = "3.13.10" +__version__ = "3.14.1" diff --git a/mocket/compat.py b/mocket/compat.py index 1ac2fc89..a8e726f6 100644 --- a/mocket/compat.py +++ b/mocket/compat.py @@ -11,23 +11,57 @@ def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes: + """Encode a string or bytes to bytes. + + Args: + s: String or bytes to encode + encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var) + + Returns: + Encoded bytes + """ if isinstance(s, str): s = s.encode(encoding) return bytes(s) def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str: + """Decode bytes or string to string. + + Args: + s: String or bytes to decode + encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var) + + Returns: + Decoded string + """ if isinstance(s, bytes): s = codecs.decode(s, encoding, "ignore") return str(s) def shsplit(s: str | bytes) -> list[str]: + """Split a shell command string into arguments. + + Args: + s: Shell command string or bytes + + Returns: + List of shell command arguments + """ s = decode_from_bytes(s) return shlex.split(s) -def do_the_magic(body): +def do_the_magic(body: bytes) -> str: + """Detect MIME type of binary data using puremagic. + + Args: + body: Binary data to analyze + + Returns: + MIME type string + """ try: magic = puremagic.magic_string(body) except puremagic.PureError: diff --git a/mocket/decorators/async_mocket.py b/mocket/decorators/async_mocket.py index 3839d5f1..53b966c0 100644 --- a/mocket/decorators/async_mocket.py +++ b/mocket/decorators/async_mocket.py @@ -1,15 +1,34 @@ +"""Async version of Mocket decorator.""" + +from __future__ import annotations + +from typing import Any, Callable + from mocket.decorators.mocketizer import Mocketizer from mocket.utils import get_mocketize async def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): + test: Callable, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + *args: Any, + **kwargs: Any, +) -> Any: + """Async wrapper function for @async_mocketize decorator. + + Args: + test: Async test function to wrap + truesocket_recording_dir: Directory for recording true socket calls + strict_mode: Enable STRICT mode to forbid real socket calls + strict_mode_allowed: List of allowed hosts in STRICT mode + *args: Test arguments + **kwargs: Test keyword arguments + + Returns: + Result of the test function + """ async with Mocketizer.factory( test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args ): diff --git a/mocket/decorators/mocketizer.py b/mocket/decorators/mocketizer.py index fb7c811b..b067ffdf 100644 --- a/mocket/decorators/mocketizer.py +++ b/mocket/decorators/mocketizer.py @@ -1,17 +1,34 @@ +"""Mocketizer decorator for managing Mocket lifecycle in tests.""" + +from __future__ import annotations + +from typing import Any, Callable + from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import get_mocketize class Mocketizer: + """Context manager and decorator for managing Mocket lifecycle in tests.""" + def __init__( self, - instance=None, - namespace=None, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - ): + instance: Any | None = None, + namespace: str | None = None, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + ) -> None: + """Initialize the Mocketizer. + + Args: + instance: Test instance (optional) + namespace: Namespace for recordings + truesocket_recording_dir: Directory for recording true socket calls + strict_mode: Enable STRICT mode to forbid real socket calls + strict_mode_allowed: List of allowed hosts in STRICT mode + """ self.instance = instance self.truesocket_recording_dir = truesocket_recording_dir self.namespace = namespace or str(id(self)) @@ -23,7 +40,8 @@ def __init__( "Allowed locations are only accepted when STRICT mode is active." ) - def enter(self): + def enter(self) -> None: + """Enter the Mocketizer context (enable Mocket).""" Mocket.enable( namespace=self.namespace, truesocket_recording_dir=self.truesocket_recording_dir, @@ -31,33 +49,80 @@ def enter(self): if self.instance: self.check_and_call("mocketize_setup") - def __enter__(self): + def __enter__(self) -> Mocketizer: + """Enter context manager. + + Returns: + Self for use in `with` statements + """ self.enter() return self - def exit(self): + def exit(self) -> None: + """Exit the Mocketizer context (disable Mocket).""" if self.instance: self.check_and_call("mocketize_teardown") Mocket.disable() - def __exit__(self, type, value, tb): + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + """Exit context manager. + + Args: + type: Exception type + value: Exception value + tb: Traceback + """ self.exit() - async def __aenter__(self, *args, **kwargs): + async def __aenter__(self, *args: Any, **kwargs: Any) -> Mocketizer: + """Enter async context manager. + + Returns: + Self for use in `async with` statements + """ self.enter() return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit async context manager. + + Args: + *args: Exception arguments + **kwargs: Exception keyword arguments + """ self.exit() - def check_and_call(self, method_name): + def check_and_call(self, method_name: str) -> None: + """Check if instance has a method and call it. + + Args: + method_name: Name of method to check and call + """ method = getattr(self.instance, method_name, None) if callable(method): method() @staticmethod - def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): + def factory( + test: Callable, + truesocket_recording_dir: str | None, + strict_mode: bool, + strict_mode_allowed: list | None, + args: tuple, + ) -> Mocketizer: + """Create a Mocketizer instance for a test function. + + Args: + test: Test function being decorated + truesocket_recording_dir: Recording directory + strict_mode: Enable STRICT mode + strict_mode_allowed: Allowed hosts in STRICT mode + args: Positional arguments to test + + Returns: + Configured Mocketizer instance + """ instance = args[0] if args else None namespace = None if truesocket_recording_dir: @@ -79,13 +144,26 @@ def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, ar def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): + test: Callable, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + *args: Any, + **kwargs: Any, +) -> Any: + """Wrapper function for @mocketize decorator. + + Args: + test: Test function to wrap + truesocket_recording_dir: Recording directory + strict_mode: Enable STRICT mode + strict_mode_allowed: Allowed hosts in STRICT mode + *args: Test arguments + **kwargs: Test keyword arguments + + Returns: + Result of the test function + """ with Mocketizer.factory( test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args ): diff --git a/mocket/entry.py b/mocket/entry.py index 9dbbf442..2d618472 100644 --- a/mocket/entry.py +++ b/mocket/entry.py @@ -1,22 +1,38 @@ +"""Mocket entry base class for registering mock responses.""" + +from __future__ import annotations + import collections.abc +from typing import Any from mocket.compat import encode_to_bytes from mocket.mocket import Mocket class MocketEntry: + """Base class for Mocket entries that match requests and return responses.""" + class Response(bytes): + """Response wrapper class that extends bytes.""" + @property - def data(self): + def data(self) -> bytes: + """Get the response data.""" return self - response_index = 0 - request_cls = bytes - response_cls = Response - responses = None - _served = None + response_index: int = 0 + request_cls: type = bytes + response_cls: type = Response + responses: list | None = None + _served: bool | None = None + + def __init__(self, location: tuple, responses: Any) -> None: + """Initialize a Mocket entry. - def __init__(self, location, responses): + Args: + location: Tuple of (host, port) + responses: Single response or list of responses to cycle through + """ self._served = False self.location = location @@ -34,18 +50,40 @@ def __init__(self, location, responses): r = self.response_cls(r) self.responses.append(r) - def __repr__(self): + def __repr__(self) -> str: + """Return a string representation of the entry.""" return f"{self.__class__.__name__}(location={self.location})" @staticmethod - def can_handle(data): + def can_handle(data: bytes) -> bool: + """Check if this entry can handle the given request data. + + Args: + data: Request data to check + + Returns: + True if this entry can handle the request, False otherwise + """ return True - def collect(self, data): + def collect(self, data: bytes) -> None: + """Collect the request data in the Mocket singleton. + + Args: + data: Request data to collect + """ req = self.request_cls(data) Mocket.collect(req) - def get_response(self): + def get_response(self) -> bytes: + """Get the next response to send. + + Returns: + Response bytes to send to the client + + Raises: + BaseException: If a response is an exception, it will be raised + """ response = self.responses[self.response_index] if self.response_index < len(self.responses) - 1: self.response_index += 1 diff --git a/mocket/exceptions.py b/mocket/exceptions.py index f5537568..db78dbf5 100644 --- a/mocket/exceptions.py +++ b/mocket/exceptions.py @@ -1,6 +1,13 @@ +"""Mocket exception classes.""" + + class MocketException(Exception): + """Base exception class for Mocket errors.""" + pass class StrictMocketException(MocketException): + """Exception raised when a socket operation is not allowed in STRICT mode.""" + pass diff --git a/mocket/inject.py b/mocket/inject.py index 866ee563..e788a929 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -1,3 +1,5 @@ +"""Socket patching and restoration for Mocket injection.""" + from __future__ import annotations import contextlib @@ -12,17 +14,31 @@ def _patch(module: ModuleType, name: str, patched_value: Any) -> None: + """Patch a module with a new value and store the original. + + Args: + module: Module to patch + name: Attribute name to patch + patched_value: New value to set + """ with contextlib.suppress(KeyError): original_value, module.__dict__[name] = module.__dict__[name], patched_value _patches_restore[(module, name)] = original_value def _restore(module: ModuleType, name: str) -> None: + """Restore a module's original attribute value. + + Args: + module: Module to restore + name: Attribute name to restore + """ if original_value := _patches_restore.pop((module, name)): module.__dict__[name] = original_value def enable() -> None: + """Enable Mocket by patching socket, ssl, and urllib3 modules.""" from mocket.socket import ( MocketSocket, mock_create_connection, @@ -71,6 +87,7 @@ def enable() -> None: def disable() -> None: + """Disable Mocket by restoring all patched modules.""" for module, name in list(_patches_restore.keys()): _restore(module, name) diff --git a/mocket/io.py b/mocket/io.py index 0334410b..e815e0ec 100644 --- a/mocket/io.py +++ b/mocket/io.py @@ -1,3 +1,7 @@ +"""Mocket socket I/O implementation.""" + +from __future__ import annotations + import io import os @@ -5,13 +9,29 @@ class MocketSocketIO(io.BytesIO): - def __init__(self, address) -> None: + """A BytesIO wrapper that integrates with Mocket's pipe-based I/O.""" + + def __init__(self, address: tuple) -> None: + """Initialize the socket I/O with a socket address. + + Args: + address: Tuple of (host, port) + """ self._address = address super().__init__() - def write(self, content): + def write(self, content: bytes) -> int: + """Write content to the buffer and the pipe if available. + + Args: + content: Bytes to write + + Returns: + Number of bytes written + """ super().write(content) _, w_fd = Mocket.get_pair(self._address) if w_fd: os.write(w_fd, content) + return len(content) diff --git a/mocket/mocket.py b/mocket/mocket.py index c9e6e204..75ae6285 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,3 +1,5 @@ +"""Core Mocket singleton for socket mocking management.""" + from __future__ import annotations import collections @@ -18,8 +20,10 @@ class Mocket: + """Singleton class managing all mock socket operations and entries.""" + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} - _address: ClassVar[Address] = (None, None) + _address: ClassVar[Address | tuple[None, None]] = (None, None) _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) _requests: ClassVar[list] = [] _record_storage: ClassVar[MocketRecordStorage | None] = None @@ -30,15 +34,19 @@ def enable( namespace: str | None = None, truesocket_recording_dir: str | None = None, ) -> None: + """Enable Mocket socket mocking. + + Args: + namespace: Namespace for recording storage (defaults to id of _entries) + truesocket_recording_dir: Directory to store recorded requests/responses + """ if namespace is None: namespace = str(id(cls._entries)) if truesocket_recording_dir is not None: recording_dir = Path(truesocket_recording_dir) - if not recording_dir.is_dir(): - # JSON dumps will be saved here - raise AssertionError + assert recording_dir.is_dir(), f"Not a directory: {recording_dir}" cls._record_storage = MocketRecordStorage( directory=recording_dir, @@ -49,33 +57,61 @@ def enable( @classmethod def disable(cls) -> None: + """Disable Mocket socket mocking and clean up resources.""" cls.reset() mocket.inject.disable() @classmethod def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: - """ + """Get the file descriptor pair for a socket address. + Given the id() of the caller, return a pair of file descriptors as a tuple of two integers: (, ) + + Args: + address: (host, port) tuple + + Returns: + Tuple of (read_fd, write_fd) or (None, None) if not found """ return cls._socket_pairs.get(address, (None, None)) @classmethod def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: - """ - Store a pair of file descriptors under the key `id_` + """Store a file descriptor pair for a socket address. + + Store a pair of file descriptors under the key `address` as a tuple of two integers: (, ) + + Args: + address: (host, port) tuple + pair: Tuple of (read_fd, write_fd) """ cls._socket_pairs[address] = pair @classmethod def register(cls, *entries: MocketEntry) -> None: + """Register mock entries with Mocket. + + Args: + *entries: Variable number of MocketEntry instances to register + """ for entry in entries: cls._entries[entry.location].append(entry) @classmethod - def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: + def get_entry(cls, host: str, port: int, data: Any) -> MocketEntry | None: + """Get a matching entry for the given request data. + + Args: + host: Hostname + port: Port number + data: Request data + + Returns: + Matching MocketEntry or None + """ host = host or cls._address[0] port = port or cls._address[1] entries = cls._entries.get((host, port), []) @@ -85,11 +121,17 @@ def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: return None @classmethod - def collect(cls, data) -> None: + def collect(cls, data: Any) -> None: + """Collect a request in the list of all requests. + + Args: + data: Request data to collect + """ cls._requests.append(data) @classmethod def reset(cls) -> None: + """Reset all Mocket state and clean up file descriptors.""" for r_fd, w_fd in cls._socket_pairs.values(): os.close(r_fd) os.close(w_fd) @@ -100,36 +142,62 @@ def reset(cls) -> None: @classmethod def last_request(cls) -> Any: + """Get the last request made. + + Returns: + Last request data or None if no requests + """ if cls.has_requests(): return cls._requests[-1] @classmethod def request_list(cls) -> list[Any]: + """Get the list of all requests. + + Returns: + List of all collected requests + """ return cls._requests @classmethod def remove_last_request(cls) -> None: + """Remove the last request from the request list.""" if cls.has_requests(): del cls._requests[-1] @classmethod def has_requests(cls) -> bool: + """Check if any requests have been made. + + Returns: + True if there are requests, False otherwise + """ return bool(cls.request_list()) @classmethod def get_namespace(cls) -> str | None: - if not cls._record_storage: - return None - return cls._record_storage.namespace + """Get the recording namespace. + + Returns: + Namespace string or None if recording is not enabled + """ + return cls._record_storage.namespace if cls._record_storage else None @classmethod def get_truesocket_recording_dir(cls) -> str | None: - if not cls._record_storage: - return None - return str(cls._record_storage.directory) + """Get the true socket recording directory. + + Returns: + Directory path as string or None if recording is not enabled + """ + return str(cls._record_storage.directory) if cls._record_storage else None @classmethod def assert_fail_if_entries_not_served(cls) -> None: - """Mocket checks that all entries have been served at least once.""" + """Assert that all registered entries have been served at least once. + + Raises: + AssertionError: If any entries have not been served + """ if not all(entry._served for entry in itertools.chain(*cls._entries.values())): raise AssertionError("Some Mocket entries have not been served") diff --git a/mocket/mocks/mockhttp.py b/mocket/mocks/mockhttp.py index 3db6a65d..5ec14a62 100644 --- a/mocket/mocks/mockhttp.py +++ b/mocket/mocks/mockhttp.py @@ -1,7 +1,12 @@ +"""HTTP mocking implementation for Mocket.""" + +from __future__ import annotations + import re import time from functools import cached_property from http.server import BaseHTTPRequestHandler +from typing import Any, Callable from urllib.parse import parse_qs, unquote, urlsplit from h11 import SERVER, Connection, Data @@ -11,42 +16,79 @@ from mocket.entry import MocketEntry from mocket.mocket import Mocket -STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} -CRLF = "\r\n" -ASCII = "ascii" +STATUS: dict = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} +CRLF: str = "\r\n" +ASCII: str = "ascii" class Request: - _parser = None - _event = None + """HTTP request parser using h11.""" + + _parser: Connection | None = None + _event: Any | None = None + + def __init__(self, data: bytes) -> None: + """Initialize the request parser. - def __init__(self, data): + Args: + data: Raw HTTP request data + """ self._parser = Connection(SERVER) self.add_data(data) - def add_data(self, data): + def add_data(self, data: bytes) -> None: + """Add more data to the request. + + Args: + data: Additional raw request data + """ self._parser.receive_data(data) @property - def event(self): + def event(self) -> Any: + """Get the parsed request event. + + Returns: + The h11 request event + """ if not self._event: self._event = self._parser.next_event() return self._event @cached_property - def method(self): + def method(self) -> str: + """Get the HTTP method. + + Returns: + HTTP method (GET, POST, etc.) + """ return self.event.method.decode(ASCII) @cached_property - def path(self): + def path(self) -> str: + """Get the request path. + + Returns: + Request path with query string + """ return self.event.target.decode(ASCII) @cached_property - def headers(self): + def headers(self) -> dict: + """Get the request headers. + + Returns: + Dictionary of header names to values + """ return {k.decode(ASCII): v.decode(ASCII) for k, v in self.event.headers} @cached_property - def querystring(self): + def querystring(self) -> dict: + """Get the parsed query string. + + Returns: + Dictionary of query parameter names to lists of values + """ parts = self.path.split("?", 1) return ( parse_qs(unquote(parts[1]), keep_blank_values=True) @@ -55,7 +97,12 @@ def querystring(self): ) @cached_property - def body(self): + def body(self) -> str: + """Get the request body. + + Returns: + Decoded request body string + """ while True: event = self._parser.next_event() if isinstance(event, H11Request): @@ -63,15 +110,31 @@ def body(self): elif isinstance(event, Data): return event.data.decode(ENCODING) - def __str__(self): + def __str__(self) -> str: + """Get string representation of request. + + Returns: + Formatted request string + """ return f"{self.method} - {self.path} - {self.headers}" class Response: - headers = None - is_file_object = False + """HTTP response builder.""" + + headers: dict | None = None + is_file_object: bool = False - def __init__(self, body="", status=200, headers=None): + def __init__( + self, body: Any = "", status: int = 200, headers: dict | None = None + ) -> None: + """Initialize an HTTP response. + + Args: + body: Response body (string, bytes, or file-like object) + status: HTTP status code + headers: Dictionary of response headers + """ headers = headers or {} try: # File Objects @@ -82,13 +145,19 @@ def __init__(self, body="", status=200, headers=None): self.status = status self.set_base_headers() - - if headers is not None: - self.set_extra_headers(headers) + self.set_extra_headers(headers) self.data = self.get_protocol_data() + self.body def get_protocol_data(self, str_format_fun_name: str = "capitalize") -> bytes: + """Get the HTTP protocol headers and status line. + + Args: + str_format_fun_name: Name of string formatting method to use + + Returns: + Bytes of protocol headers (status line and headers) + """ status_line = f"HTTP/1.1 {self.status} {STATUS[self.status]}" header_lines = CRLF.join( ( @@ -98,7 +167,8 @@ def get_protocol_data(self, str_format_fun_name: str = "capitalize") -> bytes: ) return f"{status_line}\r\n{header_lines}\r\n\r\n".encode(ENCODING) - def set_base_headers(self): + def set_base_headers(self) -> None: + """Set the base response headers.""" self.headers = { "Status": str(self.status), "Date": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()), @@ -111,8 +181,12 @@ def set_base_headers(self): else: self.headers["Content-Type"] = do_the_magic(self.body) - def set_extra_headers(self, headers): - r""" + def set_extra_headers(self, headers: dict) -> None: + r"""Add extra headers to the response. + + Args: + headers: Dictionary of additional headers + >>> r = Response(body="") >>> len(r.headers.keys()) 6 @@ -127,6 +201,8 @@ def set_extra_headers(self, headers): class Entry(MocketEntry): + """HTTP entry for matching and responding to HTTP requests.""" + CONNECT = "CONNECT" DELETE = "DELETE" GET = "GET" @@ -137,12 +213,33 @@ class Entry(MocketEntry): PUT = "PUT" TRACE = "TRACE" - METHODS = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE) - - request_cls = Request - response_cls = Response + METHODS: tuple = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE) + + request_cls: type = Request + response_cls: type = Response + + default_config: dict = {"match_querystring": True, "can_handle_fun": None} + _can_handle_fun: Callable | None = None + + def __init__( + self, + uri: str, + method: str, + responses: Any, + match_querystring: bool = True, + can_handle_fun: Callable | None = None, + ) -> None: + """Initialize an HTTP entry. + + Args: + uri: URI to match (http://host:port/path?query) + method: HTTP method (GET, POST, etc.) + responses: Response(s) to return + match_querystring: Whether to match query strings + can_handle_fun: Custom matching function + """ + self._can_handle_fun = can_handle_fun if can_handle_fun else self._can_handle - def __init__(self, uri, method, responses, match_querystring=True): uri = urlsplit(uri) port = uri.port @@ -151,16 +248,29 @@ def __init__(self, uri, method, responses, match_querystring=True): super().__init__((uri.hostname, port), responses) self.schema = uri.scheme - self.path = uri.path + self.path = uri.path or "/" self.query = uri.query self.method = method.upper() self._sent_data = b"" self._match_querystring = match_querystring - def __repr__(self): + def __repr__(self) -> str: + """Get string representation of the entry. + + Returns: + String representation + """ return f"{self.__class__.__name__}(method={self.method!r}, schema={self.schema!r}, location={self.location!r}, path={self.path!r}, query={self.query!r})" - def collect(self, data): + def collect(self, data: bytes) -> bool: + """Collect the request data. + + Args: + data: Request data + + Returns: + Whether to consume the response + """ consume_response = True decoded_data = decode_from_bytes(data) @@ -175,8 +285,32 @@ def collect(self, data): return consume_response - def can_handle(self, data): - r""" + def _can_handle(self, path: str, qs_dict: dict) -> bool: + """Default can_handle function checking path and query string. + + Args: + path: Request path + qs_dict: Parsed query string parameters + + Returns: + True if this entry can handle the request + """ + can_handle = path == self.path + if self._match_querystring: + can_handle = can_handle and qs_dict == parse_qs( + self.query, keep_blank_values=True + ) + return can_handle + + def can_handle(self, data: bytes) -> bool: + r"""Check if this entry can handle the given request data. + + Args: + data: Request data + + Returns: + True if this entry can handle the request + >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),)) >>> e.can_handle(b'GET /?bar=foo HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n') False @@ -190,22 +324,31 @@ def can_handle(self, data): except ValueError: return self is getattr(Mocket, "_last_entry", None) - uri = urlsplit(path) - can_handle = uri.path == self.path and method == self.method - if self._match_querystring: - kw = dict(keep_blank_values=True) - can_handle = can_handle and parse_qs(uri.query, **kw) == parse_qs( - self.query, **kw - ) + _request = urlsplit(path) + + can_handle = method == self.method and self._can_handle_fun( + _request.path, parse_qs(_request.query, keep_blank_values=True) + ) + if can_handle: Mocket._last_entry = self return can_handle @staticmethod - def _parse_requestline(line): - """ + def _parse_requestline(line: str) -> tuple: + """Parse an HTTP request line. + http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5 + Args: + line: HTTP request line string + + Returns: + Tuple of (method, path, version) + + Raises: + ValueError: If line is not a valid request line + >>> Entry._parse_requestline('GET / HTTP/1.0') == ('GET', '/', '1.0') True >>> Entry._parse_requestline('post /testurl htTP/1.1') == ('POST', '/testurl', '1.1') @@ -223,32 +366,61 @@ def _parse_requestline(line): raise ValueError("Not a Request-Line") @classmethod - def register(cls, method, uri, *responses, **config): + def register(cls, method: str, uri: str, *responses: Any, **config: Any) -> None: + """Register an HTTP entry for multiple responses. + + Args: + method: HTTP method (GET, POST, etc.) + uri: URI to match + *responses: Response(s) to cycle through + **config: Configuration options (match_querystring, can_handle_fun) + + Raises: + AttributeError: If using body/status params (use single_register instead) + KeyError: If invalid config keys provided + """ if "body" in config or "status" in config: raise AttributeError("Did you mean `Entry.single_register(...)`?") - default_config = dict(match_querystring=True, add_trailing_slash=True) - default_config.update(config) - config = default_config + if config.keys() - cls.default_config.keys(): + raise KeyError( + f"Invalid config keys: {config.keys() - cls.default_config.keys()}" + ) - if config["add_trailing_slash"] and not urlsplit(uri).path: - uri += "/" + _config = cls.default_config.copy() + _config.update({k: v for k, v in config.items() if k in _config}) - Mocket.register( - cls(uri, method, responses, match_querystring=config["match_querystring"]) - ) + Mocket.register(cls(uri, method, responses, **_config)) @classmethod def single_register( cls, - method, - uri, - body="", - status=200, - headers=None, - match_querystring=True, - exception=None, - ): + method: str, + uri: str, + body: Any = "", + status: int = 200, + headers: dict | None = None, + exception: Exception | None = None, + match_querystring: bool = True, + can_handle_fun: Callable | None = None, + **config: Any, + ) -> None: + """Register a single HTTP response for a URI and method. + + This is a convenience method that creates a single Response object + instead of requiring a list. + + Args: + method: HTTP method (GET, POST, etc.) + uri: URI to match + body: Response body content + status: HTTP status code + headers: Dictionary of response headers + exception: Exception to raise instead of returning response + match_querystring: Whether to match query strings + can_handle_fun: Custom matching function + **config: Additional configuration options + """ response = ( exception if exception @@ -260,4 +432,6 @@ def single_register( uri, response, match_querystring=match_querystring, + can_handle_fun=can_handle_fun, + **config, ) diff --git a/mocket/mocks/mockredis.py b/mocket/mocks/mockredis.py index fc386e2d..eee2d6c8 100644 --- a/mocket/mocks/mockredis.py +++ b/mocket/mocks/mockredis.py @@ -1,4 +1,9 @@ +"""Redis mocking implementation for Mocket.""" + +from __future__ import annotations + from itertools import chain +from typing import Any from mocket.compat import ( decode_from_bytes, @@ -7,29 +12,63 @@ ) from mocket.entry import MocketEntry from mocket.mocket import Mocket +from mocket.types import Address class Request: - def __init__(self, data): + """Redis request wrapper.""" + + def __init__(self, data: bytes) -> None: + """Initialize a Redis request. + + Args: + data: Raw Redis command data + """ self.data = data class Response: - def __init__(self, data=None): + """Redis response wrapper.""" + + def __init__(self, data: Any = None) -> None: + """Initialize a Redis response. + + Args: + data: Response data (will be "redisize"d) + """ self.data = Redisizer.redisize(data or OK) class Redisizer(bytes): + """Convert Python types to Redis protocol format.""" + @staticmethod - def tokens(iterable): + def tokens(iterable: list[Any]) -> list[bytes]: + """Convert an iterable to Redis tokens. + + Args: + iterable: List of items to convert + + Returns: + List of Redis protocol bytes + """ iterable = [encode_to_bytes(x) for x in iterable] return [f"*{len(iterable)}".encode()] + list( chain(*zip([f"${len(x)}".encode() for x in iterable], iterable)) ) @staticmethod - def redisize(data): - def get_conversion(t): + def redisize(data: Any) -> Redisizer: + """Convert Python data to Redis protocol format. + + Args: + data: Python data to convert + + Returns: + Redisizer bytes + """ + + def get_conversion(t: type) -> Any: return { dict: lambda x: b"\r\n".join( Redisizer.tokens(list(chain(*tuple(x.items())))) @@ -48,11 +87,28 @@ def get_conversion(t): return Redisizer(get_conversion(data.__class__)(data) + b"\r\n") @staticmethod - def command(description, _type="+"): + def command(description: str, _type: str = "+") -> Redisizer: + """Create a Redis command response. + + Args: + description: Response description + _type: Response type prefix (+, -, :, $, *) + + Returns: + Formatted Redis response + """ return Redisizer("{}{}{}".format(_type, description, "\r\n").encode("utf-8")) @staticmethod - def error(description): + def error(description: str) -> Redisizer: + """Create a Redis error response. + + Args: + description: Error description + + Returns: + Formatted Redis error response + """ return Redisizer.command(description, _type="-") @@ -62,20 +118,46 @@ def error(description): class Entry(MocketEntry): + """Redis entry for matching and responding to Redis commands.""" + request_cls = Request response_cls = Response - def __init__(self, addr, command, responses): + def __init__( + self, addr: Address | None, command: str, responses: list[Any] + ) -> None: + """Initialize a Redis entry. + + Args: + addr: (host, port) tuple or None for default + command: Redis command string to match + responses: List of responses to cycle through + """ super().__init__(addr or ("localhost", 6379), responses) d = shsplit(command) d[0] = d[0].upper() self.command = Redisizer.tokens(d) - def can_handle(self, data): + def can_handle(self, data: bytes) -> bool: + """Check if this entry can handle the given command. + + Args: + data: Raw Redis command data + + Returns: + True if this entry matches the command + """ return data.splitlines() == self.command @classmethod - def register(cls, addr, command, *responses): + def register(cls, addr: Address | None, command: str, *responses: Any) -> None: + """Register a Redis entry. + + Args: + addr: (host, port) tuple or None for default + command: Redis command to match + *responses: Responses to cycle through + """ responses = [ r if isinstance(r, BaseException) else cls.response_cls(r) for r in responses @@ -83,9 +165,27 @@ def register(cls, addr, command, *responses): Mocket.register(cls(addr, command, responses)) @classmethod - def register_response(cls, command, response, addr=None): + def register_response( + cls, command: str, response: Any, addr: Address | None = None + ) -> None: + """Register a single response for a command. + + Args: + command: Redis command to match + response: Response to return + addr: (host, port) tuple or None for default + """ cls.register(addr, command, response) @classmethod - def register_responses(cls, command, responses, addr=None): + def register_responses( + cls, command: str, responses: list[Any], addr: Address | None = None + ) -> None: + """Register multiple responses for a command. + + Args: + command: Redis command to match + responses: List of responses to cycle through + addr: (host, port) tuple or None for default + """ cls.register(addr, command, *responses) diff --git a/mocket/mode.py b/mocket/mode.py index ac2ca16a..ffb23a44 100644 --- a/mocket/mode.py +++ b/mocket/mode.py @@ -1,3 +1,5 @@ +"""Mocket mode management for strict socket enforcement.""" + from __future__ import annotations from typing import TYPE_CHECKING, Any, ClassVar @@ -10,17 +12,27 @@ class _MocketMode: + """Singleton class for managing Mocket's strict mode enforcement.""" + __shared_state: ClassVar[dict[str, Any]] = {} STRICT: ClassVar = None STRICT_ALLOWED: ClassVar = None def __init__(self) -> None: + """Initialize the MocketMode singleton with shared state.""" self.__dict__ = self.__shared_state def is_allowed(self, location: str | tuple[str, int]) -> bool: - """ + """Check if a location is allowed to perform real socket calls. + Checks if (`host`, `port`) or at least `host` are allowed locations to perform real `socket` calls + + Args: + location: Hostname string or (host, port) tuple + + Returns: + True if the location is allowed, False if in STRICT mode and not allowed """ if not self.STRICT: return True @@ -35,6 +47,15 @@ def raise_not_allowed( address: tuple[str, int] | None = None, data: bytes | None = None, ) -> NoReturn: + """Raise an exception when a socket operation is not allowed in STRICT mode. + + Args: + address: The (host, port) tuple that was attempted + data: The request data that was sent + + Raises: + StrictMocketException: Always raised with detailed context + """ current_entries = [ (location, "\n ".join(map(str, entries))) for location, entries in Mocket._entries.items() diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index 34de7932..fb40c0c5 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -139,6 +139,4 @@ def __getattr__(self, name): "HEAD", "PATCH", "register_uri", - "str", - "bytes", ) diff --git a/mocket/recording.py b/mocket/recording.py index 97d2adbe..95faf126 100644 --- a/mocket/recording.py +++ b/mocket/recording.py @@ -1,3 +1,5 @@ +"""Request/response recording for playback during tests.""" + from __future__ import annotations import contextlib @@ -6,12 +8,13 @@ from collections import defaultdict from dataclasses import dataclass from pathlib import Path +from typing import Any from mocket.compat import decode_from_bytes, encode_to_bytes from mocket.types import Address from mocket.utils import hexdump, hexload -hash_function = hashlib.md5 +hash_function: Any = hashlib.md5 with contextlib.suppress(ImportError): from xxhash_cffi import xxh32 as xxhash_cffi_xxh32 @@ -25,22 +28,48 @@ def _hash_prepare_request(data: bytes) -> bytes: + """Prepare request data for hashing by sorting headers. + + Args: + data: Raw request data + + Returns: + Prepared bytes for hashing + """ _data = decode_from_bytes(data) return encode_to_bytes("".join(sorted(_data.split("\r\n")))) def _hash_request(data: bytes) -> str: + """Hash a request using the best available hash function. + + Args: + data: Raw request data + + Returns: + Hex digest of the hash + """ _data = _hash_prepare_request(data) return hash_function(_data).hexdigest() def _hash_request_fallback(data: bytes) -> str: + """Hash a request using MD5 as fallback. + + Args: + data: Raw request data + + Returns: + Hex digest of the MD5 hash + """ _data = _hash_prepare_request(data) return hashlib.md5(_data).hexdigest() @dataclass class MocketRecord: + """A record of a request and its corresponding response.""" + host: str port: int request: bytes @@ -48,7 +77,15 @@ class MocketRecord: class MocketRecordStorage: + """Storage for recording and retrieving request/response pairs.""" + def __init__(self, directory: Path, namespace: str) -> None: + """Initialize the record storage. + + Args: + directory: Path to directory for storing recordings + namespace: Namespace for grouping records + """ self._directory = directory self._namespace = namespace self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = ( @@ -59,17 +96,33 @@ def __init__(self, directory: Path, namespace: str) -> None: @property def directory(self) -> Path: + """Get the recording directory. + + Returns: + Path to recording directory + """ return self._directory @property def namespace(self) -> str: + """Get the recording namespace. + + Returns: + Namespace string + """ return self._namespace @property def file(self) -> Path: + """Get the path to the namespace's JSON file. + + Returns: + Path to JSON recording file + """ return self._directory / f"{self._namespace}.json" def _load(self) -> None: + """Load recordings from disk.""" if not self.file.exists(): return @@ -92,6 +145,7 @@ def _load(self) -> None: ) def _save(self) -> None: + """Save recordings to disk.""" data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict( lambda: defaultdict(defaultdict) ) @@ -108,9 +162,26 @@ def _save(self) -> None: self.file.write_text(json_data) def get_records(self, address: Address) -> list[MocketRecord]: + """Get all records for an address. + + Args: + address: (host, port) tuple + + Returns: + List of MocketRecord instances + """ return list(self._records[address].values()) def get_record(self, address: Address, request: bytes) -> MocketRecord | None: + """Get a specific record matching the request. + + Args: + address: (host, port) tuple + request: Request bytes + + Returns: + Matching MocketRecord or None + """ # NOTE for backward-compat request_signature_fallback = _hash_request_fallback(request) if request_signature_fallback in self._records[address]: @@ -128,6 +199,13 @@ def put_record( request: bytes, response: bytes, ) -> None: + """Store a new record. + + Args: + address: (host, port) tuple + request: Request bytes + response: Response bytes + """ host, port = address record = MocketRecord( host=host, diff --git a/mocket/socket.py b/mocket/socket.py index e06a1a8e..bd79528c 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -1,3 +1,5 @@ +"""Mock socket implementation for Mocket.""" + from __future__ import annotations import contextlib @@ -25,7 +27,21 @@ true_socket = socket.socket -def mock_create_connection(address, timeout=None, source_address=None): +def mock_create_connection( + address: Address, + timeout: float | None = None, + source_address: Address | None = None, +) -> socket.socket: + """Create a mock socket connection. + + Args: + address: (host, port) tuple + timeout: Connection timeout in seconds + source_address: Source address for binding (unused) + + Returns: + MocketSocket instance + """ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: s.settimeout(timeout) @@ -41,29 +57,80 @@ def mock_getaddrinfo( proto: int = 0, flags: int = 0, ) -> list[tuple[int, int, int, str, tuple[str, int]]]: + """Mock socket.getaddrinfo function. + + Args: + host: Hostname + port: Port number + family: Address family (ignored) + type: Socket type (ignored) + proto: Protocol (ignored) + flags: Flags (ignored) + + Returns: + List of address info tuples + """ return [(2, 1, 6, "", (host, port))] def mock_gethostbyname(hostname: str) -> str: + """Mock socket.gethostbyname function. + + Args: + hostname: Hostname to resolve (unused) + + Returns: + Localhost IP address + """ return "127.0.0.1" def mock_gethostname() -> str: + """Mock socket.gethostname function. + + Returns: + Localhost hostname + """ return "localhost" def mock_inet_pton(address_family: int, ip_string: str) -> bytes: + """Mock socket.inet_pton function. + + Args: + address_family: Address family (unused) + ip_string: IP string (unused) + + Returns: + Localhost as bytes + """ return bytes("\x7f\x00\x00\x01", "utf-8") -def mock_socketpair(*args, **kwargs): - """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" +def mock_socketpair( + *args: Any, + **kwargs: Any, +) -> tuple[socket.socket, socket.socket]: + """Mock socket.socketpair function. + + Returns a real socketpair() used by asyncio loop for supporting + calls made by fastapi and similar services. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Tuple of two connected sockets + """ import _socket return _socket.socketpair(*args, **kwargs) class MocketSocket: + """Mock socket implementation for Mocket.""" + def __init__( self, family: socket.AddressFamily | int = socket.AF_INET, @@ -72,6 +139,15 @@ def __init__( fileno: int | None = None, **kwargs: Any, ) -> None: + """Initialize a Mocket socket. + + Args: + family: Address family + type: Socket type + proto: Protocol number + fileno: File descriptor (unused) + **kwargs: Additional keyword arguments + """ self._family = family self._type = type self._proto = proto @@ -90,9 +166,11 @@ def __init__( self._entry = None def __str__(self) -> str: + """Return a string representation of the socket.""" return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" def __enter__(self) -> Self: + """Enter context manager.""" return self def __exit__( @@ -101,27 +179,37 @@ def __exit__( value: BaseException | None, traceback: TracebackType | None, ) -> None: + """Exit context manager and close socket.""" self.close() @property def family(self) -> int: + """Get the address family.""" return self._family @property def type(self) -> int: + """Get the socket type.""" return self._type @property def proto(self) -> int: + """Get the protocol number.""" return self._proto @property def io(self) -> MocketSocketIO: + """Get or create the socket I/O object.""" if self._io is None: self._io = MocketSocketIO((self._host, self._port)) return self._io def fileno(self) -> int: + """Get the file descriptor for reading. + + Returns: + File descriptor number + """ address = (self._host, self._port) r_fd, _ = Mocket.get_pair(address) if not r_fd: @@ -130,52 +218,153 @@ def fileno(self) -> int: return r_fd def gettimeout(self) -> float | None: + """Get the socket timeout. + + Returns: + Timeout in seconds or None + """ return self._timeout - # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None` - def setsockopt(self, family: int, type: int, proto: int) -> None: - self._family = family - self._type = type - self._proto = proto + def setsockopt( + self, + level: int, + optname: int, + value: int | bytes | None, + optlen: int | None = None, + ) -> None: + """Set socket option. + Args: + level: Socket option level (e.g., socket.SOL_SOCKET) + optname: Socket option name (e.g., socket.SO_REUSEADDR) + value: Option value as an integer or bytes, or None when optlen is provided + optlen: Option length (used when value is None) + """ if self._true_socket: - self._true_socket.setsockopt(family, type, proto) + if optlen is not None: + self._true_socket.setsockopt(level, optname, value, optlen) + else: + self._true_socket.setsockopt(level, optname, value) def settimeout(self, timeout: float | None) -> None: + """Set the socket timeout. + + Args: + timeout: Timeout in seconds or None + """ self._timeout = timeout @staticmethod def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: + """Get socket option (mock implementation). + + Args: + level: Socket option level + optname: Socket option name + buflen: Buffer length (unused) + + Returns: + SOCK_STREAM constant + """ return socket.SOCK_STREAM def getpeername(self) -> _RetAddress: + """Get the remote socket address. + + Returns: + Address of the remote socket + """ return self._address def setblocking(self, block: bool) -> None: + """Set the socket to blocking or non-blocking mode. + + Args: + block: True for blocking, False for non-blocking + """ self.settimeout(None) if block else self.settimeout(0.0) def getblocking(self) -> bool: + """Check if the socket is in blocking mode. + + Returns: + True if blocking, False otherwise + """ return self.gettimeout() is None def getsockname(self) -> _RetAddress: + """Get the local socket address. + + Returns: + Local socket address + """ return socket.gethostbyname(self._address[0]), self._address[1] def connect(self, address: Address) -> None: + """Connect the socket to a remote address. + + Args: + address: (host, port) tuple + """ self._address = self._host, self._port = address Mocket._address = address def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO: + """Create a file object for the socket. + + Args: + mode: Mode string (unused) + bufsize: Buffer size (unused) + + Returns: + MocketSocketIO object + """ return self.io def get_entry(self, data: bytes) -> MocketEntry | None: + """Get a matching entry for the given data. + + Args: + data: Request data + + Returns: + Matching MocketEntry or None + """ return Mocket.get_entry(self._host, self._port, data) - def sendto(self, data: ReadableBuffer, address: Address | None = None) -> int: + def sendto( + self, + data: ReadableBuffer, + address: Address | None = None, + ) -> int: + """Send data to a specific address (UDP-like). + + Args: + data: Data to send + address: Destination address + + Returns: + Number of bytes sent + """ self.connect(address) self.sendall(data) return len(data) - def sendall(self, data, entry=None, *args, **kwargs): + def sendall( + self, + data: ReadableBuffer, + entry: MocketEntry | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """Send all data through the socket. + + Args: + data: Data to send + entry: Pre-matched entry (optional) + *args: Additional arguments + **kwargs: Additional keyword arguments + """ if entry is None: entry = self.get_entry(data) @@ -198,6 +387,17 @@ def sendmsg( flags: int = 0, address: Address | None = None, ) -> int: + """Send a message through multiple buffers. + + Args: + buffers: List of buffers to send + ancdata: Ancillary data (unused) + flags: Flags (unused) + address: Destination address (unused) + + Returns: + Number of bytes sent + """ if not buffers: return 0 @@ -211,16 +411,23 @@ def recvmsg( ancbufsize: int | None = None, flags: int = 0, ) -> tuple[bytes, list[tuple[int, bytes]]]: - """ - Receive a message from the socket. + """Receive a message from the socket. + This is a mock implementation that reads from the MocketSocketIO. + + Args: + buffersize: Size of buffer to receive + ancbufsize: Ancillary buffer size (unused) + flags: Flags (unused) + + Returns: + Tuple of (data, ancillary_data) """ try: data = self.recv(buffersize) except BlockingIOError: return b"", [] - # Mocking the ancillary data and flags as empty return data, [] def recvmsg_into( @@ -229,10 +436,19 @@ def recvmsg_into( ancbufsize: int | None = None, flags: int = 0, address: Address | None = None, - ): - """ - Receive a message into multiple buffers. + ) -> int: + """Receive a message into multiple buffers. + This is a mock implementation that reads from the MocketSocketIO. + + Args: + buffers: List of buffers to receive into + ancbufsize: Ancillary buffer size (unused) + flags: Flags (unused) + address: Address (unused) + + Returns: + Number of bytes received """ if not buffers: return 0 @@ -254,10 +470,16 @@ def recvfrom_into( buffer: WriteableBuffer, buffersize: int | None = None, flags: int | None = None, - ): - """ - Receive data into a buffer and return the number of bytes received. - This is a mock implementation that reads from the MocketSocketIO. + ) -> tuple[int, _RetAddress]: + """Receive data into a buffer and return the source address. + + Args: + buffer: Buffer to receive into + buffersize: Size to receive + flags: Flags (unused) + + Returns: + Tuple of (bytes_received, source_address) """ return self.recv_into(buffer, buffersize, flags), self._address @@ -267,10 +489,19 @@ def recv_into( buffersize: int | None = None, flags: int | None = None, ) -> int: + """Receive data into a buffer. + + Args: + buffer: Buffer to receive into + buffersize: Number of bytes to receive + flags: Flags (unused) + + Returns: + Number of bytes received + """ if hasattr(buffer, "write"): return buffer.write(self.recv(buffersize)) - # buffer is a memoryview if buffersize is None: buffersize = len(buffer) @@ -282,9 +513,30 @@ def recv_into( def recvfrom( self, buffersize: int, flags: int | None = None ) -> tuple[bytes, _RetAddress]: + """Receive data and the source address. + + Args: + buffersize: Number of bytes to receive + flags: Flags (unused) + + Returns: + Tuple of (data, source_address) + """ return self.recv(buffersize, flags), self._address def recv(self, buffersize: int, flags: int | None = None) -> bytes: + """Receive data from the socket. + + Args: + buffersize: Maximum number of bytes to receive + flags: Flags (unused) + + Returns: + Received bytes + + Raises: + BlockingIOError: If socket is non-blocking and no data available + """ r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) @@ -298,6 +550,19 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes: raise exc def true_sendall(self, data: bytes, *args: Any, **kwargs: Any) -> bytes: + """Send data through the real socket and receive response. + + Args: + data: Data to send + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + Response bytes from the real socket + + Raises: + StrictMocketException: If operation not allowed in STRICT mode + """ if not MocketMode.is_allowed(self._address): MocketMode.raise_not_allowed(self._address, data) @@ -344,7 +609,17 @@ def send( data: ReadableBuffer, *args: Any, **kwargs: Any, - ) -> int: # pragma: no cover + ) -> int: + """Send data through the socket. + + Args: + data: Data to send + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + Number of bytes sent + """ entry = self.get_entry(data) if not entry or (entry and self._entry != entry): kwargs["entry"] = entry @@ -357,7 +632,11 @@ def send( return len(data) def accept(self) -> tuple[MocketSocket, _RetAddress]: - """Accept a connection and return a new MocketSocket object.""" + """Accept a connection and return a new MocketSocket object. + + Returns: + Tuple of (new_socket, client_address) + """ new_socket = MocketSocket( family=self._family, type=self._type, @@ -369,11 +648,19 @@ def accept(self) -> tuple[MocketSocket, _RetAddress]: return new_socket, (self._host, self._port) def close(self) -> None: + """Close the socket and underlying true socket.""" if self._true_socket and not self._true_socket._closed: self._true_socket.close() def __getattr__(self, name: str) -> Any: - """Do nothing catchall function, for methods like shutdown()""" + """Do-nothing catchall function for methods like shutdown(). + + Args: + name: Method name + + Returns: + A callable that does nothing + """ def do_nothing(*args: Any, **kwargs: Any) -> Any: pass diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 6d5e7307..aeaab6b5 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,3 +1,5 @@ +"""Mocket SSL context implementation.""" + from __future__ import annotations from typing import Any @@ -7,10 +9,13 @@ class _MocketSSLContext: - """For Python 3.6 and newer.""" + """Mock SSL context for Python 3.6 and newer.""" class FakeSetter(int): + """Descriptor that ignores assignment.""" + def __set__(self, *args: Any) -> None: + """Ignore any assignment attempts.""" pass minimum_version = FakeSetter() @@ -20,29 +25,49 @@ def __set__(self, *args: Any) -> None: class MocketSSLContext(_MocketSSLContext): - DUMMY_METHODS = ( + """Mock SSL context that wraps sockets in MocketSSLSocket.""" + + DUMMY_METHODS: tuple = ( "load_default_certs", "load_verify_locations", "set_alpn_protocols", "set_ciphers", "set_default_verify_paths", ) - sock = None - post_handshake_auth = None - _check_hostname = False + sock: MocketSocket | None = None + post_handshake_auth: bool | None = None + _check_hostname: bool = False @property def check_hostname(self) -> bool: + """Get the check_hostname setting. + + Returns: + Always False (mock implementation) + """ return self._check_hostname @check_hostname.setter def check_hostname(self, _: bool) -> None: + """Set the check_hostname setting (mocked). + + Args: + _: Value (ignored, always set to False) + """ self._check_hostname = False def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the SSL context. + + Args: + *args: Positional arguments (ignored) + **kwargs: Keyword arguments (ignored) + """ self._set_dummy_methods() def _set_dummy_methods(self) -> None: + """Set all dummy methods that do nothing.""" + def dummy_method(*args: Any, **kwargs: Any) -> Any: pass @@ -55,15 +80,36 @@ def wrap_socket( *args: Any, **kwargs: Any, ) -> MocketSSLSocket: + """Wrap a socket in an SSL socket. + + Args: + sock: Socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ return MocketSSLSocket._create(sock, *args, **kwargs) def wrap_bio( self, - incoming: Any, # _ssl.MemoryBIO - outgoing: Any, # _ssl.MemoryBIO + incoming: Any, + outgoing: Any, server_side: bool = False, server_hostname: str | bytes | None = None, ) -> MocketSSLSocket: + """Wrap BIO objects in an SSL socket (mock implementation). + + Args: + incoming: Incoming BIO (_ssl.MemoryBIO) + outgoing: Outgoing BIO (_ssl.MemoryBIO) + server_side: Whether this is server side + server_hostname: Server hostname + + Returns: + MocketSSLSocket instance + """ ssl_obj = MocketSSLSocket() ssl_obj._host = server_hostname return ssl_obj @@ -74,5 +120,15 @@ def mock_wrap_socket( *args: Any, **kwargs: Any, ) -> MocketSSLSocket: + """Mock ssl.wrap_socket function. + + Args: + sock: Socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ context = MocketSSLContext() return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py index 6dcd7817..94984fce 100644 --- a/mocket/ssl/socket.py +++ b/mocket/ssl/socket.py @@ -1,3 +1,5 @@ +"""Mocket SSL socket implementation.""" + from __future__ import annotations import ssl @@ -12,14 +14,33 @@ class MocketSSLSocket(MocketSocket): + """Mock SSL socket that extends MocketSocket with SSL-specific behavior.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize an SSL socket. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + """ super().__init__(*args, **kwargs) - self._did_handshake = False - self._sent_non_empty_bytes = False + self._did_handshake: bool = False + self._sent_non_empty_bytes: bool = False self._original_socket: MocketSocket = self def read(self, buffersize: int | None = None) -> bytes: + """Read data from the SSL socket. + + Args: + buffersize: Maximum bytes to read + + Returns: + Bytes read from the socket + + Raises: + ssl.SSLWantReadError: If handshake not completed and no data + """ rv = self.io.read(buffersize) if rv: self._sent_non_empty_bytes = True @@ -28,12 +49,29 @@ def read(self, buffersize: int | None = None) -> bytes: return rv def write(self, data: bytes) -> int | None: + """Write data to the SSL socket. + + Args: + data: Bytes to write + + Returns: + Number of bytes written + """ return self.send(encode_to_bytes(data)) def do_handshake(self) -> None: + """Perform SSL handshake (mock implementation).""" self._did_handshake = True def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: + """Get the peer certificate (mock implementation). + + Args: + binary_form: Whether to return binary form (unused) + + Returns: + Mock certificate dictionary + """ if not (self._host and self._port): self._address = self._host, self._port = Mocket._address @@ -54,12 +92,27 @@ def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: } def ciper(self) -> tuple[str, str, str]: + """Get cipher information (mock implementation). + + Returns: + Tuple of (cipher_name, protocol, key_exchange_algorithm) + """ return "ADH", "AES256", "SHA" def compression(self) -> Options: + """Get compression options (mock implementation). + + Returns: + SSL options constant + """ return ssl.OP_NO_COMPRESSION def unwrap(self) -> MocketSocket: + """Unwrap the SSL socket and return the underlying socket. + + Returns: + The original MocketSocket + """ return self._original_socket @classmethod @@ -71,6 +124,18 @@ def _create( *args: Any, **kwargs: Any, ) -> MocketSSLSocket: + """Create an SSL socket from a regular socket. + + Args: + sock: Socket to wrap + ssl_context: SSL context (optional) + server_hostname: Server hostname + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + New MocketSSLSocket instance + """ ssl_socket = MocketSSLSocket() ssl_socket._original_socket = sock ssl_socket._true_socket = sock._true_socket diff --git a/mocket/types.py b/mocket/types.py index 562648c7..fedfd37f 100644 --- a/mocket/types.py +++ b/mocket/types.py @@ -1,3 +1,5 @@ +"""Type aliases and definitions for Mocket.""" + from __future__ import annotations from typing import Any, Dict, Tuple, Union diff --git a/mocket/urllib3.py b/mocket/urllib3.py index e89bc7b5..872efc5f 100644 --- a/mocket/urllib3.py +++ b/mocket/urllib3.py @@ -1,3 +1,5 @@ +"""Urllib3 specific socket mocking.""" + from __future__ import annotations from typing import Any @@ -8,6 +10,14 @@ def mock_match_hostname(*args: Any) -> None: + """Mock urllib3's match_hostname function. + + Args: + *args: Ignored arguments + + Returns: + None + """ return None @@ -16,5 +26,15 @@ def mock_ssl_wrap_socket( *args: Any, **kwargs: Any, ) -> MocketSSLSocket: + """Mock urllib3's ssl_wrap_socket function. + + Args: + sock: The socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ context = MocketSSLContext() return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/utils.py b/mocket/utils.py index 6180ae3f..749b2b70 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,3 +1,5 @@ +"""Utility functions for Mocket.""" + from __future__ import annotations import binascii @@ -14,12 +16,13 @@ class MocketizeDecorator(Protocol): - """ + """Protocol for a flexible decorator that can be used in multiple ways. + This is a generic decorator signature, currently applicable to get_mocketize. - Decorators can be used as: + Decorators implementing this protocol can be used as: 1. A function that transforms func (the parameter) into func1 (the returned object). - 2. A function that takes keyword arguments and returns 1. + 2. A function that takes keyword arguments and returns a decorator. """ @overload @@ -32,18 +35,37 @@ def __call__( def hexdump(binary_string: bytes) -> str: - r""" - >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) - True + """Convert binary data to space-separated hex string. + + Args: + binary_string: Binary data to convert + + Returns: + Space-separated hexadecimal representation + + Example: + >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) + True """ bs = decode_from_bytes(binascii.hexlify(binary_string).upper()) return " ".join(a + b for a, b in zip(bs[::2], bs[1::2])) def hexload(string: str) -> bytes: - r""" - >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo") - True + """Convert space-separated hex string to binary data. + + Args: + string: Space-separated hexadecimal string + + Returns: + Binary data + + Raises: + ValueError: If the hex string is invalid + + Example: + >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo") + True """ string_no_spaces = "".join(string.split()) try: @@ -53,6 +75,18 @@ def hexload(string: str) -> bytes: def get_mocketize(wrapper_: Callable) -> MocketizeDecorator: + """Get a mocketize decorator from a wrapper function. + + Decorators can be used as: + 1. A function that transforms func (the parameter) into func1 (the returned object). + 2. A function that takes keyword arguments and returns 1. + + Args: + wrapper_: The wrapper function to convert to a decorator + + Returns: + A MocketizeDecorator instance that can be used as a flexible decorator + """ # trying to support different versions of `decorator` with contextlib.suppress(TypeError): return decorator.decorator(wrapper_, kwsyntax=True) # type: ignore[return-value, call-arg, unused-ignore] diff --git a/pyproject.toml b/pyproject.toml index a921e223..f20dbb93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,11 @@ build-backend = "hatchling.build" requires-python = ">=3.8" name = "mocket" description = "Socket Mock Framework - for all kinds of socket animals, web-clients included - with gevent/asyncio/SSL support" -readme = { file = "README.rst", content-type = "text/x-rst" } -license = { file = "LICENSE" } +readme = "README.rst" +license = "BSD-3-Clause" +license-files = [ + "LICENSE", +] authors = [{ name = "Giorgio Salluzzo", email = "giorgio.salluzzo@gmail.com" }] classifiers = [ "Development Status :: 6 - Mature", @@ -19,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Software Development", @@ -58,7 +62,7 @@ test = [ "fastapi", "aiohttp", "wait-for-it", - "mypy", + "mypy; platform_python_implementation!='PyPy'", "types-decorator", "types-requests", "trio", @@ -123,7 +127,7 @@ select = [ max-complexity = 8 [tool.mypy] -python_version = "3.8" +python_version = "3.13" files = [ "mocket/exceptions.py", "mocket/compat.py", diff --git a/scripts/patch_hosts.sh b/scripts/patch_hosts.sh index af7e453d..ec527d2d 100644 --- a/scripts/patch_hosts.sh +++ b/scripts/patch_hosts.sh @@ -1,5 +1,9 @@ -sudo grep -v httpbin.local /etc/hosts | sudo tee /etc/hosts.mocket -export CONTAINER_ID=$(docker compose ps -q proxy) -export CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' $CONTAINER_ID) -echo "$CONTAINER_IP httpbin.local" | sudo tee -a /etc/hosts.mocket -sudo mv /etc/hosts.mocket /etc/hosts +HOSTS=/etc/hosts +MOCKET_HOSTS=/etc/hosts.mocket +HTTPBIN_HOST=httpbin.local + +sudo grep -v ${HTTPBIN_HOST} ${HOSTS} | sudo tee ${MOCKET_HOSTS} +CONTAINER_ID=$(docker compose ps -q proxy) +CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' ${CONTAINER_ID}) +echo "${CONTAINER_IP} ${HTTPBIN_HOST}" | sudo tee -a ${MOCKET_HOSTS} +sudo mv ${MOCKET_HOSTS} ${HOSTS} diff --git a/tests/test_http.py b/tests/test_http.py index afa31185..3d3e5b8e 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -12,7 +12,7 @@ import requests from mocket import Mocket, Mocketizer, mocketize -from mocket.mockhttp import Entry, Response +from mocket.mocks.mockhttp import Entry, Response class HttpTestCase(TestCase): @@ -433,3 +433,52 @@ def test_suggestion_for_register_and_status(self): url, status=201, ) + + def test_invalid_config_key(self): + url = "http://foobar.com/path" + with self.assertRaises(KeyError): + Entry.register( + Entry.POST, + url, + Response(body='{"foo":"bar0"}', status=200), + invalid_key=True, + ) + + def test_add_trailing_slash(self): + url = "http://testme.org" + entry = Entry(url, "GET", [Response(body='{"foo":"bar0"}', status=200)]) + self.assertEqual(entry.path, "/") + + @mocketize + def test_mocket_with_no_path(self): + Entry.register(Entry.GET, "http://httpbin.local", Response(status=202)) + response = urlopen("http://httpbin.local/") + self.assertEqual(response.code, 202) + self.assertEqual(Mocket._entries[("httpbin.local", 80)][0].path, "/") + + @mocketize + def test_can_handle(self): + Entry.single_register( + Entry.POST, + "http://testme.org/foobar", + body=json.dumps({"message": "Spooky!"}), + match_querystring=False, + ) + Entry.single_register( + Entry.GET, + "http://testme.org/", + body=json.dumps({"message": "Gotcha!"}), + can_handle_fun=lambda p, q: p.endswith("/foobar") and "a" in q, + ) + Entry.single_register( + Entry.GET, + "http://testme.org/foobar", + body=json.dumps({"message": "Missed!"}), + match_querystring=False, + ) + response = requests.get("http://testme.org/foobar?a=1") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"message": "Gotcha!"}) + response = requests.get("http://testme.org/foobar?b=2") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"message": "Missed!"}) diff --git a/tests/test_https.py b/tests/test_https.py index f8c8549e..4685f4eb 100644 --- a/tests/test_https.py +++ b/tests/test_https.py @@ -7,7 +7,7 @@ import requests from mocket import Mocket, Mocketizer, mocketize -from mocket.mockhttp import Entry +from mocket.mockhttp import Entry # noqa - test retrocompatibility @pytest.fixture @@ -43,6 +43,7 @@ def test_json(response): @pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') +@pytest.mark.xfail(reason="Service down or blocking GitHub actions IPs") def test_truesendall_with_recording_https(url_to_mock): with tempfile.TemporaryDirectory() as temp_dir, Mocketizer( truesocket_recording_dir=temp_dir @@ -62,6 +63,7 @@ def test_truesendall_with_recording_https(url_to_mock): @pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') +@pytest.mark.xfail(reason="Service down or blocking GitHub actions IPs") def test_truesendall_after_mocket_session(url_to_mock): Mocket.enable() Mocket.disable() @@ -71,6 +73,7 @@ def test_truesendall_after_mocket_session(url_to_mock): @pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') +@pytest.mark.xfail(reason="Service down or blocking GitHub actions IPs") def test_real_request_session(url_to_mock): session = requests.Session() @@ -88,3 +91,24 @@ def test_raise_exception_from_single_register(): Entry.single_register(Entry.GET, url, exception=OSError()) with pytest.raises(requests.exceptions.ConnectionError): requests.get(url) + + +@mocketize +def test_can_handle(): + Entry.single_register( + Entry.GET, + "https://httpbin.org", + body=json.dumps({"message": "Nope... not this time!"}), + headers={"content-type": "application/json"}, + can_handle_fun=lambda path, qs_dict: path == "/ip" and qs_dict, + ) + Entry.single_register( + Entry.GET, + "https://httpbin.org", + body=json.dumps({"message": "There you go!"}), + headers={"content-type": "application/json"}, + can_handle_fun=lambda path, qs_dict: path == "/ip" and not qs_dict, + ) + resp = requests.get("https://httpbin.org/ip") + assert resp.status_code == 200 + assert resp.json() == {"message": "There you go!"} diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 889a7df8..add53de8 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -194,3 +194,22 @@ async def test_httpx_fixture(httpx_client): response = await client.get(url) assert response.json() == data + + +@pytest.mark.asyncio +async def test_httpx_fixture_with_can_handle_fun(httpx_client): + url = "https://foo.bar/barfoo" + data = {"message": "Gotcha!"} + + Entry.single_register( + Entry.GET, + "https://foo.bar", + body=json.dumps(data), + headers={"content-type": "application/json"}, + can_handle_fun=lambda p, q: p.endswith("foo"), + ) + + async with httpx_client as client: + response = await client.get(url) + + assert response.json() == data diff --git a/tests/test_pook.py b/tests/test_pook.py index 56721b5f..012fcdfb 100644 --- a/tests/test_pook.py +++ b/tests/test_pook.py @@ -3,7 +3,6 @@ with contextlib.suppress(ModuleNotFoundError): import pook import requests - from mocket.plugins.pook_mock_engine import MocketEngine pook.set_mock_engine(MocketEngine) diff --git a/tests/test_socket.py b/tests/test_socket.py index dad62a33..68e71aee 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,4 +1,6 @@ import socket +import struct +from unittest.mock import MagicMock import pytest @@ -126,3 +128,22 @@ def test_recvfrom_into(): assert nbytes == len(test_data) assert buf[:nbytes] == test_data assert addr == sock._address + + +def test_setsockopt_without_optlen(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._true_socket = MagicMock() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock._true_socket.setsockopt.assert_called_once_with( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 + ) + + +def test_setsockopt_with_optlen(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._true_socket = MagicMock() + linger_value = struct.pack("ii", 1, 5) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value)) + sock._true_socket.setsockopt.assert_called_once_with( + socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value) + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index d3b5eba7..a791d136 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,7 @@ import decorator -from mocket.utils import get_mocketize +from mocket.utils import get_mocketize, hexdump, hexload def mock_decorator(func: Callable[[], None]) -> None: @@ -29,3 +29,27 @@ def test_get_mocketize_without_kwsyntax(self, dec: NonCallableMock) -> None: dec.call_args_list[0].assert_compare_to((mock_decorator,), {"kwsyntax": True}) # Second time without kwsyntax, which succeeds dec.call_args_list[1].assert_compare_to((mock_decorator,)) + + +class HexdumpTestCase(TestCase): + def test_hexdump_converts_bytes_to_spaced_hex(self) -> None: + assert hexdump(b"Hi") == "48 69" + + def test_hexdump_empty_bytes(self) -> None: + assert hexdump(b"") == "" + + def test_hexdump_roundtrip_with_hexload(self) -> None: + data = b"bar foobar foo" + assert hexload(hexdump(data)) == data + + +class HexloadTestCase(TestCase): + def test_hexload_converts_spaced_hex_to_bytes(self) -> None: + assert hexload("48 69") == b"Hi" + + def test_hexload_empty_string(self) -> None: + assert hexload("") == b"" + + def test_hexload_invalid_hex_raises_value_error(self) -> None: + with self.assertRaises(ValueError): + hexload("ZZ ZZ")