diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5a71c970..3a845672 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -8,12 +8,12 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install libolm @@ -46,17 +46,17 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: - python-version: "3.12" + python-version: "3.14" - uses: isort/isort-action@master with: sortPaths: "./mautrix" - uses: psf/black@stable with: src: "./mautrix" - version: "24.1.1" + version: "26.3.1" - name: pre-commit run: | pip install pre-commit diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8f3efe97..b0d7ab3f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,6 +1,6 @@ build docs builder: stage: build - image: docker:stable + image: docker:latest tags: - amd64 only: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a056ffdb..66065033 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -8,13 +8,13 @@ repos: - id: check-yaml - id: check-added-large-files - repo: https://github.com/psf/black - rev: 24.1.1 + rev: 26.3.1 hooks: - id: black language_version: python3 files: ^mautrix/.*\.pyi?$ - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + rev: 8.0.1 hooks: - id: isort files: ^mautrix/.*\.pyi?$ diff --git a/CHANGELOG.md b/CHANGELOG.md index aad156fd..c7179c86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,52 @@ +## v0.21.0 (2025-11-17) + +* *(event)* Added support for creator power in room v12+. +* *(crypto)* Added support for generating and using recovery keys for verifying + the active device. +* *(bridge)* Added config option for self-signing bot device. +* *(bridge)* Removed check for login flows when using MSC4190 + (thanks to [@meson800] in [#178]). +* *(client)* Changed `set_displayname` and `set_avatar_url` to avoid setting + empty strings if the value is already unset (thanks to [@frebib] in [#171]). + +[@frebib]: https://github.com/frebib +[@meson800]: https://github.com/meson800 +[#171]: https://github.com/mautrix/python/pull/171 +[#178]: https://github.com/mautrix/python/pull/178 + +## v0.20.8 (2025-06-01) + +* *(bridge)* Added support for [MSC4190] (thanks to [@surakin] in [#175]). +* *(appservice)* Renamed `push_ephemeral` in generated registrations to + `receive_ephemeral` to match the accepted version of [MSC2409]. +* *(bridge)* Fixed compatibility with breaking change in aiohttp 3.12.6. + +[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 +[@surakin]: https://github.com/surakin +[#175]: https://github.com/mautrix/python/pull/175 + +## v0.20.7 (2025-01-03) + +* *(types)* Removed support for generating reply fallbacks to implement + [MSC2781]. Stripping fallbacks is still supported. + +[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 + +## v0.20.6 (2024-07-12) + +* *(bridge)* Added `/register` call if `/versions` fails with `M_FORBIDDEN`. + +## v0.20.5 (2024-07-09) + +**Note:** The `bridge` module is deprecated as all bridges are being rewritten +in Go. See for more info. + +* *(client)* Added support for authenticated media downloads. +* *(bridge)* Stopped using cached homeserver URLs for double puppeting if one + is set in the config file. +* *(crypto)* Fixed error when checking OTK counts before uploading new keys. +* *(types)* Added MSC2530 (captions) fields to `MediaMessageEventContent`. + ## v0.20.4 (2024-01-09) * Dropped Python 3.9 support. diff --git a/README.rst b/README.rst index e03596d9..f1342c19 100644 --- a/README.rst +++ b/README.rst @@ -49,7 +49,7 @@ Components .. _#maunium:maunium.net: https://matrix.to/#/#maunium:maunium.net .. _python-appservice-framework: https://github.com/Cadair/python-appservice-framework/ -.. _Client API: https://matrix.org/docs/spec/client_server/r0.6.1.html +.. _Client API: https://spec.matrix.org/latest/client-server-api/ .. _mautrix.api: https://docs.mau.fi/python/latest/api/mautrix.api.html .. _mautrix.client.api: https://docs.mau.fi/python/latest/api/mautrix.client.api.html diff --git a/dev-requirements.txt b/dev-requirements.txt index bb8c2a0a..dff6ee9d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ pre-commit>=2.10.1,<3 -isort>=5.10.1,<6 -black>=24,<25 +isort>=8,<9 +black>=26,<27 diff --git a/docs/requirements.txt b/docs/requirements.txt index 798e0f1b..7269d3a5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,7 +9,7 @@ yarl # that aren't used for anything that's in the docs python-magic ruamel.yaml -SQLAlchemy +SQLAlchemy<2 commonmark asyncpg aiosqlite @@ -17,3 +17,4 @@ prometheus_client python-olm unpaddedbase64 pycryptodome +base58 diff --git a/mautrix/__init__.py b/mautrix/__init__.py index 5d4ce561..c94995f9 100644 --- a/mautrix/__init__.py +++ b/mautrix/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.20.4" +__version__ = "0.21.0" __author__ = "Tulir Asokan " __all__ = [ "api", diff --git a/mautrix/api.py b/mautrix/api.py index 39871bb5..1adde9ec 100644 --- a/mautrix/api.py +++ b/mautrix/api.py @@ -462,6 +462,7 @@ def get_download_url( mxc_uri: str, download_type: Literal["download", "thumbnail"] = "download", file_name: str | None = None, + authenticated: bool = False, ) -> URL: """ Get the full HTTP URL to download a ``mxc://`` URI. @@ -470,6 +471,7 @@ def get_download_url( mxc_uri: The MXC URI whose full URL to get. download_type: The type of download ("download" or "thumbnail"). file_name: Optionally, a file name to include in the download URL. + authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11. Returns: The full HTTP URL. @@ -485,7 +487,11 @@ def get_download_url( "https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png" """ server_name, media_id = self.parse_mxc_uri(mxc_uri) - url = self.base_url / str(APIPath.MEDIA) / "v3" / download_type / server_name / media_id + if authenticated: + url = self.base_url / str(APIPath.CLIENT) / "v1" / "media" + else: + url = self.base_url / str(APIPath.MEDIA) / "v3" + url = url / download_type / server_name / media_id if file_name: url /= file_name return url diff --git a/mautrix/appservice/api/intent.py b/mautrix/appservice/api/intent.py index 00990671..626b34f1 100644 --- a/mautrix/appservice/api/intent.py +++ b/mautrix/appservice/api/intent.py @@ -118,6 +118,8 @@ def __init__( ) -> None: super().__init__(mxid=mxid, api=api, state_store=state_store) self.bot = bot + if bot is not None: + self.versions_cache = bot.versions_cache self.log = api.base_log.getChild("intent") for method in ENSURE_REGISTERED_METHODS: @@ -708,6 +710,8 @@ async def _ensure_has_power_level_for( if not await self.state_store.has_power_levels_cached(room_id): # TODO add option to not try to fetch power levels from server await self.get_power_levels(room_id, ignore_cache=True, ensure_joined=False) + if not await self.state_store.has_create_cached(room_id): + await self.get_state_event(room_id, EventType.ROOM_CREATE, format="event") if not await self.state_store.has_power_level(room_id, self.mxid, event_type): # TODO implement something better raise IntentError( diff --git a/mautrix/appservice/appservice.py b/mautrix/appservice/appservice.py index 892db195..65e202ae 100644 --- a/mautrix/appservice/appservice.py +++ b/mautrix/appservice/appservice.py @@ -13,7 +13,7 @@ from aiohttp import web import aiohttp -from mautrix.types import JSON, RoomAlias, UserID +from mautrix.types import JSON, RoomAlias, UserID, VersionsResponse from mautrix.util.logging import TraceLogger from ..api import HTTPAPI diff --git a/mautrix/bridge/config.py b/mautrix/bridge/config.py index b98ebc52..defed222 100644 --- a/mautrix/bridge/config.py +++ b/mautrix/bridge/config.py @@ -143,6 +143,8 @@ def do_update(self, helper: ConfigUpdateHelper) -> None: copy("bridge.encryption.default") copy("bridge.encryption.require") copy("bridge.encryption.appservice") + copy("bridge.encryption.msc4190") + copy("bridge.encryption.self_sign") copy("bridge.encryption.delete_keys.delete_outbound_on_ack") copy("bridge.encryption.delete_keys.dont_store_outbound") copy("bridge.encryption.delete_keys.ratchet_on_decrypt") @@ -240,4 +242,7 @@ def generate_registration(self) -> None: if self["appservice.ephemeral_events"]: self._registration["de.sorunome.msc2409.push_ephemeral"] = True - self._registration["push_ephemeral"] = True + self._registration["receive_ephemeral"] = True + + if self["bridge.encryption.msc4190"]: + self._registration["io.element.msc4190"] = True diff --git a/mautrix/bridge/custom_puppet.py b/mautrix/bridge/custom_puppet.py index 9057877c..f5befd5f 100644 --- a/mautrix/bridge/custom_puppet.py +++ b/mautrix/bridge/custom_puppet.py @@ -132,8 +132,14 @@ def is_real_user(self) -> bool: return bool(self.custom_mxid and self.access_token) def _fresh_intent(self) -> IntentAPI: - if self.access_token == "appservice-config" and self.custom_mxid: + if self.custom_mxid: _, server = self.az.intent.parse_user_id(self.custom_mxid) + try: + self.base_url = self.homeserver_url_map[server] + except KeyError: + if server == self.az.domain: + self.base_url = self.az.intent.api.base_url + if self.access_token == "appservice-config" and self.custom_mxid: try: secret = self.login_shared_secret_map[server] except KeyError: diff --git a/mautrix/bridge/e2ee.py b/mautrix/bridge/e2ee.py index 7ae66abf..1525b388 100644 --- a/mautrix/bridge/e2ee.py +++ b/mautrix/bridge/e2ee.py @@ -57,6 +57,8 @@ class EncryptionManager: appservice_mode: bool periodically_delete_expired_keys: bool delete_outdated_inbound: bool + msc4190: bool + self_sign: bool bridge: br.Bridge az: AppService @@ -108,6 +110,8 @@ def __init__( self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"]) self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"] self.appservice_mode = bridge.config["bridge.encryption.appservice"] + self.msc4190 = bridge.config["bridge.encryption.msc4190"] + self.self_sign = bridge.config["bridge.encryption.self_sign"] if self.appservice_mode: self.az.otk_handler = self.crypto.handle_as_otk_counts self.az.device_list_handler = self.crypto.handle_as_device_lists @@ -245,12 +249,13 @@ async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> M return decrypted async def start(self) -> None: - flows = await self.client.get_login_flows() - if not flows.supports_type(LoginType.APPSERVICE): - self.log.critical( - "Encryption enabled in config, but homeserver does not support appservice login" - ) - sys.exit(30) + if not self.msc4190: + flows = await self.client.get_login_flows() + if not flows.supports_type(LoginType.APPSERVICE): + self.log.critical( + "Encryption enabled in config, but homeserver does not support appservice login" + ) + sys.exit(30) self.log.debug("Logging in with bridge bot user") if self.crypto_db: try: @@ -261,22 +266,42 @@ async def start(self) -> None: device_id = await self.crypto_store.get_device_id() if device_id: self.log.debug(f"Found device ID in database: {device_id}") - # We set the API token to the AS token here to authenticate the appservice login - # It'll get overridden after the login - self.client.api.token = self.az.as_token - await self.client.login( - login_type=LoginType.APPSERVICE, - device_name=self.device_name, - device_id=device_id, - store_access_token=True, - update_hs_url=False, - ) + + if self.msc4190: + if not device_id: + self.log.debug("Creating bot device with MSC4190") + self.client.api.token = self.az.as_token + await self.client.create_device_msc4190( + device_id=device_id, initial_display_name=self.device_name + ) + else: + # We set the API token to the AS token here to authenticate the appservice login + # It'll get overridden after the login + self.client.api.token = self.az.as_token + await self.client.login( + login_type=LoginType.APPSERVICE, + device_name=self.device_name, + device_id=device_id, + store_access_token=True, + update_hs_url=False, + ) + await self.crypto.load() if not device_id: await self.crypto_store.put_device_id(self.client.device_id) self.log.debug(f"Logged in with new device ID {self.client.device_id}") + await self.crypto.share_keys() elif self.crypto.account.shared: await self._verify_keys_are_on_server() + else: + await self.crypto.share_keys() + if self.self_sign: + trust_state = await self.crypto.resolve_trust(self.crypto.own_identity) + if trust_state < TrustState.CROSS_SIGNED_UNTRUSTED: + recovery_key = await self.crypto.generate_recovery_key() + self.log.info(f"Generated recovery key and signed own device: {recovery_key}") + else: + self.log.debug(f"Own device is already verified ({trust_state})") if self.appservice_mode: self.log.info("End-to-bridge encryption support is enabled (appservice mode)") else: diff --git a/mautrix/bridge/matrix.py b/mautrix/bridge/matrix.py index da42a95e..e5399094 100644 --- a/mautrix/bridge/matrix.py +++ b/mautrix/bridge/matrix.py @@ -225,6 +225,12 @@ async def wait_for_connection(self) -> None: try: self.versions = await self.az.intent.versions() break + except MForbidden: + self.log.debug( + "/versions endpoint returned M_FORBIDDEN, " + "trying to register bridge bot before retrying..." + ) + await self.az.intent.ensure_registered() except Exception: self.log.exception("Connection to homeserver failed, retrying in 10 seconds") await asyncio.sleep(10) diff --git a/mautrix/client/api/authentication.py b/mautrix/client/api/authentication.py index cd77272d..0f6249ea 100644 --- a/mautrix/client/api/authentication.py +++ b/mautrix/client/api/authentication.py @@ -5,6 +5,8 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +import secrets + from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( @@ -117,6 +119,19 @@ async def login( self.api.base_url = base_url.rstrip("/") return resp_data + async def create_device_msc4190(self, device_id: str, initial_display_name: str) -> None: + """ + Create a Device for a user of the homeserver using appservice interface defined in MSC4190 + """ + if len(device_id) == 0: + device_id = DeviceID(secrets.token_urlsafe(10)) + self.api.as_user_id = self.mxid + await self.api.request( + Method.PUT, Path.v3.devices[device_id], {"display_name": initial_display_name} + ) + self.api.as_device_id = device_id + self.device_id = device_id + async def logout(self, clear_access_token: bool = True) -> None: """ Invalidates an existing access token, so that it can no longer be used for authorization. diff --git a/mautrix/client/api/events.py b/mautrix/client/api/events.py index 99e1c178..bd415b9b 100644 --- a/mautrix/client/api/events.py +++ b/mautrix/client/api/events.py @@ -5,7 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Awaitable +from typing import Awaitable, Literal, overload import json from mautrix.api import Method, Path @@ -168,12 +168,41 @@ async def get_event_context( ) return EventContext.deserialize(resp) + @overload async def get_state_event( self, room_id: RoomID, event_type: EventType, state_key: str = "", - ) -> StateEventContent: + *, + format: Literal["content"] = "content", + ) -> StateEventContent: ... + @overload + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: Literal["event"], + ) -> StateEvent: ... + @overload + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: ... + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: """ Looks up the contents of a state event in a room. If the user is joined to the room then the state is taken from the current state of the room. If the user has left the room then the @@ -185,6 +214,9 @@ async def get_state_event( room_id: The ID of the room to look up the state in. event_type: The type of state to look up. state_key: The key of the state to look up. Defaults to empty string. + format: The format of the state event to return. Defaults to "content", which only returns + the content of the state event. If set to "event", the full event is returned. + See https://github.com/matrix-org/matrix-spec/issues/1047 for more info. Returns: The state event. @@ -192,11 +224,17 @@ async def get_state_event( content = await self.api.request( Method.GET, Path.v3.rooms[room_id].state[event_type][state_key], + query_params={"format": format} if format != "content" else None, metrics_method="getStateEvent", ) - content["__mautrix_event_type"] = event_type try: - return StateEvent.deserialize_content(content) + if format == "content": + content["__mautrix_event_type"] = event_type + return StateEvent.deserialize_content(content) + elif format == "event": + return StateEvent.deserialize(content) + else: + return content except SerializerError as e: raise MatrixResponseError("Invalid state event in response") from e @@ -357,15 +395,13 @@ async def get_messages( try: return PaginatedMessages( content["start"], - content["end"], + content.get("end"), [Event.deserialize(event) for event in content["chunk"]], ) except KeyError: if "start" not in content: raise MatrixResponseError("`start` not in response.") - elif "end" not in content: - raise MatrixResponseError("`start` not in response.") - raise MatrixResponseError("`content` not in response.") + raise MatrixResponseError("`chunk` not in response.") except SerializerError as e: raise MatrixResponseError("Invalid events in response") from e diff --git a/mautrix/client/api/modules/crypto.py b/mautrix/client/api/modules/crypto.py index af656916..2d879a63 100644 --- a/mautrix/client/api/modules/crypto.py +++ b/mautrix/client/api/modules/crypto.py @@ -5,13 +5,17 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any +from typing import Any, Union from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( + JSON, ClaimKeysResponse, + CrossSigningKeys, + CrossSigningUsage, DeviceID, + DeviceKeys, EncryptionKeyAlgorithm, EventType, QueryKeysResponse, @@ -82,7 +86,7 @@ async def send_to_one_device( async def upload_keys( self, one_time_keys: dict[str, Any] | None = None, - device_keys: dict[str, Any] | None = None, + device_keys: DeviceKeys | dict[str, Any] | None = None, ) -> dict[EncryptionKeyAlgorithm, int]: """ Publishes end-to-end encryption keys for the device. @@ -102,8 +106,12 @@ async def upload_keys( """ data = {} if device_keys: + if isinstance(device_keys, Serializable): + device_keys = device_keys.serialize() data["device_keys"] = device_keys if one_time_keys: + if isinstance(one_time_keys, Serializable): + one_time_keys = one_time_keys.serialize() data["one_time_keys"] = one_time_keys resp = await self.api.request(Method.POST, Path.v3.keys.upload, data) try: @@ -116,6 +124,43 @@ async def upload_keys( except AttributeError as e: raise MatrixResponseError("Invalid `one_time_key_counts` field in response.") from e + async def upload_cross_signing_keys( + self, + keys: dict[CrossSigningUsage, CrossSigningKeys], + auth: dict[str, JSON] | None = None, + ) -> None: + await self.api.request( + Method.POST, + Path.v3.keys.device_signing.upload, + {f"{usage}_key": key.serialize() for usage, key in keys.items()} + | ({"auth": auth} if auth else {}), + ) + + async def upload_one_signature( + self, + user_id: UserID, + device_id: DeviceID, + keys: Union[DeviceKeys, CrossSigningKeys], + ) -> None: + await self.api.request( + Method.POST, Path.v3.keys.signatures.upload, {user_id: {device_id: keys.serialize()}} + ) + # TODO check failures + + async def upload_many_signatures( + self, + signatures: dict[UserID, dict[DeviceID, Union[DeviceKeys, CrossSigningKeys]]], + ) -> None: + await self.api.request( + Method.POST, + Path.v3.keys.signatures.upload, + { + user_id: {device_id: keys.serialize() for device_id, keys in devices.items()} + for user_id, devices in signatures.items() + }, + ) + # TODO check failures + async def query_keys( self, device_keys: list[UserID] | set[UserID] | dict[UserID, list[DeviceID]], diff --git a/mautrix/client/api/modules/media_repository.py b/mautrix/client/api/modules/media_repository.py index fe52252a..b1d90cc1 100644 --- a/mautrix/client/api/modules/media_repository.py +++ b/mautrix/client/api/modules/media_repository.py @@ -10,6 +10,8 @@ import asyncio import time +from yarl import URL + from mautrix import __optional_imports__ from mautrix.api import MediaPath, Method from mautrix.errors import MatrixResponseError, make_request_error @@ -19,6 +21,7 @@ MediaRepoConfig, MXOpenGraph, SerializerError, + SpecVersions, ) from mautrix.util import background_task from mautrix.util.async_body import async_iter_bytes @@ -178,13 +181,19 @@ async def download_media(self, url: ContentURI, timeout_ms: int | None = None) - Returns: The raw downloaded data. """ - url = self.api.get_download_url(url) + authenticated = (await self.versions()).supports(SpecVersions.V111) + url = self.api.get_download_url(url, authenticated=authenticated) query_params: dict[str, Any] = {"allow_redirect": "true"} if timeout_ms is not None: query_params["timeout_ms"] = timeout_ms + headers: dict[str, str] = {} + if authenticated: + headers["Authorization"] = f"Bearer {self.api.token}" + if self.api.as_user_id: + query_params["user_id"] = self.api.as_user_id req_id = self.api.log_download_request(url, query_params) start = time.monotonic() - async with self.api.session.get(url, params=query_params) as response: + async with self.api.session.get(url, params=query_params, headers=headers) as response: try: response.raise_for_status() return await response.read() @@ -199,7 +208,7 @@ async def download_thumbnail( width: int | None = None, height: int | None = None, resize_method: Literal["crop", "scale"] = None, - allow_remote: bool = True, + allow_remote: bool | None = None, timeout_ms: int | None = None, ): """ @@ -223,7 +232,10 @@ async def download_thumbnail( Returns: The raw downloaded data. """ - url = self.api.get_download_url(url, download_type="thumbnail") + authenticated = (await self.versions()).supports(SpecVersions.V111) + url = self.api.get_download_url( + url, download_type="thumbnail", authenticated=authenticated + ) query_params: dict[str, Any] = {"allow_redirect": "true"} if width is not None: query_params["width"] = width @@ -232,12 +244,17 @@ async def download_thumbnail( if resize_method is not None: query_params["method"] = resize_method if allow_remote is not None: - query_params["allow_remote"] = allow_remote + query_params["allow_remote"] = str(allow_remote).lower() if timeout_ms is not None: query_params["timeout_ms"] = timeout_ms + headers: dict[str, str] = {} + if authenticated: + headers["Authorization"] = f"Bearer {self.api.token}" + if self.api.as_user_id: + query_params["user_id"] = self.api.as_user_id req_id = self.api.log_download_request(url, query_params) start = time.monotonic() - async with self.api.session.get(url, params=query_params) as response: + async with self.api.session.get(url, params=query_params, headers=headers) as response: try: response.raise_for_status() return await response.read() diff --git a/mautrix/client/api/user_data.py b/mautrix/client/api/user_data.py index 4c3d437a..9c380335 100644 --- a/mautrix/client/api/user_data.py +++ b/mautrix/client/api/user_data.py @@ -71,7 +71,7 @@ async def search_users(self, search_query: str, limit: int | None = 10) -> UserS # region 10.2 Profiles # API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#profiles - async def set_displayname(self, displayname: str, check_current: bool = True) -> None: + async def set_displayname(self, displayname: str | None, check_current: bool = True) -> None: """ Set the display name of the current user. @@ -81,7 +81,9 @@ async def set_displayname(self, displayname: str, check_current: bool = True) -> displayname: The new display name for the user. check_current: Whether or not to check if the displayname is already set. """ - if check_current and await self.get_displayname(self.mxid) == displayname: + if check_current and str_or_none(await self.get_displayname(self.mxid)) == str_or_none( + displayname + ): return await self.api.request( Method.PUT, @@ -112,7 +114,9 @@ async def get_displayname(self, user_id: UserID) -> str | None: except KeyError: return None - async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = True) -> None: + async def set_avatar_url( + self, avatar_url: ContentURI | None, check_current: bool = True + ) -> None: """ Set the avatar of the current user. @@ -122,7 +126,9 @@ async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = Tru avatar_url: The ``mxc://`` URI to the new avatar. check_current: Whether or not to check if the avatar is already set. """ - if check_current and await self.get_avatar_url(self.mxid) == avatar_url: + if check_current and str_or_none(await self.get_avatar_url(self.mxid)) == str_or_none( + avatar_url + ): return await self.api.request( Method.PUT, @@ -185,3 +191,10 @@ async def beeper_update_profile(self, custom_fields: dict[str, Any]) -> None: await self.api.request(Method.PATCH, Path.v3.profile[self.mxid], custom_fields) # endregion + + +def str_or_none(v: str | None) -> str | None: + """ + str_or_none empty string values to None + """ + return None if v == "" else v diff --git a/mautrix/client/state_store/abstract.py b/mautrix/client/state_store/abstract.py index 3d37087d..d5b1f5f2 100644 --- a/mautrix/client/state_store/abstract.py +++ b/mautrix/client/state_store/abstract.py @@ -121,6 +121,18 @@ async def set_power_levels( ) -> None: pass + @abstractmethod + async def has_create_cached(self, room_id: RoomID) -> bool: + pass + + @abstractmethod + async def get_create(self, room_id: RoomID) -> StateEvent | None: + pass + + @abstractmethod + async def set_create(self, event: StateEvent) -> None: + pass + @abstractmethod async def has_encryption_info_cached(self, room_id: RoomID) -> bool: pass @@ -135,7 +147,7 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent @abstractmethod async def set_encryption_info( - self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, any] + self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: pass @@ -149,6 +161,8 @@ async def update_state(self, evt: StateEvent) -> None: await self.set_member(evt.room_id, UserID(evt.state_key), evt.content) elif evt.type == EventType.ROOM_ENCRYPTION: await self.set_encryption_info(evt.room_id, evt.content) + elif evt.type == EventType.ROOM_CREATE and evt.sender: + await self.set_create(evt) async def get_membership(self, room_id: RoomID, user_id: UserID) -> Membership: member = await self.get_member(room_id, user_id) @@ -172,4 +186,7 @@ async def has_power_level( room_levels = await self.get_power_levels(room_id) if not room_levels: return None - return room_levels.get_user_level(user_id) >= room_levels.get_event_level(event_type) + create_event = await self.get_create(room_id) + return room_levels.get_user_level(user_id, create_event) >= room_levels.get_event_level( + event_type + ) diff --git a/mautrix/client/state_store/asyncpg/store.py b/mautrix/client/state_store/asyncpg/store.py index d78c04d6..f4f8436f 100644 --- a/mautrix/client/state_store/asyncpg/store.py +++ b/mautrix/client/state_store/asyncpg/store.py @@ -16,6 +16,7 @@ RoomEncryptionStateEventContent, RoomID, Serializable, + StateEvent, UserID, ) from mautrix.util.async_db import Database, Scheme @@ -223,6 +224,29 @@ async def set_power_levels( json.dumps(content.serialize() if isinstance(content, Serializable) else content), ) + async def has_create_cached(self, room_id: RoomID) -> bool: + return bool( + await self.db.fetchval( + "SELECT create_event IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id + ) + ) + + async def get_create(self, room_id: RoomID) -> StateEvent | None: + create_event_json = await self.db.fetchval( + "SELECT create_event FROM mx_room_state WHERE room_id=$1", room_id + ) + if create_event_json is None: + return None + return StateEvent.parse_json(create_event_json) + + async def set_create(self, event: StateEvent) -> None: + await self.db.execute( + "INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) " + "ON CONFLICT (room_id) DO UPDATE SET create_event=$2", + event.room_id, + json.dumps(event.serialize() if isinstance(event, Serializable) else event), + ) + async def has_encryption_info_cached(self, room_id: RoomID) -> bool: return bool( await self.db.fetchval( diff --git a/mautrix/client/state_store/asyncpg/upgrade.py b/mautrix/client/state_store/asyncpg/upgrade.py index 88f115f2..20b2e5b2 100644 --- a/mautrix/client/state_store/asyncpg/upgrade.py +++ b/mautrix/client/state_store/asyncpg/upgrade.py @@ -14,17 +14,18 @@ ) -@upgrade_table.register(description="Latest revision", upgrades_to=3) -async def upgrade_blank_to_v3(conn: Connection, scheme: Scheme) -> None: - await conn.execute( - """CREATE TABLE mx_room_state ( +@upgrade_table.register(description="Latest revision", upgrades_to=4) +async def upgrade_blank_to_v4(conn: Connection, scheme: Scheme) -> None: + await conn.execute(""" + CREATE TABLE mx_room_state ( room_id TEXT PRIMARY KEY, is_encrypted BOOLEAN, has_full_member_list BOOLEAN, encryption TEXT, - power_levels TEXT - )""" - ) + power_levels TEXT, + create_event TEXT + ) + """) membership_check = "" if scheme != Scheme.SQLITE: await conn.execute( @@ -32,16 +33,16 @@ async def upgrade_blank_to_v3(conn: Connection, scheme: Scheme) -> None: ) else: membership_check = "CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock'))" - await conn.execute( - f"""CREATE TABLE mx_user_profile ( + await conn.execute(f""" + CREATE TABLE mx_user_profile ( room_id TEXT, user_id TEXT, membership membership NOT NULL {membership_check}, displayname TEXT, avatar_url TEXT, PRIMARY KEY (room_id, user_id) - )""" - ) + ) + """) @upgrade_table.register(description="Stop using size-limited string fields") @@ -59,13 +60,16 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: @upgrade_table.register(description="Mark rooms that need crypto state event resynced") async def upgrade_v3(conn: Connection) -> None: if await conn.table_exists("portal"): - await conn.execute( - """ + await conn.execute(""" INSERT INTO mx_room_state (room_id, encryption) SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption WHERE mx_room_state.encryption IS NULL - """ - ) + """) + + +@upgrade_table.register(description="Add create event to room state cache") +async def upgrade_v4(conn: Connection) -> None: + await conn.execute("ALTER TABLE mx_room_state ADD COLUMN create_event TEXT") diff --git a/mautrix/client/state_store/file.py b/mautrix/client/state_store/file.py index d567c853..a5d53663 100644 --- a/mautrix/client/state_store/file.py +++ b/mautrix/client/state_store/file.py @@ -15,6 +15,7 @@ PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, + StateEvent, UserID, ) from mautrix.util.file_store import Filer, FileStore @@ -65,3 +66,7 @@ async def set_power_levels( ) -> None: await super().set_power_levels(room_id, content) self._time_limited_flush() + + async def set_create(self, event: StateEvent) -> None: + await super().set_create(event) + self._time_limited_flush() diff --git a/mautrix/client/state_store/memory.py b/mautrix/client/state_store/memory.py index 8f2edac5..2b010a75 100644 --- a/mautrix/client/state_store/memory.py +++ b/mautrix/client/state_store/memory.py @@ -14,6 +14,7 @@ PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, + StateEvent, UserID, ) @@ -25,6 +26,7 @@ class SerializedStateStore(TypedDict): full_member_list: dict[RoomID, bool] power_levels: dict[RoomID, Any] encryption: dict[RoomID, Any] + create: dict[RoomID, Any] class MemoryStateStore(StateStore): @@ -32,12 +34,14 @@ class MemoryStateStore(StateStore): full_member_list: dict[RoomID, bool] power_levels: dict[RoomID, PowerLevelStateEventContent] encryption: dict[RoomID, RoomEncryptionStateEventContent | None] + create: dict[RoomID, StateEvent] def __init__(self) -> None: self.members = {} self.full_member_list = {} self.power_levels = {} self.encryption = {} + self.create = {} def serialize(self) -> SerializedStateStore: """ @@ -58,6 +62,7 @@ def serialize(self) -> SerializedStateStore: room_id: (content.serialize() if content is not None else None) for room_id, content in self.encryption.items() }, + "create": {room_id: evt.serialize() for room_id, evt in self.create.items()}, } def deserialize(self, data: SerializedStateStore) -> None: @@ -84,6 +89,9 @@ def deserialize(self, data: SerializedStateStore) -> None: ) for room_id, content in data["encryption"].items() } + self.create = { + room_id: StateEvent.deserialize(evt) for room_id, evt in data["create"].items() + } async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None: try: @@ -176,6 +184,17 @@ async def set_power_levels( content = PowerLevelStateEventContent.deserialize(content) self.power_levels[room_id] = content + async def has_create_cached(self, room_id: RoomID) -> bool: + return room_id in self.create + + async def get_create(self, room_id: RoomID) -> StateEvent | None: + return self.create.get(room_id) + + async def set_create(self, event: StateEvent | dict[str, Any]) -> None: + if not isinstance(event, StateEvent): + event = StateEvent.deserialize(event) + self.create[event.room_id] = event + async def has_encryption_info_cached(self, room_id: RoomID) -> bool: return room_id in self.encryption diff --git a/mautrix/client/store_updater.py b/mautrix/client/store_updater.py index a5530bde..35280324 100644 --- a/mautrix/client/store_updater.py +++ b/mautrix/client/store_updater.py @@ -5,6 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +from typing import Literal import asyncio from mautrix.errors import MForbidden, MNotFound @@ -196,20 +197,28 @@ async def send_state_event( return event_id async def get_state_event( - self, room_id: RoomID, event_type: EventType, state_key: str = "" - ) -> StateEventContent: - event = await super().get_state_event(room_id, event_type, state_key) + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: + event = await super().get_state_event(room_id, event_type, state_key, format=format) if self.state_store: - fake_event = StateEvent( - type=event_type, - room_id=room_id, - event_id=EventID(""), - sender=UserID(""), - state_key=state_key, - timestamp=0, - content=event, - ) - await self.state_store.update_state(fake_event) + if isinstance(event, StateEvent): + await self.state_store.update_state(event) + else: + fake_event = StateEvent( + type=event_type, + room_id=room_id, + event_id=EventID(""), + sender=UserID(""), + state_key=state_key, + timestamp=0, + content=event, + ) + await self.state_store.update_state(fake_event) return event async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]: diff --git a/mautrix/client/syncer.py b/mautrix/client/syncer.py index 8e115bbb..4b8661a2 100644 --- a/mautrix/client/syncer.py +++ b/mautrix/client/syncer.py @@ -5,11 +5,11 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Awaitable, Callable, Type, TypeVar +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Type, TypeVar from abc import ABC, abstractmethod -from contextlib import suppress from enum import Enum, Flag, auto import asyncio +import itertools import time from mautrix.errors import MUnknownToken @@ -25,7 +25,6 @@ Filter, FilterID, GenericEvent, - MessageEvent, PresenceState, SerializerError, StateEvent, @@ -79,13 +78,18 @@ class InternalEventType(Enum): DEVICE_OTK_COUNT = auto() +class EventHandlerProps(NamedTuple): + wait_sync: bool + sync_stream: Optional[SyncStream] + + class Syncer(ABC): loop: asyncio.AbstractEventLoop log: TraceLogger mxid: UserID - global_event_handlers: list[tuple[EventHandler, bool]] - event_handlers: dict[EventType | InternalEventType, list[tuple[EventHandler, bool]]] + global_event_handlers: dict[EventHandler, EventHandlerProps] + event_handlers: dict[EventType | InternalEventType, dict[EventHandler, EventHandlerProps]] dispatchers: dict[Type[dispatcher.Dispatcher], dispatcher.Dispatcher] syncing_task: asyncio.Task | None ignore_initial_sync: bool @@ -95,7 +99,7 @@ class Syncer(ABC): sync_store: SyncStore def __init__(self, sync_store: SyncStore) -> None: - self.global_event_handlers = [] + self.global_event_handlers = {} self.event_handlers = {} self.dispatchers = {} self.syncing_task = None @@ -158,6 +162,7 @@ def add_event_handler( event_type: InternalEventType | EventType, handler: EventHandler, wait_sync: bool = False, + sync_stream: Optional[SyncStream] = None, ) -> None: """ Add a new event handler. @@ -167,13 +172,15 @@ def add_event_handler( event types. handler: The handler function to add. wait_sync: Whether or not the handler should be awaited before the next sync request. + sync_stream: The sync streams to listen to. Defaults to all. """ if not isinstance(event_type, (EventType, InternalEventType)): raise ValueError("Invalid event type") + props = EventHandlerProps(wait_sync=wait_sync, sync_stream=sync_stream) if event_type == EventType.ALL: - self.global_event_handlers.append((handler, wait_sync)) + self.global_event_handlers[handler] = props else: - self.event_handlers.setdefault(event_type, []).append((handler, wait_sync)) + self.event_handlers.setdefault(event_type, {})[handler] = props def remove_event_handler( self, event_type: EventType | InternalEventType, handler: EventHandler @@ -197,11 +204,7 @@ def remove_event_handler( # No handlers for this event type registered return - # FIXME this is a bit hacky - with suppress(ValueError): - handler_list.remove((handler, True)) - with suppress(ValueError): - handler_list.remove((handler, False)) + handler_list.pop(handler, None) if len(handler_list) == 0 and event_type != EventType.ALL: del self.event_handlers[event_type] @@ -229,7 +232,9 @@ def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asynci else: event.type = event.type.with_class(EventType.Class.MESSAGE) setattr(event, "source", source) - return self.dispatch_manual_event(event.type, event, include_global_handlers=True) + return self.dispatch_manual_event( + event.type, event, include_global_handlers=True, source=source + ) async def _catch_errors(self, handler: EventHandler, data: Any) -> None: try: @@ -243,13 +248,22 @@ def dispatch_manual_event( data: Any, include_global_handlers: bool = False, force_synchronous: bool = False, + source: Optional[SyncStream] = None, ) -> list[asyncio.Task]: - handlers = self.event_handlers.get(event_type, []) + handlers = self.event_handlers.get(event_type, {}).items() if include_global_handlers: - handlers = self.global_event_handlers + handlers + handlers = itertools.chain(self.global_event_handlers.items(), handlers) tasks = [] - for handler, wait_sync in handlers: - if force_synchronous or wait_sync: + if source is None: + source = getattr(data, "source", None) + for handler, props in handlers: + if ( + props.sync_stream is not None + and source is not None + and not props.sync_stream & source + ): + continue + if force_synchronous or props.wait_sync: tasks.append(asyncio.create_task(self._catch_errors(handler, data))) else: background_task.create(self._catch_errors(handler, data)) @@ -263,6 +277,7 @@ async def run_internal_event( event_type, custom_type if custom_type is not None else kwargs, include_global_handlers=False, + source=SyncStream.INTERNAL, ) await asyncio.gather(*tasks) @@ -274,6 +289,7 @@ def dispatch_internal_event( event_type, custom_type if custom_type is not None else kwargs, include_global_handlers=False, + source=SyncStream.INTERNAL, ) def _try_deserialize(self, type: Type[T], data: JSON) -> T | GenericEvent: @@ -353,12 +369,18 @@ def handle_sync(self, data: JSON) -> list[asyncio.Task]: events: list[dict[str, JSON]] = room_data.get("invite_state", {}).get("events", []) for raw_event in events: raw_event["room_id"] = room_id - raw_invite = next( - raw_event - for raw_event in events - if raw_event.get("type", "") == "m.room.member" - and raw_event.get("state_key", "") == self.mxid - ) + try: + raw_invite = next( + raw_event + for raw_event in events + if raw_event.get("type", "") == "m.room.member" + and raw_event.get("state_key", "") == self.mxid + ) + except StopIteration: + self.log.warning( + f"Corrupted invite section in sync: no invite event present for {room_id}" + ) + continue # These aren't required by the spec, so make sure they're set raw_invite.setdefault("event_id", None) raw_invite.setdefault("origin_server_ts", int(time.time() * 1000)) diff --git a/mautrix/crypto/account.py b/mautrix/crypto/account.py index db508262..a00ada71 100644 --- a/mautrix/crypto/account.py +++ b/mautrix/crypto/account.py @@ -10,15 +10,17 @@ from mautrix.types import ( DeviceID, + DeviceKeys, EncryptionAlgorithm, EncryptionKeyAlgorithm, IdentityKey, + KeyID, SigningKey, UserID, ) -from . import base from .sessions import Session +from .signature import sign_olm class OlmAccount(olm.Account): @@ -74,19 +76,18 @@ def new_outbound_session(self, target_key: IdentityKey, one_time_key: IdentityKe session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now() ) - def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> Dict[str, Any]: - device_keys = { - "user_id": user_id, - "device_id": device_id, - "algorithms": [EncryptionAlgorithm.OLM_V1.value, EncryptionAlgorithm.MEGOLM_V1.value], - "keys": { - f"{algorithm}:{device_id}": key for algorithm, key in self.identity_keys.items() + def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> DeviceKeys: + device_keys = DeviceKeys( + user_id=user_id, + device_id=device_id, + algorithms=[EncryptionAlgorithm.OLM_V1, EncryptionAlgorithm.MEGOLM_V1], + keys={ + KeyID(algorithm=EncryptionKeyAlgorithm(algorithm), key_id=device_id): key + for algorithm, key in self.identity_keys.items() }, - } - signature = self.sign(base.canonical_json(device_keys)) - device_keys["signatures"] = { - user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature} - } + signatures={}, + ) + device_keys.signatures[user_id] = {KeyID.ed25519(device_id): sign_olm(device_keys, self)} return device_keys def get_one_time_keys( @@ -97,12 +98,12 @@ def get_one_time_keys( self.generate_one_time_keys(new_count) keys = {} for key_id, key in self.one_time_keys.get("curve25519", {}).items(): - signature = self.sign(base.canonical_json({"key": key})) - keys[f"{EncryptionKeyAlgorithm.SIGNED_CURVE25519}:{key_id}"] = { + keys[str(KeyID.signed_curve25519(IdentityKey(key_id)))] = { "key": key, "signatures": { - user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature} + user_id: { + str(KeyID.ed25519(device_id)): sign_olm({"key": key}, self), + } }, } - self.mark_keys_as_published() return keys diff --git a/mautrix/crypto/base.py b/mautrix/crypto/base.py index 1f8e9a62..f7015ca1 100644 --- a/mautrix/crypto/base.py +++ b/mautrix/crypto/base.py @@ -5,41 +5,30 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Awaitable, Callable, TypedDict +from typing import Awaitable, Callable import asyncio -import functools -import json - -import olm from mautrix.errors import MForbidden, MNotFound from mautrix.types import ( - DeviceID, - EncryptionKeyAlgorithm, EventType, IdentityKey, - KeyID, RequestedKeyInfo, RoomEncryptionStateEventContent, RoomID, RoomKeyEventContent, SessionID, - SigningKey, TrustState, UserID, ) from mautrix.util.logging import TraceLogger from .. import client as cli, crypto - - -class SignedObject(TypedDict): - signatures: dict[UserID, dict[str, str]] - unsigned: Any +from .ssss import Machine as SSSSMachine class BaseOlmMachine: client: cli.Client + ssss: SSSSMachine log: TraceLogger crypto_store: crypto.CryptoStore state_store: crypto.StateStore @@ -116,27 +105,3 @@ async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None: evt.beeper_max_age_ms = encryption_info.rotation_period_ms if not evt.beeper_max_messages: evt.beeper_max_messages = encryption_info.rotation_period_msgs - - -canonical_json = functools.partial( - json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True -) - - -def verify_signature_json( - data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey -) -> bool: - data_copy = {**data} - data_copy.pop("unsigned", None) - signatures = data_copy.pop("signatures") - key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name)) - try: - signature = signatures[user_id][key_id] - except KeyError: - return False - signed_data = canonical_json(data_copy) - try: - olm.ed25519_verify(key, signed_data, signature) - return True - except olm.OlmVerifyError: - return False diff --git a/mautrix/crypto/cross_signing.py b/mautrix/crypto/cross_signing.py new file mode 100644 index 00000000..7577cd5d --- /dev/null +++ b/mautrix/crypto/cross_signing.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from ..types import ( + JSON, + CrossSigner, + CrossSigningKeys, + CrossSigningUsage, + DeviceIdentity, + EventType, + KeyID, + UserID, +) +from .cross_signing_key import CrossSigningPrivateKeys, CrossSigningPublicKeys, CrossSigningSeeds +from .device_lists import DeviceListMachine +from .signature import sign_olm +from .ssss import Key as SSSSKey + + +class CrossSigningMachine(DeviceListMachine): + _cross_signing_public_keys: CrossSigningPublicKeys | None + _cross_signing_public_keys_fetched: bool + _cross_signing_private_keys: CrossSigningPrivateKeys | None + + async def verify_with_recovery_key(self, recovery_key: str) -> None: + if not self.account.shared: + raise ValueError("Device keys must be shared before verifying with recovery key") + key_id, key_data = await self.ssss.get_default_key_data() + ssss_key = key_data.verify_recovery_key(key_id, recovery_key) + seeds = await self._fetch_cross_signing_keys_from_ssss(ssss_key) + self._import_cross_signing_keys(seeds) + await self.sign_own_device(self.own_identity) + + def _import_cross_signing_keys(self, seeds: CrossSigningSeeds) -> None: + self._cross_signing_private_keys = seeds.to_keys() + self._cross_signing_public_keys = self._cross_signing_private_keys.public_keys + + async def generate_recovery_key( + self, passphrase: str | None = None, seeds: CrossSigningSeeds | None = None + ) -> str: + if not self.account.shared: + raise ValueError("Device keys must be shared before generating recovery key") + seeds = seeds or CrossSigningSeeds.generate() + ssss_key = await self.ssss.generate_and_upload_key(passphrase) + await self._upload_cross_signing_keys_to_ssss(ssss_key, seeds) + await self._publish_cross_signing_keys(seeds.to_keys()) + await self.ssss.set_default_key_id(ssss_key.id) + await self.sign_own_device(self.own_identity) + return ssss_key.recovery_key + + async def _fetch_cross_signing_keys_from_ssss(self, key: SSSSKey) -> CrossSigningSeeds: + return CrossSigningSeeds( + master_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_MASTER, key + ), + user_signing_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_USER_SIGNING, key + ), + self_signing_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_SELF_SIGNING, key + ), + ) + + async def _upload_cross_signing_keys_to_ssss( + self, key: SSSSKey, seeds: CrossSigningSeeds + ) -> None: + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_MASTER, seeds.master_key, key + ) + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_USER_SIGNING, seeds.user_signing_key, key + ) + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_SELF_SIGNING, seeds.self_signing_key, key + ) + + async def get_own_cross_signing_public_keys(self) -> CrossSigningPublicKeys | None: + if self._cross_signing_public_keys or self._cross_signing_public_keys_fetched: + return self._cross_signing_public_keys + keys = await self.get_cross_signing_public_keys(self.client.mxid) + self._cross_signing_public_keys_fetched = True + if keys: + self._cross_signing_public_keys = keys + return keys + + async def get_cross_signing_public_keys( + self, user_id: UserID + ) -> CrossSigningPublicKeys | None: + db_keys = await self.crypto_store.get_cross_signing_keys(user_id) + if CrossSigningUsage.MASTER not in db_keys and user_id not in self._cs_fetch_attempted: + self.log.debug(f"Didn't find any cross-signing keys for {user_id}, fetching...") + async with self._fetch_keys_lock: + if user_id not in self._cs_fetch_attempted: + self._cs_fetch_attempted.add(user_id) + await self._fetch_keys([user_id], include_untracked=True) + db_keys = await self.crypto_store.get_cross_signing_keys(user_id) + if CrossSigningUsage.MASTER not in db_keys: + return None + return CrossSigningPublicKeys( + master_key=db_keys[CrossSigningUsage.MASTER].key, + self_signing_key=( + db_keys[CrossSigningUsage.SELF].key if CrossSigningUsage.SELF in db_keys else None + ), + user_signing_key=( + db_keys[CrossSigningUsage.USER].key if CrossSigningUsage.USER in db_keys else None + ), + ) + + async def sign_own_device(self, device: DeviceIdentity) -> None: + full_keys = await self._get_full_device_keys(device) + ssk = self._cross_signing_private_keys.self_signing_key + signature = sign_olm(full_keys, ssk) + full_keys.signatures = {self.client.mxid: {KeyID.ed25519(ssk.public_key): signature}} + await self.client.upload_one_signature(device.user_id, device.device_id, full_keys) + await self.crypto_store.put_signature( + CrossSigner(device.user_id, device.signing_key), + CrossSigner(self.client.mxid, ssk.public_key), + signature, + ) + + async def _publish_cross_signing_keys( + self, + keys: CrossSigningPrivateKeys, + auth: dict[str, JSON] | None = None, + ) -> None: + public = keys.public_keys + master_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.MASTER], + keys={KeyID.ed25519(public.master_key): public.master_key}, + ) + master_key.signatures = { + self.client.mxid: { + KeyID.ed25519(self.client.device_id): sign_olm(master_key, self.account), + } + } + self_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.SELF], + keys={KeyID.ed25519(public.self_signing_key): public.self_signing_key}, + ) + self_key.signatures = { + self.client.mxid: { + KeyID.ed25519(public.master_key): sign_olm(self_key, keys.master_key), + } + } + user_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.USER], + keys={KeyID.ed25519(public.user_signing_key): public.user_signing_key}, + ) + user_key.signatures = { + self.client.mxid: { + KeyID.ed25519(public.master_key): sign_olm(user_key, keys.master_key), + } + } + await self.client.upload_cross_signing_keys( + keys={ + CrossSigningUsage.MASTER: master_key, + CrossSigningUsage.SELF: self_key, + CrossSigningUsage.USER: user_key, + }, + auth=auth, + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.MASTER, public.master_key + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.SELF, public.self_signing_key + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.USER, public.user_signing_key + ) + self._cross_signing_private_keys = keys + self._cross_signing_public_keys = public diff --git a/mautrix/crypto/cross_signing_key.py b/mautrix/crypto/cross_signing_key.py new file mode 100644 index 00000000..f4e3c1c1 --- /dev/null +++ b/mautrix/crypto/cross_signing_key.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import NamedTuple + +import olm + +from mautrix.crypto.ssss.util import cryptorand +from mautrix.types import SigningKey + + +class CrossSigningPublicKeys(NamedTuple): + master_key: SigningKey + self_signing_key: SigningKey + user_signing_key: SigningKey + + +class CrossSigningPrivateKeys(NamedTuple): + master_key: olm.PkSigning + self_signing_key: olm.PkSigning + user_signing_key: olm.PkSigning + + @property + def public_keys(self) -> CrossSigningPublicKeys: + return CrossSigningPublicKeys( + master_key=self.master_key.public_key, + self_signing_key=self.self_signing_key.public_key, + user_signing_key=self.user_signing_key.public_key, + ) + + +class CrossSigningSeeds(NamedTuple): + master_key: bytes + self_signing_key: bytes + user_signing_key: bytes + + def to_keys(self) -> CrossSigningPrivateKeys: + return CrossSigningPrivateKeys( + master_key=olm.PkSigning(self.master_key), + self_signing_key=olm.PkSigning(self.self_signing_key), + user_signing_key=olm.PkSigning(self.user_signing_key), + ) + + @classmethod + def generate(cls) -> "CrossSigningSeeds": + return cls( + master_key=cryptorand.read(32), + self_signing_key=cryptorand.read(32), + user_signing_key=cryptorand.read(32), + ) diff --git a/mautrix/crypto/device_lists.py b/mautrix/crypto/device_lists.py index 1b5e0dbd..b00b104c 100644 --- a/mautrix/crypto/device_lists.py +++ b/mautrix/crypto/device_lists.py @@ -23,10 +23,23 @@ UserID, ) -from .base import BaseOlmMachine, verify_signature_json +from .base import BaseOlmMachine +from .signature import verify_signature_json class DeviceListMachine(BaseOlmMachine): + @property + def own_identity(self) -> DeviceIdentity: + return DeviceIdentity( + user_id=self.client.mxid, + device_id=self.client.device_id, + identity_key=self.account.identity_key, + signing_key=self.account.signing_key, + trust=TrustState.VERIFIED, + deleted=False, + name="", + ) + async def _fetch_keys( self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False ) -> dict[UserID, dict[DeviceID, DeviceIdentity]]: @@ -46,61 +59,67 @@ async def _fetch_keys( data = {} for user_id, devices in resp.device_keys.items(): missing_users.remove(user_id) - - new_devices = {} - existing_devices = (await self.crypto_store.get_devices(user_id)) or {} - - self.log.trace( - f"Updating devices for {user_id}, got {len(devices)}, " - f"have {len(existing_devices)} in store" - ) - changed = False - ssks = resp.self_signing_keys.get(user_id) - ssk = ssks.first_ed25519_key if ssks else None - for device_id, device_keys in devices.items(): - try: - existing = existing_devices[device_id] - except KeyError: - existing = None - changed = True - self.log.trace(f"Validating device {device_keys} of {user_id}") - try: - new_device = await self._validate_device( - user_id, device_id, device_keys, existing - ) - except DeviceValidationError as e: - self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}") - else: - if new_device: - new_devices[device_id] = new_device - await self._store_device_self_signatures(device_keys, ssk) - self.log.debug( - f"Storing new device list for {user_id} containing {len(new_devices)} devices" - ) - await self.crypto_store.put_devices(user_id, new_devices) - data[user_id] = new_devices - - if changed or len(new_devices) != len(existing_devices): - if self.delete_keys_on_device_delete: - for device_id in existing_devices.keys() - new_devices.keys(): - device = existing_devices[device_id] - removed_ids = await self.crypto_store.redact_group_sessions( - room_id=None, sender_key=device.identity_key, reason="device removed" - ) - self.log.info( - "Redacted megolm sessions sent by removed device " - f"{device.user_id}/{device.device_id}: {removed_ids}" - ) - await self.on_devices_changed(user_id) + async with self.crypto_store.transaction(): + data[user_id] = await self._process_fetched_keys(user_id, devices, resp) for user_id in missing_users: self.log.warning(f"Didn't get any devices for user {user_id}") - for user_id in users: - await self._store_cross_signing_keys(resp, user_id) - return data + async def _process_fetched_keys( + self, + user_id: UserID, + devices: dict[DeviceID, DeviceKeys], + resp: QueryKeysResponse, + ) -> dict[DeviceID, DeviceIdentity]: + new_devices = {} + existing_devices = (await self.crypto_store.get_devices(user_id)) or {} + + self.log.trace( + f"Updating devices for {user_id}, got {len(devices)}, " + f"have {len(existing_devices)} in store" + ) + changed = False + ssks = resp.self_signing_keys.get(user_id) + ssk = ssks.first_ed25519_key if ssks else None + for device_id, device_keys in devices.items(): + try: + existing = existing_devices[device_id] + except KeyError: + existing = None + changed = True + self.log.trace(f"Validating device {device_keys} of {user_id}") + try: + new_device = await self._validate_device(user_id, device_id, device_keys, existing) + except DeviceValidationError as e: + self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}") + else: + if new_device: + new_devices[device_id] = new_device + await self._store_device_self_signatures(device_keys, ssk) + self.log.debug( + f"Storing new device list for {user_id} containing {len(new_devices)} devices" + ) + await self.crypto_store.put_devices(user_id, new_devices) + + if changed or len(new_devices) != len(existing_devices): + if self.delete_keys_on_device_delete: + for device_id in existing_devices.keys() - new_devices.keys(): + device = existing_devices[device_id] + removed_ids = await self.crypto_store.redact_group_sessions( + room_id=None, sender_key=device.identity_key, reason="device removed" + ) + self.log.info( + "Redacted megolm sessions sent by removed device " + f"{device.user_id}/{device.device_id}: {removed_ids}" + ) + await self.on_devices_changed(user_id) + + await self._store_cross_signing_keys(resp, user_id) + + return new_devices + async def _store_device_self_signatures( self, device_keys: DeviceKeys, self_signing_key: SigningKey | None ) -> None: @@ -193,7 +212,7 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User signing_key = device.ed25519 except KeyError: pass - if len(signing_key) != 43: + if not signing_key or len(signing_key) != 43: self.log.debug( f"Cross-signing key {user_id}/{actual_key} has a signature from " f"an unknown key {key_id}" @@ -219,6 +238,12 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User else: self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}") + async def _get_full_device_keys(self, device: DeviceIdentity) -> DeviceKeys: + resp = await self.client.query_keys({device.user_id: [device.device_id]}) + keys = resp.device_keys[device.user_id][device.device_id] + await self._validate_device(device.user_id, device.device_id, keys, device) + return keys + async def get_or_fetch_device( self, user_id: UserID, device_id: DeviceID ) -> DeviceIdentity | None: @@ -296,18 +321,23 @@ async def _validate_device( deleted=False, ) - async def resolve_trust(self, device: DeviceIdentity) -> TrustState: + async def resolve_trust(self, device: DeviceIdentity, allow_fetch: bool = True) -> TrustState: try: - return await self._try_resolve_trust(device) + return await self._try_resolve_trust(device, allow_fetch) except Exception: self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}") return TrustState.UNVERIFIED - async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState: - if device.trust in (TrustState.VERIFIED, TrustState.BLACKLISTED): + async def _try_resolve_trust( + self, device: DeviceIdentity, allow_fetch: bool = True + ) -> TrustState: + if device.device_id != self.client.device_id and device.trust in ( + TrustState.VERIFIED, + TrustState.BLACKLISTED, + ): return device.trust their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id) - if len(their_keys) == 0 and device.user_id not in self._cs_fetch_attempted: + if len(their_keys) == 0 and allow_fetch and device.user_id not in self._cs_fetch_attempted: self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...") async with self._fetch_keys_lock: if device.user_id not in self._cs_fetch_attempted: @@ -318,7 +348,8 @@ async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState: msk = their_keys[CrossSigningUsage.MASTER] ssk = their_keys[CrossSigningUsage.SELF] except KeyError as e: - self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}") + if allow_fetch: + self.log.warning(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}") return TrustState.UNVERIFIED ssk_signed = await self.crypto_store.is_key_signed_by( target=CrossSigner(device.user_id, ssk.key), diff --git a/mautrix/crypto/encrypt_olm.py b/mautrix/crypto/encrypt_olm.py index 620cbbc8..029ad1e5 100644 --- a/mautrix/crypto/encrypt_olm.py +++ b/mautrix/crypto/encrypt_olm.py @@ -18,8 +18,9 @@ UserID, ) -from .base import BaseOlmMachine, verify_signature_json +from .base import BaseOlmMachine from .sessions import Session +from .signature import verify_signature_json ClaimKeysList = Dict[UserID, Dict[DeviceID, DeviceIdentity]] diff --git a/mautrix/crypto/machine.py b/mautrix/crypto/machine.py index 0cfe6ea3..60c65677 100644 --- a/mautrix/crypto/machine.py +++ b/mautrix/crypto/machine.py @@ -32,10 +32,12 @@ from mautrix.util.logging import TraceLogger from .account import OlmAccount +from .cross_signing import CrossSigningMachine from .decrypt_megolm import MegolmDecryptionMachine from .encrypt_megolm import MegolmEncryptionMachine from .key_request import KeyRequestingMachine from .key_share import KeySharingMachine +from .ssss import Machine as SSSSMachine from .store import CryptoStore, StateStore from .unwedge import OlmUnwedgingMachine @@ -46,6 +48,7 @@ class OlmMachine( OlmUnwedgingMachine, KeySharingMachine, KeyRequestingMachine, + CrossSigningMachine, ): """ OlmMachine is the main class for handling things related to Matrix end-to-end encryption with @@ -58,6 +61,7 @@ class OlmMachine( log: TraceLogger crypto_store: CryptoStore state_store: StateStore + ssss: SSSSMachine account: Optional[OlmAccount] @@ -70,6 +74,7 @@ def __init__( ) -> None: super().__init__() self.client = client + self.ssss = SSSSMachine(client) self.log = log or logging.getLogger("mau.crypto") self.crypto_store = crypto_store self.state_store = state_store @@ -96,6 +101,10 @@ def __init__( self._prev_unwedge = {} self._cs_fetch_attempted = set() + self._cross_signing_public_keys = None + self._cross_signing_public_keys_fetched = False + self._cross_signing_private_keys = None + self.client.add_event_handler( cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True ) @@ -213,6 +222,11 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None: passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you do syncing in some manual way. """ + if isinstance(evt, DecryptedOlmEvent): + self.log.warning( + f"Dropping unexpected nested encrypted to-device event from {evt.sender}" + ) + return self.log.trace( f"Handling encrypted to-device event from {evt.content.sender_key} ({evt.sender})" ) @@ -221,6 +235,8 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None: await self._receive_room_key(decrypted_evt) elif decrypted_evt.type == EventType.FORWARDED_ROOM_KEY: await self._receive_forwarded_room_key(decrypted_evt) + else: + self.client.dispatch_event(decrypted_evt, source=evt.source) async def _receive_room_key(self, evt: DecryptedOlmEvent) -> None: # TODO nio had a comment saying "handle this better" @@ -292,7 +308,7 @@ async def _share_keys(self, current_otk_count: int | None) -> None: ): self.log.debug("Checking OTK count on server") current_otk_count = (await self.client.upload_keys()).get( - EncryptionKeyAlgorithm.SIGNED_CURVE25519 + EncryptionKeyAlgorithm.SIGNED_CURVE25519, 0 ) device_keys = ( self.account.get_device_keys(self.client.mxid, self.client.device_id) @@ -310,6 +326,7 @@ async def _share_keys(self, current_otk_count: int | None) -> None: self.log.debug(f"Uploading {len(one_time_keys)} one-time keys") resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys) self.account.shared = True + self.account.mark_keys_as_published() self._last_key_share = time.monotonic() await self.crypto_store.put_account(self.account) self.log.debug(f"Shared keys and saved account, new keys: {resp}") diff --git a/mautrix/crypto/signature.py b/mautrix/crypto/signature.py new file mode 100644 index 00000000..6dc13e65 --- /dev/null +++ b/mautrix/crypto/signature.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Any, TypedDict +import functools +import json + +import olm +import unpaddedbase64 + +from mautrix.types import ( + JSON, + DeviceID, + EncryptionKeyAlgorithm, + KeyID, + Serializable, + Signature, + SigningKey, + UserID, +) + +try: + from Crypto.PublicKey import ECC + from Crypto.Signature import eddsa +except ImportError: + from Cryptodome.PublicKey import ECC + from Cryptodome.Signature import eddsa + +canonical_json = functools.partial( + json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True +) + + +class SignedObject(TypedDict): + signatures: dict[UserID, dict[str, str]] + unsigned: Any + + +def sign_olm(data: dict[str, JSON] | Serializable, key: olm.PkSigning | olm.Account) -> Signature: + if isinstance(data, Serializable): + data = data.serialize() + data.pop("signatures", None) + data.pop("unsigned", None) + return Signature(key.sign(canonical_json(data))) + + +def verify_signature_json( + data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey +) -> bool: + data_copy = {**data} + data_copy.pop("unsigned", None) + signatures = data_copy.pop("signatures") + key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name)) + try: + signature = signatures[user_id][key_id] + decoded_key = unpaddedbase64.decode_base64(key) + # pycryptodome doesn't accept raw keys, so wrap it in a DER structure + der_key = b"\x30\x2a\x30\x05\x06\x03\x2b\x65\x70\x03\x21\x00" + decoded_key + decoded_signature = unpaddedbase64.decode_base64(signature) + parsed_key = ECC.import_key(der_key) + verifier = eddsa.new(parsed_key, "rfc8032") + verifier.verify(canonical_json(data_copy).encode("utf-8"), decoded_signature) + return True + except (KeyError, ValueError): + return False diff --git a/mautrix/crypto/signature_test.py b/mautrix/crypto/signature_test.py new file mode 100644 index 00000000..115e4836 --- /dev/null +++ b/mautrix/crypto/signature_test.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from mautrix.types import SigningKey, UserID + +from .signature import verify_signature_json + + +def test_verify_signature_json() -> None: + assert verify_signature_json( + # This is actually a federation PDU rather than a device signature, + # but they're both 25519 curves so it doesn't make a difference. + { + "auth_events": [ + "$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw", + "$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM", + "$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo", + ], + "content": {}, + "depth": 3212, + "hashes": {"sha256": "K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"}, + "origin_server_ts": 1754242687127, + "prev_events": ["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"], + "room_id": "!offtopic-2:continuwuity.org", + "sender": "@tulir:maunium.net", + "type": "m.room.message", + "signatures": { + UserID("maunium.net"): { + "ed25519:a_xxeS": "SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw" + } + }, + "unsigned": {"age_ts": 1754242687146}, + }, + UserID("maunium.net"), + "a_xxeS", + SigningKey("lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg"), + ) diff --git a/mautrix/crypto/ssss/__init__.py b/mautrix/crypto/ssss/__init__.py new file mode 100644 index 00000000..9224418d --- /dev/null +++ b/mautrix/crypto/ssss/__init__.py @@ -0,0 +1,8 @@ +from .key import Key, KeyMetadata, PassphraseMetadata +from .machine import Machine +from .types import ( + Algorithm, + EncryptedAccountDataEventContent, + EncryptedKeyData, + PassphraseAlgorithm, +) diff --git a/mautrix/crypto/ssss/key.py b/mautrix/crypto/ssss/key.py new file mode 100644 index 00000000..691ded71 --- /dev/null +++ b/mautrix/crypto/ssss/key.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Optional +import base64 +import hashlib +import hmac + +from attr import dataclass +import unpaddedbase64 + +from mautrix.types import EventType, SerializableAttrs + +from .types import Algorithm, EncryptedKeyData, PassphraseAlgorithm +from .util import ( + calculate_hash, + cryptorand, + decode_base58_recovery_key, + derive_keys, + encode_base58_recovery_key, + prepare_aes, +) + +try: + from Crypto.Cipher import AES + from Crypto.Util import Counter +except ImportError: + from Cryptodome.Cipher import AES + from Cryptodome.Util import Counter + + +@dataclass +class PassphraseMetadata(SerializableAttrs): + algorithm: PassphraseAlgorithm + iterations: int + salt: str + bits: int = 256 + + def get_key(self, passphrase: str) -> bytes: + if self.algorithm != PassphraseAlgorithm.PBKDF2: + raise ValueError(f"Unsupported passphrase algorithm {self.algorithm}") + return hashlib.pbkdf2_hmac( + "sha512", + passphrase.encode("utf-8"), + self.salt.encode("utf-8"), + self.iterations, + self.bits // 8, + ) + + +@dataclass +class KeyMetadata(SerializableAttrs): + algorithm: Algorithm + + iv: str | None = None + mac: str | None = None + + name: str | None = None + passphrase: Optional[PassphraseMetadata] = None + + def verify_passphrase(self, key_id: str, phrase: str) -> "Key": + if not self.passphrase: + raise ValueError("Passphrase not set on this key") + return self.verify_raw_key(key_id, self.passphrase.get_key(phrase)) + + def verify_recovery_key(self, key_id: str, recovery_key: str) -> "Key": + decoded_key = decode_base58_recovery_key(recovery_key) + if not decoded_key: + raise ValueError("Invalid recovery key syntax") + return self.verify_raw_key(key_id, decoded_key) + + def verify_raw_key(self, key_id: str, key: bytes) -> "Key": + if self.mac.rstrip("=") != calculate_hash(key, self.iv): + raise ValueError("Key MAC does not match") + return Key(id=key_id, key=key, metadata=self) + + +@dataclass +class Key: + id: str + key: bytes + metadata: KeyMetadata + + @classmethod + def generate(cls, passphrase: str | None = None) -> "Key": + passphrase_meta = ( + PassphraseMetadata( + algorithm=PassphraseAlgorithm.PBKDF2, + iterations=500_000, + salt=base64.b64encode(cryptorand.read(24)).decode("utf-8"), + bits=256, + ) + if passphrase + else None + ) + key = passphrase_meta.get_key(passphrase) if passphrase else cryptorand.read(32) + iv = unpaddedbase64.encode_base64(cryptorand.read(16)) + metadata = KeyMetadata( + algorithm=Algorithm.AES_HMAC_SHA2, + passphrase=passphrase_meta, + mac=calculate_hash(key, iv), + iv=iv, + ) + key_id = unpaddedbase64.encode_base64(cryptorand.read(24)) + return cls(key=key, id=key_id, metadata=metadata) + + @property + def recovery_key(self) -> str: + return encode_base58_recovery_key(self.key) + + def encrypt(self, event_type: str | EventType, data: str | bytes) -> EncryptedKeyData: + if isinstance(data, str): + data = data.encode("utf-8") + data = base64.b64encode(data).rstrip(b"=") + + aes_key, hmac_key = derive_keys(self.key, event_type) + iv = bytearray(cryptorand.read(16)) + iv[8] &= 0x7F + ciphertext = prepare_aes(aes_key, iv).encrypt(data) + digest = hmac.digest(hmac_key, ciphertext, hashlib.sha256) + return EncryptedKeyData( + ciphertext=unpaddedbase64.encode_base64(ciphertext), + iv=unpaddedbase64.encode_base64(iv), + mac=unpaddedbase64.encode_base64(digest), + ) + + def decrypt(self, event_type: str | EventType, data: EncryptedKeyData) -> bytes: + aes_key, hmac_key = derive_keys(self.key, event_type) + ciphertext = unpaddedbase64.decode_base64(data.ciphertext) + mac = unpaddedbase64.decode_base64(data.mac) + + expected_mac = hmac.digest(hmac_key, ciphertext, hashlib.sha256) + if not hmac.compare_digest(mac, expected_mac): + raise ValueError("Invalid MAC") + + plaintext = prepare_aes(aes_key, data.iv).decrypt(ciphertext) + return unpaddedbase64.decode_base64(plaintext.decode("utf-8")) diff --git a/mautrix/crypto/ssss/key_test.py b/mautrix/crypto/ssss/key_test.py new file mode 100644 index 00000000..e60f1e3c --- /dev/null +++ b/mautrix/crypto/ssss/key_test.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import pytest + +from ...types.event.type import EventType +from .key import Key, KeyMetadata +from .types import EncryptedAccountDataEventContent + +KEY1_CROSS_SIGNING_MASTER_KEY = """{ + "encrypted": { + "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0": { + "iv": "BpKP9nQJTE9jrsAssoxPqQ==", + "ciphertext": "fNRiiiidezjerTgV+G6pUtmeF3izzj5re/mVvY0hO2kM6kYGrxLuIu2ej80=", + "mac": "/gWGDGMyOLmbJp+aoSLh5JxCs0AdS6nAhjzpe+9G2Q0=" + } + } +}""" + +KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED = bytes( + [ + 0x68, + 0xF9, + 0x7F, + 0xD1, + 0x92, + 0x2E, + 0xEC, + 0xF6, + 0xB8, + 0x2B, + 0xB8, + 0x90, + 0xD2, + 0x4D, + 0x06, + 0x52, + 0x98, + 0x4E, + 0x7A, + 0x1D, + 0x70, + 0x3B, + 0x9E, + 0x86, + 0x7B, + 0x7E, + 0xBA, + 0xF7, + 0xFE, + 0xB9, + 0x5B, + 0x6F, + ] +) + +KEY1_META = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "passphrase": { + "algorithm": "m.pbkdf2", + "iterations": 500000, + "salt": "y863BOoqOadgDp8S3FtHXikDJEalsQ7d" + }, + "iv": "xxkTK0L4UzxgAFkQ6XPwsw", + "mac": "MEhooO0ZhFJNxUhvRMSxBnJfL20wkLgle3ocY0ee/eA" +}""" +KEY1_ID = "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0" +KEY1_RECOVERY_KEY = "EsTE s92N EtaX s2h6 VQYF 9Kao tHYL mkyL GKMh isZb KJ4E tvoC" +KEY1_PASSPHRASE = "correct horse battery staple" + +KEY2_META = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +}""" +KEY2_ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" +KEY2_RECOVERY_KEY = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" + +KEY2_META_BROKEN_IV = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +}""" + +KEY2_META_BROKEN_MAC = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow" +}""" + + +def get_key_meta(meta: str) -> KeyMetadata: + return KeyMetadata.parse_json(meta) + + +def get_key1() -> Key: + return get_key_meta(KEY1_META).verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY) + + +def get_key2() -> Key: + return get_key_meta(KEY2_META).verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def get_encrypted_master_key() -> EncryptedAccountDataEventContent: + return EncryptedAccountDataEventContent.parse_json(KEY1_CROSS_SIGNING_MASTER_KEY) + + +def test_decrypt_success() -> None: + key = get_key1() + emk = get_encrypted_master_key() + assert ( + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) == KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED + ) + + +def test_decrypt_fail_wrong_key() -> None: + key = get_key2() + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) + + +def test_decrypt_fail_fake_key() -> None: + key = get_key2() + key.id = KEY1_ID + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) + + +def test_decrypt_fail_wrong_type() -> None: + key = get_key1() + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_SELF_SIGNING, key) + + +def test_encrypt_roundtrip() -> None: + key = get_key1() + data = bytes([0xDE, 0xAD, 0xBE, 0xEF]) + ciphertext = key.encrypt("net.maunium.data", data) + plaintext = key.decrypt("net.maunium.data", ciphertext) + assert plaintext == data + + +def test_verify_recovery_key_correct() -> None: + meta = get_key_meta(KEY1_META) + key = meta.verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY) + assert key.recovery_key == KEY1_RECOVERY_KEY + + +def test_verify_recovery_key_correct2() -> None: + meta = get_key_meta(KEY2_META) + key = meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + assert key.recovery_key == KEY2_RECOVERY_KEY + + +def test_verify_recovery_key_invalid() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY1_ID, "foo") + + +def test_verify_recovery_key_incorrect() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_recovery_key_broken_iv() -> None: + meta = get_key_meta(KEY2_META_BROKEN_IV) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_recovery_key_broken_mac() -> None: + meta = get_key_meta(KEY2_META_BROKEN_MAC) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_passphrase_correct() -> None: + meta = get_key_meta(KEY1_META) + key = meta.verify_passphrase(KEY1_ID, KEY1_PASSPHRASE) + assert key.recovery_key == KEY1_RECOVERY_KEY + + +def test_verify_passphrase_incorrect() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_passphrase(KEY1_ID, "incorrect horse battery staple") + + +def test_verify_passphrase_notset() -> None: + meta = get_key_meta(KEY2_META) + with pytest.raises(ValueError): + meta.verify_passphrase(KEY2_ID, "hmm") diff --git a/mautrix/crypto/ssss/machine.py b/mautrix/crypto/ssss/machine.py new file mode 100644 index 00000000..c43e25e3 --- /dev/null +++ b/mautrix/crypto/ssss/machine.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from mautrix import client as cli +from mautrix.errors import MNotFound +from mautrix.types import EventType, SecretStorageDefaultKeyEventContent + +from .key import Key, KeyMetadata +from .types import EncryptedAccountDataEventContent + + +class Machine: + client: cli.Client + + def __init__(self, client: cli.Client) -> None: + self.client = client + + async def get_default_key_id(self) -> str | None: + try: + data = await self.client.get_account_data(EventType.SECRET_STORAGE_DEFAULT_KEY) + return SecretStorageDefaultKeyEventContent.deserialize(data).key + except (MNotFound, ValueError): + return None + + async def set_default_key_id(self, key_id: str) -> None: + await self.client.set_account_data( + EventType.SECRET_STORAGE_DEFAULT_KEY, + SecretStorageDefaultKeyEventContent(key=key_id), + ) + + async def get_key_data(self, key_id: str) -> KeyMetadata: + data = await self.client.get_account_data(f"m.secret_storage.key.{key_id}") + return KeyMetadata.deserialize(data) + + async def set_key_data(self, key_id: str, data: KeyMetadata) -> None: + await self.client.set_account_data(f"m.secret_storage.key.{key_id}", data) + + async def get_default_key_data(self) -> tuple[str, KeyMetadata]: + key_id = await self.get_default_key_id() + if not key_id: + raise ValueError("No default key ID set") + return key_id, await self.get_key_data(key_id) + + async def get_decrypted_account_data(self, event_type: EventType | str, key: Key) -> bytes: + data = await self.client.get_account_data(event_type) + parsed = EncryptedAccountDataEventContent.deserialize(data) + return parsed.decrypt(event_type, key) + + async def set_encrypted_account_data( + self, event_type: EventType | str, data: bytes, *keys: Key + ) -> None: + encrypted_data = {} + for key in keys: + encrypted_data[key.id] = key.encrypt(event_type, data) + await self.client.set_account_data( + event_type, + EncryptedAccountDataEventContent(encrypted=encrypted_data), + ) + + async def generate_and_upload_key(self, passphrase: str | None = None) -> Key: + key = Key.generate(passphrase) + await self.set_key_data(key.id, key.metadata) + return key diff --git a/mautrix/crypto/ssss/types.py b/mautrix/crypto/ssss/types.py new file mode 100644 index 00000000..4a47f743 --- /dev/null +++ b/mautrix/crypto/ssss/types.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import TYPE_CHECKING + +from attr import dataclass + +from mautrix.types import EventType, SerializableAttrs, SerializableEnum +from mautrix.types.event.account_data import account_data_event_content_map + +if TYPE_CHECKING: + from .key import Key + + +class Algorithm(SerializableEnum): + AES_HMAC_SHA2 = "m.secret_storage.v1.aes-hmac-sha2" + CURVE25519_AES_SHA2 = "m.secret_storage.v1.curve25519-aes-sha2" + + +class PassphraseAlgorithm(SerializableEnum): + PBKDF2 = "m.pbkdf2" + + +@dataclass +class EncryptedKeyData(SerializableAttrs): + ciphertext: str + iv: str + mac: str + + +@dataclass +class EncryptedAccountDataEventContent(SerializableAttrs): + encrypted: dict[str, EncryptedKeyData] + + def decrypt(self, event_type: str | EventType, key: "Key") -> bytes: + try: + encrypted_data = self.encrypted[key.id] + except KeyError as e: + raise ValueError(f"Event not encrypted for provided key") from e + return key.decrypt(event_type, encrypted_data) + + +for encrypted_account_data_type in ( + EventType.CROSS_SIGNING_MASTER, + EventType.CROSS_SIGNING_USER_SIGNING, + EventType.CROSS_SIGNING_SELF_SIGNING, + EventType.MEGOLM_BACKUP_V1, +): + account_data_event_content_map[encrypted_account_data_type] = EncryptedAccountDataEventContent diff --git a/mautrix/crypto/ssss/util.py b/mautrix/crypto/ssss/util.py new file mode 100644 index 00000000..b58c941a --- /dev/null +++ b/mautrix/crypto/ssss/util.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import hashlib +import hmac + +import base58 +import unpaddedbase64 + +from mautrix.types import EventType + +try: + from Crypto import Random + from Crypto.Cipher import AES + from Crypto.Hash import SHA256 + from Crypto.Protocol.KDF import HKDF + from Crypto.Util import Counter +except ImportError: + from Cryptodome import Random + from Cryptodome.Cipher import AES + from Cryptodome.Hash import SHA256 + from Cryptodome.Protocol.KDF import HKDF + from Cryptodome.Util import Counter + +cryptorand = Random.new() + + +def decode_base58_recovery_key(key: str) -> bytes | None: + key_bytes = base58.b58decode(key.replace(" ", "")) + if len(key_bytes) != 35 or key_bytes[0] != 0x8B or key_bytes[1] != 1: + return None + parity = 0 + for byte in key_bytes[:34]: + parity ^= byte + return key_bytes[2:34] if parity == key_bytes[34] else None + + +def encode_base58_recovery_key(key: bytes) -> str: + key_bytes = bytearray(35) + key_bytes[0] = 0x8B + key_bytes[1] = 1 + key_bytes[2:34] = key + parity = 0 + for byte in key_bytes: + parity ^= byte + key_bytes[34] = parity + encoded_key = base58.b58encode(key_bytes).decode("utf-8") + return " ".join(encoded_key[i : i + 4] for i in range(0, len(encoded_key), 4)) + + +def derive_keys(key: bytes, name: str | EventType = "") -> tuple[bytes, bytes]: + aes_key, hmac_key = HKDF( + master=key, + key_len=32, + salt=b"\x00" * 32, + hashmod=SHA256, + num_keys=2, + context=str(name).encode("utf-8"), + ) + return aes_key, hmac_key + + +def prepare_aes(key: bytes, iv: str | bytes) -> AES: + if isinstance(iv, str): + iv = unpaddedbase64.decode_base64(iv) + # initial_value = struct.unpack(">Q", iv[8:])[0] + # counter = Counter.new(64, prefix=iv[:8], initial_value=initial_value) + counter = Counter.new(128, initial_value=int.from_bytes(iv, byteorder="big")) + return AES.new(key=key, mode=AES.MODE_CTR, counter=counter) + + +def calculate_hash(key: bytes, iv: str | bytes) -> str: + aes_key, hmac_key = derive_keys(key) + cipher = prepare_aes(aes_key, iv).decrypt(b"\x00" * 32) + digest = hmac.digest(hmac_key, cipher, hashlib.sha256) + return unpaddedbase64.encode_base64(digest) diff --git a/mautrix/crypto/store/abstract.py b/mautrix/crypto/store/abstract.py index 787d5225..0916a2d7 100644 --- a/mautrix/crypto/store/abstract.py +++ b/mautrix/crypto/store/abstract.py @@ -5,8 +5,9 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import NamedTuple +from typing import AsyncContextManager, NamedTuple from abc import ABC, abstractmethod +from contextlib import asynccontextmanager from mautrix.types import ( CrossSigner, @@ -87,6 +88,11 @@ async def close(self) -> None: async def flush(self) -> None: """Flush the store. If all the methods persist data immediately, this can be a no-op.""" + @asynccontextmanager + async def transaction(self) -> None: + """Run a database transaction. If the store doesn't support transactions, this can be a no-op.""" + yield + @abstractmethod async def delete(self) -> None: """Delete the data in the store.""" diff --git a/mautrix/crypto/store/asyncpg/store.py b/mautrix/crypto/store/asyncpg/store.py index a29f7737..bdc37ddd 100644 --- a/mautrix/crypto/store/asyncpg/store.py +++ b/mautrix/crypto/store/asyncpg/store.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections import defaultdict +from contextlib import asynccontextmanager from datetime import timedelta from asyncpg import UniqueViolationError @@ -79,6 +80,11 @@ def __init__(self, account_id: str, pickle_key: str, db: Database) -> None: self._account = None self._olm_cache = defaultdict(lambda: {}) + @asynccontextmanager + async def transaction(self) -> None: + async with self.db.acquire() as conn, conn.transaction(): + yield + async def delete(self) -> None: tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session") async with self.db.acquire() as conn, conn.transaction(): diff --git a/mautrix/crypto/store/asyncpg/upgrade.py b/mautrix/crypto/store/asyncpg/upgrade.py index e097c5d9..8d413858 100644 --- a/mautrix/crypto/store/asyncpg/upgrade.py +++ b/mautrix/crypto/store/asyncpg/upgrade.py @@ -18,32 +18,32 @@ @upgrade_table.register(description="Latest revision", upgrades_to=10) async def upgrade_blank_to_latest(conn: Connection) -> None: - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_account ( + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_message_index ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_message_index ( sender_key CHAR(43), session_id CHAR(43), "index" INTEGER, event_id TEXT NOT NULL, timestamp BIGINT NOT NULL, PRIMARY KEY (sender_key, session_id, "index") - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_tracked_user ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_tracked_user ( user_id TEXT PRIMARY KEY - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_device ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_device ( user_id TEXT, device_id TEXT, identity_key CHAR(43) NOT NULL, @@ -52,10 +52,10 @@ async def upgrade_blank_to_latest(conn: Connection) -> None: deleted BOOLEAN NOT NULL, name TEXT NOT NULL, PRIMARY KEY (user_id, device_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_olm_session ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_olm_session ( account_id TEXT, session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -64,10 +64,10 @@ async def upgrade_blank_to_latest(conn: Connection) -> None: last_decrypted timestamp NOT NULL, last_encrypted timestamp NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( account_id TEXT, session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -83,10 +83,10 @@ async def upgrade_blank_to_latest(conn: Connection) -> None: max_messages INTEGER, is_scheduled BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, room_id TEXT, session_id CHAR(43) NOT NULL UNIQUE, @@ -98,10 +98,10 @@ async def upgrade_blank_to_latest(conn: Connection) -> None: created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_cross_signing_keys ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_keys ( user_id TEXT, usage TEXT, key CHAR(43) NOT NULL, @@ -109,18 +109,18 @@ async def upgrade_blank_to_latest(conn: Connection) -> None: first_seen_key CHAR(43) NOT NULL, PRIMARY KEY (user_id, usage) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_cross_signing_signatures ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_signatures ( signed_user_id TEXT, signed_key TEXT, signer_user_id TEXT, signer_key TEXT, signature CHAR(88) NOT NULL, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) - )""" - ) + ) + """) @upgrade_table.register(description="Add account_id primary key column") @@ -130,17 +130,17 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: await conn.execute("DROP TABLE crypto_olm_session") await conn.execute("DROP TABLE crypto_megolm_inbound_session") await conn.execute("DROP TABLE crypto_megolm_outbound_session") - await conn.execute( - """CREATE TABLE crypto_account ( + await conn.execute(""" + CREATE TABLE crypto_account ( account_id VARCHAR(255) PRIMARY KEY, device_id VARCHAR(255) NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL - )""" - ) - await conn.execute( - """CREATE TABLE crypto_olm_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_olm_session ( account_id VARCHAR(255), session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -148,10 +148,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_megolm_inbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_megolm_inbound_session ( account_id VARCHAR(255), session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -160,10 +160,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: session bytea NOT NULL, forwarding_chains TEXT NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_megolm_outbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_megolm_outbound_session ( account_id VARCHAR(255), room_id VARCHAR(255), session_id CHAR(43) NOT NULL UNIQUE, @@ -175,8 +175,8 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) - )""" - ) + ) + """) else: async def add_account_id_column(table: str, pkey_columns: list[str]) -> None: @@ -233,25 +233,25 @@ async def upgrade_v4(conn: Connection, scheme: Scheme) -> None: @upgrade_table.register(description="Add cross-signing key and signature caches") async def upgrade_v5(conn: Connection) -> None: - await conn.execute( - """CREATE TABLE crypto_cross_signing_keys ( + await conn.execute(""" + CREATE TABLE crypto_cross_signing_keys ( user_id TEXT, usage TEXT, key CHAR(43), first_seen_key CHAR(43), PRIMARY KEY (user_id, usage) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_cross_signing_signatures ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_signatures ( signed_user_id TEXT, signed_key TEXT, signer_user_id TEXT, signer_key TEXT, signature TEXT, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) - )""" - ) + ) + """) @upgrade_table.register(description="Update trust state values") @@ -322,27 +322,25 @@ async def upgrade_v9_postgres(conn: Connection) -> None: async def upgrade_v9_sqlite(conn: Connection) -> None: await conn.execute("PRAGMA foreign_keys = OFF") async with conn.transaction(): - await conn.execute( - """CREATE TABLE new_crypto_account ( + await conn.execute(""" + CREATE TABLE new_crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL - )""" - ) - await conn.execute( - """ + ) + """) + await conn.execute(""" INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account) SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account FROM crypto_account - """ - ) + """) await conn.execute("DROP TABLE crypto_account") await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account") - await conn.execute( - """CREATE TABLE new_crypto_megolm_inbound_session ( + await conn.execute(""" + CREATE TABLE new_crypto_megolm_inbound_session ( account_id TEXT, session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -353,10 +351,9 @@ async def upgrade_v9_sqlite(conn: Connection) -> None: withheld_code TEXT, withheld_reason TEXT, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """ + ) + """) + await conn.execute(""" INSERT INTO new_crypto_megolm_inbound_session ( account_id, session_id, sender_key, signing_key, room_id, session, forwarding_chains @@ -364,8 +361,7 @@ async def upgrade_v9_sqlite(conn: Connection) -> None: SELECT account_id, session_id, sender_key, signing_key, room_id, session, forwarding_chains FROM crypto_megolm_inbound_session - """ - ) + """) await conn.execute("DROP TABLE crypto_megolm_inbound_session") await conn.execute( "ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session" @@ -373,8 +369,8 @@ async def upgrade_v9_sqlite(conn: Connection) -> None: await conn.execute("UPDATE crypto_megolm_outbound_session SET max_age=max_age*1000") - await conn.execute( - """CREATE TABLE new_crypto_cross_signing_keys ( + await conn.execute(""" + CREATE TABLE new_crypto_cross_signing_keys ( user_id TEXT, usage TEXT, key CHAR(43) NOT NULL, @@ -382,41 +378,37 @@ async def upgrade_v9_sqlite(conn: Connection) -> None: first_seen_key CHAR(43) NOT NULL, PRIMARY KEY (user_id, usage) - )""" - ) - await conn.execute( - """ + ) + """) + await conn.execute(""" INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key) SELECT user_id, usage, key, COALESCE(first_seen_key, key) FROM crypto_cross_signing_keys WHERE key IS NOT NULL - """ - ) + """) await conn.execute("DROP TABLE crypto_cross_signing_keys") await conn.execute( "ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys" ) - await conn.execute( - """CREATE TABLE new_crypto_cross_signing_signatures ( + await conn.execute(""" + CREATE TABLE new_crypto_cross_signing_signatures ( signed_user_id TEXT, signed_key TEXT, signer_user_id TEXT, signer_key TEXT, signature CHAR(88) NOT NULL, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) - )""" - ) - await conn.execute( - """ + ) + """) + await conn.execute(""" INSERT INTO new_crypto_cross_signing_signatures ( signed_user_id, signed_key, signer_user_id, signer_key, signature ) SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature FROM crypto_cross_signing_signatures WHERE signature IS NOT NULL - """ - ) + """) await conn.execute("DROP TABLE crypto_cross_signing_signatures") await conn.execute( "ALTER TABLE new_crypto_cross_signing_signatures " diff --git a/mautrix/types/__init__.py b/mautrix/types/__init__.py index 42b9068c..ceceeaca 100644 --- a/mautrix/types/__init__.py +++ b/mautrix/types/__init__.py @@ -57,6 +57,7 @@ CallRejectEventContent, CallSelectAnswerEventContent, CanonicalAliasStateEventContent, + DirectAccountDataEventContent, EncryptedEvent, EncryptedEventContent, EncryptedFile, @@ -122,6 +123,7 @@ RoomTombstoneStateEventContent, RoomTopicStateEventContent, RoomType, + SecretStorageDefaultKeyEventContent, SingleReceiptEventContent, SpaceChildStateEventContent, SpaceParentStateEventContent, @@ -259,6 +261,7 @@ "CallRejectEventContent", "CallSelectAnswerEventContent", "CanonicalAliasStateEventContent", + "DirectAccountDataEventContent", "EncryptedEvent", "EncryptedEventContent", "EncryptedFile", @@ -324,6 +327,7 @@ "RoomTombstoneStateEventContent", "RoomTopicStateEventContent", "RoomType", + "SecretStorageDefaultKeyEventContent", "SingleReceiptEventContent", "SpaceChildStateEventContent", "SpaceParentStateEventContent", @@ -354,6 +358,7 @@ "OpenGraphImage", "OpenGraphVideo", "BatchSendResponse", + "BeeperBatchSendResponse", "DeviceLists", "DeviceOTKCount", "DirectoryPaginationToken", diff --git a/mautrix/types/crypto.py b/mautrix/types/crypto.py index 3bf7e96b..fe4ab742 100644 --- a/mautrix/types/crypto.py +++ b/mautrix/types/crypto.py @@ -47,9 +47,9 @@ def curve25519(self) -> Optional[IdentityKey]: class CrossSigningUsage(ExtensibleEnum): - MASTER = "master" - SELF = "self_signing" - USER = "user_signing" + MASTER: "CrossSigningUsage" = "master" + SELF: "CrossSigningUsage" = "self_signing" + USER: "CrossSigningUsage" = "user_signing" @dataclass diff --git a/mautrix/types/event/__init__.py b/mautrix/types/event/__init__.py index b391e912..db0658db 100644 --- a/mautrix/types/event/__init__.py +++ b/mautrix/types/event/__init__.py @@ -6,8 +6,10 @@ from .account_data import ( AccountDataEvent, AccountDataEventContent, + DirectAccountDataEventContent, RoomTagAccountDataEventContent, RoomTagInfo, + SecretStorageDefaultKeyEventContent, ) from .base import BaseEvent, BaseRoomEvent, BaseUnsigned, GenericEvent from .batch import BatchSendEvent, BatchSendStateEvent diff --git a/mautrix/types/event/account_data.py b/mautrix/types/event/account_data.py index 3c144a36..fdfa0a30 100644 --- a/mautrix/types/event/account_data.py +++ b/mautrix/types/event/account_data.py @@ -3,7 +3,7 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Union from attr import dataclass import attr @@ -12,6 +12,9 @@ from ..util import Obj, SerializableAttrs, deserializer from .base import BaseEvent, EventType +if TYPE_CHECKING: + from mautrix.crypto.ssss import EncryptedAccountDataEventContent, KeyMetadata + @dataclass class RoomTagInfo(SerializableAttrs): @@ -23,11 +26,24 @@ class RoomTagAccountDataEventContent(SerializableAttrs): tags: Dict[str, RoomTagInfo] = attr.ib(default=None, metadata={"json": "tags"}) +@dataclass +class SecretStorageDefaultKeyEventContent(SerializableAttrs): + key: str + + DirectAccountDataEventContent = Dict[UserID, List[RoomID]] -AccountDataEventContent = Union[RoomTagAccountDataEventContent, DirectAccountDataEventContent, Obj] +AccountDataEventContent = Union[ + RoomTagAccountDataEventContent, + DirectAccountDataEventContent, + SecretStorageDefaultKeyEventContent, + "EncryptedAccountDataEventContent", + "KeyMetadata", + Obj, +] account_data_event_content_map = { EventType.TAG: RoomTagAccountDataEventContent, + EventType.SECRET_STORAGE_DEFAULT_KEY: SecretStorageDefaultKeyEventContent, # m.direct doesn't really need deserializing # EventType.DIRECT: DirectAccountDataEventContent, } diff --git a/mautrix/types/event/encrypted.py b/mautrix/types/event/encrypted.py index cd08fc94..735e481a 100644 --- a/mautrix/types/event/encrypted.py +++ b/mautrix/types/event/encrypted.py @@ -9,7 +9,7 @@ from attr import dataclass -from ..primitive import JSON, DeviceID, IdentityKey, SessionID +from ..primitive import JSON, DeviceID, IdentityKey, SessionID, SigningKey from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field from .base import BaseRoomEvent, BaseUnsigned from .message import RelatesTo @@ -43,6 +43,18 @@ def deserialize(cls, raw: JSON) -> "KeyID": def __str__(self) -> str: return f"{self.algorithm.value}:{self.key_id}" + @classmethod + def ed25519(cls, key_id: SigningKey | DeviceID) -> "KeyID": + return cls(EncryptionKeyAlgorithm.ED25519, key_id) + + @classmethod + def curve25519(cls, key_id: IdentityKey) -> "KeyID": + return cls(EncryptionKeyAlgorithm.CURVE25519, key_id) + + @classmethod + def signed_curve25519(cls, key_id: IdentityKey) -> "KeyID": + return cls(EncryptionKeyAlgorithm.SIGNED_CURVE25519, key_id) + class OlmMsgType(Serializable, IntEnum): PREKEY = 0 diff --git a/mautrix/types/event/message.py b/mautrix/types/event/message.py index 8d5f1bba..32033581 100644 --- a/mautrix/types/event/message.py +++ b/mautrix/types/event/message.py @@ -119,7 +119,7 @@ def set_thread_parent( thread_parent.content.get_thread_parent() or self.relates_to.event_id ) if not disable_reply_fallback: - self.set_reply(last_event_in_thread or thread_parent, disable_fallback=True, **kwargs) + self.set_reply(last_event_in_thread or thread_parent, **kwargs) self.relates_to.is_falling_back = True def set_edit(self, edits: Union[EventID, "MessageEvent"]) -> None: @@ -271,33 +271,6 @@ class LocationInfo(SerializableAttrs): # region Event content -@dataclass -class MediaMessageEventContent(BaseMessageEventContent, SerializableAttrs): - """The content of a media message event (m.image, m.audio, m.video, m.file)""" - - url: Optional[ContentURI] = None - info: Optional[MediaInfo] = None - file: Optional[EncryptedFile] = None - - @staticmethod - @deserializer(MediaInfo) - @deserializer(Optional[MediaInfo]) - def deserialize_info(data: JSON) -> MediaInfo: - if not isinstance(data, dict): - return Obj() - msgtype = data.pop("__mautrix_msgtype", None) - if msgtype == "m.image" or msgtype == "m.sticker": - return ImageInfo.deserialize(data) - elif msgtype == "m.video": - return VideoInfo.deserialize(data) - elif msgtype == "m.audio": - return AudioInfo.deserialize(data) - elif msgtype == "m.file": - return FileInfo.deserialize(data) - else: - return Obj(**data) - - @dataclass class LocationMessageEventContent(BaseMessageEventContent, SerializableAttrs): geo_uri: str = None @@ -314,25 +287,6 @@ class TextMessageEventContent(BaseMessageEventContent, SerializableAttrs): format: Format = None formatted_body: str = None - def set_reply( - self, - reply_to: Union["MessageEvent", EventID], - *, - displayname: Optional[str] = None, - disable_fallback: bool = False, - ) -> None: - super().set_reply(reply_to) - if isinstance(reply_to, str): - return - if isinstance(reply_to, MessageEvent) and not disable_fallback: - self.ensure_has_html() - if isinstance(reply_to.content, TextMessageEventContent): - reply_to.content.trim_reply_fallback() - self.formatted_body = ( - reply_to.make_reply_fallback_html(displayname) + self.formatted_body - ) - self.body = reply_to.make_reply_fallback_text(displayname) + self.body - def ensure_has_html(self) -> None: if not self.formatted_body or self.format != Format.HTML: self.format = Format.HTML @@ -364,6 +318,34 @@ def _trim_reply_fallback_html(self) -> None: self.formatted_body = html_reply_fallback_regex.sub("", self.formatted_body) +@dataclass +class MediaMessageEventContent(TextMessageEventContent, SerializableAttrs): + """The content of a media message event (m.image, m.audio, m.video, m.file)""" + + url: Optional[ContentURI] = None + info: Optional[MediaInfo] = None + file: Optional[EncryptedFile] = None + filename: Optional[str] = None + + @staticmethod + @deserializer(MediaInfo) + @deserializer(Optional[MediaInfo]) + def deserialize_info(data: JSON) -> MediaInfo: + if not isinstance(data, dict): + return Obj() + msgtype = data.pop("__mautrix_msgtype", None) + if msgtype == "m.image" or msgtype == "m.sticker": + return ImageInfo.deserialize(data) + elif msgtype == "m.video": + return VideoInfo.deserialize(data) + elif msgtype == "m.audio": + return AudioInfo.deserialize(data) + elif msgtype == "m.file": + return FileInfo.deserialize(data) + else: + return Obj(**data) + + MessageEventContent = Union[ TextMessageEventContent, MediaMessageEventContent, LocationMessageEventContent, Obj ] @@ -423,36 +405,3 @@ def deserialize_content(data: JSON) -> MessageEventContent: return LocationMessageEventContent.deserialize(data) else: return Obj(**data) - - def make_reply_fallback_html(self, displayname: Optional[str] = None) -> str: - """Generate the HTML fallback for messages replying to this event.""" - if self.content.msgtype.is_text: - body = self.content.formatted_body or escape(self.content.body).replace("\n", "
") - else: - sent_type = media_reply_fallback_body_map[self.content.msgtype] or "a message" - body = f"sent {sent_type}" - displayname = escape(displayname) if displayname else self.sender - return html_reply_fallback_format.format( - room_id=self.room_id, - event_id=self.event_id, - sender=self.sender, - displayname=displayname, - content=body, - ) - - def make_reply_fallback_text(self, displayname: Optional[str] = None) -> str: - """Generate the plaintext fallback for messages replying to this event.""" - if self.content.msgtype.is_text: - body = self.content.body - else: - try: - body = media_reply_fallback_body_map[self.content.msgtype] - except KeyError: - body = "an unknown message type" - lines = body.strip().split("\n") - first_line, lines = lines[0], lines[1:] - fallback_text = f"> <{displayname or self.sender}> {first_line}" - for line in lines: - fallback_text += f"\n> {line}" - fallback_text += "\n\n" - return fallback_text diff --git a/mautrix/types/event/state.py b/mautrix/types/event/state.py index c6b351c6..5ffc855f 100644 --- a/mautrix/types/event/state.py +++ b/mautrix/types/event/state.py @@ -9,7 +9,7 @@ import attr from ..primitive import JSON, ContentURI, EventID, RoomAlias, RoomID, UserID -from ..util import Obj, SerializableAttrs, SerializableEnum, deserializer, field +from ..util import ExtensibleEnum, Obj, SerializableAttrs, SerializableEnum, deserializer, field from .base import BaseRoomEvent, BaseUnsigned from .encrypted import EncryptionAlgorithm from .type import EventType, RoomType @@ -41,7 +41,19 @@ class PowerLevelStateEventContent(SerializableAttrs): ban: int = 50 redact: int = 50 - def get_user_level(self, user_id: UserID) -> int: + def get_user_level( + self, + user_id: UserID, + create: Optional["StateEvent"] = None, + ) -> int: + if ( + create + and create.content.supports_creator_power + and (user_id == create.sender or user_id in (create.content.additional_creators or [])) + ): + # This is really meant to be infinity, but involving floats would be annoying, + # so we use an integer larger than the maximum power level (2^53-1) instead. + return 2**60 - 1 return int(self.users.get(user_id, self.users_default)) def set_user_level(self, user_id: UserID, level: int) -> None: @@ -50,7 +62,16 @@ def set_user_level(self, user_id: UserID, level: int) -> None: else: self.users[user_id] = level - def ensure_user_level(self, user_id: UserID, level: int) -> bool: + def ensure_user_level( + self, user_id: UserID, level: int, create: Optional["StateEvent"] = None + ) -> bool: + if ( + create + and create.content.supports_creator_power + and (user_id == create.sender or user_id in (create.content.additional_creators or [])) + ): + # Don't try to set creator power levels + return False if self.get_user_level(user_id) != level: self.set_user_level(user_id, level) return True @@ -138,16 +159,15 @@ class RoomAvatarStateEventContent(SerializableAttrs): url: Optional[ContentURI] = None -class JoinRule(SerializableEnum): +class JoinRule(ExtensibleEnum): PUBLIC = "public" KNOCK = "knock" RESTRICTED = "restricted" INVITE = "invite" - PRIVATE = "private" KNOCK_RESTRICTED = "knock_restricted" -class JoinRestrictionType(SerializableEnum): +class JoinRestrictionType(ExtensibleEnum): ROOM_MEMBERSHIP = "m.room_membership" @@ -193,6 +213,24 @@ class RoomCreateStateEventContent(SerializableAttrs): federate: bool = field(json="m.federate", omit_default=True, default=True) predecessor: Optional[RoomPredecessor] = None type: Optional[RoomType] = None + additional_creators: Optional[List[UserID]] = None + + @property + def supports_creator_power(self) -> bool: + return self.room_version not in ( + "", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + ) @dataclass diff --git a/mautrix/types/event/type.py b/mautrix/types/event/type.py index 509d6857..5faf3785 100644 --- a/mautrix/types/event/type.py +++ b/mautrix/types/event/type.py @@ -207,6 +207,11 @@ def is_to_device(self) -> bool: "m.push_rules": "PUSH_RULES", "m.tag": "TAG", "m.ignored_user_list": "IGNORED_USER_LIST", + "m.secret_storage.default_key": "SECRET_STORAGE_DEFAULT_KEY", + "m.cross_signing.master": "CROSS_SIGNING_MASTER", + "m.cross_signing.self_signing": "CROSS_SIGNING_SELF_SIGNING", + "m.cross_signing.user_signing": "CROSS_SIGNING_USER_SIGNING", + "m.megolm_backup.v1": "MEGOLM_BACKUP_V1", }, EventType.Class.TO_DEVICE: { "m.room.encrypted": "TO_DEVICE_ENCRYPTED", diff --git a/mautrix/types/event/type.pyi b/mautrix/types/event/type.pyi index 22922288..a2788d6f 100644 --- a/mautrix/types/event/type.pyi +++ b/mautrix/types/event/type.pyi @@ -61,6 +61,11 @@ class EventType(Serializable): PUSH_RULES: "EventType" TAG: "EventType" IGNORED_USER_LIST: "EventType" + SECRET_STORAGE_DEFAULT_KEY: "EventType" + CROSS_SIGNING_MASTER: "EventType" + CROSS_SIGNING_SELF_SIGNING: "EventType" + CROSS_SIGNING_USER_SIGNING: "EventType" + MEGOLM_BACKUP_V1: "EventType" TO_DEVICE_ENCRYPTED: "EventType" TO_DEVICE_DUMMY: "EventType" diff --git a/mautrix/types/misc.py b/mautrix/types/misc.py index cf576c2b..5a07699c 100644 --- a/mautrix/types/misc.py +++ b/mautrix/types/misc.py @@ -106,7 +106,7 @@ class RoomDirectoryResponse(SerializableAttrs): PaginatedMessages = NamedTuple( - "PaginatedMessages", start=SyncToken, end=SyncToken, events=List[Event] + "PaginatedMessages", start=SyncToken, end=Optional[SyncToken], events=List[Event] ) diff --git a/mautrix/types/versions.py b/mautrix/types/versions.py index 9ad81710..52a62f59 100644 --- a/mautrix/types/versions.py +++ b/mautrix/types/versions.py @@ -74,6 +74,10 @@ class SpecVersions: V15 = Version.deserialize("v1.5") V16 = Version.deserialize("v1.6") V17 = Version.deserialize("v1.7") + V18 = Version.deserialize("v1.8") + V19 = Version.deserialize("v1.9") + V110 = Version.deserialize("v1.10") + V111 = Version.deserialize("v1.11") @dataclass diff --git a/mautrix/util/async_db/aiosqlite.py b/mautrix/util/async_db/aiosqlite.py index 80fa3cf8..934379a8 100644 --- a/mautrix/util/async_db/aiosqlite.py +++ b/mautrix/util/async_db/aiosqlite.py @@ -7,6 +7,7 @@ from typing import Any, AsyncContextManager from contextlib import asynccontextmanager +from contextvars import ContextVar import asyncio import logging import os @@ -24,6 +25,9 @@ POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)") +in_transaction = ContextVar("in_transaction", default=False) + + class TxnConnection(aiosqlite.Connection): def __init__(self, path: str, **kwargs) -> None: def connector() -> sqlite3.Connection: @@ -35,7 +39,11 @@ def connector() -> sqlite3.Connection: @asynccontextmanager async def transaction(self) -> None: + if in_transaction.get(): + yield + return await self.execute("BEGIN TRANSACTION") + token = in_transaction.set(True) try: yield except Exception: @@ -43,6 +51,8 @@ async def transaction(self) -> None: raise else: await self.commit() + finally: + in_transaction.reset(token) def __execute(self, query: str, *args: Any): query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query) @@ -181,7 +191,7 @@ async def stop(self) -> None: self._conns -= 1 await conn.close() - def acquire(self) -> AsyncContextManager[LoggingConnection]: + def acquire_direct(self) -> AsyncContextManager[LoggingConnection]: if self._parent: return self._parent.acquire() return self._acquire() diff --git a/mautrix/util/async_db/asyncpg.py b/mautrix/util/async_db/asyncpg.py index 07ef7ad7..97b49f6c 100644 --- a/mautrix/util/async_db/asyncpg.py +++ b/mautrix/util/async_db/asyncpg.py @@ -96,7 +96,7 @@ async def _handle_exception(self, err: Exception) -> None: sys.exit(26) @asynccontextmanager - async def acquire(self) -> LoggingConnection: + async def acquire_direct(self) -> LoggingConnection: async with self.pool.acquire() as conn: yield LoggingConnection( self.scheme, conn, self.log, handle_exception=self._handle_exception diff --git a/mautrix/util/async_db/database.py b/mautrix/util/async_db/database.py index 0f23b02d..b5128b74 100644 --- a/mautrix/util/async_db/database.py +++ b/mautrix/util/async_db/database.py @@ -7,6 +7,8 @@ from typing import Any, AsyncContextManager, Type from abc import ABC, abstractmethod +from contextlib import asynccontextmanager +from contextvars import ContextVar import logging from yarl import URL @@ -23,6 +25,8 @@ from aiosqlite import Cursor from asyncpg import Record +conn_var: ContextVar[LoggingConnection | None] = ContextVar("db_connection", default=None) + class Database(ABC): schemes: dict[str, Type[Database]] = {} @@ -111,15 +115,17 @@ async def _check_foreign_tables(self) -> None: raise ForeignTablesFound("found roomserver_rooms likely belonging to Dendrite") async def _check_owner(self) -> None: - await self.execute( - """CREATE TABLE IF NOT EXISTS database_owner ( + await self.execute(""" + CREATE TABLE IF NOT EXISTS database_owner ( key INTEGER PRIMARY KEY DEFAULT 0, owner TEXT NOT NULL - )""" - ) + ) + """) owner = await self.fetchval("SELECT owner FROM database_owner WHERE key=0") if not owner: - await self.execute("INSERT INTO database_owner (owner) VALUES ($1)", self.owner_name) + await self.execute( + "INSERT INTO database_owner (key, owner) VALUES (0, $1)", self.owner_name + ) elif owner != self.owner_name: raise DatabaseNotOwned(owner) @@ -128,9 +134,22 @@ async def stop(self) -> None: pass @abstractmethod - def acquire(self) -> AsyncContextManager[LoggingConnection]: + def acquire_direct(self) -> AsyncContextManager[LoggingConnection]: pass + @asynccontextmanager + async def acquire(self) -> LoggingConnection: + conn = conn_var.get(None) + if conn is not None: + yield conn + return + async with self.acquire_direct() as conn: + token = conn_var.set(conn) + try: + yield conn + finally: + conn_var.reset(token) + async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor: async with self.acquire() as conn: return await conn.execute(query, *args, timeout=timeout) diff --git a/mautrix/util/async_db/upgrade.py b/mautrix/util/async_db/upgrade.py index 9e593ece..c084d28b 100644 --- a/mautrix/util/async_db/upgrade.py +++ b/mautrix/util/async_db/upgrade.py @@ -97,11 +97,11 @@ async def _save_version(self, conn: LoggingConnection, version: int) -> None: await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version) async def upgrade(self, db: async_db.Database) -> None: - await db.execute( - f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} ( + await db.execute(f""" + CREATE TABLE IF NOT EXISTS {self.version_table_name} ( version INTEGER PRIMARY KEY - )""" - ) + ) + """) row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1") version = row["version"] if row else 0 diff --git a/mautrix/util/program.py b/mautrix/util/program.py index 2fb6b141..74752ca5 100644 --- a/mautrix/util/program.py +++ b/mautrix/util/program.py @@ -98,7 +98,7 @@ def _prepare(self) -> None: self.log.info(f"Initializing {self.name} {self.version}") try: - self.prepare() + self.loop.run_until_complete(self._async_prepare()) except Exception: self.log.critical("Unexpected error in initialization", exc_info=True) sys.exit(1) @@ -117,6 +117,7 @@ def preinit(self) -> None: self.prepare_config() self.prepare_log() self.check_config() + self.init_loop() @property def base_config_path(self) -> str: @@ -168,14 +169,16 @@ def prepare_log(self) -> None: logging.config.dictConfig(copy.deepcopy(self.config["logging"])) self.log = cast(TraceLogger, logging.getLogger("mau.init")) + async def _async_prepare(self) -> None: + self.prepare() + def prepare(self) -> None: """ Lifecycle method where the primary program initialization happens. Use this to fill startup_actions with async startup tasks. """ - self.prepare_loop() - def prepare_loop(self) -> None: + def init_loop(self) -> None: """Init lifecycle method where the asyncio event loop is created.""" if uvloop is not None: uvloop.install() @@ -188,6 +191,7 @@ def start_prometheus(self) -> None: try: enabled = self.config["metrics.enabled"] listen_port = self.config["metrics.listen_port"] + hostname = self.config.get("metrics.hostname", "0.0.0.0") except KeyError: return if not enabled: @@ -197,7 +201,7 @@ def start_prometheus(self) -> None: "Metrics are enabled in config, but prometheus_client is not installed" ) return - prometheus.start_http_server(listen_port) + prometheus.start_http_server(listen_port, addr=hostname) def _run(self) -> None: signal.signal(signal.SIGINT, signal.default_int_handler) diff --git a/mautrix/util/variation_selector.json b/mautrix/util/variation_selector.json index 0205ce94..f27acb7b 100644 --- a/mautrix/util/variation_selector.json +++ b/mautrix/util/variation_selector.json @@ -31,9 +31,12 @@ "23CF": "⏏", "23E9": "⏩", "23EA": "⏪", + "23EB": "⏫", + "23EC": "⏬", "23ED": "⏭", "23EE": "⏮", "23EF": "⏯", + "23F0": "⏰", "23F1": "⏱", "23F2": "⏲", "23F3": "⏳", @@ -114,6 +117,7 @@ "26C4": "⛄", "26C5": "⛅", "26C8": "⛈", + "26CE": "⛎", "26CF": "⛏", "26D1": "⛑", "26D3": "⛓", @@ -132,8 +136,11 @@ "26FA": "⛺", "26FD": "⛽", "2702": "✂", + "2705": "✅", "2708": "✈", "2709": "✉", + "270A": "✊", + "270B": "✋", "270C": "✌", "270D": "✍", "270F": "✏", @@ -142,15 +149,25 @@ "2716": "✖", "271D": "✝", "2721": "✡", + "2728": "✨", "2733": "✳", "2734": "✴", "2744": "❄", "2747": "❇", + "274C": "❌", + "274E": "❎", "2753": "❓", + "2754": "❔", + "2755": "❕", "2757": "❗", "2763": "❣", "2764": "❤", + "2795": "➕", + "2796": "➖", + "2797": "➗", "27A1": "➡", + "27B0": "➰", + "27BF": "➿", "2934": "⤴", "2935": "⤵", "2B05": "⬅", diff --git a/mautrix/util/variation_selector.py b/mautrix/util/variation_selector.py index 498f1abe..e1430293 100644 --- a/mautrix/util/variation_selector.py +++ b/mautrix/util/variation_selector.py @@ -10,7 +10,7 @@ import aiohttp -EMOJI_VAR_URL = "https://www.unicode.org/Public/14.0.0/ucd/emoji/emoji-variation-sequences.txt" +EMOJI_VAR_URL = "https://www.unicode.org/Public/17.0.0/ucd/emoji/emoji-variation-sequences.txt" def read_data() -> dict[str, str]: @@ -43,11 +43,12 @@ async def fetch_data() -> dict[str, str]: if __name__ == "__main__": import asyncio + import importlib.resources + import pathlib import sys - import pkg_resources - - path = pkg_resources.resource_filename("mautrix.util", "variation_selector.json") + path = importlib.resources.files("mautrix.util").joinpath("variation_selector.json") + assert isinstance(path, pathlib.Path) emojis = asyncio.run(fetch_data()) with open(path, "w") as file: json.dump(emojis, file, indent=" ", ensure_ascii=False) @@ -59,11 +60,11 @@ async def fetch_data() -> dict[str, str]: ADD_VARIATION_TRANSLATION = str.maketrans( {ord(emoji): f"{emoji}{VARIATION_SELECTOR_16}" for emoji in read_data().values()} ) -SKIN_TONE_MODIFIERS = ("\U0001F3FB", "\U0001F3FC", "\U0001F3FD", "\U0001F3FE", "\U0001F3FF") +SKIN_TONE_MODIFIERS = ("\U0001f3fb", "\U0001f3fc", "\U0001f3fd", "\U0001f3fe", "\U0001f3ff") SKIN_TONE_REPLACEMENTS = {f"{VARIATION_SELECTOR_16}{mod}": mod for mod in SKIN_TONE_MODIFIERS} VARIATION_SELECTOR_REPLACEMENTS = { **SKIN_TONE_REPLACEMENTS, - "\U0001F408\ufe0f\u200d\u2b1b\ufe0f": "\U0001F408\u200d\u2b1b", + "\U0001f408\ufe0f\u200d\u2b1b\ufe0f": "\U0001f408\u200d\u2b1b", } diff --git a/optional-requirements.txt b/optional-requirements.txt index a5660f4a..a6e0227d 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -11,3 +11,4 @@ uvloop python-olm unpaddedbase64 pycryptodome +base58 diff --git a/setup.py b/setup.py index b01c50a1..23ac1cb1 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from mautrix import __version__ -encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome"] +encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome", "base58"] test_dependencies = ["aiosqlite", "asyncpg", "ruamel.yaml", *encryption_dependencies] setuptools.setup( @@ -28,7 +28,7 @@ ], extras_require={ "detect_mimetype": ["python-magic>=0.4.15,<0.5"], - "lint": ["black~=24.1", "isort"], + "lint": ["black~=25.1", "isort"], "test": ["pytest", "pytest-asyncio", *test_dependencies], "encryption": encryption_dependencies, }, @@ -45,6 +45,8 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ], package_data={