diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 81c07712..3a845672 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -4,17 +4,16 @@ on: [push, pull_request] jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install libolm @@ -47,17 +46,17 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.14" - uses: isort/isort-action@master with: sortPaths: "./mautrix" - uses: psf/black@stable with: src: "./mautrix" - version: "22.3.0" + version: "26.3.1" - name: pre-commit run: | pip install pre-commit diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8f3efe97..b0d7ab3f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,6 +1,6 @@ build docs builder: stage: build - image: docker:stable + image: docker:latest tags: - amd64 only: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 56919be3..66065033 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -8,13 +8,13 @@ repos: - id: check-yaml - id: check-added-large-files - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 26.3.1 hooks: - id: black language_version: python3 files: ^mautrix/.*\.pyi?$ - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 8.0.1 hooks: - id: isort files: ^mautrix/.*\.pyi?$ diff --git a/CHANGELOG.md b/CHANGELOG.md index 70c15c8f..c7179c86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,223 @@ -## unreleased +## v0.21.0 (2025-11-17) + +* *(event)* Added support for creator power in room v12+. +* *(crypto)* Added support for generating and using recovery keys for verifying + the active device. +* *(bridge)* Added config option for self-signing bot device. +* *(bridge)* Removed check for login flows when using MSC4190 + (thanks to [@meson800] in [#178]). +* *(client)* Changed `set_displayname` and `set_avatar_url` to avoid setting + empty strings if the value is already unset (thanks to [@frebib] in [#171]). + +[@frebib]: https://github.com/frebib +[@meson800]: https://github.com/meson800 +[#171]: https://github.com/mautrix/python/pull/171 +[#178]: https://github.com/mautrix/python/pull/178 + +## v0.20.8 (2025-06-01) + +* *(bridge)* Added support for [MSC4190] (thanks to [@surakin] in [#175]). +* *(appservice)* Renamed `push_ephemeral` in generated registrations to + `receive_ephemeral` to match the accepted version of [MSC2409]. +* *(bridge)* Fixed compatibility with breaking change in aiohttp 3.12.6. +[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 +[@surakin]: https://github.com/surakin +[#175]: https://github.com/mautrix/python/pull/175 + +## v0.20.7 (2025-01-03) + +* *(types)* Removed support for generating reply fallbacks to implement + [MSC2781]. Stripping fallbacks is still supported. + +[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 + +## v0.20.6 (2024-07-12) + +* *(bridge)* Added `/register` call if `/versions` fails with `M_FORBIDDEN`. + +## v0.20.5 (2024-07-09) + +**Note:** The `bridge` module is deprecated as all bridges are being rewritten +in Go. See for more info. + +* *(client)* Added support for authenticated media downloads. +* *(bridge)* Stopped using cached homeserver URLs for double puppeting if one + is set in the config file. +* *(crypto)* Fixed error when checking OTK counts before uploading new keys. +* *(types)* Added MSC2530 (captions) fields to `MediaMessageEventContent`. + +## v0.20.4 (2024-01-09) + +* Dropped Python 3.9 support. +* *(client)* Changed media download methods to log requests and to raise + exceptions on non-successful status codes. + +## v0.20.3 (2023-11-10) + +* *(client)* Deprecated MSC2716 methods and added new Beeper-specific batch + send methods, as upstream MSC2716 support has been abandoned. +* *(util.async_db)* Added `PRAGMA synchronous = NORMAL;` to default pragmas. +* *(types)* Fixed `guest_can_join` field name in room directory response + (thanks to [@ashfame] in [#163]). + +[@ashfame]: https://github.com/ashfame +[#163]: https://github.com/mautrix/python/pull/163 + +## v0.20.2 (2023-09-09) + +* *(crypto)* Changed `OlmMachine.share_keys` to make the OTK count parameter + optional. When omitted, the count is fetched from the server. +* *(appservice)* Added option to run appservice transaction event handlers + synchronously. +* *(appservice)* Added `log` and `hs_token` parameters to `AppServiceServerMixin` + to allow using it as a standalone class without extending. +* *(api)* Added support for setting appservice `user_id` and `device_id` query + parameters manually without using `AppServiceAPI`. + +## v0.20.1 (2023-08-29) + +* *(util.program)* Removed `--base-config` flag in bridges, as there are no + valid use cases (package data should always work) and it's easy to cause + issues by pointing the flag at the wrong file. +* *(bridge)* Added support for the `com.devture.shared_secret_auth` login type + for automatic double puppeting. +* *(bridge)* Dropped support for syncing with double puppets. MSC2409 is now + the only way to receive ephemeral events. +* *(bridge)* Added support for double puppeting with arbitrary `as_token`s. + +## v0.20.0 (2023-06-25) + +* Dropped Python 3.8 support. +* **Breaking change *(.state_store)*** Removed legacy SQLAlchemy state store + implementations. +* **Mildly breaking change *(util.async_db)*** Changed `SQLiteDatabase` to not + remove prefix slashes from database paths. + * Library users should use `sqlite:path.db` instead of `sqlite:///path.db` + for relative paths, and `sqlite:/path.db` instead of `sqlite:////path.db` + for absolute paths. + * Bridge configs do this migration automatically. +* *(util.async_db)* Added warning log if using SQLite database path that isn't + writable. +* *(util.program)* Fixed `manual_stop` not working if it's called during startup. +* *(client)* Stabilized support for asynchronous uploads. + * `unstable_create_msc` was renamed to `create_mxc`, and the `max_stall_ms` + parameters for downloading were renamed to `timeout_ms`. +* *(crypto)* Added option to not rotate keys when devices change. +* *(crypto)* Added option to remove all keys that were received before the + automatic ratcheting was implemented (in v0.19.10). +* *(types)* Improved reply fallback removal to have a smaller chance of false + positives for messages that don't use reply fallbacks. + +## v0.19.16 (2023-05-26) + +* *(appservice)* Fixed Python 3.8 compatibility. + +## v0.19.15 (2023-05-24) + +* *(client)* Fixed dispatching room ephemeral events (i.e. typing notifications) in syncer. + +## v0.19.14 (2023-05-16) + +* *(bridge)* Implemented appservice pinging using MSC2659. +* *(bridge)* Started reusing aiosqlite connection pool for crypto db. + * This fixes the crypto pool getting stuck if the bridge exits unexpectedly + (the default pool is closed automatically at any type of exit). + +## v0.19.13 (2023-04-24) + +* *(crypto)* Fixed bug with redacting megolm sessions when device is deleted. + +## v0.19.12 (2023-04-18) + +* *(bridge)* Fixed backwards-compatibility with new key deletion config options. + +## v0.19.11 (2023-04-14) + +* *(crypto)* Fixed bug in previous release which caused errors if the `max_age` + of a megolm session was not known. +* *(crypto)* Changed key receiving handler to fetch encryption config from + server if it's not cached locally (to find `max_age` and `max_messages` more + reliably). + +## v0.19.10 (2023-04-13) + +* *(crypto, bridge)* Added options to automatically ratchet/delete megolm + sessions to minimize access to old messages. + +## v0.19.9 (2023-04-12) + +* *(crypto)* Fixed bug in crypto store migration when using outbound sessions + with max age higher than usual. + +## v0.19.8 (2023-04-06) + +* *(crypto)* Updated crypto store schema to match mautrix-go. +* *(types)* Fixed `set_thread_parent` adding reply fallbacks to the message body. + +## v0.19.7 (2023-03-22) + +* *(bridge, crypto)* Fixed key sharing trust checker not resolving cross-signing + signatures when minimum trust level is set to cross-signed. + +## v0.19.6 (2023-03-13) + +* *(crypto)* Added cache checks to prevent invalidating group session when the + server sends a duplicate member event in /sync. +* *(util.proxy)* Fixed `min_wait_seconds` behavior and added `max_wait_seconds` + and `multiply_wait_seconds` to `proxy_with_retry`. + +## v0.19.5 (2023-03-07) + +* *(util.proxy)* Added utility for dynamic proxies (from mautrix-instagram/facebook). +* *(types)* Added default value for `upload_size` in `MediaRepoConfig` as the + field is optional in the spec. +* *(bridge)* Changed ghost invite handling to only process one per room at a time + (thanks to [@maltee1] in [#132]). + +[#132]: https://github.com/mautrix/python/pull/132 + +## v0.19.4 (2023-02-12) + +* *(types)* Changed `set_thread_parent` to inherit the existing thread parent + if a `MessageEvent` is passed, as starting threads from a message in a thread + is not allowed. +* *(util.background_task)* Added new utility for creating background tasks + safely, by ensuring that the task is not garbage collected before finishing + and logging uncaught exceptions immediately. + +## v0.19.3 (2023-01-27) + +* *(bridge)* Bumped default timeouts for decrypting incoming messages. + +## v0.19.2 (2023-01-14) + +* *(util.async_body)* Added utility for reading aiohttp response into a bytearray + (so that the output is mutable, e.g. for decrypting or encrypting media). +* *(client.api)* Fixed retry loop for MSC3870 URL uploads not exiting properly + after too many errors. + +## v0.19.1 (2023-01-11) + +* Marked Python 3.11 as supported. Python 3.8 support will likely be dropped in + the coming months. +* *(client.api)* Added request payload memory optimization to MSC3870 URL uploads. + * aiohttp will duplicate the entire request body if it's raw bytes, which + wastes a lot of memory. The optimization is passing an iterator instead of + raw bytes, so aiohttp won't accidentally duplicate the whole thing. + * The main `HTTPAPI` has had the optimization for a while, but uploading to + URL calls aiohttp manually. + +## v0.19.0 (2023-01-10) + +* **Breaking change *(appservice)*** Removed typing status from state store. +* **Breaking change *(appservice)*** Removed `is_typing` parameter from + `IntentAPI.set_typing` to make the signature match `ClientAPI.set_typing`. + `timeout=0` is equivalent to the old `is_typing=False`. +* **Breaking change *(types)*** Removed legacy fields in Beeper MSS events. * *(bridge)* Removed accidentally nested reply loop when accepting invites as the bridge bot. * *(bridge)* Fixed decoding JSON values in config override env vars. -* *(appservice)* Removed typing status from state store. - * Additionally, the `is_typing` boolean in `set_typing` is now deprecated, - and `timeout=0` should be used instead to match the `ClientAPI` behavior. ## v0.18.9 (2022-12-14) diff --git a/README.rst b/README.rst index 75ce8f6a..f1342c19 100644 --- a/README.rst +++ b/README.rst @@ -3,7 +3,7 @@ mautrix-python |PyPI| |Python versions| |License| |Docs| |Code style| |Imports| -A Python 3.8+ asyncio Matrix framework. +A Python 3.10+ asyncio Matrix framework. Matrix room: `#maunium:maunium.net`_ @@ -49,7 +49,7 @@ Components .. _#maunium:maunium.net: https://matrix.to/#/#maunium:maunium.net .. _python-appservice-framework: https://github.com/Cadair/python-appservice-framework/ -.. _Client API: https://matrix.org/docs/spec/client_server/r0.6.1.html +.. _Client API: https://spec.matrix.org/latest/client-server-api/ .. _mautrix.api: https://docs.mau.fi/python/latest/api/mautrix.api.html .. _mautrix.client.api: https://docs.mau.fi/python/latest/api/mautrix.client.api.html diff --git a/dev-requirements.txt b/dev-requirements.txt index e513c0df..dff6ee9d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ pre-commit>=2.10.1,<3 -isort>=5.10.1,<6 -black>=22.3,<23 +isort>=8,<9 +black>=26,<27 diff --git a/docs/api/mautrix.client.state_store/index.rst b/docs/api/mautrix.client.state_store/index.rst index 91f2ba42..0a934308 100644 --- a/docs/api/mautrix.client.state_store/index.rst +++ b/docs/api/mautrix.client.state_store/index.rst @@ -13,5 +13,4 @@ Implementations In-memory Async database (asyncpg/aiosqlite) - Legacy database (SQLAlchemy) Flat file diff --git a/docs/api/mautrix.client.state_store/sqlalchemy.rst b/docs/api/mautrix.client.state_store/sqlalchemy.rst deleted file mode 100644 index 767e0595..00000000 --- a/docs/api/mautrix.client.state_store/sqlalchemy.rst +++ /dev/null @@ -1,5 +0,0 @@ -mautrix.client.state\_store.sqlalchemy -====================================== - -.. autoclass:: mautrix.client.state_store.sqlalchemy.SQLStateStore - :no-undoc-members: diff --git a/docs/requirements.txt b/docs/requirements.txt index 798e0f1b..7269d3a5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,7 +9,7 @@ yarl # that aren't used for anything that's in the docs python-magic ruamel.yaml -SQLAlchemy +SQLAlchemy<2 commonmark asyncpg aiosqlite @@ -17,3 +17,4 @@ prometheus_client python-olm unpaddedbase64 pycryptodome +base58 diff --git a/mautrix/__init__.py b/mautrix/__init__.py index ae628d89..c94995f9 100644 --- a/mautrix/__init__.py +++ b/mautrix/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.18.9" +__version__ = "0.21.0" __author__ = "Tulir Asokan " __all__ = [ "api", diff --git a/mautrix/api.py b/mautrix/api.py index a0561d08..1adde9ec 100644 --- a/mautrix/api.py +++ b/mautrix/api.py @@ -5,7 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import AsyncGenerator, ClassVar, Literal, Mapping, Union +from typing import ClassVar, Literal, Mapping from enum import Enum from json.decoder import JSONDecodeError from urllib.parse import quote as urllib_quote, urljoin as urllib_join @@ -22,12 +22,13 @@ from mautrix import __optional_imports__, __version__ as mautrix_version from mautrix.errors import MatrixConnectionError, MatrixRequestError, make_request_error +from mautrix.util.async_body import AsyncBody, async_iter_bytes from mautrix.util.logging import TraceLogger from mautrix.util.opt_prometheus import Counter if __optional_imports__: # Safe to import, but it's not actually needed, so don't force-import the whole types module. - from mautrix.types import JSON + from mautrix.types import JSON, DeviceID, UserID API_CALLS = Counter( name="bridge_matrix_api_calls", @@ -155,7 +156,6 @@ def replace(self, find: str, replace: str) -> PathBuilder: """ _req_id = 0 -AsyncBody = AsyncGenerator[Union[bytes, bytearray, memoryview], None] def _next_global_req_id() -> int: @@ -164,12 +164,6 @@ def _next_global_req_id() -> int: return _req_id -async def _async_iter_bytes(data: bytearray | bytes, chunk_size: int = 1024**2) -> AsyncBody: - with memoryview(data) as mv: - for i in range(0, len(data), chunk_size): - yield mv[i : i + chunk_size] - - class HTTPAPI: """HTTPAPI is a simple asyncio Matrix API request sender.""" @@ -199,6 +193,13 @@ class HTTPAPI: default_retry_count: int """The default retry count to use if a custom value is not passed to :meth:`request`""" + as_user_id: UserID | None + """An optional user ID to set as the user_id query parameter for appservice requests.""" + as_device_id: DeviceID | None + """ + An optional device ID to set as the user_id query parameter for appservice requests (MSC3202). + """ + def __init__( self, base_url: URL | str, @@ -209,6 +210,8 @@ def __init__( txn_id: int = 0, log: TraceLogger | None = None, loop: asyncio.AbstractEventLoop | None = None, + as_user_id: UserID | None = None, + as_device_id: UserID | None = None, ) -> None: """ Args: @@ -218,6 +221,10 @@ def __init__( txn_id: The outgoing transaction ID to start with. log: The :class:`logging.Logger` instance to log requests with. default_retry_count: Default number of retries to do when encountering network errors. + as_user_id: An optional user ID to set as the user_id query parameter for + appservice requests. + as_device_id: An optional device ID to set as the user_id query parameter for + appservice requests (MSC3202). """ self.base_url = URL(base_url) self.token = token @@ -225,6 +232,8 @@ def __init__( self.session = client_session or ClientSession( loop=loop, headers={"User-Agent": self.default_ua} ) + self.as_user_id = as_user_id + self.as_device_id = as_device_id if txn_id is not None: self.txn_id = txn_id if default_retry_count is not None: @@ -266,7 +275,7 @@ def _log_request( self, method: Method, url: URL, - content: str | bytes | bytearray | AsyncBody, + content: str | bytes | bytearray | AsyncBody | None, orig_content, query_params: dict[str, str], headers: dict[str, str], @@ -305,7 +314,7 @@ def _log_request( ) def _log_request_done( - self, path: PathBuilder, req_id: int, duration: float, status: int + self, path: PathBuilder | str, req_id: int, duration: float, status: int ) -> None: level = 5 if path == Path.v3.sync else 10 duration_str = f"{duration * 1000:.1f}ms" if duration < 1 else f"{duration:.3f}s" @@ -325,6 +334,16 @@ def _full_path(self, path: PathBuilder | str) -> str: base_path += "/" return urllib_join(base_path, path) + def log_download_request(self, url: URL, query_params: dict[str, str]) -> int: + req_id = _next_global_req_id() + self._log_request(Method.GET, url, None, None, query_params, {}, req_id, False) + return req_id + + def log_download_request_done( + self, url: URL, req_id: int, duration: float, status: int + ) -> None: + self._log_request_done(url.path.removeprefix("/_matrix/media/"), req_id, duration, status) + async def request( self, method: Method, @@ -366,6 +385,11 @@ async def request( query_params = query_params or {} if isinstance(query_params, dict): query_params = {k: v for k, v in query_params.items() if v is not None} + if self.as_user_id: + query_params["user_id"] = self.as_user_id + if self.as_device_id: + query_params["org.matrix.msc3202.device_id"] = self.as_device_id + query_params["device_id"] = self.as_device_id if method != Method.GET: content = content or {} @@ -395,7 +419,7 @@ async def request( method, log_url, content, orig_content, query_params, headers, req_id, sensitive ) API_CALLS.labels(method=metrics_method).inc() - req_content = _async_iter_bytes(content) if do_fake_iter else content + req_content = async_iter_bytes(content) if do_fake_iter else content start = time.monotonic() try: resp_data, resp = await self._send( @@ -438,6 +462,7 @@ def get_download_url( mxc_uri: str, download_type: Literal["download", "thumbnail"] = "download", file_name: str | None = None, + authenticated: bool = False, ) -> URL: """ Get the full HTTP URL to download a ``mxc://`` URI. @@ -446,6 +471,7 @@ def get_download_url( mxc_uri: The MXC URI whose full URL to get. download_type: The type of download ("download" or "thumbnail"). file_name: Optionally, a file name to include in the download URL. + authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11. Returns: The full HTTP URL. @@ -461,7 +487,11 @@ def get_download_url( "https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png" """ server_name, media_id = self.parse_mxc_uri(mxc_uri) - url = self.base_url / str(APIPath.MEDIA) / "v3" / download_type / server_name / media_id + if authenticated: + url = self.base_url / str(APIPath.CLIENT) / "v1" / "media" + else: + url = self.base_url / str(APIPath.MEDIA) / "v3" + url = url / download_type / server_name / media_id if file_name: url /= file_name return url diff --git a/mautrix/appservice/api/appservice.py b/mautrix/appservice/api/appservice.py index af6ae48d..6654ec20 100644 --- a/mautrix/appservice/api/appservice.py +++ b/mautrix/appservice/api/appservice.py @@ -51,6 +51,7 @@ def __init__( client_session: ClientSession = None, child: bool = False, real_user: bool = False, + real_user_as_token: bool = False, bridge_name: str | None = None, default_retry_count: int = None, loop: asyncio.AbstractEventLoop | None = None, @@ -66,6 +67,7 @@ def __init__( client_session: The aiohttp ClientSession to use. child: Whether or not this is instance is a child of another AppServiceAPI. real_user: Whether or not this is a real (non-appservice-managed) user. + real_user_as_token: Whether this real user is actually using another ``as_token``. bridge_name: The name of the bridge to put in the ``fi.mau.double_puppet_source`` field in outgoing message events sent through real users. """ @@ -85,6 +87,7 @@ def __init__( self._bot_intent = None self.state_store = state_store self.is_real_user = real_user + self.is_real_user_as_token = real_user_as_token self.bridge_name = bridge_name if not child: @@ -113,7 +116,9 @@ def user(self, user: UserID) -> ChildAppServiceAPI: self.children[user] = child return child - def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> AppServiceAPI: + def real_user( + self, mxid: UserID, token: str, base_url: URL | None = None, as_token: bool = False + ) -> AppServiceAPI: """ Get the AppServiceAPI for a real (non-appservice-managed) Matrix user. @@ -122,6 +127,8 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap token: The access token for the user. base_url: The base URL of the homeserver client-server API to use. Defaults to the appservice homeserver URL. + as_token: Whether the token is actually an as_token + (meaning the ``user_id`` query parameter needs to be used). Returns: The AppServiceAPI object for the user. @@ -136,6 +143,7 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap child = self.real_users[mxid] child.base_url = base_url or child.base_url child.token = token or child.token + child.is_real_user_as_token = as_token except KeyError: child = type(self)( base_url=base_url or self.base_url, @@ -145,6 +153,7 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap state_store=self.state_store, client_session=self.session, real_user=True, + real_user_as_token=as_token, bridge_name=self.bridge_name, default_retry_count=self.default_retry_count, ) @@ -163,7 +172,11 @@ def bot_intent(self) -> as_api.IntentAPI: return self._bot_intent def intent( - self, user: UserID = None, token: str | None = None, base_url: str | None = None + self, + user: UserID = None, + token: str | None = None, + base_url: str | None = None, + real_user_as_token: bool = False, ) -> as_api.IntentAPI: """ Get the intent API of a child user. @@ -173,6 +186,8 @@ def intent( token: The access token to use. Only applicable for non-appservice-managed users. base_url: The base URL of the homeserver client-server API to use. Only applicable for non-appservice users. Defaults to the appservice homeserver URL. + real_user_as_token: When providing a token, whether it's actually another as_token + (meaning the ``user_id`` query parameter needs to be used). Returns: The IntentAPI object for the given user. @@ -184,7 +199,10 @@ def intent( raise ValueError("Can't get child intent of real user") if token: return as_api.IntentAPI( - user, self.real_user(user, token, base_url), self.bot_intent(), self.state_store + user, + self.real_user(user, token, base_url, as_token=real_user_as_token), + self.bot_intent(), + self.state_store, ) return as_api.IntentAPI(user, self.user(user), self.bot_intent(), self.state_store) @@ -229,7 +247,7 @@ def request( if isinstance(timestamp, datetime): timestamp = int(timestamp.replace(tzinfo=timezone.utc).timestamp() * 1000) query_params["ts"] = timestamp - if not self.is_real_user: + if not self.is_real_user or self.is_real_user_as_token: query_params["user_id"] = self.identity or self.bot_mxid return super().request( diff --git a/mautrix/appservice/api/intent.py b/mautrix/appservice/api/intent.py index 92966313..626b34f1 100644 --- a/mautrix/appservice/api/intent.py +++ b/mautrix/appservice/api/intent.py @@ -25,6 +25,7 @@ BatchSendEvent, BatchSendResponse, BatchSendStateEvent, + BeeperBatchSendResponse, ContentURI, EventContent, EventID, @@ -40,7 +41,6 @@ RoomNameStateEventContent, RoomPinnedEventsStateEventContent, RoomTopicStateEventContent, - SerializableAttrs, StateEventContent, UserID, ) @@ -71,7 +71,8 @@ def quote(*args, **kwargs): ClientAPI.search_users, ClientAPI.set_displayname, ClientAPI.set_avatar_url, - ClientAPI.unstable_create_mxc, + ClientAPI.beeper_update_profile, + ClientAPI.create_mxc, ClientAPI.upload_media, ClientAPI.send_receipt, ClientAPI.set_fully_read_marker, @@ -117,6 +118,8 @@ def __init__( ) -> None: super().__init__(mxid=mxid, api=api, state_store=state_store) self.bot = bot + if bot is not None: + self.versions_cache = bot.versions_cache self.log = api.base_log.getChild("intent") for method in ENSURE_REGISTERED_METHODS: @@ -143,7 +146,11 @@ async def wrapper(*args, __self=self, __method=method, **kwargs): setattr(self, method.__name__, wrapper) def user( - self, user_id: UserID, token: str | None = None, base_url: str | None = None + self, + user_id: UserID, + token: str | None = None, + base_url: str | None = None, + as_token: bool = False, ) -> IntentAPI: """ Get the intent API for a specific user. @@ -156,15 +163,17 @@ def user( user_id: The Matrix ID of the user whose intent API to get. token: The access token to use for the Matrix ID. base_url: An optional URL to use for API requests. + as_token: Whether the provided token is actually another as_token + (meaning the ``user_id`` query parameter needs to be used). Returns: The IntentAPI for the given user. """ if not self.bot: - return self.api.intent(user_id, token, base_url) + return self.api.intent(user_id, token, base_url, real_user_as_token=as_token) else: self.log.warning("Called IntentAPI#user() of child intent object.") - return self.bot.api.intent(user_id, token, base_url) + return self.bot.api.intent(user_id, token, base_url, real_user_as_token=as_token) # region User actions @@ -182,7 +191,7 @@ async def set_presence( Args: presence: The online status of the user. status: The status message. - ignore_cache: Whether or not to set presence even if the cache says the presence is + ignore_cache: Whether to set presence even if the cache says the presence is already set to that value. """ await self.ensure_registered() @@ -248,7 +257,9 @@ async def invite_user( await self.state_store.joined(room_id, user_id) except MatrixRequestError as e: # TODO remove this once MSC3848 is released and minimum spec version is bumped - if e.errcode == "M_FORBIDDEN" and "is already in the room" in e.message: + if e.errcode == "M_FORBIDDEN" and ( + "already in the room" in e.message or "is already joined to room" in e.message + ): await self.state_store.joined(room_id, user_id) else: raise @@ -405,19 +416,10 @@ async def get_room_member_info( async def set_typing( self, room_id: RoomID, - is_typing: bool = True, - timeout: int = 5000, + timeout: int = 0, ) -> None: - """ - Args: - room_id: The ID of the room in which the user is typing. - is_typing: Whether the user is typing. - .. deprecated:: 0.18.10 - Use ``timeout=0`` instead of setting this flag. - timeout: The length of time in seconds to mark this user as typing. - """ await self.ensure_joined(room_id) - await super().set_typing(room_id, timeout if is_typing else 0) + await super().set_typing(room_id, timeout) async def error_and_leave( self, room_id: RoomID, text: str | None = None, html: str | None = None @@ -496,6 +498,14 @@ async def mark_read( ) self.state_store.set_read(room_id, self.mxid, event_id) + async def appservice_ping(self, appservice_id: str, txn_id: str | None = None) -> int: + resp = await self.api.request( + Method.POST, + Path.v1.appservice[appservice_id].ping, + content={"transaction_id": txn_id} if txn_id is not None else {}, + ) + return resp.get("duration_ms") or -1 + async def batch_send( self, room_id: RoomID, @@ -514,6 +524,9 @@ async def batch_send( .. versionadded:: v0.12.5 + .. deprecated:: v0.20.3 + MSC2716 was abandoned by upstream and Beeper has forked the endpoint. + Args: room_id: The room ID to send the events to. prev_event_id: The anchor event. The batch will be inserted immediately after this event. @@ -548,6 +561,52 @@ async def batch_send( ) return BatchSendResponse.deserialize(resp) + async def beeper_batch_send( + self, + room_id: RoomID, + events: Iterable[BatchSendEvent], + *, + forward: bool = False, + forward_if_no_messages: bool = False, + send_notification: bool = False, + mark_read_by: UserID | None = None, + ) -> BeeperBatchSendResponse: + """ + Send a batch of events into a room. Only for Beeper/hungryserv. + + .. versionadded:: v0.20.3 + + Args: + room_id: The room ID to send the events to. + events: The events to send. + forward: Send events to the end of the room instead of the beginning + forward_if_no_messages: Send events to the end of the room, but only if there are no + messages in the room. If there are messages, send the new messages to the beginning. + send_notification: Send a push notification for the new messages. + Only applies when sending to the end of the room. + mark_read_by: Send a read receipt from the given user ID atomically. + + Returns: + All the event IDs generated. + """ + body = { + "events": [evt.serialize() for evt in events], + } + if forward: + body["forward"] = forward + elif forward_if_no_messages: + body["forward_if_no_messages"] = forward_if_no_messages + if send_notification: + body["send_notification"] = send_notification + if mark_read_by: + body["mark_read_by"] = mark_read_by + resp = await self.api.request( + Method.POST, + Path.unstable["com.beeper.backfill"].rooms[room_id].batch_send, + content=body, + ) + return BeeperBatchSendResponse.deserialize(resp) + async def beeper_delete_room(self, room_id: RoomID) -> None: versions = await self.versions() if not versions.supports("com.beeper.room_yeeting"): @@ -651,6 +710,8 @@ async def _ensure_has_power_level_for( if not await self.state_store.has_power_levels_cached(room_id): # TODO add option to not try to fetch power levels from server await self.get_power_levels(room_id, ignore_cache=True, ensure_joined=False) + if not await self.state_store.has_create_cached(room_id): + await self.get_state_event(room_id, EventType.ROOM_CREATE, format="event") if not await self.state_store.has_power_level(room_id, self.mxid, event_type): # TODO implement something better raise IntentError( diff --git a/mautrix/appservice/appservice.py b/mautrix/appservice/appservice.py index 551eab5c..65e202ae 100644 --- a/mautrix/appservice/appservice.py +++ b/mautrix/appservice/appservice.py @@ -13,7 +13,7 @@ from aiohttp import web import aiohttp -from mautrix.types import JSON, RoomAlias, UserID +from mautrix.types import JSON, RoomAlias, UserID, VersionsResponse from mautrix.util.logging import TraceLogger from ..api import HTTPAPI @@ -194,3 +194,6 @@ async def _liveness_probe(self, _: web.Request) -> web.Response: async def _readiness_probe(self, _: web.Request) -> web.Response: return web.Response(status=200 if self.ready else 500, text="{}") + + async def ping_self(self, txn_id: str | None = None) -> int: + return await self.intent.appservice_ping(self.id, txn_id=txn_id) diff --git a/mautrix/appservice/as_handler.py b/mautrix/appservice/as_handler.py index 88715100..ec7e339f 100644 --- a/mautrix/appservice/as_handler.py +++ b/mautrix/appservice/as_handler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,7 +8,6 @@ from typing import Any, Awaitable, Callable from json import JSONDecodeError -import asyncio import json import logging @@ -27,17 +26,18 @@ SerializerError, UserID, ) +from mautrix.util import background_task HandlerFunc = Callable[[Event], Awaitable] class AppServiceServerMixin: - loop: asyncio.AbstractEventLoop log: logging.Logger hs_token: str ephemeral_events: bool encryption_events: bool + synchronous_handlers: bool query_user: Callable[[UserID], JSON] query_alias: Callable[[RoomAlias], JSON] @@ -48,7 +48,17 @@ class AppServiceServerMixin: otk_handler: Callable[[dict[UserID, dict[DeviceID, DeviceOTKCount]]], Awaitable] | None device_list_handler: Callable[[DeviceLists], Awaitable] | None - def __init__(self, ephemeral_events: bool = False, encryption_events: bool = False) -> None: + def __init__( + self, + ephemeral_events: bool = False, + encryption_events: bool = False, + log: logging.Logger | None = None, + hs_token: str | None = None, + ) -> None: + if log is not None: + self.log = log + if hs_token is not None: + self.hs_token = hs_token self.transactions = set() self.event_handlers = [] self.to_device_handler = None @@ -56,6 +66,7 @@ def __init__(self, ephemeral_events: bool = False, encryption_events: bool = Fal self.device_list_handler = None self.ephemeral_events = ephemeral_events self.encryption_events = encryption_events + self.synchronous_handlers = False async def default_query_handler(_): return None @@ -74,6 +85,7 @@ def register_routes(self, app: web.Application) -> None: ) app.router.add_route("GET", "/_matrix/app/v1/rooms/{alias}", self._http_query_alias) app.router.add_route("GET", "/_matrix/app/v1/users/{user_id}", self._http_query_user) + app.router.add_route("POST", "/_matrix/app/v1/ping", self._http_ping) def _check_token(self, request: web.Request) -> bool: try: @@ -81,10 +93,12 @@ def _check_token(self, request: web.Request) -> bool: except KeyError: try: token = request.headers["Authorization"].removeprefix("Bearer ") - except (KeyError, AttributeError): + except KeyError: + self.log.debug("No access_token nor Authorization header in request") return False if token != self.hs_token: + self.log.debug(f"Incorrect hs_token in request") return False return True @@ -127,6 +141,23 @@ async def _http_query_alias(self, request: web.Request) -> web.Response: return web.json_response({}, status=404) return web.json_response(response) + async def _http_ping(self, request: web.Request) -> web.Response: + if not self._check_token(request): + raise web.HTTPUnauthorized( + content_type="application/json", + text=json.dumps({"error": "Invalid auth token", "errcode": "M_UNKNOWN_TOKEN"}), + ) + try: + body = await request.json() + except JSONDecodeError: + raise web.HTTPBadRequest( + content_type="application/json", + text=json.dumps({"error": "Body is not JSON", "errcode": "M_NOT_JSON"}), + ) + txn_id = body.get("transaction_id") + self.log.info(f"Received ping from homeserver with transaction ID {txn_id}") + return web.json_response({}) + @staticmethod def _get_with_fallback( json: dict[str, Any], field: str, unstable_prefix: str, default: Any = None @@ -269,7 +300,7 @@ async def handle_transaction( else: try: await self.to_device_handler(td) - except: + except Exception: self.log.exception("Exception in Matrix to-device event handler") if device_lists and self.device_list_handler: try: @@ -287,7 +318,7 @@ async def handle_transaction( except SerializerError: self.log.exception("Failed to deserialize ephemeral event %s", raw_edu) else: - self.handle_matrix_event(edu, ephemeral=True) + await self.handle_matrix_event(edu, ephemeral=True) for raw_event in events: try: self._fix_prev_content(raw_event) @@ -295,10 +326,10 @@ async def handle_transaction( except SerializerError: self.log.exception("Failed to deserialize event %s", raw_event) else: - self.handle_matrix_event(event) + await self.handle_matrix_event(event) return {} - def handle_matrix_event(self, event: Event, ephemeral: bool = False) -> None: + async def handle_matrix_event(self, event: Event, ephemeral: bool = False) -> None: if ephemeral: event.type = event.type.with_class(EventType.Class.EPHEMERAL) elif getattr(event, "state_key", None) is not None: @@ -312,9 +343,12 @@ async def try_handle(handler_func: HandlerFunc): except Exception: self.log.exception("Exception in Matrix event handler") - for handler in self.event_handlers: - # TODO add option to handle events synchronously - asyncio.create_task(try_handle(handler)) + if self.synchronous_handlers: + for handler in self.event_handlers: + await handler(event) + else: + for handler in self.event_handlers: + background_task.create(try_handle(handler)) def matrix_event_handler(self, func: HandlerFunc) -> HandlerFunc: self.event_handlers.append(func) diff --git a/mautrix/appservice/state_store/__init__.py b/mautrix/appservice/state_store/__init__.py index bc23b5f5..771ac252 100644 --- a/mautrix/appservice/state_store/__init__.py +++ b/mautrix/appservice/state_store/__init__.py @@ -1,4 +1,4 @@ from .file import FileASStateStore from .memory import ASStateStore -__all__ = ["FileASStateStore", "ASStateStore", "sqlalchemy", "asyncpg"] +__all__ = ["FileASStateStore", "ASStateStore", "asyncpg"] diff --git a/mautrix/appservice/state_store/sqlalchemy.py b/mautrix/appservice/state_store/sqlalchemy.py deleted file mode 100644 index 30c0b1fc..00000000 --- a/mautrix/appservice/state_store/sqlalchemy.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2022 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from mautrix.client.state_store.sqlalchemy import SQLStateStore as SQLClientStateStore - -from .memory import ASStateStore - - -class SQLASStateStore(SQLClientStateStore, ASStateStore): - def __init__(self) -> None: - SQLClientStateStore.__init__(self) - ASStateStore.__init__(self) diff --git a/mautrix/bridge/bridge.py b/mautrix/bridge/bridge.py index 7005d775..9ce9360e 100644 --- a/mautrix/bridge/bridge.py +++ b/mautrix/bridge/bridge.py @@ -8,6 +8,7 @@ from typing import Any from abc import ABC, abstractmethod from enum import Enum +import asyncio import sys from aiohttp import web @@ -59,6 +60,8 @@ class Bridge(Program, ABC): markdown_version: str manhole: br.commands.manhole.ManholeState | None homeserver_software: HomeserverSoftware + beeper_network_name: str | None = None + beeper_service_name: str | None = None def __init__( self, @@ -133,7 +136,7 @@ def prepare_config(self) -> None: self.config = self.config_class( self.args.config, self.args.registration, - self.args.base_config, + self.base_config_path, env_prefix=self.module.upper(), ) if self.args.generate_registration: @@ -244,6 +247,9 @@ async def start(self) -> None: "correct, and do they match the values in the registration?" ) sys.exit(16) + except Exception: + self.log.critical("Failed to check connection to homeserver", exc_info=True) + sys.exit(16) await self.matrix.init_encryption() self.add_startup_actions(self.matrix.init_as_bot()) @@ -253,12 +259,16 @@ async def start(self) -> None: status_endpoint = self.config["homeserver.status_endpoint"] if status_endpoint and await self.count_logged_in_users() == 0: state = BridgeState(state_event=BridgeStateEvent.UNCONFIGURED).fill() - await state.send(status_endpoint, self.az.as_token, self.log) + while not await state.send(status_endpoint, self.az.as_token, self.log): + await asyncio.sleep(5) async def system_exit(self) -> None: if hasattr(self, "db") and isinstance(self.db, Database): - self.log.trace("Stopping database due to SystemExit") + self.log.debug("Stopping database due to SystemExit") await self.db.stop() + self.log.debug("Database stopped") + elif getattr(self, "db", None): + self.log.trace("Database not started at SystemExit") async def stop(self) -> None: if self.manhole: diff --git a/mautrix/bridge/commands/login_matrix.py b/mautrix/bridge/commands/login_matrix.py index 2fd89d2d..6cc6f7aa 100644 --- a/mautrix/bridge/commands/login_matrix.py +++ b/mautrix/bridge/commands/login_matrix.py @@ -70,24 +70,3 @@ async def ping_matrix(evt: CommandEvent) -> EventID: except InvalidAccessToken: return await evt.reply("Your access token is invalid.") return await evt.reply("Your Matrix login is working.") - - -@command_handler( - needs_auth=True, - help_section=SECTION_AUTH, - help_text="Clear the Matrix sync token stored for your double puppet.", -) -async def clear_cache_matrix(evt: CommandEvent) -> EventID: - try: - puppet = await evt.sender.get_puppet() - except NotImplementedError: - return await evt.reply("This bridge has not implemented the clear-cache-matrix command") - if not puppet.is_real_user: - return await evt.reply("You are not logged in with your Matrix account.") - try: - puppet.stop() - puppet.next_batch = None - await puppet.start() - except InvalidAccessToken: - return await evt.reply("Your access token is invalid.") - return await evt.reply("Cleared cache successfully.") diff --git a/mautrix/bridge/config.py b/mautrix/bridge/config.py index 4289b178..defed222 100644 --- a/mautrix/bridge/config.py +++ b/mautrix/bridge/config.py @@ -115,7 +115,12 @@ def do_update(self, helper: ConfigUpdateHelper) -> None: copy("appservice.tls_cert") copy("appservice.tls_key") - copy("appservice.database") + if "appservice.database" in self and self["appservice.database"].startswith("sqlite:///"): + helper.base["appservice.database"] = self["appservice.database"].replace( + "sqlite:///", "sqlite:" + ) + else: + copy("appservice.database") copy("appservice.database_opts") copy("appservice.id") @@ -138,6 +143,16 @@ def do_update(self, helper: ConfigUpdateHelper) -> None: copy("bridge.encryption.default") copy("bridge.encryption.require") copy("bridge.encryption.appservice") + copy("bridge.encryption.msc4190") + copy("bridge.encryption.self_sign") + copy("bridge.encryption.delete_keys.delete_outbound_on_ack") + copy("bridge.encryption.delete_keys.dont_store_outbound") + copy("bridge.encryption.delete_keys.ratchet_on_decrypt") + copy("bridge.encryption.delete_keys.delete_fully_used_on_decrypt") + copy("bridge.encryption.delete_keys.delete_prev_on_new_session") + copy("bridge.encryption.delete_keys.delete_on_device_delete") + copy("bridge.encryption.delete_keys.periodically_delete_expired") + copy("bridge.encryption.delete_keys.delete_outdated_inbound") copy("bridge.encryption.verification_levels.receive") copy("bridge.encryption.verification_levels.send") copy("bridge.encryption.verification_levels.share") @@ -154,6 +169,7 @@ def do_update(self, helper: ConfigUpdateHelper) -> None: copy("bridge.encryption.rotation.enable_custom") copy("bridge.encryption.rotation.milliseconds") copy("bridge.encryption.rotation.messages") + copy("bridge.encryption.rotation.disable_device_change_key_rotation") copy("bridge.relay.enabled") copy_dict("bridge.relay.message_formats", override_existing_map=False) @@ -186,14 +202,18 @@ def namespaces(self) -> dict[str, list[dict[str, Any]]]: "regex": re.escape(f"@{username_format}:{homeserver}").replace(regex_ph, ".*"), } ], - "aliases": [ - { - "exclusive": True, - "regex": re.escape(f"#{alias_format}:{homeserver}").replace(regex_ph, ".*"), - } - ] - if alias_format - else [], + "aliases": ( + [ + { + "exclusive": True, + "regex": re.escape(f"#{alias_format}:{homeserver}").replace( + regex_ph, ".*" + ), + } + ] + if alias_format + else [] + ), } def generate_registration(self) -> None: @@ -222,4 +242,7 @@ def generate_registration(self) -> None: if self["appservice.ephemeral_events"]: self._registration["de.sorunome.msc2409.push_ephemeral"] = True - self._registration["push_ephemeral"] = True + self._registration["receive_ephemeral"] = True + + if self["bridge.encryption.msc4190"]: + self._registration["io.element.msc4190"] = True diff --git a/mautrix/bridge/crypto_state_store.py b/mautrix/bridge/crypto_state_store.py index 84169e2e..f08dec2c 100644 --- a/mautrix/bridge/crypto_state_store.py +++ b/mautrix/bridge/crypto_state_store.py @@ -28,29 +28,6 @@ async def is_encrypted(self, room_id: RoomID) -> bool: return portal.encrypted if portal else False -try: - from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile - - class SQLCryptoStateStore(BaseCryptoStateStore): - @staticmethod - async def find_shared_rooms(user_id: UserID) -> list[RoomID]: - return [profile.room_id for profile in UserProfile.find_rooms_with_user(user_id)] - - @staticmethod - async def get_encryption_info(room_id: RoomID) -> RoomEncryptionStateEventContent | None: - state = RoomState.get(room_id) - if not state: - return None - return state.encryption - -except ImportError: - if __optional_imports__: - raise - UserProfile = None - RoomState = None - SQLCryptoStateStore = None - - class PgCryptoStateStore(BaseCryptoStateStore): db: Database diff --git a/mautrix/bridge/custom_puppet.py b/mautrix/bridge/custom_puppet.py index 3d5c7f9e..f5befd5f 100644 --- a/mautrix/bridge/custom_puppet.py +++ b/mautrix/bridge/custom_puppet.py @@ -5,20 +5,16 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Awaitable, Iterator from abc import ABC, abstractmethod -from itertools import chain import asyncio import hashlib import hmac -import json import logging -from aiohttp import ClientConnectionError from yarl import URL -from mautrix.api import Path from mautrix.appservice import AppService, IntentAPI +from mautrix.client import ClientAPI from mautrix.errors import ( IntentError, MatrixError, @@ -26,21 +22,7 @@ MatrixRequestError, WellKnownError, ) -from mautrix.types import ( - Event, - EventFilter, - EventType, - Filter, - FilterID, - LoginType, - PresenceState, - RoomEventFilter, - RoomFilter, - RoomID, - StateFilter, - SyncToken, - UserID, -) +from mautrix.types import LoginType, MatrixUserIdentifier, RoomID, UserID from .. import bridge as br @@ -114,7 +96,6 @@ class CustomPuppetMixin(ABC): intent: The primary IntentAPI. """ - sync_with_custom_puppets: bool = True allow_discover_url: bool = False homeserver_url_map: dict[str, URL] = {} only_handle_own_synced_events: bool = True @@ -133,12 +114,9 @@ class CustomPuppetMixin(ABC): custom_mxid: UserID | None access_token: str | None base_url: URL | None - next_batch: SyncToken | None intent: IntentAPI - _sync_task: asyncio.Task | None = None - @abstractmethod async def save(self) -> None: """Save the information of this puppet. Called from :meth:`switch_mxid`""" @@ -154,6 +132,25 @@ def is_real_user(self) -> bool: return bool(self.custom_mxid and self.access_token) def _fresh_intent(self) -> IntentAPI: + if self.custom_mxid: + _, server = self.az.intent.parse_user_id(self.custom_mxid) + try: + self.base_url = self.homeserver_url_map[server] + except KeyError: + if server == self.az.domain: + self.base_url = self.az.intent.api.base_url + if self.access_token == "appservice-config" and self.custom_mxid: + try: + secret = self.login_shared_secret_map[server] + except KeyError: + raise AutologinError(f"No shared secret configured for {server}") + self.log.debug(f"Using as_token for double puppeting {self.custom_mxid}") + return self.az.intent.user( + self.custom_mxid, + secret.decode("utf-8").removeprefix("as_token:"), + self.base_url, + as_token=True, + ) return ( self.az.intent.user(self.custom_mxid, self.access_token, self.base_url) if self.is_real_user @@ -174,6 +171,8 @@ async def _login_with_shared_secret(cls, mxid: UserID) -> str: secret = cls.login_shared_secret_map[server] except KeyError: raise AutologinError(f"No shared secret configured for {server}") + if secret.startswith(b"as_token:"): + return "appservice-config" try: base_url = cls.homeserver_url_map[server] except KeyError: @@ -181,31 +180,32 @@ async def _login_with_shared_secret(cls, mxid: UserID) -> str: base_url = cls.az.intent.api.base_url else: raise AutologinError(f"No homeserver URL configured for {server}") - url = base_url / str(Path.v3.login) - headers = {"Content-Type": "application/json"} - login_req = { - "initial_device_display_name": cls.login_device_name, - "device_id": cls.login_device_name, - "identifier": { - "type": "m.id.user", - "user": mxid, - }, - } + client = ClientAPI(base_url=base_url) + login_args = {} if secret == b"appservice": - login_req["type"] = str(LoginType.APPSERVICE) - headers["Authorization"] = f"Bearer {cls.az.as_token}" + login_type = LoginType.APPSERVICE + client.api.token = cls.az.as_token else: - login_req["type"] = str(LoginType.PASSWORD) - login_req["password"] = hmac.new( - secret, mxid.encode("utf-8"), hashlib.sha512 - ).hexdigest() - resp = await cls.az.http_session.post(url, data=json.dumps(login_req), headers=headers) - data = await resp.json() - try: - return data["access_token"] - except KeyError: - error_msg = data.get("error", data.get("errcode", f"HTTP {resp.status}")) - raise AutologinError(f"Didn't get an access token: {error_msg}") from None + flows = await client.get_login_flows() + flow = flows.get_first_of_type(LoginType.DEVTURE_SHARED_SECRET, LoginType.PASSWORD) + if not flow: + raise AutologinError("No supported shared secret auth login flows") + login_type = flow.type + token = hmac.new(secret, mxid.encode("utf-8"), hashlib.sha512).hexdigest() + if login_type == LoginType.DEVTURE_SHARED_SECRET: + login_args["token"] = token + elif login_type == LoginType.PASSWORD: + login_args["password"] = token + resp = await client.login( + identifier=MatrixUserIdentifier(user=mxid), + device_id=cls.login_device_name, + initial_device_display_name=cls.login_device_name, + login_type=login_type, + **login_args, + store_access_token=False, + update_hs_url=False, + ) + return resp.access_token async def switch_mxid( self, access_token: str | None, mxid: UserID | None, start_sync_task: bool = True @@ -218,11 +218,11 @@ async def switch_mxid( the appservice-owned ID. mxid: The expected Matrix user ID of the custom account, or ``None`` when ``access_token`` is None. - start_sync_task: Whether or not syncing should be started after logging in. """ if access_token == "auto": access_token = await self._login_with_shared_secret(mxid) - self.log.debug(f"Logged in for {mxid} using shared secret") + if access_token != "appservice-config": + self.log.debug(f"Logged in for {mxid} using shared secret") if mxid is not None: _, mxid_domain = self.az.intent.parse_user_id(mxid) @@ -248,7 +248,7 @@ async def switch_mxid( self.base_url = base_url self.intent = self._fresh_intent() - await self.start(start_sync_task=start_sync_task, check_e2ee_keys=True) + await self.start(check_e2ee_keys=True) try: del self.by_custom_mxid[prev_mxid] @@ -273,7 +273,6 @@ async def _invalidate_double_puppet(self) -> None: del self.by_custom_mxid[self.custom_mxid] self.custom_mxid = None self.access_token = None - self.next_batch = None await self.save() self.intent = self._fresh_intent() @@ -293,7 +292,7 @@ async def start( except MatrixInvalidToken as e: if retry_auto_login and self.custom_mxid and self.can_auto_login(self.custom_mxid): self.log.debug(f"Got {e.errcode} while trying to initialize custom mxid") - await self.switch_mxid("auto", self.custom_mxid, start_sync_task=start_sync_task) + await self.switch_mxid("auto", self.custom_mxid) return self.log.warning(f"Got {e.errcode} while trying to initialize custom mxid") whoami = None @@ -316,19 +315,14 @@ async def start( if device_keys and len(device_keys.keys) > 0: await self._invalidate_double_puppet() raise EncryptionKeysFound() - if self.sync_with_custom_puppets and start_sync_task: - if self._sync_task: - self._sync_task.cancel() - self.log.info(f"Initialized custom mxid: {whoami.user_id}. Starting sync task") - self._sync_task = asyncio.create_task(self._try_sync()) - else: - self.log.info(f"Initialized custom mxid: {whoami.user_id}. Not starting sync task") + self.log.info(f"Initialized custom mxid: {whoami.user_id}") def stop(self) -> None: - """Cancel the sync task.""" - if self._sync_task: - self._sync_task.cancel() - self._sync_task = None + """ + No-op + + .. deprecated:: 0.20.1 + """ async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool: """ @@ -351,112 +345,3 @@ async def _leave_rooms_with_default_user(self) -> None: await self.intent.ensure_joined(room_id) except (IntentError, MatrixRequestError): pass - - def _create_sync_filter(self) -> Awaitable[FilterID]: - all_events = EventType.find("*") - return self.intent.create_filter( - Filter( - account_data=EventFilter(types=[all_events]), - presence=EventFilter( - types=[EventType.PRESENCE], - senders=[self.custom_mxid] if self.only_handle_own_synced_events else None, - ), - room=RoomFilter( - include_leave=False, - state=StateFilter(not_types=[all_events]), - timeline=RoomEventFilter(not_types=[all_events]), - account_data=RoomEventFilter(not_types=[all_events]), - ephemeral=RoomEventFilter( - types=[ - EventType.TYPING, - EventType.RECEIPT, - ] - ), - ), - ) - ) - - def _filter_events(self, room_id: RoomID, events: list[dict]) -> Iterator[Event]: - for event in events: - event["room_id"] = room_id - if self.only_handle_own_synced_events: - # We only want events about the custom puppet user, but we can't use - # filters for typing and read receipt events. - evt_type = EventType.find(event.get("type", None)) - event.setdefault("content", {}) - if evt_type == EventType.TYPING: - is_typing = self.custom_mxid in event["content"].get("user_ids", []) - event["content"]["user_ids"] = [self.custom_mxid] if is_typing else [] - elif evt_type == EventType.RECEIPT: - try: - event_id, receipt = event["content"].popitem() - data = receipt["m.read"][self.custom_mxid] - event["content"] = {event_id: {"m.read": {self.custom_mxid: data}}} - except KeyError: - continue - yield event - - def _handle_sync(self, sync_resp: dict) -> None: - # Get events from rooms -> join -> [room_id] -> ephemeral -> events (array) - ephemeral_events = ( - event - for room_id, data in sync_resp.get("rooms", {}).get("join", {}).items() - for event in self._filter_events(room_id, data.get("ephemeral", {}).get("events", [])) - ) - - # Get events from presence -> events (array) - presence_events = sync_resp.get("presence", {}).get("events", []) - - # Deserialize and handle all events - for event in chain(ephemeral_events, presence_events): - asyncio.create_task(self.mx.try_handle_sync_event(Event.deserialize(event))) - - async def _try_sync(self) -> None: - try: - await self._sync() - except asyncio.CancelledError: - self.log.info(f"Syncing for {self.custom_mxid} cancelled") - except Exception: - self.log.critical(f"Fatal error syncing {self.custom_mxid}", exc_info=True) - - async def _sync(self) -> None: - if not self.is_real_user: - self.log.warning("Called sync() for non-custom puppet.") - return - custom_mxid: UserID = self.custom_mxid - access_token_at_start: str = self.access_token - errors: int = 0 - filter_id: FilterID = await self._create_sync_filter() - self.log.debug(f"Starting syncer for {custom_mxid} with sync filter {filter_id}.") - while access_token_at_start == self.access_token: - try: - cur_batch = self.next_batch - sync_resp = await self.intent.sync( - filter_id=filter_id, since=cur_batch, set_presence=PresenceState.OFFLINE - ) - try: - self.next_batch = sync_resp.get("next_batch", None) - except Exception: - self.log.warning("Failed to store next batch", exc_info=True) - errors = 0 - if cur_batch is not None: - self._handle_sync(sync_resp) - except MatrixInvalidToken: - # TODO when not using syncing, we should still check this occasionally and relogin - self.log.warning(f"Access token for {custom_mxid} got invalidated, restarting...") - await self.start(retry_auto_login=True, start_sync_task=False) - if self.is_real_user: - self.log.info("Successfully relogined custom puppet, continuing sync") - filter_id = await self._create_sync_filter() - access_token_at_start = self.access_token - else: - self.log.warning("Something went wrong during relogin") - raise - except (MatrixError, ClientConnectionError, asyncio.TimeoutError) as e: - errors += 1 - wait = min(errors, 11) ** 2 - self.log.warning( - f"Syncer for {custom_mxid} errored: {e}. Waiting for {wait} seconds..." - ) - await asyncio.sleep(wait) - self.log.debug(f"Syncer for custom puppet {custom_mxid} stopped.") diff --git a/mautrix/bridge/e2ee.py b/mautrix/bridge/e2ee.py index 516e54bd..1525b388 100644 --- a/mautrix/bridge/e2ee.py +++ b/mautrix/bridge/e2ee.py @@ -13,7 +13,7 @@ from mautrix.appservice import AppService from mautrix.client import Client, InternalEventType, SyncStore from mautrix.crypto import CryptoStore, OlmMachine, PgCryptoStore, RejectKeyShare, StateStore -from mautrix.errors import EncryptionError, SessionNotFound +from mautrix.errors import EncryptionError, MForbidden, MNotFound, SessionNotFound from mautrix.types import ( JSON, DeviceIdentity, @@ -34,19 +34,13 @@ StateFilter, TrustState, ) +from mautrix.util import background_task from mautrix.util.async_db import Database from mautrix.util.logging import TraceLogger from .. import bridge as br from .crypto_state_store import PgCryptoStateStore -try: - from mautrix.client.state_store.sqlalchemy import UserProfile -except ImportError: - if __optional_imports__: - raise - UserProfile = None - class EncryptionManager: loop: asyncio.AbstractEventLoop @@ -61,6 +55,10 @@ class EncryptionManager: min_send_trust: TrustState key_sharing_enabled: bool appservice_mode: bool + periodically_delete_expired_keys: bool + delete_outdated_inbound: bool + msc4190: bool + self_sign: bool bridge: br.Bridge az: AppService @@ -68,6 +66,7 @@ class EncryptionManager: _id_suffix: str _share_session_events: dict[RoomID, asyncio.Event] + _key_delete_task: asyncio.Task | None def __init__( self, @@ -100,6 +99,7 @@ def __init__( sync_store=self.crypto_store, log=self.log.getChild("client"), default_retry_count=default_http_retry_count, + state_store=self.bridge.state_store, ) self.crypto = OlmMachine(self.client, self.crypto_store, self.state_store) self.client.add_event_handler(InternalEventType.SYNC_STOPPED, self._exit_on_sync_fail) @@ -110,11 +110,30 @@ def __init__( self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"]) self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"] self.appservice_mode = bridge.config["bridge.encryption.appservice"] + self.msc4190 = bridge.config["bridge.encryption.msc4190"] + self.self_sign = bridge.config["bridge.encryption.self_sign"] if self.appservice_mode: self.az.otk_handler = self.crypto.handle_as_otk_counts self.az.device_list_handler = self.crypto.handle_as_device_lists self.az.to_device_handler = self.crypto.handle_as_to_device_event + self.periodically_delete_expired_keys = False + self.delete_outdated_inbound = False + self._key_delete_task = None + del_cfg = bridge.config["bridge.encryption.delete_keys"] + if del_cfg: + self.crypto.delete_outbound_keys_on_ack = del_cfg["delete_outbound_on_ack"] + self.crypto.dont_store_outbound_keys = del_cfg["dont_store_outbound"] + self.crypto.delete_previous_keys_on_receive = del_cfg["delete_prev_on_new_session"] + self.crypto.ratchet_keys_on_decrypt = del_cfg["ratchet_on_decrypt"] + self.crypto.delete_fully_used_keys_on_decrypt = del_cfg["delete_fully_used_on_decrypt"] + self.crypto.delete_keys_on_device_delete = del_cfg["delete_on_device_delete"] + self.periodically_delete_expired_keys = del_cfg["periodically_delete_expired"] + self.delete_outdated_inbound = del_cfg["delete_outdated_inbound"] + self.crypto.disable_device_change_key_rotation = bridge.config[ + "bridge.encryption.rotation.disable_device_change_key_rotation" + ] + async def _exit_on_sync_fail(self, data) -> None: if data["error"]: self.log.critical("Exiting due to crypto sync error") @@ -132,9 +151,9 @@ async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInf f"Rejecting key request from blacklisted device " f"{device.user_id}/{device.device_id}", code=RoomKeyWithheldCode.BLACKLISTED, - reason="You have been blacklisted by this device", + reason="Your device has been blacklisted by the bridge", ) - elif device.trust >= self.crypto.share_keys_min_trust: + elif await self.crypto.resolve_trust(device) >= self.crypto.share_keys_min_trust: portal = await self.bridge.get_portal(request.room_id) if portal is None: raise RejectKeyShare( @@ -161,7 +180,7 @@ async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInf f"Rejecting key request from unverified device " f"{device.user_id}/{device.device_id}", code=RoomKeyWithheldCode.UNVERIFIED, - reason="You have not been verified by this device", + reason="Your device is not trusted by the bridge", ) def _ignore_user(self, user_id: str) -> bool: @@ -230,12 +249,13 @@ async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> M return decrypted async def start(self) -> None: - flows = await self.client.get_login_flows() - if not flows.supports_type(LoginType.APPSERVICE): - self.log.critical( - "Encryption enabled in config, but homeserver does not support appservice login" - ) - sys.exit(30) + if not self.msc4190: + flows = await self.client.get_login_flows() + if not flows.supports_type(LoginType.APPSERVICE): + self.log.critical( + "Encryption enabled in config, but homeserver does not support appservice login" + ) + sys.exit(30) self.log.debug("Logging in with bridge bot user") if self.crypto_db: try: @@ -246,27 +266,84 @@ async def start(self) -> None: device_id = await self.crypto_store.get_device_id() if device_id: self.log.debug(f"Found device ID in database: {device_id}") - # We set the API token to the AS token here to authenticate the appservice login - # It'll get overridden after the login - self.client.api.token = self.az.as_token - await self.client.login( - login_type=LoginType.APPSERVICE, - device_name=self.device_name, - device_id=device_id, - store_access_token=True, - update_hs_url=False, - ) + + if self.msc4190: + if not device_id: + self.log.debug("Creating bot device with MSC4190") + self.client.api.token = self.az.as_token + await self.client.create_device_msc4190( + device_id=device_id, initial_display_name=self.device_name + ) + else: + # We set the API token to the AS token here to authenticate the appservice login + # It'll get overridden after the login + self.client.api.token = self.az.as_token + await self.client.login( + login_type=LoginType.APPSERVICE, + device_name=self.device_name, + device_id=device_id, + store_access_token=True, + update_hs_url=False, + ) + await self.crypto.load() if not device_id: await self.crypto_store.put_device_id(self.client.device_id) self.log.debug(f"Logged in with new device ID {self.client.device_id}") + await self.crypto.share_keys() elif self.crypto.account.shared: await self._verify_keys_are_on_server() + else: + await self.crypto.share_keys() + if self.self_sign: + trust_state = await self.crypto.resolve_trust(self.crypto.own_identity) + if trust_state < TrustState.CROSS_SIGNED_UNTRUSTED: + recovery_key = await self.crypto.generate_recovery_key() + self.log.info(f"Generated recovery key and signed own device: {recovery_key}") + else: + self.log.debug(f"Own device is already verified ({trust_state})") if self.appservice_mode: self.log.info("End-to-bridge encryption support is enabled (appservice mode)") else: _ = self.client.start(self._filter) self.log.info("End-to-bridge encryption support is enabled (sync mode)") + if self.delete_outdated_inbound: + deleted = await self.crypto_store.redact_outdated_group_sessions() + if len(deleted) > 0: + self.log.debug( + f"Deleted {len(deleted)} inbound keys which lacked expiration metadata" + ) + if self.periodically_delete_expired_keys: + self._key_delete_task = background_task.create(self._periodically_delete_keys()) + background_task.create(self._resync_encryption_info()) + + async def _resync_encryption_info(self) -> None: + rows = await self.crypto_db.fetch( + """SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'""" + ) + room_ids = [row["room_id"] for row in rows] + if not room_ids: + return + self.log.debug(f"Resyncing encryption state event in rooms: {room_ids}") + for room_id in room_ids: + try: + evt = await self.client.get_state_event(room_id, EventType.ROOM_ENCRYPTION) + except (MNotFound, MForbidden) as e: + self.log.debug(f"Failed to get encryption state in {room_id}: {e}") + q = """ + UPDATE mx_room_state SET encryption=NULL + WHERE room_id=$1 AND encryption='{"resync":true}' + """ + await self.crypto_db.execute(q, room_id) + else: + self.log.debug(f"Resynced encryption state in {room_id}: {evt}") + q = """ + UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2 + WHERE room_id=$3 AND max_age IS NULL and max_messages IS NULL + """ + await self.crypto_db.execute( + q, evt.rotation_period_ms, evt.rotation_period_msgs, room_id + ) async def _verify_keys_are_on_server(self) -> None: self.log.debug("Making sure keys are still on server") @@ -289,6 +366,9 @@ async def _verify_keys_are_on_server(self) -> None: sys.exit(34) async def stop(self) -> None: + if self._key_delete_task: + self._key_delete_task.cancel() + self._key_delete_task = None self.client.stop() await self.crypto_store.close() if self.crypto_db: @@ -308,3 +388,12 @@ def _filter(self) -> Filter: ephemeral=RoomEventFilter(not_types=[all_events]), ), ) + + async def _periodically_delete_keys(self) -> None: + while True: + deleted = await self.crypto_store.redact_expired_group_sessions() + if deleted: + self.log.info(f"Deleted expired megolm sessions: {deleted}") + else: + self.log.debug("No expired megolm sessions found") + await asyncio.sleep(24 * 60 * 60) diff --git a/mautrix/bridge/matrix.py b/mautrix/bridge/matrix.py index 144264aa..e5399094 100644 --- a/mautrix/bridge/matrix.py +++ b/mautrix/bridge/matrix.py @@ -5,6 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +from collections import defaultdict import asyncio import logging import sys @@ -56,7 +57,7 @@ Version, VersionsResponse, ) -from mautrix.util import markdown +from mautrix.util import background_task, markdown from mautrix.util.logging import TraceLogger from mautrix.util.message_send_checkpoint import ( CHECKPOINT_TYPES, @@ -143,6 +144,7 @@ class BaseMatrixHandler: media_config: MediaRepoConfig versions: VersionsResponse minimum_spec_version: Version = SpecVersions.V11 + room_locks: dict[str, asyncio.Lock] user_id_prefix: str user_id_suffix: str @@ -159,6 +161,7 @@ def __init__( self.media_config = MediaRepoConfig(upload_size=50 * 1024 * 1024) self.versions = VersionsResponse.deserialize({"versions": ["v1.3"]}) self.az.matrix_event_handler(self.int_handle_event) + self.room_locks = defaultdict(asyncio.Lock) self.e2ee = None self.require_e2ee = False @@ -218,34 +221,46 @@ async def check_versions(self) -> None: async def wait_for_connection(self) -> None: self.log.info("Ensuring connectivity to homeserver") - errors = 0 - tried_to_register = False while True: try: self.versions = await self.az.intent.versions() - await self.check_versions() - await self.az.intent.whoami() break - except (MUnknownToken, MExclusive): - # These are probably not going to resolve themselves by waiting - raise except MForbidden: - if not tried_to_register: - self.log.debug( - "Whoami endpoint returned M_FORBIDDEN, " - "trying to register bridge bot before retrying..." - ) - await self.az.intent.ensure_registered() - tried_to_register = True - else: - raise + self.log.debug( + "/versions endpoint returned M_FORBIDDEN, " + "trying to register bridge bot before retrying..." + ) + await self.az.intent.ensure_registered() except Exception: - errors += 1 - if errors <= 6: - self.log.exception("Connection to homeserver failed, retrying in 10 seconds") - await asyncio.sleep(10) - else: - raise + self.log.exception("Connection to homeserver failed, retrying in 10 seconds") + await asyncio.sleep(10) + await self.check_versions() + try: + await self.az.intent.whoami() + except MForbidden: + self.log.debug( + "Whoami endpoint returned M_FORBIDDEN, " + "trying to register bridge bot before retrying..." + ) + await self.az.intent.ensure_registered() + await self.az.intent.whoami() + if self.versions.supports("fi.mau.msc2659.stable") or self.versions.supports_at_least( + SpecVersions.V17 + ): + try: + txn_id = self.az.intent.api.get_txn_id() + duration = await self.az.ping_self(txn_id) + self.log.debug( + "Homeserver->bridge connection works, " + f"roundtrip time is {duration} ms (txn ID: {txn_id})" + ) + except Exception: + self.log.exception("Error checking homeserver -> bridge connection") + sys.exit(16) + else: + self.log.debug( + "Homeserver does not support checking status of homeserver -> bridge connection" + ) try: self.media_config = await self.az.intent.get_media_repo_config() except Exception: @@ -269,6 +284,16 @@ async def init_as_bot(self) -> None: except Exception: self.log.exception("Failed to set bot avatar") + if self.bridge.homeserver_software.is_hungry and self.bridge.beeper_network_name: + self.log.debug("Setting contact info on the appservice bot") + await self.az.intent.beeper_update_profile( + { + "com.beeper.bridge.service": self.bridge.beeper_service_name, + "com.beeper.bridge.network": self.bridge.beeper_network_name, + "com.beeper.bridge.is_bridge_bot": True, + } + ) + async def init_encryption(self) -> None: if self.e2ee: await self.e2ee.start() @@ -382,7 +407,7 @@ async def handle_puppet_nonportal_invite( members = await intent.get_room_members(room_id) except MatrixError: self.log.exception(f"Failed to get state after joining {room_id} as {intent.mxid}") - asyncio.create_task(intent.leave_room(room_id, reason="Internal error")) + background_task.create(intent.leave_room(room_id, reason="Internal error")) return if create_evt.type == RoomType.SPACE: await self.handle_puppet_space_invite(room_id, puppet, invited_by, evt) @@ -400,19 +425,20 @@ async def handle_puppet_invite( await intent.leave_room(room_id, reason="You're not allowed to invite this ghost.") return - portal = await self.bridge.get_portal(room_id) - if portal: - try: - await portal.handle_matrix_invite(invited_by, puppet) - except br.RejectMatrixInvite as e: - await intent.leave_room(room_id, reason=e.message) - except br.IgnoreMatrixInvite: - pass + async with self.room_locks[room_id]: + portal = await self.bridge.get_portal(room_id) + if portal: + try: + await portal.handle_matrix_invite(invited_by, puppet) + except br.RejectMatrixInvite as e: + await intent.leave_room(room_id, reason=e.message) + except br.IgnoreMatrixInvite: + pass + else: + await intent.join_room(room_id) + return else: - await intent.join_room(room_id) - return - else: - await self.handle_puppet_nonportal_invite(room_id, puppet, invited_by, evt) + await self.handle_puppet_nonportal_invite(room_id, puppet, invited_by, evt) async def handle_invite( self, room_id: RoomID, user_id: UserID, invited_by: br.BaseUser, evt: StateEvent @@ -537,6 +563,28 @@ def is_command(self, message: MessageEventContent) -> tuple[bool, str]: text = text[len(prefix) + 1 :].lstrip() return is_command, text + async def _send_mss( + self, + evt: Event, + status: MessageStatus, + reason: MessageStatusReason | None = None, + error: str | None = None, + message: str | None = None, + ) -> None: + if not self.config.get("bridge.message_status_events", False): + return + status_content = BeeperMessageStatusEventContent( + network="", # TODO set network properly + relates_to=RelatesTo(rel_type=RelationType.REFERENCE, event_id=evt.event_id), + status=status, + reason=reason, + error=error, + message=message, + ) + await self.az.intent.send_message_event( + evt.room_id, EventType.BEEPER_MESSAGE_STATUS, status_content + ) + async def _send_crypto_status_error( self, evt: Event, @@ -574,19 +622,13 @@ async def _send_crypto_status_error( "it by default on the bridge (bridge -> encryption -> default)." ) - if self.config.get("bridge.message_status_events", False): - status_content = BeeperMessageStatusEventContent( - network="", # TODO set network properly - relates_to=RelatesTo(rel_type=RelationType.REFERENCE, event_id=evt.event_id), - status=MessageStatus.RETRIABLE if is_final else MessageStatus.PENDING, - reason=MessageStatusReason.UNDECRYPTABLE, - error=msg, - message=err.human_message if err else None, - ) - status_content.fill_legacy_booleans() - await self.az.intent.send_message_event( - evt.room_id, EventType.BEEPER_MESSAGE_STATUS, status_content - ) + await self._send_mss( + evt, + status=MessageStatus.RETRIABLE if is_final else MessageStatus.PENDING, + reason=MessageStatusReason.UNDECRYPTABLE, + error=str(err), + message=err.human_message if err else None, + ) return event_id @@ -678,6 +720,13 @@ async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) - except Exception as e: self.log.debug(f"Error handling command {command} from {sender}: {e}") self._send_message_checkpoint(evt, MessageSendCheckpointStep.COMMAND, e) + await self._send_mss( + evt, + status=MessageStatus.FAIL, + reason=MessageStatusReason.GENERIC_ERROR, + error="", + message="Command execution failed", + ) else: await MessageSendCheckpoint( event_id=event_id, @@ -693,6 +742,7 @@ async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) - self.az.as_token, self.log, ) + await self._send_mss(evt, status=MessageStatus.SUCCESS) else: self.log.debug( f"Ignoring event {event_id} from {sender.mxid}:" @@ -701,6 +751,13 @@ async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) - self._send_message_checkpoint( evt, MessageSendCheckpointStep.COMMAND, "not a command and not a portal room" ) + await self._send_mss( + evt, + status=MessageStatus.FAIL, + reason=MessageStatusReason.UNSUPPORTED, + error="Unknown room", + message="Unknown room", + ) async def _is_direct_chat(self, room_id: RoomID) -> tuple[bool, bool]: try: @@ -784,7 +841,7 @@ async def handle_encrypted(self, evt: EncryptedEvent) -> None: try: decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=3) except SessionNotFound as e: - await self._handle_encrypted_wait(evt, e, wait=6) + await self._handle_encrypted_wait(evt, e, wait=22) except DecryptionError as e: self.log.warning(f"Failed to decrypt {evt.event_id}: {e}") self.log.trace("%s decryption traceback:", evt.event_id, exc_info=True) @@ -799,7 +856,7 @@ async def _handle_encrypted_wait( f"Couldn't find session {err.session_id} trying to decrypt {evt.event_id}," " waiting even longer" ) - asyncio.create_task( + background_task.create( self.e2ee.crypto.request_room_key( evt.room_id, evt.content.sender_key, @@ -876,7 +933,7 @@ def _send_message_checkpoint( info=str(err) if err else None, retry_num=retry_num, ) - asyncio.create_task(checkpoint.send(endpoint, self.az.as_token, self.log)) + background_task.create(checkpoint.send(endpoint, self.az.as_token, self.log)) allowed_event_classes: tuple[type, ...] = ( MessageEvent, diff --git a/mautrix/bridge/portal.py b/mautrix/bridge/portal.py index f51a0ae8..05d67e3b 100644 --- a/mautrix/bridge/portal.py +++ b/mautrix/bridge/portal.py @@ -30,6 +30,7 @@ TextMessageEventContent, UserID, ) +from mautrix.util import background_task from mautrix.util.logging import TraceLogger from mautrix.util.simple_lock import SimpleLock @@ -385,7 +386,7 @@ async def _disappear_event(self, msg: br.AbstractDisappearingMessage) -> None: await self._do_disappear(msg.event_id) self.log.debug(f"Expired event {msg.event_id} disappeared successfully") except Exception as e: - self.log.warning(f"Failed to make expired event {msg.event_id} disappear: {e}", e) + self.log.warning(f"Failed to make expired event {msg.event_id} disappear: {e}") async def _do_disappear(self, event_id: EventID) -> None: await self.main_intent.redact(self.mxid, event_id) @@ -402,7 +403,7 @@ async def restart_scheduled_disappearing(cls) -> None: for msg in msgs: portal = await cls.bridge.get_portal(msg.room_id) if portal and portal.mxid: - asyncio.create_task(portal._disappear_event(msg)) + background_task.create(portal._disappear_event(msg)) else: await msg.delete() @@ -418,7 +419,7 @@ async def schedule_disappearing(self) -> None: for msg in msgs: msg.start_timer() await msg.update() - asyncio.create_task(self._disappear_event(msg)) + background_task.create(self._disappear_event(msg)) async def _send_message( self, @@ -431,7 +432,7 @@ async def _send_message( event_type, content = await self.matrix.e2ee.encrypt(self.mxid, event_type, content) event_id = await intent.send_message_event(self.mxid, event_type, content, **kwargs) if intent.api.is_real_user: - asyncio.create_task(intent.mark_read(self.mxid, event_id)) + background_task.create(intent.mark_read(self.mxid, event_id)) return event_id @property diff --git a/mautrix/bridge/state_store/__init__.py b/mautrix/bridge/state_store/__init__.py index 0e8137a5..b990ac86 100644 --- a/mautrix/bridge/state_store/__init__.py +++ b/mautrix/bridge/state_store/__init__.py @@ -1 +1 @@ -__all__ = ["asyncpg", "sqlalchemy"] +__all__ = ["asyncpg"] diff --git a/mautrix/bridge/state_store/sqlalchemy.py b/mautrix/bridge/state_store/sqlalchemy.py deleted file mode 100644 index 2fa7353e..00000000 --- a/mautrix/bridge/state_store/sqlalchemy.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) 2022 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from __future__ import annotations - -from typing import Awaitable, Callable, Union - -from mautrix.appservice.state_store.sqlalchemy import SQLASStateStore -from mautrix.types import UserID - -from ..puppet import BasePuppet - -GetPuppetFunc = Union[ - Callable[[UserID], Awaitable[BasePuppet]], Callable[[UserID, bool], Awaitable[BasePuppet]] -] - - -class SQLBridgeStateStore(SQLASStateStore): - def __init__(self, get_puppet: GetPuppetFunc, get_double_puppet: GetPuppetFunc) -> None: - super().__init__() - self.get_puppet = get_puppet - self.get_double_puppet = get_double_puppet - - async def is_registered(self, user_id: UserID) -> bool: - puppet = await self.get_puppet(user_id) - if puppet: - return puppet.is_registered - custom_puppet = await self.get_double_puppet(user_id) - if custom_puppet: - return True - return await super().is_registered(user_id) - - async def registered(self, user_id: UserID) -> None: - puppet = await self.get_puppet(user_id, True) - if puppet: - puppet.is_registered = True - await puppet.save() - else: - await super().registered(user_id) diff --git a/mautrix/bridge/user.py b/mautrix/bridge/user.py index bea6bdf7..66860c04 100644 --- a/mautrix/bridge/user.py +++ b/mautrix/bridge/user.py @@ -16,6 +16,7 @@ from mautrix.appservice import AppService from mautrix.errors import MNotFound from mautrix.types import EventID, EventType, Membership, MessageType, RoomID, UserID +from mautrix.util import background_task from mautrix.util.bridge_state import BridgeState, BridgeStateEvent from mautrix.util.logging import TraceLogger from mautrix.util.message_send_checkpoint import ( @@ -244,7 +245,7 @@ def send_remote_checkpoint( """ if not self.bridge.config["homeserver.message_send_checkpoint_endpoint"]: return WrappedTask(task=None) - task = asyncio.create_task( + task = background_task.create( MessageSendCheckpoint( event_id=event_id, room_id=room_id, diff --git a/mautrix/client/api/authentication.py b/mautrix/client/api/authentication.py index cd77272d..0f6249ea 100644 --- a/mautrix/client/api/authentication.py +++ b/mautrix/client/api/authentication.py @@ -5,6 +5,8 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +import secrets + from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( @@ -117,6 +119,19 @@ async def login( self.api.base_url = base_url.rstrip("/") return resp_data + async def create_device_msc4190(self, device_id: str, initial_display_name: str) -> None: + """ + Create a Device for a user of the homeserver using appservice interface defined in MSC4190 + """ + if len(device_id) == 0: + device_id = DeviceID(secrets.token_urlsafe(10)) + self.api.as_user_id = self.mxid + await self.api.request( + Method.PUT, Path.v3.devices[device_id], {"display_name": initial_display_name} + ) + self.api.as_device_id = device_id + self.device_id = device_id + async def logout(self, clear_access_token: bool = True) -> None: """ Invalidates an existing access token, so that it can no longer be used for authorization. diff --git a/mautrix/client/api/events.py b/mautrix/client/api/events.py index 99e1c178..bd415b9b 100644 --- a/mautrix/client/api/events.py +++ b/mautrix/client/api/events.py @@ -5,7 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Awaitable +from typing import Awaitable, Literal, overload import json from mautrix.api import Method, Path @@ -168,12 +168,41 @@ async def get_event_context( ) return EventContext.deserialize(resp) + @overload async def get_state_event( self, room_id: RoomID, event_type: EventType, state_key: str = "", - ) -> StateEventContent: + *, + format: Literal["content"] = "content", + ) -> StateEventContent: ... + @overload + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: Literal["event"], + ) -> StateEvent: ... + @overload + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: ... + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: """ Looks up the contents of a state event in a room. If the user is joined to the room then the state is taken from the current state of the room. If the user has left the room then the @@ -185,6 +214,9 @@ async def get_state_event( room_id: The ID of the room to look up the state in. event_type: The type of state to look up. state_key: The key of the state to look up. Defaults to empty string. + format: The format of the state event to return. Defaults to "content", which only returns + the content of the state event. If set to "event", the full event is returned. + See https://github.com/matrix-org/matrix-spec/issues/1047 for more info. Returns: The state event. @@ -192,11 +224,17 @@ async def get_state_event( content = await self.api.request( Method.GET, Path.v3.rooms[room_id].state[event_type][state_key], + query_params={"format": format} if format != "content" else None, metrics_method="getStateEvent", ) - content["__mautrix_event_type"] = event_type try: - return StateEvent.deserialize_content(content) + if format == "content": + content["__mautrix_event_type"] = event_type + return StateEvent.deserialize_content(content) + elif format == "event": + return StateEvent.deserialize(content) + else: + return content except SerializerError as e: raise MatrixResponseError("Invalid state event in response") from e @@ -357,15 +395,13 @@ async def get_messages( try: return PaginatedMessages( content["start"], - content["end"], + content.get("end"), [Event.deserialize(event) for event in content["chunk"]], ) except KeyError: if "start" not in content: raise MatrixResponseError("`start` not in response.") - elif "end" not in content: - raise MatrixResponseError("`start` not in response.") - raise MatrixResponseError("`content` not in response.") + raise MatrixResponseError("`chunk` not in response.") except SerializerError as e: raise MatrixResponseError("Invalid events in response") from e diff --git a/mautrix/client/api/filtering.py b/mautrix/client/api/filtering.py index 09889ec5..b8df1f36 100644 --- a/mautrix/client/api/filtering.py +++ b/mautrix/client/api/filtering.py @@ -50,9 +50,11 @@ async def create_filter(self, filter_params: Filter) -> FilterID: resp = await self.api.request( Method.POST, Path.v3.user[self.mxid].filter, - filter_params.serialize() - if isinstance(filter_params, Serializable) - else filter_params, + ( + filter_params.serialize() + if isinstance(filter_params, Serializable) + else filter_params + ), ) try: return resp["filter_id"] diff --git a/mautrix/client/api/modules/crypto.py b/mautrix/client/api/modules/crypto.py index af656916..2d879a63 100644 --- a/mautrix/client/api/modules/crypto.py +++ b/mautrix/client/api/modules/crypto.py @@ -5,13 +5,17 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any +from typing import Any, Union from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( + JSON, ClaimKeysResponse, + CrossSigningKeys, + CrossSigningUsage, DeviceID, + DeviceKeys, EncryptionKeyAlgorithm, EventType, QueryKeysResponse, @@ -82,7 +86,7 @@ async def send_to_one_device( async def upload_keys( self, one_time_keys: dict[str, Any] | None = None, - device_keys: dict[str, Any] | None = None, + device_keys: DeviceKeys | dict[str, Any] | None = None, ) -> dict[EncryptionKeyAlgorithm, int]: """ Publishes end-to-end encryption keys for the device. @@ -102,8 +106,12 @@ async def upload_keys( """ data = {} if device_keys: + if isinstance(device_keys, Serializable): + device_keys = device_keys.serialize() data["device_keys"] = device_keys if one_time_keys: + if isinstance(one_time_keys, Serializable): + one_time_keys = one_time_keys.serialize() data["one_time_keys"] = one_time_keys resp = await self.api.request(Method.POST, Path.v3.keys.upload, data) try: @@ -116,6 +124,43 @@ async def upload_keys( except AttributeError as e: raise MatrixResponseError("Invalid `one_time_key_counts` field in response.") from e + async def upload_cross_signing_keys( + self, + keys: dict[CrossSigningUsage, CrossSigningKeys], + auth: dict[str, JSON] | None = None, + ) -> None: + await self.api.request( + Method.POST, + Path.v3.keys.device_signing.upload, + {f"{usage}_key": key.serialize() for usage, key in keys.items()} + | ({"auth": auth} if auth else {}), + ) + + async def upload_one_signature( + self, + user_id: UserID, + device_id: DeviceID, + keys: Union[DeviceKeys, CrossSigningKeys], + ) -> None: + await self.api.request( + Method.POST, Path.v3.keys.signatures.upload, {user_id: {device_id: keys.serialize()}} + ) + # TODO check failures + + async def upload_many_signatures( + self, + signatures: dict[UserID, dict[DeviceID, Union[DeviceKeys, CrossSigningKeys]]], + ) -> None: + await self.api.request( + Method.POST, + Path.v3.keys.signatures.upload, + { + user_id: {device_id: keys.serialize() for device_id, keys in devices.items()} + for user_id, devices in signatures.items() + }, + ) + # TODO check failures + async def query_keys( self, device_keys: list[UserID] | set[UserID] | dict[UserID, list[DeviceID]], diff --git a/mautrix/client/api/modules/media_repository.py b/mautrix/client/api/modules/media_repository.py index b1de6751..b1d90cc1 100644 --- a/mautrix/client/api/modules/media_repository.py +++ b/mautrix/client/api/modules/media_repository.py @@ -10,6 +10,8 @@ import asyncio import time +from yarl import URL + from mautrix import __optional_imports__ from mautrix.api import MediaPath, Method from mautrix.errors import MatrixResponseError, make_request_error @@ -19,7 +21,10 @@ MediaRepoConfig, MXOpenGraph, SerializerError, + SpecVersions, ) +from mautrix.util import background_task +from mautrix.util.async_body import async_iter_bytes from mautrix.util.opt_prometheus import Histogram from ..base import BaseClientAPI @@ -44,22 +49,19 @@ class MediaRepositoryMethods(BaseClientAPI): downloading content from the media repository and for getting URL previews without leaking client IPs. - See also: `API reference `__ - - There are also methods for supporting `MSC2246 - `__ which allows asynchronous - uploads of media. + See also: `API reference `__ """ - async def unstable_create_mxc(self) -> MediaCreateResponse: + async def create_mxc(self) -> MediaCreateResponse: """ - Create a media ID for uploading media to the homeserver. Requires the homeserver to have - `MSC2246 `__ support. + Create a media ID for uploading media to the homeserver. + + See also: `API reference `__ Returns: - MediaCreateResponse Containing the MXC URI that can be used to upload a file to later, as well as an optional upload URL + MediaCreateResponse Containing the MXC URI that can be used to upload a file to later """ - resp = await self.api.request(Method.POST, MediaPath.unstable["fi.mau.msc2246"].create) + resp = await self.api.request(Method.POST, MediaPath.v1.create) return MediaCreateResponse.deserialize(resp) @contextmanager @@ -85,21 +87,18 @@ async def upload_media( """ Upload a file to the content repository. - See also: `API reference `__ + See also: `API reference `__ Args: data: The data to upload. mime_type: The MIME type to send with the upload request. filename: The filename to send with the upload request. size: The file size to send with the upload request. - mxc: An existing MXC URI which doesn't have content yet to upload into. Requires the - homeserver to have MSC2246_ support. - async_upload: Should the media be uploaded in the background (using MSC2246_)? - If ``True``, this will create a MXC URI, start uploading in the background and then - immediately return the created URI. This is mutually exclusive with manually - passing the ``mxc`` parameter. - - .. _MSC2246: https://github.com/matrix-org/matrix-spec-proposals/pull/2246 + mxc: An existing MXC URI which doesn't have content yet to upload into. + async_upload: Should the media be uploaded in the background? + If ``True``, this will create a MXC URI using :meth:`create_mxc`, start uploading + in the background, and then immediately return the created URI. This is mutually + exclusive with manually passing the ``mxc`` parameter. Returns: The MXC URI to the uploaded file. @@ -126,19 +125,21 @@ async def upload_media( if async_upload: if mxc: raise ValueError("async_upload and mxc can't be provided simultaneously") - create_response = await self.unstable_create_mxc() + create_response = await self.create_mxc() mxc = create_response.content_uri - upload_url = create_response.upload_url + upload_url = create_response.unstable_upload_url path = MediaPath.v3.upload method = Method.POST if mxc: server_name, media_id = self.api.parse_mxc_uri(mxc) if upload_url is None: - path = MediaPath.unstable["fi.mau.msc2246"].upload[server_name][media_id] + path = MediaPath.v3.upload[server_name][media_id] method = Method.PUT else: - path = MediaPath.unstable["fi.mau.msc2246"].upload[server_name][media_id].complete + path = ( + MediaPath.unstable["com.beeper.msc3870"].upload[server_name][media_id].complete + ) if upload_url is not None: task = self._upload_to_url(upload_url, path, headers, data, post_upload_query=query) @@ -156,7 +157,7 @@ async def _try_upload(): except Exception as e: self.log.error(f"Failed to upload {mxc}: {type(e).__name__}: {e}") - asyncio.create_task(_try_upload()) + background_task.create(_try_upload()) return mxc else: with self._observe_upload_time(size): @@ -166,27 +167,40 @@ async def _try_upload(): except KeyError: raise MatrixResponseError("`content_uri` not in response.") - async def download_media(self, url: ContentURI, max_stall_ms: int | None = None) -> bytes: + async def download_media(self, url: ContentURI, timeout_ms: int | None = None) -> bytes: """ Download a file from the content repository. - See also: `API reference `__ + See also: `API reference `__ Args: url: The MXC URI to download. - max_stall_ms: The maximum number of milliseconds that the client is willing to wait to - start receiving data. Used for MSC2246 Asynchronous Uploads. + timeout_ms: The maximum number of milliseconds that the client is willing to wait to + start receiving data. Used for asynchronous uploads. Returns: The raw downloaded data. """ - url = self.api.get_download_url(url) + authenticated = (await self.versions()).supports(SpecVersions.V111) + url = self.api.get_download_url(url, authenticated=authenticated) query_params: dict[str, Any] = {"allow_redirect": "true"} - if max_stall_ms is not None: - query_params["max_stall_ms"] = max_stall_ms - query_params["fi.mau.msc2246.max_stall_ms"] = max_stall_ms - async with self.api.session.get(url, params=query_params) as response: - return await response.read() + if timeout_ms is not None: + query_params["timeout_ms"] = timeout_ms + headers: dict[str, str] = {} + if authenticated: + headers["Authorization"] = f"Bearer {self.api.token}" + if self.api.as_user_id: + query_params["user_id"] = self.api.as_user_id + req_id = self.api.log_download_request(url, query_params) + start = time.monotonic() + async with self.api.session.get(url, params=query_params, headers=headers) as response: + try: + response.raise_for_status() + return await response.read() + finally: + self.api.log_download_request_done( + url, req_id, time.monotonic() - start, response.status + ) async def download_thumbnail( self, @@ -194,13 +208,13 @@ async def download_thumbnail( width: int | None = None, height: int | None = None, resize_method: Literal["crop", "scale"] = None, - allow_remote: bool = True, - max_stall_ms: int | None = None, + allow_remote: bool | None = None, + timeout_ms: int | None = None, ): """ Download a thumbnail for a file in the content repository. - See also: `API reference `__ + See also: `API reference `__ Args: url: The MXC URI to download. @@ -212,13 +226,16 @@ async def download_thumbnail( allow_remote: Indicates to the server that it should not attempt to fetch the media if it is deemed remote. This is to prevent routing loops where the server contacts itself. - max_stall_ms: The maximum number of milliseconds that the client is willing to wait to - start receiving data. Used for MSC2246 Asynchronous Uploads. + timeout_ms: The maximum number of milliseconds that the client is willing to wait to + start receiving data. Used for asynchronous Uploads. Returns: The raw downloaded data. """ - url = self.api.get_download_url(url, download_type="thumbnail") + authenticated = (await self.versions()).supports(SpecVersions.V111) + url = self.api.get_download_url( + url, download_type="thumbnail", authenticated=authenticated + ) query_params: dict[str, Any] = {"allow_redirect": "true"} if width is not None: query_params["width"] = width @@ -227,12 +244,24 @@ async def download_thumbnail( if resize_method is not None: query_params["method"] = resize_method if allow_remote is not None: - query_params["allow_remote"] = allow_remote - if max_stall_ms is not None: - query_params["max_stall_ms"] = max_stall_ms - query_params["fi.mau.msc2246.max_stall_ms"] = max_stall_ms - async with self.api.session.get(url, params=query_params) as response: - return await response.read() + query_params["allow_remote"] = str(allow_remote).lower() + if timeout_ms is not None: + query_params["timeout_ms"] = timeout_ms + headers: dict[str, str] = {} + if authenticated: + headers["Authorization"] = f"Bearer {self.api.token}" + if self.api.as_user_id: + query_params["user_id"] = self.api.as_user_id + req_id = self.api.log_download_request(url, query_params) + start = time.monotonic() + async with self.api.session.get(url, params=query_params, headers=headers) as response: + try: + response.raise_for_status() + return await response.read() + finally: + self.api.log_download_request_done( + url, req_id, time.monotonic() - start, response.status + ) async def get_url_preview(self, url: str, timestamp: int | None = None) -> MXOpenGraph: """ @@ -286,19 +315,24 @@ async def _upload_to_url( headers: dict[str, str], data: bytes | bytearray | AsyncIterable[bytes], post_upload_query: dict[str, str], + min_iter_size: int = 25 * 1024 * 1024, ) -> None: retry_count = self.api.default_retry_count - backoff = 4 + backoff = 2 + do_fake_iter = data and hasattr(data, "__len__") and len(data) > min_iter_size + if do_fake_iter: + headers["Content-Length"] = str(len(data)) while True: self.log.debug("Uploading media to external URL %s", upload_url) upload_response = None try: + req_data = async_iter_bytes(data) if do_fake_iter else data upload_response = await self.api.session.put( - upload_url, data=data, headers=headers + upload_url, data=req_data, headers=headers ) upload_response.raise_for_status() except Exception as e: - if retry_count == 0: + if retry_count <= 0: raise make_request_error( http_status=upload_response.status if upload_response else -1, text=(await upload_response.text()) if upload_response else "", @@ -311,7 +345,7 @@ async def _upload_to_url( ) await asyncio.sleep(backoff) backoff *= 2 - retry_count = -1 + retry_count -= 1 else: break diff --git a/mautrix/client/api/modules/misc.py b/mautrix/client/api/modules/misc.py index 443877c1..8ff9c85a 100644 --- a/mautrix/client/api/modules/misc.py +++ b/mautrix/client/api/modules/misc.py @@ -50,7 +50,7 @@ async def set_typing(self, room_id: RoomID, timeout: int = 0) -> None: Args: room_id: The ID of the room in which the user is typing. - timeout: The length of time in seconds to mark this user as typing. + timeout: The length of time in milliseconds to mark this user as typing. """ if timeout > 0: content = {"typing": True, "timeout": timeout} diff --git a/mautrix/client/api/rooms.py b/mautrix/client/api/rooms.py index 9f935430..a488fba6 100644 --- a/mautrix/client/api/rooms.py +++ b/mautrix/client/api/rooms.py @@ -354,9 +354,12 @@ async def join_room( except KeyError: raise MatrixResponseError("`room_id` not in response.") - fill_member_event_callback: Callable[ - [RoomID, UserID, MemberStateEventContent], Awaitable[MemberStateEventContent | None] - ] | None + fill_member_event_callback: ( + Callable[ + [RoomID, UserID, MemberStateEventContent], Awaitable[MemberStateEventContent | None] + ] + | None + ) async def fill_member_event( self, room_id: RoomID, user_id: UserID, content: MemberStateEventContent diff --git a/mautrix/client/api/user_data.py b/mautrix/client/api/user_data.py index f37cb769..9c380335 100644 --- a/mautrix/client/api/user_data.py +++ b/mautrix/client/api/user_data.py @@ -5,6 +5,8 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +from typing import Any + from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError, MNotFound from mautrix.types import ContentURI, Member, SerializerError, User, UserID, UserSearchResults @@ -69,7 +71,7 @@ async def search_users(self, search_query: str, limit: int | None = 10) -> UserS # region 10.2 Profiles # API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#profiles - async def set_displayname(self, displayname: str, check_current: bool = True) -> None: + async def set_displayname(self, displayname: str | None, check_current: bool = True) -> None: """ Set the display name of the current user. @@ -79,7 +81,9 @@ async def set_displayname(self, displayname: str, check_current: bool = True) -> displayname: The new display name for the user. check_current: Whether or not to check if the displayname is already set. """ - if check_current and await self.get_displayname(self.mxid) == displayname: + if check_current and str_or_none(await self.get_displayname(self.mxid)) == str_or_none( + displayname + ): return await self.api.request( Method.PUT, @@ -110,7 +114,9 @@ async def get_displayname(self, user_id: UserID) -> str | None: except KeyError: return None - async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = True) -> None: + async def set_avatar_url( + self, avatar_url: ContentURI | None, check_current: bool = True + ) -> None: """ Set the avatar of the current user. @@ -120,7 +126,9 @@ async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = Tru avatar_url: The ``mxc://`` URI to the new avatar. check_current: Whether or not to check if the avatar is already set. """ - if check_current and await self.get_avatar_url(self.mxid) == avatar_url: + if check_current and str_or_none(await self.get_avatar_url(self.mxid)) == str_or_none( + avatar_url + ): return await self.api.request( Method.PUT, @@ -170,3 +178,23 @@ async def get_profile(self, user_id: UserID) -> Member: raise MatrixResponseError("Invalid member in response") from e # endregion + + # region Beeper Custom Fields API + + async def beeper_update_profile(self, custom_fields: dict[str, Any]) -> None: + """ + Set custom fields on the user's profile. Only works on Hungryserv. + + Args: + custom_fields: A dictionary of fields to set in the custom content of the profile. + """ + await self.api.request(Method.PATCH, Path.v3.profile[self.mxid], custom_fields) + + # endregion + + +def str_or_none(v: str | None) -> str | None: + """ + str_or_none empty string values to None + """ + return None if v == "" else v diff --git a/mautrix/client/state_store/__init__.py b/mautrix/client/state_store/__init__.py index 5f76c7dc..98150ca4 100644 --- a/mautrix/client/state_store/__init__.py +++ b/mautrix/client/state_store/__init__.py @@ -10,5 +10,4 @@ "MemorySyncStore", "SyncStore", "asyncpg", - "sqlalchemy", ] diff --git a/mautrix/client/state_store/abstract.py b/mautrix/client/state_store/abstract.py index e241d8f1..d5b1f5f2 100644 --- a/mautrix/client/state_store/abstract.py +++ b/mautrix/client/state_store/abstract.py @@ -121,6 +121,18 @@ async def set_power_levels( ) -> None: pass + @abstractmethod + async def has_create_cached(self, room_id: RoomID) -> bool: + pass + + @abstractmethod + async def get_create(self, room_id: RoomID) -> StateEvent | None: + pass + + @abstractmethod + async def set_create(self, event: StateEvent) -> None: + pass + @abstractmethod async def has_encryption_info_cached(self, room_id: RoomID) -> bool: pass @@ -135,7 +147,7 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent @abstractmethod async def set_encryption_info( - self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, any] + self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: pass @@ -143,9 +155,14 @@ async def update_state(self, evt: StateEvent) -> None: if evt.type == EventType.ROOM_POWER_LEVELS: await self.set_power_levels(evt.room_id, evt.content) elif evt.type == EventType.ROOM_MEMBER: + evt.unsigned["mautrix_prev_membership"] = await self.get_member( + evt.room_id, UserID(evt.state_key) + ) await self.set_member(evt.room_id, UserID(evt.state_key), evt.content) elif evt.type == EventType.ROOM_ENCRYPTION: await self.set_encryption_info(evt.room_id, evt.content) + elif evt.type == EventType.ROOM_CREATE and evt.sender: + await self.set_create(evt) async def get_membership(self, room_id: RoomID, user_id: UserID) -> Membership: member = await self.get_member(room_id, user_id) @@ -169,4 +186,7 @@ async def has_power_level( room_levels = await self.get_power_levels(room_id) if not room_levels: return None - return room_levels.get_user_level(user_id) >= room_levels.get_event_level(event_type) + create_event = await self.get_create(room_id) + return room_levels.get_user_level(user_id, create_event) >= room_levels.get_event_level( + event_type + ) diff --git a/mautrix/client/state_store/asyncpg/store.py b/mautrix/client/state_store/asyncpg/store.py index d78c04d6..f4f8436f 100644 --- a/mautrix/client/state_store/asyncpg/store.py +++ b/mautrix/client/state_store/asyncpg/store.py @@ -16,6 +16,7 @@ RoomEncryptionStateEventContent, RoomID, Serializable, + StateEvent, UserID, ) from mautrix.util.async_db import Database, Scheme @@ -223,6 +224,29 @@ async def set_power_levels( json.dumps(content.serialize() if isinstance(content, Serializable) else content), ) + async def has_create_cached(self, room_id: RoomID) -> bool: + return bool( + await self.db.fetchval( + "SELECT create_event IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id + ) + ) + + async def get_create(self, room_id: RoomID) -> StateEvent | None: + create_event_json = await self.db.fetchval( + "SELECT create_event FROM mx_room_state WHERE room_id=$1", room_id + ) + if create_event_json is None: + return None + return StateEvent.parse_json(create_event_json) + + async def set_create(self, event: StateEvent) -> None: + await self.db.execute( + "INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) " + "ON CONFLICT (room_id) DO UPDATE SET create_event=$2", + event.room_id, + json.dumps(event.serialize() if isinstance(event, Serializable) else event), + ) + async def has_encryption_info_cached(self, room_id: RoomID) -> bool: return bool( await self.db.fetchval( diff --git a/mautrix/client/state_store/asyncpg/upgrade.py b/mautrix/client/state_store/asyncpg/upgrade.py index 0a489aae..20b2e5b2 100644 --- a/mautrix/client/state_store/asyncpg/upgrade.py +++ b/mautrix/client/state_store/asyncpg/upgrade.py @@ -14,17 +14,18 @@ ) -@upgrade_table.register(description="Latest revision", upgrades_to=2) -async def upgrade_blank_to_v2(conn: Connection, scheme: Scheme) -> None: - await conn.execute( - """CREATE TABLE mx_room_state ( +@upgrade_table.register(description="Latest revision", upgrades_to=4) +async def upgrade_blank_to_v4(conn: Connection, scheme: Scheme) -> None: + await conn.execute(""" + CREATE TABLE mx_room_state ( room_id TEXT PRIMARY KEY, is_encrypted BOOLEAN, has_full_member_list BOOLEAN, encryption TEXT, - power_levels TEXT - )""" - ) + power_levels TEXT, + create_event TEXT + ) + """) membership_check = "" if scheme != Scheme.SQLITE: await conn.execute( @@ -32,16 +33,16 @@ async def upgrade_blank_to_v2(conn: Connection, scheme: Scheme) -> None: ) else: membership_check = "CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock'))" - await conn.execute( - f"""CREATE TABLE mx_user_profile ( + await conn.execute(f""" + CREATE TABLE mx_user_profile ( room_id TEXT, user_id TEXT, membership membership NOT NULL {membership_check}, displayname TEXT, avatar_url TEXT, PRIMARY KEY (room_id, user_id) - )""" - ) + ) + """) @upgrade_table.register(description="Stop using size-limited string fields") @@ -54,3 +55,21 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT") await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN displayname TYPE TEXT") await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT") + + +@upgrade_table.register(description="Mark rooms that need crypto state event resynced") +async def upgrade_v3(conn: Connection) -> None: + if await conn.table_exists("portal"): + await conn.execute(""" + INSERT INTO mx_room_state (room_id, encryption) + SELECT portal.mxid, '{"resync":true}' FROM portal + WHERE portal.encrypted=true AND portal.mxid IS NOT NULL + ON CONFLICT (room_id) DO UPDATE + SET encryption=excluded.encryption + WHERE mx_room_state.encryption IS NULL + """) + + +@upgrade_table.register(description="Add create event to room state cache") +async def upgrade_v4(conn: Connection) -> None: + await conn.execute("ALTER TABLE mx_room_state ADD COLUMN create_event TEXT") diff --git a/mautrix/client/state_store/file.py b/mautrix/client/state_store/file.py index d567c853..a5d53663 100644 --- a/mautrix/client/state_store/file.py +++ b/mautrix/client/state_store/file.py @@ -15,6 +15,7 @@ PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, + StateEvent, UserID, ) from mautrix.util.file_store import Filer, FileStore @@ -65,3 +66,7 @@ async def set_power_levels( ) -> None: await super().set_power_levels(room_id, content) self._time_limited_flush() + + async def set_create(self, event: StateEvent) -> None: + await super().set_create(event) + self._time_limited_flush() diff --git a/mautrix/client/state_store/memory.py b/mautrix/client/state_store/memory.py index 8f2edac5..2b010a75 100644 --- a/mautrix/client/state_store/memory.py +++ b/mautrix/client/state_store/memory.py @@ -14,6 +14,7 @@ PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, + StateEvent, UserID, ) @@ -25,6 +26,7 @@ class SerializedStateStore(TypedDict): full_member_list: dict[RoomID, bool] power_levels: dict[RoomID, Any] encryption: dict[RoomID, Any] + create: dict[RoomID, Any] class MemoryStateStore(StateStore): @@ -32,12 +34,14 @@ class MemoryStateStore(StateStore): full_member_list: dict[RoomID, bool] power_levels: dict[RoomID, PowerLevelStateEventContent] encryption: dict[RoomID, RoomEncryptionStateEventContent | None] + create: dict[RoomID, StateEvent] def __init__(self) -> None: self.members = {} self.full_member_list = {} self.power_levels = {} self.encryption = {} + self.create = {} def serialize(self) -> SerializedStateStore: """ @@ -58,6 +62,7 @@ def serialize(self) -> SerializedStateStore: room_id: (content.serialize() if content is not None else None) for room_id, content in self.encryption.items() }, + "create": {room_id: evt.serialize() for room_id, evt in self.create.items()}, } def deserialize(self, data: SerializedStateStore) -> None: @@ -84,6 +89,9 @@ def deserialize(self, data: SerializedStateStore) -> None: ) for room_id, content in data["encryption"].items() } + self.create = { + room_id: StateEvent.deserialize(evt) for room_id, evt in data["create"].items() + } async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None: try: @@ -176,6 +184,17 @@ async def set_power_levels( content = PowerLevelStateEventContent.deserialize(content) self.power_levels[room_id] = content + async def has_create_cached(self, room_id: RoomID) -> bool: + return room_id in self.create + + async def get_create(self, room_id: RoomID) -> StateEvent | None: + return self.create.get(room_id) + + async def set_create(self, event: StateEvent | dict[str, Any]) -> None: + if not isinstance(event, StateEvent): + event = StateEvent.deserialize(event) + self.create[event.room_id] = event + async def has_encryption_info_cached(self, room_id: RoomID) -> bool: return room_id in self.encryption diff --git a/mautrix/client/state_store/sqlalchemy/__init__.py b/mautrix/client/state_store/sqlalchemy/__init__.py deleted file mode 100644 index 6ce78f46..00000000 --- a/mautrix/client/state_store/sqlalchemy/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mx_room_state import RoomState, SerializableType -from .mx_user_profile import UserProfile -from .sqlstatestore import SQLStateStore diff --git a/mautrix/client/state_store/sqlalchemy/mx_room_state.py b/mautrix/client/state_store/sqlalchemy/mx_room_state.py deleted file mode 100644 index 9f97056f..00000000 --- a/mautrix/client/state_store/sqlalchemy/mx_room_state.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2022 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from __future__ import annotations - -from typing import Type -import json - -from sqlalchemy import Boolean, Column, Text, types - -from mautrix.types import ( - PowerLevelStateEventContent as PowerLevels, - RoomEncryptionStateEventContent as EncryptionInfo, - RoomID, - Serializable, -) -from mautrix.util.db import Base - - -class SerializableType(types.TypeDecorator): - impl = types.Text - - def __init__(self, python_type: Type[Serializable], *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._python_type = python_type - - @property - def python_type(self) -> Type[Serializable]: - return self._python_type - - def process_bind_param(self, value: Serializable, dialect) -> str | None: - if value is not None: - return json.dumps(value.serialize() if isinstance(value, Serializable) else value) - return None - - def process_result_value(self, value: str, dialect) -> Serializable | None: - if value is not None: - return self.python_type.deserialize(json.loads(value)) - return None - - def process_literal_param(self, value, dialect): - return value - - -class RoomState(Base): - __tablename__ = "mx_room_state" - - room_id: RoomID = Column(Text, primary_key=True) - is_encrypted: bool = Column(Boolean, nullable=True) - has_full_member_list: bool = Column(Boolean, nullable=True) - encryption: EncryptionInfo = Column(SerializableType(EncryptionInfo), nullable=True) - power_levels: PowerLevels = Column(SerializableType(PowerLevels), nullable=True) - - @property - def has_power_levels(self) -> bool: - return bool(self.power_levels) - - @property - def has_encryption_info(self) -> bool: - return self.is_encrypted is not None - - @classmethod - def get(cls, room_id: RoomID) -> RoomState | None: - return cls._select_one_or_none(cls.c.room_id == room_id) diff --git a/mautrix/client/state_store/sqlalchemy/mx_user_profile.py b/mautrix/client/state_store/sqlalchemy/mx_user_profile.py deleted file mode 100644 index 0bb6b766..00000000 --- a/mautrix/client/state_store/sqlalchemy/mx_user_profile.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2022 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from __future__ import annotations - -from typing import Iterable - -from sqlalchemy import Column, Enum, Text - -from mautrix.types import ContentURI, Member, Membership, RoomID, UserID -from mautrix.util.db import Base - -from .mx_room_state import RoomState - - -class UserProfile(Base): - __tablename__ = "mx_user_profile" - - room_id: RoomID = Column(Text, primary_key=True) - user_id: UserID = Column(Text, primary_key=True) - membership: Membership = Column(Enum(Membership), nullable=False, default=Membership.LEAVE) - displayname: str = Column(Text, nullable=True) - avatar_url: ContentURI = Column(Text, nullable=True) - - def member(self) -> Member: - return Member( - membership=self.membership, displayname=self.displayname, avatar_url=self.avatar_url - ) - - @classmethod - def get(cls, room_id: RoomID, user_id: UserID) -> UserProfile | None: - return cls._select_one_or_none((cls.c.room_id == room_id) & (cls.c.user_id == user_id)) - - @classmethod - def all_in_room( - cls, - room_id: RoomID, - memberships: tuple[Membership, ...], - prefix: str = None, - suffix: str = None, - bot: str = None, - ) -> Iterable[UserProfile]: - clause = cls.c.membership == memberships[0] - for membership in memberships[1:]: - clause |= cls.c.membership == membership - clause &= cls.c.room_id == room_id - if bot: - clause &= cls.c.user_id != bot - if prefix: - clause &= ~cls.c.user_id.startswith(prefix, autoescape=True) - if suffix: - clause &= ~cls.c.user_id.startswith(suffix, autoescape=True) - return cls._select_all(clause) - - @classmethod - def find_rooms_with_user(cls, user_id: UserID) -> Iterable[UserProfile]: - return cls._select_all( - (cls.c.user_id == user_id) - & (cls.c.room_id == RoomState.c.room_id) - & (RoomState.c.is_encrypted == True) - ) - - @classmethod - def delete_all(cls, room_id: RoomID) -> None: - with cls.db.begin() as conn: - conn.execute(cls.t.delete().where(cls.c.room_id == room_id)) - - @classmethod - def bulk_replace( - cls, - room_id: RoomID, - members: dict[UserID, Member], - only_membership: Membership | None = None, - ) -> None: - with cls.db.begin() as conn: - delete_condition = cls.c.room_id == room_id - if only_membership is not None: - delete_condition &= cls.c.membership == only_membership - cls.db.execute(cls.t.delete().where(delete_condition)) - conn.execute( - cls.t.insert(), - [ - dict( - room_id=room_id, - user_id=user_id, - membership=member.membership, - displayname=member.displayname, - avatar_url=member.avatar_url, - ) - for user_id, member in members.items() - ], - ) diff --git a/mautrix/client/state_store/sqlalchemy/sqlstatestore.py b/mautrix/client/state_store/sqlalchemy/sqlstatestore.py deleted file mode 100644 index 1145dec7..00000000 --- a/mautrix/client/state_store/sqlalchemy/sqlstatestore.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) 2022 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from __future__ import annotations - -from typing import Any - -from mautrix.types import ( - Member, - Membership, - PowerLevelStateEventContent, - RoomEncryptionStateEventContent, - RoomID, - UserID, -) - -from ..abstract import StateStore -from .mx_room_state import RoomState -from .mx_user_profile import UserProfile - - -class SQLStateStore(StateStore): - _profile_cache: dict[RoomID, dict[UserID, UserProfile]] - _room_state_cache: dict[RoomID, RoomState] - - def __init__(self) -> None: - super().__init__() - self._profile_cache = {} - self._room_state_cache = {} - - def _get_user_profile( - self, room_id: RoomID, user_id: UserID, create: bool = False - ) -> UserProfile: - if not room_id: - raise ValueError("room_id is empty") - elif not user_id: - raise ValueError("user_id is empty") - try: - return self._profile_cache[room_id][user_id] - except KeyError: - pass - if room_id not in self._profile_cache: - self._profile_cache[room_id] = {} - - profile = UserProfile.get(room_id, user_id) - if profile: - self._profile_cache[room_id][user_id] = profile - elif create: - profile = UserProfile(room_id=room_id, user_id=user_id, membership=Membership.LEAVE) - profile.insert() - self._profile_cache[room_id][user_id] = profile - return profile - - async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None: - profile = self._get_user_profile(room_id, user_id) - if not profile: - return None - return profile.member() - - async def set_member(self, room_id: RoomID, user_id: UserID, member: Member) -> None: - if not member: - raise ValueError("member info is empty") - profile = self._get_user_profile(room_id, user_id, create=True) - profile.edit( - membership=member.membership, - displayname=member.displayname or profile.displayname, - avatar_url=member.avatar_url or profile.avatar_url, - ) - - async def set_membership( - self, room_id: RoomID, user_id: UserID, membership: Membership - ) -> None: - await self.set_member(room_id, user_id, Member(membership=membership)) - - async def get_member_profiles( - self, - room_id: RoomID, - memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), - ) -> dict[UserID, Member]: - self._profile_cache[room_id] = {} - for profile in UserProfile.all_in_room(room_id, memberships): - self._profile_cache[room_id][profile.user_id] = profile - return { - profile.user_id: profile.member() for profile in self._profile_cache[room_id].values() - } - - async def get_members_filtered( - self, - room_id: RoomID, - not_prefix: str, - not_suffix: str, - not_id: str, - memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), - ) -> list[UserID]: - return [ - profile.user_id - for profile in UserProfile.all_in_room( - room_id, memberships, not_suffix, not_prefix, not_id - ) - ] - - async def set_members( - self, - room_id: RoomID, - members: dict[UserID, Member], - only_membership: Membership | None = None, - ) -> None: - UserProfile.bulk_replace(room_id, members, only_membership=only_membership) - self._get_room_state(room_id, create=True).edit(has_full_member_list=True) - try: - del self._profile_cache[room_id] - except KeyError: - pass - - async def has_full_member_list(self, room_id: RoomID) -> bool: - room = self._get_room_state(room_id) - if not room: - return False - return room.has_full_member_list - - async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]: - return [profile.room_id for profile in UserProfile.find_rooms_with_user(user_id)] - - def _get_room_state(self, room_id: RoomID, create: bool = False) -> RoomState: - if not room_id: - raise ValueError("room_id is empty") - try: - return self._room_state_cache[room_id] - except KeyError: - pass - - room = RoomState.get(room_id) - if room: - self._room_state_cache[room_id] = room - elif create: - room = RoomState(room_id=room_id) - room.insert() - self._room_state_cache[room_id] = room - return room - - async def has_power_levels_cached(self, room_id: RoomID) -> bool: - room = self._get_room_state(room_id) - if not room: - return False - return room.has_power_levels - - async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None: - room = self._get_room_state(room_id) - if not room: - return None - return room.power_levels - - async def set_power_levels( - self, room_id: RoomID, content: PowerLevelStateEventContent - ) -> None: - if not content: - raise ValueError("content is empty") - self._get_room_state(room_id, create=True).edit(power_levels=content) - - async def is_encrypted(self, room_id: RoomID) -> bool | None: - room = self._get_room_state(room_id) - if not room: - return None - return room.is_encrypted - - async def has_encryption_info_cached(self, room_id: RoomID) -> bool: - room = self._get_room_state(room_id) - return room and room.has_encryption_info - - async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None: - room = self._get_room_state(room_id) - if not room: - return None - return room.encryption - - async def set_encryption_info( - self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] - ) -> None: - if not content: - raise ValueError("content is empty") - self._get_room_state(room_id, create=True).edit(encryption=content, is_encrypted=True) diff --git a/mautrix/client/state_store/tests/store_test.py b/mautrix/client/state_store/tests/store_test.py index dbbd376a..46eee750 100644 --- a/mautrix/client/state_store/tests/store_test.py +++ b/mautrix/client/state_store/tests/store_test.py @@ -16,15 +16,12 @@ import asyncpg import pytest -import sqlalchemy as sql from mautrix.types import EncryptionAlgorithm, Member, Membership, RoomID, StateEvent, UserID from mautrix.util.async_db import Database -from mautrix.util.db import Base from .. import MemoryStateStore, StateStore from ..asyncpg import PgStateStore -from ..sqlalchemy import RoomState, SQLStateStore, UserProfile @asynccontextmanager @@ -54,7 +51,7 @@ async def async_postgres_store() -> AsyncIterator[PgStateStore]: @asynccontextmanager async def async_sqlite_store() -> AsyncIterator[PgStateStore]: db = Database.create( - "sqlite:///:memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1} + "sqlite::memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1} ) store = PgStateStore(db) await db.start() @@ -62,23 +59,12 @@ async def async_sqlite_store() -> AsyncIterator[PgStateStore]: await db.stop() -@asynccontextmanager -async def alchemy_store() -> AsyncIterator[SQLStateStore]: - db = sql.create_engine("sqlite:///:memory:") - Base.metadata.bind = db - for table in (RoomState, UserProfile): - table.bind(db) - Base.metadata.create_all() - yield SQLStateStore() - db.dispose() - - @asynccontextmanager async def memory_store() -> AsyncIterator[MemoryStateStore]: yield MemoryStateStore() -@pytest.fixture(params=[async_postgres_store, async_sqlite_store, alchemy_store, memory_store]) +@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store]) async def store(request) -> AsyncIterator[StateStore]: param: Callable[[], AsyncContextManager[StateStore]] = request.param async with param() as state_store: diff --git a/mautrix/client/store_updater.py b/mautrix/client/store_updater.py index a5530bde..35280324 100644 --- a/mautrix/client/store_updater.py +++ b/mautrix/client/store_updater.py @@ -5,6 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +from typing import Literal import asyncio from mautrix.errors import MForbidden, MNotFound @@ -196,20 +197,28 @@ async def send_state_event( return event_id async def get_state_event( - self, room_id: RoomID, event_type: EventType, state_key: str = "" - ) -> StateEventContent: - event = await super().get_state_event(room_id, event_type, state_key) + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: + event = await super().get_state_event(room_id, event_type, state_key, format=format) if self.state_store: - fake_event = StateEvent( - type=event_type, - room_id=room_id, - event_id=EventID(""), - sender=UserID(""), - state_key=state_key, - timestamp=0, - content=event, - ) - await self.state_store.update_state(fake_event) + if isinstance(event, StateEvent): + await self.state_store.update_state(event) + else: + fake_event = StateEvent( + type=event_type, + room_id=room_id, + event_id=EventID(""), + sender=UserID(""), + state_key=state_key, + timestamp=0, + content=event, + ) + await self.state_store.update_state(fake_event) return event async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]: diff --git a/mautrix/client/syncer.py b/mautrix/client/syncer.py index 9f942665..4b8661a2 100644 --- a/mautrix/client/syncer.py +++ b/mautrix/client/syncer.py @@ -5,11 +5,11 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Awaitable, Callable, Type, TypeVar +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Type, TypeVar from abc import ABC, abstractmethod -from contextlib import suppress from enum import Enum, Flag, auto import asyncio +import itertools import time from mautrix.errors import MUnknownToken @@ -25,7 +25,6 @@ Filter, FilterID, GenericEvent, - MessageEvent, PresenceState, SerializerError, StateEvent, @@ -34,6 +33,7 @@ ToDeviceEvent, UserID, ) +from mautrix.util import background_task from mautrix.util.logging import TraceLogger from . import dispatcher @@ -78,13 +78,18 @@ class InternalEventType(Enum): DEVICE_OTK_COUNT = auto() +class EventHandlerProps(NamedTuple): + wait_sync: bool + sync_stream: Optional[SyncStream] + + class Syncer(ABC): loop: asyncio.AbstractEventLoop log: TraceLogger mxid: UserID - global_event_handlers: list[tuple[EventHandler, bool]] - event_handlers: dict[EventType | InternalEventType, list[tuple[EventHandler, bool]]] + global_event_handlers: dict[EventHandler, EventHandlerProps] + event_handlers: dict[EventType | InternalEventType, dict[EventHandler, EventHandlerProps]] dispatchers: dict[Type[dispatcher.Dispatcher], dispatcher.Dispatcher] syncing_task: asyncio.Task | None ignore_initial_sync: bool @@ -94,7 +99,7 @@ class Syncer(ABC): sync_store: SyncStore def __init__(self, sync_store: SyncStore) -> None: - self.global_event_handlers = [] + self.global_event_handlers = {} self.event_handlers = {} self.dispatchers = {} self.syncing_task = None @@ -157,6 +162,7 @@ def add_event_handler( event_type: InternalEventType | EventType, handler: EventHandler, wait_sync: bool = False, + sync_stream: Optional[SyncStream] = None, ) -> None: """ Add a new event handler. @@ -166,13 +172,15 @@ def add_event_handler( event types. handler: The handler function to add. wait_sync: Whether or not the handler should be awaited before the next sync request. + sync_stream: The sync streams to listen to. Defaults to all. """ if not isinstance(event_type, (EventType, InternalEventType)): raise ValueError("Invalid event type") + props = EventHandlerProps(wait_sync=wait_sync, sync_stream=sync_stream) if event_type == EventType.ALL: - self.global_event_handlers.append((handler, wait_sync)) + self.global_event_handlers[handler] = props else: - self.event_handlers.setdefault(event_type, []).append((handler, wait_sync)) + self.event_handlers.setdefault(event_type, {})[handler] = props def remove_event_handler( self, event_type: EventType | InternalEventType, handler: EventHandler @@ -196,11 +204,7 @@ def remove_event_handler( # No handlers for this event type registered return - # FIXME this is a bit hacky - with suppress(ValueError): - handler_list.remove((handler, True)) - with suppress(ValueError): - handler_list.remove((handler, False)) + handler_list.pop(handler, None) if len(handler_list) == 0 and event_type != EventType.ALL: del self.event_handlers[event_type] @@ -228,7 +232,9 @@ def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asynci else: event.type = event.type.with_class(EventType.Class.MESSAGE) setattr(event, "source", source) - return self.dispatch_manual_event(event.type, event, include_global_handlers=True) + return self.dispatch_manual_event( + event.type, event, include_global_handlers=True, source=source + ) async def _catch_errors(self, handler: EventHandler, data: Any) -> None: try: @@ -242,15 +248,25 @@ def dispatch_manual_event( data: Any, include_global_handlers: bool = False, force_synchronous: bool = False, + source: Optional[SyncStream] = None, ) -> list[asyncio.Task]: - handlers = self.event_handlers.get(event_type, []) + handlers = self.event_handlers.get(event_type, {}).items() if include_global_handlers: - handlers = self.global_event_handlers + handlers + handlers = itertools.chain(self.global_event_handlers.items(), handlers) tasks = [] - for handler, wait_sync in handlers: - task = asyncio.create_task(self._catch_errors(handler, data)) - if force_synchronous or wait_sync: - tasks.append(task) + if source is None: + source = getattr(data, "source", None) + for handler, props in handlers: + if ( + props.sync_stream is not None + and source is not None + and not props.sync_stream & source + ): + continue + if force_synchronous or props.wait_sync: + tasks.append(asyncio.create_task(self._catch_errors(handler, data))) + else: + background_task.create(self._catch_errors(handler, data)) return tasks async def run_internal_event( @@ -261,6 +277,7 @@ async def run_internal_event( event_type, custom_type if custom_type is not None else kwargs, include_global_handlers=False, + source=SyncStream.INTERNAL, ) await asyncio.gather(*tasks) @@ -272,6 +289,7 @@ def dispatch_internal_event( event_type, custom_type if custom_type is not None else kwargs, include_global_handlers=False, + source=SyncStream.INTERNAL, ) def _try_deserialize(self, type: Type[T], data: JSON) -> T | GenericEvent: @@ -340,16 +358,29 @@ def handle_sync(self, data: JSON) -> list[asyncio.Task]: self._try_deserialize(Event, raw_event), source=SyncStream.JOINED_ROOM | SyncStream.TIMELINE, ) + + for raw_event in room_data.get("ephemeral", {}).get("events", []): + raw_event["room_id"] = room_id + tasks += self.dispatch_event( + self._try_deserialize(EphemeralEvent, raw_event), + source=SyncStream.JOINED_ROOM | SyncStream.EPHEMERAL, + ) for room_id, room_data in rooms.get("invite", {}).items(): events: list[dict[str, JSON]] = room_data.get("invite_state", {}).get("events", []) for raw_event in events: raw_event["room_id"] = room_id - raw_invite = next( - raw_event - for raw_event in events - if raw_event.get("type", "") == "m.room.member" - and raw_event.get("state_key", "") == self.mxid - ) + try: + raw_invite = next( + raw_event + for raw_event in events + if raw_event.get("type", "") == "m.room.member" + and raw_event.get("state_key", "") == self.mxid + ) + except StopIteration: + self.log.warning( + f"Corrupted invite section in sync: no invite event present for {room_id}" + ) + continue # These aren't required by the spec, so make sure they're set raw_invite.setdefault("event_id", None) raw_invite.setdefault("origin_server_ts", int(time.time() * 1000)) diff --git a/mautrix/crypto/__init__.py b/mautrix/crypto/__init__.py index 39867225..743fc2c6 100644 --- a/mautrix/crypto/__init__.py +++ b/mautrix/crypto/__init__.py @@ -1,6 +1,6 @@ from .account import OlmAccount from .key_share import RejectKeyShare -from .sessions import InboundGroupSession, OutboundGroupSession, Session +from .sessions import InboundGroupSession, OutboundGroupSession, RatchetSafety, Session # These have to be last from .store import ( # isort: skip diff --git a/mautrix/crypto/account.py b/mautrix/crypto/account.py index db508262..a00ada71 100644 --- a/mautrix/crypto/account.py +++ b/mautrix/crypto/account.py @@ -10,15 +10,17 @@ from mautrix.types import ( DeviceID, + DeviceKeys, EncryptionAlgorithm, EncryptionKeyAlgorithm, IdentityKey, + KeyID, SigningKey, UserID, ) -from . import base from .sessions import Session +from .signature import sign_olm class OlmAccount(olm.Account): @@ -74,19 +76,18 @@ def new_outbound_session(self, target_key: IdentityKey, one_time_key: IdentityKe session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now() ) - def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> Dict[str, Any]: - device_keys = { - "user_id": user_id, - "device_id": device_id, - "algorithms": [EncryptionAlgorithm.OLM_V1.value, EncryptionAlgorithm.MEGOLM_V1.value], - "keys": { - f"{algorithm}:{device_id}": key for algorithm, key in self.identity_keys.items() + def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> DeviceKeys: + device_keys = DeviceKeys( + user_id=user_id, + device_id=device_id, + algorithms=[EncryptionAlgorithm.OLM_V1, EncryptionAlgorithm.MEGOLM_V1], + keys={ + KeyID(algorithm=EncryptionKeyAlgorithm(algorithm), key_id=device_id): key + for algorithm, key in self.identity_keys.items() }, - } - signature = self.sign(base.canonical_json(device_keys)) - device_keys["signatures"] = { - user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature} - } + signatures={}, + ) + device_keys.signatures[user_id] = {KeyID.ed25519(device_id): sign_olm(device_keys, self)} return device_keys def get_one_time_keys( @@ -97,12 +98,12 @@ def get_one_time_keys( self.generate_one_time_keys(new_count) keys = {} for key_id, key in self.one_time_keys.get("curve25519", {}).items(): - signature = self.sign(base.canonical_json({"key": key})) - keys[f"{EncryptionKeyAlgorithm.SIGNED_CURVE25519}:{key_id}"] = { + keys[str(KeyID.signed_curve25519(IdentityKey(key_id)))] = { "key": key, "signatures": { - user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature} + user_id: { + str(KeyID.ed25519(device_id)): sign_olm({"key": key}, self), + } }, } - self.mark_keys_as_published() return keys diff --git a/mautrix/crypto/base.py b/mautrix/crypto/base.py index 4d12e6cf..f7015ca1 100644 --- a/mautrix/crypto/base.py +++ b/mautrix/crypto/base.py @@ -1,41 +1,34 @@ -# Copyright (c) 2022 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Awaitable, Callable, TypedDict +from typing import Awaitable, Callable import asyncio -import functools -import json - -import olm +from mautrix.errors import MForbidden, MNotFound from mautrix.types import ( - DeviceID, - EncryptionKeyAlgorithm, + EventType, IdentityKey, - KeyID, RequestedKeyInfo, + RoomEncryptionStateEventContent, RoomID, + RoomKeyEventContent, SessionID, - SigningKey, TrustState, UserID, ) from mautrix.util.logging import TraceLogger from .. import client as cli, crypto - - -class SignedObject(TypedDict): - signatures: dict[UserID, dict[str, str]] - unsigned: Any +from .ssss import Machine as SSSSMachine class BaseOlmMachine: client: cli.Client + ssss: SSSSMachine log: TraceLogger crypto_store: crypto.CryptoStore state_store: crypto.StateStore @@ -46,6 +39,14 @@ class BaseOlmMachine: share_keys_min_trust: TrustState allow_key_share: Callable[[crypto.DeviceIdentity, RequestedKeyInfo], Awaitable[bool]] + delete_outbound_keys_on_ack: bool + dont_store_outbound_keys: bool + delete_previous_keys_on_receive: bool + ratchet_keys_on_decrypt: bool + delete_fully_used_keys_on_decrypt: bool + delete_keys_on_device_delete: bool + disable_device_change_key_rotation: bool + # Futures that wait for responses to a key request _key_request_waiters: dict[SessionID, asyncio.Future] # Futures that wait for a session to be received (either normally or through a key request) @@ -53,6 +54,9 @@ class BaseOlmMachine: _prev_unwedge: dict[IdentityKey, float] _fetch_keys_lock: asyncio.Lock + _megolm_decrypt_lock: asyncio.Lock + _share_keys_lock: asyncio.Lock + _last_key_share: float _cs_fetch_attempted: set[UserID] async def wait_for_session( @@ -74,26 +78,30 @@ def _mark_session_received(self, session_id: SessionID) -> None: except KeyError: return - -canonical_json = functools.partial( - json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True -) - - -def verify_signature_json( - data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey -) -> bool: - data_copy = {**data} - data_copy.pop("unsigned", None) - signatures = data_copy.pop("signatures") - key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name)) - try: - signature = signatures[user_id][key_id] - except KeyError: - return False - signed_data = canonical_json(data_copy) - try: - olm.ed25519_verify(key, signed_data, signature) - return True - except olm.OlmVerifyError: - return False + async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None: + encryption_info = await self.state_store.get_encryption_info(evt.room_id) + if not encryption_info: + self.log.warning( + f"Encryption info for {evt.room_id} not found in state store, fetching from server" + ) + try: + encryption_info = await self.client.get_state_event( + evt.room_id, EventType.ROOM_ENCRYPTION + ) + except (MNotFound, MForbidden) as e: + self.log.warning( + f"Failed to get encryption info for {evt.room_id} from server: {e}," + " using defaults" + ) + encryption_info = RoomEncryptionStateEventContent() + if not encryption_info: + self.log.warning( + f"Didn't find encryption info for {evt.room_id} on server either," + " using defaults" + ) + encryption_info = RoomEncryptionStateEventContent() + + if not evt.beeper_max_age_ms: + evt.beeper_max_age_ms = encryption_info.rotation_period_ms + if not evt.beeper_max_messages: + evt.beeper_max_messages = encryption_info.rotation_period_msgs diff --git a/mautrix/crypto/cross_signing.py b/mautrix/crypto/cross_signing.py new file mode 100644 index 00000000..7577cd5d --- /dev/null +++ b/mautrix/crypto/cross_signing.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from ..types import ( + JSON, + CrossSigner, + CrossSigningKeys, + CrossSigningUsage, + DeviceIdentity, + EventType, + KeyID, + UserID, +) +from .cross_signing_key import CrossSigningPrivateKeys, CrossSigningPublicKeys, CrossSigningSeeds +from .device_lists import DeviceListMachine +from .signature import sign_olm +from .ssss import Key as SSSSKey + + +class CrossSigningMachine(DeviceListMachine): + _cross_signing_public_keys: CrossSigningPublicKeys | None + _cross_signing_public_keys_fetched: bool + _cross_signing_private_keys: CrossSigningPrivateKeys | None + + async def verify_with_recovery_key(self, recovery_key: str) -> None: + if not self.account.shared: + raise ValueError("Device keys must be shared before verifying with recovery key") + key_id, key_data = await self.ssss.get_default_key_data() + ssss_key = key_data.verify_recovery_key(key_id, recovery_key) + seeds = await self._fetch_cross_signing_keys_from_ssss(ssss_key) + self._import_cross_signing_keys(seeds) + await self.sign_own_device(self.own_identity) + + def _import_cross_signing_keys(self, seeds: CrossSigningSeeds) -> None: + self._cross_signing_private_keys = seeds.to_keys() + self._cross_signing_public_keys = self._cross_signing_private_keys.public_keys + + async def generate_recovery_key( + self, passphrase: str | None = None, seeds: CrossSigningSeeds | None = None + ) -> str: + if not self.account.shared: + raise ValueError("Device keys must be shared before generating recovery key") + seeds = seeds or CrossSigningSeeds.generate() + ssss_key = await self.ssss.generate_and_upload_key(passphrase) + await self._upload_cross_signing_keys_to_ssss(ssss_key, seeds) + await self._publish_cross_signing_keys(seeds.to_keys()) + await self.ssss.set_default_key_id(ssss_key.id) + await self.sign_own_device(self.own_identity) + return ssss_key.recovery_key + + async def _fetch_cross_signing_keys_from_ssss(self, key: SSSSKey) -> CrossSigningSeeds: + return CrossSigningSeeds( + master_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_MASTER, key + ), + user_signing_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_USER_SIGNING, key + ), + self_signing_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_SELF_SIGNING, key + ), + ) + + async def _upload_cross_signing_keys_to_ssss( + self, key: SSSSKey, seeds: CrossSigningSeeds + ) -> None: + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_MASTER, seeds.master_key, key + ) + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_USER_SIGNING, seeds.user_signing_key, key + ) + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_SELF_SIGNING, seeds.self_signing_key, key + ) + + async def get_own_cross_signing_public_keys(self) -> CrossSigningPublicKeys | None: + if self._cross_signing_public_keys or self._cross_signing_public_keys_fetched: + return self._cross_signing_public_keys + keys = await self.get_cross_signing_public_keys(self.client.mxid) + self._cross_signing_public_keys_fetched = True + if keys: + self._cross_signing_public_keys = keys + return keys + + async def get_cross_signing_public_keys( + self, user_id: UserID + ) -> CrossSigningPublicKeys | None: + db_keys = await self.crypto_store.get_cross_signing_keys(user_id) + if CrossSigningUsage.MASTER not in db_keys and user_id not in self._cs_fetch_attempted: + self.log.debug(f"Didn't find any cross-signing keys for {user_id}, fetching...") + async with self._fetch_keys_lock: + if user_id not in self._cs_fetch_attempted: + self._cs_fetch_attempted.add(user_id) + await self._fetch_keys([user_id], include_untracked=True) + db_keys = await self.crypto_store.get_cross_signing_keys(user_id) + if CrossSigningUsage.MASTER not in db_keys: + return None + return CrossSigningPublicKeys( + master_key=db_keys[CrossSigningUsage.MASTER].key, + self_signing_key=( + db_keys[CrossSigningUsage.SELF].key if CrossSigningUsage.SELF in db_keys else None + ), + user_signing_key=( + db_keys[CrossSigningUsage.USER].key if CrossSigningUsage.USER in db_keys else None + ), + ) + + async def sign_own_device(self, device: DeviceIdentity) -> None: + full_keys = await self._get_full_device_keys(device) + ssk = self._cross_signing_private_keys.self_signing_key + signature = sign_olm(full_keys, ssk) + full_keys.signatures = {self.client.mxid: {KeyID.ed25519(ssk.public_key): signature}} + await self.client.upload_one_signature(device.user_id, device.device_id, full_keys) + await self.crypto_store.put_signature( + CrossSigner(device.user_id, device.signing_key), + CrossSigner(self.client.mxid, ssk.public_key), + signature, + ) + + async def _publish_cross_signing_keys( + self, + keys: CrossSigningPrivateKeys, + auth: dict[str, JSON] | None = None, + ) -> None: + public = keys.public_keys + master_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.MASTER], + keys={KeyID.ed25519(public.master_key): public.master_key}, + ) + master_key.signatures = { + self.client.mxid: { + KeyID.ed25519(self.client.device_id): sign_olm(master_key, self.account), + } + } + self_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.SELF], + keys={KeyID.ed25519(public.self_signing_key): public.self_signing_key}, + ) + self_key.signatures = { + self.client.mxid: { + KeyID.ed25519(public.master_key): sign_olm(self_key, keys.master_key), + } + } + user_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.USER], + keys={KeyID.ed25519(public.user_signing_key): public.user_signing_key}, + ) + user_key.signatures = { + self.client.mxid: { + KeyID.ed25519(public.master_key): sign_olm(user_key, keys.master_key), + } + } + await self.client.upload_cross_signing_keys( + keys={ + CrossSigningUsage.MASTER: master_key, + CrossSigningUsage.SELF: self_key, + CrossSigningUsage.USER: user_key, + }, + auth=auth, + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.MASTER, public.master_key + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.SELF, public.self_signing_key + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.USER, public.user_signing_key + ) + self._cross_signing_private_keys = keys + self._cross_signing_public_keys = public diff --git a/mautrix/crypto/cross_signing_key.py b/mautrix/crypto/cross_signing_key.py new file mode 100644 index 00000000..f4e3c1c1 --- /dev/null +++ b/mautrix/crypto/cross_signing_key.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import NamedTuple + +import olm + +from mautrix.crypto.ssss.util import cryptorand +from mautrix.types import SigningKey + + +class CrossSigningPublicKeys(NamedTuple): + master_key: SigningKey + self_signing_key: SigningKey + user_signing_key: SigningKey + + +class CrossSigningPrivateKeys(NamedTuple): + master_key: olm.PkSigning + self_signing_key: olm.PkSigning + user_signing_key: olm.PkSigning + + @property + def public_keys(self) -> CrossSigningPublicKeys: + return CrossSigningPublicKeys( + master_key=self.master_key.public_key, + self_signing_key=self.self_signing_key.public_key, + user_signing_key=self.user_signing_key.public_key, + ) + + +class CrossSigningSeeds(NamedTuple): + master_key: bytes + self_signing_key: bytes + user_signing_key: bytes + + def to_keys(self) -> CrossSigningPrivateKeys: + return CrossSigningPrivateKeys( + master_key=olm.PkSigning(self.master_key), + self_signing_key=olm.PkSigning(self.self_signing_key), + user_signing_key=olm.PkSigning(self.user_signing_key), + ) + + @classmethod + def generate(cls) -> "CrossSigningSeeds": + return cls( + master_key=cryptorand.read(32), + self_signing_key=cryptorand.read(32), + user_signing_key=cryptorand.read(32), + ) diff --git a/mautrix/crypto/decrypt_megolm.py b/mautrix/crypto/decrypt_megolm.py index 8edf5aaf..fd7a7eef 100644 --- a/mautrix/crypto/decrypt_megolm.py +++ b/mautrix/crypto/decrypt_megolm.py @@ -25,6 +25,7 @@ ) from .device_lists import DeviceListMachine +from .sessions import InboundGroupSession class MegolmDecryptionMachine(DeviceListMachine): @@ -45,18 +46,22 @@ async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event: raise DecryptionError("Unsupported event content class") elif evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1: raise DecryptionError("Unsupported event encryption algorithm") - session = await self.crypto_store.get_group_session(evt.room_id, evt.content.session_id) - if session is None: - # TODO check if olm session is wedged - raise SessionNotFound(evt.content.session_id, evt.content.sender_key) - try: - plaintext, index = session.decrypt(evt.content.ciphertext) - except olm.OlmGroupSessionError as e: - raise DecryptionError("Failed to decrypt megolm event") from e - if not await self.crypto_store.validate_message_index( - session.sender_key, SessionID(session.id), evt.event_id, index, evt.timestamp - ): - raise DuplicateMessageIndex() + async with self._megolm_decrypt_lock: + session = await self.crypto_store.get_group_session( + evt.room_id, evt.content.session_id + ) + if session is None: + # TODO check if olm session is wedged + raise SessionNotFound(evt.content.session_id, evt.content.sender_key) + try: + plaintext, index = session.decrypt(evt.content.ciphertext) + except olm.OlmGroupSessionError as e: + raise DecryptionError("Failed to decrypt megolm event") from e + if not await self.crypto_store.validate_message_index( + session.sender_key, SessionID(session.id), evt.event_id, index, evt.timestamp + ): + raise DuplicateMessageIndex() + await self._ratchet_session(session, index) forwarded_keys = False if ( @@ -133,3 +138,55 @@ async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event: "was_encrypted": True, } return result + + async def _ratchet_session(self, sess: InboundGroupSession, index: int) -> None: + expected_message_index = sess.ratchet_safety.next_index + did_modify = True + if index > expected_message_index: + sess.ratchet_safety.missed_indices += list(range(expected_message_index, index)) + sess.ratchet_safety.next_index = index + 1 + elif index == expected_message_index: + sess.ratchet_safety.next_index = index + 1 + else: + try: + sess.ratchet_safety.missed_indices.remove(index) + except ValueError: + did_modify = False + # Use presence of received_at as a sign that this is a recent megolm session, + # and therefore it's safe to drop missed indices entirely. + if ( + sess.received_at + and sess.ratchet_safety.missed_indices + and sess.ratchet_safety.missed_indices[0] < expected_message_index - 10 + ): + i = 0 + for i, lost_index in enumerate(sess.ratchet_safety.missed_indices): + if lost_index < expected_message_index - 10: + sess.ratchet_safety.lost_indices.append(lost_index) + else: + break + sess.ratchet_safety.missed_indices = sess.ratchet_safety.missed_indices[i + 1 :] + ratchet_target_index = sess.ratchet_safety.next_index + if len(sess.ratchet_safety.missed_indices) > 0: + ratchet_target_index = min(sess.ratchet_safety.missed_indices) + self.log.debug( + f"Ratchet safety info for {sess.id}: {sess.ratchet_safety}, {ratchet_target_index=}" + ) + sess_id = SessionID(sess.id) + if ( + sess.max_messages + and ratchet_target_index >= sess.max_messages + and not sess.ratchet_safety.missed_indices + and self.delete_fully_used_keys_on_decrypt + ): + self.log.info(f"Deleting fully used session {sess.id}") + await self.crypto_store.redact_group_session( + sess.room_id, sess_id, reason="maximum messages reached" + ) + return + elif sess.first_known_index < ratchet_target_index and self.ratchet_keys_on_decrypt: + self.log.info(f"Ratcheting session {sess.id} to {ratchet_target_index}") + sess = sess.ratchet_to(ratchet_target_index) + elif not did_modify: + return + await self.crypto_store.put_group_session(sess.room_id, sess.sender_key, sess_id, sess) diff --git a/mautrix/crypto/decrypt_olm.py b/mautrix/crypto/decrypt_olm.py index 8182b160..6eef76f5 100644 --- a/mautrix/crypto/decrypt_olm.py +++ b/mautrix/crypto/decrypt_olm.py @@ -19,6 +19,7 @@ ToDeviceEvent, UserID, ) +from mautrix.util import background_task from .base import BaseOlmMachine from .sessions import Session @@ -74,19 +75,19 @@ async def _decrypt_olm_ciphertext( f"Found matching session yet decryption failed for sender {sender}" f" with key {sender_key}" ) - asyncio.create_task(self._unwedge_session(sender, sender_key)) + background_task.create(self._unwedge_session(sender, sender_key)) raise if not plaintext: if message.type != OlmMsgType.PREKEY: - asyncio.create_task(self._unwedge_session(sender, sender_key)) + background_task.create(self._unwedge_session(sender, sender_key)) raise DecryptionError("Decryption failed for normal message") self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}") try: session = await self._create_inbound_session(sender_key, message.body) except olm.OlmSessionError as e: - asyncio.create_task(self._unwedge_session(sender, sender_key)) + background_task.create(self._unwedge_session(sender, sender_key)) raise DecryptionError("Failed to create new session from prekey message") from e self.log.debug( f"Created inbound session {session.id} for {sender} (sender key: {sender_key})" diff --git a/mautrix/crypto/device_lists.py b/mautrix/crypto/device_lists.py index c0cea43f..b00b104c 100644 --- a/mautrix/crypto/device_lists.py +++ b/mautrix/crypto/device_lists.py @@ -23,10 +23,23 @@ UserID, ) -from .base import BaseOlmMachine, verify_signature_json +from .base import BaseOlmMachine +from .signature import verify_signature_json class DeviceListMachine(BaseOlmMachine): + @property + def own_identity(self) -> DeviceIdentity: + return DeviceIdentity( + user_id=self.client.mxid, + device_id=self.client.device_id, + identity_key=self.account.identity_key, + signing_key=self.account.signing_key, + trust=TrustState.VERIFIED, + deleted=False, + name="", + ) + async def _fetch_keys( self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False ) -> dict[UserID, dict[DeviceID, DeviceIdentity]]: @@ -46,51 +59,67 @@ async def _fetch_keys( data = {} for user_id, devices in resp.device_keys.items(): missing_users.remove(user_id) - - new_devices = {} - existing_devices = (await self.crypto_store.get_devices(user_id)) or {} - - self.log.trace( - f"Updating devices for {user_id}, got {len(devices)}, " - f"have {len(existing_devices)} in store" - ) - changed = False - ssks = resp.self_signing_keys.get(user_id) - ssk = ssks.first_ed25519_key if ssks else None - for device_id, device_keys in devices.items(): - try: - existing = existing_devices[device_id] - except KeyError: - existing = None - changed = True - self.log.trace(f"Validating device {device_keys} of {user_id}") - try: - new_device = await self._validate_device( - user_id, device_id, device_keys, existing - ) - except DeviceValidationError as e: - self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}") - else: - if new_device: - new_devices[device_id] = new_device - await self._store_device_self_signatures(device_keys, ssk) - self.log.debug( - f"Storing new device list for {user_id} containing {len(new_devices)} devices" - ) - await self.crypto_store.put_devices(user_id, new_devices) - data[user_id] = new_devices - - if changed or len(new_devices) != len(existing_devices): - await self.on_devices_changed(user_id) + async with self.crypto_store.transaction(): + data[user_id] = await self._process_fetched_keys(user_id, devices, resp) for user_id in missing_users: self.log.warning(f"Didn't get any devices for user {user_id}") - for user_id in users: - await self._store_cross_signing_keys(resp, user_id) - return data + async def _process_fetched_keys( + self, + user_id: UserID, + devices: dict[DeviceID, DeviceKeys], + resp: QueryKeysResponse, + ) -> dict[DeviceID, DeviceIdentity]: + new_devices = {} + existing_devices = (await self.crypto_store.get_devices(user_id)) or {} + + self.log.trace( + f"Updating devices for {user_id}, got {len(devices)}, " + f"have {len(existing_devices)} in store" + ) + changed = False + ssks = resp.self_signing_keys.get(user_id) + ssk = ssks.first_ed25519_key if ssks else None + for device_id, device_keys in devices.items(): + try: + existing = existing_devices[device_id] + except KeyError: + existing = None + changed = True + self.log.trace(f"Validating device {device_keys} of {user_id}") + try: + new_device = await self._validate_device(user_id, device_id, device_keys, existing) + except DeviceValidationError as e: + self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}") + else: + if new_device: + new_devices[device_id] = new_device + await self._store_device_self_signatures(device_keys, ssk) + self.log.debug( + f"Storing new device list for {user_id} containing {len(new_devices)} devices" + ) + await self.crypto_store.put_devices(user_id, new_devices) + + if changed or len(new_devices) != len(existing_devices): + if self.delete_keys_on_device_delete: + for device_id in existing_devices.keys() - new_devices.keys(): + device = existing_devices[device_id] + removed_ids = await self.crypto_store.redact_group_sessions( + room_id=None, sender_key=device.identity_key, reason="device removed" + ) + self.log.info( + "Redacted megolm sessions sent by removed device " + f"{device.user_id}/{device.device_id}: {removed_ids}" + ) + await self.on_devices_changed(user_id) + + await self._store_cross_signing_keys(resp, user_id) + + return new_devices + async def _store_device_self_signatures( self, device_keys: DeviceKeys, self_signing_key: SigningKey | None ) -> None: @@ -183,7 +212,7 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User signing_key = device.ed25519 except KeyError: pass - if len(signing_key) != 43: + if not signing_key or len(signing_key) != 43: self.log.debug( f"Cross-signing key {user_id}/{actual_key} has a signature from " f"an unknown key {key_id}" @@ -209,6 +238,12 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User else: self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}") + async def _get_full_device_keys(self, device: DeviceIdentity) -> DeviceKeys: + resp = await self.client.query_keys({device.user_id: [device.device_id]}) + keys = resp.device_keys[device.user_id][device.device_id] + await self._validate_device(device.user_id, device.device_id, keys, device) + return keys + async def get_or_fetch_device( self, user_id: UserID, device_id: DeviceID ) -> DeviceIdentity | None: @@ -234,6 +269,8 @@ async def get_or_fetch_device_by_key( return None async def on_devices_changed(self, user_id: UserID) -> None: + if self.disable_device_change_key_rotation: + return shared_rooms = await self.state_store.find_shared_rooms(user_id) self.log.debug( f"Devices of {user_id} changed, invalidating group session in {shared_rooms}" @@ -284,18 +321,23 @@ async def _validate_device( deleted=False, ) - async def resolve_trust(self, device: DeviceIdentity) -> TrustState: + async def resolve_trust(self, device: DeviceIdentity, allow_fetch: bool = True) -> TrustState: try: - return await self._try_resolve_trust(device) + return await self._try_resolve_trust(device, allow_fetch) except Exception: self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}") return TrustState.UNVERIFIED - async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState: - if device.trust in (TrustState.VERIFIED, TrustState.BLACKLISTED): + async def _try_resolve_trust( + self, device: DeviceIdentity, allow_fetch: bool = True + ) -> TrustState: + if device.device_id != self.client.device_id and device.trust in ( + TrustState.VERIFIED, + TrustState.BLACKLISTED, + ): return device.trust their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id) - if len(their_keys) == 0 and device.user_id not in self._cs_fetch_attempted: + if len(their_keys) == 0 and allow_fetch and device.user_id not in self._cs_fetch_attempted: self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...") async with self._fetch_keys_lock: if device.user_id not in self._cs_fetch_attempted: @@ -306,7 +348,8 @@ async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState: msk = their_keys[CrossSigningUsage.MASTER] ssk = their_keys[CrossSigningUsage.SELF] except KeyError as e: - self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}") + if allow_fetch: + self.log.warning(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}") return TrustState.UNVERIFIED ssk_signed = await self.crypto_store.is_key_signed_by( target=CrossSigner(device.user_id, ssk.key), diff --git a/mautrix/crypto/encrypt_megolm.py b/mautrix/crypto/encrypt_megolm.py index 9b37e486..459bbf1f 100644 --- a/mautrix/crypto/encrypt_megolm.py +++ b/mautrix/crypto/encrypt_megolm.py @@ -5,7 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Dict, List, Tuple, Union from collections import defaultdict -from datetime import timedelta +from datetime import datetime, timedelta import asyncio import json import time @@ -95,9 +95,9 @@ async def _encrypt_megolm_event( { "room_id": room_id, "type": event_type.serialize(), - "content": content.serialize() - if isinstance(content, Serializable) - else content, + "content": ( + content.serialize() if isinstance(content, Serializable) else content + ), } ) ) @@ -173,21 +173,6 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No session = await self._new_outbound_group_session(room_id) self.log.debug(f"Sharing group session {session.id} for room {room_id} with {users}") - encryption_info = await self.state_store.get_encryption_info(room_id) - if encryption_info: - if encryption_info.algorithm != EncryptionAlgorithm.MEGOLM_V1: - raise SessionShareError("Room encryption algorithm is not supported") - session.max_messages = encryption_info.rotation_period_msgs or session.max_messages - session.max_age = ( - timedelta(milliseconds=encryption_info.rotation_period_ms) - if encryption_info.rotation_period_ms - else session.max_age - ) - self.log.debug( - "Got stored encryption state event and configured session to rotate " - f"after {session.max_messages} messages or {session.max_age}" - ) - olm_sessions: DeviceMap = defaultdict(lambda: {}) withhold_key_msgs = defaultdict(lambda: {}) missing_sessions: Dict[UserID, Dict[DeviceID, DeviceIdentity]] = defaultdict(lambda: {}) @@ -253,13 +238,33 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No async def _new_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession: session = OutboundGroupSession(room_id) - await self._create_group_session( - self.account.identity_key, - self.account.signing_key, - room_id, - SessionID(session.id), - session.session_key, - ) + + encryption_info = await self.state_store.get_encryption_info(room_id) + if encryption_info: + if encryption_info.algorithm != EncryptionAlgorithm.MEGOLM_V1: + raise SessionShareError("Room encryption algorithm is not supported") + session.max_messages = encryption_info.rotation_period_msgs or session.max_messages + session.max_age = ( + timedelta(milliseconds=encryption_info.rotation_period_ms) + if encryption_info.rotation_period_ms + else session.max_age + ) + self.log.debug( + "Got stored encryption state event and configured session to rotate " + f"after {session.max_messages} messages or {session.max_age}" + ) + + if not self.dont_store_outbound_keys: + await self._create_group_session( + self.account.identity_key, + self.account.signing_key, + room_id, + SessionID(session.id), + session.session_key, + max_messages=session.max_messages, + max_age=session.max_age, + is_scheduled=False, + ) return session async def _encrypt_and_share_group_session( @@ -286,6 +291,9 @@ async def _create_group_session( room_id: RoomID, session_id: SessionID, session_key: str, + max_age: Union[timedelta, int], + max_messages: int, + is_scheduled: bool = False, ) -> None: start = time.monotonic() session = InboundGroupSession( @@ -293,6 +301,10 @@ async def _create_group_session( signing_key=signing_key, sender_key=sender_key, room_id=room_id, + received_at=datetime.utcnow(), + max_age=max_age, + max_messages=max_messages, + is_scheduled=is_scheduled, ) olm_duration = time.monotonic() - start if olm_duration > 5: @@ -302,7 +314,10 @@ async def _create_group_session( session_id = session.id await self.crypto_store.put_group_session(room_id, sender_key, session_id, session) self._mark_session_received(session_id) - self.log.debug(f"Created inbound group session {room_id}/{sender_key}/{session_id}") + self.log.debug( + f"Created inbound group session {room_id}/{sender_key}/{session_id} " + f"(max {max_age} / {max_messages} messages, {is_scheduled=})" + ) async def _find_olm_sessions( self, diff --git a/mautrix/crypto/encrypt_olm.py b/mautrix/crypto/encrypt_olm.py index 620cbbc8..029ad1e5 100644 --- a/mautrix/crypto/encrypt_olm.py +++ b/mautrix/crypto/encrypt_olm.py @@ -18,8 +18,9 @@ UserID, ) -from .base import BaseOlmMachine, verify_signature_json +from .base import BaseOlmMachine from .sessions import Session +from .signature import verify_signature_json ClaimKeysList = Dict[UserID, Dict[DeviceID, DeviceIdentity]] diff --git a/mautrix/crypto/key_request.py b/mautrix/crypto/key_request.py index 9bbd2d67..ecaf994d 100644 --- a/mautrix/crypto/key_request.py +++ b/mautrix/crypto/key_request.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -120,9 +120,18 @@ async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None: f"{evt.sender_device}, as crypto store says we have it already" ) return + if not key.beeper_max_messages or not key.beeper_max_age_ms: + await self._fill_encryption_info(key) key.forwarding_key_chain.append(evt.sender_key) sess = InboundGroupSession.import_session( - key.session_key, key.signing_key, key.sender_key, key.room_id, key.forwarding_key_chain + key.session_key, + key.signing_key, + key.sender_key, + key.room_id, + key.forwarding_key_chain, + max_age=key.beeper_max_age_ms, + max_messages=key.beeper_max_messages, + is_scheduled=key.beeper_is_scheduled, ) if key.session_id != sess.id: self.log.warning( diff --git a/mautrix/crypto/key_share.py b/mautrix/crypto/key_share.py index 58525bbf..8d6b6e73 100644 --- a/mautrix/crypto/key_share.py +++ b/mautrix/crypto/key_share.py @@ -76,7 +76,7 @@ async def default_allow_key_share( code=RoomKeyWithheldCode.BLACKLISTED, reason="You have been blacklisted by this device", ) - elif device.trust >= self.share_keys_min_trust: + elif await self.resolve_trust(device) >= self.share_keys_min_trust: self.log.debug(f"Accepting key request from trusted device {device.device_id}") return True else: diff --git a/mautrix/crypto/machine.py b/mautrix/crypto/machine.py index a1fd4e16..60c65677 100644 --- a/mautrix/crypto/machine.py +++ b/mautrix/crypto/machine.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,8 +8,10 @@ from typing import Optional import asyncio import logging +import time from mautrix import client as cli +from mautrix.errors import GroupSessionWithheldError from mautrix.types import ( ASToDeviceEvent, DecryptedOlmEvent, @@ -17,20 +19,25 @@ DeviceLists, DeviceOTKCount, EncryptionAlgorithm, + EncryptionKeyAlgorithm, EventType, + Member, Membership, StateEvent, ToDeviceEvent, TrustState, UserID, ) +from mautrix.util import background_task from mautrix.util.logging import TraceLogger from .account import OlmAccount +from .cross_signing import CrossSigningMachine from .decrypt_megolm import MegolmDecryptionMachine from .encrypt_megolm import MegolmEncryptionMachine from .key_request import KeyRequestingMachine from .key_share import KeySharingMachine +from .ssss import Machine as SSSSMachine from .store import CryptoStore, StateStore from .unwedge import OlmUnwedgingMachine @@ -41,6 +48,7 @@ class OlmMachine( OlmUnwedgingMachine, KeySharingMachine, KeyRequestingMachine, + CrossSigningMachine, ): """ OlmMachine is the main class for handling things related to Matrix end-to-end encryption with @@ -53,6 +61,7 @@ class OlmMachine( log: TraceLogger crypto_store: CryptoStore state_store: StateStore + ssss: SSSSMachine account: Optional[OlmAccount] @@ -65,6 +74,7 @@ def __init__( ) -> None: super().__init__() self.client = client + self.ssss = SSSSMachine(client) self.log = log or logging.getLogger("mau.crypto") self.crypto_store = crypto_store self.state_store = state_store @@ -74,18 +84,34 @@ def __init__( self.share_keys_min_trust = TrustState.CROSS_SIGNED_TOFU self.allow_key_share = self.default_allow_key_share + self.delete_outbound_keys_on_ack = False + self.dont_store_outbound_keys = False + self.delete_previous_keys_on_receive = False + self.ratchet_keys_on_decrypt = False + self.delete_fully_used_keys_on_decrypt = False + self.delete_keys_on_device_delete = False + self.disable_device_change_key_rotation = False + self._fetch_keys_lock = asyncio.Lock() + self._megolm_decrypt_lock = asyncio.Lock() + self._share_keys_lock = asyncio.Lock() + self._last_key_share = time.monotonic() - 60 self._key_request_waiters = {} self._inbound_session_waiters = {} self._prev_unwedge = {} self._cs_fetch_attempted = set() + self._cross_signing_public_keys = None + self._cross_signing_public_keys_fetched = False + self._cross_signing_private_keys = None + self.client.add_event_handler( cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True ) self.client.add_event_handler(cli.InternalEventType.DEVICE_LISTS, self.handle_device_lists) self.client.add_event_handler(EventType.TO_DEVICE_ENCRYPTED, self.handle_to_device_event) self.client.add_event_handler(EventType.ROOM_KEY_REQUEST, self.handle_room_key_request) + self.client.add_event_handler(EventType.BEEPER_ROOM_KEY_ACK, self.handle_beep_room_key_ack) # self.client.add_event_handler(EventType.ROOM_KEY_WITHHELD, self.handle_room_key_withheld) # self.client.add_event_handler(EventType.ORG_MATRIX_ROOM_KEY_WITHHELD, # self.handle_room_key_withheld) @@ -109,7 +135,7 @@ async def handle_as_otk_counts( self.log.warning(f"Got OTK count for unknown device {user_id}/{device_id}") async def handle_as_device_lists(self, device_lists: DeviceLists) -> None: - asyncio.create_task(self.handle_device_lists(device_lists)) + background_task.create(self.handle_device_lists(device_lists)) async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None: if evt.to_user_id != self.client.mxid or evt.to_device_id != self.client.device_id: @@ -121,6 +147,8 @@ async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None: await self.handle_to_device_event(evt) elif evt.type == EventType.ROOM_KEY_REQUEST: await self.handle_room_key_request(evt) + elif evt.type == EventType.BEEPER_ROOM_KEY_ACK: + await self.handle_beep_room_key_ack(evt) else: self.log.debug(f"Got unknown to-device event {evt.type} from {evt.sender}") @@ -170,9 +198,19 @@ async def handle_member_event(self, evt: StateEvent) -> None: } if prev == cur or ignored_changes.get(prev) == cur: return + src = getattr(evt, "source", None) + prev_cache = evt.unsigned.get("mautrix_prev_membership") + if isinstance(prev_cache, Member) and prev_cache.membership == cur: + self.log.debug( + f"Got duplicate membership state event in {evt.room_id} changing {evt.state_key} " + f"from {prev} to {cur}, cached state was {prev_cache} (event ID: {evt.event_id}, " + f"sync source: {src})" + ) + return self.log.debug( f"Got membership state event in {evt.room_id} changing {evt.state_key} from " - f"{prev} to {cur}, invalidating group session" + f"{prev} to {cur} (event ID: {evt.event_id}, sync source: {src}, " + f"cached: {prev_cache.membership if prev_cache else None}), invalidating group session" ) await self.crypto_store.remove_outbound_group_session(evt.room_id) @@ -184,6 +222,11 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None: passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you do syncing in some manual way. """ + if isinstance(evt, DecryptedOlmEvent): + self.log.warning( + f"Dropping unexpected nested encrypted to-device event from {evt.sender}" + ) + return self.log.trace( f"Handling encrypted to-device event from {evt.content.sender_key} ({evt.sender})" ) @@ -192,28 +235,81 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None: await self._receive_room_key(decrypted_evt) elif decrypted_evt.type == EventType.FORWARDED_ROOM_KEY: await self._receive_forwarded_room_key(decrypted_evt) + else: + self.client.dispatch_event(decrypted_evt, source=evt.source) async def _receive_room_key(self, evt: DecryptedOlmEvent) -> None: # TODO nio had a comment saying "handle this better" # for the case where evt.Keys.Ed25519 is none? if evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1 or not evt.keys.ed25519: return + if not evt.content.beeper_max_messages or not evt.content.beeper_max_age_ms: + await self._fill_encryption_info(evt.content) + if self.delete_previous_keys_on_receive and not evt.content.beeper_is_scheduled: + removed_ids = await self.crypto_store.redact_group_sessions( + evt.content.room_id, evt.sender_key, reason="received new key from device" + ) + self.log.info(f"Redacted previous megolm sessions: {removed_ids}") await self._create_group_session( evt.sender_key, evt.keys.ed25519, evt.content.room_id, evt.content.session_id, evt.content.session_key, + max_age=evt.content.beeper_max_age_ms, + max_messages=evt.content.beeper_max_messages, + is_scheduled=evt.content.beeper_is_scheduled, ) - async def share_keys(self, current_otk_count: int) -> None: + async def handle_beep_room_key_ack(self, evt: ToDeviceEvent) -> None: + try: + sess = await self.crypto_store.get_group_session( + evt.content.room_id, evt.content.session_id + ) + except GroupSessionWithheldError: + self.log.debug( + f"Ignoring room key ack for session {evt.content.session_id}" + " that was already redacted" + ) + return + if not sess: + self.log.debug(f"Ignoring room key ack for unknown session {evt.content.session_id}") + return + if ( + sess.sender_key == self.account.identity_key + and self.delete_outbound_keys_on_ack + and evt.content.first_message_index == 0 + ): + self.log.debug("Redacting inbound copy of outbound group session after ack") + await self.crypto_store.redact_group_session( + evt.content.room_id, evt.content.session_id, reason="outbound session acked" + ) + else: + self.log.debug(f"Received room key ack for {sess.id}") + + async def share_keys(self, current_otk_count: int | None = None) -> None: """ Share any keys that need to be shared. This is automatically called from :meth:`handle_otk_count`, so you should not need to call this yourself. Args: current_otk_count: The current number of signed curve25519 keys present on the server. + If omitted, the count will be fetched from the server. """ + async with self._share_keys_lock: + await self._share_keys(current_otk_count) + + async def _share_keys(self, current_otk_count: int | None) -> None: + if current_otk_count is None or ( + # If the last key share was recent and the new count is very low, re-check the count + # from the server to avoid any race conditions. + self._last_key_share + 60 > time.monotonic() + and current_otk_count < 10 + ): + self.log.debug("Checking OTK count on server") + current_otk_count = (await self.client.upload_keys()).get( + EncryptionKeyAlgorithm.SIGNED_CURVE25519, 0 + ) device_keys = ( self.account.get_device_keys(self.client.mxid, self.client.device_id) if not self.account.shared @@ -228,7 +324,9 @@ async def share_keys(self, current_otk_count: int) -> None: if device_keys: self.log.debug("Going to upload initial account keys") self.log.debug(f"Uploading {len(one_time_keys)} one-time keys") - await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys) + resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys) self.account.shared = True + self.account.mark_keys_as_published() + self._last_key_share = time.monotonic() await self.crypto_store.put_account(self.account) - self.log.debug("Shared keys and saved account") + self.log.debug(f"Shared keys and saved account, new keys: {resp}") diff --git a/mautrix/crypto/sessions.py b/mautrix/crypto/sessions.py index c8164b27..b0b16a18 100644 --- a/mautrix/crypto/sessions.py +++ b/mautrix/crypto/sessions.py @@ -3,10 +3,11 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import List, Optional, Set, Tuple, cast +from typing import List, Optional, Set, Tuple, Union, cast from datetime import datetime, timedelta from _libolm import ffi, lib +from attr import dataclass import olm from mautrix.errors import EncryptionError @@ -18,8 +19,10 @@ OlmMsgType, RoomID, RoomKeyEventContent, + SerializableAttrs, SigningKey, UserID, + field, ) @@ -93,12 +96,25 @@ def describe(self) -> str: return "describe not supported" +@dataclass +class RatchetSafety(SerializableAttrs): + next_index: int = 0 + missed_indices: List[int] = field(factory=lambda: []) + lost_indices: List[int] = field(factory=lambda: []) + + class InboundGroupSession(olm.InboundGroupSession): room_id: RoomID signing_key: SigningKey sender_key: IdentityKey forwarding_chain: List[IdentityKey] + ratchet_safety: RatchetSafety + received_at: datetime + max_age: timedelta + max_messages: int + is_scheduled: bool + def __init__( self, session_key: str, @@ -106,11 +122,23 @@ def __init__( sender_key: IdentityKey, room_id: RoomID, forwarding_chain: Optional[List[IdentityKey]] = None, + ratchet_safety: Optional[RatchetSafety] = None, + received_at: Optional[datetime] = None, + max_age: Union[timedelta, int, None] = None, + max_messages: Optional[int] = None, + is_scheduled: bool = False, ) -> None: self.signing_key = signing_key self.sender_key = sender_key self.room_id = room_id self.forwarding_chain = forwarding_chain or [] + self.ratchet_safety = ratchet_safety or RatchetSafety() + self.received_at = received_at or datetime.utcnow() + if isinstance(max_age, int): + max_age = timedelta(milliseconds=max_age) + self.max_age = max_age + self.max_messages = max_messages + self.is_scheduled = is_scheduled super().__init__(session_key) def __new__(cls, *args, **kwargs): @@ -125,12 +153,22 @@ def from_pickle( sender_key: IdentityKey, room_id: RoomID, forwarding_chain: Optional[List[IdentityKey]] = None, + ratchet_safety: Optional[RatchetSafety] = None, + received_at: Optional[datetime] = None, + max_age: Optional[timedelta] = None, + max_messages: Optional[int] = None, + is_scheduled: bool = False, ) -> "InboundGroupSession": session = super().from_pickle(pickle, passphrase) session.signing_key = signing_key session.sender_key = sender_key session.room_id = room_id session.forwarding_chain = forwarding_chain or [] + session.ratchet_safety = ratchet_safety or RatchetSafety() + session.received_at = received_at + session.max_age = max_age + session.max_messages = max_messages + session.is_scheduled = is_scheduled return session @classmethod @@ -141,14 +179,41 @@ def import_session( sender_key: IdentityKey, room_id: RoomID, forwarding_chain: Optional[List[str]] = None, + ratchet_safety: Optional[RatchetSafety] = None, + received_at: Optional[datetime] = None, + max_age: Union[timedelta, int, None] = None, + max_messages: Optional[int] = None, + is_scheduled: bool = False, ) -> "InboundGroupSession": session = super().import_session(session_key) session.signing_key = signing_key session.sender_key = sender_key session.room_id = room_id session.forwarding_chain = forwarding_chain or [] + session.ratchet_safety = ratchet_safety or RatchetSafety() + session.received_at = received_at or datetime.utcnow() + if isinstance(max_age, int): + max_age = timedelta(milliseconds=max_age) + session.max_age = max_age + session.max_messages = max_messages + session.is_scheduled = is_scheduled return session + def ratchet_to(self, index: int) -> "InboundGroupSession": + exported = self.export_session(index) + return self.import_session( + exported, + signing_key=self.signing_key, + sender_key=self.sender_key, + room_id=self.room_id, + forwarding_chain=self.forwarding_chain, + ratchet_safety=self.ratchet_safety, + received_at=self.received_at, + max_age=self.max_age, + max_messages=self.max_messages, + is_scheduled=self.is_scheduled, + ) + class OutboundGroupSession(olm.OutboundGroupSession): """Outbound group session aware of the users it is shared with. diff --git a/mautrix/crypto/signature.py b/mautrix/crypto/signature.py new file mode 100644 index 00000000..6dc13e65 --- /dev/null +++ b/mautrix/crypto/signature.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Any, TypedDict +import functools +import json + +import olm +import unpaddedbase64 + +from mautrix.types import ( + JSON, + DeviceID, + EncryptionKeyAlgorithm, + KeyID, + Serializable, + Signature, + SigningKey, + UserID, +) + +try: + from Crypto.PublicKey import ECC + from Crypto.Signature import eddsa +except ImportError: + from Cryptodome.PublicKey import ECC + from Cryptodome.Signature import eddsa + +canonical_json = functools.partial( + json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True +) + + +class SignedObject(TypedDict): + signatures: dict[UserID, dict[str, str]] + unsigned: Any + + +def sign_olm(data: dict[str, JSON] | Serializable, key: olm.PkSigning | olm.Account) -> Signature: + if isinstance(data, Serializable): + data = data.serialize() + data.pop("signatures", None) + data.pop("unsigned", None) + return Signature(key.sign(canonical_json(data))) + + +def verify_signature_json( + data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey +) -> bool: + data_copy = {**data} + data_copy.pop("unsigned", None) + signatures = data_copy.pop("signatures") + key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name)) + try: + signature = signatures[user_id][key_id] + decoded_key = unpaddedbase64.decode_base64(key) + # pycryptodome doesn't accept raw keys, so wrap it in a DER structure + der_key = b"\x30\x2a\x30\x05\x06\x03\x2b\x65\x70\x03\x21\x00" + decoded_key + decoded_signature = unpaddedbase64.decode_base64(signature) + parsed_key = ECC.import_key(der_key) + verifier = eddsa.new(parsed_key, "rfc8032") + verifier.verify(canonical_json(data_copy).encode("utf-8"), decoded_signature) + return True + except (KeyError, ValueError): + return False diff --git a/mautrix/crypto/signature_test.py b/mautrix/crypto/signature_test.py new file mode 100644 index 00000000..115e4836 --- /dev/null +++ b/mautrix/crypto/signature_test.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from mautrix.types import SigningKey, UserID + +from .signature import verify_signature_json + + +def test_verify_signature_json() -> None: + assert verify_signature_json( + # This is actually a federation PDU rather than a device signature, + # but they're both 25519 curves so it doesn't make a difference. + { + "auth_events": [ + "$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw", + "$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM", + "$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo", + ], + "content": {}, + "depth": 3212, + "hashes": {"sha256": "K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"}, + "origin_server_ts": 1754242687127, + "prev_events": ["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"], + "room_id": "!offtopic-2:continuwuity.org", + "sender": "@tulir:maunium.net", + "type": "m.room.message", + "signatures": { + UserID("maunium.net"): { + "ed25519:a_xxeS": "SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw" + } + }, + "unsigned": {"age_ts": 1754242687146}, + }, + UserID("maunium.net"), + "a_xxeS", + SigningKey("lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg"), + ) diff --git a/mautrix/crypto/ssss/__init__.py b/mautrix/crypto/ssss/__init__.py new file mode 100644 index 00000000..9224418d --- /dev/null +++ b/mautrix/crypto/ssss/__init__.py @@ -0,0 +1,8 @@ +from .key import Key, KeyMetadata, PassphraseMetadata +from .machine import Machine +from .types import ( + Algorithm, + EncryptedAccountDataEventContent, + EncryptedKeyData, + PassphraseAlgorithm, +) diff --git a/mautrix/crypto/ssss/key.py b/mautrix/crypto/ssss/key.py new file mode 100644 index 00000000..691ded71 --- /dev/null +++ b/mautrix/crypto/ssss/key.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Optional +import base64 +import hashlib +import hmac + +from attr import dataclass +import unpaddedbase64 + +from mautrix.types import EventType, SerializableAttrs + +from .types import Algorithm, EncryptedKeyData, PassphraseAlgorithm +from .util import ( + calculate_hash, + cryptorand, + decode_base58_recovery_key, + derive_keys, + encode_base58_recovery_key, + prepare_aes, +) + +try: + from Crypto.Cipher import AES + from Crypto.Util import Counter +except ImportError: + from Cryptodome.Cipher import AES + from Cryptodome.Util import Counter + + +@dataclass +class PassphraseMetadata(SerializableAttrs): + algorithm: PassphraseAlgorithm + iterations: int + salt: str + bits: int = 256 + + def get_key(self, passphrase: str) -> bytes: + if self.algorithm != PassphraseAlgorithm.PBKDF2: + raise ValueError(f"Unsupported passphrase algorithm {self.algorithm}") + return hashlib.pbkdf2_hmac( + "sha512", + passphrase.encode("utf-8"), + self.salt.encode("utf-8"), + self.iterations, + self.bits // 8, + ) + + +@dataclass +class KeyMetadata(SerializableAttrs): + algorithm: Algorithm + + iv: str | None = None + mac: str | None = None + + name: str | None = None + passphrase: Optional[PassphraseMetadata] = None + + def verify_passphrase(self, key_id: str, phrase: str) -> "Key": + if not self.passphrase: + raise ValueError("Passphrase not set on this key") + return self.verify_raw_key(key_id, self.passphrase.get_key(phrase)) + + def verify_recovery_key(self, key_id: str, recovery_key: str) -> "Key": + decoded_key = decode_base58_recovery_key(recovery_key) + if not decoded_key: + raise ValueError("Invalid recovery key syntax") + return self.verify_raw_key(key_id, decoded_key) + + def verify_raw_key(self, key_id: str, key: bytes) -> "Key": + if self.mac.rstrip("=") != calculate_hash(key, self.iv): + raise ValueError("Key MAC does not match") + return Key(id=key_id, key=key, metadata=self) + + +@dataclass +class Key: + id: str + key: bytes + metadata: KeyMetadata + + @classmethod + def generate(cls, passphrase: str | None = None) -> "Key": + passphrase_meta = ( + PassphraseMetadata( + algorithm=PassphraseAlgorithm.PBKDF2, + iterations=500_000, + salt=base64.b64encode(cryptorand.read(24)).decode("utf-8"), + bits=256, + ) + if passphrase + else None + ) + key = passphrase_meta.get_key(passphrase) if passphrase else cryptorand.read(32) + iv = unpaddedbase64.encode_base64(cryptorand.read(16)) + metadata = KeyMetadata( + algorithm=Algorithm.AES_HMAC_SHA2, + passphrase=passphrase_meta, + mac=calculate_hash(key, iv), + iv=iv, + ) + key_id = unpaddedbase64.encode_base64(cryptorand.read(24)) + return cls(key=key, id=key_id, metadata=metadata) + + @property + def recovery_key(self) -> str: + return encode_base58_recovery_key(self.key) + + def encrypt(self, event_type: str | EventType, data: str | bytes) -> EncryptedKeyData: + if isinstance(data, str): + data = data.encode("utf-8") + data = base64.b64encode(data).rstrip(b"=") + + aes_key, hmac_key = derive_keys(self.key, event_type) + iv = bytearray(cryptorand.read(16)) + iv[8] &= 0x7F + ciphertext = prepare_aes(aes_key, iv).encrypt(data) + digest = hmac.digest(hmac_key, ciphertext, hashlib.sha256) + return EncryptedKeyData( + ciphertext=unpaddedbase64.encode_base64(ciphertext), + iv=unpaddedbase64.encode_base64(iv), + mac=unpaddedbase64.encode_base64(digest), + ) + + def decrypt(self, event_type: str | EventType, data: EncryptedKeyData) -> bytes: + aes_key, hmac_key = derive_keys(self.key, event_type) + ciphertext = unpaddedbase64.decode_base64(data.ciphertext) + mac = unpaddedbase64.decode_base64(data.mac) + + expected_mac = hmac.digest(hmac_key, ciphertext, hashlib.sha256) + if not hmac.compare_digest(mac, expected_mac): + raise ValueError("Invalid MAC") + + plaintext = prepare_aes(aes_key, data.iv).decrypt(ciphertext) + return unpaddedbase64.decode_base64(plaintext.decode("utf-8")) diff --git a/mautrix/crypto/ssss/key_test.py b/mautrix/crypto/ssss/key_test.py new file mode 100644 index 00000000..e60f1e3c --- /dev/null +++ b/mautrix/crypto/ssss/key_test.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import pytest + +from ...types.event.type import EventType +from .key import Key, KeyMetadata +from .types import EncryptedAccountDataEventContent + +KEY1_CROSS_SIGNING_MASTER_KEY = """{ + "encrypted": { + "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0": { + "iv": "BpKP9nQJTE9jrsAssoxPqQ==", + "ciphertext": "fNRiiiidezjerTgV+G6pUtmeF3izzj5re/mVvY0hO2kM6kYGrxLuIu2ej80=", + "mac": "/gWGDGMyOLmbJp+aoSLh5JxCs0AdS6nAhjzpe+9G2Q0=" + } + } +}""" + +KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED = bytes( + [ + 0x68, + 0xF9, + 0x7F, + 0xD1, + 0x92, + 0x2E, + 0xEC, + 0xF6, + 0xB8, + 0x2B, + 0xB8, + 0x90, + 0xD2, + 0x4D, + 0x06, + 0x52, + 0x98, + 0x4E, + 0x7A, + 0x1D, + 0x70, + 0x3B, + 0x9E, + 0x86, + 0x7B, + 0x7E, + 0xBA, + 0xF7, + 0xFE, + 0xB9, + 0x5B, + 0x6F, + ] +) + +KEY1_META = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "passphrase": { + "algorithm": "m.pbkdf2", + "iterations": 500000, + "salt": "y863BOoqOadgDp8S3FtHXikDJEalsQ7d" + }, + "iv": "xxkTK0L4UzxgAFkQ6XPwsw", + "mac": "MEhooO0ZhFJNxUhvRMSxBnJfL20wkLgle3ocY0ee/eA" +}""" +KEY1_ID = "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0" +KEY1_RECOVERY_KEY = "EsTE s92N EtaX s2h6 VQYF 9Kao tHYL mkyL GKMh isZb KJ4E tvoC" +KEY1_PASSPHRASE = "correct horse battery staple" + +KEY2_META = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +}""" +KEY2_ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" +KEY2_RECOVERY_KEY = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" + +KEY2_META_BROKEN_IV = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +}""" + +KEY2_META_BROKEN_MAC = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow" +}""" + + +def get_key_meta(meta: str) -> KeyMetadata: + return KeyMetadata.parse_json(meta) + + +def get_key1() -> Key: + return get_key_meta(KEY1_META).verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY) + + +def get_key2() -> Key: + return get_key_meta(KEY2_META).verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def get_encrypted_master_key() -> EncryptedAccountDataEventContent: + return EncryptedAccountDataEventContent.parse_json(KEY1_CROSS_SIGNING_MASTER_KEY) + + +def test_decrypt_success() -> None: + key = get_key1() + emk = get_encrypted_master_key() + assert ( + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) == KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED + ) + + +def test_decrypt_fail_wrong_key() -> None: + key = get_key2() + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) + + +def test_decrypt_fail_fake_key() -> None: + key = get_key2() + key.id = KEY1_ID + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) + + +def test_decrypt_fail_wrong_type() -> None: + key = get_key1() + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_SELF_SIGNING, key) + + +def test_encrypt_roundtrip() -> None: + key = get_key1() + data = bytes([0xDE, 0xAD, 0xBE, 0xEF]) + ciphertext = key.encrypt("net.maunium.data", data) + plaintext = key.decrypt("net.maunium.data", ciphertext) + assert plaintext == data + + +def test_verify_recovery_key_correct() -> None: + meta = get_key_meta(KEY1_META) + key = meta.verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY) + assert key.recovery_key == KEY1_RECOVERY_KEY + + +def test_verify_recovery_key_correct2() -> None: + meta = get_key_meta(KEY2_META) + key = meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + assert key.recovery_key == KEY2_RECOVERY_KEY + + +def test_verify_recovery_key_invalid() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY1_ID, "foo") + + +def test_verify_recovery_key_incorrect() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_recovery_key_broken_iv() -> None: + meta = get_key_meta(KEY2_META_BROKEN_IV) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_recovery_key_broken_mac() -> None: + meta = get_key_meta(KEY2_META_BROKEN_MAC) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_passphrase_correct() -> None: + meta = get_key_meta(KEY1_META) + key = meta.verify_passphrase(KEY1_ID, KEY1_PASSPHRASE) + assert key.recovery_key == KEY1_RECOVERY_KEY + + +def test_verify_passphrase_incorrect() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_passphrase(KEY1_ID, "incorrect horse battery staple") + + +def test_verify_passphrase_notset() -> None: + meta = get_key_meta(KEY2_META) + with pytest.raises(ValueError): + meta.verify_passphrase(KEY2_ID, "hmm") diff --git a/mautrix/crypto/ssss/machine.py b/mautrix/crypto/ssss/machine.py new file mode 100644 index 00000000..c43e25e3 --- /dev/null +++ b/mautrix/crypto/ssss/machine.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from mautrix import client as cli +from mautrix.errors import MNotFound +from mautrix.types import EventType, SecretStorageDefaultKeyEventContent + +from .key import Key, KeyMetadata +from .types import EncryptedAccountDataEventContent + + +class Machine: + client: cli.Client + + def __init__(self, client: cli.Client) -> None: + self.client = client + + async def get_default_key_id(self) -> str | None: + try: + data = await self.client.get_account_data(EventType.SECRET_STORAGE_DEFAULT_KEY) + return SecretStorageDefaultKeyEventContent.deserialize(data).key + except (MNotFound, ValueError): + return None + + async def set_default_key_id(self, key_id: str) -> None: + await self.client.set_account_data( + EventType.SECRET_STORAGE_DEFAULT_KEY, + SecretStorageDefaultKeyEventContent(key=key_id), + ) + + async def get_key_data(self, key_id: str) -> KeyMetadata: + data = await self.client.get_account_data(f"m.secret_storage.key.{key_id}") + return KeyMetadata.deserialize(data) + + async def set_key_data(self, key_id: str, data: KeyMetadata) -> None: + await self.client.set_account_data(f"m.secret_storage.key.{key_id}", data) + + async def get_default_key_data(self) -> tuple[str, KeyMetadata]: + key_id = await self.get_default_key_id() + if not key_id: + raise ValueError("No default key ID set") + return key_id, await self.get_key_data(key_id) + + async def get_decrypted_account_data(self, event_type: EventType | str, key: Key) -> bytes: + data = await self.client.get_account_data(event_type) + parsed = EncryptedAccountDataEventContent.deserialize(data) + return parsed.decrypt(event_type, key) + + async def set_encrypted_account_data( + self, event_type: EventType | str, data: bytes, *keys: Key + ) -> None: + encrypted_data = {} + for key in keys: + encrypted_data[key.id] = key.encrypt(event_type, data) + await self.client.set_account_data( + event_type, + EncryptedAccountDataEventContent(encrypted=encrypted_data), + ) + + async def generate_and_upload_key(self, passphrase: str | None = None) -> Key: + key = Key.generate(passphrase) + await self.set_key_data(key.id, key.metadata) + return key diff --git a/mautrix/crypto/ssss/types.py b/mautrix/crypto/ssss/types.py new file mode 100644 index 00000000..4a47f743 --- /dev/null +++ b/mautrix/crypto/ssss/types.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import TYPE_CHECKING + +from attr import dataclass + +from mautrix.types import EventType, SerializableAttrs, SerializableEnum +from mautrix.types.event.account_data import account_data_event_content_map + +if TYPE_CHECKING: + from .key import Key + + +class Algorithm(SerializableEnum): + AES_HMAC_SHA2 = "m.secret_storage.v1.aes-hmac-sha2" + CURVE25519_AES_SHA2 = "m.secret_storage.v1.curve25519-aes-sha2" + + +class PassphraseAlgorithm(SerializableEnum): + PBKDF2 = "m.pbkdf2" + + +@dataclass +class EncryptedKeyData(SerializableAttrs): + ciphertext: str + iv: str + mac: str + + +@dataclass +class EncryptedAccountDataEventContent(SerializableAttrs): + encrypted: dict[str, EncryptedKeyData] + + def decrypt(self, event_type: str | EventType, key: "Key") -> bytes: + try: + encrypted_data = self.encrypted[key.id] + except KeyError as e: + raise ValueError(f"Event not encrypted for provided key") from e + return key.decrypt(event_type, encrypted_data) + + +for encrypted_account_data_type in ( + EventType.CROSS_SIGNING_MASTER, + EventType.CROSS_SIGNING_USER_SIGNING, + EventType.CROSS_SIGNING_SELF_SIGNING, + EventType.MEGOLM_BACKUP_V1, +): + account_data_event_content_map[encrypted_account_data_type] = EncryptedAccountDataEventContent diff --git a/mautrix/crypto/ssss/util.py b/mautrix/crypto/ssss/util.py new file mode 100644 index 00000000..b58c941a --- /dev/null +++ b/mautrix/crypto/ssss/util.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import hashlib +import hmac + +import base58 +import unpaddedbase64 + +from mautrix.types import EventType + +try: + from Crypto import Random + from Crypto.Cipher import AES + from Crypto.Hash import SHA256 + from Crypto.Protocol.KDF import HKDF + from Crypto.Util import Counter +except ImportError: + from Cryptodome import Random + from Cryptodome.Cipher import AES + from Cryptodome.Hash import SHA256 + from Cryptodome.Protocol.KDF import HKDF + from Cryptodome.Util import Counter + +cryptorand = Random.new() + + +def decode_base58_recovery_key(key: str) -> bytes | None: + key_bytes = base58.b58decode(key.replace(" ", "")) + if len(key_bytes) != 35 or key_bytes[0] != 0x8B or key_bytes[1] != 1: + return None + parity = 0 + for byte in key_bytes[:34]: + parity ^= byte + return key_bytes[2:34] if parity == key_bytes[34] else None + + +def encode_base58_recovery_key(key: bytes) -> str: + key_bytes = bytearray(35) + key_bytes[0] = 0x8B + key_bytes[1] = 1 + key_bytes[2:34] = key + parity = 0 + for byte in key_bytes: + parity ^= byte + key_bytes[34] = parity + encoded_key = base58.b58encode(key_bytes).decode("utf-8") + return " ".join(encoded_key[i : i + 4] for i in range(0, len(encoded_key), 4)) + + +def derive_keys(key: bytes, name: str | EventType = "") -> tuple[bytes, bytes]: + aes_key, hmac_key = HKDF( + master=key, + key_len=32, + salt=b"\x00" * 32, + hashmod=SHA256, + num_keys=2, + context=str(name).encode("utf-8"), + ) + return aes_key, hmac_key + + +def prepare_aes(key: bytes, iv: str | bytes) -> AES: + if isinstance(iv, str): + iv = unpaddedbase64.decode_base64(iv) + # initial_value = struct.unpack(">Q", iv[8:])[0] + # counter = Counter.new(64, prefix=iv[:8], initial_value=initial_value) + counter = Counter.new(128, initial_value=int.from_bytes(iv, byteorder="big")) + return AES.new(key=key, mode=AES.MODE_CTR, counter=counter) + + +def calculate_hash(key: bytes, iv: str | bytes) -> str: + aes_key, hmac_key = derive_keys(key) + cipher = prepare_aes(aes_key, iv).decrypt(b"\x00" * 32) + digest = hmac.digest(hmac_key, cipher, hashlib.sha256) + return unpaddedbase64.encode_base64(digest) diff --git a/mautrix/crypto/store/abstract.py b/mautrix/crypto/store/abstract.py index 7828524d..0916a2d7 100644 --- a/mautrix/crypto/store/abstract.py +++ b/mautrix/crypto/store/abstract.py @@ -5,8 +5,9 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import NamedTuple +from typing import AsyncContextManager, NamedTuple from abc import ABC, abstractmethod +from contextlib import asynccontextmanager from mautrix.types import ( CrossSigner, @@ -87,6 +88,11 @@ async def close(self) -> None: async def flush(self) -> None: """Flush the store. If all the methods persist data immediately, this can be a no-op.""" + @asynccontextmanager + async def transaction(self) -> None: + """Run a database transaction. If the store doesn't support transactions, this can be a no-op.""" + yield + @abstractmethod async def delete(self) -> None: """Delete the data in the store.""" @@ -197,6 +203,56 @@ async def get_group_session( The :class:`InboundGroupSession` object, or ``None`` if not found. """ + @abstractmethod + async def redact_group_session( + self, room_id: RoomID, session_id: SessionID, reason: str + ) -> None: + """ + Remove the keys for a specific Megolm group session. + + Args: + room_id: The room where the session is. + session_id: The session ID to remove. + reason: The reason the session is being removed. + """ + + @abstractmethod + async def redact_group_sessions( + self, room_id: RoomID | None, sender_key: IdentityKey | None, reason: str + ) -> list[SessionID]: + """ + Remove the keys for multiple Megolm group sessions, + based on the room ID and/or sender device. + + Args: + room_id: The room ID to delete keys from. + sender_key: The Olm identity key of the device to delete keys from. + reason: The reason why the keys are being deleted. + + Returns: + The list of session IDs that were deleted. + """ + + @abstractmethod + async def redact_expired_group_sessions(self) -> list[SessionID]: + """ + Remove all Megolm group sessions where at least twice the maximum age has passed since + receiving the keys. + + Returns: + The list of session IDs that were deleted. + """ + + @abstractmethod + async def redact_outdated_group_sessions(self) -> list[SessionID]: + """ + Remove all Megolm group sessions which lack the metadata to determine when they should + expire. + + Returns: + The list of session IDs that were deleted. + """ + @abstractmethod async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: """ diff --git a/mautrix/crypto/store/asyncpg/store.py b/mautrix/crypto/store/asyncpg/store.py index 8609e12d..bdc37ddd 100644 --- a/mautrix/crypto/store/asyncpg/store.py +++ b/mautrix/crypto/store/asyncpg/store.py @@ -6,12 +6,14 @@ from __future__ import annotations from collections import defaultdict +from contextlib import asynccontextmanager from datetime import timedelta from asyncpg import UniqueViolationError from mautrix.client.state_store import SyncStore from mautrix.client.state_store.asyncpg import PgStateStore +from mautrix.errors import GroupSessionWithheldError from mautrix.types import ( CrossSigner, CrossSigningUsage, @@ -20,6 +22,7 @@ EventID, IdentityKey, RoomID, + RoomKeyWithheldCode, SessionID, SigningKey, SyncToken, @@ -30,7 +33,7 @@ from mautrix.util.async_db import Database, Scheme from mautrix.util.logging import TraceLogger -from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, Session +from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, RatchetSafety, Session from ..abstract import CryptoStore, StateStore from .upgrade import upgrade_table @@ -77,6 +80,11 @@ def __init__(self, account_id: str, pickle_key: str, db: Database) -> None: self._account = None self._olm_cache = defaultdict(lambda: {}) + @asynccontextmanager + async def transaction(self) -> None: + async with self.db.acquire() as conn, conn.transaction(): + yield + async def delete(self) -> None: tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session") async with self.db.acquire() as conn, conn.transaction(): @@ -117,7 +125,7 @@ async def put_account(self, account: OlmAccount) -> None: await self.db.execute( q, self.account_id, - self._device_id, + self._device_id or "", account.shared, self._sync_token or "", pickle, @@ -234,8 +242,15 @@ async def put_group_session( forwarding_chains = ",".join(session.forwarding_chain) q = """ INSERT INTO crypto_megolm_inbound_session ( - session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7) + session_id, sender_key, signing_key, room_id, session, forwarding_chains, + ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT (session_id, account_id) DO UPDATE + SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, + signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, + forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, + received_at=excluded.received_at, max_age=excluded.max_age, + max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled """ try: await self.db.execute( @@ -246,6 +261,11 @@ async def put_group_session( room_id, pickle, forwarding_chains, + session.ratchet_safety.json(), + session.received_at, + int(session.max_age.total_seconds() * 1000) if session.max_age else None, + session.max_messages, + session.is_scheduled, self.account_id, ) except (IntegrityError, UniqueViolationError): @@ -255,13 +275,17 @@ async def get_group_session( self, room_id: RoomID, session_id: SessionID ) -> InboundGroupSession | None: q = """ - SELECT sender_key, signing_key, session, forwarding_chains + SELECT + sender_key, signing_key, session, forwarding_chains, withheld_code, + ratchet_safety, received_at, max_age, max_messages, is_scheduled FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3 """ row = await self.db.fetchrow(q, room_id, session_id, self.account_id) if row is None: return None + if row["withheld_code"] is not None: + raise GroupSessionWithheldError(session_id, row["withheld_code"]) forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else [] return InboundGroupSession.from_pickle( row["session"], @@ -270,21 +294,106 @@ async def get_group_session( sender_key=row["sender_key"], room_id=room_id, forwarding_chain=forwarding_chain, + ratchet_safety=RatchetSafety.parse_json(row["ratchet_safety"] or "{}"), + received_at=row["received_at"], + max_age=timedelta(milliseconds=row["max_age"]) if row["max_age"] else None, + max_messages=row["max_messages"], + is_scheduled=row["is_scheduled"], ) + async def redact_group_session( + self, room_id: RoomID, session_id: SessionID, reason: str + ) -> None: + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL + """ + await self.db.execute( + q, + RoomKeyWithheldCode.BEEPER_REDACTED.value, + f"Session redacted: {reason}", + session_id, + self.account_id, + ) + + async def redact_group_sessions( + self, room_id: RoomID, sender_key: IdentityKey, reason: str + ) -> list[SessionID]: + if not room_id and not sender_key: + raise ValueError("Either room_id or sender_key must be provided") + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5 + AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL + RETURNING session_id + """ + rows = await self.db.fetch( + q, + RoomKeyWithheldCode.BEEPER_REDACTED.value, + f"Session redacted: {reason}", + room_id, + sender_key, + self.account_id, + ) + return [row["session_id"] for row in rows] + + async def redact_expired_group_sessions(self) -> list[SessionID]: + if self.db.scheme == Scheme.SQLITE: + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false + AND received_at IS NOT NULL and max_age IS NOT NULL + AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now')) + RETURNING session_id + """ + elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false + AND received_at IS NOT NULL and max_age IS NOT NULL + AND received_at + 2 * (max_age * interval '1 millisecond') < now() + RETURNING session_id + """ + else: + raise RuntimeError(f"Unsupported dialect {self.db.scheme}") + rows = await self.db.fetch( + q, + RoomKeyWithheldCode.BEEPER_REDACTED.value, + f"Session redacted: expired", + self.account_id, + ) + return [row["session_id"] for row in rows] + + async def redact_outdated_group_sessions(self) -> list[SessionID]: + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL + RETURNING session_id + """ + rows = await self.db.fetch( + q, + RoomKeyWithheldCode.BEEPER_REDACTED.value, + f"Session redacted: outdated", + self.account_id, + ) + return [row["session_id"] for row in rows] + async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: q = """ SELECT COUNT(session) FROM crypto_megolm_inbound_session - WHERE room_id=$1 AND session_id=$2 AND account_id=$3 + WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL """ count = await self.db.fetchval(q, room_id, session_id, self.account_id) return count > 0 async def add_outbound_group_session(self, session: OutboundGroupSession) -> None: pickle = session.pickle(self.pickle_key) - max_age = session.max_age - if self.db.scheme == Scheme.SQLITE: - max_age = max_age.total_seconds() + max_age = int(session.max_age.total_seconds() * 1000) q = """ INSERT INTO crypto_megolm_outbound_session ( room_id, session_id, session, shared, max_messages, message_count, @@ -334,9 +443,6 @@ async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSess row = await self.db.fetchrow(q, room_id, self.account_id) if row is None: return None - max_age = row["max_age"] - if self.db.scheme == Scheme.SQLITE: - max_age = timedelta(seconds=max_age) return OutboundGroupSession.from_pickle( row["session"], passphrase=self.pickle_key, @@ -344,7 +450,7 @@ async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSess shared=row["shared"], max_messages=row["max_messages"], message_count=row["message_count"], - max_age=max_age, + max_age=timedelta(milliseconds=row["max_age"]), use_time=row["last_used"], creation_time=row["created_at"], ) diff --git a/mautrix/crypto/store/asyncpg/upgrade.py b/mautrix/crypto/store/asyncpg/upgrade.py index f7f50a74..8d413858 100644 --- a/mautrix/crypto/store/asyncpg/upgrade.py +++ b/mautrix/crypto/store/asyncpg/upgrade.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -16,34 +16,34 @@ ) -@upgrade_table.register(description="Latest revision", upgrades_to=5) -async def upgrade_blank_to_v4(conn: Connection) -> None: - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_account ( - account_id TEXT PRIMARY KEY, - device_id TEXT, - shared BOOLEAN NOT NULL, - sync_token TEXT NOT NULL, - account bytea NOT NULL - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_message_index ( +@upgrade_table.register(description="Latest revision", upgrades_to=10) +async def upgrade_blank_to_latest(conn: Connection) -> None: + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_account ( + account_id TEXT PRIMARY KEY, + device_id TEXT NOT NULL, + shared BOOLEAN NOT NULL, + sync_token TEXT NOT NULL, + account bytea NOT NULL + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_message_index ( sender_key CHAR(43), session_id CHAR(43), "index" INTEGER, event_id TEXT NOT NULL, timestamp BIGINT NOT NULL, PRIMARY KEY (sender_key, session_id, "index") - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_tracked_user ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_tracked_user ( user_id TEXT PRIMARY KEY - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_device ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_device ( user_id TEXT, device_id TEXT, identity_key CHAR(43) NOT NULL, @@ -52,10 +52,10 @@ async def upgrade_blank_to_v4(conn: Connection) -> None: deleted BOOLEAN NOT NULL, name TEXT NOT NULL, PRIMARY KEY (user_id, device_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_olm_session ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_olm_session ( account_id TEXT, session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -64,22 +64,29 @@ async def upgrade_blank_to_v4(conn: Connection) -> None: last_decrypted timestamp NOT NULL, last_encrypted timestamp NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( - account_id TEXT, - session_id CHAR(43), - sender_key CHAR(43) NOT NULL, - signing_key CHAR(43) NOT NULL, - room_id TEXT NOT NULL, - session bytea NOT NULL, - forwarding_chains TEXT NOT NULL, + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( + account_id TEXT, + session_id CHAR(43), + sender_key CHAR(43) NOT NULL, + signing_key CHAR(43), + room_id TEXT NOT NULL, + session bytea, + forwarding_chains TEXT, + withheld_code TEXT, + withheld_reason TEXT, + ratchet_safety jsonb, + received_at timestamp, + max_age BIGINT, + max_messages INTEGER, + is_scheduled BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, room_id TEXT, session_id CHAR(43) NOT NULL UNIQUE, @@ -87,31 +94,33 @@ async def upgrade_blank_to_v4(conn: Connection) -> None: shared BOOLEAN NOT NULL, max_messages INTEGER NOT NULL, message_count INTEGER NOT NULL, - max_age INTERVAL NOT NULL, + max_age BIGINT NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_cross_signing_keys ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_keys ( user_id TEXT, usage TEXT, - key CHAR(43), - first_seen_key CHAR(43), + key CHAR(43) NOT NULL, + + first_seen_key CHAR(43) NOT NULL, + PRIMARY KEY (user_id, usage) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_cross_signing_signatures ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_signatures ( signed_user_id TEXT, signed_key TEXT, signer_user_id TEXT, signer_key TEXT, - signature TEXT, + signature CHAR(88) NOT NULL, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) - )""" - ) + ) + """) @upgrade_table.register(description="Add account_id primary key column") @@ -121,17 +130,17 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: await conn.execute("DROP TABLE crypto_olm_session") await conn.execute("DROP TABLE crypto_megolm_inbound_session") await conn.execute("DROP TABLE crypto_megolm_outbound_session") - await conn.execute( - """CREATE TABLE crypto_account ( + await conn.execute(""" + CREATE TABLE crypto_account ( account_id VARCHAR(255) PRIMARY KEY, device_id VARCHAR(255) NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL - )""" - ) - await conn.execute( - """CREATE TABLE crypto_olm_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_olm_session ( account_id VARCHAR(255), session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -139,10 +148,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_megolm_inbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_megolm_inbound_session ( account_id VARCHAR(255), session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -151,10 +160,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: session bytea NOT NULL, forwarding_chains TEXT NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_megolm_outbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_megolm_outbound_session ( account_id VARCHAR(255), room_id VARCHAR(255), session_id CHAR(43) NOT NULL UNIQUE, @@ -162,12 +171,12 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: shared BOOLEAN NOT NULL, max_messages INTEGER NOT NULL, message_count INTEGER NOT NULL, - max_age INTERVAL NOT NULL, + max_age BIGINT NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) - )""" - ) + ) + """) else: async def add_account_id_column(table: str, pkey_columns: list[str]) -> None: @@ -224,25 +233,25 @@ async def upgrade_v4(conn: Connection, scheme: Scheme) -> None: @upgrade_table.register(description="Add cross-signing key and signature caches") async def upgrade_v5(conn: Connection) -> None: - await conn.execute( - """CREATE TABLE crypto_cross_signing_keys ( + await conn.execute(""" + CREATE TABLE crypto_cross_signing_keys ( user_id TEXT, usage TEXT, key CHAR(43), first_seen_key CHAR(43), PRIMARY KEY (user_id, usage) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_cross_signing_signatures ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_signatures ( signed_user_id TEXT, signed_key TEXT, signer_user_id TEXT, signer_key TEXT, signature TEXT, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) - )""" - ) + ) + """) @upgrade_table.register(description="Update trust state values") @@ -250,3 +259,177 @@ async def upgrade_v6(conn: Connection) -> None: await conn.execute("UPDATE crypto_device SET trust=300 WHERE trust=1") # verified await conn.execute("UPDATE crypto_device SET trust=-100 WHERE trust=2") # blacklisted await conn.execute("UPDATE crypto_device SET trust=0 WHERE trust=3") # ignored -> unset + + +@upgrade_table.register( + description="Synchronize schema with mautrix-go", upgrades_to=9, transaction=False +) +async def upgrade_v9(conn: Connection, scheme: Scheme) -> None: + if scheme == Scheme.POSTGRES: + async with conn.transaction(): + await upgrade_v9_postgres(conn) + else: + await upgrade_v9_sqlite(conn) + + +# These two are never used because the previous one jumps from 6 to 9. +@upgrade_table.register +async def upgrade_noop_7_to_8(_: Connection) -> None: + pass + + +@upgrade_table.register +async def upgrade_noop_8_to_9(_: Connection) -> None: + pass + + +async def upgrade_v9_postgres(conn: Connection) -> None: + await conn.execute("UPDATE crypto_account SET device_id='' WHERE device_id IS NULL") + await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id SET NOT NULL") + + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN signing_key DROP NOT NULL" + ) + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN session DROP NOT NULL" + ) + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN forwarding_chains DROP NOT NULL" + ) + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_code TEXT") + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_reason TEXT") + + await conn.execute("DELETE FROM crypto_cross_signing_keys WHERE key IS NULL") + await conn.execute( + "UPDATE crypto_cross_signing_keys SET first_seen_key=key WHERE first_seen_key IS NULL" + ) + await conn.execute("ALTER TABLE crypto_cross_signing_keys ALTER COLUMN key SET NOT NULL") + await conn.execute( + "ALTER TABLE crypto_cross_signing_keys ALTER COLUMN first_seen_key SET NOT NULL" + ) + + await conn.execute("DELETE FROM crypto_cross_signing_signatures WHERE signature IS NULL") + await conn.execute( + "ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature SET NOT NULL" + ) + + await conn.execute( + "ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN max_age TYPE BIGINT " + "USING (EXTRACT(EPOCH from max_age)*1000)::bigint" + ) + + +async def upgrade_v9_sqlite(conn: Connection) -> None: + await conn.execute("PRAGMA foreign_keys = OFF") + async with conn.transaction(): + await conn.execute(""" + CREATE TABLE new_crypto_account ( + account_id TEXT PRIMARY KEY, + device_id TEXT NOT NULL, + shared BOOLEAN NOT NULL, + sync_token TEXT NOT NULL, + account bytea NOT NULL + ) + """) + await conn.execute(""" + INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account) + SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account + FROM crypto_account + """) + await conn.execute("DROP TABLE crypto_account") + await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account") + + await conn.execute(""" + CREATE TABLE new_crypto_megolm_inbound_session ( + account_id TEXT, + session_id CHAR(43), + sender_key CHAR(43) NOT NULL, + signing_key CHAR(43), + room_id TEXT NOT NULL, + session bytea, + forwarding_chains TEXT, + withheld_code TEXT, + withheld_reason TEXT, + PRIMARY KEY (account_id, session_id) + ) + """) + await conn.execute(""" + INSERT INTO new_crypto_megolm_inbound_session ( + account_id, session_id, sender_key, signing_key, room_id, session, + forwarding_chains + ) + SELECT account_id, session_id, sender_key, signing_key, room_id, session, + forwarding_chains + FROM crypto_megolm_inbound_session + """) + await conn.execute("DROP TABLE crypto_megolm_inbound_session") + await conn.execute( + "ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session" + ) + + await conn.execute("UPDATE crypto_megolm_outbound_session SET max_age=max_age*1000") + + await conn.execute(""" + CREATE TABLE new_crypto_cross_signing_keys ( + user_id TEXT, + usage TEXT, + key CHAR(43) NOT NULL, + + first_seen_key CHAR(43) NOT NULL, + + PRIMARY KEY (user_id, usage) + ) + """) + await conn.execute(""" + INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key) + SELECT user_id, usage, key, COALESCE(first_seen_key, key) + FROM crypto_cross_signing_keys + WHERE key IS NOT NULL + """) + await conn.execute("DROP TABLE crypto_cross_signing_keys") + await conn.execute( + "ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys" + ) + + await conn.execute(""" + CREATE TABLE new_crypto_cross_signing_signatures ( + signed_user_id TEXT, + signed_key TEXT, + signer_user_id TEXT, + signer_key TEXT, + signature CHAR(88) NOT NULL, + PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) + ) + """) + await conn.execute(""" + INSERT INTO new_crypto_cross_signing_signatures ( + signed_user_id, signed_key, signer_user_id, signer_key, signature + ) + SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature + FROM crypto_cross_signing_signatures + WHERE signature IS NOT NULL + """) + await conn.execute("DROP TABLE crypto_cross_signing_signatures") + await conn.execute( + "ALTER TABLE new_crypto_cross_signing_signatures " + "RENAME TO crypto_cross_signing_signatures" + ) + + await conn.execute("PRAGMA foreign_key_check") + await conn.execute("PRAGMA foreign_keys = ON") + + +@upgrade_table.register( + description="Add metadata for detecting when megolm sessions are safe to delete" +) +async def upgrade_v10(conn: Connection) -> None: + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb") + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp" + ) + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT") + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER") + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session " + "ADD COLUMN is_scheduled BOOLEAN NOT NULL DEFAULT false" + ) diff --git a/mautrix/crypto/store/memory.py b/mautrix/crypto/store/memory.py index 35dc26b5..c26f86bc 100644 --- a/mautrix/crypto/store/memory.py +++ b/mautrix/crypto/store/memory.py @@ -110,6 +110,33 @@ async def get_group_session( ) -> InboundGroupSession: return self._inbound_sessions.get((room_id, session_id)) + async def redact_group_session( + self, room_id: RoomID, session_id: SessionID, reason: str + ) -> None: + self._inbound_sessions.pop((room_id, session_id), None) + + async def redact_group_sessions( + self, room_id: RoomID, sender_key: IdentityKey, reason: str + ) -> list[SessionID]: + if not room_id and not sender_key: + raise ValueError("Either room_id or sender_key must be provided") + deleted = [] + keys = list(self._inbound_sessions.keys()) + for key in keys: + item = self._inbound_sessions[key] + if (not room_id or item.room_id == room_id) and ( + not sender_key or item.sender_key == sender_key + ): + deleted.append(SessionID(item.id)) + del self._inbound_sessions[key] + return deleted + + async def redact_expired_group_sessions(self) -> list[SessionID]: + raise NotImplementedError() + + async def redact_outdated_group_sessions(self) -> list[SessionID]: + raise NotImplementedError() + async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: return (room_id, session_id) in self._inbound_sessions diff --git a/mautrix/crypto/store/tests/store_test.py b/mautrix/crypto/store/tests/store_test.py index 8d3fc851..949949b8 100644 --- a/mautrix/crypto/store/tests/store_test.py +++ b/mautrix/crypto/store/tests/store_test.py @@ -50,7 +50,7 @@ async def async_postgres_store() -> AsyncIterator[PgCryptoStore]: @asynccontextmanager async def async_sqlite_store() -> AsyncIterator[PgCryptoStore]: db = Database.create( - "sqlite:///:memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1} + "sqlite::memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1} ) store = PgCryptoStore("", "test", db) await db.start() diff --git a/mautrix/errors/__init__.py b/mautrix/errors/__init__.py index fdec6e3a..afe68dc9 100644 --- a/mautrix/errors/__init__.py +++ b/mautrix/errors/__init__.py @@ -6,6 +6,7 @@ DeviceValidationError, DuplicateMessageIndex, EncryptionError, + GroupSessionWithheldError, MatchingSessionDecryptionError, MismatchingRoomError, SessionNotFound, @@ -72,6 +73,7 @@ "DeviceValidationError", "DuplicateMessageIndex", "EncryptionError", + "GroupSessionWithheldError", "MatchingSessionDecryptionError", "MismatchingRoomError", "SessionNotFound", diff --git a/mautrix/errors/crypto.py b/mautrix/errors/crypto.py index 97592b05..4a65048c 100644 --- a/mautrix/errors/crypto.py +++ b/mautrix/errors/crypto.py @@ -36,6 +36,12 @@ class MatchingSessionDecryptionError(DecryptionError): pass +class GroupSessionWithheldError(DecryptionError): + def __init__(self, session_id: SessionID, withheld_code: str) -> None: + super().__init__(f"Session ID {session_id} was withheld ({withheld_code})") + self.withheld_code = withheld_code + + class SessionNotFound(DecryptionError): def __init__(self, session_id: SessionID, sender_key: IdentityKey | None = None) -> None: super().__init__( diff --git a/mautrix/types/__init__.py b/mautrix/types/__init__.py index 6d2ffa7c..ceceeaca 100644 --- a/mautrix/types/__init__.py +++ b/mautrix/types/__init__.py @@ -57,6 +57,7 @@ CallRejectEventContent, CallSelectAnswerEventContent, CanonicalAliasStateEventContent, + DirectAccountDataEventContent, EncryptedEvent, EncryptedEventContent, EncryptedFile, @@ -122,6 +123,7 @@ RoomTombstoneStateEventContent, RoomTopicStateEventContent, RoomType, + SecretStorageDefaultKeyEventContent, SingleReceiptEventContent, SpaceChildStateEventContent, SpaceParentStateEventContent, @@ -149,6 +151,7 @@ ) from .misc import ( BatchSendResponse, + BeeperBatchSendResponse, DeviceLists, DeviceOTKCount, DirectoryPaginationToken, @@ -258,6 +261,7 @@ "CallRejectEventContent", "CallSelectAnswerEventContent", "CanonicalAliasStateEventContent", + "DirectAccountDataEventContent", "EncryptedEvent", "EncryptedEventContent", "EncryptedFile", @@ -323,6 +327,7 @@ "RoomTombstoneStateEventContent", "RoomTopicStateEventContent", "RoomType", + "SecretStorageDefaultKeyEventContent", "SingleReceiptEventContent", "SpaceChildStateEventContent", "SpaceParentStateEventContent", @@ -353,6 +358,7 @@ "OpenGraphImage", "OpenGraphVideo", "BatchSendResponse", + "BeeperBatchSendResponse", "DeviceLists", "DeviceOTKCount", "DirectoryPaginationToken", diff --git a/mautrix/types/auth.py b/mautrix/types/auth.py index 198d2f53..ad582118 100644 --- a/mautrix/types/auth.py +++ b/mautrix/types/auth.py @@ -26,6 +26,8 @@ class LoginType(ExtensibleEnum): UNSTABLE_JWT: "LoginType" = "org.matrix.login.jwt" + DEVTURE_SHARED_SECRET: "LoginType" = "com.devture.shared_secret_auth" + @dataclass class LoginFlow(SerializableAttrs): diff --git a/mautrix/types/crypto.py b/mautrix/types/crypto.py index 821599de..fe4ab742 100644 --- a/mautrix/types/crypto.py +++ b/mautrix/types/crypto.py @@ -47,9 +47,9 @@ def curve25519(self) -> Optional[IdentityKey]: class CrossSigningUsage(ExtensibleEnum): - MASTER = "master" - SELF = "self_signing" - USER = "user_signing" + MASTER: "CrossSigningUsage" = "master" + SELF: "CrossSigningUsage" = "self_signing" + USER: "CrossSigningUsage" = "user_signing" @dataclass @@ -71,6 +71,8 @@ def first_ed25519_key(self) -> Optional[SigningKey]: return self.first_key_with_algorithm(EncryptionKeyAlgorithm.ED25519) def first_key_with_algorithm(self, alg: EncryptionKeyAlgorithm) -> Optional[SigningKey]: + if not self.keys: + return None try: return next(key for key_id, key in self.keys.items() if key_id.algorithm == alg) except StopIteration: diff --git a/mautrix/types/event/__init__.py b/mautrix/types/event/__init__.py index b391e912..db0658db 100644 --- a/mautrix/types/event/__init__.py +++ b/mautrix/types/event/__init__.py @@ -6,8 +6,10 @@ from .account_data import ( AccountDataEvent, AccountDataEventContent, + DirectAccountDataEventContent, RoomTagAccountDataEventContent, RoomTagInfo, + SecretStorageDefaultKeyEventContent, ) from .base import BaseEvent, BaseRoomEvent, BaseUnsigned, GenericEvent from .batch import BatchSendEvent, BatchSendStateEvent diff --git a/mautrix/types/event/account_data.py b/mautrix/types/event/account_data.py index 3c144a36..fdfa0a30 100644 --- a/mautrix/types/event/account_data.py +++ b/mautrix/types/event/account_data.py @@ -3,7 +3,7 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Union from attr import dataclass import attr @@ -12,6 +12,9 @@ from ..util import Obj, SerializableAttrs, deserializer from .base import BaseEvent, EventType +if TYPE_CHECKING: + from mautrix.crypto.ssss import EncryptedAccountDataEventContent, KeyMetadata + @dataclass class RoomTagInfo(SerializableAttrs): @@ -23,11 +26,24 @@ class RoomTagAccountDataEventContent(SerializableAttrs): tags: Dict[str, RoomTagInfo] = attr.ib(default=None, metadata={"json": "tags"}) +@dataclass +class SecretStorageDefaultKeyEventContent(SerializableAttrs): + key: str + + DirectAccountDataEventContent = Dict[UserID, List[RoomID]] -AccountDataEventContent = Union[RoomTagAccountDataEventContent, DirectAccountDataEventContent, Obj] +AccountDataEventContent = Union[ + RoomTagAccountDataEventContent, + DirectAccountDataEventContent, + SecretStorageDefaultKeyEventContent, + "EncryptedAccountDataEventContent", + "KeyMetadata", + Obj, +] account_data_event_content_map = { EventType.TAG: RoomTagAccountDataEventContent, + EventType.SECRET_STORAGE_DEFAULT_KEY: SecretStorageDefaultKeyEventContent, # m.direct doesn't really need deserializing # EventType.DIRECT: DirectAccountDataEventContent, } diff --git a/mautrix/types/event/beeper.py b/mautrix/types/event/beeper.py index dc588800..0ec16479 100644 --- a/mautrix/types/event/beeper.py +++ b/mautrix/types/event/beeper.py @@ -7,7 +7,7 @@ from attr import dataclass -from ..primitive import EventID +from ..primitive import EventID, RoomID, SessionID from ..util import SerializableAttrs, SerializableEnum, field from .base import BaseRoomEvent from .message import RelatesTo @@ -49,20 +49,16 @@ class BeeperMessageStatusEventContent(SerializableAttrs): error: Optional[str] = None message: Optional[str] = None - success: Optional[bool] = None - still_working: Optional[bool] = None - can_retry: Optional[bool] = None - is_certain: Optional[bool] = None - last_retry: Optional[EventID] = None - def fill_legacy_booleans(self) -> None: - self.success = self.status == MessageStatus.SUCCESS - if not self.success: - self.still_working = self.status == MessageStatus.PENDING - self.can_retry = self.status in (MessageStatus.PENDING, MessageStatus.RETRIABLE) - @dataclass class BeeperMessageStatusEvent(BaseRoomEvent, SerializableAttrs): content: BeeperMessageStatusEventContent + + +@dataclass +class BeeperRoomKeyAckEventContent(SerializableAttrs): + room_id: RoomID + session_id: SessionID + first_message_index: int diff --git a/mautrix/types/event/encrypted.py b/mautrix/types/event/encrypted.py index cd08fc94..735e481a 100644 --- a/mautrix/types/event/encrypted.py +++ b/mautrix/types/event/encrypted.py @@ -9,7 +9,7 @@ from attr import dataclass -from ..primitive import JSON, DeviceID, IdentityKey, SessionID +from ..primitive import JSON, DeviceID, IdentityKey, SessionID, SigningKey from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field from .base import BaseRoomEvent, BaseUnsigned from .message import RelatesTo @@ -43,6 +43,18 @@ def deserialize(cls, raw: JSON) -> "KeyID": def __str__(self) -> str: return f"{self.algorithm.value}:{self.key_id}" + @classmethod + def ed25519(cls, key_id: SigningKey | DeviceID) -> "KeyID": + return cls(EncryptionKeyAlgorithm.ED25519, key_id) + + @classmethod + def curve25519(cls, key_id: IdentityKey) -> "KeyID": + return cls(EncryptionKeyAlgorithm.CURVE25519, key_id) + + @classmethod + def signed_curve25519(cls, key_id: IdentityKey) -> "KeyID": + return cls(EncryptionKeyAlgorithm.SIGNED_CURVE25519, key_id) + class OlmMsgType(Serializable, IntEnum): PREKEY = 0 diff --git a/mautrix/types/event/message.py b/mautrix/types/event/message.py index eab4ba95..32033581 100644 --- a/mautrix/types/event/message.py +++ b/mautrix/types/event/message.py @@ -112,6 +112,12 @@ def set_thread_parent( self.relates_to.event_id = ( thread_parent if isinstance(thread_parent, str) else thread_parent.event_id ) + if isinstance(thread_parent, MessageEvent) and isinstance( + thread_parent.content, BaseMessageEventContentFuncs + ): + self.relates_to.event_id = ( + thread_parent.content.get_thread_parent() or self.relates_to.event_id + ) if not disable_reply_fallback: self.set_reply(last_event_in_thread or thread_parent, **kwargs) self.relates_to.is_falling_back = True @@ -265,33 +271,6 @@ class LocationInfo(SerializableAttrs): # region Event content -@dataclass -class MediaMessageEventContent(BaseMessageEventContent, SerializableAttrs): - """The content of a media message event (m.image, m.audio, m.video, m.file)""" - - url: Optional[ContentURI] = None - info: Optional[MediaInfo] = None - file: Optional[EncryptedFile] = None - - @staticmethod - @deserializer(MediaInfo) - @deserializer(Optional[MediaInfo]) - def deserialize_info(data: JSON) -> MediaInfo: - if not isinstance(data, dict): - return Obj() - msgtype = data.pop("__mautrix_msgtype", None) - if msgtype == "m.image" or msgtype == "m.sticker": - return ImageInfo.deserialize(data) - elif msgtype == "m.video": - return VideoInfo.deserialize(data) - elif msgtype == "m.audio": - return AudioInfo.deserialize(data) - elif msgtype == "m.file": - return FileInfo.deserialize(data) - else: - return Obj(**data) - - @dataclass class LocationMessageEventContent(BaseMessageEventContent, SerializableAttrs): geo_uri: str = None @@ -308,21 +287,6 @@ class TextMessageEventContent(BaseMessageEventContent, SerializableAttrs): format: Format = None formatted_body: str = None - def set_reply( - self, reply_to: Union["MessageEvent", EventID], *, displayname: Optional[str] = None - ) -> None: - super().set_reply(reply_to) - if isinstance(reply_to, str): - return - if isinstance(reply_to, MessageEvent): - self.ensure_has_html() - if isinstance(reply_to.content, TextMessageEventContent): - reply_to.content.trim_reply_fallback() - self.formatted_body = ( - reply_to.make_reply_fallback_html(displayname) + self.formatted_body - ) - self.body = reply_to.make_reply_fallback_text(displayname) + self.body - def ensure_has_html(self) -> None: if not self.formatted_body or self.format != Format.HTML: self.format = Format.HTML @@ -340,20 +304,48 @@ def trim_reply_fallback(self) -> None: setattr(self, "__reply_fallback_trimmed", True) def _trim_reply_fallback_text(self) -> None: - if not self.body.startswith("> ") or "\n" not in self.body: + if ( + not self.body.startswith("> <") and not self.body.startswith("> * <") + ) or "\n" not in self.body: return lines = self.body.split("\n") while len(lines) > 0 and lines[0].startswith("> "): lines.pop(0) - # Pop extra newline at end of fallback - lines.pop(0) - self.body = "\n".join(lines) + self.body = "\n".join(lines).strip() def _trim_reply_fallback_html(self) -> None: if self.formatted_body and self.format == Format.HTML: self.formatted_body = html_reply_fallback_regex.sub("", self.formatted_body) +@dataclass +class MediaMessageEventContent(TextMessageEventContent, SerializableAttrs): + """The content of a media message event (m.image, m.audio, m.video, m.file)""" + + url: Optional[ContentURI] = None + info: Optional[MediaInfo] = None + file: Optional[EncryptedFile] = None + filename: Optional[str] = None + + @staticmethod + @deserializer(MediaInfo) + @deserializer(Optional[MediaInfo]) + def deserialize_info(data: JSON) -> MediaInfo: + if not isinstance(data, dict): + return Obj() + msgtype = data.pop("__mautrix_msgtype", None) + if msgtype == "m.image" or msgtype == "m.sticker": + return ImageInfo.deserialize(data) + elif msgtype == "m.video": + return VideoInfo.deserialize(data) + elif msgtype == "m.audio": + return AudioInfo.deserialize(data) + elif msgtype == "m.file": + return FileInfo.deserialize(data) + else: + return Obj(**data) + + MessageEventContent = Union[ TextMessageEventContent, MediaMessageEventContent, LocationMessageEventContent, Obj ] @@ -413,36 +405,3 @@ def deserialize_content(data: JSON) -> MessageEventContent: return LocationMessageEventContent.deserialize(data) else: return Obj(**data) - - def make_reply_fallback_html(self, displayname: Optional[str] = None) -> str: - """Generate the HTML fallback for messages replying to this event.""" - if self.content.msgtype.is_text: - body = self.content.formatted_body or escape(self.content.body).replace("\n", "
") - else: - sent_type = media_reply_fallback_body_map[self.content.msgtype] or "a message" - body = f"sent {sent_type}" - displayname = escape(displayname) if displayname else self.sender - return html_reply_fallback_format.format( - room_id=self.room_id, - event_id=self.event_id, - sender=self.sender, - displayname=displayname, - content=body, - ) - - def make_reply_fallback_text(self, displayname: Optional[str] = None) -> str: - """Generate the plaintext fallback for messages replying to this event.""" - if self.content.msgtype.is_text: - body = self.content.body - else: - try: - body = media_reply_fallback_body_map[self.content.msgtype] - except KeyError: - body = "an unknown message type" - lines = body.strip().split("\n") - first_line, lines = lines[0], lines[1:] - fallback_text = f"> <{displayname or self.sender}> {first_line}" - for line in lines: - fallback_text += f"\n> {line}" - fallback_text += "\n\n" - return fallback_text diff --git a/mautrix/types/event/state.py b/mautrix/types/event/state.py index c6b351c6..5ffc855f 100644 --- a/mautrix/types/event/state.py +++ b/mautrix/types/event/state.py @@ -9,7 +9,7 @@ import attr from ..primitive import JSON, ContentURI, EventID, RoomAlias, RoomID, UserID -from ..util import Obj, SerializableAttrs, SerializableEnum, deserializer, field +from ..util import ExtensibleEnum, Obj, SerializableAttrs, SerializableEnum, deserializer, field from .base import BaseRoomEvent, BaseUnsigned from .encrypted import EncryptionAlgorithm from .type import EventType, RoomType @@ -41,7 +41,19 @@ class PowerLevelStateEventContent(SerializableAttrs): ban: int = 50 redact: int = 50 - def get_user_level(self, user_id: UserID) -> int: + def get_user_level( + self, + user_id: UserID, + create: Optional["StateEvent"] = None, + ) -> int: + if ( + create + and create.content.supports_creator_power + and (user_id == create.sender or user_id in (create.content.additional_creators or [])) + ): + # This is really meant to be infinity, but involving floats would be annoying, + # so we use an integer larger than the maximum power level (2^53-1) instead. + return 2**60 - 1 return int(self.users.get(user_id, self.users_default)) def set_user_level(self, user_id: UserID, level: int) -> None: @@ -50,7 +62,16 @@ def set_user_level(self, user_id: UserID, level: int) -> None: else: self.users[user_id] = level - def ensure_user_level(self, user_id: UserID, level: int) -> bool: + def ensure_user_level( + self, user_id: UserID, level: int, create: Optional["StateEvent"] = None + ) -> bool: + if ( + create + and create.content.supports_creator_power + and (user_id == create.sender or user_id in (create.content.additional_creators or [])) + ): + # Don't try to set creator power levels + return False if self.get_user_level(user_id) != level: self.set_user_level(user_id, level) return True @@ -138,16 +159,15 @@ class RoomAvatarStateEventContent(SerializableAttrs): url: Optional[ContentURI] = None -class JoinRule(SerializableEnum): +class JoinRule(ExtensibleEnum): PUBLIC = "public" KNOCK = "knock" RESTRICTED = "restricted" INVITE = "invite" - PRIVATE = "private" KNOCK_RESTRICTED = "knock_restricted" -class JoinRestrictionType(SerializableEnum): +class JoinRestrictionType(ExtensibleEnum): ROOM_MEMBERSHIP = "m.room_membership" @@ -193,6 +213,24 @@ class RoomCreateStateEventContent(SerializableAttrs): federate: bool = field(json="m.federate", omit_default=True, default=True) predecessor: Optional[RoomPredecessor] = None type: Optional[RoomType] = None + additional_creators: Optional[List[UserID]] = None + + @property + def supports_creator_power(self) -> bool: + return self.room_version not in ( + "", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + ) @dataclass diff --git a/mautrix/types/event/to_device.py b/mautrix/types/event/to_device.py index 6a17e36e..d6392177 100644 --- a/mautrix/types/event/to_device.py +++ b/mautrix/types/event/to_device.py @@ -9,8 +9,9 @@ import attr from ..primitive import JSON, DeviceID, IdentityKey, RoomID, SessionID, SigningKey, UserID -from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer +from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field from .base import BaseEvent, EventType +from .beeper import BeeperRoomKeyAckEventContent from .encrypted import EncryptedOlmEventContent, EncryptionAlgorithm @@ -21,6 +22,8 @@ class RoomKeyWithheldCode(ExtensibleEnum): UNAVAILABLE: "RoomKeyWithheldCode" = "m.unavailable" NO_OLM_SESSION: "RoomKeyWithheldCode" = "m.no_olm" + BEEPER_REDACTED: "RoomKeyWithheldCode" = "com.beeper.redacted" + @dataclass class RoomKeyWithheldEventContent(SerializableAttrs): @@ -39,6 +42,10 @@ class RoomKeyEventContent(SerializableAttrs): session_id: SessionID session_key: str + beeper_max_age_ms: Optional[int] = field(json="com.beeper.max_age_ms", default=None) + beeper_max_messages: Optional[int] = field(json="com.beeper.max_messages", default=None) + beeper_is_scheduled: Optional[bool] = field(json="com.beeper.is_scheduled", default=False) + class KeyRequestAction(ExtensibleEnum): REQUEST: "KeyRequestAction" = "request" @@ -61,7 +68,7 @@ class RoomKeyRequestEventContent(SerializableAttrs): body: Optional[RequestedKeyInfo] = None -@dataclass +@dataclass(kw_only=True) class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs): sender_key: IdentityKey signing_key: SigningKey = attr.ib(metadata={"json": "sender_claimed_ed25519_key"}) @@ -75,6 +82,7 @@ class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs): RoomKeyEventContent, RoomKeyRequestEventContent, ForwardedRoomKeyEventContent, + BeeperRoomKeyAckEventContent, ] to_device_event_content_map = { EventType.TO_DEVICE_ENCRYPTED: EncryptedOlmEventContent, @@ -82,12 +90,10 @@ class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs): EventType.ROOM_KEY_REQUEST: RoomKeyRequestEventContent, EventType.ROOM_KEY: RoomKeyEventContent, EventType.FORWARDED_ROOM_KEY: ForwardedRoomKeyEventContent, + EventType.BEEPER_ROOM_KEY_ACK: BeeperRoomKeyAckEventContent, } -# TODO remaining account data event types - - @dataclass class ToDeviceEvent(BaseEvent, SerializableAttrs): sender: UserID diff --git a/mautrix/types/event/type.py b/mautrix/types/event/type.py index 609f40ff..5faf3785 100644 --- a/mautrix/types/event/type.py +++ b/mautrix/types/event/type.py @@ -207,6 +207,11 @@ def is_to_device(self) -> bool: "m.push_rules": "PUSH_RULES", "m.tag": "TAG", "m.ignored_user_list": "IGNORED_USER_LIST", + "m.secret_storage.default_key": "SECRET_STORAGE_DEFAULT_KEY", + "m.cross_signing.master": "CROSS_SIGNING_MASTER", + "m.cross_signing.self_signing": "CROSS_SIGNING_SELF_SIGNING", + "m.cross_signing.user_signing": "CROSS_SIGNING_USER_SIGNING", + "m.megolm_backup.v1": "MEGOLM_BACKUP_V1", }, EventType.Class.TO_DEVICE: { "m.room.encrypted": "TO_DEVICE_ENCRYPTED", @@ -216,6 +221,7 @@ def is_to_device(self) -> bool: "m.room_key_request": "ROOM_KEY_REQUEST", "m.forwarded_room_key": "FORWARDED_ROOM_KEY", "m.dummy": "TO_DEVICE_DUMMY", + "com.beeper.room_key.ack": "BEEPER_ROOM_KEY_ACK", }, EventType.Class.UNKNOWN: { "__ALL__": "ALL", # This is not a real event type diff --git a/mautrix/types/event/type.pyi b/mautrix/types/event/type.pyi index 6c4c6e65..a2788d6f 100644 --- a/mautrix/types/event/type.pyi +++ b/mautrix/types/event/type.pyi @@ -18,6 +18,7 @@ class EventType(Serializable): ACCOUNT_DATA = "account_data" EPHEMERAL = "ephemeral" TO_DEVICE = "to_device" + _by_event_type: ClassVar[dict[str, EventType]] ROOM_CANONICAL_ALIAS: "EventType" @@ -60,6 +61,11 @@ class EventType(Serializable): PUSH_RULES: "EventType" TAG: "EventType" IGNORED_USER_LIST: "EventType" + SECRET_STORAGE_DEFAULT_KEY: "EventType" + CROSS_SIGNING_MASTER: "EventType" + CROSS_SIGNING_SELF_SIGNING: "EventType" + CROSS_SIGNING_USER_SIGNING: "EventType" + MEGOLM_BACKUP_V1: "EventType" TO_DEVICE_ENCRYPTED: "EventType" TO_DEVICE_DUMMY: "EventType" @@ -68,6 +74,7 @@ class EventType(Serializable): ORG_MATRIX_ROOM_KEY_WITHHELD: "EventType" ROOM_KEY_REQUEST: "EventType" FORWARDED_ROOM_KEY: "EventType" + BEEPER_ROOM_KEY_ACK: "EventType" ALL: "EventType" diff --git a/mautrix/types/media.py b/mautrix/types/media.py index 72aa6c4f..1d0ee66a 100644 --- a/mautrix/types/media.py +++ b/mautrix/types/media.py @@ -20,7 +20,7 @@ class MediaRepoConfig(SerializableAttrs): https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3config """ - upload_size: int = field(json="m.upload.size") + upload_size: int = field(default=50 * 1024 * 1024, json="m.upload.size") @dataclass @@ -71,4 +71,4 @@ class MediaCreateResponse(SerializableAttrs): content_uri: ContentURI unused_expired_at: Optional[int] = None - upload_url: Optional[str] = None + unstable_upload_url: Optional[str] = field(default=None, json="com.beeper.msc3870.upload_url") diff --git a/mautrix/types/misc.py b/mautrix/types/misc.py index 7a978bd1..5a07699c 100644 --- a/mautrix/types/misc.py +++ b/mautrix/types/misc.py @@ -87,7 +87,7 @@ class PublicRoomInfo(SerializableAttrs): num_joined_members: int world_readable: bool - guests_can_join: bool + guest_can_join: bool name: str = None topic: str = None @@ -106,7 +106,7 @@ class RoomDirectoryResponse(SerializableAttrs): PaginatedMessages = NamedTuple( - "PaginatedMessages", start=SyncToken, end=SyncToken, events=List[Event] + "PaginatedMessages", start=SyncToken, end=Optional[SyncToken], events=List[Event] ) @@ -129,3 +129,8 @@ class BatchSendResponse(SerializableAttrs): batch_event_id: EventID next_batch_id: BatchID base_insertion_event_id: Optional[EventID] = None + + +@dataclass +class BeeperBatchSendResponse(SerializableAttrs): + event_ids: List[EventID] diff --git a/mautrix/types/versions.py b/mautrix/types/versions.py index 8b42de39..52a62f59 100644 --- a/mautrix/types/versions.py +++ b/mautrix/types/versions.py @@ -70,6 +70,14 @@ class SpecVersions: V11 = Version.deserialize("v1.1") V12 = Version.deserialize("v1.2") V13 = Version.deserialize("v1.3") + V14 = Version.deserialize("v1.4") + V15 = Version.deserialize("v1.5") + V16 = Version.deserialize("v1.6") + V17 = Version.deserialize("v1.7") + V18 = Version.deserialize("v1.8") + V19 = Version.deserialize("v1.9") + V110 = Version.deserialize("v1.10") + V111 = Version.deserialize("v1.11") @dataclass diff --git a/mautrix/util/__init__.py b/mautrix/util/__init__.py index 6a7827d1..fd349bef 100644 --- a/mautrix/util/__init__.py +++ b/mautrix/util/__init__.py @@ -1,22 +1,28 @@ __all__ = [ + # Directory modules + "async_db", + "config", + "db", "formatter", "logging", - "config", - "signed_token", - "simple_template", - "manhole", - "markdown", - "simple_lock", + # File modules + "async_body", + "async_getter_lock", + "background_task", + "bridge_state", + "color_log", + "ffmpeg", "file_store", - "program", - "async_db", - "db", - "opt_prometheus", + "format_duration", "magic", - "bridge_state", + "manhole", + "markdown", "message_send_checkpoint", - "variation_selector", - "format_duration", - "ffmpeg", + "opt_prometheus", + "program", + "signed_token", + "simple_lock", + "simple_template", "utf16_surrogate", + "variation_selector", ] diff --git a/mautrix/util/async_body.py b/mautrix/util/async_body.py new file mode 100644 index 00000000..4db4d1e5 --- /dev/null +++ b/mautrix/util/async_body.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from typing import AsyncGenerator, Union +import logging + +import aiohttp + +AsyncBody = AsyncGenerator[Union[bytes, bytearray, memoryview], None] + + +async def async_iter_bytes(data: bytearray | bytes, chunk_size: int = 1024**2) -> AsyncBody: + """ + Return memory views into a byte array in chunks. This is used to prevent aiohttp from copying + the entire request body. + + Args: + data: The underlying data to iterate through. + chunk_size: How big each returned chunk should be. + + Returns: + An async generator that yields the given data in chunks. + """ + with memoryview(data) as mv: + for i in range(0, len(data), chunk_size): + yield mv[i : i + chunk_size] + + +class FileTooLargeError(Exception): + def __init__(self, max_size: int) -> None: + super().__init__(f"File size larger than maximum ({max_size / 1024 / 1024} MiB)") + + +_default_dl_log = logging.getLogger("mau.util.download") + + +async def read_response_chunks( + resp: aiohttp.ClientResponse, max_size: int, log: logging.Logger = _default_dl_log +) -> bytearray: + """ + Read the body from an aiohttp response in chunks into a mutable bytearray. + + Args: + resp: The aiohttp response object to read the body from. + max_size: The maximum size to read. FileTooLargeError will be raised if the Content-Length + is higher than this, or if the body exceeds this size during reading. + log: A logger for logging download status. + + Returns: + The body data as a byte array. + + Raises: + FileTooLargeError: if the body is larger than the provided max_size. + """ + content_length = int(resp.headers.get("Content-Length", "0")) + if 0 < max_size < content_length: + raise FileTooLargeError(max_size) + size_str = "unknown length" if content_length == 0 else f"{content_length} bytes" + log.info(f"Reading file download response with {size_str} (max: {max_size})") + data = bytearray(content_length) + mv = memoryview(data) if content_length > 0 else None + read_size = 0 + max_size += 1 + while True: + block = await resp.content.readany() + if not block: + break + max_size -= len(block) + if max_size <= 0: + raise FileTooLargeError(max_size) + if len(data) >= read_size + len(block): + mv[read_size : read_size + len(block)] = block + elif len(data) > read_size: + log.warning("File being downloaded is bigger than expected") + mv[read_size:] = block[: len(data) - read_size] + mv.release() + mv = None + data.extend(block[len(data) - read_size :]) + else: + if mv is not None: + mv.release() + mv = None + data.extend(block) + read_size += len(block) + if mv is not None: + mv.release() + log.info(f"Successfully read {read_size} bytes of file download response") + return data + + +__all__ = ["AsyncBody", "FileTooLargeError", "async_iter_bytes", "async_read_bytes"] diff --git a/mautrix/util/async_db/aiosqlite.py b/mautrix/util/async_db/aiosqlite.py index 9b7fa16a..934379a8 100644 --- a/mautrix/util/async_db/aiosqlite.py +++ b/mautrix/util/async_db/aiosqlite.py @@ -5,10 +5,12 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any +from typing import Any, AsyncContextManager from contextlib import asynccontextmanager +from contextvars import ContextVar import asyncio import logging +import os import re import sqlite3 @@ -23,6 +25,9 @@ POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)") +in_transaction = ContextVar("in_transaction", default=False) + + class TxnConnection(aiosqlite.Connection): def __init__(self, path: str, **kwargs) -> None: def connector() -> sqlite3.Connection: @@ -34,7 +39,11 @@ def connector() -> sqlite3.Connection: @asynccontextmanager async def transaction(self) -> None: + if in_transaction.get(): + yield + return await self.execute("BEGIN TRANSACTION") + token = in_transaction.set(True) try: yield except Exception: @@ -42,6 +51,8 @@ async def transaction(self) -> None: raise else: await self.commit() + finally: + in_transaction.reset(token) def __execute(self, query: str, *args: Any): query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query) @@ -81,6 +92,7 @@ async def fetchval( class SQLiteDatabase(Database): scheme = Scheme.SQLITE + _parent: SQLiteDatabase | None _pool: asyncio.Queue[TxnConnection] _stopped: bool _conns: int @@ -103,9 +115,8 @@ def __init__( owner_name=owner_name, ignore_foreign_tables=ignore_foreign_tables, ) + self._parent = None self._path = url.path - if self._path.startswith("/"): - self._path = self._path[1:] self._pool = asyncio.Queue(self._db_args.pop("min_size", 1)) self._db_args.pop("max_size", None) self._stopped = False @@ -116,6 +127,7 @@ def __init__( def _add_missing_pragmas(init_commands: list[str]) -> list[str]: has_foreign_keys = False has_journal_mode = False + has_synchronous = False has_busy_timeout = False for cmd in init_commands: if "PRAGMA" not in cmd: @@ -124,23 +136,39 @@ def _add_missing_pragmas(init_commands: list[str]) -> list[str]: has_foreign_keys = True elif "journal_mode" in cmd: has_journal_mode = True + elif "synchronous" in cmd: + has_synchronous = True elif "busy_timeout" in cmd: has_busy_timeout = True if not has_foreign_keys: init_commands.append("PRAGMA foreign_keys = ON") if not has_journal_mode: init_commands.append("PRAGMA journal_mode = WAL") + if not has_synchronous and "PRAGMA journal_mode = WAL" in init_commands: + init_commands.append("PRAGMA synchronous = NORMAL") if not has_busy_timeout: init_commands.append("PRAGMA busy_timeout = 5000") return init_commands + def override_pool(self, db: Database) -> None: + assert isinstance(db, SQLiteDatabase) + self._parent = db + async def start(self) -> None: + if self._parent: + await super().start() + return if self._conns: raise RuntimeError("database pool has already been started") elif self._stopped: raise RuntimeError("database pool can't be restarted") self.log.debug(f"Connecting to {self.url}") self.log.debug(f"Database connection init commands: {self._init_commands}") + if os.path.exists(self._path): + if not os.access(self._path, os.W_OK): + self.log.warning("Database file doesn't seem writable") + elif not os.access(os.path.dirname(os.path.abspath(self._path)), os.W_OK): + self.log.warning("Database file doesn't exist and directory doesn't seem writable") for _ in range(self._pool.maxsize): conn = await TxnConnection(self._path, **self._db_args) if self._init_commands: @@ -155,14 +183,21 @@ async def start(self) -> None: await super().start() async def stop(self) -> None: + if self._parent: + return self._stopped = True while self._conns > 0: conn = await self._pool.get() self._conns -= 1 await conn.close() + def acquire_direct(self) -> AsyncContextManager[LoggingConnection]: + if self._parent: + return self._parent.acquire() + return self._acquire() + @asynccontextmanager - async def acquire(self) -> LoggingConnection: + async def _acquire(self) -> LoggingConnection: if self._stopped: raise RuntimeError("database pool has been stopped") conn = await self._pool.get() diff --git a/mautrix/util/async_db/asyncpg.py b/mautrix/util/async_db/asyncpg.py index 07ef7ad7..97b49f6c 100644 --- a/mautrix/util/async_db/asyncpg.py +++ b/mautrix/util/async_db/asyncpg.py @@ -96,7 +96,7 @@ async def _handle_exception(self, err: Exception) -> None: sys.exit(26) @asynccontextmanager - async def acquire(self) -> LoggingConnection: + async def acquire_direct(self) -> LoggingConnection: async with self.pool.acquire() as conn: yield LoggingConnection( self.scheme, conn, self.log, handle_exception=self._handle_exception diff --git a/mautrix/util/async_db/database.py b/mautrix/util/async_db/database.py index 0f23b02d..b5128b74 100644 --- a/mautrix/util/async_db/database.py +++ b/mautrix/util/async_db/database.py @@ -7,6 +7,8 @@ from typing import Any, AsyncContextManager, Type from abc import ABC, abstractmethod +from contextlib import asynccontextmanager +from contextvars import ContextVar import logging from yarl import URL @@ -23,6 +25,8 @@ from aiosqlite import Cursor from asyncpg import Record +conn_var: ContextVar[LoggingConnection | None] = ContextVar("db_connection", default=None) + class Database(ABC): schemes: dict[str, Type[Database]] = {} @@ -111,15 +115,17 @@ async def _check_foreign_tables(self) -> None: raise ForeignTablesFound("found roomserver_rooms likely belonging to Dendrite") async def _check_owner(self) -> None: - await self.execute( - """CREATE TABLE IF NOT EXISTS database_owner ( + await self.execute(""" + CREATE TABLE IF NOT EXISTS database_owner ( key INTEGER PRIMARY KEY DEFAULT 0, owner TEXT NOT NULL - )""" - ) + ) + """) owner = await self.fetchval("SELECT owner FROM database_owner WHERE key=0") if not owner: - await self.execute("INSERT INTO database_owner (owner) VALUES ($1)", self.owner_name) + await self.execute( + "INSERT INTO database_owner (key, owner) VALUES (0, $1)", self.owner_name + ) elif owner != self.owner_name: raise DatabaseNotOwned(owner) @@ -128,9 +134,22 @@ async def stop(self) -> None: pass @abstractmethod - def acquire(self) -> AsyncContextManager[LoggingConnection]: + def acquire_direct(self) -> AsyncContextManager[LoggingConnection]: pass + @asynccontextmanager + async def acquire(self) -> LoggingConnection: + conn = conn_var.get(None) + if conn is not None: + yield conn + return + async with self.acquire_direct() as conn: + token = conn_var.set(conn) + try: + yield conn + finally: + conn_var.reset(token) + async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor: async with self.acquire() as conn: return await conn.execute(query, *args, timeout=timeout) diff --git a/mautrix/util/async_db/upgrade.py b/mautrix/util/async_db/upgrade.py index d69dac56..c084d28b 100644 --- a/mautrix/util/async_db/upgrade.py +++ b/mautrix/util/async_db/upgrade.py @@ -21,7 +21,7 @@ UpgradeWithoutScheme = Callable[[LoggingConnection], Awaitable[Optional[int]]] -async def noop_upgrade(_: LoggingConnection) -> None: +async def noop_upgrade(_: LoggingConnection, _2: Scheme) -> None: pass @@ -97,11 +97,11 @@ async def _save_version(self, conn: LoggingConnection, version: int) -> None: await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version) async def upgrade(self, db: async_db.Database) -> None: - await db.execute( - f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} ( + await db.execute(f""" + CREATE TABLE IF NOT EXISTS {self.version_table_name} ( version INTEGER PRIMARY KEY - )""" - ) + ) + """) row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1") version = row["version"] if row else 0 @@ -178,6 +178,6 @@ def _find_upgrade_table(fn: Upgrade) -> UpgradeTable: def register_upgrade(index: int = -1, description: str = "") -> Callable[[Upgrade], Upgrade]: def actually_register(fn: Upgrade) -> Upgrade: - return _find_upgrade_table(fn).register(index, description, fn) + return _find_upgrade_table(fn).register(fn, index=index, description=description) return actually_register diff --git a/mautrix/util/background_task.py b/mautrix/util/background_task.py new file mode 100644 index 00000000..e22e74f1 --- /dev/null +++ b/mautrix/util/background_task.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from typing import Coroutine +import asyncio +import logging + +_tasks = set() +log = logging.getLogger("mau.background_task") + + +async def catch(coro: Coroutine, caller: str) -> None: + try: + await coro + except Exception: + log.exception(f"Uncaught error in background task (created in {caller})") + + +# Logger.findCaller finds the 3rd stack frame, so add an intermediate function +# to get the caller of create(). +def _find_caller() -> tuple[str, int, str, None]: + return log.findCaller() + + +def create(coro: Coroutine, *, name: str | None = None, catch_errors: bool = True) -> asyncio.Task: + """ + Create a background asyncio task safely, ensuring a reference is kept until the task completes. + It also catches and logs uncaught errors (unless disabled via the parameter). + + Args: + coro: The coroutine to wrap in a task and execute. + name: An optional name for the created task. + catch_errors: Should the task be wrapped in a try-except block to log any uncaught errors? + + Returns: + An asyncio Task object wrapping the given coroutine. + """ + if catch_errors: + try: + file_name, line_number, function_name, _ = _find_caller() + caller = f"{function_name} at {file_name}:{line_number}" + except ValueError: + caller = "unknown function" + task = asyncio.create_task(catch(coro, caller), name=name) + else: + task = asyncio.create_task(coro, name=name) + _tasks.add(task) + task.add_done_callback(_tasks.discard) + return task diff --git a/mautrix/util/bridge_state.py b/mautrix/util/bridge_state.py index b7c346f0..d28448bf 100644 --- a/mautrix/util/bridge_state.py +++ b/mautrix/util/bridge_state.py @@ -62,8 +62,8 @@ class BridgeStateEvent(SerializableEnum): class BridgeState(SerializableAttrs): human_readable_errors: ClassVar[Dict[Optional[str], str]] = {} default_source: ClassVar[str] = "bridge" - default_error_ttl: ClassVar[int] = 60 - default_ok_ttl: ClassVar[int] = 240 + default_error_ttl: ClassVar[int] = 3600 + default_ok_ttl: ClassVar[int] = 21600 state_event: BridgeStateEvent user_id: Optional[UserID] = None @@ -106,8 +106,8 @@ def should_deduplicate(self, prev_state: Optional["BridgeState"]) -> bool: ): # If there's no previous state or the state was different, send this one. return False - # If there's more than ⅘ of the previous pong's time-to-live left, drop this one - return prev_state.timestamp + (prev_state.ttl / 5) > self.timestamp + # If the previous state is recent, drop this one + return prev_state.timestamp + prev_state.ttl > self.timestamp async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool = True) -> bool: if not url: @@ -115,9 +115,10 @@ async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool = self.send_attempts_ += 1 headers = {"Authorization": f"Bearer {token}", "User-Agent": HTTPAPI.default_ua} try: - async with aiohttp.ClientSession() as sess, sess.post( - url, json=self.serialize(), headers=headers - ) as resp: + async with ( + aiohttp.ClientSession() as sess, + sess.post(url, json=self.serialize(), headers=headers) as resp, + ): if not 200 <= resp.status < 300: text = await resp.text() text = text.replace("\n", "\\n") diff --git a/mautrix/util/ffmpeg.py b/mautrix/util/ffmpeg.py index f158ae52..41cebf32 100644 --- a/mautrix/util/ffmpeg.py +++ b/mautrix/util/ffmpeg.py @@ -5,9 +5,11 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Iterable +from typing import Any, Iterable from pathlib import Path import asyncio +import json +import logging import mimetypes import os import shutil @@ -34,7 +36,89 @@ def __init__(self) -> None: ffmpeg_path = _abswhich("ffmpeg") -ffmpeg_default_params = ("-hide_banner", "-loglevel", "warning") +ffmpeg_default_params = ("-hide_banner", "-loglevel", "warning", "-y") + +ffprobe_path = _abswhich("ffprobe") +ffprobe_default_params = ( + "-loglevel", + "quiet", + "-print_format", + "json", + "-show_optional_fields", + "1", + "-show_format", + "-show_streams", +) + + +async def probe_path( + input_file: os.PathLike[str] | str, + logger: logging.Logger | None = None, +) -> Any: + """ + Probes a media file on the disk using ffprobe. + + Args: + input_file: The full path to the file. + + Returns: + A Python object containing the parsed JSON response from ffprobe + + Raises: + ConverterError: if ffprobe returns a non-zero exit code. + """ + if ffprobe_path is None: + raise NotInstalledError() + + input_file = Path(input_file) + proc = await asyncio.create_subprocess_exec( + ffprobe_path, + *ffprobe_default_params, + str(input_file), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})" + raise ConverterError(f"ffprobe error: {err_text}") + elif stderr and logger: + logger.warning(f"ffprobe warning: {stderr.decode('utf-8')}") + return json.loads(stdout) + + +async def probe_bytes( + data: bytes, + input_mime: str | None = None, + logger: logging.Logger | None = None, +) -> Any: + """ + Probe media file data using ffprobe. + + Args: + data: The bytes of the file to probe. + input_mime: The mime type of the input data. If not specified, will be guessed using magic. + + Returns: + A Python object containing the parsed JSON response from ffprobe + + Raises: + ConverterError: if ffprobe returns a non-zero exit code. + """ + if ffprobe_path is None: + raise NotInstalledError() + + if input_mime is None: + if magic is None: + raise ValueError("input_mime was not specified and magic is not installed") + input_mime = magic.mimetype(data) + input_extension = mimetypes.guess_extension(input_mime) + with tempfile.TemporaryDirectory(prefix="mautrix_ffmpeg_") as tmpdir: + input_file = Path(tmpdir) / f"data{input_extension}" + with open(input_file, "wb") as file: + file.write(data) + return await probe_path(input_file=input_file, logger=logger) async def convert_path( @@ -44,6 +128,7 @@ async def convert_path( output_args: Iterable[str] | None = None, remove_input: bool = False, output_path_override: os.PathLike[str] | str | None = None, + logger: logging.Logger | None = None, ) -> Path | bytes: """ Convert a media file on the disk using ffmpeg. @@ -76,6 +161,10 @@ async def convert_path( else: input_file = Path(input_file) output_file = input_file.parent / f"{input_file.stem}{output_extension}" + if input_file == output_file: + output_file = Path(output_file) + output_file = output_file.parent / f"{output_file.stem}-new{output_extension}" + proc = await asyncio.create_subprocess_exec( ffmpeg_path, *ffmpeg_default_params, @@ -92,9 +181,8 @@ async def convert_path( if proc.returncode != 0: err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})" raise ConverterError(f"ffmpeg error: {err_text}") - elif stderr: - # TODO log warnings? - pass + elif stderr and logger: + logger.warning(f"ffmpeg warning: {stderr.decode('utf-8')}") if remove_input and isinstance(input_file, Path): input_file.unlink(missing_ok=True) return stdout if output_file == "-" else output_file @@ -106,6 +194,7 @@ async def convert_bytes( input_args: Iterable[str] | None = None, output_args: Iterable[str] | None = None, input_mime: str | None = None, + logger: logging.Logger | None = None, ) -> bytes: """ Convert media file data using ffmpeg. @@ -140,6 +229,7 @@ async def convert_bytes( output_extension=output_extension, input_args=input_args, output_args=output_args, + logger=logger, ) with open(output_file, "rb") as file: return file.read() @@ -152,4 +242,6 @@ async def convert_bytes( "NotInstalledError", "convert_bytes", "convert_path", + "probe_bytes", + "probe_path", ] diff --git a/mautrix/util/message_send_checkpoint.py b/mautrix/util/message_send_checkpoint.py index 61eb691d..ee0c17f3 100644 --- a/mautrix/util/message_send_checkpoint.py +++ b/mautrix/util/message_send_checkpoint.py @@ -29,6 +29,7 @@ class MessageSendCheckpointStatus(SerializableEnum): PERM_FAILURE = "PERM_FAILURE" UNSUPPORTED = "UNSUPPORTED" TIMEOUT = "TIMEOUT" + DELIVERY_FAILED = "DELIVERY_FAILED" class MessageSendCheckpointReportedBy(SerializableEnum): @@ -56,12 +57,15 @@ async def send(self, endpoint: str, as_token: str, log: logging.Logger) -> None: return try: headers = {"Authorization": f"Bearer {as_token}", "User-Agent": HTTPAPI.default_ua} - async with aiohttp.ClientSession() as sess, sess.post( - endpoint, - json={"checkpoints": [self.serialize()]}, - headers=headers, - timeout=ClientTimeout(30), - ) as resp: + async with ( + aiohttp.ClientSession() as sess, + sess.post( + endpoint, + json={"checkpoints": [self.serialize()]}, + headers=headers, + timeout=ClientTimeout(30), + ) as resp, + ): if not 200 <= resp.status < 300: text = await resp.text() text = text.replace("\n", "\\n") diff --git a/mautrix/util/program.py b/mautrix/util/program.py index 80d2298b..74752ca5 100644 --- a/mautrix/util/program.py +++ b/mautrix/util/program.py @@ -98,7 +98,7 @@ def _prepare(self) -> None: self.log.info(f"Initializing {self.name} {self.version}") try: - self.prepare() + self.loop.run_until_complete(self._async_prepare()) except Exception: self.log.critical("Unexpected error in initialization", exc_info=True) sys.exit(1) @@ -117,9 +117,10 @@ def preinit(self) -> None: self.prepare_config() self.prepare_log() self.check_config() + self.init_loop() @property - def _default_base_config(self) -> str: + def base_config_path(self) -> str: return f"pkg://{self.module}/example-config.yaml" def prepare_arg_parser(self) -> None: @@ -133,21 +134,13 @@ def prepare_arg_parser(self) -> None: metavar="", help="the path to your config file", ) - self.parser.add_argument( - "-b", - "--base-config", - type=str, - default=self._default_base_config, - metavar="", - help="the path to the example config (for automatic config updates)", - ) self.parser.add_argument( "-n", "--no-update", action="store_true", help="Don't save updated config to disk" ) def prepare_config(self) -> None: """Pre-init lifecycle method. Extend this if you want to customize config loading.""" - self.config = self.config_class(self.args.config, self.args.base_config) + self.config = self.config_class(self.args.config, self.base_config_path) self.load_and_update_config() def load_and_update_config(self) -> None: @@ -155,13 +148,10 @@ def load_and_update_config(self) -> None: try: self.config.update(save=not self.args.no_update) except BaseMissingError: - if self.args.base_config != self._default_base_config: - print(f"Failed to read base config from {self.args.base_config}") - else: - print( - "Failed to read base config from the default path " - f"({self._default_base_config}). Maybe your installation is corrupted?" - ) + print( + "Failed to read base config from the default path " + f"({self.base_config_path}). Maybe your installation is corrupted?" + ) sys.exit(12) def check_config(self) -> None: @@ -179,14 +169,16 @@ def prepare_log(self) -> None: logging.config.dictConfig(copy.deepcopy(self.config["logging"])) self.log = cast(TraceLogger, logging.getLogger("mau.init")) + async def _async_prepare(self) -> None: + self.prepare() + def prepare(self) -> None: """ Lifecycle method where the primary program initialization happens. Use this to fill startup_actions with async startup tasks. """ - self.prepare_loop() - def prepare_loop(self) -> None: + def init_loop(self) -> None: """Init lifecycle method where the asyncio event loop is created.""" if uvloop is not None: uvloop.install() @@ -199,6 +191,7 @@ def start_prometheus(self) -> None: try: enabled = self.config["metrics.enabled"] listen_port = self.config["metrics.listen_port"] + hostname = self.config.get("metrics.hostname", "0.0.0.0") except KeyError: return if not enabled: @@ -208,12 +201,13 @@ def start_prometheus(self) -> None: "Metrics are enabled in config, but prometheus_client is not installed" ) return - prometheus.start_http_server(listen_port) + prometheus.start_http_server(listen_port, addr=hostname) def _run(self) -> None: signal.signal(signal.SIGINT, signal.default_int_handler) signal.signal(signal.SIGTERM, signal.default_int_handler) + self._stop_task = self.loop.create_future() exit_code = 0 try: self.log.debug("Running startup actions...") @@ -224,7 +218,6 @@ def _run(self) -> None: f"Startup actions complete in {round(end_ts - start_ts, 2)} seconds, " "now running forever" ) - self._stop_task = self.loop.create_future() exit_code = self.loop.run_until_complete(self._stop_task) self.log.debug("manual_stop() called, stopping...") except KeyboardInterrupt: diff --git a/mautrix/util/proxy.py b/mautrix/util/proxy.py new file mode 100644 index 00000000..f36da73d --- /dev/null +++ b/mautrix/util/proxy.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Awaitable, Callable, TypeVar +import asyncio +import json +import logging +import time +import urllib.request + +from aiohttp import ClientConnectionError +from yarl import URL + +from mautrix.util.logging import TraceLogger + +try: + from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError +except ImportError: + + class ProxyError(Exception): + pass + + ProxyConnectionError = ProxyTimeoutError = ProxyError + +RETRYABLE_PROXY_EXCEPTIONS = ( + ProxyError, + ProxyTimeoutError, + ProxyConnectionError, + ClientConnectionError, + ConnectionError, + asyncio.TimeoutError, +) + + +class ProxyHandler: + current_proxy_url: str | None = None + log = logging.getLogger("mau.proxy") + + def __init__(self, api_url: str | None) -> None: + self.api_url = api_url + + def get_proxy_url_from_api(self, reason: str | None = None) -> str | None: + assert self.api_url is not None + + api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {})) + + # NOTE: using urllib.request to intentionally block the whole bridge until the proxy change applied + request = urllib.request.Request(api_url, method="GET") + self.log.debug("Requesting proxy from: %s", api_url) + + try: + with urllib.request.urlopen(request) as f: + response = json.loads(f.read().decode()) + except Exception: + self.log.exception("Failed to retrieve proxy from API") + return self.current_proxy_url + else: + return response["proxy_url"] + + def update_proxy_url(self, reason: str | None = None) -> bool: + old_proxy = self.current_proxy_url + new_proxy = None + + if self.api_url is not None: + new_proxy = self.get_proxy_url_from_api(reason) + else: + new_proxy = urllib.request.getproxies().get("http") + + if old_proxy != new_proxy: + self.log.debug("Set new proxy URL: %s", new_proxy) + self.current_proxy_url = new_proxy + return True + + self.log.debug("Got same proxy URL: %s", new_proxy) + return False + + def get_proxy_url(self) -> str | None: + if not self.current_proxy_url: + self.update_proxy_url() + + return self.current_proxy_url + + +T = TypeVar("T") + + +async def proxy_with_retry( + name: str, + func: Callable[[], Awaitable[T]], + logger: TraceLogger, + proxy_handler: ProxyHandler, + on_proxy_change: Callable[[], Awaitable[None]], + max_retries: int = 10, + min_wait_seconds: int = 0, + max_wait_seconds: int = 60, + multiply_wait_seconds: int = 10, + retryable_exceptions: tuple[Exception] = RETRYABLE_PROXY_EXCEPTIONS, + reset_after_seconds: int | None = None, +) -> T: + errors = 0 + last_error = 0 + + while True: + try: + return await func() + except retryable_exceptions as e: + errors += 1 + if errors > max_retries: + raise + wait = errors * multiply_wait_seconds + wait = max(wait, min_wait_seconds) + wait = min(wait, max_wait_seconds) + logger.warning( + "%s while trying to %s, retrying in %d seconds", + e.__class__.__name__, + name, + wait, + ) + if errors > 1 and proxy_handler.update_proxy_url( + f"{e.__class__.__name__} while trying to {name}" + ): + await on_proxy_change() + + # If sufficient time has passed since the previous error, reset the + # error count. Useful for long running tasks with rare failures. + if reset_after_seconds is not None: + now = time.time() + if last_error and now - last_error > reset_after_seconds: + errors = 0 + last_error = now diff --git a/mautrix/util/simple_lock.py b/mautrix/util/simple_lock.py index ad979aec..c6bd08ce 100644 --- a/mautrix/util/simple_lock.py +++ b/mautrix/util/simple_lock.py @@ -47,7 +47,7 @@ def locked(self) -> bool: return not self.noop_mode and not self._event.is_set() async def wait(self, task: str | None = None) -> None: - if not self.noop_mode and not self._event.is_set(): + if self.locked: if self.log and self.message: self.log.debug(self.message, task) await self._event.wait() diff --git a/mautrix/util/utf16_surrogate.py b/mautrix/util/utf16_surrogate.py index d202e15e..92dc49c2 100644 --- a/mautrix/util/utf16_surrogate.py +++ b/mautrix/util/utf16_surrogate.py @@ -16,9 +16,11 @@ def add(text: str) -> str: The text with surrogate pairs. """ return "".join( - "".join(chr(y) for y in struct.unpack(" dict[str, str]: @@ -43,11 +43,12 @@ async def fetch_data() -> dict[str, str]: if __name__ == "__main__": import asyncio + import importlib.resources + import pathlib import sys - import pkg_resources - - path = pkg_resources.resource_filename("mautrix.util", "variation_selector.json") + path = importlib.resources.files("mautrix.util").joinpath("variation_selector.json") + assert isinstance(path, pathlib.Path) emojis = asyncio.run(fetch_data()) with open(path, "w") as file: json.dump(emojis, file, indent=" ", ensure_ascii=False) @@ -59,11 +60,11 @@ async def fetch_data() -> dict[str, str]: ADD_VARIATION_TRANSLATION = str.maketrans( {ord(emoji): f"{emoji}{VARIATION_SELECTOR_16}" for emoji in read_data().values()} ) -SKIN_TONE_MODIFIERS = ("\U0001F3FB", "\U0001F3FC", "\U0001F3FD", "\U0001F3FE", "\U0001F3FF") +SKIN_TONE_MODIFIERS = ("\U0001f3fb", "\U0001f3fc", "\U0001f3fd", "\U0001f3fe", "\U0001f3ff") SKIN_TONE_REPLACEMENTS = {f"{VARIATION_SELECTOR_16}{mod}": mod for mod in SKIN_TONE_MODIFIERS} VARIATION_SELECTOR_REPLACEMENTS = { **SKIN_TONE_REPLACEMENTS, - "\U0001F408\ufe0f\u200d\u2b1b\ufe0f": "\U0001F408\u200d\u2b1b", + "\U0001f408\ufe0f\u200d\u2b1b\ufe0f": "\U0001f408\u200d\u2b1b", } diff --git a/optional-requirements.txt b/optional-requirements.txt index 6cfdcece..a6e0227d 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -1,6 +1,6 @@ python-magic ruamel.yaml -SQLAlchemy +SQLAlchemy<2 commonmark lxml asyncpg @@ -11,3 +11,4 @@ uvloop python-olm unpaddedbase64 pycryptodome +base58 diff --git a/pyproject.toml b/pyproject.toml index c0e41313..bc17b4d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,8 @@ line_length = 99 [tool.black] line-length = 99 -target-version = ["py38"] +target-version = ["py310"] [tool.pytest.ini_options] asyncio_mode = "auto" +addopts = "--ignore mautrix/util/db/ --ignore mautrix/bridge/" diff --git a/setup.py b/setup.py index f7596154..23ac1cb1 100644 --- a/setup.py +++ b/setup.py @@ -2,8 +2,8 @@ from mautrix import __version__ -encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome"] -test_dependencies = ["aiosqlite", "sqlalchemy", "asyncpg", *encryption_dependencies] +encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome", "base58"] +test_dependencies = ["aiosqlite", "asyncpg", "ruamel.yaml", *encryption_dependencies] setuptools.setup( name="mautrix", @@ -28,12 +28,12 @@ ], extras_require={ "detect_mimetype": ["python-magic>=0.4.15,<0.5"], - "lint": ["black==22.1.0", "isort"], + "lint": ["black~=25.1", "isort"], "test": ["pytest", "pytest-asyncio", *test_dependencies], "encryption": encryption_dependencies, }, tests_require=test_dependencies, - python_requires="~=3.8", + python_requires="~=3.10", classifiers=[ "Development Status :: 4 - Beta", @@ -42,9 +42,11 @@ "Framework :: AsyncIO", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ], package_data={