diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index 5a71c970..3a845672 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -8,12 +8,12 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.10", "3.11", "3.12"]
+ python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install libolm
@@ -46,17 +46,17 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-python@v5
+ - uses: actions/checkout@v6
+ - uses: actions/setup-python@v6
with:
- python-version: "3.12"
+ python-version: "3.14"
- uses: isort/isort-action@master
with:
sortPaths: "./mautrix"
- uses: psf/black@stable
with:
src: "./mautrix"
- version: "24.1.1"
+ version: "26.3.1"
- name: pre-commit
run: |
pip install pre-commit
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 8f3efe97..b0d7ab3f 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,6 +1,6 @@
build docs builder:
stage: build
- image: docker:stable
+ image: docker:latest
tags:
- amd64
only:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index a056ffdb..66065033 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.5.0
+ rev: v6.0.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
@@ -8,13 +8,13 @@ repos:
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
- rev: 24.1.1
+ rev: 26.3.1
hooks:
- id: black
language_version: python3
files: ^mautrix/.*\.pyi?$
- repo: https://github.com/PyCQA/isort
- rev: 5.13.2
+ rev: 8.0.1
hooks:
- id: isort
files: ^mautrix/.*\.pyi?$
diff --git a/CHANGELOG.md b/CHANGELOG.md
index aad156fd..c7179c86 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,52 @@
+## v0.21.0 (2025-11-17)
+
+* *(event)* Added support for creator power in room v12+.
+* *(crypto)* Added support for generating and using recovery keys for verifying
+ the active device.
+* *(bridge)* Added config option for self-signing bot device.
+* *(bridge)* Removed check for login flows when using MSC4190
+ (thanks to [@meson800] in [#178]).
+* *(client)* Changed `set_displayname` and `set_avatar_url` to avoid setting
+ empty strings if the value is already unset (thanks to [@frebib] in [#171]).
+
+[@frebib]: https://github.com/frebib
+[@meson800]: https://github.com/meson800
+[#171]: https://github.com/mautrix/python/pull/171
+[#178]: https://github.com/mautrix/python/pull/178
+
+## v0.20.8 (2025-06-01)
+
+* *(bridge)* Added support for [MSC4190] (thanks to [@surakin] in [#175]).
+* *(appservice)* Renamed `push_ephemeral` in generated registrations to
+ `receive_ephemeral` to match the accepted version of [MSC2409].
+* *(bridge)* Fixed compatibility with breaking change in aiohttp 3.12.6.
+
+[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781
+[@surakin]: https://github.com/surakin
+[#175]: https://github.com/mautrix/python/pull/175
+
+## v0.20.7 (2025-01-03)
+
+* *(types)* Removed support for generating reply fallbacks to implement
+ [MSC2781]. Stripping fallbacks is still supported.
+
+[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781
+
+## v0.20.6 (2024-07-12)
+
+* *(bridge)* Added `/register` call if `/versions` fails with `M_FORBIDDEN`.
+
+## v0.20.5 (2024-07-09)
+
+**Note:** The `bridge` module is deprecated as all bridges are being rewritten
+in Go. See for more info.
+
+* *(client)* Added support for authenticated media downloads.
+* *(bridge)* Stopped using cached homeserver URLs for double puppeting if one
+ is set in the config file.
+* *(crypto)* Fixed error when checking OTK counts before uploading new keys.
+* *(types)* Added MSC2530 (captions) fields to `MediaMessageEventContent`.
+
## v0.20.4 (2024-01-09)
* Dropped Python 3.9 support.
diff --git a/README.rst b/README.rst
index e03596d9..f1342c19 100644
--- a/README.rst
+++ b/README.rst
@@ -49,7 +49,7 @@ Components
.. _#maunium:maunium.net: https://matrix.to/#/#maunium:maunium.net
.. _python-appservice-framework: https://github.com/Cadair/python-appservice-framework/
-.. _Client API: https://matrix.org/docs/spec/client_server/r0.6.1.html
+.. _Client API: https://spec.matrix.org/latest/client-server-api/
.. _mautrix.api: https://docs.mau.fi/python/latest/api/mautrix.api.html
.. _mautrix.client.api: https://docs.mau.fi/python/latest/api/mautrix.client.api.html
diff --git a/dev-requirements.txt b/dev-requirements.txt
index bb8c2a0a..dff6ee9d 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -1,3 +1,3 @@
pre-commit>=2.10.1,<3
-isort>=5.10.1,<6
-black>=24,<25
+isort>=8,<9
+black>=26,<27
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 798e0f1b..7269d3a5 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -9,7 +9,7 @@ yarl
# that aren't used for anything that's in the docs
python-magic
ruamel.yaml
-SQLAlchemy
+SQLAlchemy<2
commonmark
asyncpg
aiosqlite
@@ -17,3 +17,4 @@ prometheus_client
python-olm
unpaddedbase64
pycryptodome
+base58
diff --git a/mautrix/__init__.py b/mautrix/__init__.py
index 5d4ce561..c94995f9 100644
--- a/mautrix/__init__.py
+++ b/mautrix/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.20.4"
+__version__ = "0.21.0"
__author__ = "Tulir Asokan "
__all__ = [
"api",
diff --git a/mautrix/api.py b/mautrix/api.py
index 39871bb5..1adde9ec 100644
--- a/mautrix/api.py
+++ b/mautrix/api.py
@@ -462,6 +462,7 @@ def get_download_url(
mxc_uri: str,
download_type: Literal["download", "thumbnail"] = "download",
file_name: str | None = None,
+ authenticated: bool = False,
) -> URL:
"""
Get the full HTTP URL to download a ``mxc://`` URI.
@@ -470,6 +471,7 @@ def get_download_url(
mxc_uri: The MXC URI whose full URL to get.
download_type: The type of download ("download" or "thumbnail").
file_name: Optionally, a file name to include in the download URL.
+ authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11.
Returns:
The full HTTP URL.
@@ -485,7 +487,11 @@ def get_download_url(
"https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png"
"""
server_name, media_id = self.parse_mxc_uri(mxc_uri)
- url = self.base_url / str(APIPath.MEDIA) / "v3" / download_type / server_name / media_id
+ if authenticated:
+ url = self.base_url / str(APIPath.CLIENT) / "v1" / "media"
+ else:
+ url = self.base_url / str(APIPath.MEDIA) / "v3"
+ url = url / download_type / server_name / media_id
if file_name:
url /= file_name
return url
diff --git a/mautrix/appservice/api/intent.py b/mautrix/appservice/api/intent.py
index 00990671..626b34f1 100644
--- a/mautrix/appservice/api/intent.py
+++ b/mautrix/appservice/api/intent.py
@@ -118,6 +118,8 @@ def __init__(
) -> None:
super().__init__(mxid=mxid, api=api, state_store=state_store)
self.bot = bot
+ if bot is not None:
+ self.versions_cache = bot.versions_cache
self.log = api.base_log.getChild("intent")
for method in ENSURE_REGISTERED_METHODS:
@@ -708,6 +710,8 @@ async def _ensure_has_power_level_for(
if not await self.state_store.has_power_levels_cached(room_id):
# TODO add option to not try to fetch power levels from server
await self.get_power_levels(room_id, ignore_cache=True, ensure_joined=False)
+ if not await self.state_store.has_create_cached(room_id):
+ await self.get_state_event(room_id, EventType.ROOM_CREATE, format="event")
if not await self.state_store.has_power_level(room_id, self.mxid, event_type):
# TODO implement something better
raise IntentError(
diff --git a/mautrix/appservice/appservice.py b/mautrix/appservice/appservice.py
index 892db195..65e202ae 100644
--- a/mautrix/appservice/appservice.py
+++ b/mautrix/appservice/appservice.py
@@ -13,7 +13,7 @@
from aiohttp import web
import aiohttp
-from mautrix.types import JSON, RoomAlias, UserID
+from mautrix.types import JSON, RoomAlias, UserID, VersionsResponse
from mautrix.util.logging import TraceLogger
from ..api import HTTPAPI
diff --git a/mautrix/bridge/config.py b/mautrix/bridge/config.py
index b98ebc52..defed222 100644
--- a/mautrix/bridge/config.py
+++ b/mautrix/bridge/config.py
@@ -143,6 +143,8 @@ def do_update(self, helper: ConfigUpdateHelper) -> None:
copy("bridge.encryption.default")
copy("bridge.encryption.require")
copy("bridge.encryption.appservice")
+ copy("bridge.encryption.msc4190")
+ copy("bridge.encryption.self_sign")
copy("bridge.encryption.delete_keys.delete_outbound_on_ack")
copy("bridge.encryption.delete_keys.dont_store_outbound")
copy("bridge.encryption.delete_keys.ratchet_on_decrypt")
@@ -240,4 +242,7 @@ def generate_registration(self) -> None:
if self["appservice.ephemeral_events"]:
self._registration["de.sorunome.msc2409.push_ephemeral"] = True
- self._registration["push_ephemeral"] = True
+ self._registration["receive_ephemeral"] = True
+
+ if self["bridge.encryption.msc4190"]:
+ self._registration["io.element.msc4190"] = True
diff --git a/mautrix/bridge/custom_puppet.py b/mautrix/bridge/custom_puppet.py
index 9057877c..f5befd5f 100644
--- a/mautrix/bridge/custom_puppet.py
+++ b/mautrix/bridge/custom_puppet.py
@@ -132,8 +132,14 @@ def is_real_user(self) -> bool:
return bool(self.custom_mxid and self.access_token)
def _fresh_intent(self) -> IntentAPI:
- if self.access_token == "appservice-config" and self.custom_mxid:
+ if self.custom_mxid:
_, server = self.az.intent.parse_user_id(self.custom_mxid)
+ try:
+ self.base_url = self.homeserver_url_map[server]
+ except KeyError:
+ if server == self.az.domain:
+ self.base_url = self.az.intent.api.base_url
+ if self.access_token == "appservice-config" and self.custom_mxid:
try:
secret = self.login_shared_secret_map[server]
except KeyError:
diff --git a/mautrix/bridge/e2ee.py b/mautrix/bridge/e2ee.py
index 7ae66abf..1525b388 100644
--- a/mautrix/bridge/e2ee.py
+++ b/mautrix/bridge/e2ee.py
@@ -57,6 +57,8 @@ class EncryptionManager:
appservice_mode: bool
periodically_delete_expired_keys: bool
delete_outdated_inbound: bool
+ msc4190: bool
+ self_sign: bool
bridge: br.Bridge
az: AppService
@@ -108,6 +110,8 @@ def __init__(
self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"])
self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"]
self.appservice_mode = bridge.config["bridge.encryption.appservice"]
+ self.msc4190 = bridge.config["bridge.encryption.msc4190"]
+ self.self_sign = bridge.config["bridge.encryption.self_sign"]
if self.appservice_mode:
self.az.otk_handler = self.crypto.handle_as_otk_counts
self.az.device_list_handler = self.crypto.handle_as_device_lists
@@ -245,12 +249,13 @@ async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> M
return decrypted
async def start(self) -> None:
- flows = await self.client.get_login_flows()
- if not flows.supports_type(LoginType.APPSERVICE):
- self.log.critical(
- "Encryption enabled in config, but homeserver does not support appservice login"
- )
- sys.exit(30)
+ if not self.msc4190:
+ flows = await self.client.get_login_flows()
+ if not flows.supports_type(LoginType.APPSERVICE):
+ self.log.critical(
+ "Encryption enabled in config, but homeserver does not support appservice login"
+ )
+ sys.exit(30)
self.log.debug("Logging in with bridge bot user")
if self.crypto_db:
try:
@@ -261,22 +266,42 @@ async def start(self) -> None:
device_id = await self.crypto_store.get_device_id()
if device_id:
self.log.debug(f"Found device ID in database: {device_id}")
- # We set the API token to the AS token here to authenticate the appservice login
- # It'll get overridden after the login
- self.client.api.token = self.az.as_token
- await self.client.login(
- login_type=LoginType.APPSERVICE,
- device_name=self.device_name,
- device_id=device_id,
- store_access_token=True,
- update_hs_url=False,
- )
+
+ if self.msc4190:
+ if not device_id:
+ self.log.debug("Creating bot device with MSC4190")
+ self.client.api.token = self.az.as_token
+ await self.client.create_device_msc4190(
+ device_id=device_id, initial_display_name=self.device_name
+ )
+ else:
+ # We set the API token to the AS token here to authenticate the appservice login
+ # It'll get overridden after the login
+ self.client.api.token = self.az.as_token
+ await self.client.login(
+ login_type=LoginType.APPSERVICE,
+ device_name=self.device_name,
+ device_id=device_id,
+ store_access_token=True,
+ update_hs_url=False,
+ )
+
await self.crypto.load()
if not device_id:
await self.crypto_store.put_device_id(self.client.device_id)
self.log.debug(f"Logged in with new device ID {self.client.device_id}")
+ await self.crypto.share_keys()
elif self.crypto.account.shared:
await self._verify_keys_are_on_server()
+ else:
+ await self.crypto.share_keys()
+ if self.self_sign:
+ trust_state = await self.crypto.resolve_trust(self.crypto.own_identity)
+ if trust_state < TrustState.CROSS_SIGNED_UNTRUSTED:
+ recovery_key = await self.crypto.generate_recovery_key()
+ self.log.info(f"Generated recovery key and signed own device: {recovery_key}")
+ else:
+ self.log.debug(f"Own device is already verified ({trust_state})")
if self.appservice_mode:
self.log.info("End-to-bridge encryption support is enabled (appservice mode)")
else:
diff --git a/mautrix/bridge/matrix.py b/mautrix/bridge/matrix.py
index da42a95e..e5399094 100644
--- a/mautrix/bridge/matrix.py
+++ b/mautrix/bridge/matrix.py
@@ -225,6 +225,12 @@ async def wait_for_connection(self) -> None:
try:
self.versions = await self.az.intent.versions()
break
+ except MForbidden:
+ self.log.debug(
+ "/versions endpoint returned M_FORBIDDEN, "
+ "trying to register bridge bot before retrying..."
+ )
+ await self.az.intent.ensure_registered()
except Exception:
self.log.exception("Connection to homeserver failed, retrying in 10 seconds")
await asyncio.sleep(10)
diff --git a/mautrix/client/api/authentication.py b/mautrix/client/api/authentication.py
index cd77272d..0f6249ea 100644
--- a/mautrix/client/api/authentication.py
+++ b/mautrix/client/api/authentication.py
@@ -5,6 +5,8 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
+import secrets
+
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import (
@@ -117,6 +119,19 @@ async def login(
self.api.base_url = base_url.rstrip("/")
return resp_data
+ async def create_device_msc4190(self, device_id: str, initial_display_name: str) -> None:
+ """
+ Create a Device for a user of the homeserver using appservice interface defined in MSC4190
+ """
+ if len(device_id) == 0:
+ device_id = DeviceID(secrets.token_urlsafe(10))
+ self.api.as_user_id = self.mxid
+ await self.api.request(
+ Method.PUT, Path.v3.devices[device_id], {"display_name": initial_display_name}
+ )
+ self.api.as_device_id = device_id
+ self.device_id = device_id
+
async def logout(self, clear_access_token: bool = True) -> None:
"""
Invalidates an existing access token, so that it can no longer be used for authorization.
diff --git a/mautrix/client/api/events.py b/mautrix/client/api/events.py
index 99e1c178..bd415b9b 100644
--- a/mautrix/client/api/events.py
+++ b/mautrix/client/api/events.py
@@ -5,7 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Awaitable
+from typing import Awaitable, Literal, overload
import json
from mautrix.api import Method, Path
@@ -168,12 +168,41 @@ async def get_event_context(
)
return EventContext.deserialize(resp)
+ @overload
async def get_state_event(
self,
room_id: RoomID,
event_type: EventType,
state_key: str = "",
- ) -> StateEventContent:
+ *,
+ format: Literal["content"] = "content",
+ ) -> StateEventContent: ...
+ @overload
+ async def get_state_event(
+ self,
+ room_id: RoomID,
+ event_type: EventType,
+ state_key: str = "",
+ *,
+ format: Literal["event"],
+ ) -> StateEvent: ...
+ @overload
+ async def get_state_event(
+ self,
+ room_id: RoomID,
+ event_type: EventType,
+ state_key: str = "",
+ *,
+ format: str = "content",
+ ) -> StateEventContent | StateEvent: ...
+ async def get_state_event(
+ self,
+ room_id: RoomID,
+ event_type: EventType,
+ state_key: str = "",
+ *,
+ format: str = "content",
+ ) -> StateEventContent | StateEvent:
"""
Looks up the contents of a state event in a room. If the user is joined to the room then the
state is taken from the current state of the room. If the user has left the room then the
@@ -185,6 +214,9 @@ async def get_state_event(
room_id: The ID of the room to look up the state in.
event_type: The type of state to look up.
state_key: The key of the state to look up. Defaults to empty string.
+ format: The format of the state event to return. Defaults to "content", which only returns
+ the content of the state event. If set to "event", the full event is returned.
+ See https://github.com/matrix-org/matrix-spec/issues/1047 for more info.
Returns:
The state event.
@@ -192,11 +224,17 @@ async def get_state_event(
content = await self.api.request(
Method.GET,
Path.v3.rooms[room_id].state[event_type][state_key],
+ query_params={"format": format} if format != "content" else None,
metrics_method="getStateEvent",
)
- content["__mautrix_event_type"] = event_type
try:
- return StateEvent.deserialize_content(content)
+ if format == "content":
+ content["__mautrix_event_type"] = event_type
+ return StateEvent.deserialize_content(content)
+ elif format == "event":
+ return StateEvent.deserialize(content)
+ else:
+ return content
except SerializerError as e:
raise MatrixResponseError("Invalid state event in response") from e
@@ -357,15 +395,13 @@ async def get_messages(
try:
return PaginatedMessages(
content["start"],
- content["end"],
+ content.get("end"),
[Event.deserialize(event) for event in content["chunk"]],
)
except KeyError:
if "start" not in content:
raise MatrixResponseError("`start` not in response.")
- elif "end" not in content:
- raise MatrixResponseError("`start` not in response.")
- raise MatrixResponseError("`content` not in response.")
+ raise MatrixResponseError("`chunk` not in response.")
except SerializerError as e:
raise MatrixResponseError("Invalid events in response") from e
diff --git a/mautrix/client/api/modules/crypto.py b/mautrix/client/api/modules/crypto.py
index af656916..2d879a63 100644
--- a/mautrix/client/api/modules/crypto.py
+++ b/mautrix/client/api/modules/crypto.py
@@ -5,13 +5,17 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Any
+from typing import Any, Union
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import (
+ JSON,
ClaimKeysResponse,
+ CrossSigningKeys,
+ CrossSigningUsage,
DeviceID,
+ DeviceKeys,
EncryptionKeyAlgorithm,
EventType,
QueryKeysResponse,
@@ -82,7 +86,7 @@ async def send_to_one_device(
async def upload_keys(
self,
one_time_keys: dict[str, Any] | None = None,
- device_keys: dict[str, Any] | None = None,
+ device_keys: DeviceKeys | dict[str, Any] | None = None,
) -> dict[EncryptionKeyAlgorithm, int]:
"""
Publishes end-to-end encryption keys for the device.
@@ -102,8 +106,12 @@ async def upload_keys(
"""
data = {}
if device_keys:
+ if isinstance(device_keys, Serializable):
+ device_keys = device_keys.serialize()
data["device_keys"] = device_keys
if one_time_keys:
+ if isinstance(one_time_keys, Serializable):
+ one_time_keys = one_time_keys.serialize()
data["one_time_keys"] = one_time_keys
resp = await self.api.request(Method.POST, Path.v3.keys.upload, data)
try:
@@ -116,6 +124,43 @@ async def upload_keys(
except AttributeError as e:
raise MatrixResponseError("Invalid `one_time_key_counts` field in response.") from e
+ async def upload_cross_signing_keys(
+ self,
+ keys: dict[CrossSigningUsage, CrossSigningKeys],
+ auth: dict[str, JSON] | None = None,
+ ) -> None:
+ await self.api.request(
+ Method.POST,
+ Path.v3.keys.device_signing.upload,
+ {f"{usage}_key": key.serialize() for usage, key in keys.items()}
+ | ({"auth": auth} if auth else {}),
+ )
+
+ async def upload_one_signature(
+ self,
+ user_id: UserID,
+ device_id: DeviceID,
+ keys: Union[DeviceKeys, CrossSigningKeys],
+ ) -> None:
+ await self.api.request(
+ Method.POST, Path.v3.keys.signatures.upload, {user_id: {device_id: keys.serialize()}}
+ )
+ # TODO check failures
+
+ async def upload_many_signatures(
+ self,
+ signatures: dict[UserID, dict[DeviceID, Union[DeviceKeys, CrossSigningKeys]]],
+ ) -> None:
+ await self.api.request(
+ Method.POST,
+ Path.v3.keys.signatures.upload,
+ {
+ user_id: {device_id: keys.serialize() for device_id, keys in devices.items()}
+ for user_id, devices in signatures.items()
+ },
+ )
+ # TODO check failures
+
async def query_keys(
self,
device_keys: list[UserID] | set[UserID] | dict[UserID, list[DeviceID]],
diff --git a/mautrix/client/api/modules/media_repository.py b/mautrix/client/api/modules/media_repository.py
index fe52252a..b1d90cc1 100644
--- a/mautrix/client/api/modules/media_repository.py
+++ b/mautrix/client/api/modules/media_repository.py
@@ -10,6 +10,8 @@
import asyncio
import time
+from yarl import URL
+
from mautrix import __optional_imports__
from mautrix.api import MediaPath, Method
from mautrix.errors import MatrixResponseError, make_request_error
@@ -19,6 +21,7 @@
MediaRepoConfig,
MXOpenGraph,
SerializerError,
+ SpecVersions,
)
from mautrix.util import background_task
from mautrix.util.async_body import async_iter_bytes
@@ -178,13 +181,19 @@ async def download_media(self, url: ContentURI, timeout_ms: int | None = None) -
Returns:
The raw downloaded data.
"""
- url = self.api.get_download_url(url)
+ authenticated = (await self.versions()).supports(SpecVersions.V111)
+ url = self.api.get_download_url(url, authenticated=authenticated)
query_params: dict[str, Any] = {"allow_redirect": "true"}
if timeout_ms is not None:
query_params["timeout_ms"] = timeout_ms
+ headers: dict[str, str] = {}
+ if authenticated:
+ headers["Authorization"] = f"Bearer {self.api.token}"
+ if self.api.as_user_id:
+ query_params["user_id"] = self.api.as_user_id
req_id = self.api.log_download_request(url, query_params)
start = time.monotonic()
- async with self.api.session.get(url, params=query_params) as response:
+ async with self.api.session.get(url, params=query_params, headers=headers) as response:
try:
response.raise_for_status()
return await response.read()
@@ -199,7 +208,7 @@ async def download_thumbnail(
width: int | None = None,
height: int | None = None,
resize_method: Literal["crop", "scale"] = None,
- allow_remote: bool = True,
+ allow_remote: bool | None = None,
timeout_ms: int | None = None,
):
"""
@@ -223,7 +232,10 @@ async def download_thumbnail(
Returns:
The raw downloaded data.
"""
- url = self.api.get_download_url(url, download_type="thumbnail")
+ authenticated = (await self.versions()).supports(SpecVersions.V111)
+ url = self.api.get_download_url(
+ url, download_type="thumbnail", authenticated=authenticated
+ )
query_params: dict[str, Any] = {"allow_redirect": "true"}
if width is not None:
query_params["width"] = width
@@ -232,12 +244,17 @@ async def download_thumbnail(
if resize_method is not None:
query_params["method"] = resize_method
if allow_remote is not None:
- query_params["allow_remote"] = allow_remote
+ query_params["allow_remote"] = str(allow_remote).lower()
if timeout_ms is not None:
query_params["timeout_ms"] = timeout_ms
+ headers: dict[str, str] = {}
+ if authenticated:
+ headers["Authorization"] = f"Bearer {self.api.token}"
+ if self.api.as_user_id:
+ query_params["user_id"] = self.api.as_user_id
req_id = self.api.log_download_request(url, query_params)
start = time.monotonic()
- async with self.api.session.get(url, params=query_params) as response:
+ async with self.api.session.get(url, params=query_params, headers=headers) as response:
try:
response.raise_for_status()
return await response.read()
diff --git a/mautrix/client/api/user_data.py b/mautrix/client/api/user_data.py
index 4c3d437a..9c380335 100644
--- a/mautrix/client/api/user_data.py
+++ b/mautrix/client/api/user_data.py
@@ -71,7 +71,7 @@ async def search_users(self, search_query: str, limit: int | None = 10) -> UserS
# region 10.2 Profiles
# API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#profiles
- async def set_displayname(self, displayname: str, check_current: bool = True) -> None:
+ async def set_displayname(self, displayname: str | None, check_current: bool = True) -> None:
"""
Set the display name of the current user.
@@ -81,7 +81,9 @@ async def set_displayname(self, displayname: str, check_current: bool = True) ->
displayname: The new display name for the user.
check_current: Whether or not to check if the displayname is already set.
"""
- if check_current and await self.get_displayname(self.mxid) == displayname:
+ if check_current and str_or_none(await self.get_displayname(self.mxid)) == str_or_none(
+ displayname
+ ):
return
await self.api.request(
Method.PUT,
@@ -112,7 +114,9 @@ async def get_displayname(self, user_id: UserID) -> str | None:
except KeyError:
return None
- async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = True) -> None:
+ async def set_avatar_url(
+ self, avatar_url: ContentURI | None, check_current: bool = True
+ ) -> None:
"""
Set the avatar of the current user.
@@ -122,7 +126,9 @@ async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = Tru
avatar_url: The ``mxc://`` URI to the new avatar.
check_current: Whether or not to check if the avatar is already set.
"""
- if check_current and await self.get_avatar_url(self.mxid) == avatar_url:
+ if check_current and str_or_none(await self.get_avatar_url(self.mxid)) == str_or_none(
+ avatar_url
+ ):
return
await self.api.request(
Method.PUT,
@@ -185,3 +191,10 @@ async def beeper_update_profile(self, custom_fields: dict[str, Any]) -> None:
await self.api.request(Method.PATCH, Path.v3.profile[self.mxid], custom_fields)
# endregion
+
+
+def str_or_none(v: str | None) -> str | None:
+ """
+ str_or_none empty string values to None
+ """
+ return None if v == "" else v
diff --git a/mautrix/client/state_store/abstract.py b/mautrix/client/state_store/abstract.py
index 3d37087d..d5b1f5f2 100644
--- a/mautrix/client/state_store/abstract.py
+++ b/mautrix/client/state_store/abstract.py
@@ -121,6 +121,18 @@ async def set_power_levels(
) -> None:
pass
+ @abstractmethod
+ async def has_create_cached(self, room_id: RoomID) -> bool:
+ pass
+
+ @abstractmethod
+ async def get_create(self, room_id: RoomID) -> StateEvent | None:
+ pass
+
+ @abstractmethod
+ async def set_create(self, event: StateEvent) -> None:
+ pass
+
@abstractmethod
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
pass
@@ -135,7 +147,7 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent
@abstractmethod
async def set_encryption_info(
- self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, any]
+ self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
) -> None:
pass
@@ -149,6 +161,8 @@ async def update_state(self, evt: StateEvent) -> None:
await self.set_member(evt.room_id, UserID(evt.state_key), evt.content)
elif evt.type == EventType.ROOM_ENCRYPTION:
await self.set_encryption_info(evt.room_id, evt.content)
+ elif evt.type == EventType.ROOM_CREATE and evt.sender:
+ await self.set_create(evt)
async def get_membership(self, room_id: RoomID, user_id: UserID) -> Membership:
member = await self.get_member(room_id, user_id)
@@ -172,4 +186,7 @@ async def has_power_level(
room_levels = await self.get_power_levels(room_id)
if not room_levels:
return None
- return room_levels.get_user_level(user_id) >= room_levels.get_event_level(event_type)
+ create_event = await self.get_create(room_id)
+ return room_levels.get_user_level(user_id, create_event) >= room_levels.get_event_level(
+ event_type
+ )
diff --git a/mautrix/client/state_store/asyncpg/store.py b/mautrix/client/state_store/asyncpg/store.py
index d78c04d6..f4f8436f 100644
--- a/mautrix/client/state_store/asyncpg/store.py
+++ b/mautrix/client/state_store/asyncpg/store.py
@@ -16,6 +16,7 @@
RoomEncryptionStateEventContent,
RoomID,
Serializable,
+ StateEvent,
UserID,
)
from mautrix.util.async_db import Database, Scheme
@@ -223,6 +224,29 @@ async def set_power_levels(
json.dumps(content.serialize() if isinstance(content, Serializable) else content),
)
+ async def has_create_cached(self, room_id: RoomID) -> bool:
+ return bool(
+ await self.db.fetchval(
+ "SELECT create_event IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id
+ )
+ )
+
+ async def get_create(self, room_id: RoomID) -> StateEvent | None:
+ create_event_json = await self.db.fetchval(
+ "SELECT create_event FROM mx_room_state WHERE room_id=$1", room_id
+ )
+ if create_event_json is None:
+ return None
+ return StateEvent.parse_json(create_event_json)
+
+ async def set_create(self, event: StateEvent) -> None:
+ await self.db.execute(
+ "INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) "
+ "ON CONFLICT (room_id) DO UPDATE SET create_event=$2",
+ event.room_id,
+ json.dumps(event.serialize() if isinstance(event, Serializable) else event),
+ )
+
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
return bool(
await self.db.fetchval(
diff --git a/mautrix/client/state_store/asyncpg/upgrade.py b/mautrix/client/state_store/asyncpg/upgrade.py
index 88f115f2..20b2e5b2 100644
--- a/mautrix/client/state_store/asyncpg/upgrade.py
+++ b/mautrix/client/state_store/asyncpg/upgrade.py
@@ -14,17 +14,18 @@
)
-@upgrade_table.register(description="Latest revision", upgrades_to=3)
-async def upgrade_blank_to_v3(conn: Connection, scheme: Scheme) -> None:
- await conn.execute(
- """CREATE TABLE mx_room_state (
+@upgrade_table.register(description="Latest revision", upgrades_to=4)
+async def upgrade_blank_to_v4(conn: Connection, scheme: Scheme) -> None:
+ await conn.execute("""
+ CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
is_encrypted BOOLEAN,
has_full_member_list BOOLEAN,
encryption TEXT,
- power_levels TEXT
- )"""
- )
+ power_levels TEXT,
+ create_event TEXT
+ )
+ """)
membership_check = ""
if scheme != Scheme.SQLITE:
await conn.execute(
@@ -32,16 +33,16 @@ async def upgrade_blank_to_v3(conn: Connection, scheme: Scheme) -> None:
)
else:
membership_check = "CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock'))"
- await conn.execute(
- f"""CREATE TABLE mx_user_profile (
+ await conn.execute(f"""
+ CREATE TABLE mx_user_profile (
room_id TEXT,
user_id TEXT,
membership membership NOT NULL {membership_check},
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
- )"""
- )
+ )
+ """)
@upgrade_table.register(description="Stop using size-limited string fields")
@@ -59,13 +60,16 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
@upgrade_table.register(description="Mark rooms that need crypto state event resynced")
async def upgrade_v3(conn: Connection) -> None:
if await conn.table_exists("portal"):
- await conn.execute(
- """
+ await conn.execute("""
INSERT INTO mx_room_state (room_id, encryption)
SELECT portal.mxid, '{"resync":true}' FROM portal
WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
ON CONFLICT (room_id) DO UPDATE
SET encryption=excluded.encryption
WHERE mx_room_state.encryption IS NULL
- """
- )
+ """)
+
+
+@upgrade_table.register(description="Add create event to room state cache")
+async def upgrade_v4(conn: Connection) -> None:
+ await conn.execute("ALTER TABLE mx_room_state ADD COLUMN create_event TEXT")
diff --git a/mautrix/client/state_store/file.py b/mautrix/client/state_store/file.py
index d567c853..a5d53663 100644
--- a/mautrix/client/state_store/file.py
+++ b/mautrix/client/state_store/file.py
@@ -15,6 +15,7 @@
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
+ StateEvent,
UserID,
)
from mautrix.util.file_store import Filer, FileStore
@@ -65,3 +66,7 @@ async def set_power_levels(
) -> None:
await super().set_power_levels(room_id, content)
self._time_limited_flush()
+
+ async def set_create(self, event: StateEvent) -> None:
+ await super().set_create(event)
+ self._time_limited_flush()
diff --git a/mautrix/client/state_store/memory.py b/mautrix/client/state_store/memory.py
index 8f2edac5..2b010a75 100644
--- a/mautrix/client/state_store/memory.py
+++ b/mautrix/client/state_store/memory.py
@@ -14,6 +14,7 @@
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
+ StateEvent,
UserID,
)
@@ -25,6 +26,7 @@ class SerializedStateStore(TypedDict):
full_member_list: dict[RoomID, bool]
power_levels: dict[RoomID, Any]
encryption: dict[RoomID, Any]
+ create: dict[RoomID, Any]
class MemoryStateStore(StateStore):
@@ -32,12 +34,14 @@ class MemoryStateStore(StateStore):
full_member_list: dict[RoomID, bool]
power_levels: dict[RoomID, PowerLevelStateEventContent]
encryption: dict[RoomID, RoomEncryptionStateEventContent | None]
+ create: dict[RoomID, StateEvent]
def __init__(self) -> None:
self.members = {}
self.full_member_list = {}
self.power_levels = {}
self.encryption = {}
+ self.create = {}
def serialize(self) -> SerializedStateStore:
"""
@@ -58,6 +62,7 @@ def serialize(self) -> SerializedStateStore:
room_id: (content.serialize() if content is not None else None)
for room_id, content in self.encryption.items()
},
+ "create": {room_id: evt.serialize() for room_id, evt in self.create.items()},
}
def deserialize(self, data: SerializedStateStore) -> None:
@@ -84,6 +89,9 @@ def deserialize(self, data: SerializedStateStore) -> None:
)
for room_id, content in data["encryption"].items()
}
+ self.create = {
+ room_id: StateEvent.deserialize(evt) for room_id, evt in data["create"].items()
+ }
async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
try:
@@ -176,6 +184,17 @@ async def set_power_levels(
content = PowerLevelStateEventContent.deserialize(content)
self.power_levels[room_id] = content
+ async def has_create_cached(self, room_id: RoomID) -> bool:
+ return room_id in self.create
+
+ async def get_create(self, room_id: RoomID) -> StateEvent | None:
+ return self.create.get(room_id)
+
+ async def set_create(self, event: StateEvent | dict[str, Any]) -> None:
+ if not isinstance(event, StateEvent):
+ event = StateEvent.deserialize(event)
+ self.create[event.room_id] = event
+
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
return room_id in self.encryption
diff --git a/mautrix/client/store_updater.py b/mautrix/client/store_updater.py
index a5530bde..35280324 100644
--- a/mautrix/client/store_updater.py
+++ b/mautrix/client/store_updater.py
@@ -5,6 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
+from typing import Literal
import asyncio
from mautrix.errors import MForbidden, MNotFound
@@ -196,20 +197,28 @@ async def send_state_event(
return event_id
async def get_state_event(
- self, room_id: RoomID, event_type: EventType, state_key: str = ""
- ) -> StateEventContent:
- event = await super().get_state_event(room_id, event_type, state_key)
+ self,
+ room_id: RoomID,
+ event_type: EventType,
+ state_key: str = "",
+ *,
+ format: str = "content",
+ ) -> StateEventContent | StateEvent:
+ event = await super().get_state_event(room_id, event_type, state_key, format=format)
if self.state_store:
- fake_event = StateEvent(
- type=event_type,
- room_id=room_id,
- event_id=EventID(""),
- sender=UserID(""),
- state_key=state_key,
- timestamp=0,
- content=event,
- )
- await self.state_store.update_state(fake_event)
+ if isinstance(event, StateEvent):
+ await self.state_store.update_state(event)
+ else:
+ fake_event = StateEvent(
+ type=event_type,
+ room_id=room_id,
+ event_id=EventID(""),
+ sender=UserID(""),
+ state_key=state_key,
+ timestamp=0,
+ content=event,
+ )
+ await self.state_store.update_state(fake_event)
return event
async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]:
diff --git a/mautrix/client/syncer.py b/mautrix/client/syncer.py
index 8e115bbb..4b8661a2 100644
--- a/mautrix/client/syncer.py
+++ b/mautrix/client/syncer.py
@@ -5,11 +5,11 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Any, Awaitable, Callable, Type, TypeVar
+from typing import Any, Awaitable, Callable, NamedTuple, Optional, Type, TypeVar
from abc import ABC, abstractmethod
-from contextlib import suppress
from enum import Enum, Flag, auto
import asyncio
+import itertools
import time
from mautrix.errors import MUnknownToken
@@ -25,7 +25,6 @@
Filter,
FilterID,
GenericEvent,
- MessageEvent,
PresenceState,
SerializerError,
StateEvent,
@@ -79,13 +78,18 @@ class InternalEventType(Enum):
DEVICE_OTK_COUNT = auto()
+class EventHandlerProps(NamedTuple):
+ wait_sync: bool
+ sync_stream: Optional[SyncStream]
+
+
class Syncer(ABC):
loop: asyncio.AbstractEventLoop
log: TraceLogger
mxid: UserID
- global_event_handlers: list[tuple[EventHandler, bool]]
- event_handlers: dict[EventType | InternalEventType, list[tuple[EventHandler, bool]]]
+ global_event_handlers: dict[EventHandler, EventHandlerProps]
+ event_handlers: dict[EventType | InternalEventType, dict[EventHandler, EventHandlerProps]]
dispatchers: dict[Type[dispatcher.Dispatcher], dispatcher.Dispatcher]
syncing_task: asyncio.Task | None
ignore_initial_sync: bool
@@ -95,7 +99,7 @@ class Syncer(ABC):
sync_store: SyncStore
def __init__(self, sync_store: SyncStore) -> None:
- self.global_event_handlers = []
+ self.global_event_handlers = {}
self.event_handlers = {}
self.dispatchers = {}
self.syncing_task = None
@@ -158,6 +162,7 @@ def add_event_handler(
event_type: InternalEventType | EventType,
handler: EventHandler,
wait_sync: bool = False,
+ sync_stream: Optional[SyncStream] = None,
) -> None:
"""
Add a new event handler.
@@ -167,13 +172,15 @@ def add_event_handler(
event types.
handler: The handler function to add.
wait_sync: Whether or not the handler should be awaited before the next sync request.
+ sync_stream: The sync streams to listen to. Defaults to all.
"""
if not isinstance(event_type, (EventType, InternalEventType)):
raise ValueError("Invalid event type")
+ props = EventHandlerProps(wait_sync=wait_sync, sync_stream=sync_stream)
if event_type == EventType.ALL:
- self.global_event_handlers.append((handler, wait_sync))
+ self.global_event_handlers[handler] = props
else:
- self.event_handlers.setdefault(event_type, []).append((handler, wait_sync))
+ self.event_handlers.setdefault(event_type, {})[handler] = props
def remove_event_handler(
self, event_type: EventType | InternalEventType, handler: EventHandler
@@ -197,11 +204,7 @@ def remove_event_handler(
# No handlers for this event type registered
return
- # FIXME this is a bit hacky
- with suppress(ValueError):
- handler_list.remove((handler, True))
- with suppress(ValueError):
- handler_list.remove((handler, False))
+ handler_list.pop(handler, None)
if len(handler_list) == 0 and event_type != EventType.ALL:
del self.event_handlers[event_type]
@@ -229,7 +232,9 @@ def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asynci
else:
event.type = event.type.with_class(EventType.Class.MESSAGE)
setattr(event, "source", source)
- return self.dispatch_manual_event(event.type, event, include_global_handlers=True)
+ return self.dispatch_manual_event(
+ event.type, event, include_global_handlers=True, source=source
+ )
async def _catch_errors(self, handler: EventHandler, data: Any) -> None:
try:
@@ -243,13 +248,22 @@ def dispatch_manual_event(
data: Any,
include_global_handlers: bool = False,
force_synchronous: bool = False,
+ source: Optional[SyncStream] = None,
) -> list[asyncio.Task]:
- handlers = self.event_handlers.get(event_type, [])
+ handlers = self.event_handlers.get(event_type, {}).items()
if include_global_handlers:
- handlers = self.global_event_handlers + handlers
+ handlers = itertools.chain(self.global_event_handlers.items(), handlers)
tasks = []
- for handler, wait_sync in handlers:
- if force_synchronous or wait_sync:
+ if source is None:
+ source = getattr(data, "source", None)
+ for handler, props in handlers:
+ if (
+ props.sync_stream is not None
+ and source is not None
+ and not props.sync_stream & source
+ ):
+ continue
+ if force_synchronous or props.wait_sync:
tasks.append(asyncio.create_task(self._catch_errors(handler, data)))
else:
background_task.create(self._catch_errors(handler, data))
@@ -263,6 +277,7 @@ async def run_internal_event(
event_type,
custom_type if custom_type is not None else kwargs,
include_global_handlers=False,
+ source=SyncStream.INTERNAL,
)
await asyncio.gather(*tasks)
@@ -274,6 +289,7 @@ def dispatch_internal_event(
event_type,
custom_type if custom_type is not None else kwargs,
include_global_handlers=False,
+ source=SyncStream.INTERNAL,
)
def _try_deserialize(self, type: Type[T], data: JSON) -> T | GenericEvent:
@@ -353,12 +369,18 @@ def handle_sync(self, data: JSON) -> list[asyncio.Task]:
events: list[dict[str, JSON]] = room_data.get("invite_state", {}).get("events", [])
for raw_event in events:
raw_event["room_id"] = room_id
- raw_invite = next(
- raw_event
- for raw_event in events
- if raw_event.get("type", "") == "m.room.member"
- and raw_event.get("state_key", "") == self.mxid
- )
+ try:
+ raw_invite = next(
+ raw_event
+ for raw_event in events
+ if raw_event.get("type", "") == "m.room.member"
+ and raw_event.get("state_key", "") == self.mxid
+ )
+ except StopIteration:
+ self.log.warning(
+ f"Corrupted invite section in sync: no invite event present for {room_id}"
+ )
+ continue
# These aren't required by the spec, so make sure they're set
raw_invite.setdefault("event_id", None)
raw_invite.setdefault("origin_server_ts", int(time.time() * 1000))
diff --git a/mautrix/crypto/account.py b/mautrix/crypto/account.py
index db508262..a00ada71 100644
--- a/mautrix/crypto/account.py
+++ b/mautrix/crypto/account.py
@@ -10,15 +10,17 @@
from mautrix.types import (
DeviceID,
+ DeviceKeys,
EncryptionAlgorithm,
EncryptionKeyAlgorithm,
IdentityKey,
+ KeyID,
SigningKey,
UserID,
)
-from . import base
from .sessions import Session
+from .signature import sign_olm
class OlmAccount(olm.Account):
@@ -74,19 +76,18 @@ def new_outbound_session(self, target_key: IdentityKey, one_time_key: IdentityKe
session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now()
)
- def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> Dict[str, Any]:
- device_keys = {
- "user_id": user_id,
- "device_id": device_id,
- "algorithms": [EncryptionAlgorithm.OLM_V1.value, EncryptionAlgorithm.MEGOLM_V1.value],
- "keys": {
- f"{algorithm}:{device_id}": key for algorithm, key in self.identity_keys.items()
+ def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> DeviceKeys:
+ device_keys = DeviceKeys(
+ user_id=user_id,
+ device_id=device_id,
+ algorithms=[EncryptionAlgorithm.OLM_V1, EncryptionAlgorithm.MEGOLM_V1],
+ keys={
+ KeyID(algorithm=EncryptionKeyAlgorithm(algorithm), key_id=device_id): key
+ for algorithm, key in self.identity_keys.items()
},
- }
- signature = self.sign(base.canonical_json(device_keys))
- device_keys["signatures"] = {
- user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature}
- }
+ signatures={},
+ )
+ device_keys.signatures[user_id] = {KeyID.ed25519(device_id): sign_olm(device_keys, self)}
return device_keys
def get_one_time_keys(
@@ -97,12 +98,12 @@ def get_one_time_keys(
self.generate_one_time_keys(new_count)
keys = {}
for key_id, key in self.one_time_keys.get("curve25519", {}).items():
- signature = self.sign(base.canonical_json({"key": key}))
- keys[f"{EncryptionKeyAlgorithm.SIGNED_CURVE25519}:{key_id}"] = {
+ keys[str(KeyID.signed_curve25519(IdentityKey(key_id)))] = {
"key": key,
"signatures": {
- user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature}
+ user_id: {
+ str(KeyID.ed25519(device_id)): sign_olm({"key": key}, self),
+ }
},
}
- self.mark_keys_as_published()
return keys
diff --git a/mautrix/crypto/base.py b/mautrix/crypto/base.py
index 1f8e9a62..f7015ca1 100644
--- a/mautrix/crypto/base.py
+++ b/mautrix/crypto/base.py
@@ -5,41 +5,30 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Any, Awaitable, Callable, TypedDict
+from typing import Awaitable, Callable
import asyncio
-import functools
-import json
-
-import olm
from mautrix.errors import MForbidden, MNotFound
from mautrix.types import (
- DeviceID,
- EncryptionKeyAlgorithm,
EventType,
IdentityKey,
- KeyID,
RequestedKeyInfo,
RoomEncryptionStateEventContent,
RoomID,
RoomKeyEventContent,
SessionID,
- SigningKey,
TrustState,
UserID,
)
from mautrix.util.logging import TraceLogger
from .. import client as cli, crypto
-
-
-class SignedObject(TypedDict):
- signatures: dict[UserID, dict[str, str]]
- unsigned: Any
+from .ssss import Machine as SSSSMachine
class BaseOlmMachine:
client: cli.Client
+ ssss: SSSSMachine
log: TraceLogger
crypto_store: crypto.CryptoStore
state_store: crypto.StateStore
@@ -116,27 +105,3 @@ async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None:
evt.beeper_max_age_ms = encryption_info.rotation_period_ms
if not evt.beeper_max_messages:
evt.beeper_max_messages = encryption_info.rotation_period_msgs
-
-
-canonical_json = functools.partial(
- json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True
-)
-
-
-def verify_signature_json(
- data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey
-) -> bool:
- data_copy = {**data}
- data_copy.pop("unsigned", None)
- signatures = data_copy.pop("signatures")
- key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name))
- try:
- signature = signatures[user_id][key_id]
- except KeyError:
- return False
- signed_data = canonical_json(data_copy)
- try:
- olm.ed25519_verify(key, signed_data, signature)
- return True
- except olm.OlmVerifyError:
- return False
diff --git a/mautrix/crypto/cross_signing.py b/mautrix/crypto/cross_signing.py
new file mode 100644
index 00000000..7577cd5d
--- /dev/null
+++ b/mautrix/crypto/cross_signing.py
@@ -0,0 +1,177 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from ..types import (
+ JSON,
+ CrossSigner,
+ CrossSigningKeys,
+ CrossSigningUsage,
+ DeviceIdentity,
+ EventType,
+ KeyID,
+ UserID,
+)
+from .cross_signing_key import CrossSigningPrivateKeys, CrossSigningPublicKeys, CrossSigningSeeds
+from .device_lists import DeviceListMachine
+from .signature import sign_olm
+from .ssss import Key as SSSSKey
+
+
+class CrossSigningMachine(DeviceListMachine):
+ _cross_signing_public_keys: CrossSigningPublicKeys | None
+ _cross_signing_public_keys_fetched: bool
+ _cross_signing_private_keys: CrossSigningPrivateKeys | None
+
+ async def verify_with_recovery_key(self, recovery_key: str) -> None:
+ if not self.account.shared:
+ raise ValueError("Device keys must be shared before verifying with recovery key")
+ key_id, key_data = await self.ssss.get_default_key_data()
+ ssss_key = key_data.verify_recovery_key(key_id, recovery_key)
+ seeds = await self._fetch_cross_signing_keys_from_ssss(ssss_key)
+ self._import_cross_signing_keys(seeds)
+ await self.sign_own_device(self.own_identity)
+
+ def _import_cross_signing_keys(self, seeds: CrossSigningSeeds) -> None:
+ self._cross_signing_private_keys = seeds.to_keys()
+ self._cross_signing_public_keys = self._cross_signing_private_keys.public_keys
+
+ async def generate_recovery_key(
+ self, passphrase: str | None = None, seeds: CrossSigningSeeds | None = None
+ ) -> str:
+ if not self.account.shared:
+ raise ValueError("Device keys must be shared before generating recovery key")
+ seeds = seeds or CrossSigningSeeds.generate()
+ ssss_key = await self.ssss.generate_and_upload_key(passphrase)
+ await self._upload_cross_signing_keys_to_ssss(ssss_key, seeds)
+ await self._publish_cross_signing_keys(seeds.to_keys())
+ await self.ssss.set_default_key_id(ssss_key.id)
+ await self.sign_own_device(self.own_identity)
+ return ssss_key.recovery_key
+
+ async def _fetch_cross_signing_keys_from_ssss(self, key: SSSSKey) -> CrossSigningSeeds:
+ return CrossSigningSeeds(
+ master_key=await self.ssss.get_decrypted_account_data(
+ EventType.CROSS_SIGNING_MASTER, key
+ ),
+ user_signing_key=await self.ssss.get_decrypted_account_data(
+ EventType.CROSS_SIGNING_USER_SIGNING, key
+ ),
+ self_signing_key=await self.ssss.get_decrypted_account_data(
+ EventType.CROSS_SIGNING_SELF_SIGNING, key
+ ),
+ )
+
+ async def _upload_cross_signing_keys_to_ssss(
+ self, key: SSSSKey, seeds: CrossSigningSeeds
+ ) -> None:
+ await self.ssss.set_encrypted_account_data(
+ EventType.CROSS_SIGNING_MASTER, seeds.master_key, key
+ )
+ await self.ssss.set_encrypted_account_data(
+ EventType.CROSS_SIGNING_USER_SIGNING, seeds.user_signing_key, key
+ )
+ await self.ssss.set_encrypted_account_data(
+ EventType.CROSS_SIGNING_SELF_SIGNING, seeds.self_signing_key, key
+ )
+
+ async def get_own_cross_signing_public_keys(self) -> CrossSigningPublicKeys | None:
+ if self._cross_signing_public_keys or self._cross_signing_public_keys_fetched:
+ return self._cross_signing_public_keys
+ keys = await self.get_cross_signing_public_keys(self.client.mxid)
+ self._cross_signing_public_keys_fetched = True
+ if keys:
+ self._cross_signing_public_keys = keys
+ return keys
+
+ async def get_cross_signing_public_keys(
+ self, user_id: UserID
+ ) -> CrossSigningPublicKeys | None:
+ db_keys = await self.crypto_store.get_cross_signing_keys(user_id)
+ if CrossSigningUsage.MASTER not in db_keys and user_id not in self._cs_fetch_attempted:
+ self.log.debug(f"Didn't find any cross-signing keys for {user_id}, fetching...")
+ async with self._fetch_keys_lock:
+ if user_id not in self._cs_fetch_attempted:
+ self._cs_fetch_attempted.add(user_id)
+ await self._fetch_keys([user_id], include_untracked=True)
+ db_keys = await self.crypto_store.get_cross_signing_keys(user_id)
+ if CrossSigningUsage.MASTER not in db_keys:
+ return None
+ return CrossSigningPublicKeys(
+ master_key=db_keys[CrossSigningUsage.MASTER].key,
+ self_signing_key=(
+ db_keys[CrossSigningUsage.SELF].key if CrossSigningUsage.SELF in db_keys else None
+ ),
+ user_signing_key=(
+ db_keys[CrossSigningUsage.USER].key if CrossSigningUsage.USER in db_keys else None
+ ),
+ )
+
+ async def sign_own_device(self, device: DeviceIdentity) -> None:
+ full_keys = await self._get_full_device_keys(device)
+ ssk = self._cross_signing_private_keys.self_signing_key
+ signature = sign_olm(full_keys, ssk)
+ full_keys.signatures = {self.client.mxid: {KeyID.ed25519(ssk.public_key): signature}}
+ await self.client.upload_one_signature(device.user_id, device.device_id, full_keys)
+ await self.crypto_store.put_signature(
+ CrossSigner(device.user_id, device.signing_key),
+ CrossSigner(self.client.mxid, ssk.public_key),
+ signature,
+ )
+
+ async def _publish_cross_signing_keys(
+ self,
+ keys: CrossSigningPrivateKeys,
+ auth: dict[str, JSON] | None = None,
+ ) -> None:
+ public = keys.public_keys
+ master_key = CrossSigningKeys(
+ user_id=self.client.mxid,
+ usage=[CrossSigningUsage.MASTER],
+ keys={KeyID.ed25519(public.master_key): public.master_key},
+ )
+ master_key.signatures = {
+ self.client.mxid: {
+ KeyID.ed25519(self.client.device_id): sign_olm(master_key, self.account),
+ }
+ }
+ self_key = CrossSigningKeys(
+ user_id=self.client.mxid,
+ usage=[CrossSigningUsage.SELF],
+ keys={KeyID.ed25519(public.self_signing_key): public.self_signing_key},
+ )
+ self_key.signatures = {
+ self.client.mxid: {
+ KeyID.ed25519(public.master_key): sign_olm(self_key, keys.master_key),
+ }
+ }
+ user_key = CrossSigningKeys(
+ user_id=self.client.mxid,
+ usage=[CrossSigningUsage.USER],
+ keys={KeyID.ed25519(public.user_signing_key): public.user_signing_key},
+ )
+ user_key.signatures = {
+ self.client.mxid: {
+ KeyID.ed25519(public.master_key): sign_olm(user_key, keys.master_key),
+ }
+ }
+ await self.client.upload_cross_signing_keys(
+ keys={
+ CrossSigningUsage.MASTER: master_key,
+ CrossSigningUsage.SELF: self_key,
+ CrossSigningUsage.USER: user_key,
+ },
+ auth=auth,
+ )
+ await self.crypto_store.put_cross_signing_key(
+ self.client.mxid, CrossSigningUsage.MASTER, public.master_key
+ )
+ await self.crypto_store.put_cross_signing_key(
+ self.client.mxid, CrossSigningUsage.SELF, public.self_signing_key
+ )
+ await self.crypto_store.put_cross_signing_key(
+ self.client.mxid, CrossSigningUsage.USER, public.user_signing_key
+ )
+ self._cross_signing_private_keys = keys
+ self._cross_signing_public_keys = public
diff --git a/mautrix/crypto/cross_signing_key.py b/mautrix/crypto/cross_signing_key.py
new file mode 100644
index 00000000..f4e3c1c1
--- /dev/null
+++ b/mautrix/crypto/cross_signing_key.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from typing import NamedTuple
+
+import olm
+
+from mautrix.crypto.ssss.util import cryptorand
+from mautrix.types import SigningKey
+
+
+class CrossSigningPublicKeys(NamedTuple):
+ master_key: SigningKey
+ self_signing_key: SigningKey
+ user_signing_key: SigningKey
+
+
+class CrossSigningPrivateKeys(NamedTuple):
+ master_key: olm.PkSigning
+ self_signing_key: olm.PkSigning
+ user_signing_key: olm.PkSigning
+
+ @property
+ def public_keys(self) -> CrossSigningPublicKeys:
+ return CrossSigningPublicKeys(
+ master_key=self.master_key.public_key,
+ self_signing_key=self.self_signing_key.public_key,
+ user_signing_key=self.user_signing_key.public_key,
+ )
+
+
+class CrossSigningSeeds(NamedTuple):
+ master_key: bytes
+ self_signing_key: bytes
+ user_signing_key: bytes
+
+ def to_keys(self) -> CrossSigningPrivateKeys:
+ return CrossSigningPrivateKeys(
+ master_key=olm.PkSigning(self.master_key),
+ self_signing_key=olm.PkSigning(self.self_signing_key),
+ user_signing_key=olm.PkSigning(self.user_signing_key),
+ )
+
+ @classmethod
+ def generate(cls) -> "CrossSigningSeeds":
+ return cls(
+ master_key=cryptorand.read(32),
+ self_signing_key=cryptorand.read(32),
+ user_signing_key=cryptorand.read(32),
+ )
diff --git a/mautrix/crypto/device_lists.py b/mautrix/crypto/device_lists.py
index 1b5e0dbd..b00b104c 100644
--- a/mautrix/crypto/device_lists.py
+++ b/mautrix/crypto/device_lists.py
@@ -23,10 +23,23 @@
UserID,
)
-from .base import BaseOlmMachine, verify_signature_json
+from .base import BaseOlmMachine
+from .signature import verify_signature_json
class DeviceListMachine(BaseOlmMachine):
+ @property
+ def own_identity(self) -> DeviceIdentity:
+ return DeviceIdentity(
+ user_id=self.client.mxid,
+ device_id=self.client.device_id,
+ identity_key=self.account.identity_key,
+ signing_key=self.account.signing_key,
+ trust=TrustState.VERIFIED,
+ deleted=False,
+ name="",
+ )
+
async def _fetch_keys(
self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False
) -> dict[UserID, dict[DeviceID, DeviceIdentity]]:
@@ -46,61 +59,67 @@ async def _fetch_keys(
data = {}
for user_id, devices in resp.device_keys.items():
missing_users.remove(user_id)
-
- new_devices = {}
- existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
-
- self.log.trace(
- f"Updating devices for {user_id}, got {len(devices)}, "
- f"have {len(existing_devices)} in store"
- )
- changed = False
- ssks = resp.self_signing_keys.get(user_id)
- ssk = ssks.first_ed25519_key if ssks else None
- for device_id, device_keys in devices.items():
- try:
- existing = existing_devices[device_id]
- except KeyError:
- existing = None
- changed = True
- self.log.trace(f"Validating device {device_keys} of {user_id}")
- try:
- new_device = await self._validate_device(
- user_id, device_id, device_keys, existing
- )
- except DeviceValidationError as e:
- self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
- else:
- if new_device:
- new_devices[device_id] = new_device
- await self._store_device_self_signatures(device_keys, ssk)
- self.log.debug(
- f"Storing new device list for {user_id} containing {len(new_devices)} devices"
- )
- await self.crypto_store.put_devices(user_id, new_devices)
- data[user_id] = new_devices
-
- if changed or len(new_devices) != len(existing_devices):
- if self.delete_keys_on_device_delete:
- for device_id in existing_devices.keys() - new_devices.keys():
- device = existing_devices[device_id]
- removed_ids = await self.crypto_store.redact_group_sessions(
- room_id=None, sender_key=device.identity_key, reason="device removed"
- )
- self.log.info(
- "Redacted megolm sessions sent by removed device "
- f"{device.user_id}/{device.device_id}: {removed_ids}"
- )
- await self.on_devices_changed(user_id)
+ async with self.crypto_store.transaction():
+ data[user_id] = await self._process_fetched_keys(user_id, devices, resp)
for user_id in missing_users:
self.log.warning(f"Didn't get any devices for user {user_id}")
- for user_id in users:
- await self._store_cross_signing_keys(resp, user_id)
-
return data
+ async def _process_fetched_keys(
+ self,
+ user_id: UserID,
+ devices: dict[DeviceID, DeviceKeys],
+ resp: QueryKeysResponse,
+ ) -> dict[DeviceID, DeviceIdentity]:
+ new_devices = {}
+ existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
+
+ self.log.trace(
+ f"Updating devices for {user_id}, got {len(devices)}, "
+ f"have {len(existing_devices)} in store"
+ )
+ changed = False
+ ssks = resp.self_signing_keys.get(user_id)
+ ssk = ssks.first_ed25519_key if ssks else None
+ for device_id, device_keys in devices.items():
+ try:
+ existing = existing_devices[device_id]
+ except KeyError:
+ existing = None
+ changed = True
+ self.log.trace(f"Validating device {device_keys} of {user_id}")
+ try:
+ new_device = await self._validate_device(user_id, device_id, device_keys, existing)
+ except DeviceValidationError as e:
+ self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
+ else:
+ if new_device:
+ new_devices[device_id] = new_device
+ await self._store_device_self_signatures(device_keys, ssk)
+ self.log.debug(
+ f"Storing new device list for {user_id} containing {len(new_devices)} devices"
+ )
+ await self.crypto_store.put_devices(user_id, new_devices)
+
+ if changed or len(new_devices) != len(existing_devices):
+ if self.delete_keys_on_device_delete:
+ for device_id in existing_devices.keys() - new_devices.keys():
+ device = existing_devices[device_id]
+ removed_ids = await self.crypto_store.redact_group_sessions(
+ room_id=None, sender_key=device.identity_key, reason="device removed"
+ )
+ self.log.info(
+ "Redacted megolm sessions sent by removed device "
+ f"{device.user_id}/{device.device_id}: {removed_ids}"
+ )
+ await self.on_devices_changed(user_id)
+
+ await self._store_cross_signing_keys(resp, user_id)
+
+ return new_devices
+
async def _store_device_self_signatures(
self, device_keys: DeviceKeys, self_signing_key: SigningKey | None
) -> None:
@@ -193,7 +212,7 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User
signing_key = device.ed25519
except KeyError:
pass
- if len(signing_key) != 43:
+ if not signing_key or len(signing_key) != 43:
self.log.debug(
f"Cross-signing key {user_id}/{actual_key} has a signature from "
f"an unknown key {key_id}"
@@ -219,6 +238,12 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User
else:
self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}")
+ async def _get_full_device_keys(self, device: DeviceIdentity) -> DeviceKeys:
+ resp = await self.client.query_keys({device.user_id: [device.device_id]})
+ keys = resp.device_keys[device.user_id][device.device_id]
+ await self._validate_device(device.user_id, device.device_id, keys, device)
+ return keys
+
async def get_or_fetch_device(
self, user_id: UserID, device_id: DeviceID
) -> DeviceIdentity | None:
@@ -296,18 +321,23 @@ async def _validate_device(
deleted=False,
)
- async def resolve_trust(self, device: DeviceIdentity) -> TrustState:
+ async def resolve_trust(self, device: DeviceIdentity, allow_fetch: bool = True) -> TrustState:
try:
- return await self._try_resolve_trust(device)
+ return await self._try_resolve_trust(device, allow_fetch)
except Exception:
self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}")
return TrustState.UNVERIFIED
- async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState:
- if device.trust in (TrustState.VERIFIED, TrustState.BLACKLISTED):
+ async def _try_resolve_trust(
+ self, device: DeviceIdentity, allow_fetch: bool = True
+ ) -> TrustState:
+ if device.device_id != self.client.device_id and device.trust in (
+ TrustState.VERIFIED,
+ TrustState.BLACKLISTED,
+ ):
return device.trust
their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id)
- if len(their_keys) == 0 and device.user_id not in self._cs_fetch_attempted:
+ if len(their_keys) == 0 and allow_fetch and device.user_id not in self._cs_fetch_attempted:
self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...")
async with self._fetch_keys_lock:
if device.user_id not in self._cs_fetch_attempted:
@@ -318,7 +348,8 @@ async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState:
msk = their_keys[CrossSigningUsage.MASTER]
ssk = their_keys[CrossSigningUsage.SELF]
except KeyError as e:
- self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
+ if allow_fetch:
+ self.log.warning(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
return TrustState.UNVERIFIED
ssk_signed = await self.crypto_store.is_key_signed_by(
target=CrossSigner(device.user_id, ssk.key),
diff --git a/mautrix/crypto/encrypt_olm.py b/mautrix/crypto/encrypt_olm.py
index 620cbbc8..029ad1e5 100644
--- a/mautrix/crypto/encrypt_olm.py
+++ b/mautrix/crypto/encrypt_olm.py
@@ -18,8 +18,9 @@
UserID,
)
-from .base import BaseOlmMachine, verify_signature_json
+from .base import BaseOlmMachine
from .sessions import Session
+from .signature import verify_signature_json
ClaimKeysList = Dict[UserID, Dict[DeviceID, DeviceIdentity]]
diff --git a/mautrix/crypto/machine.py b/mautrix/crypto/machine.py
index 0cfe6ea3..60c65677 100644
--- a/mautrix/crypto/machine.py
+++ b/mautrix/crypto/machine.py
@@ -32,10 +32,12 @@
from mautrix.util.logging import TraceLogger
from .account import OlmAccount
+from .cross_signing import CrossSigningMachine
from .decrypt_megolm import MegolmDecryptionMachine
from .encrypt_megolm import MegolmEncryptionMachine
from .key_request import KeyRequestingMachine
from .key_share import KeySharingMachine
+from .ssss import Machine as SSSSMachine
from .store import CryptoStore, StateStore
from .unwedge import OlmUnwedgingMachine
@@ -46,6 +48,7 @@ class OlmMachine(
OlmUnwedgingMachine,
KeySharingMachine,
KeyRequestingMachine,
+ CrossSigningMachine,
):
"""
OlmMachine is the main class for handling things related to Matrix end-to-end encryption with
@@ -58,6 +61,7 @@ class OlmMachine(
log: TraceLogger
crypto_store: CryptoStore
state_store: StateStore
+ ssss: SSSSMachine
account: Optional[OlmAccount]
@@ -70,6 +74,7 @@ def __init__(
) -> None:
super().__init__()
self.client = client
+ self.ssss = SSSSMachine(client)
self.log = log or logging.getLogger("mau.crypto")
self.crypto_store = crypto_store
self.state_store = state_store
@@ -96,6 +101,10 @@ def __init__(
self._prev_unwedge = {}
self._cs_fetch_attempted = set()
+ self._cross_signing_public_keys = None
+ self._cross_signing_public_keys_fetched = False
+ self._cross_signing_private_keys = None
+
self.client.add_event_handler(
cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True
)
@@ -213,6 +222,11 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None:
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
do syncing in some manual way.
"""
+ if isinstance(evt, DecryptedOlmEvent):
+ self.log.warning(
+ f"Dropping unexpected nested encrypted to-device event from {evt.sender}"
+ )
+ return
self.log.trace(
f"Handling encrypted to-device event from {evt.content.sender_key} ({evt.sender})"
)
@@ -221,6 +235,8 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None:
await self._receive_room_key(decrypted_evt)
elif decrypted_evt.type == EventType.FORWARDED_ROOM_KEY:
await self._receive_forwarded_room_key(decrypted_evt)
+ else:
+ self.client.dispatch_event(decrypted_evt, source=evt.source)
async def _receive_room_key(self, evt: DecryptedOlmEvent) -> None:
# TODO nio had a comment saying "handle this better"
@@ -292,7 +308,7 @@ async def _share_keys(self, current_otk_count: int | None) -> None:
):
self.log.debug("Checking OTK count on server")
current_otk_count = (await self.client.upload_keys()).get(
- EncryptionKeyAlgorithm.SIGNED_CURVE25519
+ EncryptionKeyAlgorithm.SIGNED_CURVE25519, 0
)
device_keys = (
self.account.get_device_keys(self.client.mxid, self.client.device_id)
@@ -310,6 +326,7 @@ async def _share_keys(self, current_otk_count: int | None) -> None:
self.log.debug(f"Uploading {len(one_time_keys)} one-time keys")
resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys)
self.account.shared = True
+ self.account.mark_keys_as_published()
self._last_key_share = time.monotonic()
await self.crypto_store.put_account(self.account)
self.log.debug(f"Shared keys and saved account, new keys: {resp}")
diff --git a/mautrix/crypto/signature.py b/mautrix/crypto/signature.py
new file mode 100644
index 00000000..6dc13e65
--- /dev/null
+++ b/mautrix/crypto/signature.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from typing import Any, TypedDict
+import functools
+import json
+
+import olm
+import unpaddedbase64
+
+from mautrix.types import (
+ JSON,
+ DeviceID,
+ EncryptionKeyAlgorithm,
+ KeyID,
+ Serializable,
+ Signature,
+ SigningKey,
+ UserID,
+)
+
+try:
+ from Crypto.PublicKey import ECC
+ from Crypto.Signature import eddsa
+except ImportError:
+ from Cryptodome.PublicKey import ECC
+ from Cryptodome.Signature import eddsa
+
+canonical_json = functools.partial(
+ json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True
+)
+
+
+class SignedObject(TypedDict):
+ signatures: dict[UserID, dict[str, str]]
+ unsigned: Any
+
+
+def sign_olm(data: dict[str, JSON] | Serializable, key: olm.PkSigning | olm.Account) -> Signature:
+ if isinstance(data, Serializable):
+ data = data.serialize()
+ data.pop("signatures", None)
+ data.pop("unsigned", None)
+ return Signature(key.sign(canonical_json(data)))
+
+
+def verify_signature_json(
+ data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey
+) -> bool:
+ data_copy = {**data}
+ data_copy.pop("unsigned", None)
+ signatures = data_copy.pop("signatures")
+ key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name))
+ try:
+ signature = signatures[user_id][key_id]
+ decoded_key = unpaddedbase64.decode_base64(key)
+ # pycryptodome doesn't accept raw keys, so wrap it in a DER structure
+ der_key = b"\x30\x2a\x30\x05\x06\x03\x2b\x65\x70\x03\x21\x00" + decoded_key
+ decoded_signature = unpaddedbase64.decode_base64(signature)
+ parsed_key = ECC.import_key(der_key)
+ verifier = eddsa.new(parsed_key, "rfc8032")
+ verifier.verify(canonical_json(data_copy).encode("utf-8"), decoded_signature)
+ return True
+ except (KeyError, ValueError):
+ return False
diff --git a/mautrix/crypto/signature_test.py b/mautrix/crypto/signature_test.py
new file mode 100644
index 00000000..115e4836
--- /dev/null
+++ b/mautrix/crypto/signature_test.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from mautrix.types import SigningKey, UserID
+
+from .signature import verify_signature_json
+
+
+def test_verify_signature_json() -> None:
+ assert verify_signature_json(
+ # This is actually a federation PDU rather than a device signature,
+ # but they're both 25519 curves so it doesn't make a difference.
+ {
+ "auth_events": [
+ "$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw",
+ "$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM",
+ "$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo",
+ ],
+ "content": {},
+ "depth": 3212,
+ "hashes": {"sha256": "K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},
+ "origin_server_ts": 1754242687127,
+ "prev_events": ["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],
+ "room_id": "!offtopic-2:continuwuity.org",
+ "sender": "@tulir:maunium.net",
+ "type": "m.room.message",
+ "signatures": {
+ UserID("maunium.net"): {
+ "ed25519:a_xxeS": "SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"
+ }
+ },
+ "unsigned": {"age_ts": 1754242687146},
+ },
+ UserID("maunium.net"),
+ "a_xxeS",
+ SigningKey("lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg"),
+ )
diff --git a/mautrix/crypto/ssss/__init__.py b/mautrix/crypto/ssss/__init__.py
new file mode 100644
index 00000000..9224418d
--- /dev/null
+++ b/mautrix/crypto/ssss/__init__.py
@@ -0,0 +1,8 @@
+from .key import Key, KeyMetadata, PassphraseMetadata
+from .machine import Machine
+from .types import (
+ Algorithm,
+ EncryptedAccountDataEventContent,
+ EncryptedKeyData,
+ PassphraseAlgorithm,
+)
diff --git a/mautrix/crypto/ssss/key.py b/mautrix/crypto/ssss/key.py
new file mode 100644
index 00000000..691ded71
--- /dev/null
+++ b/mautrix/crypto/ssss/key.py
@@ -0,0 +1,139 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from typing import Optional
+import base64
+import hashlib
+import hmac
+
+from attr import dataclass
+import unpaddedbase64
+
+from mautrix.types import EventType, SerializableAttrs
+
+from .types import Algorithm, EncryptedKeyData, PassphraseAlgorithm
+from .util import (
+ calculate_hash,
+ cryptorand,
+ decode_base58_recovery_key,
+ derive_keys,
+ encode_base58_recovery_key,
+ prepare_aes,
+)
+
+try:
+ from Crypto.Cipher import AES
+ from Crypto.Util import Counter
+except ImportError:
+ from Cryptodome.Cipher import AES
+ from Cryptodome.Util import Counter
+
+
+@dataclass
+class PassphraseMetadata(SerializableAttrs):
+ algorithm: PassphraseAlgorithm
+ iterations: int
+ salt: str
+ bits: int = 256
+
+ def get_key(self, passphrase: str) -> bytes:
+ if self.algorithm != PassphraseAlgorithm.PBKDF2:
+ raise ValueError(f"Unsupported passphrase algorithm {self.algorithm}")
+ return hashlib.pbkdf2_hmac(
+ "sha512",
+ passphrase.encode("utf-8"),
+ self.salt.encode("utf-8"),
+ self.iterations,
+ self.bits // 8,
+ )
+
+
+@dataclass
+class KeyMetadata(SerializableAttrs):
+ algorithm: Algorithm
+
+ iv: str | None = None
+ mac: str | None = None
+
+ name: str | None = None
+ passphrase: Optional[PassphraseMetadata] = None
+
+ def verify_passphrase(self, key_id: str, phrase: str) -> "Key":
+ if not self.passphrase:
+ raise ValueError("Passphrase not set on this key")
+ return self.verify_raw_key(key_id, self.passphrase.get_key(phrase))
+
+ def verify_recovery_key(self, key_id: str, recovery_key: str) -> "Key":
+ decoded_key = decode_base58_recovery_key(recovery_key)
+ if not decoded_key:
+ raise ValueError("Invalid recovery key syntax")
+ return self.verify_raw_key(key_id, decoded_key)
+
+ def verify_raw_key(self, key_id: str, key: bytes) -> "Key":
+ if self.mac.rstrip("=") != calculate_hash(key, self.iv):
+ raise ValueError("Key MAC does not match")
+ return Key(id=key_id, key=key, metadata=self)
+
+
+@dataclass
+class Key:
+ id: str
+ key: bytes
+ metadata: KeyMetadata
+
+ @classmethod
+ def generate(cls, passphrase: str | None = None) -> "Key":
+ passphrase_meta = (
+ PassphraseMetadata(
+ algorithm=PassphraseAlgorithm.PBKDF2,
+ iterations=500_000,
+ salt=base64.b64encode(cryptorand.read(24)).decode("utf-8"),
+ bits=256,
+ )
+ if passphrase
+ else None
+ )
+ key = passphrase_meta.get_key(passphrase) if passphrase else cryptorand.read(32)
+ iv = unpaddedbase64.encode_base64(cryptorand.read(16))
+ metadata = KeyMetadata(
+ algorithm=Algorithm.AES_HMAC_SHA2,
+ passphrase=passphrase_meta,
+ mac=calculate_hash(key, iv),
+ iv=iv,
+ )
+ key_id = unpaddedbase64.encode_base64(cryptorand.read(24))
+ return cls(key=key, id=key_id, metadata=metadata)
+
+ @property
+ def recovery_key(self) -> str:
+ return encode_base58_recovery_key(self.key)
+
+ def encrypt(self, event_type: str | EventType, data: str | bytes) -> EncryptedKeyData:
+ if isinstance(data, str):
+ data = data.encode("utf-8")
+ data = base64.b64encode(data).rstrip(b"=")
+
+ aes_key, hmac_key = derive_keys(self.key, event_type)
+ iv = bytearray(cryptorand.read(16))
+ iv[8] &= 0x7F
+ ciphertext = prepare_aes(aes_key, iv).encrypt(data)
+ digest = hmac.digest(hmac_key, ciphertext, hashlib.sha256)
+ return EncryptedKeyData(
+ ciphertext=unpaddedbase64.encode_base64(ciphertext),
+ iv=unpaddedbase64.encode_base64(iv),
+ mac=unpaddedbase64.encode_base64(digest),
+ )
+
+ def decrypt(self, event_type: str | EventType, data: EncryptedKeyData) -> bytes:
+ aes_key, hmac_key = derive_keys(self.key, event_type)
+ ciphertext = unpaddedbase64.decode_base64(data.ciphertext)
+ mac = unpaddedbase64.decode_base64(data.mac)
+
+ expected_mac = hmac.digest(hmac_key, ciphertext, hashlib.sha256)
+ if not hmac.compare_digest(mac, expected_mac):
+ raise ValueError("Invalid MAC")
+
+ plaintext = prepare_aes(aes_key, data.iv).decrypt(ciphertext)
+ return unpaddedbase64.decode_base64(plaintext.decode("utf-8"))
diff --git a/mautrix/crypto/ssss/key_test.py b/mautrix/crypto/ssss/key_test.py
new file mode 100644
index 00000000..e60f1e3c
--- /dev/null
+++ b/mautrix/crypto/ssss/key_test.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+import pytest
+
+from ...types.event.type import EventType
+from .key import Key, KeyMetadata
+from .types import EncryptedAccountDataEventContent
+
+KEY1_CROSS_SIGNING_MASTER_KEY = """{
+ "encrypted": {
+ "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0": {
+ "iv": "BpKP9nQJTE9jrsAssoxPqQ==",
+ "ciphertext": "fNRiiiidezjerTgV+G6pUtmeF3izzj5re/mVvY0hO2kM6kYGrxLuIu2ej80=",
+ "mac": "/gWGDGMyOLmbJp+aoSLh5JxCs0AdS6nAhjzpe+9G2Q0="
+ }
+ }
+}"""
+
+KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED = bytes(
+ [
+ 0x68,
+ 0xF9,
+ 0x7F,
+ 0xD1,
+ 0x92,
+ 0x2E,
+ 0xEC,
+ 0xF6,
+ 0xB8,
+ 0x2B,
+ 0xB8,
+ 0x90,
+ 0xD2,
+ 0x4D,
+ 0x06,
+ 0x52,
+ 0x98,
+ 0x4E,
+ 0x7A,
+ 0x1D,
+ 0x70,
+ 0x3B,
+ 0x9E,
+ 0x86,
+ 0x7B,
+ 0x7E,
+ 0xBA,
+ 0xF7,
+ 0xFE,
+ 0xB9,
+ 0x5B,
+ 0x6F,
+ ]
+)
+
+KEY1_META = """{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "passphrase": {
+ "algorithm": "m.pbkdf2",
+ "iterations": 500000,
+ "salt": "y863BOoqOadgDp8S3FtHXikDJEalsQ7d"
+ },
+ "iv": "xxkTK0L4UzxgAFkQ6XPwsw",
+ "mac": "MEhooO0ZhFJNxUhvRMSxBnJfL20wkLgle3ocY0ee/eA"
+}"""
+KEY1_ID = "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0"
+KEY1_RECOVERY_KEY = "EsTE s92N EtaX s2h6 VQYF 9Kao tHYL mkyL GKMh isZb KJ4E tvoC"
+KEY1_PASSPHRASE = "correct horse battery staple"
+
+KEY2_META = """{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "O0BOvTqiIAYjC+RMcyHfWw==",
+ "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
+}"""
+KEY2_ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv"
+KEY2_RECOVERY_KEY = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe"
+
+KEY2_META_BROKEN_IV = """{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow",
+ "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
+}"""
+
+KEY2_META_BROKEN_MAC = """{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "O0BOvTqiIAYjC+RMcyHfWw==",
+ "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow"
+}"""
+
+
+def get_key_meta(meta: str) -> KeyMetadata:
+ return KeyMetadata.parse_json(meta)
+
+
+def get_key1() -> Key:
+ return get_key_meta(KEY1_META).verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY)
+
+
+def get_key2() -> Key:
+ return get_key_meta(KEY2_META).verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+
+
+def get_encrypted_master_key() -> EncryptedAccountDataEventContent:
+ return EncryptedAccountDataEventContent.parse_json(KEY1_CROSS_SIGNING_MASTER_KEY)
+
+
+def test_decrypt_success() -> None:
+ key = get_key1()
+ emk = get_encrypted_master_key()
+ assert (
+ emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) == KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED
+ )
+
+
+def test_decrypt_fail_wrong_key() -> None:
+ key = get_key2()
+ emk = get_encrypted_master_key()
+ with pytest.raises(ValueError):
+ emk.decrypt(EventType.CROSS_SIGNING_MASTER, key)
+
+
+def test_decrypt_fail_fake_key() -> None:
+ key = get_key2()
+ key.id = KEY1_ID
+ emk = get_encrypted_master_key()
+ with pytest.raises(ValueError):
+ emk.decrypt(EventType.CROSS_SIGNING_MASTER, key)
+
+
+def test_decrypt_fail_wrong_type() -> None:
+ key = get_key1()
+ emk = get_encrypted_master_key()
+ with pytest.raises(ValueError):
+ emk.decrypt(EventType.CROSS_SIGNING_SELF_SIGNING, key)
+
+
+def test_encrypt_roundtrip() -> None:
+ key = get_key1()
+ data = bytes([0xDE, 0xAD, 0xBE, 0xEF])
+ ciphertext = key.encrypt("net.maunium.data", data)
+ plaintext = key.decrypt("net.maunium.data", ciphertext)
+ assert plaintext == data
+
+
+def test_verify_recovery_key_correct() -> None:
+ meta = get_key_meta(KEY1_META)
+ key = meta.verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY)
+ assert key.recovery_key == KEY1_RECOVERY_KEY
+
+
+def test_verify_recovery_key_correct2() -> None:
+ meta = get_key_meta(KEY2_META)
+ key = meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+ assert key.recovery_key == KEY2_RECOVERY_KEY
+
+
+def test_verify_recovery_key_invalid() -> None:
+ meta = get_key_meta(KEY1_META)
+ with pytest.raises(ValueError):
+ meta.verify_recovery_key(KEY1_ID, "foo")
+
+
+def test_verify_recovery_key_incorrect() -> None:
+ meta = get_key_meta(KEY1_META)
+ with pytest.raises(ValueError):
+ meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+
+
+def test_verify_recovery_key_broken_iv() -> None:
+ meta = get_key_meta(KEY2_META_BROKEN_IV)
+ with pytest.raises(ValueError):
+ meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+
+
+def test_verify_recovery_key_broken_mac() -> None:
+ meta = get_key_meta(KEY2_META_BROKEN_MAC)
+ with pytest.raises(ValueError):
+ meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+
+
+def test_verify_passphrase_correct() -> None:
+ meta = get_key_meta(KEY1_META)
+ key = meta.verify_passphrase(KEY1_ID, KEY1_PASSPHRASE)
+ assert key.recovery_key == KEY1_RECOVERY_KEY
+
+
+def test_verify_passphrase_incorrect() -> None:
+ meta = get_key_meta(KEY1_META)
+ with pytest.raises(ValueError):
+ meta.verify_passphrase(KEY1_ID, "incorrect horse battery staple")
+
+
+def test_verify_passphrase_notset() -> None:
+ meta = get_key_meta(KEY2_META)
+ with pytest.raises(ValueError):
+ meta.verify_passphrase(KEY2_ID, "hmm")
diff --git a/mautrix/crypto/ssss/machine.py b/mautrix/crypto/ssss/machine.py
new file mode 100644
index 00000000..c43e25e3
--- /dev/null
+++ b/mautrix/crypto/ssss/machine.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from mautrix import client as cli
+from mautrix.errors import MNotFound
+from mautrix.types import EventType, SecretStorageDefaultKeyEventContent
+
+from .key import Key, KeyMetadata
+from .types import EncryptedAccountDataEventContent
+
+
+class Machine:
+ client: cli.Client
+
+ def __init__(self, client: cli.Client) -> None:
+ self.client = client
+
+ async def get_default_key_id(self) -> str | None:
+ try:
+ data = await self.client.get_account_data(EventType.SECRET_STORAGE_DEFAULT_KEY)
+ return SecretStorageDefaultKeyEventContent.deserialize(data).key
+ except (MNotFound, ValueError):
+ return None
+
+ async def set_default_key_id(self, key_id: str) -> None:
+ await self.client.set_account_data(
+ EventType.SECRET_STORAGE_DEFAULT_KEY,
+ SecretStorageDefaultKeyEventContent(key=key_id),
+ )
+
+ async def get_key_data(self, key_id: str) -> KeyMetadata:
+ data = await self.client.get_account_data(f"m.secret_storage.key.{key_id}")
+ return KeyMetadata.deserialize(data)
+
+ async def set_key_data(self, key_id: str, data: KeyMetadata) -> None:
+ await self.client.set_account_data(f"m.secret_storage.key.{key_id}", data)
+
+ async def get_default_key_data(self) -> tuple[str, KeyMetadata]:
+ key_id = await self.get_default_key_id()
+ if not key_id:
+ raise ValueError("No default key ID set")
+ return key_id, await self.get_key_data(key_id)
+
+ async def get_decrypted_account_data(self, event_type: EventType | str, key: Key) -> bytes:
+ data = await self.client.get_account_data(event_type)
+ parsed = EncryptedAccountDataEventContent.deserialize(data)
+ return parsed.decrypt(event_type, key)
+
+ async def set_encrypted_account_data(
+ self, event_type: EventType | str, data: bytes, *keys: Key
+ ) -> None:
+ encrypted_data = {}
+ for key in keys:
+ encrypted_data[key.id] = key.encrypt(event_type, data)
+ await self.client.set_account_data(
+ event_type,
+ EncryptedAccountDataEventContent(encrypted=encrypted_data),
+ )
+
+ async def generate_and_upload_key(self, passphrase: str | None = None) -> Key:
+ key = Key.generate(passphrase)
+ await self.set_key_data(key.id, key.metadata)
+ return key
diff --git a/mautrix/crypto/ssss/types.py b/mautrix/crypto/ssss/types.py
new file mode 100644
index 00000000..4a47f743
--- /dev/null
+++ b/mautrix/crypto/ssss/types.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from typing import TYPE_CHECKING
+
+from attr import dataclass
+
+from mautrix.types import EventType, SerializableAttrs, SerializableEnum
+from mautrix.types.event.account_data import account_data_event_content_map
+
+if TYPE_CHECKING:
+ from .key import Key
+
+
+class Algorithm(SerializableEnum):
+ AES_HMAC_SHA2 = "m.secret_storage.v1.aes-hmac-sha2"
+ CURVE25519_AES_SHA2 = "m.secret_storage.v1.curve25519-aes-sha2"
+
+
+class PassphraseAlgorithm(SerializableEnum):
+ PBKDF2 = "m.pbkdf2"
+
+
+@dataclass
+class EncryptedKeyData(SerializableAttrs):
+ ciphertext: str
+ iv: str
+ mac: str
+
+
+@dataclass
+class EncryptedAccountDataEventContent(SerializableAttrs):
+ encrypted: dict[str, EncryptedKeyData]
+
+ def decrypt(self, event_type: str | EventType, key: "Key") -> bytes:
+ try:
+ encrypted_data = self.encrypted[key.id]
+ except KeyError as e:
+ raise ValueError(f"Event not encrypted for provided key") from e
+ return key.decrypt(event_type, encrypted_data)
+
+
+for encrypted_account_data_type in (
+ EventType.CROSS_SIGNING_MASTER,
+ EventType.CROSS_SIGNING_USER_SIGNING,
+ EventType.CROSS_SIGNING_SELF_SIGNING,
+ EventType.MEGOLM_BACKUP_V1,
+):
+ account_data_event_content_map[encrypted_account_data_type] = EncryptedAccountDataEventContent
diff --git a/mautrix/crypto/ssss/util.py b/mautrix/crypto/ssss/util.py
new file mode 100644
index 00000000..b58c941a
--- /dev/null
+++ b/mautrix/crypto/ssss/util.py
@@ -0,0 +1,78 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+import hashlib
+import hmac
+
+import base58
+import unpaddedbase64
+
+from mautrix.types import EventType
+
+try:
+ from Crypto import Random
+ from Crypto.Cipher import AES
+ from Crypto.Hash import SHA256
+ from Crypto.Protocol.KDF import HKDF
+ from Crypto.Util import Counter
+except ImportError:
+ from Cryptodome import Random
+ from Cryptodome.Cipher import AES
+ from Cryptodome.Hash import SHA256
+ from Cryptodome.Protocol.KDF import HKDF
+ from Cryptodome.Util import Counter
+
+cryptorand = Random.new()
+
+
+def decode_base58_recovery_key(key: str) -> bytes | None:
+ key_bytes = base58.b58decode(key.replace(" ", ""))
+ if len(key_bytes) != 35 or key_bytes[0] != 0x8B or key_bytes[1] != 1:
+ return None
+ parity = 0
+ for byte in key_bytes[:34]:
+ parity ^= byte
+ return key_bytes[2:34] if parity == key_bytes[34] else None
+
+
+def encode_base58_recovery_key(key: bytes) -> str:
+ key_bytes = bytearray(35)
+ key_bytes[0] = 0x8B
+ key_bytes[1] = 1
+ key_bytes[2:34] = key
+ parity = 0
+ for byte in key_bytes:
+ parity ^= byte
+ key_bytes[34] = parity
+ encoded_key = base58.b58encode(key_bytes).decode("utf-8")
+ return " ".join(encoded_key[i : i + 4] for i in range(0, len(encoded_key), 4))
+
+
+def derive_keys(key: bytes, name: str | EventType = "") -> tuple[bytes, bytes]:
+ aes_key, hmac_key = HKDF(
+ master=key,
+ key_len=32,
+ salt=b"\x00" * 32,
+ hashmod=SHA256,
+ num_keys=2,
+ context=str(name).encode("utf-8"),
+ )
+ return aes_key, hmac_key
+
+
+def prepare_aes(key: bytes, iv: str | bytes) -> AES:
+ if isinstance(iv, str):
+ iv = unpaddedbase64.decode_base64(iv)
+ # initial_value = struct.unpack(">Q", iv[8:])[0]
+ # counter = Counter.new(64, prefix=iv[:8], initial_value=initial_value)
+ counter = Counter.new(128, initial_value=int.from_bytes(iv, byteorder="big"))
+ return AES.new(key=key, mode=AES.MODE_CTR, counter=counter)
+
+
+def calculate_hash(key: bytes, iv: str | bytes) -> str:
+ aes_key, hmac_key = derive_keys(key)
+ cipher = prepare_aes(aes_key, iv).decrypt(b"\x00" * 32)
+ digest = hmac.digest(hmac_key, cipher, hashlib.sha256)
+ return unpaddedbase64.encode_base64(digest)
diff --git a/mautrix/crypto/store/abstract.py b/mautrix/crypto/store/abstract.py
index 787d5225..0916a2d7 100644
--- a/mautrix/crypto/store/abstract.py
+++ b/mautrix/crypto/store/abstract.py
@@ -5,8 +5,9 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import NamedTuple
+from typing import AsyncContextManager, NamedTuple
from abc import ABC, abstractmethod
+from contextlib import asynccontextmanager
from mautrix.types import (
CrossSigner,
@@ -87,6 +88,11 @@ async def close(self) -> None:
async def flush(self) -> None:
"""Flush the store. If all the methods persist data immediately, this can be a no-op."""
+ @asynccontextmanager
+ async def transaction(self) -> None:
+ """Run a database transaction. If the store doesn't support transactions, this can be a no-op."""
+ yield
+
@abstractmethod
async def delete(self) -> None:
"""Delete the data in the store."""
diff --git a/mautrix/crypto/store/asyncpg/store.py b/mautrix/crypto/store/asyncpg/store.py
index a29f7737..bdc37ddd 100644
--- a/mautrix/crypto/store/asyncpg/store.py
+++ b/mautrix/crypto/store/asyncpg/store.py
@@ -6,6 +6,7 @@
from __future__ import annotations
from collections import defaultdict
+from contextlib import asynccontextmanager
from datetime import timedelta
from asyncpg import UniqueViolationError
@@ -79,6 +80,11 @@ def __init__(self, account_id: str, pickle_key: str, db: Database) -> None:
self._account = None
self._olm_cache = defaultdict(lambda: {})
+ @asynccontextmanager
+ async def transaction(self) -> None:
+ async with self.db.acquire() as conn, conn.transaction():
+ yield
+
async def delete(self) -> None:
tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session")
async with self.db.acquire() as conn, conn.transaction():
diff --git a/mautrix/crypto/store/asyncpg/upgrade.py b/mautrix/crypto/store/asyncpg/upgrade.py
index e097c5d9..8d413858 100644
--- a/mautrix/crypto/store/asyncpg/upgrade.py
+++ b/mautrix/crypto/store/asyncpg/upgrade.py
@@ -18,32 +18,32 @@
@upgrade_table.register(description="Latest revision", upgrades_to=10)
async def upgrade_blank_to_latest(conn: Connection) -> None:
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_account (
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_message_index (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id TEXT NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_tracked_user (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_tracked_user (
user_id TEXT PRIMARY KEY
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_device (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_device (
user_id TEXT,
device_id TEXT,
identity_key CHAR(43) NOT NULL,
@@ -52,10 +52,10 @@ async def upgrade_blank_to_latest(conn: Connection) -> None:
deleted BOOLEAN NOT NULL,
name TEXT NOT NULL,
PRIMARY KEY (user_id, device_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_olm_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_olm_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
@@ -64,10 +64,10 @@ async def upgrade_blank_to_latest(conn: Connection) -> None:
last_decrypted timestamp NOT NULL,
last_encrypted timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
@@ -83,10 +83,10 @@ async def upgrade_blank_to_latest(conn: Connection) -> None:
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
account_id TEXT,
room_id TEXT,
session_id CHAR(43) NOT NULL UNIQUE,
@@ -98,10 +98,10 @@ async def upgrade_blank_to_latest(conn: Connection) -> None:
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_cross_signing_keys (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
key CHAR(43) NOT NULL,
@@ -109,18 +109,18 @@ async def upgrade_blank_to_latest(conn: Connection) -> None:
first_seen_key CHAR(43) NOT NULL,
PRIMARY KEY (user_id, usage)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_cross_signing_signatures (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
- )"""
- )
+ )
+ """)
@upgrade_table.register(description="Add account_id primary key column")
@@ -130,17 +130,17 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
await conn.execute("DROP TABLE crypto_olm_session")
await conn.execute("DROP TABLE crypto_megolm_inbound_session")
await conn.execute("DROP TABLE crypto_megolm_outbound_session")
- await conn.execute(
- """CREATE TABLE crypto_account (
+ await conn.execute("""
+ CREATE TABLE crypto_account (
account_id VARCHAR(255) PRIMARY KEY,
device_id VARCHAR(255) NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_olm_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_olm_session (
account_id VARCHAR(255),
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
@@ -148,10 +148,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_megolm_inbound_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_megolm_inbound_session (
account_id VARCHAR(255),
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
@@ -160,10 +160,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
session bytea NOT NULL,
forwarding_chains TEXT NOT NULL,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_megolm_outbound_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_megolm_outbound_session (
account_id VARCHAR(255),
room_id VARCHAR(255),
session_id CHAR(43) NOT NULL UNIQUE,
@@ -175,8 +175,8 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
- )"""
- )
+ )
+ """)
else:
async def add_account_id_column(table: str, pkey_columns: list[str]) -> None:
@@ -233,25 +233,25 @@ async def upgrade_v4(conn: Connection, scheme: Scheme) -> None:
@upgrade_table.register(description="Add cross-signing key and signature caches")
async def upgrade_v5(conn: Connection) -> None:
- await conn.execute(
- """CREATE TABLE crypto_cross_signing_keys (
+ await conn.execute("""
+ CREATE TABLE crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
key CHAR(43),
first_seen_key CHAR(43),
PRIMARY KEY (user_id, usage)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_cross_signing_signatures (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
signature TEXT,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
- )"""
- )
+ )
+ """)
@upgrade_table.register(description="Update trust state values")
@@ -322,27 +322,25 @@ async def upgrade_v9_postgres(conn: Connection) -> None:
async def upgrade_v9_sqlite(conn: Connection) -> None:
await conn.execute("PRAGMA foreign_keys = OFF")
async with conn.transaction():
- await conn.execute(
- """CREATE TABLE new_crypto_account (
+ await conn.execute("""
+ CREATE TABLE new_crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
- )"""
- )
- await conn.execute(
- """
+ )
+ """)
+ await conn.execute("""
INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account)
SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account
FROM crypto_account
- """
- )
+ """)
await conn.execute("DROP TABLE crypto_account")
await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account")
- await conn.execute(
- """CREATE TABLE new_crypto_megolm_inbound_session (
+ await conn.execute("""
+ CREATE TABLE new_crypto_megolm_inbound_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
@@ -353,10 +351,9 @@ async def upgrade_v9_sqlite(conn: Connection) -> None:
withheld_code TEXT,
withheld_reason TEXT,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """
+ )
+ """)
+ await conn.execute("""
INSERT INTO new_crypto_megolm_inbound_session (
account_id, session_id, sender_key, signing_key, room_id, session,
forwarding_chains
@@ -364,8 +361,7 @@ async def upgrade_v9_sqlite(conn: Connection) -> None:
SELECT account_id, session_id, sender_key, signing_key, room_id, session,
forwarding_chains
FROM crypto_megolm_inbound_session
- """
- )
+ """)
await conn.execute("DROP TABLE crypto_megolm_inbound_session")
await conn.execute(
"ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session"
@@ -373,8 +369,8 @@ async def upgrade_v9_sqlite(conn: Connection) -> None:
await conn.execute("UPDATE crypto_megolm_outbound_session SET max_age=max_age*1000")
- await conn.execute(
- """CREATE TABLE new_crypto_cross_signing_keys (
+ await conn.execute("""
+ CREATE TABLE new_crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
key CHAR(43) NOT NULL,
@@ -382,41 +378,37 @@ async def upgrade_v9_sqlite(conn: Connection) -> None:
first_seen_key CHAR(43) NOT NULL,
PRIMARY KEY (user_id, usage)
- )"""
- )
- await conn.execute(
- """
+ )
+ """)
+ await conn.execute("""
INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key)
SELECT user_id, usage, key, COALESCE(first_seen_key, key)
FROM crypto_cross_signing_keys
WHERE key IS NOT NULL
- """
- )
+ """)
await conn.execute("DROP TABLE crypto_cross_signing_keys")
await conn.execute(
"ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys"
)
- await conn.execute(
- """CREATE TABLE new_crypto_cross_signing_signatures (
+ await conn.execute("""
+ CREATE TABLE new_crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
- )"""
- )
- await conn.execute(
- """
+ )
+ """)
+ await conn.execute("""
INSERT INTO new_crypto_cross_signing_signatures (
signed_user_id, signed_key, signer_user_id, signer_key, signature
)
SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature
FROM crypto_cross_signing_signatures
WHERE signature IS NOT NULL
- """
- )
+ """)
await conn.execute("DROP TABLE crypto_cross_signing_signatures")
await conn.execute(
"ALTER TABLE new_crypto_cross_signing_signatures "
diff --git a/mautrix/types/__init__.py b/mautrix/types/__init__.py
index 42b9068c..ceceeaca 100644
--- a/mautrix/types/__init__.py
+++ b/mautrix/types/__init__.py
@@ -57,6 +57,7 @@
CallRejectEventContent,
CallSelectAnswerEventContent,
CanonicalAliasStateEventContent,
+ DirectAccountDataEventContent,
EncryptedEvent,
EncryptedEventContent,
EncryptedFile,
@@ -122,6 +123,7 @@
RoomTombstoneStateEventContent,
RoomTopicStateEventContent,
RoomType,
+ SecretStorageDefaultKeyEventContent,
SingleReceiptEventContent,
SpaceChildStateEventContent,
SpaceParentStateEventContent,
@@ -259,6 +261,7 @@
"CallRejectEventContent",
"CallSelectAnswerEventContent",
"CanonicalAliasStateEventContent",
+ "DirectAccountDataEventContent",
"EncryptedEvent",
"EncryptedEventContent",
"EncryptedFile",
@@ -324,6 +327,7 @@
"RoomTombstoneStateEventContent",
"RoomTopicStateEventContent",
"RoomType",
+ "SecretStorageDefaultKeyEventContent",
"SingleReceiptEventContent",
"SpaceChildStateEventContent",
"SpaceParentStateEventContent",
@@ -354,6 +358,7 @@
"OpenGraphImage",
"OpenGraphVideo",
"BatchSendResponse",
+ "BeeperBatchSendResponse",
"DeviceLists",
"DeviceOTKCount",
"DirectoryPaginationToken",
diff --git a/mautrix/types/crypto.py b/mautrix/types/crypto.py
index 3bf7e96b..fe4ab742 100644
--- a/mautrix/types/crypto.py
+++ b/mautrix/types/crypto.py
@@ -47,9 +47,9 @@ def curve25519(self) -> Optional[IdentityKey]:
class CrossSigningUsage(ExtensibleEnum):
- MASTER = "master"
- SELF = "self_signing"
- USER = "user_signing"
+ MASTER: "CrossSigningUsage" = "master"
+ SELF: "CrossSigningUsage" = "self_signing"
+ USER: "CrossSigningUsage" = "user_signing"
@dataclass
diff --git a/mautrix/types/event/__init__.py b/mautrix/types/event/__init__.py
index b391e912..db0658db 100644
--- a/mautrix/types/event/__init__.py
+++ b/mautrix/types/event/__init__.py
@@ -6,8 +6,10 @@
from .account_data import (
AccountDataEvent,
AccountDataEventContent,
+ DirectAccountDataEventContent,
RoomTagAccountDataEventContent,
RoomTagInfo,
+ SecretStorageDefaultKeyEventContent,
)
from .base import BaseEvent, BaseRoomEvent, BaseUnsigned, GenericEvent
from .batch import BatchSendEvent, BatchSendStateEvent
diff --git a/mautrix/types/event/account_data.py b/mautrix/types/event/account_data.py
index 3c144a36..fdfa0a30 100644
--- a/mautrix/types/event/account_data.py
+++ b/mautrix/types/event/account_data.py
@@ -3,7 +3,7 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Dict, List, Union
+from typing import TYPE_CHECKING, Dict, List, Union
from attr import dataclass
import attr
@@ -12,6 +12,9 @@
from ..util import Obj, SerializableAttrs, deserializer
from .base import BaseEvent, EventType
+if TYPE_CHECKING:
+ from mautrix.crypto.ssss import EncryptedAccountDataEventContent, KeyMetadata
+
@dataclass
class RoomTagInfo(SerializableAttrs):
@@ -23,11 +26,24 @@ class RoomTagAccountDataEventContent(SerializableAttrs):
tags: Dict[str, RoomTagInfo] = attr.ib(default=None, metadata={"json": "tags"})
+@dataclass
+class SecretStorageDefaultKeyEventContent(SerializableAttrs):
+ key: str
+
+
DirectAccountDataEventContent = Dict[UserID, List[RoomID]]
-AccountDataEventContent = Union[RoomTagAccountDataEventContent, DirectAccountDataEventContent, Obj]
+AccountDataEventContent = Union[
+ RoomTagAccountDataEventContent,
+ DirectAccountDataEventContent,
+ SecretStorageDefaultKeyEventContent,
+ "EncryptedAccountDataEventContent",
+ "KeyMetadata",
+ Obj,
+]
account_data_event_content_map = {
EventType.TAG: RoomTagAccountDataEventContent,
+ EventType.SECRET_STORAGE_DEFAULT_KEY: SecretStorageDefaultKeyEventContent,
# m.direct doesn't really need deserializing
# EventType.DIRECT: DirectAccountDataEventContent,
}
diff --git a/mautrix/types/event/encrypted.py b/mautrix/types/event/encrypted.py
index cd08fc94..735e481a 100644
--- a/mautrix/types/event/encrypted.py
+++ b/mautrix/types/event/encrypted.py
@@ -9,7 +9,7 @@
from attr import dataclass
-from ..primitive import JSON, DeviceID, IdentityKey, SessionID
+from ..primitive import JSON, DeviceID, IdentityKey, SessionID, SigningKey
from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field
from .base import BaseRoomEvent, BaseUnsigned
from .message import RelatesTo
@@ -43,6 +43,18 @@ def deserialize(cls, raw: JSON) -> "KeyID":
def __str__(self) -> str:
return f"{self.algorithm.value}:{self.key_id}"
+ @classmethod
+ def ed25519(cls, key_id: SigningKey | DeviceID) -> "KeyID":
+ return cls(EncryptionKeyAlgorithm.ED25519, key_id)
+
+ @classmethod
+ def curve25519(cls, key_id: IdentityKey) -> "KeyID":
+ return cls(EncryptionKeyAlgorithm.CURVE25519, key_id)
+
+ @classmethod
+ def signed_curve25519(cls, key_id: IdentityKey) -> "KeyID":
+ return cls(EncryptionKeyAlgorithm.SIGNED_CURVE25519, key_id)
+
class OlmMsgType(Serializable, IntEnum):
PREKEY = 0
diff --git a/mautrix/types/event/message.py b/mautrix/types/event/message.py
index 8d5f1bba..32033581 100644
--- a/mautrix/types/event/message.py
+++ b/mautrix/types/event/message.py
@@ -119,7 +119,7 @@ def set_thread_parent(
thread_parent.content.get_thread_parent() or self.relates_to.event_id
)
if not disable_reply_fallback:
- self.set_reply(last_event_in_thread or thread_parent, disable_fallback=True, **kwargs)
+ self.set_reply(last_event_in_thread or thread_parent, **kwargs)
self.relates_to.is_falling_back = True
def set_edit(self, edits: Union[EventID, "MessageEvent"]) -> None:
@@ -271,33 +271,6 @@ class LocationInfo(SerializableAttrs):
# region Event content
-@dataclass
-class MediaMessageEventContent(BaseMessageEventContent, SerializableAttrs):
- """The content of a media message event (m.image, m.audio, m.video, m.file)"""
-
- url: Optional[ContentURI] = None
- info: Optional[MediaInfo] = None
- file: Optional[EncryptedFile] = None
-
- @staticmethod
- @deserializer(MediaInfo)
- @deserializer(Optional[MediaInfo])
- def deserialize_info(data: JSON) -> MediaInfo:
- if not isinstance(data, dict):
- return Obj()
- msgtype = data.pop("__mautrix_msgtype", None)
- if msgtype == "m.image" or msgtype == "m.sticker":
- return ImageInfo.deserialize(data)
- elif msgtype == "m.video":
- return VideoInfo.deserialize(data)
- elif msgtype == "m.audio":
- return AudioInfo.deserialize(data)
- elif msgtype == "m.file":
- return FileInfo.deserialize(data)
- else:
- return Obj(**data)
-
-
@dataclass
class LocationMessageEventContent(BaseMessageEventContent, SerializableAttrs):
geo_uri: str = None
@@ -314,25 +287,6 @@ class TextMessageEventContent(BaseMessageEventContent, SerializableAttrs):
format: Format = None
formatted_body: str = None
- def set_reply(
- self,
- reply_to: Union["MessageEvent", EventID],
- *,
- displayname: Optional[str] = None,
- disable_fallback: bool = False,
- ) -> None:
- super().set_reply(reply_to)
- if isinstance(reply_to, str):
- return
- if isinstance(reply_to, MessageEvent) and not disable_fallback:
- self.ensure_has_html()
- if isinstance(reply_to.content, TextMessageEventContent):
- reply_to.content.trim_reply_fallback()
- self.formatted_body = (
- reply_to.make_reply_fallback_html(displayname) + self.formatted_body
- )
- self.body = reply_to.make_reply_fallback_text(displayname) + self.body
-
def ensure_has_html(self) -> None:
if not self.formatted_body or self.format != Format.HTML:
self.format = Format.HTML
@@ -364,6 +318,34 @@ def _trim_reply_fallback_html(self) -> None:
self.formatted_body = html_reply_fallback_regex.sub("", self.formatted_body)
+@dataclass
+class MediaMessageEventContent(TextMessageEventContent, SerializableAttrs):
+ """The content of a media message event (m.image, m.audio, m.video, m.file)"""
+
+ url: Optional[ContentURI] = None
+ info: Optional[MediaInfo] = None
+ file: Optional[EncryptedFile] = None
+ filename: Optional[str] = None
+
+ @staticmethod
+ @deserializer(MediaInfo)
+ @deserializer(Optional[MediaInfo])
+ def deserialize_info(data: JSON) -> MediaInfo:
+ if not isinstance(data, dict):
+ return Obj()
+ msgtype = data.pop("__mautrix_msgtype", None)
+ if msgtype == "m.image" or msgtype == "m.sticker":
+ return ImageInfo.deserialize(data)
+ elif msgtype == "m.video":
+ return VideoInfo.deserialize(data)
+ elif msgtype == "m.audio":
+ return AudioInfo.deserialize(data)
+ elif msgtype == "m.file":
+ return FileInfo.deserialize(data)
+ else:
+ return Obj(**data)
+
+
MessageEventContent = Union[
TextMessageEventContent, MediaMessageEventContent, LocationMessageEventContent, Obj
]
@@ -423,36 +405,3 @@ def deserialize_content(data: JSON) -> MessageEventContent:
return LocationMessageEventContent.deserialize(data)
else:
return Obj(**data)
-
- def make_reply_fallback_html(self, displayname: Optional[str] = None) -> str:
- """Generate the HTML fallback for messages replying to this event."""
- if self.content.msgtype.is_text:
- body = self.content.formatted_body or escape(self.content.body).replace("\n", "
")
- else:
- sent_type = media_reply_fallback_body_map[self.content.msgtype] or "a message"
- body = f"sent {sent_type}"
- displayname = escape(displayname) if displayname else self.sender
- return html_reply_fallback_format.format(
- room_id=self.room_id,
- event_id=self.event_id,
- sender=self.sender,
- displayname=displayname,
- content=body,
- )
-
- def make_reply_fallback_text(self, displayname: Optional[str] = None) -> str:
- """Generate the plaintext fallback for messages replying to this event."""
- if self.content.msgtype.is_text:
- body = self.content.body
- else:
- try:
- body = media_reply_fallback_body_map[self.content.msgtype]
- except KeyError:
- body = "an unknown message type"
- lines = body.strip().split("\n")
- first_line, lines = lines[0], lines[1:]
- fallback_text = f"> <{displayname or self.sender}> {first_line}"
- for line in lines:
- fallback_text += f"\n> {line}"
- fallback_text += "\n\n"
- return fallback_text
diff --git a/mautrix/types/event/state.py b/mautrix/types/event/state.py
index c6b351c6..5ffc855f 100644
--- a/mautrix/types/event/state.py
+++ b/mautrix/types/event/state.py
@@ -9,7 +9,7 @@
import attr
from ..primitive import JSON, ContentURI, EventID, RoomAlias, RoomID, UserID
-from ..util import Obj, SerializableAttrs, SerializableEnum, deserializer, field
+from ..util import ExtensibleEnum, Obj, SerializableAttrs, SerializableEnum, deserializer, field
from .base import BaseRoomEvent, BaseUnsigned
from .encrypted import EncryptionAlgorithm
from .type import EventType, RoomType
@@ -41,7 +41,19 @@ class PowerLevelStateEventContent(SerializableAttrs):
ban: int = 50
redact: int = 50
- def get_user_level(self, user_id: UserID) -> int:
+ def get_user_level(
+ self,
+ user_id: UserID,
+ create: Optional["StateEvent"] = None,
+ ) -> int:
+ if (
+ create
+ and create.content.supports_creator_power
+ and (user_id == create.sender or user_id in (create.content.additional_creators or []))
+ ):
+ # This is really meant to be infinity, but involving floats would be annoying,
+ # so we use an integer larger than the maximum power level (2^53-1) instead.
+ return 2**60 - 1
return int(self.users.get(user_id, self.users_default))
def set_user_level(self, user_id: UserID, level: int) -> None:
@@ -50,7 +62,16 @@ def set_user_level(self, user_id: UserID, level: int) -> None:
else:
self.users[user_id] = level
- def ensure_user_level(self, user_id: UserID, level: int) -> bool:
+ def ensure_user_level(
+ self, user_id: UserID, level: int, create: Optional["StateEvent"] = None
+ ) -> bool:
+ if (
+ create
+ and create.content.supports_creator_power
+ and (user_id == create.sender or user_id in (create.content.additional_creators or []))
+ ):
+ # Don't try to set creator power levels
+ return False
if self.get_user_level(user_id) != level:
self.set_user_level(user_id, level)
return True
@@ -138,16 +159,15 @@ class RoomAvatarStateEventContent(SerializableAttrs):
url: Optional[ContentURI] = None
-class JoinRule(SerializableEnum):
+class JoinRule(ExtensibleEnum):
PUBLIC = "public"
KNOCK = "knock"
RESTRICTED = "restricted"
INVITE = "invite"
- PRIVATE = "private"
KNOCK_RESTRICTED = "knock_restricted"
-class JoinRestrictionType(SerializableEnum):
+class JoinRestrictionType(ExtensibleEnum):
ROOM_MEMBERSHIP = "m.room_membership"
@@ -193,6 +213,24 @@ class RoomCreateStateEventContent(SerializableAttrs):
federate: bool = field(json="m.federate", omit_default=True, default=True)
predecessor: Optional[RoomPredecessor] = None
type: Optional[RoomType] = None
+ additional_creators: Optional[List[UserID]] = None
+
+ @property
+ def supports_creator_power(self) -> bool:
+ return self.room_version not in (
+ "",
+ "1",
+ "2",
+ "3",
+ "4",
+ "5",
+ "6",
+ "7",
+ "8",
+ "9",
+ "10",
+ "11",
+ )
@dataclass
diff --git a/mautrix/types/event/type.py b/mautrix/types/event/type.py
index 509d6857..5faf3785 100644
--- a/mautrix/types/event/type.py
+++ b/mautrix/types/event/type.py
@@ -207,6 +207,11 @@ def is_to_device(self) -> bool:
"m.push_rules": "PUSH_RULES",
"m.tag": "TAG",
"m.ignored_user_list": "IGNORED_USER_LIST",
+ "m.secret_storage.default_key": "SECRET_STORAGE_DEFAULT_KEY",
+ "m.cross_signing.master": "CROSS_SIGNING_MASTER",
+ "m.cross_signing.self_signing": "CROSS_SIGNING_SELF_SIGNING",
+ "m.cross_signing.user_signing": "CROSS_SIGNING_USER_SIGNING",
+ "m.megolm_backup.v1": "MEGOLM_BACKUP_V1",
},
EventType.Class.TO_DEVICE: {
"m.room.encrypted": "TO_DEVICE_ENCRYPTED",
diff --git a/mautrix/types/event/type.pyi b/mautrix/types/event/type.pyi
index 22922288..a2788d6f 100644
--- a/mautrix/types/event/type.pyi
+++ b/mautrix/types/event/type.pyi
@@ -61,6 +61,11 @@ class EventType(Serializable):
PUSH_RULES: "EventType"
TAG: "EventType"
IGNORED_USER_LIST: "EventType"
+ SECRET_STORAGE_DEFAULT_KEY: "EventType"
+ CROSS_SIGNING_MASTER: "EventType"
+ CROSS_SIGNING_SELF_SIGNING: "EventType"
+ CROSS_SIGNING_USER_SIGNING: "EventType"
+ MEGOLM_BACKUP_V1: "EventType"
TO_DEVICE_ENCRYPTED: "EventType"
TO_DEVICE_DUMMY: "EventType"
diff --git a/mautrix/types/misc.py b/mautrix/types/misc.py
index cf576c2b..5a07699c 100644
--- a/mautrix/types/misc.py
+++ b/mautrix/types/misc.py
@@ -106,7 +106,7 @@ class RoomDirectoryResponse(SerializableAttrs):
PaginatedMessages = NamedTuple(
- "PaginatedMessages", start=SyncToken, end=SyncToken, events=List[Event]
+ "PaginatedMessages", start=SyncToken, end=Optional[SyncToken], events=List[Event]
)
diff --git a/mautrix/types/versions.py b/mautrix/types/versions.py
index 9ad81710..52a62f59 100644
--- a/mautrix/types/versions.py
+++ b/mautrix/types/versions.py
@@ -74,6 +74,10 @@ class SpecVersions:
V15 = Version.deserialize("v1.5")
V16 = Version.deserialize("v1.6")
V17 = Version.deserialize("v1.7")
+ V18 = Version.deserialize("v1.8")
+ V19 = Version.deserialize("v1.9")
+ V110 = Version.deserialize("v1.10")
+ V111 = Version.deserialize("v1.11")
@dataclass
diff --git a/mautrix/util/async_db/aiosqlite.py b/mautrix/util/async_db/aiosqlite.py
index 80fa3cf8..934379a8 100644
--- a/mautrix/util/async_db/aiosqlite.py
+++ b/mautrix/util/async_db/aiosqlite.py
@@ -7,6 +7,7 @@
from typing import Any, AsyncContextManager
from contextlib import asynccontextmanager
+from contextvars import ContextVar
import asyncio
import logging
import os
@@ -24,6 +25,9 @@
POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)")
+in_transaction = ContextVar("in_transaction", default=False)
+
+
class TxnConnection(aiosqlite.Connection):
def __init__(self, path: str, **kwargs) -> None:
def connector() -> sqlite3.Connection:
@@ -35,7 +39,11 @@ def connector() -> sqlite3.Connection:
@asynccontextmanager
async def transaction(self) -> None:
+ if in_transaction.get():
+ yield
+ return
await self.execute("BEGIN TRANSACTION")
+ token = in_transaction.set(True)
try:
yield
except Exception:
@@ -43,6 +51,8 @@ async def transaction(self) -> None:
raise
else:
await self.commit()
+ finally:
+ in_transaction.reset(token)
def __execute(self, query: str, *args: Any):
query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query)
@@ -181,7 +191,7 @@ async def stop(self) -> None:
self._conns -= 1
await conn.close()
- def acquire(self) -> AsyncContextManager[LoggingConnection]:
+ def acquire_direct(self) -> AsyncContextManager[LoggingConnection]:
if self._parent:
return self._parent.acquire()
return self._acquire()
diff --git a/mautrix/util/async_db/asyncpg.py b/mautrix/util/async_db/asyncpg.py
index 07ef7ad7..97b49f6c 100644
--- a/mautrix/util/async_db/asyncpg.py
+++ b/mautrix/util/async_db/asyncpg.py
@@ -96,7 +96,7 @@ async def _handle_exception(self, err: Exception) -> None:
sys.exit(26)
@asynccontextmanager
- async def acquire(self) -> LoggingConnection:
+ async def acquire_direct(self) -> LoggingConnection:
async with self.pool.acquire() as conn:
yield LoggingConnection(
self.scheme, conn, self.log, handle_exception=self._handle_exception
diff --git a/mautrix/util/async_db/database.py b/mautrix/util/async_db/database.py
index 0f23b02d..b5128b74 100644
--- a/mautrix/util/async_db/database.py
+++ b/mautrix/util/async_db/database.py
@@ -7,6 +7,8 @@
from typing import Any, AsyncContextManager, Type
from abc import ABC, abstractmethod
+from contextlib import asynccontextmanager
+from contextvars import ContextVar
import logging
from yarl import URL
@@ -23,6 +25,8 @@
from aiosqlite import Cursor
from asyncpg import Record
+conn_var: ContextVar[LoggingConnection | None] = ContextVar("db_connection", default=None)
+
class Database(ABC):
schemes: dict[str, Type[Database]] = {}
@@ -111,15 +115,17 @@ async def _check_foreign_tables(self) -> None:
raise ForeignTablesFound("found roomserver_rooms likely belonging to Dendrite")
async def _check_owner(self) -> None:
- await self.execute(
- """CREATE TABLE IF NOT EXISTS database_owner (
+ await self.execute("""
+ CREATE TABLE IF NOT EXISTS database_owner (
key INTEGER PRIMARY KEY DEFAULT 0,
owner TEXT NOT NULL
- )"""
- )
+ )
+ """)
owner = await self.fetchval("SELECT owner FROM database_owner WHERE key=0")
if not owner:
- await self.execute("INSERT INTO database_owner (owner) VALUES ($1)", self.owner_name)
+ await self.execute(
+ "INSERT INTO database_owner (key, owner) VALUES (0, $1)", self.owner_name
+ )
elif owner != self.owner_name:
raise DatabaseNotOwned(owner)
@@ -128,9 +134,22 @@ async def stop(self) -> None:
pass
@abstractmethod
- def acquire(self) -> AsyncContextManager[LoggingConnection]:
+ def acquire_direct(self) -> AsyncContextManager[LoggingConnection]:
pass
+ @asynccontextmanager
+ async def acquire(self) -> LoggingConnection:
+ conn = conn_var.get(None)
+ if conn is not None:
+ yield conn
+ return
+ async with self.acquire_direct() as conn:
+ token = conn_var.set(conn)
+ try:
+ yield conn
+ finally:
+ conn_var.reset(token)
+
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor:
async with self.acquire() as conn:
return await conn.execute(query, *args, timeout=timeout)
diff --git a/mautrix/util/async_db/upgrade.py b/mautrix/util/async_db/upgrade.py
index 9e593ece..c084d28b 100644
--- a/mautrix/util/async_db/upgrade.py
+++ b/mautrix/util/async_db/upgrade.py
@@ -97,11 +97,11 @@ async def _save_version(self, conn: LoggingConnection, version: int) -> None:
await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version)
async def upgrade(self, db: async_db.Database) -> None:
- await db.execute(
- f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} (
+ await db.execute(f"""
+ CREATE TABLE IF NOT EXISTS {self.version_table_name} (
version INTEGER PRIMARY KEY
- )"""
- )
+ )
+ """)
row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1")
version = row["version"] if row else 0
diff --git a/mautrix/util/program.py b/mautrix/util/program.py
index 2fb6b141..74752ca5 100644
--- a/mautrix/util/program.py
+++ b/mautrix/util/program.py
@@ -98,7 +98,7 @@ def _prepare(self) -> None:
self.log.info(f"Initializing {self.name} {self.version}")
try:
- self.prepare()
+ self.loop.run_until_complete(self._async_prepare())
except Exception:
self.log.critical("Unexpected error in initialization", exc_info=True)
sys.exit(1)
@@ -117,6 +117,7 @@ def preinit(self) -> None:
self.prepare_config()
self.prepare_log()
self.check_config()
+ self.init_loop()
@property
def base_config_path(self) -> str:
@@ -168,14 +169,16 @@ def prepare_log(self) -> None:
logging.config.dictConfig(copy.deepcopy(self.config["logging"]))
self.log = cast(TraceLogger, logging.getLogger("mau.init"))
+ async def _async_prepare(self) -> None:
+ self.prepare()
+
def prepare(self) -> None:
"""
Lifecycle method where the primary program initialization happens.
Use this to fill startup_actions with async startup tasks.
"""
- self.prepare_loop()
- def prepare_loop(self) -> None:
+ def init_loop(self) -> None:
"""Init lifecycle method where the asyncio event loop is created."""
if uvloop is not None:
uvloop.install()
@@ -188,6 +191,7 @@ def start_prometheus(self) -> None:
try:
enabled = self.config["metrics.enabled"]
listen_port = self.config["metrics.listen_port"]
+ hostname = self.config.get("metrics.hostname", "0.0.0.0")
except KeyError:
return
if not enabled:
@@ -197,7 +201,7 @@ def start_prometheus(self) -> None:
"Metrics are enabled in config, but prometheus_client is not installed"
)
return
- prometheus.start_http_server(listen_port)
+ prometheus.start_http_server(listen_port, addr=hostname)
def _run(self) -> None:
signal.signal(signal.SIGINT, signal.default_int_handler)
diff --git a/mautrix/util/variation_selector.json b/mautrix/util/variation_selector.json
index 0205ce94..f27acb7b 100644
--- a/mautrix/util/variation_selector.json
+++ b/mautrix/util/variation_selector.json
@@ -31,9 +31,12 @@
"23CF": "⏏",
"23E9": "⏩",
"23EA": "⏪",
+ "23EB": "⏫",
+ "23EC": "⏬",
"23ED": "⏭",
"23EE": "⏮",
"23EF": "⏯",
+ "23F0": "⏰",
"23F1": "⏱",
"23F2": "⏲",
"23F3": "⏳",
@@ -114,6 +117,7 @@
"26C4": "⛄",
"26C5": "⛅",
"26C8": "⛈",
+ "26CE": "⛎",
"26CF": "⛏",
"26D1": "⛑",
"26D3": "⛓",
@@ -132,8 +136,11 @@
"26FA": "⛺",
"26FD": "⛽",
"2702": "✂",
+ "2705": "✅",
"2708": "✈",
"2709": "✉",
+ "270A": "✊",
+ "270B": "✋",
"270C": "✌",
"270D": "✍",
"270F": "✏",
@@ -142,15 +149,25 @@
"2716": "✖",
"271D": "✝",
"2721": "✡",
+ "2728": "✨",
"2733": "✳",
"2734": "✴",
"2744": "❄",
"2747": "❇",
+ "274C": "❌",
+ "274E": "❎",
"2753": "❓",
+ "2754": "❔",
+ "2755": "❕",
"2757": "❗",
"2763": "❣",
"2764": "❤",
+ "2795": "➕",
+ "2796": "➖",
+ "2797": "➗",
"27A1": "➡",
+ "27B0": "➰",
+ "27BF": "➿",
"2934": "⤴",
"2935": "⤵",
"2B05": "⬅",
diff --git a/mautrix/util/variation_selector.py b/mautrix/util/variation_selector.py
index 498f1abe..e1430293 100644
--- a/mautrix/util/variation_selector.py
+++ b/mautrix/util/variation_selector.py
@@ -10,7 +10,7 @@
import aiohttp
-EMOJI_VAR_URL = "https://www.unicode.org/Public/14.0.0/ucd/emoji/emoji-variation-sequences.txt"
+EMOJI_VAR_URL = "https://www.unicode.org/Public/17.0.0/ucd/emoji/emoji-variation-sequences.txt"
def read_data() -> dict[str, str]:
@@ -43,11 +43,12 @@ async def fetch_data() -> dict[str, str]:
if __name__ == "__main__":
import asyncio
+ import importlib.resources
+ import pathlib
import sys
- import pkg_resources
-
- path = pkg_resources.resource_filename("mautrix.util", "variation_selector.json")
+ path = importlib.resources.files("mautrix.util").joinpath("variation_selector.json")
+ assert isinstance(path, pathlib.Path)
emojis = asyncio.run(fetch_data())
with open(path, "w") as file:
json.dump(emojis, file, indent=" ", ensure_ascii=False)
@@ -59,11 +60,11 @@ async def fetch_data() -> dict[str, str]:
ADD_VARIATION_TRANSLATION = str.maketrans(
{ord(emoji): f"{emoji}{VARIATION_SELECTOR_16}" for emoji in read_data().values()}
)
-SKIN_TONE_MODIFIERS = ("\U0001F3FB", "\U0001F3FC", "\U0001F3FD", "\U0001F3FE", "\U0001F3FF")
+SKIN_TONE_MODIFIERS = ("\U0001f3fb", "\U0001f3fc", "\U0001f3fd", "\U0001f3fe", "\U0001f3ff")
SKIN_TONE_REPLACEMENTS = {f"{VARIATION_SELECTOR_16}{mod}": mod for mod in SKIN_TONE_MODIFIERS}
VARIATION_SELECTOR_REPLACEMENTS = {
**SKIN_TONE_REPLACEMENTS,
- "\U0001F408\ufe0f\u200d\u2b1b\ufe0f": "\U0001F408\u200d\u2b1b",
+ "\U0001f408\ufe0f\u200d\u2b1b\ufe0f": "\U0001f408\u200d\u2b1b",
}
diff --git a/optional-requirements.txt b/optional-requirements.txt
index a5660f4a..a6e0227d 100644
--- a/optional-requirements.txt
+++ b/optional-requirements.txt
@@ -11,3 +11,4 @@ uvloop
python-olm
unpaddedbase64
pycryptodome
+base58
diff --git a/setup.py b/setup.py
index b01c50a1..23ac1cb1 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,7 @@
from mautrix import __version__
-encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome"]
+encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome", "base58"]
test_dependencies = ["aiosqlite", "asyncpg", "ruamel.yaml", *encryption_dependencies]
setuptools.setup(
@@ -28,7 +28,7 @@
],
extras_require={
"detect_mimetype": ["python-magic>=0.4.15,<0.5"],
- "lint": ["black~=24.1", "isort"],
+ "lint": ["black~=25.1", "isort"],
"test": ["pytest", "pytest-asyncio", *test_dependencies],
"encryption": encryption_dependencies,
},
@@ -45,6 +45,8 @@
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
],
package_data={