diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index 81c07712..3a845672 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -4,17 +4,16 @@ on: [push, pull_request]
jobs:
build:
-
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install libolm
@@ -47,17 +46,17 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
- - uses: actions/setup-python@v3
+ - uses: actions/checkout@v6
+ - uses: actions/setup-python@v6
with:
- python-version: "3.10"
+ python-version: "3.14"
- uses: isort/isort-action@master
with:
sortPaths: "./mautrix"
- uses: psf/black@stable
with:
src: "./mautrix"
- version: "22.3.0"
+ version: "26.3.1"
- name: pre-commit
run: |
pip install pre-commit
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 8f3efe97..b0d7ab3f 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,6 +1,6 @@
build docs builder:
stage: build
- image: docker:stable
+ image: docker:latest
tags:
- amd64
only:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 56919be3..66065033 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.1.0
+ rev: v6.0.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
@@ -8,13 +8,13 @@ repos:
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
- rev: 22.3.0
+ rev: 26.3.1
hooks:
- id: black
language_version: python3
files: ^mautrix/.*\.pyi?$
- repo: https://github.com/PyCQA/isort
- rev: 5.10.1
+ rev: 8.0.1
hooks:
- id: isort
files: ^mautrix/.*\.pyi?$
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 70c15c8f..c7179c86 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,11 +1,223 @@
-## unreleased
+## v0.21.0 (2025-11-17)
+
+* *(event)* Added support for creator power in room v12+.
+* *(crypto)* Added support for generating and using recovery keys for verifying
+ the active device.
+* *(bridge)* Added config option for self-signing bot device.
+* *(bridge)* Removed check for login flows when using MSC4190
+ (thanks to [@meson800] in [#178]).
+* *(client)* Changed `set_displayname` and `set_avatar_url` to avoid setting
+ empty strings if the value is already unset (thanks to [@frebib] in [#171]).
+
+[@frebib]: https://github.com/frebib
+[@meson800]: https://github.com/meson800
+[#171]: https://github.com/mautrix/python/pull/171
+[#178]: https://github.com/mautrix/python/pull/178
+
+## v0.20.8 (2025-06-01)
+
+* *(bridge)* Added support for [MSC4190] (thanks to [@surakin] in [#175]).
+* *(appservice)* Renamed `push_ephemeral` in generated registrations to
+ `receive_ephemeral` to match the accepted version of [MSC2409].
+* *(bridge)* Fixed compatibility with breaking change in aiohttp 3.12.6.
+[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781
+[@surakin]: https://github.com/surakin
+[#175]: https://github.com/mautrix/python/pull/175
+
+## v0.20.7 (2025-01-03)
+
+* *(types)* Removed support for generating reply fallbacks to implement
+ [MSC2781]. Stripping fallbacks is still supported.
+
+[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781
+
+## v0.20.6 (2024-07-12)
+
+* *(bridge)* Added `/register` call if `/versions` fails with `M_FORBIDDEN`.
+
+## v0.20.5 (2024-07-09)
+
+**Note:** The `bridge` module is deprecated as all bridges are being rewritten
+in Go. See for more info.
+
+* *(client)* Added support for authenticated media downloads.
+* *(bridge)* Stopped using cached homeserver URLs for double puppeting if one
+ is set in the config file.
+* *(crypto)* Fixed error when checking OTK counts before uploading new keys.
+* *(types)* Added MSC2530 (captions) fields to `MediaMessageEventContent`.
+
+## v0.20.4 (2024-01-09)
+
+* Dropped Python 3.9 support.
+* *(client)* Changed media download methods to log requests and to raise
+ exceptions on non-successful status codes.
+
+## v0.20.3 (2023-11-10)
+
+* *(client)* Deprecated MSC2716 methods and added new Beeper-specific batch
+ send methods, as upstream MSC2716 support has been abandoned.
+* *(util.async_db)* Added `PRAGMA synchronous = NORMAL;` to default pragmas.
+* *(types)* Fixed `guest_can_join` field name in room directory response
+ (thanks to [@ashfame] in [#163]).
+
+[@ashfame]: https://github.com/ashfame
+[#163]: https://github.com/mautrix/python/pull/163
+
+## v0.20.2 (2023-09-09)
+
+* *(crypto)* Changed `OlmMachine.share_keys` to make the OTK count parameter
+ optional. When omitted, the count is fetched from the server.
+* *(appservice)* Added option to run appservice transaction event handlers
+ synchronously.
+* *(appservice)* Added `log` and `hs_token` parameters to `AppServiceServerMixin`
+ to allow using it as a standalone class without extending.
+* *(api)* Added support for setting appservice `user_id` and `device_id` query
+ parameters manually without using `AppServiceAPI`.
+
+## v0.20.1 (2023-08-29)
+
+* *(util.program)* Removed `--base-config` flag in bridges, as there are no
+ valid use cases (package data should always work) and it's easy to cause
+ issues by pointing the flag at the wrong file.
+* *(bridge)* Added support for the `com.devture.shared_secret_auth` login type
+ for automatic double puppeting.
+* *(bridge)* Dropped support for syncing with double puppets. MSC2409 is now
+ the only way to receive ephemeral events.
+* *(bridge)* Added support for double puppeting with arbitrary `as_token`s.
+
+## v0.20.0 (2023-06-25)
+
+* Dropped Python 3.8 support.
+* **Breaking change *(.state_store)*** Removed legacy SQLAlchemy state store
+ implementations.
+* **Mildly breaking change *(util.async_db)*** Changed `SQLiteDatabase` to not
+ remove prefix slashes from database paths.
+ * Library users should use `sqlite:path.db` instead of `sqlite:///path.db`
+ for relative paths, and `sqlite:/path.db` instead of `sqlite:////path.db`
+ for absolute paths.
+ * Bridge configs do this migration automatically.
+* *(util.async_db)* Added warning log if using SQLite database path that isn't
+ writable.
+* *(util.program)* Fixed `manual_stop` not working if it's called during startup.
+* *(client)* Stabilized support for asynchronous uploads.
+ * `unstable_create_msc` was renamed to `create_mxc`, and the `max_stall_ms`
+ parameters for downloading were renamed to `timeout_ms`.
+* *(crypto)* Added option to not rotate keys when devices change.
+* *(crypto)* Added option to remove all keys that were received before the
+ automatic ratcheting was implemented (in v0.19.10).
+* *(types)* Improved reply fallback removal to have a smaller chance of false
+ positives for messages that don't use reply fallbacks.
+
+## v0.19.16 (2023-05-26)
+
+* *(appservice)* Fixed Python 3.8 compatibility.
+
+## v0.19.15 (2023-05-24)
+
+* *(client)* Fixed dispatching room ephemeral events (i.e. typing notifications) in syncer.
+
+## v0.19.14 (2023-05-16)
+
+* *(bridge)* Implemented appservice pinging using MSC2659.
+* *(bridge)* Started reusing aiosqlite connection pool for crypto db.
+ * This fixes the crypto pool getting stuck if the bridge exits unexpectedly
+ (the default pool is closed automatically at any type of exit).
+
+## v0.19.13 (2023-04-24)
+
+* *(crypto)* Fixed bug with redacting megolm sessions when device is deleted.
+
+## v0.19.12 (2023-04-18)
+
+* *(bridge)* Fixed backwards-compatibility with new key deletion config options.
+
+## v0.19.11 (2023-04-14)
+
+* *(crypto)* Fixed bug in previous release which caused errors if the `max_age`
+ of a megolm session was not known.
+* *(crypto)* Changed key receiving handler to fetch encryption config from
+ server if it's not cached locally (to find `max_age` and `max_messages` more
+ reliably).
+
+## v0.19.10 (2023-04-13)
+
+* *(crypto, bridge)* Added options to automatically ratchet/delete megolm
+ sessions to minimize access to old messages.
+
+## v0.19.9 (2023-04-12)
+
+* *(crypto)* Fixed bug in crypto store migration when using outbound sessions
+ with max age higher than usual.
+
+## v0.19.8 (2023-04-06)
+
+* *(crypto)* Updated crypto store schema to match mautrix-go.
+* *(types)* Fixed `set_thread_parent` adding reply fallbacks to the message body.
+
+## v0.19.7 (2023-03-22)
+
+* *(bridge, crypto)* Fixed key sharing trust checker not resolving cross-signing
+ signatures when minimum trust level is set to cross-signed.
+
+## v0.19.6 (2023-03-13)
+
+* *(crypto)* Added cache checks to prevent invalidating group session when the
+ server sends a duplicate member event in /sync.
+* *(util.proxy)* Fixed `min_wait_seconds` behavior and added `max_wait_seconds`
+ and `multiply_wait_seconds` to `proxy_with_retry`.
+
+## v0.19.5 (2023-03-07)
+
+* *(util.proxy)* Added utility for dynamic proxies (from mautrix-instagram/facebook).
+* *(types)* Added default value for `upload_size` in `MediaRepoConfig` as the
+ field is optional in the spec.
+* *(bridge)* Changed ghost invite handling to only process one per room at a time
+ (thanks to [@maltee1] in [#132]).
+
+[#132]: https://github.com/mautrix/python/pull/132
+
+## v0.19.4 (2023-02-12)
+
+* *(types)* Changed `set_thread_parent` to inherit the existing thread parent
+ if a `MessageEvent` is passed, as starting threads from a message in a thread
+ is not allowed.
+* *(util.background_task)* Added new utility for creating background tasks
+ safely, by ensuring that the task is not garbage collected before finishing
+ and logging uncaught exceptions immediately.
+
+## v0.19.3 (2023-01-27)
+
+* *(bridge)* Bumped default timeouts for decrypting incoming messages.
+
+## v0.19.2 (2023-01-14)
+
+* *(util.async_body)* Added utility for reading aiohttp response into a bytearray
+ (so that the output is mutable, e.g. for decrypting or encrypting media).
+* *(client.api)* Fixed retry loop for MSC3870 URL uploads not exiting properly
+ after too many errors.
+
+## v0.19.1 (2023-01-11)
+
+* Marked Python 3.11 as supported. Python 3.8 support will likely be dropped in
+ the coming months.
+* *(client.api)* Added request payload memory optimization to MSC3870 URL uploads.
+ * aiohttp will duplicate the entire request body if it's raw bytes, which
+ wastes a lot of memory. The optimization is passing an iterator instead of
+ raw bytes, so aiohttp won't accidentally duplicate the whole thing.
+ * The main `HTTPAPI` has had the optimization for a while, but uploading to
+ URL calls aiohttp manually.
+
+## v0.19.0 (2023-01-10)
+
+* **Breaking change *(appservice)*** Removed typing status from state store.
+* **Breaking change *(appservice)*** Removed `is_typing` parameter from
+ `IntentAPI.set_typing` to make the signature match `ClientAPI.set_typing`.
+ `timeout=0` is equivalent to the old `is_typing=False`.
+* **Breaking change *(types)*** Removed legacy fields in Beeper MSS events.
* *(bridge)* Removed accidentally nested reply loop when accepting invites as
the bridge bot.
* *(bridge)* Fixed decoding JSON values in config override env vars.
-* *(appservice)* Removed typing status from state store.
- * Additionally, the `is_typing` boolean in `set_typing` is now deprecated,
- and `timeout=0` should be used instead to match the `ClientAPI` behavior.
## v0.18.9 (2022-12-14)
diff --git a/README.rst b/README.rst
index 75ce8f6a..f1342c19 100644
--- a/README.rst
+++ b/README.rst
@@ -3,7 +3,7 @@ mautrix-python
|PyPI| |Python versions| |License| |Docs| |Code style| |Imports|
-A Python 3.8+ asyncio Matrix framework.
+A Python 3.10+ asyncio Matrix framework.
Matrix room: `#maunium:maunium.net`_
@@ -49,7 +49,7 @@ Components
.. _#maunium:maunium.net: https://matrix.to/#/#maunium:maunium.net
.. _python-appservice-framework: https://github.com/Cadair/python-appservice-framework/
-.. _Client API: https://matrix.org/docs/spec/client_server/r0.6.1.html
+.. _Client API: https://spec.matrix.org/latest/client-server-api/
.. _mautrix.api: https://docs.mau.fi/python/latest/api/mautrix.api.html
.. _mautrix.client.api: https://docs.mau.fi/python/latest/api/mautrix.client.api.html
diff --git a/dev-requirements.txt b/dev-requirements.txt
index e513c0df..dff6ee9d 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -1,3 +1,3 @@
pre-commit>=2.10.1,<3
-isort>=5.10.1,<6
-black>=22.3,<23
+isort>=8,<9
+black>=26,<27
diff --git a/docs/api/mautrix.client.state_store/index.rst b/docs/api/mautrix.client.state_store/index.rst
index 91f2ba42..0a934308 100644
--- a/docs/api/mautrix.client.state_store/index.rst
+++ b/docs/api/mautrix.client.state_store/index.rst
@@ -13,5 +13,4 @@ Implementations
In-memory
Async database (asyncpg/aiosqlite)
- Legacy database (SQLAlchemy)
Flat file
diff --git a/docs/api/mautrix.client.state_store/sqlalchemy.rst b/docs/api/mautrix.client.state_store/sqlalchemy.rst
deleted file mode 100644
index 767e0595..00000000
--- a/docs/api/mautrix.client.state_store/sqlalchemy.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-mautrix.client.state\_store.sqlalchemy
-======================================
-
-.. autoclass:: mautrix.client.state_store.sqlalchemy.SQLStateStore
- :no-undoc-members:
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 798e0f1b..7269d3a5 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -9,7 +9,7 @@ yarl
# that aren't used for anything that's in the docs
python-magic
ruamel.yaml
-SQLAlchemy
+SQLAlchemy<2
commonmark
asyncpg
aiosqlite
@@ -17,3 +17,4 @@ prometheus_client
python-olm
unpaddedbase64
pycryptodome
+base58
diff --git a/mautrix/__init__.py b/mautrix/__init__.py
index ae628d89..c94995f9 100644
--- a/mautrix/__init__.py
+++ b/mautrix/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.18.9"
+__version__ = "0.21.0"
__author__ = "Tulir Asokan "
__all__ = [
"api",
diff --git a/mautrix/api.py b/mautrix/api.py
index a0561d08..1adde9ec 100644
--- a/mautrix/api.py
+++ b/mautrix/api.py
@@ -5,7 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import AsyncGenerator, ClassVar, Literal, Mapping, Union
+from typing import ClassVar, Literal, Mapping
from enum import Enum
from json.decoder import JSONDecodeError
from urllib.parse import quote as urllib_quote, urljoin as urllib_join
@@ -22,12 +22,13 @@
from mautrix import __optional_imports__, __version__ as mautrix_version
from mautrix.errors import MatrixConnectionError, MatrixRequestError, make_request_error
+from mautrix.util.async_body import AsyncBody, async_iter_bytes
from mautrix.util.logging import TraceLogger
from mautrix.util.opt_prometheus import Counter
if __optional_imports__:
# Safe to import, but it's not actually needed, so don't force-import the whole types module.
- from mautrix.types import JSON
+ from mautrix.types import JSON, DeviceID, UserID
API_CALLS = Counter(
name="bridge_matrix_api_calls",
@@ -155,7 +156,6 @@ def replace(self, find: str, replace: str) -> PathBuilder:
"""
_req_id = 0
-AsyncBody = AsyncGenerator[Union[bytes, bytearray, memoryview], None]
def _next_global_req_id() -> int:
@@ -164,12 +164,6 @@ def _next_global_req_id() -> int:
return _req_id
-async def _async_iter_bytes(data: bytearray | bytes, chunk_size: int = 1024**2) -> AsyncBody:
- with memoryview(data) as mv:
- for i in range(0, len(data), chunk_size):
- yield mv[i : i + chunk_size]
-
-
class HTTPAPI:
"""HTTPAPI is a simple asyncio Matrix API request sender."""
@@ -199,6 +193,13 @@ class HTTPAPI:
default_retry_count: int
"""The default retry count to use if a custom value is not passed to :meth:`request`"""
+ as_user_id: UserID | None
+ """An optional user ID to set as the user_id query parameter for appservice requests."""
+ as_device_id: DeviceID | None
+ """
+ An optional device ID to set as the user_id query parameter for appservice requests (MSC3202).
+ """
+
def __init__(
self,
base_url: URL | str,
@@ -209,6 +210,8 @@ def __init__(
txn_id: int = 0,
log: TraceLogger | None = None,
loop: asyncio.AbstractEventLoop | None = None,
+ as_user_id: UserID | None = None,
+ as_device_id: UserID | None = None,
) -> None:
"""
Args:
@@ -218,6 +221,10 @@ def __init__(
txn_id: The outgoing transaction ID to start with.
log: The :class:`logging.Logger` instance to log requests with.
default_retry_count: Default number of retries to do when encountering network errors.
+ as_user_id: An optional user ID to set as the user_id query parameter for
+ appservice requests.
+ as_device_id: An optional device ID to set as the user_id query parameter for
+ appservice requests (MSC3202).
"""
self.base_url = URL(base_url)
self.token = token
@@ -225,6 +232,8 @@ def __init__(
self.session = client_session or ClientSession(
loop=loop, headers={"User-Agent": self.default_ua}
)
+ self.as_user_id = as_user_id
+ self.as_device_id = as_device_id
if txn_id is not None:
self.txn_id = txn_id
if default_retry_count is not None:
@@ -266,7 +275,7 @@ def _log_request(
self,
method: Method,
url: URL,
- content: str | bytes | bytearray | AsyncBody,
+ content: str | bytes | bytearray | AsyncBody | None,
orig_content,
query_params: dict[str, str],
headers: dict[str, str],
@@ -305,7 +314,7 @@ def _log_request(
)
def _log_request_done(
- self, path: PathBuilder, req_id: int, duration: float, status: int
+ self, path: PathBuilder | str, req_id: int, duration: float, status: int
) -> None:
level = 5 if path == Path.v3.sync else 10
duration_str = f"{duration * 1000:.1f}ms" if duration < 1 else f"{duration:.3f}s"
@@ -325,6 +334,16 @@ def _full_path(self, path: PathBuilder | str) -> str:
base_path += "/"
return urllib_join(base_path, path)
+ def log_download_request(self, url: URL, query_params: dict[str, str]) -> int:
+ req_id = _next_global_req_id()
+ self._log_request(Method.GET, url, None, None, query_params, {}, req_id, False)
+ return req_id
+
+ def log_download_request_done(
+ self, url: URL, req_id: int, duration: float, status: int
+ ) -> None:
+ self._log_request_done(url.path.removeprefix("/_matrix/media/"), req_id, duration, status)
+
async def request(
self,
method: Method,
@@ -366,6 +385,11 @@ async def request(
query_params = query_params or {}
if isinstance(query_params, dict):
query_params = {k: v for k, v in query_params.items() if v is not None}
+ if self.as_user_id:
+ query_params["user_id"] = self.as_user_id
+ if self.as_device_id:
+ query_params["org.matrix.msc3202.device_id"] = self.as_device_id
+ query_params["device_id"] = self.as_device_id
if method != Method.GET:
content = content or {}
@@ -395,7 +419,7 @@ async def request(
method, log_url, content, orig_content, query_params, headers, req_id, sensitive
)
API_CALLS.labels(method=metrics_method).inc()
- req_content = _async_iter_bytes(content) if do_fake_iter else content
+ req_content = async_iter_bytes(content) if do_fake_iter else content
start = time.monotonic()
try:
resp_data, resp = await self._send(
@@ -438,6 +462,7 @@ def get_download_url(
mxc_uri: str,
download_type: Literal["download", "thumbnail"] = "download",
file_name: str | None = None,
+ authenticated: bool = False,
) -> URL:
"""
Get the full HTTP URL to download a ``mxc://`` URI.
@@ -446,6 +471,7 @@ def get_download_url(
mxc_uri: The MXC URI whose full URL to get.
download_type: The type of download ("download" or "thumbnail").
file_name: Optionally, a file name to include in the download URL.
+ authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11.
Returns:
The full HTTP URL.
@@ -461,7 +487,11 @@ def get_download_url(
"https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png"
"""
server_name, media_id = self.parse_mxc_uri(mxc_uri)
- url = self.base_url / str(APIPath.MEDIA) / "v3" / download_type / server_name / media_id
+ if authenticated:
+ url = self.base_url / str(APIPath.CLIENT) / "v1" / "media"
+ else:
+ url = self.base_url / str(APIPath.MEDIA) / "v3"
+ url = url / download_type / server_name / media_id
if file_name:
url /= file_name
return url
diff --git a/mautrix/appservice/api/appservice.py b/mautrix/appservice/api/appservice.py
index af6ae48d..6654ec20 100644
--- a/mautrix/appservice/api/appservice.py
+++ b/mautrix/appservice/api/appservice.py
@@ -51,6 +51,7 @@ def __init__(
client_session: ClientSession = None,
child: bool = False,
real_user: bool = False,
+ real_user_as_token: bool = False,
bridge_name: str | None = None,
default_retry_count: int = None,
loop: asyncio.AbstractEventLoop | None = None,
@@ -66,6 +67,7 @@ def __init__(
client_session: The aiohttp ClientSession to use.
child: Whether or not this is instance is a child of another AppServiceAPI.
real_user: Whether or not this is a real (non-appservice-managed) user.
+ real_user_as_token: Whether this real user is actually using another ``as_token``.
bridge_name: The name of the bridge to put in the ``fi.mau.double_puppet_source`` field
in outgoing message events sent through real users.
"""
@@ -85,6 +87,7 @@ def __init__(
self._bot_intent = None
self.state_store = state_store
self.is_real_user = real_user
+ self.is_real_user_as_token = real_user_as_token
self.bridge_name = bridge_name
if not child:
@@ -113,7 +116,9 @@ def user(self, user: UserID) -> ChildAppServiceAPI:
self.children[user] = child
return child
- def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> AppServiceAPI:
+ def real_user(
+ self, mxid: UserID, token: str, base_url: URL | None = None, as_token: bool = False
+ ) -> AppServiceAPI:
"""
Get the AppServiceAPI for a real (non-appservice-managed) Matrix user.
@@ -122,6 +127,8 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap
token: The access token for the user.
base_url: The base URL of the homeserver client-server API to use. Defaults to the
appservice homeserver URL.
+ as_token: Whether the token is actually an as_token
+ (meaning the ``user_id`` query parameter needs to be used).
Returns:
The AppServiceAPI object for the user.
@@ -136,6 +143,7 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap
child = self.real_users[mxid]
child.base_url = base_url or child.base_url
child.token = token or child.token
+ child.is_real_user_as_token = as_token
except KeyError:
child = type(self)(
base_url=base_url or self.base_url,
@@ -145,6 +153,7 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap
state_store=self.state_store,
client_session=self.session,
real_user=True,
+ real_user_as_token=as_token,
bridge_name=self.bridge_name,
default_retry_count=self.default_retry_count,
)
@@ -163,7 +172,11 @@ def bot_intent(self) -> as_api.IntentAPI:
return self._bot_intent
def intent(
- self, user: UserID = None, token: str | None = None, base_url: str | None = None
+ self,
+ user: UserID = None,
+ token: str | None = None,
+ base_url: str | None = None,
+ real_user_as_token: bool = False,
) -> as_api.IntentAPI:
"""
Get the intent API of a child user.
@@ -173,6 +186,8 @@ def intent(
token: The access token to use. Only applicable for non-appservice-managed users.
base_url: The base URL of the homeserver client-server API to use. Only applicable for
non-appservice users. Defaults to the appservice homeserver URL.
+ real_user_as_token: When providing a token, whether it's actually another as_token
+ (meaning the ``user_id`` query parameter needs to be used).
Returns:
The IntentAPI object for the given user.
@@ -184,7 +199,10 @@ def intent(
raise ValueError("Can't get child intent of real user")
if token:
return as_api.IntentAPI(
- user, self.real_user(user, token, base_url), self.bot_intent(), self.state_store
+ user,
+ self.real_user(user, token, base_url, as_token=real_user_as_token),
+ self.bot_intent(),
+ self.state_store,
)
return as_api.IntentAPI(user, self.user(user), self.bot_intent(), self.state_store)
@@ -229,7 +247,7 @@ def request(
if isinstance(timestamp, datetime):
timestamp = int(timestamp.replace(tzinfo=timezone.utc).timestamp() * 1000)
query_params["ts"] = timestamp
- if not self.is_real_user:
+ if not self.is_real_user or self.is_real_user_as_token:
query_params["user_id"] = self.identity or self.bot_mxid
return super().request(
diff --git a/mautrix/appservice/api/intent.py b/mautrix/appservice/api/intent.py
index 92966313..626b34f1 100644
--- a/mautrix/appservice/api/intent.py
+++ b/mautrix/appservice/api/intent.py
@@ -25,6 +25,7 @@
BatchSendEvent,
BatchSendResponse,
BatchSendStateEvent,
+ BeeperBatchSendResponse,
ContentURI,
EventContent,
EventID,
@@ -40,7 +41,6 @@
RoomNameStateEventContent,
RoomPinnedEventsStateEventContent,
RoomTopicStateEventContent,
- SerializableAttrs,
StateEventContent,
UserID,
)
@@ -71,7 +71,8 @@ def quote(*args, **kwargs):
ClientAPI.search_users,
ClientAPI.set_displayname,
ClientAPI.set_avatar_url,
- ClientAPI.unstable_create_mxc,
+ ClientAPI.beeper_update_profile,
+ ClientAPI.create_mxc,
ClientAPI.upload_media,
ClientAPI.send_receipt,
ClientAPI.set_fully_read_marker,
@@ -117,6 +118,8 @@ def __init__(
) -> None:
super().__init__(mxid=mxid, api=api, state_store=state_store)
self.bot = bot
+ if bot is not None:
+ self.versions_cache = bot.versions_cache
self.log = api.base_log.getChild("intent")
for method in ENSURE_REGISTERED_METHODS:
@@ -143,7 +146,11 @@ async def wrapper(*args, __self=self, __method=method, **kwargs):
setattr(self, method.__name__, wrapper)
def user(
- self, user_id: UserID, token: str | None = None, base_url: str | None = None
+ self,
+ user_id: UserID,
+ token: str | None = None,
+ base_url: str | None = None,
+ as_token: bool = False,
) -> IntentAPI:
"""
Get the intent API for a specific user.
@@ -156,15 +163,17 @@ def user(
user_id: The Matrix ID of the user whose intent API to get.
token: The access token to use for the Matrix ID.
base_url: An optional URL to use for API requests.
+ as_token: Whether the provided token is actually another as_token
+ (meaning the ``user_id`` query parameter needs to be used).
Returns:
The IntentAPI for the given user.
"""
if not self.bot:
- return self.api.intent(user_id, token, base_url)
+ return self.api.intent(user_id, token, base_url, real_user_as_token=as_token)
else:
self.log.warning("Called IntentAPI#user() of child intent object.")
- return self.bot.api.intent(user_id, token, base_url)
+ return self.bot.api.intent(user_id, token, base_url, real_user_as_token=as_token)
# region User actions
@@ -182,7 +191,7 @@ async def set_presence(
Args:
presence: The online status of the user.
status: The status message.
- ignore_cache: Whether or not to set presence even if the cache says the presence is
+ ignore_cache: Whether to set presence even if the cache says the presence is
already set to that value.
"""
await self.ensure_registered()
@@ -248,7 +257,9 @@ async def invite_user(
await self.state_store.joined(room_id, user_id)
except MatrixRequestError as e:
# TODO remove this once MSC3848 is released and minimum spec version is bumped
- if e.errcode == "M_FORBIDDEN" and "is already in the room" in e.message:
+ if e.errcode == "M_FORBIDDEN" and (
+ "already in the room" in e.message or "is already joined to room" in e.message
+ ):
await self.state_store.joined(room_id, user_id)
else:
raise
@@ -405,19 +416,10 @@ async def get_room_member_info(
async def set_typing(
self,
room_id: RoomID,
- is_typing: bool = True,
- timeout: int = 5000,
+ timeout: int = 0,
) -> None:
- """
- Args:
- room_id: The ID of the room in which the user is typing.
- is_typing: Whether the user is typing.
- .. deprecated:: 0.18.10
- Use ``timeout=0`` instead of setting this flag.
- timeout: The length of time in seconds to mark this user as typing.
- """
await self.ensure_joined(room_id)
- await super().set_typing(room_id, timeout if is_typing else 0)
+ await super().set_typing(room_id, timeout)
async def error_and_leave(
self, room_id: RoomID, text: str | None = None, html: str | None = None
@@ -496,6 +498,14 @@ async def mark_read(
)
self.state_store.set_read(room_id, self.mxid, event_id)
+ async def appservice_ping(self, appservice_id: str, txn_id: str | None = None) -> int:
+ resp = await self.api.request(
+ Method.POST,
+ Path.v1.appservice[appservice_id].ping,
+ content={"transaction_id": txn_id} if txn_id is not None else {},
+ )
+ return resp.get("duration_ms") or -1
+
async def batch_send(
self,
room_id: RoomID,
@@ -514,6 +524,9 @@ async def batch_send(
.. versionadded:: v0.12.5
+ .. deprecated:: v0.20.3
+ MSC2716 was abandoned by upstream and Beeper has forked the endpoint.
+
Args:
room_id: The room ID to send the events to.
prev_event_id: The anchor event. The batch will be inserted immediately after this event.
@@ -548,6 +561,52 @@ async def batch_send(
)
return BatchSendResponse.deserialize(resp)
+ async def beeper_batch_send(
+ self,
+ room_id: RoomID,
+ events: Iterable[BatchSendEvent],
+ *,
+ forward: bool = False,
+ forward_if_no_messages: bool = False,
+ send_notification: bool = False,
+ mark_read_by: UserID | None = None,
+ ) -> BeeperBatchSendResponse:
+ """
+ Send a batch of events into a room. Only for Beeper/hungryserv.
+
+ .. versionadded:: v0.20.3
+
+ Args:
+ room_id: The room ID to send the events to.
+ events: The events to send.
+ forward: Send events to the end of the room instead of the beginning
+ forward_if_no_messages: Send events to the end of the room, but only if there are no
+ messages in the room. If there are messages, send the new messages to the beginning.
+ send_notification: Send a push notification for the new messages.
+ Only applies when sending to the end of the room.
+ mark_read_by: Send a read receipt from the given user ID atomically.
+
+ Returns:
+ All the event IDs generated.
+ """
+ body = {
+ "events": [evt.serialize() for evt in events],
+ }
+ if forward:
+ body["forward"] = forward
+ elif forward_if_no_messages:
+ body["forward_if_no_messages"] = forward_if_no_messages
+ if send_notification:
+ body["send_notification"] = send_notification
+ if mark_read_by:
+ body["mark_read_by"] = mark_read_by
+ resp = await self.api.request(
+ Method.POST,
+ Path.unstable["com.beeper.backfill"].rooms[room_id].batch_send,
+ content=body,
+ )
+ return BeeperBatchSendResponse.deserialize(resp)
+
async def beeper_delete_room(self, room_id: RoomID) -> None:
versions = await self.versions()
if not versions.supports("com.beeper.room_yeeting"):
@@ -651,6 +710,8 @@ async def _ensure_has_power_level_for(
if not await self.state_store.has_power_levels_cached(room_id):
# TODO add option to not try to fetch power levels from server
await self.get_power_levels(room_id, ignore_cache=True, ensure_joined=False)
+ if not await self.state_store.has_create_cached(room_id):
+ await self.get_state_event(room_id, EventType.ROOM_CREATE, format="event")
if not await self.state_store.has_power_level(room_id, self.mxid, event_type):
# TODO implement something better
raise IntentError(
diff --git a/mautrix/appservice/appservice.py b/mautrix/appservice/appservice.py
index 551eab5c..65e202ae 100644
--- a/mautrix/appservice/appservice.py
+++ b/mautrix/appservice/appservice.py
@@ -13,7 +13,7 @@
from aiohttp import web
import aiohttp
-from mautrix.types import JSON, RoomAlias, UserID
+from mautrix.types import JSON, RoomAlias, UserID, VersionsResponse
from mautrix.util.logging import TraceLogger
from ..api import HTTPAPI
@@ -194,3 +194,6 @@ async def _liveness_probe(self, _: web.Request) -> web.Response:
async def _readiness_probe(self, _: web.Request) -> web.Response:
return web.Response(status=200 if self.ready else 500, text="{}")
+
+ async def ping_self(self, txn_id: str | None = None) -> int:
+ return await self.intent.appservice_ping(self.id, txn_id=txn_id)
diff --git a/mautrix/appservice/as_handler.py b/mautrix/appservice/as_handler.py
index 88715100..ec7e339f 100644
--- a/mautrix/appservice/as_handler.py
+++ b/mautrix/appservice/as_handler.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2022 Tulir Asokan
+# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,7 +8,6 @@
from typing import Any, Awaitable, Callable
from json import JSONDecodeError
-import asyncio
import json
import logging
@@ -27,17 +26,18 @@
SerializerError,
UserID,
)
+from mautrix.util import background_task
HandlerFunc = Callable[[Event], Awaitable]
class AppServiceServerMixin:
- loop: asyncio.AbstractEventLoop
log: logging.Logger
hs_token: str
ephemeral_events: bool
encryption_events: bool
+ synchronous_handlers: bool
query_user: Callable[[UserID], JSON]
query_alias: Callable[[RoomAlias], JSON]
@@ -48,7 +48,17 @@ class AppServiceServerMixin:
otk_handler: Callable[[dict[UserID, dict[DeviceID, DeviceOTKCount]]], Awaitable] | None
device_list_handler: Callable[[DeviceLists], Awaitable] | None
- def __init__(self, ephemeral_events: bool = False, encryption_events: bool = False) -> None:
+ def __init__(
+ self,
+ ephemeral_events: bool = False,
+ encryption_events: bool = False,
+ log: logging.Logger | None = None,
+ hs_token: str | None = None,
+ ) -> None:
+ if log is not None:
+ self.log = log
+ if hs_token is not None:
+ self.hs_token = hs_token
self.transactions = set()
self.event_handlers = []
self.to_device_handler = None
@@ -56,6 +66,7 @@ def __init__(self, ephemeral_events: bool = False, encryption_events: bool = Fal
self.device_list_handler = None
self.ephemeral_events = ephemeral_events
self.encryption_events = encryption_events
+ self.synchronous_handlers = False
async def default_query_handler(_):
return None
@@ -74,6 +85,7 @@ def register_routes(self, app: web.Application) -> None:
)
app.router.add_route("GET", "/_matrix/app/v1/rooms/{alias}", self._http_query_alias)
app.router.add_route("GET", "/_matrix/app/v1/users/{user_id}", self._http_query_user)
+ app.router.add_route("POST", "/_matrix/app/v1/ping", self._http_ping)
def _check_token(self, request: web.Request) -> bool:
try:
@@ -81,10 +93,12 @@ def _check_token(self, request: web.Request) -> bool:
except KeyError:
try:
token = request.headers["Authorization"].removeprefix("Bearer ")
- except (KeyError, AttributeError):
+ except KeyError:
+ self.log.debug("No access_token nor Authorization header in request")
return False
if token != self.hs_token:
+ self.log.debug(f"Incorrect hs_token in request")
return False
return True
@@ -127,6 +141,23 @@ async def _http_query_alias(self, request: web.Request) -> web.Response:
return web.json_response({}, status=404)
return web.json_response(response)
+ async def _http_ping(self, request: web.Request) -> web.Response:
+ if not self._check_token(request):
+ raise web.HTTPUnauthorized(
+ content_type="application/json",
+ text=json.dumps({"error": "Invalid auth token", "errcode": "M_UNKNOWN_TOKEN"}),
+ )
+ try:
+ body = await request.json()
+ except JSONDecodeError:
+ raise web.HTTPBadRequest(
+ content_type="application/json",
+ text=json.dumps({"error": "Body is not JSON", "errcode": "M_NOT_JSON"}),
+ )
+ txn_id = body.get("transaction_id")
+ self.log.info(f"Received ping from homeserver with transaction ID {txn_id}")
+ return web.json_response({})
+
@staticmethod
def _get_with_fallback(
json: dict[str, Any], field: str, unstable_prefix: str, default: Any = None
@@ -269,7 +300,7 @@ async def handle_transaction(
else:
try:
await self.to_device_handler(td)
- except:
+ except Exception:
self.log.exception("Exception in Matrix to-device event handler")
if device_lists and self.device_list_handler:
try:
@@ -287,7 +318,7 @@ async def handle_transaction(
except SerializerError:
self.log.exception("Failed to deserialize ephemeral event %s", raw_edu)
else:
- self.handle_matrix_event(edu, ephemeral=True)
+ await self.handle_matrix_event(edu, ephemeral=True)
for raw_event in events:
try:
self._fix_prev_content(raw_event)
@@ -295,10 +326,10 @@ async def handle_transaction(
except SerializerError:
self.log.exception("Failed to deserialize event %s", raw_event)
else:
- self.handle_matrix_event(event)
+ await self.handle_matrix_event(event)
return {}
- def handle_matrix_event(self, event: Event, ephemeral: bool = False) -> None:
+ async def handle_matrix_event(self, event: Event, ephemeral: bool = False) -> None:
if ephemeral:
event.type = event.type.with_class(EventType.Class.EPHEMERAL)
elif getattr(event, "state_key", None) is not None:
@@ -312,9 +343,12 @@ async def try_handle(handler_func: HandlerFunc):
except Exception:
self.log.exception("Exception in Matrix event handler")
- for handler in self.event_handlers:
- # TODO add option to handle events synchronously
- asyncio.create_task(try_handle(handler))
+ if self.synchronous_handlers:
+ for handler in self.event_handlers:
+ await handler(event)
+ else:
+ for handler in self.event_handlers:
+ background_task.create(try_handle(handler))
def matrix_event_handler(self, func: HandlerFunc) -> HandlerFunc:
self.event_handlers.append(func)
diff --git a/mautrix/appservice/state_store/__init__.py b/mautrix/appservice/state_store/__init__.py
index bc23b5f5..771ac252 100644
--- a/mautrix/appservice/state_store/__init__.py
+++ b/mautrix/appservice/state_store/__init__.py
@@ -1,4 +1,4 @@
from .file import FileASStateStore
from .memory import ASStateStore
-__all__ = ["FileASStateStore", "ASStateStore", "sqlalchemy", "asyncpg"]
+__all__ = ["FileASStateStore", "ASStateStore", "asyncpg"]
diff --git a/mautrix/appservice/state_store/sqlalchemy.py b/mautrix/appservice/state_store/sqlalchemy.py
deleted file mode 100644
index 30c0b1fc..00000000
--- a/mautrix/appservice/state_store/sqlalchemy.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (c) 2022 Tulir Asokan
-#
-# This Source Code Form is subject to the terms of the Mozilla Public
-# License, v. 2.0. If a copy of the MPL was not distributed with this
-# file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from mautrix.client.state_store.sqlalchemy import SQLStateStore as SQLClientStateStore
-
-from .memory import ASStateStore
-
-
-class SQLASStateStore(SQLClientStateStore, ASStateStore):
- def __init__(self) -> None:
- SQLClientStateStore.__init__(self)
- ASStateStore.__init__(self)
diff --git a/mautrix/bridge/bridge.py b/mautrix/bridge/bridge.py
index 7005d775..9ce9360e 100644
--- a/mautrix/bridge/bridge.py
+++ b/mautrix/bridge/bridge.py
@@ -8,6 +8,7 @@
from typing import Any
from abc import ABC, abstractmethod
from enum import Enum
+import asyncio
import sys
from aiohttp import web
@@ -59,6 +60,8 @@ class Bridge(Program, ABC):
markdown_version: str
manhole: br.commands.manhole.ManholeState | None
homeserver_software: HomeserverSoftware
+ beeper_network_name: str | None = None
+ beeper_service_name: str | None = None
def __init__(
self,
@@ -133,7 +136,7 @@ def prepare_config(self) -> None:
self.config = self.config_class(
self.args.config,
self.args.registration,
- self.args.base_config,
+ self.base_config_path,
env_prefix=self.module.upper(),
)
if self.args.generate_registration:
@@ -244,6 +247,9 @@ async def start(self) -> None:
"correct, and do they match the values in the registration?"
)
sys.exit(16)
+ except Exception:
+ self.log.critical("Failed to check connection to homeserver", exc_info=True)
+ sys.exit(16)
await self.matrix.init_encryption()
self.add_startup_actions(self.matrix.init_as_bot())
@@ -253,12 +259,16 @@ async def start(self) -> None:
status_endpoint = self.config["homeserver.status_endpoint"]
if status_endpoint and await self.count_logged_in_users() == 0:
state = BridgeState(state_event=BridgeStateEvent.UNCONFIGURED).fill()
- await state.send(status_endpoint, self.az.as_token, self.log)
+ while not await state.send(status_endpoint, self.az.as_token, self.log):
+ await asyncio.sleep(5)
async def system_exit(self) -> None:
if hasattr(self, "db") and isinstance(self.db, Database):
- self.log.trace("Stopping database due to SystemExit")
+ self.log.debug("Stopping database due to SystemExit")
await self.db.stop()
+ self.log.debug("Database stopped")
+ elif getattr(self, "db", None):
+ self.log.trace("Database not started at SystemExit")
async def stop(self) -> None:
if self.manhole:
diff --git a/mautrix/bridge/commands/login_matrix.py b/mautrix/bridge/commands/login_matrix.py
index 2fd89d2d..6cc6f7aa 100644
--- a/mautrix/bridge/commands/login_matrix.py
+++ b/mautrix/bridge/commands/login_matrix.py
@@ -70,24 +70,3 @@ async def ping_matrix(evt: CommandEvent) -> EventID:
except InvalidAccessToken:
return await evt.reply("Your access token is invalid.")
return await evt.reply("Your Matrix login is working.")
-
-
-@command_handler(
- needs_auth=True,
- help_section=SECTION_AUTH,
- help_text="Clear the Matrix sync token stored for your double puppet.",
-)
-async def clear_cache_matrix(evt: CommandEvent) -> EventID:
- try:
- puppet = await evt.sender.get_puppet()
- except NotImplementedError:
- return await evt.reply("This bridge has not implemented the clear-cache-matrix command")
- if not puppet.is_real_user:
- return await evt.reply("You are not logged in with your Matrix account.")
- try:
- puppet.stop()
- puppet.next_batch = None
- await puppet.start()
- except InvalidAccessToken:
- return await evt.reply("Your access token is invalid.")
- return await evt.reply("Cleared cache successfully.")
diff --git a/mautrix/bridge/config.py b/mautrix/bridge/config.py
index 4289b178..defed222 100644
--- a/mautrix/bridge/config.py
+++ b/mautrix/bridge/config.py
@@ -115,7 +115,12 @@ def do_update(self, helper: ConfigUpdateHelper) -> None:
copy("appservice.tls_cert")
copy("appservice.tls_key")
- copy("appservice.database")
+ if "appservice.database" in self and self["appservice.database"].startswith("sqlite:///"):
+ helper.base["appservice.database"] = self["appservice.database"].replace(
+ "sqlite:///", "sqlite:"
+ )
+ else:
+ copy("appservice.database")
copy("appservice.database_opts")
copy("appservice.id")
@@ -138,6 +143,16 @@ def do_update(self, helper: ConfigUpdateHelper) -> None:
copy("bridge.encryption.default")
copy("bridge.encryption.require")
copy("bridge.encryption.appservice")
+ copy("bridge.encryption.msc4190")
+ copy("bridge.encryption.self_sign")
+ copy("bridge.encryption.delete_keys.delete_outbound_on_ack")
+ copy("bridge.encryption.delete_keys.dont_store_outbound")
+ copy("bridge.encryption.delete_keys.ratchet_on_decrypt")
+ copy("bridge.encryption.delete_keys.delete_fully_used_on_decrypt")
+ copy("bridge.encryption.delete_keys.delete_prev_on_new_session")
+ copy("bridge.encryption.delete_keys.delete_on_device_delete")
+ copy("bridge.encryption.delete_keys.periodically_delete_expired")
+ copy("bridge.encryption.delete_keys.delete_outdated_inbound")
copy("bridge.encryption.verification_levels.receive")
copy("bridge.encryption.verification_levels.send")
copy("bridge.encryption.verification_levels.share")
@@ -154,6 +169,7 @@ def do_update(self, helper: ConfigUpdateHelper) -> None:
copy("bridge.encryption.rotation.enable_custom")
copy("bridge.encryption.rotation.milliseconds")
copy("bridge.encryption.rotation.messages")
+ copy("bridge.encryption.rotation.disable_device_change_key_rotation")
copy("bridge.relay.enabled")
copy_dict("bridge.relay.message_formats", override_existing_map=False)
@@ -186,14 +202,18 @@ def namespaces(self) -> dict[str, list[dict[str, Any]]]:
"regex": re.escape(f"@{username_format}:{homeserver}").replace(regex_ph, ".*"),
}
],
- "aliases": [
- {
- "exclusive": True,
- "regex": re.escape(f"#{alias_format}:{homeserver}").replace(regex_ph, ".*"),
- }
- ]
- if alias_format
- else [],
+ "aliases": (
+ [
+ {
+ "exclusive": True,
+ "regex": re.escape(f"#{alias_format}:{homeserver}").replace(
+ regex_ph, ".*"
+ ),
+ }
+ ]
+ if alias_format
+ else []
+ ),
}
def generate_registration(self) -> None:
@@ -222,4 +242,7 @@ def generate_registration(self) -> None:
if self["appservice.ephemeral_events"]:
self._registration["de.sorunome.msc2409.push_ephemeral"] = True
- self._registration["push_ephemeral"] = True
+ self._registration["receive_ephemeral"] = True
+
+ if self["bridge.encryption.msc4190"]:
+ self._registration["io.element.msc4190"] = True
diff --git a/mautrix/bridge/crypto_state_store.py b/mautrix/bridge/crypto_state_store.py
index 84169e2e..f08dec2c 100644
--- a/mautrix/bridge/crypto_state_store.py
+++ b/mautrix/bridge/crypto_state_store.py
@@ -28,29 +28,6 @@ async def is_encrypted(self, room_id: RoomID) -> bool:
return portal.encrypted if portal else False
-try:
- from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
-
- class SQLCryptoStateStore(BaseCryptoStateStore):
- @staticmethod
- async def find_shared_rooms(user_id: UserID) -> list[RoomID]:
- return [profile.room_id for profile in UserProfile.find_rooms_with_user(user_id)]
-
- @staticmethod
- async def get_encryption_info(room_id: RoomID) -> RoomEncryptionStateEventContent | None:
- state = RoomState.get(room_id)
- if not state:
- return None
- return state.encryption
-
-except ImportError:
- if __optional_imports__:
- raise
- UserProfile = None
- RoomState = None
- SQLCryptoStateStore = None
-
-
class PgCryptoStateStore(BaseCryptoStateStore):
db: Database
diff --git a/mautrix/bridge/custom_puppet.py b/mautrix/bridge/custom_puppet.py
index 3d5c7f9e..f5befd5f 100644
--- a/mautrix/bridge/custom_puppet.py
+++ b/mautrix/bridge/custom_puppet.py
@@ -5,20 +5,16 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Awaitable, Iterator
from abc import ABC, abstractmethod
-from itertools import chain
import asyncio
import hashlib
import hmac
-import json
import logging
-from aiohttp import ClientConnectionError
from yarl import URL
-from mautrix.api import Path
from mautrix.appservice import AppService, IntentAPI
+from mautrix.client import ClientAPI
from mautrix.errors import (
IntentError,
MatrixError,
@@ -26,21 +22,7 @@
MatrixRequestError,
WellKnownError,
)
-from mautrix.types import (
- Event,
- EventFilter,
- EventType,
- Filter,
- FilterID,
- LoginType,
- PresenceState,
- RoomEventFilter,
- RoomFilter,
- RoomID,
- StateFilter,
- SyncToken,
- UserID,
-)
+from mautrix.types import LoginType, MatrixUserIdentifier, RoomID, UserID
from .. import bridge as br
@@ -114,7 +96,6 @@ class CustomPuppetMixin(ABC):
intent: The primary IntentAPI.
"""
- sync_with_custom_puppets: bool = True
allow_discover_url: bool = False
homeserver_url_map: dict[str, URL] = {}
only_handle_own_synced_events: bool = True
@@ -133,12 +114,9 @@ class CustomPuppetMixin(ABC):
custom_mxid: UserID | None
access_token: str | None
base_url: URL | None
- next_batch: SyncToken | None
intent: IntentAPI
- _sync_task: asyncio.Task | None = None
-
@abstractmethod
async def save(self) -> None:
"""Save the information of this puppet. Called from :meth:`switch_mxid`"""
@@ -154,6 +132,25 @@ def is_real_user(self) -> bool:
return bool(self.custom_mxid and self.access_token)
def _fresh_intent(self) -> IntentAPI:
+ if self.custom_mxid:
+ _, server = self.az.intent.parse_user_id(self.custom_mxid)
+ try:
+ self.base_url = self.homeserver_url_map[server]
+ except KeyError:
+ if server == self.az.domain:
+ self.base_url = self.az.intent.api.base_url
+ if self.access_token == "appservice-config" and self.custom_mxid:
+ try:
+ secret = self.login_shared_secret_map[server]
+ except KeyError:
+ raise AutologinError(f"No shared secret configured for {server}")
+ self.log.debug(f"Using as_token for double puppeting {self.custom_mxid}")
+ return self.az.intent.user(
+ self.custom_mxid,
+ secret.decode("utf-8").removeprefix("as_token:"),
+ self.base_url,
+ as_token=True,
+ )
return (
self.az.intent.user(self.custom_mxid, self.access_token, self.base_url)
if self.is_real_user
@@ -174,6 +171,8 @@ async def _login_with_shared_secret(cls, mxid: UserID) -> str:
secret = cls.login_shared_secret_map[server]
except KeyError:
raise AutologinError(f"No shared secret configured for {server}")
+ if secret.startswith(b"as_token:"):
+ return "appservice-config"
try:
base_url = cls.homeserver_url_map[server]
except KeyError:
@@ -181,31 +180,32 @@ async def _login_with_shared_secret(cls, mxid: UserID) -> str:
base_url = cls.az.intent.api.base_url
else:
raise AutologinError(f"No homeserver URL configured for {server}")
- url = base_url / str(Path.v3.login)
- headers = {"Content-Type": "application/json"}
- login_req = {
- "initial_device_display_name": cls.login_device_name,
- "device_id": cls.login_device_name,
- "identifier": {
- "type": "m.id.user",
- "user": mxid,
- },
- }
+ client = ClientAPI(base_url=base_url)
+ login_args = {}
if secret == b"appservice":
- login_req["type"] = str(LoginType.APPSERVICE)
- headers["Authorization"] = f"Bearer {cls.az.as_token}"
+ login_type = LoginType.APPSERVICE
+ client.api.token = cls.az.as_token
else:
- login_req["type"] = str(LoginType.PASSWORD)
- login_req["password"] = hmac.new(
- secret, mxid.encode("utf-8"), hashlib.sha512
- ).hexdigest()
- resp = await cls.az.http_session.post(url, data=json.dumps(login_req), headers=headers)
- data = await resp.json()
- try:
- return data["access_token"]
- except KeyError:
- error_msg = data.get("error", data.get("errcode", f"HTTP {resp.status}"))
- raise AutologinError(f"Didn't get an access token: {error_msg}") from None
+ flows = await client.get_login_flows()
+ flow = flows.get_first_of_type(LoginType.DEVTURE_SHARED_SECRET, LoginType.PASSWORD)
+ if not flow:
+ raise AutologinError("No supported shared secret auth login flows")
+ login_type = flow.type
+ token = hmac.new(secret, mxid.encode("utf-8"), hashlib.sha512).hexdigest()
+ if login_type == LoginType.DEVTURE_SHARED_SECRET:
+ login_args["token"] = token
+ elif login_type == LoginType.PASSWORD:
+ login_args["password"] = token
+ resp = await client.login(
+ identifier=MatrixUserIdentifier(user=mxid),
+ device_id=cls.login_device_name,
+ initial_device_display_name=cls.login_device_name,
+ login_type=login_type,
+ **login_args,
+ store_access_token=False,
+ update_hs_url=False,
+ )
+ return resp.access_token
async def switch_mxid(
self, access_token: str | None, mxid: UserID | None, start_sync_task: bool = True
@@ -218,11 +218,11 @@ async def switch_mxid(
the appservice-owned ID.
mxid: The expected Matrix user ID of the custom account, or ``None`` when
``access_token`` is None.
- start_sync_task: Whether or not syncing should be started after logging in.
"""
if access_token == "auto":
access_token = await self._login_with_shared_secret(mxid)
- self.log.debug(f"Logged in for {mxid} using shared secret")
+ if access_token != "appservice-config":
+ self.log.debug(f"Logged in for {mxid} using shared secret")
if mxid is not None:
_, mxid_domain = self.az.intent.parse_user_id(mxid)
@@ -248,7 +248,7 @@ async def switch_mxid(
self.base_url = base_url
self.intent = self._fresh_intent()
- await self.start(start_sync_task=start_sync_task, check_e2ee_keys=True)
+ await self.start(check_e2ee_keys=True)
try:
del self.by_custom_mxid[prev_mxid]
@@ -273,7 +273,6 @@ async def _invalidate_double_puppet(self) -> None:
del self.by_custom_mxid[self.custom_mxid]
self.custom_mxid = None
self.access_token = None
- self.next_batch = None
await self.save()
self.intent = self._fresh_intent()
@@ -293,7 +292,7 @@ async def start(
except MatrixInvalidToken as e:
if retry_auto_login and self.custom_mxid and self.can_auto_login(self.custom_mxid):
self.log.debug(f"Got {e.errcode} while trying to initialize custom mxid")
- await self.switch_mxid("auto", self.custom_mxid, start_sync_task=start_sync_task)
+ await self.switch_mxid("auto", self.custom_mxid)
return
self.log.warning(f"Got {e.errcode} while trying to initialize custom mxid")
whoami = None
@@ -316,19 +315,14 @@ async def start(
if device_keys and len(device_keys.keys) > 0:
await self._invalidate_double_puppet()
raise EncryptionKeysFound()
- if self.sync_with_custom_puppets and start_sync_task:
- if self._sync_task:
- self._sync_task.cancel()
- self.log.info(f"Initialized custom mxid: {whoami.user_id}. Starting sync task")
- self._sync_task = asyncio.create_task(self._try_sync())
- else:
- self.log.info(f"Initialized custom mxid: {whoami.user_id}. Not starting sync task")
+ self.log.info(f"Initialized custom mxid: {whoami.user_id}")
def stop(self) -> None:
- """Cancel the sync task."""
- if self._sync_task:
- self._sync_task.cancel()
- self._sync_task = None
+ """
+ No-op
+
+ .. deprecated:: 0.20.1
+ """
async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
"""
@@ -351,112 +345,3 @@ async def _leave_rooms_with_default_user(self) -> None:
await self.intent.ensure_joined(room_id)
except (IntentError, MatrixRequestError):
pass
-
- def _create_sync_filter(self) -> Awaitable[FilterID]:
- all_events = EventType.find("*")
- return self.intent.create_filter(
- Filter(
- account_data=EventFilter(types=[all_events]),
- presence=EventFilter(
- types=[EventType.PRESENCE],
- senders=[self.custom_mxid] if self.only_handle_own_synced_events else None,
- ),
- room=RoomFilter(
- include_leave=False,
- state=StateFilter(not_types=[all_events]),
- timeline=RoomEventFilter(not_types=[all_events]),
- account_data=RoomEventFilter(not_types=[all_events]),
- ephemeral=RoomEventFilter(
- types=[
- EventType.TYPING,
- EventType.RECEIPT,
- ]
- ),
- ),
- )
- )
-
- def _filter_events(self, room_id: RoomID, events: list[dict]) -> Iterator[Event]:
- for event in events:
- event["room_id"] = room_id
- if self.only_handle_own_synced_events:
- # We only want events about the custom puppet user, but we can't use
- # filters for typing and read receipt events.
- evt_type = EventType.find(event.get("type", None))
- event.setdefault("content", {})
- if evt_type == EventType.TYPING:
- is_typing = self.custom_mxid in event["content"].get("user_ids", [])
- event["content"]["user_ids"] = [self.custom_mxid] if is_typing else []
- elif evt_type == EventType.RECEIPT:
- try:
- event_id, receipt = event["content"].popitem()
- data = receipt["m.read"][self.custom_mxid]
- event["content"] = {event_id: {"m.read": {self.custom_mxid: data}}}
- except KeyError:
- continue
- yield event
-
- def _handle_sync(self, sync_resp: dict) -> None:
- # Get events from rooms -> join -> [room_id] -> ephemeral -> events (array)
- ephemeral_events = (
- event
- for room_id, data in sync_resp.get("rooms", {}).get("join", {}).items()
- for event in self._filter_events(room_id, data.get("ephemeral", {}).get("events", []))
- )
-
- # Get events from presence -> events (array)
- presence_events = sync_resp.get("presence", {}).get("events", [])
-
- # Deserialize and handle all events
- for event in chain(ephemeral_events, presence_events):
- asyncio.create_task(self.mx.try_handle_sync_event(Event.deserialize(event)))
-
- async def _try_sync(self) -> None:
- try:
- await self._sync()
- except asyncio.CancelledError:
- self.log.info(f"Syncing for {self.custom_mxid} cancelled")
- except Exception:
- self.log.critical(f"Fatal error syncing {self.custom_mxid}", exc_info=True)
-
- async def _sync(self) -> None:
- if not self.is_real_user:
- self.log.warning("Called sync() for non-custom puppet.")
- return
- custom_mxid: UserID = self.custom_mxid
- access_token_at_start: str = self.access_token
- errors: int = 0
- filter_id: FilterID = await self._create_sync_filter()
- self.log.debug(f"Starting syncer for {custom_mxid} with sync filter {filter_id}.")
- while access_token_at_start == self.access_token:
- try:
- cur_batch = self.next_batch
- sync_resp = await self.intent.sync(
- filter_id=filter_id, since=cur_batch, set_presence=PresenceState.OFFLINE
- )
- try:
- self.next_batch = sync_resp.get("next_batch", None)
- except Exception:
- self.log.warning("Failed to store next batch", exc_info=True)
- errors = 0
- if cur_batch is not None:
- self._handle_sync(sync_resp)
- except MatrixInvalidToken:
- # TODO when not using syncing, we should still check this occasionally and relogin
- self.log.warning(f"Access token for {custom_mxid} got invalidated, restarting...")
- await self.start(retry_auto_login=True, start_sync_task=False)
- if self.is_real_user:
- self.log.info("Successfully relogined custom puppet, continuing sync")
- filter_id = await self._create_sync_filter()
- access_token_at_start = self.access_token
- else:
- self.log.warning("Something went wrong during relogin")
- raise
- except (MatrixError, ClientConnectionError, asyncio.TimeoutError) as e:
- errors += 1
- wait = min(errors, 11) ** 2
- self.log.warning(
- f"Syncer for {custom_mxid} errored: {e}. Waiting for {wait} seconds..."
- )
- await asyncio.sleep(wait)
- self.log.debug(f"Syncer for custom puppet {custom_mxid} stopped.")
diff --git a/mautrix/bridge/e2ee.py b/mautrix/bridge/e2ee.py
index 516e54bd..1525b388 100644
--- a/mautrix/bridge/e2ee.py
+++ b/mautrix/bridge/e2ee.py
@@ -13,7 +13,7 @@
from mautrix.appservice import AppService
from mautrix.client import Client, InternalEventType, SyncStore
from mautrix.crypto import CryptoStore, OlmMachine, PgCryptoStore, RejectKeyShare, StateStore
-from mautrix.errors import EncryptionError, SessionNotFound
+from mautrix.errors import EncryptionError, MForbidden, MNotFound, SessionNotFound
from mautrix.types import (
JSON,
DeviceIdentity,
@@ -34,19 +34,13 @@
StateFilter,
TrustState,
)
+from mautrix.util import background_task
from mautrix.util.async_db import Database
from mautrix.util.logging import TraceLogger
from .. import bridge as br
from .crypto_state_store import PgCryptoStateStore
-try:
- from mautrix.client.state_store.sqlalchemy import UserProfile
-except ImportError:
- if __optional_imports__:
- raise
- UserProfile = None
-
class EncryptionManager:
loop: asyncio.AbstractEventLoop
@@ -61,6 +55,10 @@ class EncryptionManager:
min_send_trust: TrustState
key_sharing_enabled: bool
appservice_mode: bool
+ periodically_delete_expired_keys: bool
+ delete_outdated_inbound: bool
+ msc4190: bool
+ self_sign: bool
bridge: br.Bridge
az: AppService
@@ -68,6 +66,7 @@ class EncryptionManager:
_id_suffix: str
_share_session_events: dict[RoomID, asyncio.Event]
+ _key_delete_task: asyncio.Task | None
def __init__(
self,
@@ -100,6 +99,7 @@ def __init__(
sync_store=self.crypto_store,
log=self.log.getChild("client"),
default_retry_count=default_http_retry_count,
+ state_store=self.bridge.state_store,
)
self.crypto = OlmMachine(self.client, self.crypto_store, self.state_store)
self.client.add_event_handler(InternalEventType.SYNC_STOPPED, self._exit_on_sync_fail)
@@ -110,11 +110,30 @@ def __init__(
self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"])
self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"]
self.appservice_mode = bridge.config["bridge.encryption.appservice"]
+ self.msc4190 = bridge.config["bridge.encryption.msc4190"]
+ self.self_sign = bridge.config["bridge.encryption.self_sign"]
if self.appservice_mode:
self.az.otk_handler = self.crypto.handle_as_otk_counts
self.az.device_list_handler = self.crypto.handle_as_device_lists
self.az.to_device_handler = self.crypto.handle_as_to_device_event
+ self.periodically_delete_expired_keys = False
+ self.delete_outdated_inbound = False
+ self._key_delete_task = None
+ del_cfg = bridge.config["bridge.encryption.delete_keys"]
+ if del_cfg:
+ self.crypto.delete_outbound_keys_on_ack = del_cfg["delete_outbound_on_ack"]
+ self.crypto.dont_store_outbound_keys = del_cfg["dont_store_outbound"]
+ self.crypto.delete_previous_keys_on_receive = del_cfg["delete_prev_on_new_session"]
+ self.crypto.ratchet_keys_on_decrypt = del_cfg["ratchet_on_decrypt"]
+ self.crypto.delete_fully_used_keys_on_decrypt = del_cfg["delete_fully_used_on_decrypt"]
+ self.crypto.delete_keys_on_device_delete = del_cfg["delete_on_device_delete"]
+ self.periodically_delete_expired_keys = del_cfg["periodically_delete_expired"]
+ self.delete_outdated_inbound = del_cfg["delete_outdated_inbound"]
+ self.crypto.disable_device_change_key_rotation = bridge.config[
+ "bridge.encryption.rotation.disable_device_change_key_rotation"
+ ]
+
async def _exit_on_sync_fail(self, data) -> None:
if data["error"]:
self.log.critical("Exiting due to crypto sync error")
@@ -132,9 +151,9 @@ async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInf
f"Rejecting key request from blacklisted device "
f"{device.user_id}/{device.device_id}",
code=RoomKeyWithheldCode.BLACKLISTED,
- reason="You have been blacklisted by this device",
+ reason="Your device has been blacklisted by the bridge",
)
- elif device.trust >= self.crypto.share_keys_min_trust:
+ elif await self.crypto.resolve_trust(device) >= self.crypto.share_keys_min_trust:
portal = await self.bridge.get_portal(request.room_id)
if portal is None:
raise RejectKeyShare(
@@ -161,7 +180,7 @@ async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInf
f"Rejecting key request from unverified device "
f"{device.user_id}/{device.device_id}",
code=RoomKeyWithheldCode.UNVERIFIED,
- reason="You have not been verified by this device",
+ reason="Your device is not trusted by the bridge",
)
def _ignore_user(self, user_id: str) -> bool:
@@ -230,12 +249,13 @@ async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> M
return decrypted
async def start(self) -> None:
- flows = await self.client.get_login_flows()
- if not flows.supports_type(LoginType.APPSERVICE):
- self.log.critical(
- "Encryption enabled in config, but homeserver does not support appservice login"
- )
- sys.exit(30)
+ if not self.msc4190:
+ flows = await self.client.get_login_flows()
+ if not flows.supports_type(LoginType.APPSERVICE):
+ self.log.critical(
+ "Encryption enabled in config, but homeserver does not support appservice login"
+ )
+ sys.exit(30)
self.log.debug("Logging in with bridge bot user")
if self.crypto_db:
try:
@@ -246,27 +266,84 @@ async def start(self) -> None:
device_id = await self.crypto_store.get_device_id()
if device_id:
self.log.debug(f"Found device ID in database: {device_id}")
- # We set the API token to the AS token here to authenticate the appservice login
- # It'll get overridden after the login
- self.client.api.token = self.az.as_token
- await self.client.login(
- login_type=LoginType.APPSERVICE,
- device_name=self.device_name,
- device_id=device_id,
- store_access_token=True,
- update_hs_url=False,
- )
+
+ if self.msc4190:
+ if not device_id:
+ self.log.debug("Creating bot device with MSC4190")
+ self.client.api.token = self.az.as_token
+ await self.client.create_device_msc4190(
+ device_id=device_id, initial_display_name=self.device_name
+ )
+ else:
+ # We set the API token to the AS token here to authenticate the appservice login
+ # It'll get overridden after the login
+ self.client.api.token = self.az.as_token
+ await self.client.login(
+ login_type=LoginType.APPSERVICE,
+ device_name=self.device_name,
+ device_id=device_id,
+ store_access_token=True,
+ update_hs_url=False,
+ )
+
await self.crypto.load()
if not device_id:
await self.crypto_store.put_device_id(self.client.device_id)
self.log.debug(f"Logged in with new device ID {self.client.device_id}")
+ await self.crypto.share_keys()
elif self.crypto.account.shared:
await self._verify_keys_are_on_server()
+ else:
+ await self.crypto.share_keys()
+ if self.self_sign:
+ trust_state = await self.crypto.resolve_trust(self.crypto.own_identity)
+ if trust_state < TrustState.CROSS_SIGNED_UNTRUSTED:
+ recovery_key = await self.crypto.generate_recovery_key()
+ self.log.info(f"Generated recovery key and signed own device: {recovery_key}")
+ else:
+ self.log.debug(f"Own device is already verified ({trust_state})")
if self.appservice_mode:
self.log.info("End-to-bridge encryption support is enabled (appservice mode)")
else:
_ = self.client.start(self._filter)
self.log.info("End-to-bridge encryption support is enabled (sync mode)")
+ if self.delete_outdated_inbound:
+ deleted = await self.crypto_store.redact_outdated_group_sessions()
+ if len(deleted) > 0:
+ self.log.debug(
+ f"Deleted {len(deleted)} inbound keys which lacked expiration metadata"
+ )
+ if self.periodically_delete_expired_keys:
+ self._key_delete_task = background_task.create(self._periodically_delete_keys())
+ background_task.create(self._resync_encryption_info())
+
+ async def _resync_encryption_info(self) -> None:
+ rows = await self.crypto_db.fetch(
+ """SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'"""
+ )
+ room_ids = [row["room_id"] for row in rows]
+ if not room_ids:
+ return
+ self.log.debug(f"Resyncing encryption state event in rooms: {room_ids}")
+ for room_id in room_ids:
+ try:
+ evt = await self.client.get_state_event(room_id, EventType.ROOM_ENCRYPTION)
+ except (MNotFound, MForbidden) as e:
+ self.log.debug(f"Failed to get encryption state in {room_id}: {e}")
+ q = """
+ UPDATE mx_room_state SET encryption=NULL
+ WHERE room_id=$1 AND encryption='{"resync":true}'
+ """
+ await self.crypto_db.execute(q, room_id)
+ else:
+ self.log.debug(f"Resynced encryption state in {room_id}: {evt}")
+ q = """
+ UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2
+ WHERE room_id=$3 AND max_age IS NULL and max_messages IS NULL
+ """
+ await self.crypto_db.execute(
+ q, evt.rotation_period_ms, evt.rotation_period_msgs, room_id
+ )
async def _verify_keys_are_on_server(self) -> None:
self.log.debug("Making sure keys are still on server")
@@ -289,6 +366,9 @@ async def _verify_keys_are_on_server(self) -> None:
sys.exit(34)
async def stop(self) -> None:
+ if self._key_delete_task:
+ self._key_delete_task.cancel()
+ self._key_delete_task = None
self.client.stop()
await self.crypto_store.close()
if self.crypto_db:
@@ -308,3 +388,12 @@ def _filter(self) -> Filter:
ephemeral=RoomEventFilter(not_types=[all_events]),
),
)
+
+ async def _periodically_delete_keys(self) -> None:
+ while True:
+ deleted = await self.crypto_store.redact_expired_group_sessions()
+ if deleted:
+ self.log.info(f"Deleted expired megolm sessions: {deleted}")
+ else:
+ self.log.debug("No expired megolm sessions found")
+ await asyncio.sleep(24 * 60 * 60)
diff --git a/mautrix/bridge/matrix.py b/mautrix/bridge/matrix.py
index 144264aa..e5399094 100644
--- a/mautrix/bridge/matrix.py
+++ b/mautrix/bridge/matrix.py
@@ -5,6 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
+from collections import defaultdict
import asyncio
import logging
import sys
@@ -56,7 +57,7 @@
Version,
VersionsResponse,
)
-from mautrix.util import markdown
+from mautrix.util import background_task, markdown
from mautrix.util.logging import TraceLogger
from mautrix.util.message_send_checkpoint import (
CHECKPOINT_TYPES,
@@ -143,6 +144,7 @@ class BaseMatrixHandler:
media_config: MediaRepoConfig
versions: VersionsResponse
minimum_spec_version: Version = SpecVersions.V11
+ room_locks: dict[str, asyncio.Lock]
user_id_prefix: str
user_id_suffix: str
@@ -159,6 +161,7 @@ def __init__(
self.media_config = MediaRepoConfig(upload_size=50 * 1024 * 1024)
self.versions = VersionsResponse.deserialize({"versions": ["v1.3"]})
self.az.matrix_event_handler(self.int_handle_event)
+ self.room_locks = defaultdict(asyncio.Lock)
self.e2ee = None
self.require_e2ee = False
@@ -218,34 +221,46 @@ async def check_versions(self) -> None:
async def wait_for_connection(self) -> None:
self.log.info("Ensuring connectivity to homeserver")
- errors = 0
- tried_to_register = False
while True:
try:
self.versions = await self.az.intent.versions()
- await self.check_versions()
- await self.az.intent.whoami()
break
- except (MUnknownToken, MExclusive):
- # These are probably not going to resolve themselves by waiting
- raise
except MForbidden:
- if not tried_to_register:
- self.log.debug(
- "Whoami endpoint returned M_FORBIDDEN, "
- "trying to register bridge bot before retrying..."
- )
- await self.az.intent.ensure_registered()
- tried_to_register = True
- else:
- raise
+ self.log.debug(
+ "/versions endpoint returned M_FORBIDDEN, "
+ "trying to register bridge bot before retrying..."
+ )
+ await self.az.intent.ensure_registered()
except Exception:
- errors += 1
- if errors <= 6:
- self.log.exception("Connection to homeserver failed, retrying in 10 seconds")
- await asyncio.sleep(10)
- else:
- raise
+ self.log.exception("Connection to homeserver failed, retrying in 10 seconds")
+ await asyncio.sleep(10)
+ await self.check_versions()
+ try:
+ await self.az.intent.whoami()
+ except MForbidden:
+ self.log.debug(
+ "Whoami endpoint returned M_FORBIDDEN, "
+ "trying to register bridge bot before retrying..."
+ )
+ await self.az.intent.ensure_registered()
+ await self.az.intent.whoami()
+ if self.versions.supports("fi.mau.msc2659.stable") or self.versions.supports_at_least(
+ SpecVersions.V17
+ ):
+ try:
+ txn_id = self.az.intent.api.get_txn_id()
+ duration = await self.az.ping_self(txn_id)
+ self.log.debug(
+ "Homeserver->bridge connection works, "
+ f"roundtrip time is {duration} ms (txn ID: {txn_id})"
+ )
+ except Exception:
+ self.log.exception("Error checking homeserver -> bridge connection")
+ sys.exit(16)
+ else:
+ self.log.debug(
+ "Homeserver does not support checking status of homeserver -> bridge connection"
+ )
try:
self.media_config = await self.az.intent.get_media_repo_config()
except Exception:
@@ -269,6 +284,16 @@ async def init_as_bot(self) -> None:
except Exception:
self.log.exception("Failed to set bot avatar")
+ if self.bridge.homeserver_software.is_hungry and self.bridge.beeper_network_name:
+ self.log.debug("Setting contact info on the appservice bot")
+ await self.az.intent.beeper_update_profile(
+ {
+ "com.beeper.bridge.service": self.bridge.beeper_service_name,
+ "com.beeper.bridge.network": self.bridge.beeper_network_name,
+ "com.beeper.bridge.is_bridge_bot": True,
+ }
+ )
+
async def init_encryption(self) -> None:
if self.e2ee:
await self.e2ee.start()
@@ -382,7 +407,7 @@ async def handle_puppet_nonportal_invite(
members = await intent.get_room_members(room_id)
except MatrixError:
self.log.exception(f"Failed to get state after joining {room_id} as {intent.mxid}")
- asyncio.create_task(intent.leave_room(room_id, reason="Internal error"))
+ background_task.create(intent.leave_room(room_id, reason="Internal error"))
return
if create_evt.type == RoomType.SPACE:
await self.handle_puppet_space_invite(room_id, puppet, invited_by, evt)
@@ -400,19 +425,20 @@ async def handle_puppet_invite(
await intent.leave_room(room_id, reason="You're not allowed to invite this ghost.")
return
- portal = await self.bridge.get_portal(room_id)
- if portal:
- try:
- await portal.handle_matrix_invite(invited_by, puppet)
- except br.RejectMatrixInvite as e:
- await intent.leave_room(room_id, reason=e.message)
- except br.IgnoreMatrixInvite:
- pass
+ async with self.room_locks[room_id]:
+ portal = await self.bridge.get_portal(room_id)
+ if portal:
+ try:
+ await portal.handle_matrix_invite(invited_by, puppet)
+ except br.RejectMatrixInvite as e:
+ await intent.leave_room(room_id, reason=e.message)
+ except br.IgnoreMatrixInvite:
+ pass
+ else:
+ await intent.join_room(room_id)
+ return
else:
- await intent.join_room(room_id)
- return
- else:
- await self.handle_puppet_nonportal_invite(room_id, puppet, invited_by, evt)
+ await self.handle_puppet_nonportal_invite(room_id, puppet, invited_by, evt)
async def handle_invite(
self, room_id: RoomID, user_id: UserID, invited_by: br.BaseUser, evt: StateEvent
@@ -537,6 +563,28 @@ def is_command(self, message: MessageEventContent) -> tuple[bool, str]:
text = text[len(prefix) + 1 :].lstrip()
return is_command, text
+ async def _send_mss(
+ self,
+ evt: Event,
+ status: MessageStatus,
+ reason: MessageStatusReason | None = None,
+ error: str | None = None,
+ message: str | None = None,
+ ) -> None:
+ if not self.config.get("bridge.message_status_events", False):
+ return
+ status_content = BeeperMessageStatusEventContent(
+ network="", # TODO set network properly
+ relates_to=RelatesTo(rel_type=RelationType.REFERENCE, event_id=evt.event_id),
+ status=status,
+ reason=reason,
+ error=error,
+ message=message,
+ )
+ await self.az.intent.send_message_event(
+ evt.room_id, EventType.BEEPER_MESSAGE_STATUS, status_content
+ )
+
async def _send_crypto_status_error(
self,
evt: Event,
@@ -574,19 +622,13 @@ async def _send_crypto_status_error(
"it by default on the bridge (bridge -> encryption -> default)."
)
- if self.config.get("bridge.message_status_events", False):
- status_content = BeeperMessageStatusEventContent(
- network="", # TODO set network properly
- relates_to=RelatesTo(rel_type=RelationType.REFERENCE, event_id=evt.event_id),
- status=MessageStatus.RETRIABLE if is_final else MessageStatus.PENDING,
- reason=MessageStatusReason.UNDECRYPTABLE,
- error=msg,
- message=err.human_message if err else None,
- )
- status_content.fill_legacy_booleans()
- await self.az.intent.send_message_event(
- evt.room_id, EventType.BEEPER_MESSAGE_STATUS, status_content
- )
+ await self._send_mss(
+ evt,
+ status=MessageStatus.RETRIABLE if is_final else MessageStatus.PENDING,
+ reason=MessageStatusReason.UNDECRYPTABLE,
+ error=str(err),
+ message=err.human_message if err else None,
+ )
return event_id
@@ -678,6 +720,13 @@ async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) -
except Exception as e:
self.log.debug(f"Error handling command {command} from {sender}: {e}")
self._send_message_checkpoint(evt, MessageSendCheckpointStep.COMMAND, e)
+ await self._send_mss(
+ evt,
+ status=MessageStatus.FAIL,
+ reason=MessageStatusReason.GENERIC_ERROR,
+ error="",
+ message="Command execution failed",
+ )
else:
await MessageSendCheckpoint(
event_id=event_id,
@@ -693,6 +742,7 @@ async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) -
self.az.as_token,
self.log,
)
+ await self._send_mss(evt, status=MessageStatus.SUCCESS)
else:
self.log.debug(
f"Ignoring event {event_id} from {sender.mxid}:"
@@ -701,6 +751,13 @@ async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) -
self._send_message_checkpoint(
evt, MessageSendCheckpointStep.COMMAND, "not a command and not a portal room"
)
+ await self._send_mss(
+ evt,
+ status=MessageStatus.FAIL,
+ reason=MessageStatusReason.UNSUPPORTED,
+ error="Unknown room",
+ message="Unknown room",
+ )
async def _is_direct_chat(self, room_id: RoomID) -> tuple[bool, bool]:
try:
@@ -784,7 +841,7 @@ async def handle_encrypted(self, evt: EncryptedEvent) -> None:
try:
decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=3)
except SessionNotFound as e:
- await self._handle_encrypted_wait(evt, e, wait=6)
+ await self._handle_encrypted_wait(evt, e, wait=22)
except DecryptionError as e:
self.log.warning(f"Failed to decrypt {evt.event_id}: {e}")
self.log.trace("%s decryption traceback:", evt.event_id, exc_info=True)
@@ -799,7 +856,7 @@ async def _handle_encrypted_wait(
f"Couldn't find session {err.session_id} trying to decrypt {evt.event_id},"
" waiting even longer"
)
- asyncio.create_task(
+ background_task.create(
self.e2ee.crypto.request_room_key(
evt.room_id,
evt.content.sender_key,
@@ -876,7 +933,7 @@ def _send_message_checkpoint(
info=str(err) if err else None,
retry_num=retry_num,
)
- asyncio.create_task(checkpoint.send(endpoint, self.az.as_token, self.log))
+ background_task.create(checkpoint.send(endpoint, self.az.as_token, self.log))
allowed_event_classes: tuple[type, ...] = (
MessageEvent,
diff --git a/mautrix/bridge/portal.py b/mautrix/bridge/portal.py
index f51a0ae8..05d67e3b 100644
--- a/mautrix/bridge/portal.py
+++ b/mautrix/bridge/portal.py
@@ -30,6 +30,7 @@
TextMessageEventContent,
UserID,
)
+from mautrix.util import background_task
from mautrix.util.logging import TraceLogger
from mautrix.util.simple_lock import SimpleLock
@@ -385,7 +386,7 @@ async def _disappear_event(self, msg: br.AbstractDisappearingMessage) -> None:
await self._do_disappear(msg.event_id)
self.log.debug(f"Expired event {msg.event_id} disappeared successfully")
except Exception as e:
- self.log.warning(f"Failed to make expired event {msg.event_id} disappear: {e}", e)
+ self.log.warning(f"Failed to make expired event {msg.event_id} disappear: {e}")
async def _do_disappear(self, event_id: EventID) -> None:
await self.main_intent.redact(self.mxid, event_id)
@@ -402,7 +403,7 @@ async def restart_scheduled_disappearing(cls) -> None:
for msg in msgs:
portal = await cls.bridge.get_portal(msg.room_id)
if portal and portal.mxid:
- asyncio.create_task(portal._disappear_event(msg))
+ background_task.create(portal._disappear_event(msg))
else:
await msg.delete()
@@ -418,7 +419,7 @@ async def schedule_disappearing(self) -> None:
for msg in msgs:
msg.start_timer()
await msg.update()
- asyncio.create_task(self._disappear_event(msg))
+ background_task.create(self._disappear_event(msg))
async def _send_message(
self,
@@ -431,7 +432,7 @@ async def _send_message(
event_type, content = await self.matrix.e2ee.encrypt(self.mxid, event_type, content)
event_id = await intent.send_message_event(self.mxid, event_type, content, **kwargs)
if intent.api.is_real_user:
- asyncio.create_task(intent.mark_read(self.mxid, event_id))
+ background_task.create(intent.mark_read(self.mxid, event_id))
return event_id
@property
diff --git a/mautrix/bridge/state_store/__init__.py b/mautrix/bridge/state_store/__init__.py
index 0e8137a5..b990ac86 100644
--- a/mautrix/bridge/state_store/__init__.py
+++ b/mautrix/bridge/state_store/__init__.py
@@ -1 +1 @@
-__all__ = ["asyncpg", "sqlalchemy"]
+__all__ = ["asyncpg"]
diff --git a/mautrix/bridge/state_store/sqlalchemy.py b/mautrix/bridge/state_store/sqlalchemy.py
deleted file mode 100644
index 2fa7353e..00000000
--- a/mautrix/bridge/state_store/sqlalchemy.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright (c) 2022 Tulir Asokan
-#
-# This Source Code Form is subject to the terms of the Mozilla Public
-# License, v. 2.0. If a copy of the MPL was not distributed with this
-# file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from __future__ import annotations
-
-from typing import Awaitable, Callable, Union
-
-from mautrix.appservice.state_store.sqlalchemy import SQLASStateStore
-from mautrix.types import UserID
-
-from ..puppet import BasePuppet
-
-GetPuppetFunc = Union[
- Callable[[UserID], Awaitable[BasePuppet]], Callable[[UserID, bool], Awaitable[BasePuppet]]
-]
-
-
-class SQLBridgeStateStore(SQLASStateStore):
- def __init__(self, get_puppet: GetPuppetFunc, get_double_puppet: GetPuppetFunc) -> None:
- super().__init__()
- self.get_puppet = get_puppet
- self.get_double_puppet = get_double_puppet
-
- async def is_registered(self, user_id: UserID) -> bool:
- puppet = await self.get_puppet(user_id)
- if puppet:
- return puppet.is_registered
- custom_puppet = await self.get_double_puppet(user_id)
- if custom_puppet:
- return True
- return await super().is_registered(user_id)
-
- async def registered(self, user_id: UserID) -> None:
- puppet = await self.get_puppet(user_id, True)
- if puppet:
- puppet.is_registered = True
- await puppet.save()
- else:
- await super().registered(user_id)
diff --git a/mautrix/bridge/user.py b/mautrix/bridge/user.py
index bea6bdf7..66860c04 100644
--- a/mautrix/bridge/user.py
+++ b/mautrix/bridge/user.py
@@ -16,6 +16,7 @@
from mautrix.appservice import AppService
from mautrix.errors import MNotFound
from mautrix.types import EventID, EventType, Membership, MessageType, RoomID, UserID
+from mautrix.util import background_task
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
from mautrix.util.logging import TraceLogger
from mautrix.util.message_send_checkpoint import (
@@ -244,7 +245,7 @@ def send_remote_checkpoint(
"""
if not self.bridge.config["homeserver.message_send_checkpoint_endpoint"]:
return WrappedTask(task=None)
- task = asyncio.create_task(
+ task = background_task.create(
MessageSendCheckpoint(
event_id=event_id,
room_id=room_id,
diff --git a/mautrix/client/api/authentication.py b/mautrix/client/api/authentication.py
index cd77272d..0f6249ea 100644
--- a/mautrix/client/api/authentication.py
+++ b/mautrix/client/api/authentication.py
@@ -5,6 +5,8 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
+import secrets
+
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import (
@@ -117,6 +119,19 @@ async def login(
self.api.base_url = base_url.rstrip("/")
return resp_data
+ async def create_device_msc4190(self, device_id: str, initial_display_name: str) -> None:
+ """
+ Create a Device for a user of the homeserver using appservice interface defined in MSC4190
+ """
+ if len(device_id) == 0:
+ device_id = DeviceID(secrets.token_urlsafe(10))
+ self.api.as_user_id = self.mxid
+ await self.api.request(
+ Method.PUT, Path.v3.devices[device_id], {"display_name": initial_display_name}
+ )
+ self.api.as_device_id = device_id
+ self.device_id = device_id
+
async def logout(self, clear_access_token: bool = True) -> None:
"""
Invalidates an existing access token, so that it can no longer be used for authorization.
diff --git a/mautrix/client/api/events.py b/mautrix/client/api/events.py
index 99e1c178..bd415b9b 100644
--- a/mautrix/client/api/events.py
+++ b/mautrix/client/api/events.py
@@ -5,7 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Awaitable
+from typing import Awaitable, Literal, overload
import json
from mautrix.api import Method, Path
@@ -168,12 +168,41 @@ async def get_event_context(
)
return EventContext.deserialize(resp)
+ @overload
async def get_state_event(
self,
room_id: RoomID,
event_type: EventType,
state_key: str = "",
- ) -> StateEventContent:
+ *,
+ format: Literal["content"] = "content",
+ ) -> StateEventContent: ...
+ @overload
+ async def get_state_event(
+ self,
+ room_id: RoomID,
+ event_type: EventType,
+ state_key: str = "",
+ *,
+ format: Literal["event"],
+ ) -> StateEvent: ...
+ @overload
+ async def get_state_event(
+ self,
+ room_id: RoomID,
+ event_type: EventType,
+ state_key: str = "",
+ *,
+ format: str = "content",
+ ) -> StateEventContent | StateEvent: ...
+ async def get_state_event(
+ self,
+ room_id: RoomID,
+ event_type: EventType,
+ state_key: str = "",
+ *,
+ format: str = "content",
+ ) -> StateEventContent | StateEvent:
"""
Looks up the contents of a state event in a room. If the user is joined to the room then the
state is taken from the current state of the room. If the user has left the room then the
@@ -185,6 +214,9 @@ async def get_state_event(
room_id: The ID of the room to look up the state in.
event_type: The type of state to look up.
state_key: The key of the state to look up. Defaults to empty string.
+ format: The format of the state event to return. Defaults to "content", which only returns
+ the content of the state event. If set to "event", the full event is returned.
+ See https://github.com/matrix-org/matrix-spec/issues/1047 for more info.
Returns:
The state event.
@@ -192,11 +224,17 @@ async def get_state_event(
content = await self.api.request(
Method.GET,
Path.v3.rooms[room_id].state[event_type][state_key],
+ query_params={"format": format} if format != "content" else None,
metrics_method="getStateEvent",
)
- content["__mautrix_event_type"] = event_type
try:
- return StateEvent.deserialize_content(content)
+ if format == "content":
+ content["__mautrix_event_type"] = event_type
+ return StateEvent.deserialize_content(content)
+ elif format == "event":
+ return StateEvent.deserialize(content)
+ else:
+ return content
except SerializerError as e:
raise MatrixResponseError("Invalid state event in response") from e
@@ -357,15 +395,13 @@ async def get_messages(
try:
return PaginatedMessages(
content["start"],
- content["end"],
+ content.get("end"),
[Event.deserialize(event) for event in content["chunk"]],
)
except KeyError:
if "start" not in content:
raise MatrixResponseError("`start` not in response.")
- elif "end" not in content:
- raise MatrixResponseError("`start` not in response.")
- raise MatrixResponseError("`content` not in response.")
+ raise MatrixResponseError("`chunk` not in response.")
except SerializerError as e:
raise MatrixResponseError("Invalid events in response") from e
diff --git a/mautrix/client/api/filtering.py b/mautrix/client/api/filtering.py
index 09889ec5..b8df1f36 100644
--- a/mautrix/client/api/filtering.py
+++ b/mautrix/client/api/filtering.py
@@ -50,9 +50,11 @@ async def create_filter(self, filter_params: Filter) -> FilterID:
resp = await self.api.request(
Method.POST,
Path.v3.user[self.mxid].filter,
- filter_params.serialize()
- if isinstance(filter_params, Serializable)
- else filter_params,
+ (
+ filter_params.serialize()
+ if isinstance(filter_params, Serializable)
+ else filter_params
+ ),
)
try:
return resp["filter_id"]
diff --git a/mautrix/client/api/modules/crypto.py b/mautrix/client/api/modules/crypto.py
index af656916..2d879a63 100644
--- a/mautrix/client/api/modules/crypto.py
+++ b/mautrix/client/api/modules/crypto.py
@@ -5,13 +5,17 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Any
+from typing import Any, Union
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError
from mautrix.types import (
+ JSON,
ClaimKeysResponse,
+ CrossSigningKeys,
+ CrossSigningUsage,
DeviceID,
+ DeviceKeys,
EncryptionKeyAlgorithm,
EventType,
QueryKeysResponse,
@@ -82,7 +86,7 @@ async def send_to_one_device(
async def upload_keys(
self,
one_time_keys: dict[str, Any] | None = None,
- device_keys: dict[str, Any] | None = None,
+ device_keys: DeviceKeys | dict[str, Any] | None = None,
) -> dict[EncryptionKeyAlgorithm, int]:
"""
Publishes end-to-end encryption keys for the device.
@@ -102,8 +106,12 @@ async def upload_keys(
"""
data = {}
if device_keys:
+ if isinstance(device_keys, Serializable):
+ device_keys = device_keys.serialize()
data["device_keys"] = device_keys
if one_time_keys:
+ if isinstance(one_time_keys, Serializable):
+ one_time_keys = one_time_keys.serialize()
data["one_time_keys"] = one_time_keys
resp = await self.api.request(Method.POST, Path.v3.keys.upload, data)
try:
@@ -116,6 +124,43 @@ async def upload_keys(
except AttributeError as e:
raise MatrixResponseError("Invalid `one_time_key_counts` field in response.") from e
+ async def upload_cross_signing_keys(
+ self,
+ keys: dict[CrossSigningUsage, CrossSigningKeys],
+ auth: dict[str, JSON] | None = None,
+ ) -> None:
+ await self.api.request(
+ Method.POST,
+ Path.v3.keys.device_signing.upload,
+ {f"{usage}_key": key.serialize() for usage, key in keys.items()}
+ | ({"auth": auth} if auth else {}),
+ )
+
+ async def upload_one_signature(
+ self,
+ user_id: UserID,
+ device_id: DeviceID,
+ keys: Union[DeviceKeys, CrossSigningKeys],
+ ) -> None:
+ await self.api.request(
+ Method.POST, Path.v3.keys.signatures.upload, {user_id: {device_id: keys.serialize()}}
+ )
+ # TODO check failures
+
+ async def upload_many_signatures(
+ self,
+ signatures: dict[UserID, dict[DeviceID, Union[DeviceKeys, CrossSigningKeys]]],
+ ) -> None:
+ await self.api.request(
+ Method.POST,
+ Path.v3.keys.signatures.upload,
+ {
+ user_id: {device_id: keys.serialize() for device_id, keys in devices.items()}
+ for user_id, devices in signatures.items()
+ },
+ )
+ # TODO check failures
+
async def query_keys(
self,
device_keys: list[UserID] | set[UserID] | dict[UserID, list[DeviceID]],
diff --git a/mautrix/client/api/modules/media_repository.py b/mautrix/client/api/modules/media_repository.py
index b1de6751..b1d90cc1 100644
--- a/mautrix/client/api/modules/media_repository.py
+++ b/mautrix/client/api/modules/media_repository.py
@@ -10,6 +10,8 @@
import asyncio
import time
+from yarl import URL
+
from mautrix import __optional_imports__
from mautrix.api import MediaPath, Method
from mautrix.errors import MatrixResponseError, make_request_error
@@ -19,7 +21,10 @@
MediaRepoConfig,
MXOpenGraph,
SerializerError,
+ SpecVersions,
)
+from mautrix.util import background_task
+from mautrix.util.async_body import async_iter_bytes
from mautrix.util.opt_prometheus import Histogram
from ..base import BaseClientAPI
@@ -44,22 +49,19 @@ class MediaRepositoryMethods(BaseClientAPI):
downloading content from the media repository and for getting URL previews without leaking
client IPs.
- See also: `API reference `__
-
- There are also methods for supporting `MSC2246
- `__ which allows asynchronous
- uploads of media.
+ See also: `API reference `__
"""
- async def unstable_create_mxc(self) -> MediaCreateResponse:
+ async def create_mxc(self) -> MediaCreateResponse:
"""
- Create a media ID for uploading media to the homeserver. Requires the homeserver to have
- `MSC2246 `__ support.
+ Create a media ID for uploading media to the homeserver.
+
+ See also: `API reference `__
Returns:
- MediaCreateResponse Containing the MXC URI that can be used to upload a file to later, as well as an optional upload URL
+ MediaCreateResponse Containing the MXC URI that can be used to upload a file to later
"""
- resp = await self.api.request(Method.POST, MediaPath.unstable["fi.mau.msc2246"].create)
+ resp = await self.api.request(Method.POST, MediaPath.v1.create)
return MediaCreateResponse.deserialize(resp)
@contextmanager
@@ -85,21 +87,18 @@ async def upload_media(
"""
Upload a file to the content repository.
- See also: `API reference `__
+ See also: `API reference `__
Args:
data: The data to upload.
mime_type: The MIME type to send with the upload request.
filename: The filename to send with the upload request.
size: The file size to send with the upload request.
- mxc: An existing MXC URI which doesn't have content yet to upload into. Requires the
- homeserver to have MSC2246_ support.
- async_upload: Should the media be uploaded in the background (using MSC2246_)?
- If ``True``, this will create a MXC URI, start uploading in the background and then
- immediately return the created URI. This is mutually exclusive with manually
- passing the ``mxc`` parameter.
-
- .. _MSC2246: https://github.com/matrix-org/matrix-spec-proposals/pull/2246
+ mxc: An existing MXC URI which doesn't have content yet to upload into.
+ async_upload: Should the media be uploaded in the background?
+ If ``True``, this will create a MXC URI using :meth:`create_mxc`, start uploading
+ in the background, and then immediately return the created URI. This is mutually
+ exclusive with manually passing the ``mxc`` parameter.
Returns:
The MXC URI to the uploaded file.
@@ -126,19 +125,21 @@ async def upload_media(
if async_upload:
if mxc:
raise ValueError("async_upload and mxc can't be provided simultaneously")
- create_response = await self.unstable_create_mxc()
+ create_response = await self.create_mxc()
mxc = create_response.content_uri
- upload_url = create_response.upload_url
+ upload_url = create_response.unstable_upload_url
path = MediaPath.v3.upload
method = Method.POST
if mxc:
server_name, media_id = self.api.parse_mxc_uri(mxc)
if upload_url is None:
- path = MediaPath.unstable["fi.mau.msc2246"].upload[server_name][media_id]
+ path = MediaPath.v3.upload[server_name][media_id]
method = Method.PUT
else:
- path = MediaPath.unstable["fi.mau.msc2246"].upload[server_name][media_id].complete
+ path = (
+ MediaPath.unstable["com.beeper.msc3870"].upload[server_name][media_id].complete
+ )
if upload_url is not None:
task = self._upload_to_url(upload_url, path, headers, data, post_upload_query=query)
@@ -156,7 +157,7 @@ async def _try_upload():
except Exception as e:
self.log.error(f"Failed to upload {mxc}: {type(e).__name__}: {e}")
- asyncio.create_task(_try_upload())
+ background_task.create(_try_upload())
return mxc
else:
with self._observe_upload_time(size):
@@ -166,27 +167,40 @@ async def _try_upload():
except KeyError:
raise MatrixResponseError("`content_uri` not in response.")
- async def download_media(self, url: ContentURI, max_stall_ms: int | None = None) -> bytes:
+ async def download_media(self, url: ContentURI, timeout_ms: int | None = None) -> bytes:
"""
Download a file from the content repository.
- See also: `API reference `__
+ See also: `API reference `__
Args:
url: The MXC URI to download.
- max_stall_ms: The maximum number of milliseconds that the client is willing to wait to
- start receiving data. Used for MSC2246 Asynchronous Uploads.
+ timeout_ms: The maximum number of milliseconds that the client is willing to wait to
+ start receiving data. Used for asynchronous uploads.
Returns:
The raw downloaded data.
"""
- url = self.api.get_download_url(url)
+ authenticated = (await self.versions()).supports(SpecVersions.V111)
+ url = self.api.get_download_url(url, authenticated=authenticated)
query_params: dict[str, Any] = {"allow_redirect": "true"}
- if max_stall_ms is not None:
- query_params["max_stall_ms"] = max_stall_ms
- query_params["fi.mau.msc2246.max_stall_ms"] = max_stall_ms
- async with self.api.session.get(url, params=query_params) as response:
- return await response.read()
+ if timeout_ms is not None:
+ query_params["timeout_ms"] = timeout_ms
+ headers: dict[str, str] = {}
+ if authenticated:
+ headers["Authorization"] = f"Bearer {self.api.token}"
+ if self.api.as_user_id:
+ query_params["user_id"] = self.api.as_user_id
+ req_id = self.api.log_download_request(url, query_params)
+ start = time.monotonic()
+ async with self.api.session.get(url, params=query_params, headers=headers) as response:
+ try:
+ response.raise_for_status()
+ return await response.read()
+ finally:
+ self.api.log_download_request_done(
+ url, req_id, time.monotonic() - start, response.status
+ )
async def download_thumbnail(
self,
@@ -194,13 +208,13 @@ async def download_thumbnail(
width: int | None = None,
height: int | None = None,
resize_method: Literal["crop", "scale"] = None,
- allow_remote: bool = True,
- max_stall_ms: int | None = None,
+ allow_remote: bool | None = None,
+ timeout_ms: int | None = None,
):
"""
Download a thumbnail for a file in the content repository.
- See also: `API reference `__
+ See also: `API reference `__
Args:
url: The MXC URI to download.
@@ -212,13 +226,16 @@ async def download_thumbnail(
allow_remote: Indicates to the server that it should not attempt to fetch the media if
it is deemed remote. This is to prevent routing loops where the server contacts
itself.
- max_stall_ms: The maximum number of milliseconds that the client is willing to wait to
- start receiving data. Used for MSC2246 Asynchronous Uploads.
+ timeout_ms: The maximum number of milliseconds that the client is willing to wait to
+ start receiving data. Used for asynchronous Uploads.
Returns:
The raw downloaded data.
"""
- url = self.api.get_download_url(url, download_type="thumbnail")
+ authenticated = (await self.versions()).supports(SpecVersions.V111)
+ url = self.api.get_download_url(
+ url, download_type="thumbnail", authenticated=authenticated
+ )
query_params: dict[str, Any] = {"allow_redirect": "true"}
if width is not None:
query_params["width"] = width
@@ -227,12 +244,24 @@ async def download_thumbnail(
if resize_method is not None:
query_params["method"] = resize_method
if allow_remote is not None:
- query_params["allow_remote"] = allow_remote
- if max_stall_ms is not None:
- query_params["max_stall_ms"] = max_stall_ms
- query_params["fi.mau.msc2246.max_stall_ms"] = max_stall_ms
- async with self.api.session.get(url, params=query_params) as response:
- return await response.read()
+ query_params["allow_remote"] = str(allow_remote).lower()
+ if timeout_ms is not None:
+ query_params["timeout_ms"] = timeout_ms
+ headers: dict[str, str] = {}
+ if authenticated:
+ headers["Authorization"] = f"Bearer {self.api.token}"
+ if self.api.as_user_id:
+ query_params["user_id"] = self.api.as_user_id
+ req_id = self.api.log_download_request(url, query_params)
+ start = time.monotonic()
+ async with self.api.session.get(url, params=query_params, headers=headers) as response:
+ try:
+ response.raise_for_status()
+ return await response.read()
+ finally:
+ self.api.log_download_request_done(
+ url, req_id, time.monotonic() - start, response.status
+ )
async def get_url_preview(self, url: str, timestamp: int | None = None) -> MXOpenGraph:
"""
@@ -286,19 +315,24 @@ async def _upload_to_url(
headers: dict[str, str],
data: bytes | bytearray | AsyncIterable[bytes],
post_upload_query: dict[str, str],
+ min_iter_size: int = 25 * 1024 * 1024,
) -> None:
retry_count = self.api.default_retry_count
- backoff = 4
+ backoff = 2
+ do_fake_iter = data and hasattr(data, "__len__") and len(data) > min_iter_size
+ if do_fake_iter:
+ headers["Content-Length"] = str(len(data))
while True:
self.log.debug("Uploading media to external URL %s", upload_url)
upload_response = None
try:
+ req_data = async_iter_bytes(data) if do_fake_iter else data
upload_response = await self.api.session.put(
- upload_url, data=data, headers=headers
+ upload_url, data=req_data, headers=headers
)
upload_response.raise_for_status()
except Exception as e:
- if retry_count == 0:
+ if retry_count <= 0:
raise make_request_error(
http_status=upload_response.status if upload_response else -1,
text=(await upload_response.text()) if upload_response else "",
@@ -311,7 +345,7 @@ async def _upload_to_url(
)
await asyncio.sleep(backoff)
backoff *= 2
- retry_count = -1
+ retry_count -= 1
else:
break
diff --git a/mautrix/client/api/modules/misc.py b/mautrix/client/api/modules/misc.py
index 443877c1..8ff9c85a 100644
--- a/mautrix/client/api/modules/misc.py
+++ b/mautrix/client/api/modules/misc.py
@@ -50,7 +50,7 @@ async def set_typing(self, room_id: RoomID, timeout: int = 0) -> None:
Args:
room_id: The ID of the room in which the user is typing.
- timeout: The length of time in seconds to mark this user as typing.
+ timeout: The length of time in milliseconds to mark this user as typing.
"""
if timeout > 0:
content = {"typing": True, "timeout": timeout}
diff --git a/mautrix/client/api/rooms.py b/mautrix/client/api/rooms.py
index 9f935430..a488fba6 100644
--- a/mautrix/client/api/rooms.py
+++ b/mautrix/client/api/rooms.py
@@ -354,9 +354,12 @@ async def join_room(
except KeyError:
raise MatrixResponseError("`room_id` not in response.")
- fill_member_event_callback: Callable[
- [RoomID, UserID, MemberStateEventContent], Awaitable[MemberStateEventContent | None]
- ] | None
+ fill_member_event_callback: (
+ Callable[
+ [RoomID, UserID, MemberStateEventContent], Awaitable[MemberStateEventContent | None]
+ ]
+ | None
+ )
async def fill_member_event(
self, room_id: RoomID, user_id: UserID, content: MemberStateEventContent
diff --git a/mautrix/client/api/user_data.py b/mautrix/client/api/user_data.py
index f37cb769..9c380335 100644
--- a/mautrix/client/api/user_data.py
+++ b/mautrix/client/api/user_data.py
@@ -5,6 +5,8 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
+from typing import Any
+
from mautrix.api import Method, Path
from mautrix.errors import MatrixResponseError, MNotFound
from mautrix.types import ContentURI, Member, SerializerError, User, UserID, UserSearchResults
@@ -69,7 +71,7 @@ async def search_users(self, search_query: str, limit: int | None = 10) -> UserS
# region 10.2 Profiles
# API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#profiles
- async def set_displayname(self, displayname: str, check_current: bool = True) -> None:
+ async def set_displayname(self, displayname: str | None, check_current: bool = True) -> None:
"""
Set the display name of the current user.
@@ -79,7 +81,9 @@ async def set_displayname(self, displayname: str, check_current: bool = True) ->
displayname: The new display name for the user.
check_current: Whether or not to check if the displayname is already set.
"""
- if check_current and await self.get_displayname(self.mxid) == displayname:
+ if check_current and str_or_none(await self.get_displayname(self.mxid)) == str_or_none(
+ displayname
+ ):
return
await self.api.request(
Method.PUT,
@@ -110,7 +114,9 @@ async def get_displayname(self, user_id: UserID) -> str | None:
except KeyError:
return None
- async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = True) -> None:
+ async def set_avatar_url(
+ self, avatar_url: ContentURI | None, check_current: bool = True
+ ) -> None:
"""
Set the avatar of the current user.
@@ -120,7 +126,9 @@ async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = Tru
avatar_url: The ``mxc://`` URI to the new avatar.
check_current: Whether or not to check if the avatar is already set.
"""
- if check_current and await self.get_avatar_url(self.mxid) == avatar_url:
+ if check_current and str_or_none(await self.get_avatar_url(self.mxid)) == str_or_none(
+ avatar_url
+ ):
return
await self.api.request(
Method.PUT,
@@ -170,3 +178,23 @@ async def get_profile(self, user_id: UserID) -> Member:
raise MatrixResponseError("Invalid member in response") from e
# endregion
+
+ # region Beeper Custom Fields API
+
+ async def beeper_update_profile(self, custom_fields: dict[str, Any]) -> None:
+ """
+ Set custom fields on the user's profile. Only works on Hungryserv.
+
+ Args:
+ custom_fields: A dictionary of fields to set in the custom content of the profile.
+ """
+ await self.api.request(Method.PATCH, Path.v3.profile[self.mxid], custom_fields)
+
+ # endregion
+
+
+def str_or_none(v: str | None) -> str | None:
+ """
+ str_or_none empty string values to None
+ """
+ return None if v == "" else v
diff --git a/mautrix/client/state_store/__init__.py b/mautrix/client/state_store/__init__.py
index 5f76c7dc..98150ca4 100644
--- a/mautrix/client/state_store/__init__.py
+++ b/mautrix/client/state_store/__init__.py
@@ -10,5 +10,4 @@
"MemorySyncStore",
"SyncStore",
"asyncpg",
- "sqlalchemy",
]
diff --git a/mautrix/client/state_store/abstract.py b/mautrix/client/state_store/abstract.py
index e241d8f1..d5b1f5f2 100644
--- a/mautrix/client/state_store/abstract.py
+++ b/mautrix/client/state_store/abstract.py
@@ -121,6 +121,18 @@ async def set_power_levels(
) -> None:
pass
+ @abstractmethod
+ async def has_create_cached(self, room_id: RoomID) -> bool:
+ pass
+
+ @abstractmethod
+ async def get_create(self, room_id: RoomID) -> StateEvent | None:
+ pass
+
+ @abstractmethod
+ async def set_create(self, event: StateEvent) -> None:
+ pass
+
@abstractmethod
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
pass
@@ -135,7 +147,7 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent
@abstractmethod
async def set_encryption_info(
- self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, any]
+ self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
) -> None:
pass
@@ -143,9 +155,14 @@ async def update_state(self, evt: StateEvent) -> None:
if evt.type == EventType.ROOM_POWER_LEVELS:
await self.set_power_levels(evt.room_id, evt.content)
elif evt.type == EventType.ROOM_MEMBER:
+ evt.unsigned["mautrix_prev_membership"] = await self.get_member(
+ evt.room_id, UserID(evt.state_key)
+ )
await self.set_member(evt.room_id, UserID(evt.state_key), evt.content)
elif evt.type == EventType.ROOM_ENCRYPTION:
await self.set_encryption_info(evt.room_id, evt.content)
+ elif evt.type == EventType.ROOM_CREATE and evt.sender:
+ await self.set_create(evt)
async def get_membership(self, room_id: RoomID, user_id: UserID) -> Membership:
member = await self.get_member(room_id, user_id)
@@ -169,4 +186,7 @@ async def has_power_level(
room_levels = await self.get_power_levels(room_id)
if not room_levels:
return None
- return room_levels.get_user_level(user_id) >= room_levels.get_event_level(event_type)
+ create_event = await self.get_create(room_id)
+ return room_levels.get_user_level(user_id, create_event) >= room_levels.get_event_level(
+ event_type
+ )
diff --git a/mautrix/client/state_store/asyncpg/store.py b/mautrix/client/state_store/asyncpg/store.py
index d78c04d6..f4f8436f 100644
--- a/mautrix/client/state_store/asyncpg/store.py
+++ b/mautrix/client/state_store/asyncpg/store.py
@@ -16,6 +16,7 @@
RoomEncryptionStateEventContent,
RoomID,
Serializable,
+ StateEvent,
UserID,
)
from mautrix.util.async_db import Database, Scheme
@@ -223,6 +224,29 @@ async def set_power_levels(
json.dumps(content.serialize() if isinstance(content, Serializable) else content),
)
+ async def has_create_cached(self, room_id: RoomID) -> bool:
+ return bool(
+ await self.db.fetchval(
+ "SELECT create_event IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id
+ )
+ )
+
+ async def get_create(self, room_id: RoomID) -> StateEvent | None:
+ create_event_json = await self.db.fetchval(
+ "SELECT create_event FROM mx_room_state WHERE room_id=$1", room_id
+ )
+ if create_event_json is None:
+ return None
+ return StateEvent.parse_json(create_event_json)
+
+ async def set_create(self, event: StateEvent) -> None:
+ await self.db.execute(
+ "INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) "
+ "ON CONFLICT (room_id) DO UPDATE SET create_event=$2",
+ event.room_id,
+ json.dumps(event.serialize() if isinstance(event, Serializable) else event),
+ )
+
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
return bool(
await self.db.fetchval(
diff --git a/mautrix/client/state_store/asyncpg/upgrade.py b/mautrix/client/state_store/asyncpg/upgrade.py
index 0a489aae..20b2e5b2 100644
--- a/mautrix/client/state_store/asyncpg/upgrade.py
+++ b/mautrix/client/state_store/asyncpg/upgrade.py
@@ -14,17 +14,18 @@
)
-@upgrade_table.register(description="Latest revision", upgrades_to=2)
-async def upgrade_blank_to_v2(conn: Connection, scheme: Scheme) -> None:
- await conn.execute(
- """CREATE TABLE mx_room_state (
+@upgrade_table.register(description="Latest revision", upgrades_to=4)
+async def upgrade_blank_to_v4(conn: Connection, scheme: Scheme) -> None:
+ await conn.execute("""
+ CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
is_encrypted BOOLEAN,
has_full_member_list BOOLEAN,
encryption TEXT,
- power_levels TEXT
- )"""
- )
+ power_levels TEXT,
+ create_event TEXT
+ )
+ """)
membership_check = ""
if scheme != Scheme.SQLITE:
await conn.execute(
@@ -32,16 +33,16 @@ async def upgrade_blank_to_v2(conn: Connection, scheme: Scheme) -> None:
)
else:
membership_check = "CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock'))"
- await conn.execute(
- f"""CREATE TABLE mx_user_profile (
+ await conn.execute(f"""
+ CREATE TABLE mx_user_profile (
room_id TEXT,
user_id TEXT,
membership membership NOT NULL {membership_check},
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
- )"""
- )
+ )
+ """)
@upgrade_table.register(description="Stop using size-limited string fields")
@@ -54,3 +55,21 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT")
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN displayname TYPE TEXT")
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT")
+
+
+@upgrade_table.register(description="Mark rooms that need crypto state event resynced")
+async def upgrade_v3(conn: Connection) -> None:
+ if await conn.table_exists("portal"):
+ await conn.execute("""
+ INSERT INTO mx_room_state (room_id, encryption)
+ SELECT portal.mxid, '{"resync":true}' FROM portal
+ WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
+ ON CONFLICT (room_id) DO UPDATE
+ SET encryption=excluded.encryption
+ WHERE mx_room_state.encryption IS NULL
+ """)
+
+
+@upgrade_table.register(description="Add create event to room state cache")
+async def upgrade_v4(conn: Connection) -> None:
+ await conn.execute("ALTER TABLE mx_room_state ADD COLUMN create_event TEXT")
diff --git a/mautrix/client/state_store/file.py b/mautrix/client/state_store/file.py
index d567c853..a5d53663 100644
--- a/mautrix/client/state_store/file.py
+++ b/mautrix/client/state_store/file.py
@@ -15,6 +15,7 @@
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
+ StateEvent,
UserID,
)
from mautrix.util.file_store import Filer, FileStore
@@ -65,3 +66,7 @@ async def set_power_levels(
) -> None:
await super().set_power_levels(room_id, content)
self._time_limited_flush()
+
+ async def set_create(self, event: StateEvent) -> None:
+ await super().set_create(event)
+ self._time_limited_flush()
diff --git a/mautrix/client/state_store/memory.py b/mautrix/client/state_store/memory.py
index 8f2edac5..2b010a75 100644
--- a/mautrix/client/state_store/memory.py
+++ b/mautrix/client/state_store/memory.py
@@ -14,6 +14,7 @@
PowerLevelStateEventContent,
RoomEncryptionStateEventContent,
RoomID,
+ StateEvent,
UserID,
)
@@ -25,6 +26,7 @@ class SerializedStateStore(TypedDict):
full_member_list: dict[RoomID, bool]
power_levels: dict[RoomID, Any]
encryption: dict[RoomID, Any]
+ create: dict[RoomID, Any]
class MemoryStateStore(StateStore):
@@ -32,12 +34,14 @@ class MemoryStateStore(StateStore):
full_member_list: dict[RoomID, bool]
power_levels: dict[RoomID, PowerLevelStateEventContent]
encryption: dict[RoomID, RoomEncryptionStateEventContent | None]
+ create: dict[RoomID, StateEvent]
def __init__(self) -> None:
self.members = {}
self.full_member_list = {}
self.power_levels = {}
self.encryption = {}
+ self.create = {}
def serialize(self) -> SerializedStateStore:
"""
@@ -58,6 +62,7 @@ def serialize(self) -> SerializedStateStore:
room_id: (content.serialize() if content is not None else None)
for room_id, content in self.encryption.items()
},
+ "create": {room_id: evt.serialize() for room_id, evt in self.create.items()},
}
def deserialize(self, data: SerializedStateStore) -> None:
@@ -84,6 +89,9 @@ def deserialize(self, data: SerializedStateStore) -> None:
)
for room_id, content in data["encryption"].items()
}
+ self.create = {
+ room_id: StateEvent.deserialize(evt) for room_id, evt in data["create"].items()
+ }
async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
try:
@@ -176,6 +184,17 @@ async def set_power_levels(
content = PowerLevelStateEventContent.deserialize(content)
self.power_levels[room_id] = content
+ async def has_create_cached(self, room_id: RoomID) -> bool:
+ return room_id in self.create
+
+ async def get_create(self, room_id: RoomID) -> StateEvent | None:
+ return self.create.get(room_id)
+
+ async def set_create(self, event: StateEvent | dict[str, Any]) -> None:
+ if not isinstance(event, StateEvent):
+ event = StateEvent.deserialize(event)
+ self.create[event.room_id] = event
+
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
return room_id in self.encryption
diff --git a/mautrix/client/state_store/sqlalchemy/__init__.py b/mautrix/client/state_store/sqlalchemy/__init__.py
deleted file mode 100644
index 6ce78f46..00000000
--- a/mautrix/client/state_store/sqlalchemy/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .mx_room_state import RoomState, SerializableType
-from .mx_user_profile import UserProfile
-from .sqlstatestore import SQLStateStore
diff --git a/mautrix/client/state_store/sqlalchemy/mx_room_state.py b/mautrix/client/state_store/sqlalchemy/mx_room_state.py
deleted file mode 100644
index 9f97056f..00000000
--- a/mautrix/client/state_store/sqlalchemy/mx_room_state.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright (c) 2022 Tulir Asokan
-#
-# This Source Code Form is subject to the terms of the Mozilla Public
-# License, v. 2.0. If a copy of the MPL was not distributed with this
-# file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from __future__ import annotations
-
-from typing import Type
-import json
-
-from sqlalchemy import Boolean, Column, Text, types
-
-from mautrix.types import (
- PowerLevelStateEventContent as PowerLevels,
- RoomEncryptionStateEventContent as EncryptionInfo,
- RoomID,
- Serializable,
-)
-from mautrix.util.db import Base
-
-
-class SerializableType(types.TypeDecorator):
- impl = types.Text
-
- def __init__(self, python_type: Type[Serializable], *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self._python_type = python_type
-
- @property
- def python_type(self) -> Type[Serializable]:
- return self._python_type
-
- def process_bind_param(self, value: Serializable, dialect) -> str | None:
- if value is not None:
- return json.dumps(value.serialize() if isinstance(value, Serializable) else value)
- return None
-
- def process_result_value(self, value: str, dialect) -> Serializable | None:
- if value is not None:
- return self.python_type.deserialize(json.loads(value))
- return None
-
- def process_literal_param(self, value, dialect):
- return value
-
-
-class RoomState(Base):
- __tablename__ = "mx_room_state"
-
- room_id: RoomID = Column(Text, primary_key=True)
- is_encrypted: bool = Column(Boolean, nullable=True)
- has_full_member_list: bool = Column(Boolean, nullable=True)
- encryption: EncryptionInfo = Column(SerializableType(EncryptionInfo), nullable=True)
- power_levels: PowerLevels = Column(SerializableType(PowerLevels), nullable=True)
-
- @property
- def has_power_levels(self) -> bool:
- return bool(self.power_levels)
-
- @property
- def has_encryption_info(self) -> bool:
- return self.is_encrypted is not None
-
- @classmethod
- def get(cls, room_id: RoomID) -> RoomState | None:
- return cls._select_one_or_none(cls.c.room_id == room_id)
diff --git a/mautrix/client/state_store/sqlalchemy/mx_user_profile.py b/mautrix/client/state_store/sqlalchemy/mx_user_profile.py
deleted file mode 100644
index 0bb6b766..00000000
--- a/mautrix/client/state_store/sqlalchemy/mx_user_profile.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# Copyright (c) 2022 Tulir Asokan
-#
-# This Source Code Form is subject to the terms of the Mozilla Public
-# License, v. 2.0. If a copy of the MPL was not distributed with this
-# file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from __future__ import annotations
-
-from typing import Iterable
-
-from sqlalchemy import Column, Enum, Text
-
-from mautrix.types import ContentURI, Member, Membership, RoomID, UserID
-from mautrix.util.db import Base
-
-from .mx_room_state import RoomState
-
-
-class UserProfile(Base):
- __tablename__ = "mx_user_profile"
-
- room_id: RoomID = Column(Text, primary_key=True)
- user_id: UserID = Column(Text, primary_key=True)
- membership: Membership = Column(Enum(Membership), nullable=False, default=Membership.LEAVE)
- displayname: str = Column(Text, nullable=True)
- avatar_url: ContentURI = Column(Text, nullable=True)
-
- def member(self) -> Member:
- return Member(
- membership=self.membership, displayname=self.displayname, avatar_url=self.avatar_url
- )
-
- @classmethod
- def get(cls, room_id: RoomID, user_id: UserID) -> UserProfile | None:
- return cls._select_one_or_none((cls.c.room_id == room_id) & (cls.c.user_id == user_id))
-
- @classmethod
- def all_in_room(
- cls,
- room_id: RoomID,
- memberships: tuple[Membership, ...],
- prefix: str = None,
- suffix: str = None,
- bot: str = None,
- ) -> Iterable[UserProfile]:
- clause = cls.c.membership == memberships[0]
- for membership in memberships[1:]:
- clause |= cls.c.membership == membership
- clause &= cls.c.room_id == room_id
- if bot:
- clause &= cls.c.user_id != bot
- if prefix:
- clause &= ~cls.c.user_id.startswith(prefix, autoescape=True)
- if suffix:
- clause &= ~cls.c.user_id.startswith(suffix, autoescape=True)
- return cls._select_all(clause)
-
- @classmethod
- def find_rooms_with_user(cls, user_id: UserID) -> Iterable[UserProfile]:
- return cls._select_all(
- (cls.c.user_id == user_id)
- & (cls.c.room_id == RoomState.c.room_id)
- & (RoomState.c.is_encrypted == True)
- )
-
- @classmethod
- def delete_all(cls, room_id: RoomID) -> None:
- with cls.db.begin() as conn:
- conn.execute(cls.t.delete().where(cls.c.room_id == room_id))
-
- @classmethod
- def bulk_replace(
- cls,
- room_id: RoomID,
- members: dict[UserID, Member],
- only_membership: Membership | None = None,
- ) -> None:
- with cls.db.begin() as conn:
- delete_condition = cls.c.room_id == room_id
- if only_membership is not None:
- delete_condition &= cls.c.membership == only_membership
- cls.db.execute(cls.t.delete().where(delete_condition))
- conn.execute(
- cls.t.insert(),
- [
- dict(
- room_id=room_id,
- user_id=user_id,
- membership=member.membership,
- displayname=member.displayname,
- avatar_url=member.avatar_url,
- )
- for user_id, member in members.items()
- ],
- )
diff --git a/mautrix/client/state_store/sqlalchemy/sqlstatestore.py b/mautrix/client/state_store/sqlalchemy/sqlstatestore.py
deleted file mode 100644
index 1145dec7..00000000
--- a/mautrix/client/state_store/sqlalchemy/sqlstatestore.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright (c) 2022 Tulir Asokan
-#
-# This Source Code Form is subject to the terms of the Mozilla Public
-# License, v. 2.0. If a copy of the MPL was not distributed with this
-# file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from __future__ import annotations
-
-from typing import Any
-
-from mautrix.types import (
- Member,
- Membership,
- PowerLevelStateEventContent,
- RoomEncryptionStateEventContent,
- RoomID,
- UserID,
-)
-
-from ..abstract import StateStore
-from .mx_room_state import RoomState
-from .mx_user_profile import UserProfile
-
-
-class SQLStateStore(StateStore):
- _profile_cache: dict[RoomID, dict[UserID, UserProfile]]
- _room_state_cache: dict[RoomID, RoomState]
-
- def __init__(self) -> None:
- super().__init__()
- self._profile_cache = {}
- self._room_state_cache = {}
-
- def _get_user_profile(
- self, room_id: RoomID, user_id: UserID, create: bool = False
- ) -> UserProfile:
- if not room_id:
- raise ValueError("room_id is empty")
- elif not user_id:
- raise ValueError("user_id is empty")
- try:
- return self._profile_cache[room_id][user_id]
- except KeyError:
- pass
- if room_id not in self._profile_cache:
- self._profile_cache[room_id] = {}
-
- profile = UserProfile.get(room_id, user_id)
- if profile:
- self._profile_cache[room_id][user_id] = profile
- elif create:
- profile = UserProfile(room_id=room_id, user_id=user_id, membership=Membership.LEAVE)
- profile.insert()
- self._profile_cache[room_id][user_id] = profile
- return profile
-
- async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
- profile = self._get_user_profile(room_id, user_id)
- if not profile:
- return None
- return profile.member()
-
- async def set_member(self, room_id: RoomID, user_id: UserID, member: Member) -> None:
- if not member:
- raise ValueError("member info is empty")
- profile = self._get_user_profile(room_id, user_id, create=True)
- profile.edit(
- membership=member.membership,
- displayname=member.displayname or profile.displayname,
- avatar_url=member.avatar_url or profile.avatar_url,
- )
-
- async def set_membership(
- self, room_id: RoomID, user_id: UserID, membership: Membership
- ) -> None:
- await self.set_member(room_id, user_id, Member(membership=membership))
-
- async def get_member_profiles(
- self,
- room_id: RoomID,
- memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
- ) -> dict[UserID, Member]:
- self._profile_cache[room_id] = {}
- for profile in UserProfile.all_in_room(room_id, memberships):
- self._profile_cache[room_id][profile.user_id] = profile
- return {
- profile.user_id: profile.member() for profile in self._profile_cache[room_id].values()
- }
-
- async def get_members_filtered(
- self,
- room_id: RoomID,
- not_prefix: str,
- not_suffix: str,
- not_id: str,
- memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
- ) -> list[UserID]:
- return [
- profile.user_id
- for profile in UserProfile.all_in_room(
- room_id, memberships, not_suffix, not_prefix, not_id
- )
- ]
-
- async def set_members(
- self,
- room_id: RoomID,
- members: dict[UserID, Member],
- only_membership: Membership | None = None,
- ) -> None:
- UserProfile.bulk_replace(room_id, members, only_membership=only_membership)
- self._get_room_state(room_id, create=True).edit(has_full_member_list=True)
- try:
- del self._profile_cache[room_id]
- except KeyError:
- pass
-
- async def has_full_member_list(self, room_id: RoomID) -> bool:
- room = self._get_room_state(room_id)
- if not room:
- return False
- return room.has_full_member_list
-
- async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]:
- return [profile.room_id for profile in UserProfile.find_rooms_with_user(user_id)]
-
- def _get_room_state(self, room_id: RoomID, create: bool = False) -> RoomState:
- if not room_id:
- raise ValueError("room_id is empty")
- try:
- return self._room_state_cache[room_id]
- except KeyError:
- pass
-
- room = RoomState.get(room_id)
- if room:
- self._room_state_cache[room_id] = room
- elif create:
- room = RoomState(room_id=room_id)
- room.insert()
- self._room_state_cache[room_id] = room
- return room
-
- async def has_power_levels_cached(self, room_id: RoomID) -> bool:
- room = self._get_room_state(room_id)
- if not room:
- return False
- return room.has_power_levels
-
- async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None:
- room = self._get_room_state(room_id)
- if not room:
- return None
- return room.power_levels
-
- async def set_power_levels(
- self, room_id: RoomID, content: PowerLevelStateEventContent
- ) -> None:
- if not content:
- raise ValueError("content is empty")
- self._get_room_state(room_id, create=True).edit(power_levels=content)
-
- async def is_encrypted(self, room_id: RoomID) -> bool | None:
- room = self._get_room_state(room_id)
- if not room:
- return None
- return room.is_encrypted
-
- async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
- room = self._get_room_state(room_id)
- return room and room.has_encryption_info
-
- async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
- room = self._get_room_state(room_id)
- if not room:
- return None
- return room.encryption
-
- async def set_encryption_info(
- self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
- ) -> None:
- if not content:
- raise ValueError("content is empty")
- self._get_room_state(room_id, create=True).edit(encryption=content, is_encrypted=True)
diff --git a/mautrix/client/state_store/tests/store_test.py b/mautrix/client/state_store/tests/store_test.py
index dbbd376a..46eee750 100644
--- a/mautrix/client/state_store/tests/store_test.py
+++ b/mautrix/client/state_store/tests/store_test.py
@@ -16,15 +16,12 @@
import asyncpg
import pytest
-import sqlalchemy as sql
from mautrix.types import EncryptionAlgorithm, Member, Membership, RoomID, StateEvent, UserID
from mautrix.util.async_db import Database
-from mautrix.util.db import Base
from .. import MemoryStateStore, StateStore
from ..asyncpg import PgStateStore
-from ..sqlalchemy import RoomState, SQLStateStore, UserProfile
@asynccontextmanager
@@ -54,7 +51,7 @@ async def async_postgres_store() -> AsyncIterator[PgStateStore]:
@asynccontextmanager
async def async_sqlite_store() -> AsyncIterator[PgStateStore]:
db = Database.create(
- "sqlite:///:memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1}
+ "sqlite::memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1}
)
store = PgStateStore(db)
await db.start()
@@ -62,23 +59,12 @@ async def async_sqlite_store() -> AsyncIterator[PgStateStore]:
await db.stop()
-@asynccontextmanager
-async def alchemy_store() -> AsyncIterator[SQLStateStore]:
- db = sql.create_engine("sqlite:///:memory:")
- Base.metadata.bind = db
- for table in (RoomState, UserProfile):
- table.bind(db)
- Base.metadata.create_all()
- yield SQLStateStore()
- db.dispose()
-
-
@asynccontextmanager
async def memory_store() -> AsyncIterator[MemoryStateStore]:
yield MemoryStateStore()
-@pytest.fixture(params=[async_postgres_store, async_sqlite_store, alchemy_store, memory_store])
+@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store])
async def store(request) -> AsyncIterator[StateStore]:
param: Callable[[], AsyncContextManager[StateStore]] = request.param
async with param() as state_store:
diff --git a/mautrix/client/store_updater.py b/mautrix/client/store_updater.py
index a5530bde..35280324 100644
--- a/mautrix/client/store_updater.py
+++ b/mautrix/client/store_updater.py
@@ -5,6 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
+from typing import Literal
import asyncio
from mautrix.errors import MForbidden, MNotFound
@@ -196,20 +197,28 @@ async def send_state_event(
return event_id
async def get_state_event(
- self, room_id: RoomID, event_type: EventType, state_key: str = ""
- ) -> StateEventContent:
- event = await super().get_state_event(room_id, event_type, state_key)
+ self,
+ room_id: RoomID,
+ event_type: EventType,
+ state_key: str = "",
+ *,
+ format: str = "content",
+ ) -> StateEventContent | StateEvent:
+ event = await super().get_state_event(room_id, event_type, state_key, format=format)
if self.state_store:
- fake_event = StateEvent(
- type=event_type,
- room_id=room_id,
- event_id=EventID(""),
- sender=UserID(""),
- state_key=state_key,
- timestamp=0,
- content=event,
- )
- await self.state_store.update_state(fake_event)
+ if isinstance(event, StateEvent):
+ await self.state_store.update_state(event)
+ else:
+ fake_event = StateEvent(
+ type=event_type,
+ room_id=room_id,
+ event_id=EventID(""),
+ sender=UserID(""),
+ state_key=state_key,
+ timestamp=0,
+ content=event,
+ )
+ await self.state_store.update_state(fake_event)
return event
async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]:
diff --git a/mautrix/client/syncer.py b/mautrix/client/syncer.py
index 9f942665..4b8661a2 100644
--- a/mautrix/client/syncer.py
+++ b/mautrix/client/syncer.py
@@ -5,11 +5,11 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Any, Awaitable, Callable, Type, TypeVar
+from typing import Any, Awaitable, Callable, NamedTuple, Optional, Type, TypeVar
from abc import ABC, abstractmethod
-from contextlib import suppress
from enum import Enum, Flag, auto
import asyncio
+import itertools
import time
from mautrix.errors import MUnknownToken
@@ -25,7 +25,6 @@
Filter,
FilterID,
GenericEvent,
- MessageEvent,
PresenceState,
SerializerError,
StateEvent,
@@ -34,6 +33,7 @@
ToDeviceEvent,
UserID,
)
+from mautrix.util import background_task
from mautrix.util.logging import TraceLogger
from . import dispatcher
@@ -78,13 +78,18 @@ class InternalEventType(Enum):
DEVICE_OTK_COUNT = auto()
+class EventHandlerProps(NamedTuple):
+ wait_sync: bool
+ sync_stream: Optional[SyncStream]
+
+
class Syncer(ABC):
loop: asyncio.AbstractEventLoop
log: TraceLogger
mxid: UserID
- global_event_handlers: list[tuple[EventHandler, bool]]
- event_handlers: dict[EventType | InternalEventType, list[tuple[EventHandler, bool]]]
+ global_event_handlers: dict[EventHandler, EventHandlerProps]
+ event_handlers: dict[EventType | InternalEventType, dict[EventHandler, EventHandlerProps]]
dispatchers: dict[Type[dispatcher.Dispatcher], dispatcher.Dispatcher]
syncing_task: asyncio.Task | None
ignore_initial_sync: bool
@@ -94,7 +99,7 @@ class Syncer(ABC):
sync_store: SyncStore
def __init__(self, sync_store: SyncStore) -> None:
- self.global_event_handlers = []
+ self.global_event_handlers = {}
self.event_handlers = {}
self.dispatchers = {}
self.syncing_task = None
@@ -157,6 +162,7 @@ def add_event_handler(
event_type: InternalEventType | EventType,
handler: EventHandler,
wait_sync: bool = False,
+ sync_stream: Optional[SyncStream] = None,
) -> None:
"""
Add a new event handler.
@@ -166,13 +172,15 @@ def add_event_handler(
event types.
handler: The handler function to add.
wait_sync: Whether or not the handler should be awaited before the next sync request.
+ sync_stream: The sync streams to listen to. Defaults to all.
"""
if not isinstance(event_type, (EventType, InternalEventType)):
raise ValueError("Invalid event type")
+ props = EventHandlerProps(wait_sync=wait_sync, sync_stream=sync_stream)
if event_type == EventType.ALL:
- self.global_event_handlers.append((handler, wait_sync))
+ self.global_event_handlers[handler] = props
else:
- self.event_handlers.setdefault(event_type, []).append((handler, wait_sync))
+ self.event_handlers.setdefault(event_type, {})[handler] = props
def remove_event_handler(
self, event_type: EventType | InternalEventType, handler: EventHandler
@@ -196,11 +204,7 @@ def remove_event_handler(
# No handlers for this event type registered
return
- # FIXME this is a bit hacky
- with suppress(ValueError):
- handler_list.remove((handler, True))
- with suppress(ValueError):
- handler_list.remove((handler, False))
+ handler_list.pop(handler, None)
if len(handler_list) == 0 and event_type != EventType.ALL:
del self.event_handlers[event_type]
@@ -228,7 +232,9 @@ def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asynci
else:
event.type = event.type.with_class(EventType.Class.MESSAGE)
setattr(event, "source", source)
- return self.dispatch_manual_event(event.type, event, include_global_handlers=True)
+ return self.dispatch_manual_event(
+ event.type, event, include_global_handlers=True, source=source
+ )
async def _catch_errors(self, handler: EventHandler, data: Any) -> None:
try:
@@ -242,15 +248,25 @@ def dispatch_manual_event(
data: Any,
include_global_handlers: bool = False,
force_synchronous: bool = False,
+ source: Optional[SyncStream] = None,
) -> list[asyncio.Task]:
- handlers = self.event_handlers.get(event_type, [])
+ handlers = self.event_handlers.get(event_type, {}).items()
if include_global_handlers:
- handlers = self.global_event_handlers + handlers
+ handlers = itertools.chain(self.global_event_handlers.items(), handlers)
tasks = []
- for handler, wait_sync in handlers:
- task = asyncio.create_task(self._catch_errors(handler, data))
- if force_synchronous or wait_sync:
- tasks.append(task)
+ if source is None:
+ source = getattr(data, "source", None)
+ for handler, props in handlers:
+ if (
+ props.sync_stream is not None
+ and source is not None
+ and not props.sync_stream & source
+ ):
+ continue
+ if force_synchronous or props.wait_sync:
+ tasks.append(asyncio.create_task(self._catch_errors(handler, data)))
+ else:
+ background_task.create(self._catch_errors(handler, data))
return tasks
async def run_internal_event(
@@ -261,6 +277,7 @@ async def run_internal_event(
event_type,
custom_type if custom_type is not None else kwargs,
include_global_handlers=False,
+ source=SyncStream.INTERNAL,
)
await asyncio.gather(*tasks)
@@ -272,6 +289,7 @@ def dispatch_internal_event(
event_type,
custom_type if custom_type is not None else kwargs,
include_global_handlers=False,
+ source=SyncStream.INTERNAL,
)
def _try_deserialize(self, type: Type[T], data: JSON) -> T | GenericEvent:
@@ -340,16 +358,29 @@ def handle_sync(self, data: JSON) -> list[asyncio.Task]:
self._try_deserialize(Event, raw_event),
source=SyncStream.JOINED_ROOM | SyncStream.TIMELINE,
)
+
+ for raw_event in room_data.get("ephemeral", {}).get("events", []):
+ raw_event["room_id"] = room_id
+ tasks += self.dispatch_event(
+ self._try_deserialize(EphemeralEvent, raw_event),
+ source=SyncStream.JOINED_ROOM | SyncStream.EPHEMERAL,
+ )
for room_id, room_data in rooms.get("invite", {}).items():
events: list[dict[str, JSON]] = room_data.get("invite_state", {}).get("events", [])
for raw_event in events:
raw_event["room_id"] = room_id
- raw_invite = next(
- raw_event
- for raw_event in events
- if raw_event.get("type", "") == "m.room.member"
- and raw_event.get("state_key", "") == self.mxid
- )
+ try:
+ raw_invite = next(
+ raw_event
+ for raw_event in events
+ if raw_event.get("type", "") == "m.room.member"
+ and raw_event.get("state_key", "") == self.mxid
+ )
+ except StopIteration:
+ self.log.warning(
+ f"Corrupted invite section in sync: no invite event present for {room_id}"
+ )
+ continue
# These aren't required by the spec, so make sure they're set
raw_invite.setdefault("event_id", None)
raw_invite.setdefault("origin_server_ts", int(time.time() * 1000))
diff --git a/mautrix/crypto/__init__.py b/mautrix/crypto/__init__.py
index 39867225..743fc2c6 100644
--- a/mautrix/crypto/__init__.py
+++ b/mautrix/crypto/__init__.py
@@ -1,6 +1,6 @@
from .account import OlmAccount
from .key_share import RejectKeyShare
-from .sessions import InboundGroupSession, OutboundGroupSession, Session
+from .sessions import InboundGroupSession, OutboundGroupSession, RatchetSafety, Session
# These have to be last
from .store import ( # isort: skip
diff --git a/mautrix/crypto/account.py b/mautrix/crypto/account.py
index db508262..a00ada71 100644
--- a/mautrix/crypto/account.py
+++ b/mautrix/crypto/account.py
@@ -10,15 +10,17 @@
from mautrix.types import (
DeviceID,
+ DeviceKeys,
EncryptionAlgorithm,
EncryptionKeyAlgorithm,
IdentityKey,
+ KeyID,
SigningKey,
UserID,
)
-from . import base
from .sessions import Session
+from .signature import sign_olm
class OlmAccount(olm.Account):
@@ -74,19 +76,18 @@ def new_outbound_session(self, target_key: IdentityKey, one_time_key: IdentityKe
session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now()
)
- def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> Dict[str, Any]:
- device_keys = {
- "user_id": user_id,
- "device_id": device_id,
- "algorithms": [EncryptionAlgorithm.OLM_V1.value, EncryptionAlgorithm.MEGOLM_V1.value],
- "keys": {
- f"{algorithm}:{device_id}": key for algorithm, key in self.identity_keys.items()
+ def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> DeviceKeys:
+ device_keys = DeviceKeys(
+ user_id=user_id,
+ device_id=device_id,
+ algorithms=[EncryptionAlgorithm.OLM_V1, EncryptionAlgorithm.MEGOLM_V1],
+ keys={
+ KeyID(algorithm=EncryptionKeyAlgorithm(algorithm), key_id=device_id): key
+ for algorithm, key in self.identity_keys.items()
},
- }
- signature = self.sign(base.canonical_json(device_keys))
- device_keys["signatures"] = {
- user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature}
- }
+ signatures={},
+ )
+ device_keys.signatures[user_id] = {KeyID.ed25519(device_id): sign_olm(device_keys, self)}
return device_keys
def get_one_time_keys(
@@ -97,12 +98,12 @@ def get_one_time_keys(
self.generate_one_time_keys(new_count)
keys = {}
for key_id, key in self.one_time_keys.get("curve25519", {}).items():
- signature = self.sign(base.canonical_json({"key": key}))
- keys[f"{EncryptionKeyAlgorithm.SIGNED_CURVE25519}:{key_id}"] = {
+ keys[str(KeyID.signed_curve25519(IdentityKey(key_id)))] = {
"key": key,
"signatures": {
- user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature}
+ user_id: {
+ str(KeyID.ed25519(device_id)): sign_olm({"key": key}, self),
+ }
},
}
- self.mark_keys_as_published()
return keys
diff --git a/mautrix/crypto/base.py b/mautrix/crypto/base.py
index 4d12e6cf..f7015ca1 100644
--- a/mautrix/crypto/base.py
+++ b/mautrix/crypto/base.py
@@ -1,41 +1,34 @@
-# Copyright (c) 2022 Tulir Asokan
+# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Any, Awaitable, Callable, TypedDict
+from typing import Awaitable, Callable
import asyncio
-import functools
-import json
-
-import olm
+from mautrix.errors import MForbidden, MNotFound
from mautrix.types import (
- DeviceID,
- EncryptionKeyAlgorithm,
+ EventType,
IdentityKey,
- KeyID,
RequestedKeyInfo,
+ RoomEncryptionStateEventContent,
RoomID,
+ RoomKeyEventContent,
SessionID,
- SigningKey,
TrustState,
UserID,
)
from mautrix.util.logging import TraceLogger
from .. import client as cli, crypto
-
-
-class SignedObject(TypedDict):
- signatures: dict[UserID, dict[str, str]]
- unsigned: Any
+from .ssss import Machine as SSSSMachine
class BaseOlmMachine:
client: cli.Client
+ ssss: SSSSMachine
log: TraceLogger
crypto_store: crypto.CryptoStore
state_store: crypto.StateStore
@@ -46,6 +39,14 @@ class BaseOlmMachine:
share_keys_min_trust: TrustState
allow_key_share: Callable[[crypto.DeviceIdentity, RequestedKeyInfo], Awaitable[bool]]
+ delete_outbound_keys_on_ack: bool
+ dont_store_outbound_keys: bool
+ delete_previous_keys_on_receive: bool
+ ratchet_keys_on_decrypt: bool
+ delete_fully_used_keys_on_decrypt: bool
+ delete_keys_on_device_delete: bool
+ disable_device_change_key_rotation: bool
+
# Futures that wait for responses to a key request
_key_request_waiters: dict[SessionID, asyncio.Future]
# Futures that wait for a session to be received (either normally or through a key request)
@@ -53,6 +54,9 @@ class BaseOlmMachine:
_prev_unwedge: dict[IdentityKey, float]
_fetch_keys_lock: asyncio.Lock
+ _megolm_decrypt_lock: asyncio.Lock
+ _share_keys_lock: asyncio.Lock
+ _last_key_share: float
_cs_fetch_attempted: set[UserID]
async def wait_for_session(
@@ -74,26 +78,30 @@ def _mark_session_received(self, session_id: SessionID) -> None:
except KeyError:
return
-
-canonical_json = functools.partial(
- json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True
-)
-
-
-def verify_signature_json(
- data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey
-) -> bool:
- data_copy = {**data}
- data_copy.pop("unsigned", None)
- signatures = data_copy.pop("signatures")
- key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name))
- try:
- signature = signatures[user_id][key_id]
- except KeyError:
- return False
- signed_data = canonical_json(data_copy)
- try:
- olm.ed25519_verify(key, signed_data, signature)
- return True
- except olm.OlmVerifyError:
- return False
+ async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None:
+ encryption_info = await self.state_store.get_encryption_info(evt.room_id)
+ if not encryption_info:
+ self.log.warning(
+ f"Encryption info for {evt.room_id} not found in state store, fetching from server"
+ )
+ try:
+ encryption_info = await self.client.get_state_event(
+ evt.room_id, EventType.ROOM_ENCRYPTION
+ )
+ except (MNotFound, MForbidden) as e:
+ self.log.warning(
+ f"Failed to get encryption info for {evt.room_id} from server: {e},"
+ " using defaults"
+ )
+ encryption_info = RoomEncryptionStateEventContent()
+ if not encryption_info:
+ self.log.warning(
+ f"Didn't find encryption info for {evt.room_id} on server either,"
+ " using defaults"
+ )
+ encryption_info = RoomEncryptionStateEventContent()
+
+ if not evt.beeper_max_age_ms:
+ evt.beeper_max_age_ms = encryption_info.rotation_period_ms
+ if not evt.beeper_max_messages:
+ evt.beeper_max_messages = encryption_info.rotation_period_msgs
diff --git a/mautrix/crypto/cross_signing.py b/mautrix/crypto/cross_signing.py
new file mode 100644
index 00000000..7577cd5d
--- /dev/null
+++ b/mautrix/crypto/cross_signing.py
@@ -0,0 +1,177 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from ..types import (
+ JSON,
+ CrossSigner,
+ CrossSigningKeys,
+ CrossSigningUsage,
+ DeviceIdentity,
+ EventType,
+ KeyID,
+ UserID,
+)
+from .cross_signing_key import CrossSigningPrivateKeys, CrossSigningPublicKeys, CrossSigningSeeds
+from .device_lists import DeviceListMachine
+from .signature import sign_olm
+from .ssss import Key as SSSSKey
+
+
+class CrossSigningMachine(DeviceListMachine):
+ _cross_signing_public_keys: CrossSigningPublicKeys | None
+ _cross_signing_public_keys_fetched: bool
+ _cross_signing_private_keys: CrossSigningPrivateKeys | None
+
+ async def verify_with_recovery_key(self, recovery_key: str) -> None:
+ if not self.account.shared:
+ raise ValueError("Device keys must be shared before verifying with recovery key")
+ key_id, key_data = await self.ssss.get_default_key_data()
+ ssss_key = key_data.verify_recovery_key(key_id, recovery_key)
+ seeds = await self._fetch_cross_signing_keys_from_ssss(ssss_key)
+ self._import_cross_signing_keys(seeds)
+ await self.sign_own_device(self.own_identity)
+
+ def _import_cross_signing_keys(self, seeds: CrossSigningSeeds) -> None:
+ self._cross_signing_private_keys = seeds.to_keys()
+ self._cross_signing_public_keys = self._cross_signing_private_keys.public_keys
+
+ async def generate_recovery_key(
+ self, passphrase: str | None = None, seeds: CrossSigningSeeds | None = None
+ ) -> str:
+ if not self.account.shared:
+ raise ValueError("Device keys must be shared before generating recovery key")
+ seeds = seeds or CrossSigningSeeds.generate()
+ ssss_key = await self.ssss.generate_and_upload_key(passphrase)
+ await self._upload_cross_signing_keys_to_ssss(ssss_key, seeds)
+ await self._publish_cross_signing_keys(seeds.to_keys())
+ await self.ssss.set_default_key_id(ssss_key.id)
+ await self.sign_own_device(self.own_identity)
+ return ssss_key.recovery_key
+
+ async def _fetch_cross_signing_keys_from_ssss(self, key: SSSSKey) -> CrossSigningSeeds:
+ return CrossSigningSeeds(
+ master_key=await self.ssss.get_decrypted_account_data(
+ EventType.CROSS_SIGNING_MASTER, key
+ ),
+ user_signing_key=await self.ssss.get_decrypted_account_data(
+ EventType.CROSS_SIGNING_USER_SIGNING, key
+ ),
+ self_signing_key=await self.ssss.get_decrypted_account_data(
+ EventType.CROSS_SIGNING_SELF_SIGNING, key
+ ),
+ )
+
+ async def _upload_cross_signing_keys_to_ssss(
+ self, key: SSSSKey, seeds: CrossSigningSeeds
+ ) -> None:
+ await self.ssss.set_encrypted_account_data(
+ EventType.CROSS_SIGNING_MASTER, seeds.master_key, key
+ )
+ await self.ssss.set_encrypted_account_data(
+ EventType.CROSS_SIGNING_USER_SIGNING, seeds.user_signing_key, key
+ )
+ await self.ssss.set_encrypted_account_data(
+ EventType.CROSS_SIGNING_SELF_SIGNING, seeds.self_signing_key, key
+ )
+
+ async def get_own_cross_signing_public_keys(self) -> CrossSigningPublicKeys | None:
+ if self._cross_signing_public_keys or self._cross_signing_public_keys_fetched:
+ return self._cross_signing_public_keys
+ keys = await self.get_cross_signing_public_keys(self.client.mxid)
+ self._cross_signing_public_keys_fetched = True
+ if keys:
+ self._cross_signing_public_keys = keys
+ return keys
+
+ async def get_cross_signing_public_keys(
+ self, user_id: UserID
+ ) -> CrossSigningPublicKeys | None:
+ db_keys = await self.crypto_store.get_cross_signing_keys(user_id)
+ if CrossSigningUsage.MASTER not in db_keys and user_id not in self._cs_fetch_attempted:
+ self.log.debug(f"Didn't find any cross-signing keys for {user_id}, fetching...")
+ async with self._fetch_keys_lock:
+ if user_id not in self._cs_fetch_attempted:
+ self._cs_fetch_attempted.add(user_id)
+ await self._fetch_keys([user_id], include_untracked=True)
+ db_keys = await self.crypto_store.get_cross_signing_keys(user_id)
+ if CrossSigningUsage.MASTER not in db_keys:
+ return None
+ return CrossSigningPublicKeys(
+ master_key=db_keys[CrossSigningUsage.MASTER].key,
+ self_signing_key=(
+ db_keys[CrossSigningUsage.SELF].key if CrossSigningUsage.SELF in db_keys else None
+ ),
+ user_signing_key=(
+ db_keys[CrossSigningUsage.USER].key if CrossSigningUsage.USER in db_keys else None
+ ),
+ )
+
+ async def sign_own_device(self, device: DeviceIdentity) -> None:
+ full_keys = await self._get_full_device_keys(device)
+ ssk = self._cross_signing_private_keys.self_signing_key
+ signature = sign_olm(full_keys, ssk)
+ full_keys.signatures = {self.client.mxid: {KeyID.ed25519(ssk.public_key): signature}}
+ await self.client.upload_one_signature(device.user_id, device.device_id, full_keys)
+ await self.crypto_store.put_signature(
+ CrossSigner(device.user_id, device.signing_key),
+ CrossSigner(self.client.mxid, ssk.public_key),
+ signature,
+ )
+
+ async def _publish_cross_signing_keys(
+ self,
+ keys: CrossSigningPrivateKeys,
+ auth: dict[str, JSON] | None = None,
+ ) -> None:
+ public = keys.public_keys
+ master_key = CrossSigningKeys(
+ user_id=self.client.mxid,
+ usage=[CrossSigningUsage.MASTER],
+ keys={KeyID.ed25519(public.master_key): public.master_key},
+ )
+ master_key.signatures = {
+ self.client.mxid: {
+ KeyID.ed25519(self.client.device_id): sign_olm(master_key, self.account),
+ }
+ }
+ self_key = CrossSigningKeys(
+ user_id=self.client.mxid,
+ usage=[CrossSigningUsage.SELF],
+ keys={KeyID.ed25519(public.self_signing_key): public.self_signing_key},
+ )
+ self_key.signatures = {
+ self.client.mxid: {
+ KeyID.ed25519(public.master_key): sign_olm(self_key, keys.master_key),
+ }
+ }
+ user_key = CrossSigningKeys(
+ user_id=self.client.mxid,
+ usage=[CrossSigningUsage.USER],
+ keys={KeyID.ed25519(public.user_signing_key): public.user_signing_key},
+ )
+ user_key.signatures = {
+ self.client.mxid: {
+ KeyID.ed25519(public.master_key): sign_olm(user_key, keys.master_key),
+ }
+ }
+ await self.client.upload_cross_signing_keys(
+ keys={
+ CrossSigningUsage.MASTER: master_key,
+ CrossSigningUsage.SELF: self_key,
+ CrossSigningUsage.USER: user_key,
+ },
+ auth=auth,
+ )
+ await self.crypto_store.put_cross_signing_key(
+ self.client.mxid, CrossSigningUsage.MASTER, public.master_key
+ )
+ await self.crypto_store.put_cross_signing_key(
+ self.client.mxid, CrossSigningUsage.SELF, public.self_signing_key
+ )
+ await self.crypto_store.put_cross_signing_key(
+ self.client.mxid, CrossSigningUsage.USER, public.user_signing_key
+ )
+ self._cross_signing_private_keys = keys
+ self._cross_signing_public_keys = public
diff --git a/mautrix/crypto/cross_signing_key.py b/mautrix/crypto/cross_signing_key.py
new file mode 100644
index 00000000..f4e3c1c1
--- /dev/null
+++ b/mautrix/crypto/cross_signing_key.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from typing import NamedTuple
+
+import olm
+
+from mautrix.crypto.ssss.util import cryptorand
+from mautrix.types import SigningKey
+
+
+class CrossSigningPublicKeys(NamedTuple):
+ master_key: SigningKey
+ self_signing_key: SigningKey
+ user_signing_key: SigningKey
+
+
+class CrossSigningPrivateKeys(NamedTuple):
+ master_key: olm.PkSigning
+ self_signing_key: olm.PkSigning
+ user_signing_key: olm.PkSigning
+
+ @property
+ def public_keys(self) -> CrossSigningPublicKeys:
+ return CrossSigningPublicKeys(
+ master_key=self.master_key.public_key,
+ self_signing_key=self.self_signing_key.public_key,
+ user_signing_key=self.user_signing_key.public_key,
+ )
+
+
+class CrossSigningSeeds(NamedTuple):
+ master_key: bytes
+ self_signing_key: bytes
+ user_signing_key: bytes
+
+ def to_keys(self) -> CrossSigningPrivateKeys:
+ return CrossSigningPrivateKeys(
+ master_key=olm.PkSigning(self.master_key),
+ self_signing_key=olm.PkSigning(self.self_signing_key),
+ user_signing_key=olm.PkSigning(self.user_signing_key),
+ )
+
+ @classmethod
+ def generate(cls) -> "CrossSigningSeeds":
+ return cls(
+ master_key=cryptorand.read(32),
+ self_signing_key=cryptorand.read(32),
+ user_signing_key=cryptorand.read(32),
+ )
diff --git a/mautrix/crypto/decrypt_megolm.py b/mautrix/crypto/decrypt_megolm.py
index 8edf5aaf..fd7a7eef 100644
--- a/mautrix/crypto/decrypt_megolm.py
+++ b/mautrix/crypto/decrypt_megolm.py
@@ -25,6 +25,7 @@
)
from .device_lists import DeviceListMachine
+from .sessions import InboundGroupSession
class MegolmDecryptionMachine(DeviceListMachine):
@@ -45,18 +46,22 @@ async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event:
raise DecryptionError("Unsupported event content class")
elif evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1:
raise DecryptionError("Unsupported event encryption algorithm")
- session = await self.crypto_store.get_group_session(evt.room_id, evt.content.session_id)
- if session is None:
- # TODO check if olm session is wedged
- raise SessionNotFound(evt.content.session_id, evt.content.sender_key)
- try:
- plaintext, index = session.decrypt(evt.content.ciphertext)
- except olm.OlmGroupSessionError as e:
- raise DecryptionError("Failed to decrypt megolm event") from e
- if not await self.crypto_store.validate_message_index(
- session.sender_key, SessionID(session.id), evt.event_id, index, evt.timestamp
- ):
- raise DuplicateMessageIndex()
+ async with self._megolm_decrypt_lock:
+ session = await self.crypto_store.get_group_session(
+ evt.room_id, evt.content.session_id
+ )
+ if session is None:
+ # TODO check if olm session is wedged
+ raise SessionNotFound(evt.content.session_id, evt.content.sender_key)
+ try:
+ plaintext, index = session.decrypt(evt.content.ciphertext)
+ except olm.OlmGroupSessionError as e:
+ raise DecryptionError("Failed to decrypt megolm event") from e
+ if not await self.crypto_store.validate_message_index(
+ session.sender_key, SessionID(session.id), evt.event_id, index, evt.timestamp
+ ):
+ raise DuplicateMessageIndex()
+ await self._ratchet_session(session, index)
forwarded_keys = False
if (
@@ -133,3 +138,55 @@ async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event:
"was_encrypted": True,
}
return result
+
+ async def _ratchet_session(self, sess: InboundGroupSession, index: int) -> None:
+ expected_message_index = sess.ratchet_safety.next_index
+ did_modify = True
+ if index > expected_message_index:
+ sess.ratchet_safety.missed_indices += list(range(expected_message_index, index))
+ sess.ratchet_safety.next_index = index + 1
+ elif index == expected_message_index:
+ sess.ratchet_safety.next_index = index + 1
+ else:
+ try:
+ sess.ratchet_safety.missed_indices.remove(index)
+ except ValueError:
+ did_modify = False
+ # Use presence of received_at as a sign that this is a recent megolm session,
+ # and therefore it's safe to drop missed indices entirely.
+ if (
+ sess.received_at
+ and sess.ratchet_safety.missed_indices
+ and sess.ratchet_safety.missed_indices[0] < expected_message_index - 10
+ ):
+ i = 0
+ for i, lost_index in enumerate(sess.ratchet_safety.missed_indices):
+ if lost_index < expected_message_index - 10:
+ sess.ratchet_safety.lost_indices.append(lost_index)
+ else:
+ break
+ sess.ratchet_safety.missed_indices = sess.ratchet_safety.missed_indices[i + 1 :]
+ ratchet_target_index = sess.ratchet_safety.next_index
+ if len(sess.ratchet_safety.missed_indices) > 0:
+ ratchet_target_index = min(sess.ratchet_safety.missed_indices)
+ self.log.debug(
+ f"Ratchet safety info for {sess.id}: {sess.ratchet_safety}, {ratchet_target_index=}"
+ )
+ sess_id = SessionID(sess.id)
+ if (
+ sess.max_messages
+ and ratchet_target_index >= sess.max_messages
+ and not sess.ratchet_safety.missed_indices
+ and self.delete_fully_used_keys_on_decrypt
+ ):
+ self.log.info(f"Deleting fully used session {sess.id}")
+ await self.crypto_store.redact_group_session(
+ sess.room_id, sess_id, reason="maximum messages reached"
+ )
+ return
+ elif sess.first_known_index < ratchet_target_index and self.ratchet_keys_on_decrypt:
+ self.log.info(f"Ratcheting session {sess.id} to {ratchet_target_index}")
+ sess = sess.ratchet_to(ratchet_target_index)
+ elif not did_modify:
+ return
+ await self.crypto_store.put_group_session(sess.room_id, sess.sender_key, sess_id, sess)
diff --git a/mautrix/crypto/decrypt_olm.py b/mautrix/crypto/decrypt_olm.py
index 8182b160..6eef76f5 100644
--- a/mautrix/crypto/decrypt_olm.py
+++ b/mautrix/crypto/decrypt_olm.py
@@ -19,6 +19,7 @@
ToDeviceEvent,
UserID,
)
+from mautrix.util import background_task
from .base import BaseOlmMachine
from .sessions import Session
@@ -74,19 +75,19 @@ async def _decrypt_olm_ciphertext(
f"Found matching session yet decryption failed for sender {sender}"
f" with key {sender_key}"
)
- asyncio.create_task(self._unwedge_session(sender, sender_key))
+ background_task.create(self._unwedge_session(sender, sender_key))
raise
if not plaintext:
if message.type != OlmMsgType.PREKEY:
- asyncio.create_task(self._unwedge_session(sender, sender_key))
+ background_task.create(self._unwedge_session(sender, sender_key))
raise DecryptionError("Decryption failed for normal message")
self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}")
try:
session = await self._create_inbound_session(sender_key, message.body)
except olm.OlmSessionError as e:
- asyncio.create_task(self._unwedge_session(sender, sender_key))
+ background_task.create(self._unwedge_session(sender, sender_key))
raise DecryptionError("Failed to create new session from prekey message") from e
self.log.debug(
f"Created inbound session {session.id} for {sender} (sender key: {sender_key})"
diff --git a/mautrix/crypto/device_lists.py b/mautrix/crypto/device_lists.py
index c0cea43f..b00b104c 100644
--- a/mautrix/crypto/device_lists.py
+++ b/mautrix/crypto/device_lists.py
@@ -23,10 +23,23 @@
UserID,
)
-from .base import BaseOlmMachine, verify_signature_json
+from .base import BaseOlmMachine
+from .signature import verify_signature_json
class DeviceListMachine(BaseOlmMachine):
+ @property
+ def own_identity(self) -> DeviceIdentity:
+ return DeviceIdentity(
+ user_id=self.client.mxid,
+ device_id=self.client.device_id,
+ identity_key=self.account.identity_key,
+ signing_key=self.account.signing_key,
+ trust=TrustState.VERIFIED,
+ deleted=False,
+ name="",
+ )
+
async def _fetch_keys(
self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False
) -> dict[UserID, dict[DeviceID, DeviceIdentity]]:
@@ -46,51 +59,67 @@ async def _fetch_keys(
data = {}
for user_id, devices in resp.device_keys.items():
missing_users.remove(user_id)
-
- new_devices = {}
- existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
-
- self.log.trace(
- f"Updating devices for {user_id}, got {len(devices)}, "
- f"have {len(existing_devices)} in store"
- )
- changed = False
- ssks = resp.self_signing_keys.get(user_id)
- ssk = ssks.first_ed25519_key if ssks else None
- for device_id, device_keys in devices.items():
- try:
- existing = existing_devices[device_id]
- except KeyError:
- existing = None
- changed = True
- self.log.trace(f"Validating device {device_keys} of {user_id}")
- try:
- new_device = await self._validate_device(
- user_id, device_id, device_keys, existing
- )
- except DeviceValidationError as e:
- self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
- else:
- if new_device:
- new_devices[device_id] = new_device
- await self._store_device_self_signatures(device_keys, ssk)
- self.log.debug(
- f"Storing new device list for {user_id} containing {len(new_devices)} devices"
- )
- await self.crypto_store.put_devices(user_id, new_devices)
- data[user_id] = new_devices
-
- if changed or len(new_devices) != len(existing_devices):
- await self.on_devices_changed(user_id)
+ async with self.crypto_store.transaction():
+ data[user_id] = await self._process_fetched_keys(user_id, devices, resp)
for user_id in missing_users:
self.log.warning(f"Didn't get any devices for user {user_id}")
- for user_id in users:
- await self._store_cross_signing_keys(resp, user_id)
-
return data
+ async def _process_fetched_keys(
+ self,
+ user_id: UserID,
+ devices: dict[DeviceID, DeviceKeys],
+ resp: QueryKeysResponse,
+ ) -> dict[DeviceID, DeviceIdentity]:
+ new_devices = {}
+ existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
+
+ self.log.trace(
+ f"Updating devices for {user_id}, got {len(devices)}, "
+ f"have {len(existing_devices)} in store"
+ )
+ changed = False
+ ssks = resp.self_signing_keys.get(user_id)
+ ssk = ssks.first_ed25519_key if ssks else None
+ for device_id, device_keys in devices.items():
+ try:
+ existing = existing_devices[device_id]
+ except KeyError:
+ existing = None
+ changed = True
+ self.log.trace(f"Validating device {device_keys} of {user_id}")
+ try:
+ new_device = await self._validate_device(user_id, device_id, device_keys, existing)
+ except DeviceValidationError as e:
+ self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
+ else:
+ if new_device:
+ new_devices[device_id] = new_device
+ await self._store_device_self_signatures(device_keys, ssk)
+ self.log.debug(
+ f"Storing new device list for {user_id} containing {len(new_devices)} devices"
+ )
+ await self.crypto_store.put_devices(user_id, new_devices)
+
+ if changed or len(new_devices) != len(existing_devices):
+ if self.delete_keys_on_device_delete:
+ for device_id in existing_devices.keys() - new_devices.keys():
+ device = existing_devices[device_id]
+ removed_ids = await self.crypto_store.redact_group_sessions(
+ room_id=None, sender_key=device.identity_key, reason="device removed"
+ )
+ self.log.info(
+ "Redacted megolm sessions sent by removed device "
+ f"{device.user_id}/{device.device_id}: {removed_ids}"
+ )
+ await self.on_devices_changed(user_id)
+
+ await self._store_cross_signing_keys(resp, user_id)
+
+ return new_devices
+
async def _store_device_self_signatures(
self, device_keys: DeviceKeys, self_signing_key: SigningKey | None
) -> None:
@@ -183,7 +212,7 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User
signing_key = device.ed25519
except KeyError:
pass
- if len(signing_key) != 43:
+ if not signing_key or len(signing_key) != 43:
self.log.debug(
f"Cross-signing key {user_id}/{actual_key} has a signature from "
f"an unknown key {key_id}"
@@ -209,6 +238,12 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User
else:
self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}")
+ async def _get_full_device_keys(self, device: DeviceIdentity) -> DeviceKeys:
+ resp = await self.client.query_keys({device.user_id: [device.device_id]})
+ keys = resp.device_keys[device.user_id][device.device_id]
+ await self._validate_device(device.user_id, device.device_id, keys, device)
+ return keys
+
async def get_or_fetch_device(
self, user_id: UserID, device_id: DeviceID
) -> DeviceIdentity | None:
@@ -234,6 +269,8 @@ async def get_or_fetch_device_by_key(
return None
async def on_devices_changed(self, user_id: UserID) -> None:
+ if self.disable_device_change_key_rotation:
+ return
shared_rooms = await self.state_store.find_shared_rooms(user_id)
self.log.debug(
f"Devices of {user_id} changed, invalidating group session in {shared_rooms}"
@@ -284,18 +321,23 @@ async def _validate_device(
deleted=False,
)
- async def resolve_trust(self, device: DeviceIdentity) -> TrustState:
+ async def resolve_trust(self, device: DeviceIdentity, allow_fetch: bool = True) -> TrustState:
try:
- return await self._try_resolve_trust(device)
+ return await self._try_resolve_trust(device, allow_fetch)
except Exception:
self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}")
return TrustState.UNVERIFIED
- async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState:
- if device.trust in (TrustState.VERIFIED, TrustState.BLACKLISTED):
+ async def _try_resolve_trust(
+ self, device: DeviceIdentity, allow_fetch: bool = True
+ ) -> TrustState:
+ if device.device_id != self.client.device_id and device.trust in (
+ TrustState.VERIFIED,
+ TrustState.BLACKLISTED,
+ ):
return device.trust
their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id)
- if len(their_keys) == 0 and device.user_id not in self._cs_fetch_attempted:
+ if len(their_keys) == 0 and allow_fetch and device.user_id not in self._cs_fetch_attempted:
self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...")
async with self._fetch_keys_lock:
if device.user_id not in self._cs_fetch_attempted:
@@ -306,7 +348,8 @@ async def _try_resolve_trust(self, device: DeviceIdentity) -> TrustState:
msk = their_keys[CrossSigningUsage.MASTER]
ssk = their_keys[CrossSigningUsage.SELF]
except KeyError as e:
- self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
+ if allow_fetch:
+ self.log.warning(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
return TrustState.UNVERIFIED
ssk_signed = await self.crypto_store.is_key_signed_by(
target=CrossSigner(device.user_id, ssk.key),
diff --git a/mautrix/crypto/encrypt_megolm.py b/mautrix/crypto/encrypt_megolm.py
index 9b37e486..459bbf1f 100644
--- a/mautrix/crypto/encrypt_megolm.py
+++ b/mautrix/crypto/encrypt_megolm.py
@@ -5,7 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Dict, List, Tuple, Union
from collections import defaultdict
-from datetime import timedelta
+from datetime import datetime, timedelta
import asyncio
import json
import time
@@ -95,9 +95,9 @@ async def _encrypt_megolm_event(
{
"room_id": room_id,
"type": event_type.serialize(),
- "content": content.serialize()
- if isinstance(content, Serializable)
- else content,
+ "content": (
+ content.serialize() if isinstance(content, Serializable) else content
+ ),
}
)
)
@@ -173,21 +173,6 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No
session = await self._new_outbound_group_session(room_id)
self.log.debug(f"Sharing group session {session.id} for room {room_id} with {users}")
- encryption_info = await self.state_store.get_encryption_info(room_id)
- if encryption_info:
- if encryption_info.algorithm != EncryptionAlgorithm.MEGOLM_V1:
- raise SessionShareError("Room encryption algorithm is not supported")
- session.max_messages = encryption_info.rotation_period_msgs or session.max_messages
- session.max_age = (
- timedelta(milliseconds=encryption_info.rotation_period_ms)
- if encryption_info.rotation_period_ms
- else session.max_age
- )
- self.log.debug(
- "Got stored encryption state event and configured session to rotate "
- f"after {session.max_messages} messages or {session.max_age}"
- )
-
olm_sessions: DeviceMap = defaultdict(lambda: {})
withhold_key_msgs = defaultdict(lambda: {})
missing_sessions: Dict[UserID, Dict[DeviceID, DeviceIdentity]] = defaultdict(lambda: {})
@@ -253,13 +238,33 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No
async def _new_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession:
session = OutboundGroupSession(room_id)
- await self._create_group_session(
- self.account.identity_key,
- self.account.signing_key,
- room_id,
- SessionID(session.id),
- session.session_key,
- )
+
+ encryption_info = await self.state_store.get_encryption_info(room_id)
+ if encryption_info:
+ if encryption_info.algorithm != EncryptionAlgorithm.MEGOLM_V1:
+ raise SessionShareError("Room encryption algorithm is not supported")
+ session.max_messages = encryption_info.rotation_period_msgs or session.max_messages
+ session.max_age = (
+ timedelta(milliseconds=encryption_info.rotation_period_ms)
+ if encryption_info.rotation_period_ms
+ else session.max_age
+ )
+ self.log.debug(
+ "Got stored encryption state event and configured session to rotate "
+ f"after {session.max_messages} messages or {session.max_age}"
+ )
+
+ if not self.dont_store_outbound_keys:
+ await self._create_group_session(
+ self.account.identity_key,
+ self.account.signing_key,
+ room_id,
+ SessionID(session.id),
+ session.session_key,
+ max_messages=session.max_messages,
+ max_age=session.max_age,
+ is_scheduled=False,
+ )
return session
async def _encrypt_and_share_group_session(
@@ -286,6 +291,9 @@ async def _create_group_session(
room_id: RoomID,
session_id: SessionID,
session_key: str,
+ max_age: Union[timedelta, int],
+ max_messages: int,
+ is_scheduled: bool = False,
) -> None:
start = time.monotonic()
session = InboundGroupSession(
@@ -293,6 +301,10 @@ async def _create_group_session(
signing_key=signing_key,
sender_key=sender_key,
room_id=room_id,
+ received_at=datetime.utcnow(),
+ max_age=max_age,
+ max_messages=max_messages,
+ is_scheduled=is_scheduled,
)
olm_duration = time.monotonic() - start
if olm_duration > 5:
@@ -302,7 +314,10 @@ async def _create_group_session(
session_id = session.id
await self.crypto_store.put_group_session(room_id, sender_key, session_id, session)
self._mark_session_received(session_id)
- self.log.debug(f"Created inbound group session {room_id}/{sender_key}/{session_id}")
+ self.log.debug(
+ f"Created inbound group session {room_id}/{sender_key}/{session_id} "
+ f"(max {max_age} / {max_messages} messages, {is_scheduled=})"
+ )
async def _find_olm_sessions(
self,
diff --git a/mautrix/crypto/encrypt_olm.py b/mautrix/crypto/encrypt_olm.py
index 620cbbc8..029ad1e5 100644
--- a/mautrix/crypto/encrypt_olm.py
+++ b/mautrix/crypto/encrypt_olm.py
@@ -18,8 +18,9 @@
UserID,
)
-from .base import BaseOlmMachine, verify_signature_json
+from .base import BaseOlmMachine
from .sessions import Session
+from .signature import verify_signature_json
ClaimKeysList = Dict[UserID, Dict[DeviceID, DeviceIdentity]]
diff --git a/mautrix/crypto/key_request.py b/mautrix/crypto/key_request.py
index 9bbd2d67..ecaf994d 100644
--- a/mautrix/crypto/key_request.py
+++ b/mautrix/crypto/key_request.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2022 Tulir Asokan
+# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -120,9 +120,18 @@ async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None:
f"{evt.sender_device}, as crypto store says we have it already"
)
return
+ if not key.beeper_max_messages or not key.beeper_max_age_ms:
+ await self._fill_encryption_info(key)
key.forwarding_key_chain.append(evt.sender_key)
sess = InboundGroupSession.import_session(
- key.session_key, key.signing_key, key.sender_key, key.room_id, key.forwarding_key_chain
+ key.session_key,
+ key.signing_key,
+ key.sender_key,
+ key.room_id,
+ key.forwarding_key_chain,
+ max_age=key.beeper_max_age_ms,
+ max_messages=key.beeper_max_messages,
+ is_scheduled=key.beeper_is_scheduled,
)
if key.session_id != sess.id:
self.log.warning(
diff --git a/mautrix/crypto/key_share.py b/mautrix/crypto/key_share.py
index 58525bbf..8d6b6e73 100644
--- a/mautrix/crypto/key_share.py
+++ b/mautrix/crypto/key_share.py
@@ -76,7 +76,7 @@ async def default_allow_key_share(
code=RoomKeyWithheldCode.BLACKLISTED,
reason="You have been blacklisted by this device",
)
- elif device.trust >= self.share_keys_min_trust:
+ elif await self.resolve_trust(device) >= self.share_keys_min_trust:
self.log.debug(f"Accepting key request from trusted device {device.device_id}")
return True
else:
diff --git a/mautrix/crypto/machine.py b/mautrix/crypto/machine.py
index a1fd4e16..60c65677 100644
--- a/mautrix/crypto/machine.py
+++ b/mautrix/crypto/machine.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2022 Tulir Asokan
+# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,8 +8,10 @@
from typing import Optional
import asyncio
import logging
+import time
from mautrix import client as cli
+from mautrix.errors import GroupSessionWithheldError
from mautrix.types import (
ASToDeviceEvent,
DecryptedOlmEvent,
@@ -17,20 +19,25 @@
DeviceLists,
DeviceOTKCount,
EncryptionAlgorithm,
+ EncryptionKeyAlgorithm,
EventType,
+ Member,
Membership,
StateEvent,
ToDeviceEvent,
TrustState,
UserID,
)
+from mautrix.util import background_task
from mautrix.util.logging import TraceLogger
from .account import OlmAccount
+from .cross_signing import CrossSigningMachine
from .decrypt_megolm import MegolmDecryptionMachine
from .encrypt_megolm import MegolmEncryptionMachine
from .key_request import KeyRequestingMachine
from .key_share import KeySharingMachine
+from .ssss import Machine as SSSSMachine
from .store import CryptoStore, StateStore
from .unwedge import OlmUnwedgingMachine
@@ -41,6 +48,7 @@ class OlmMachine(
OlmUnwedgingMachine,
KeySharingMachine,
KeyRequestingMachine,
+ CrossSigningMachine,
):
"""
OlmMachine is the main class for handling things related to Matrix end-to-end encryption with
@@ -53,6 +61,7 @@ class OlmMachine(
log: TraceLogger
crypto_store: CryptoStore
state_store: StateStore
+ ssss: SSSSMachine
account: Optional[OlmAccount]
@@ -65,6 +74,7 @@ def __init__(
) -> None:
super().__init__()
self.client = client
+ self.ssss = SSSSMachine(client)
self.log = log or logging.getLogger("mau.crypto")
self.crypto_store = crypto_store
self.state_store = state_store
@@ -74,18 +84,34 @@ def __init__(
self.share_keys_min_trust = TrustState.CROSS_SIGNED_TOFU
self.allow_key_share = self.default_allow_key_share
+ self.delete_outbound_keys_on_ack = False
+ self.dont_store_outbound_keys = False
+ self.delete_previous_keys_on_receive = False
+ self.ratchet_keys_on_decrypt = False
+ self.delete_fully_used_keys_on_decrypt = False
+ self.delete_keys_on_device_delete = False
+ self.disable_device_change_key_rotation = False
+
self._fetch_keys_lock = asyncio.Lock()
+ self._megolm_decrypt_lock = asyncio.Lock()
+ self._share_keys_lock = asyncio.Lock()
+ self._last_key_share = time.monotonic() - 60
self._key_request_waiters = {}
self._inbound_session_waiters = {}
self._prev_unwedge = {}
self._cs_fetch_attempted = set()
+ self._cross_signing_public_keys = None
+ self._cross_signing_public_keys_fetched = False
+ self._cross_signing_private_keys = None
+
self.client.add_event_handler(
cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True
)
self.client.add_event_handler(cli.InternalEventType.DEVICE_LISTS, self.handle_device_lists)
self.client.add_event_handler(EventType.TO_DEVICE_ENCRYPTED, self.handle_to_device_event)
self.client.add_event_handler(EventType.ROOM_KEY_REQUEST, self.handle_room_key_request)
+ self.client.add_event_handler(EventType.BEEPER_ROOM_KEY_ACK, self.handle_beep_room_key_ack)
# self.client.add_event_handler(EventType.ROOM_KEY_WITHHELD, self.handle_room_key_withheld)
# self.client.add_event_handler(EventType.ORG_MATRIX_ROOM_KEY_WITHHELD,
# self.handle_room_key_withheld)
@@ -109,7 +135,7 @@ async def handle_as_otk_counts(
self.log.warning(f"Got OTK count for unknown device {user_id}/{device_id}")
async def handle_as_device_lists(self, device_lists: DeviceLists) -> None:
- asyncio.create_task(self.handle_device_lists(device_lists))
+ background_task.create(self.handle_device_lists(device_lists))
async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None:
if evt.to_user_id != self.client.mxid or evt.to_device_id != self.client.device_id:
@@ -121,6 +147,8 @@ async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None:
await self.handle_to_device_event(evt)
elif evt.type == EventType.ROOM_KEY_REQUEST:
await self.handle_room_key_request(evt)
+ elif evt.type == EventType.BEEPER_ROOM_KEY_ACK:
+ await self.handle_beep_room_key_ack(evt)
else:
self.log.debug(f"Got unknown to-device event {evt.type} from {evt.sender}")
@@ -170,9 +198,19 @@ async def handle_member_event(self, evt: StateEvent) -> None:
}
if prev == cur or ignored_changes.get(prev) == cur:
return
+ src = getattr(evt, "source", None)
+ prev_cache = evt.unsigned.get("mautrix_prev_membership")
+ if isinstance(prev_cache, Member) and prev_cache.membership == cur:
+ self.log.debug(
+ f"Got duplicate membership state event in {evt.room_id} changing {evt.state_key} "
+ f"from {prev} to {cur}, cached state was {prev_cache} (event ID: {evt.event_id}, "
+ f"sync source: {src})"
+ )
+ return
self.log.debug(
f"Got membership state event in {evt.room_id} changing {evt.state_key} from "
- f"{prev} to {cur}, invalidating group session"
+ f"{prev} to {cur} (event ID: {evt.event_id}, sync source: {src}, "
+ f"cached: {prev_cache.membership if prev_cache else None}), invalidating group session"
)
await self.crypto_store.remove_outbound_group_session(evt.room_id)
@@ -184,6 +222,11 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None:
passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you
do syncing in some manual way.
"""
+ if isinstance(evt, DecryptedOlmEvent):
+ self.log.warning(
+ f"Dropping unexpected nested encrypted to-device event from {evt.sender}"
+ )
+ return
self.log.trace(
f"Handling encrypted to-device event from {evt.content.sender_key} ({evt.sender})"
)
@@ -192,28 +235,81 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None:
await self._receive_room_key(decrypted_evt)
elif decrypted_evt.type == EventType.FORWARDED_ROOM_KEY:
await self._receive_forwarded_room_key(decrypted_evt)
+ else:
+ self.client.dispatch_event(decrypted_evt, source=evt.source)
async def _receive_room_key(self, evt: DecryptedOlmEvent) -> None:
# TODO nio had a comment saying "handle this better"
# for the case where evt.Keys.Ed25519 is none?
if evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1 or not evt.keys.ed25519:
return
+ if not evt.content.beeper_max_messages or not evt.content.beeper_max_age_ms:
+ await self._fill_encryption_info(evt.content)
+ if self.delete_previous_keys_on_receive and not evt.content.beeper_is_scheduled:
+ removed_ids = await self.crypto_store.redact_group_sessions(
+ evt.content.room_id, evt.sender_key, reason="received new key from device"
+ )
+ self.log.info(f"Redacted previous megolm sessions: {removed_ids}")
await self._create_group_session(
evt.sender_key,
evt.keys.ed25519,
evt.content.room_id,
evt.content.session_id,
evt.content.session_key,
+ max_age=evt.content.beeper_max_age_ms,
+ max_messages=evt.content.beeper_max_messages,
+ is_scheduled=evt.content.beeper_is_scheduled,
)
- async def share_keys(self, current_otk_count: int) -> None:
+ async def handle_beep_room_key_ack(self, evt: ToDeviceEvent) -> None:
+ try:
+ sess = await self.crypto_store.get_group_session(
+ evt.content.room_id, evt.content.session_id
+ )
+ except GroupSessionWithheldError:
+ self.log.debug(
+ f"Ignoring room key ack for session {evt.content.session_id}"
+ " that was already redacted"
+ )
+ return
+ if not sess:
+ self.log.debug(f"Ignoring room key ack for unknown session {evt.content.session_id}")
+ return
+ if (
+ sess.sender_key == self.account.identity_key
+ and self.delete_outbound_keys_on_ack
+ and evt.content.first_message_index == 0
+ ):
+ self.log.debug("Redacting inbound copy of outbound group session after ack")
+ await self.crypto_store.redact_group_session(
+ evt.content.room_id, evt.content.session_id, reason="outbound session acked"
+ )
+ else:
+ self.log.debug(f"Received room key ack for {sess.id}")
+
+ async def share_keys(self, current_otk_count: int | None = None) -> None:
"""
Share any keys that need to be shared. This is automatically called from
:meth:`handle_otk_count`, so you should not need to call this yourself.
Args:
current_otk_count: The current number of signed curve25519 keys present on the server.
+ If omitted, the count will be fetched from the server.
"""
+ async with self._share_keys_lock:
+ await self._share_keys(current_otk_count)
+
+ async def _share_keys(self, current_otk_count: int | None) -> None:
+ if current_otk_count is None or (
+ # If the last key share was recent and the new count is very low, re-check the count
+ # from the server to avoid any race conditions.
+ self._last_key_share + 60 > time.monotonic()
+ and current_otk_count < 10
+ ):
+ self.log.debug("Checking OTK count on server")
+ current_otk_count = (await self.client.upload_keys()).get(
+ EncryptionKeyAlgorithm.SIGNED_CURVE25519, 0
+ )
device_keys = (
self.account.get_device_keys(self.client.mxid, self.client.device_id)
if not self.account.shared
@@ -228,7 +324,9 @@ async def share_keys(self, current_otk_count: int) -> None:
if device_keys:
self.log.debug("Going to upload initial account keys")
self.log.debug(f"Uploading {len(one_time_keys)} one-time keys")
- await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys)
+ resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys)
self.account.shared = True
+ self.account.mark_keys_as_published()
+ self._last_key_share = time.monotonic()
await self.crypto_store.put_account(self.account)
- self.log.debug("Shared keys and saved account")
+ self.log.debug(f"Shared keys and saved account, new keys: {resp}")
diff --git a/mautrix/crypto/sessions.py b/mautrix/crypto/sessions.py
index c8164b27..b0b16a18 100644
--- a/mautrix/crypto/sessions.py
+++ b/mautrix/crypto/sessions.py
@@ -3,10 +3,11 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import List, Optional, Set, Tuple, cast
+from typing import List, Optional, Set, Tuple, Union, cast
from datetime import datetime, timedelta
from _libolm import ffi, lib
+from attr import dataclass
import olm
from mautrix.errors import EncryptionError
@@ -18,8 +19,10 @@
OlmMsgType,
RoomID,
RoomKeyEventContent,
+ SerializableAttrs,
SigningKey,
UserID,
+ field,
)
@@ -93,12 +96,25 @@ def describe(self) -> str:
return "describe not supported"
+@dataclass
+class RatchetSafety(SerializableAttrs):
+ next_index: int = 0
+ missed_indices: List[int] = field(factory=lambda: [])
+ lost_indices: List[int] = field(factory=lambda: [])
+
+
class InboundGroupSession(olm.InboundGroupSession):
room_id: RoomID
signing_key: SigningKey
sender_key: IdentityKey
forwarding_chain: List[IdentityKey]
+ ratchet_safety: RatchetSafety
+ received_at: datetime
+ max_age: timedelta
+ max_messages: int
+ is_scheduled: bool
+
def __init__(
self,
session_key: str,
@@ -106,11 +122,23 @@ def __init__(
sender_key: IdentityKey,
room_id: RoomID,
forwarding_chain: Optional[List[IdentityKey]] = None,
+ ratchet_safety: Optional[RatchetSafety] = None,
+ received_at: Optional[datetime] = None,
+ max_age: Union[timedelta, int, None] = None,
+ max_messages: Optional[int] = None,
+ is_scheduled: bool = False,
) -> None:
self.signing_key = signing_key
self.sender_key = sender_key
self.room_id = room_id
self.forwarding_chain = forwarding_chain or []
+ self.ratchet_safety = ratchet_safety or RatchetSafety()
+ self.received_at = received_at or datetime.utcnow()
+ if isinstance(max_age, int):
+ max_age = timedelta(milliseconds=max_age)
+ self.max_age = max_age
+ self.max_messages = max_messages
+ self.is_scheduled = is_scheduled
super().__init__(session_key)
def __new__(cls, *args, **kwargs):
@@ -125,12 +153,22 @@ def from_pickle(
sender_key: IdentityKey,
room_id: RoomID,
forwarding_chain: Optional[List[IdentityKey]] = None,
+ ratchet_safety: Optional[RatchetSafety] = None,
+ received_at: Optional[datetime] = None,
+ max_age: Optional[timedelta] = None,
+ max_messages: Optional[int] = None,
+ is_scheduled: bool = False,
) -> "InboundGroupSession":
session = super().from_pickle(pickle, passphrase)
session.signing_key = signing_key
session.sender_key = sender_key
session.room_id = room_id
session.forwarding_chain = forwarding_chain or []
+ session.ratchet_safety = ratchet_safety or RatchetSafety()
+ session.received_at = received_at
+ session.max_age = max_age
+ session.max_messages = max_messages
+ session.is_scheduled = is_scheduled
return session
@classmethod
@@ -141,14 +179,41 @@ def import_session(
sender_key: IdentityKey,
room_id: RoomID,
forwarding_chain: Optional[List[str]] = None,
+ ratchet_safety: Optional[RatchetSafety] = None,
+ received_at: Optional[datetime] = None,
+ max_age: Union[timedelta, int, None] = None,
+ max_messages: Optional[int] = None,
+ is_scheduled: bool = False,
) -> "InboundGroupSession":
session = super().import_session(session_key)
session.signing_key = signing_key
session.sender_key = sender_key
session.room_id = room_id
session.forwarding_chain = forwarding_chain or []
+ session.ratchet_safety = ratchet_safety or RatchetSafety()
+ session.received_at = received_at or datetime.utcnow()
+ if isinstance(max_age, int):
+ max_age = timedelta(milliseconds=max_age)
+ session.max_age = max_age
+ session.max_messages = max_messages
+ session.is_scheduled = is_scheduled
return session
+ def ratchet_to(self, index: int) -> "InboundGroupSession":
+ exported = self.export_session(index)
+ return self.import_session(
+ exported,
+ signing_key=self.signing_key,
+ sender_key=self.sender_key,
+ room_id=self.room_id,
+ forwarding_chain=self.forwarding_chain,
+ ratchet_safety=self.ratchet_safety,
+ received_at=self.received_at,
+ max_age=self.max_age,
+ max_messages=self.max_messages,
+ is_scheduled=self.is_scheduled,
+ )
+
class OutboundGroupSession(olm.OutboundGroupSession):
"""Outbound group session aware of the users it is shared with.
diff --git a/mautrix/crypto/signature.py b/mautrix/crypto/signature.py
new file mode 100644
index 00000000..6dc13e65
--- /dev/null
+++ b/mautrix/crypto/signature.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from typing import Any, TypedDict
+import functools
+import json
+
+import olm
+import unpaddedbase64
+
+from mautrix.types import (
+ JSON,
+ DeviceID,
+ EncryptionKeyAlgorithm,
+ KeyID,
+ Serializable,
+ Signature,
+ SigningKey,
+ UserID,
+)
+
+try:
+ from Crypto.PublicKey import ECC
+ from Crypto.Signature import eddsa
+except ImportError:
+ from Cryptodome.PublicKey import ECC
+ from Cryptodome.Signature import eddsa
+
+canonical_json = functools.partial(
+ json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True
+)
+
+
+class SignedObject(TypedDict):
+ signatures: dict[UserID, dict[str, str]]
+ unsigned: Any
+
+
+def sign_olm(data: dict[str, JSON] | Serializable, key: olm.PkSigning | olm.Account) -> Signature:
+ if isinstance(data, Serializable):
+ data = data.serialize()
+ data.pop("signatures", None)
+ data.pop("unsigned", None)
+ return Signature(key.sign(canonical_json(data)))
+
+
+def verify_signature_json(
+ data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey
+) -> bool:
+ data_copy = {**data}
+ data_copy.pop("unsigned", None)
+ signatures = data_copy.pop("signatures")
+ key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name))
+ try:
+ signature = signatures[user_id][key_id]
+ decoded_key = unpaddedbase64.decode_base64(key)
+ # pycryptodome doesn't accept raw keys, so wrap it in a DER structure
+ der_key = b"\x30\x2a\x30\x05\x06\x03\x2b\x65\x70\x03\x21\x00" + decoded_key
+ decoded_signature = unpaddedbase64.decode_base64(signature)
+ parsed_key = ECC.import_key(der_key)
+ verifier = eddsa.new(parsed_key, "rfc8032")
+ verifier.verify(canonical_json(data_copy).encode("utf-8"), decoded_signature)
+ return True
+ except (KeyError, ValueError):
+ return False
diff --git a/mautrix/crypto/signature_test.py b/mautrix/crypto/signature_test.py
new file mode 100644
index 00000000..115e4836
--- /dev/null
+++ b/mautrix/crypto/signature_test.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from mautrix.types import SigningKey, UserID
+
+from .signature import verify_signature_json
+
+
+def test_verify_signature_json() -> None:
+ assert verify_signature_json(
+ # This is actually a federation PDU rather than a device signature,
+ # but they're both 25519 curves so it doesn't make a difference.
+ {
+ "auth_events": [
+ "$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw",
+ "$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM",
+ "$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo",
+ ],
+ "content": {},
+ "depth": 3212,
+ "hashes": {"sha256": "K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},
+ "origin_server_ts": 1754242687127,
+ "prev_events": ["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],
+ "room_id": "!offtopic-2:continuwuity.org",
+ "sender": "@tulir:maunium.net",
+ "type": "m.room.message",
+ "signatures": {
+ UserID("maunium.net"): {
+ "ed25519:a_xxeS": "SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"
+ }
+ },
+ "unsigned": {"age_ts": 1754242687146},
+ },
+ UserID("maunium.net"),
+ "a_xxeS",
+ SigningKey("lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg"),
+ )
diff --git a/mautrix/crypto/ssss/__init__.py b/mautrix/crypto/ssss/__init__.py
new file mode 100644
index 00000000..9224418d
--- /dev/null
+++ b/mautrix/crypto/ssss/__init__.py
@@ -0,0 +1,8 @@
+from .key import Key, KeyMetadata, PassphraseMetadata
+from .machine import Machine
+from .types import (
+ Algorithm,
+ EncryptedAccountDataEventContent,
+ EncryptedKeyData,
+ PassphraseAlgorithm,
+)
diff --git a/mautrix/crypto/ssss/key.py b/mautrix/crypto/ssss/key.py
new file mode 100644
index 00000000..691ded71
--- /dev/null
+++ b/mautrix/crypto/ssss/key.py
@@ -0,0 +1,139 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from typing import Optional
+import base64
+import hashlib
+import hmac
+
+from attr import dataclass
+import unpaddedbase64
+
+from mautrix.types import EventType, SerializableAttrs
+
+from .types import Algorithm, EncryptedKeyData, PassphraseAlgorithm
+from .util import (
+ calculate_hash,
+ cryptorand,
+ decode_base58_recovery_key,
+ derive_keys,
+ encode_base58_recovery_key,
+ prepare_aes,
+)
+
+try:
+ from Crypto.Cipher import AES
+ from Crypto.Util import Counter
+except ImportError:
+ from Cryptodome.Cipher import AES
+ from Cryptodome.Util import Counter
+
+
+@dataclass
+class PassphraseMetadata(SerializableAttrs):
+ algorithm: PassphraseAlgorithm
+ iterations: int
+ salt: str
+ bits: int = 256
+
+ def get_key(self, passphrase: str) -> bytes:
+ if self.algorithm != PassphraseAlgorithm.PBKDF2:
+ raise ValueError(f"Unsupported passphrase algorithm {self.algorithm}")
+ return hashlib.pbkdf2_hmac(
+ "sha512",
+ passphrase.encode("utf-8"),
+ self.salt.encode("utf-8"),
+ self.iterations,
+ self.bits // 8,
+ )
+
+
+@dataclass
+class KeyMetadata(SerializableAttrs):
+ algorithm: Algorithm
+
+ iv: str | None = None
+ mac: str | None = None
+
+ name: str | None = None
+ passphrase: Optional[PassphraseMetadata] = None
+
+ def verify_passphrase(self, key_id: str, phrase: str) -> "Key":
+ if not self.passphrase:
+ raise ValueError("Passphrase not set on this key")
+ return self.verify_raw_key(key_id, self.passphrase.get_key(phrase))
+
+ def verify_recovery_key(self, key_id: str, recovery_key: str) -> "Key":
+ decoded_key = decode_base58_recovery_key(recovery_key)
+ if not decoded_key:
+ raise ValueError("Invalid recovery key syntax")
+ return self.verify_raw_key(key_id, decoded_key)
+
+ def verify_raw_key(self, key_id: str, key: bytes) -> "Key":
+ if self.mac.rstrip("=") != calculate_hash(key, self.iv):
+ raise ValueError("Key MAC does not match")
+ return Key(id=key_id, key=key, metadata=self)
+
+
+@dataclass
+class Key:
+ id: str
+ key: bytes
+ metadata: KeyMetadata
+
+ @classmethod
+ def generate(cls, passphrase: str | None = None) -> "Key":
+ passphrase_meta = (
+ PassphraseMetadata(
+ algorithm=PassphraseAlgorithm.PBKDF2,
+ iterations=500_000,
+ salt=base64.b64encode(cryptorand.read(24)).decode("utf-8"),
+ bits=256,
+ )
+ if passphrase
+ else None
+ )
+ key = passphrase_meta.get_key(passphrase) if passphrase else cryptorand.read(32)
+ iv = unpaddedbase64.encode_base64(cryptorand.read(16))
+ metadata = KeyMetadata(
+ algorithm=Algorithm.AES_HMAC_SHA2,
+ passphrase=passphrase_meta,
+ mac=calculate_hash(key, iv),
+ iv=iv,
+ )
+ key_id = unpaddedbase64.encode_base64(cryptorand.read(24))
+ return cls(key=key, id=key_id, metadata=metadata)
+
+ @property
+ def recovery_key(self) -> str:
+ return encode_base58_recovery_key(self.key)
+
+ def encrypt(self, event_type: str | EventType, data: str | bytes) -> EncryptedKeyData:
+ if isinstance(data, str):
+ data = data.encode("utf-8")
+ data = base64.b64encode(data).rstrip(b"=")
+
+ aes_key, hmac_key = derive_keys(self.key, event_type)
+ iv = bytearray(cryptorand.read(16))
+ iv[8] &= 0x7F
+ ciphertext = prepare_aes(aes_key, iv).encrypt(data)
+ digest = hmac.digest(hmac_key, ciphertext, hashlib.sha256)
+ return EncryptedKeyData(
+ ciphertext=unpaddedbase64.encode_base64(ciphertext),
+ iv=unpaddedbase64.encode_base64(iv),
+ mac=unpaddedbase64.encode_base64(digest),
+ )
+
+ def decrypt(self, event_type: str | EventType, data: EncryptedKeyData) -> bytes:
+ aes_key, hmac_key = derive_keys(self.key, event_type)
+ ciphertext = unpaddedbase64.decode_base64(data.ciphertext)
+ mac = unpaddedbase64.decode_base64(data.mac)
+
+ expected_mac = hmac.digest(hmac_key, ciphertext, hashlib.sha256)
+ if not hmac.compare_digest(mac, expected_mac):
+ raise ValueError("Invalid MAC")
+
+ plaintext = prepare_aes(aes_key, data.iv).decrypt(ciphertext)
+ return unpaddedbase64.decode_base64(plaintext.decode("utf-8"))
diff --git a/mautrix/crypto/ssss/key_test.py b/mautrix/crypto/ssss/key_test.py
new file mode 100644
index 00000000..e60f1e3c
--- /dev/null
+++ b/mautrix/crypto/ssss/key_test.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+import pytest
+
+from ...types.event.type import EventType
+from .key import Key, KeyMetadata
+from .types import EncryptedAccountDataEventContent
+
+KEY1_CROSS_SIGNING_MASTER_KEY = """{
+ "encrypted": {
+ "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0": {
+ "iv": "BpKP9nQJTE9jrsAssoxPqQ==",
+ "ciphertext": "fNRiiiidezjerTgV+G6pUtmeF3izzj5re/mVvY0hO2kM6kYGrxLuIu2ej80=",
+ "mac": "/gWGDGMyOLmbJp+aoSLh5JxCs0AdS6nAhjzpe+9G2Q0="
+ }
+ }
+}"""
+
+KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED = bytes(
+ [
+ 0x68,
+ 0xF9,
+ 0x7F,
+ 0xD1,
+ 0x92,
+ 0x2E,
+ 0xEC,
+ 0xF6,
+ 0xB8,
+ 0x2B,
+ 0xB8,
+ 0x90,
+ 0xD2,
+ 0x4D,
+ 0x06,
+ 0x52,
+ 0x98,
+ 0x4E,
+ 0x7A,
+ 0x1D,
+ 0x70,
+ 0x3B,
+ 0x9E,
+ 0x86,
+ 0x7B,
+ 0x7E,
+ 0xBA,
+ 0xF7,
+ 0xFE,
+ 0xB9,
+ 0x5B,
+ 0x6F,
+ ]
+)
+
+KEY1_META = """{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "passphrase": {
+ "algorithm": "m.pbkdf2",
+ "iterations": 500000,
+ "salt": "y863BOoqOadgDp8S3FtHXikDJEalsQ7d"
+ },
+ "iv": "xxkTK0L4UzxgAFkQ6XPwsw",
+ "mac": "MEhooO0ZhFJNxUhvRMSxBnJfL20wkLgle3ocY0ee/eA"
+}"""
+KEY1_ID = "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0"
+KEY1_RECOVERY_KEY = "EsTE s92N EtaX s2h6 VQYF 9Kao tHYL mkyL GKMh isZb KJ4E tvoC"
+KEY1_PASSPHRASE = "correct horse battery staple"
+
+KEY2_META = """{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "O0BOvTqiIAYjC+RMcyHfWw==",
+ "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
+}"""
+KEY2_ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv"
+KEY2_RECOVERY_KEY = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe"
+
+KEY2_META_BROKEN_IV = """{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow",
+ "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
+}"""
+
+KEY2_META_BROKEN_MAC = """{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "O0BOvTqiIAYjC+RMcyHfWw==",
+ "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow"
+}"""
+
+
+def get_key_meta(meta: str) -> KeyMetadata:
+ return KeyMetadata.parse_json(meta)
+
+
+def get_key1() -> Key:
+ return get_key_meta(KEY1_META).verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY)
+
+
+def get_key2() -> Key:
+ return get_key_meta(KEY2_META).verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+
+
+def get_encrypted_master_key() -> EncryptedAccountDataEventContent:
+ return EncryptedAccountDataEventContent.parse_json(KEY1_CROSS_SIGNING_MASTER_KEY)
+
+
+def test_decrypt_success() -> None:
+ key = get_key1()
+ emk = get_encrypted_master_key()
+ assert (
+ emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) == KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED
+ )
+
+
+def test_decrypt_fail_wrong_key() -> None:
+ key = get_key2()
+ emk = get_encrypted_master_key()
+ with pytest.raises(ValueError):
+ emk.decrypt(EventType.CROSS_SIGNING_MASTER, key)
+
+
+def test_decrypt_fail_fake_key() -> None:
+ key = get_key2()
+ key.id = KEY1_ID
+ emk = get_encrypted_master_key()
+ with pytest.raises(ValueError):
+ emk.decrypt(EventType.CROSS_SIGNING_MASTER, key)
+
+
+def test_decrypt_fail_wrong_type() -> None:
+ key = get_key1()
+ emk = get_encrypted_master_key()
+ with pytest.raises(ValueError):
+ emk.decrypt(EventType.CROSS_SIGNING_SELF_SIGNING, key)
+
+
+def test_encrypt_roundtrip() -> None:
+ key = get_key1()
+ data = bytes([0xDE, 0xAD, 0xBE, 0xEF])
+ ciphertext = key.encrypt("net.maunium.data", data)
+ plaintext = key.decrypt("net.maunium.data", ciphertext)
+ assert plaintext == data
+
+
+def test_verify_recovery_key_correct() -> None:
+ meta = get_key_meta(KEY1_META)
+ key = meta.verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY)
+ assert key.recovery_key == KEY1_RECOVERY_KEY
+
+
+def test_verify_recovery_key_correct2() -> None:
+ meta = get_key_meta(KEY2_META)
+ key = meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+ assert key.recovery_key == KEY2_RECOVERY_KEY
+
+
+def test_verify_recovery_key_invalid() -> None:
+ meta = get_key_meta(KEY1_META)
+ with pytest.raises(ValueError):
+ meta.verify_recovery_key(KEY1_ID, "foo")
+
+
+def test_verify_recovery_key_incorrect() -> None:
+ meta = get_key_meta(KEY1_META)
+ with pytest.raises(ValueError):
+ meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+
+
+def test_verify_recovery_key_broken_iv() -> None:
+ meta = get_key_meta(KEY2_META_BROKEN_IV)
+ with pytest.raises(ValueError):
+ meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+
+
+def test_verify_recovery_key_broken_mac() -> None:
+ meta = get_key_meta(KEY2_META_BROKEN_MAC)
+ with pytest.raises(ValueError):
+ meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY)
+
+
+def test_verify_passphrase_correct() -> None:
+ meta = get_key_meta(KEY1_META)
+ key = meta.verify_passphrase(KEY1_ID, KEY1_PASSPHRASE)
+ assert key.recovery_key == KEY1_RECOVERY_KEY
+
+
+def test_verify_passphrase_incorrect() -> None:
+ meta = get_key_meta(KEY1_META)
+ with pytest.raises(ValueError):
+ meta.verify_passphrase(KEY1_ID, "incorrect horse battery staple")
+
+
+def test_verify_passphrase_notset() -> None:
+ meta = get_key_meta(KEY2_META)
+ with pytest.raises(ValueError):
+ meta.verify_passphrase(KEY2_ID, "hmm")
diff --git a/mautrix/crypto/ssss/machine.py b/mautrix/crypto/ssss/machine.py
new file mode 100644
index 00000000..c43e25e3
--- /dev/null
+++ b/mautrix/crypto/ssss/machine.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from mautrix import client as cli
+from mautrix.errors import MNotFound
+from mautrix.types import EventType, SecretStorageDefaultKeyEventContent
+
+from .key import Key, KeyMetadata
+from .types import EncryptedAccountDataEventContent
+
+
+class Machine:
+ client: cli.Client
+
+ def __init__(self, client: cli.Client) -> None:
+ self.client = client
+
+ async def get_default_key_id(self) -> str | None:
+ try:
+ data = await self.client.get_account_data(EventType.SECRET_STORAGE_DEFAULT_KEY)
+ return SecretStorageDefaultKeyEventContent.deserialize(data).key
+ except (MNotFound, ValueError):
+ return None
+
+ async def set_default_key_id(self, key_id: str) -> None:
+ await self.client.set_account_data(
+ EventType.SECRET_STORAGE_DEFAULT_KEY,
+ SecretStorageDefaultKeyEventContent(key=key_id),
+ )
+
+ async def get_key_data(self, key_id: str) -> KeyMetadata:
+ data = await self.client.get_account_data(f"m.secret_storage.key.{key_id}")
+ return KeyMetadata.deserialize(data)
+
+ async def set_key_data(self, key_id: str, data: KeyMetadata) -> None:
+ await self.client.set_account_data(f"m.secret_storage.key.{key_id}", data)
+
+ async def get_default_key_data(self) -> tuple[str, KeyMetadata]:
+ key_id = await self.get_default_key_id()
+ if not key_id:
+ raise ValueError("No default key ID set")
+ return key_id, await self.get_key_data(key_id)
+
+ async def get_decrypted_account_data(self, event_type: EventType | str, key: Key) -> bytes:
+ data = await self.client.get_account_data(event_type)
+ parsed = EncryptedAccountDataEventContent.deserialize(data)
+ return parsed.decrypt(event_type, key)
+
+ async def set_encrypted_account_data(
+ self, event_type: EventType | str, data: bytes, *keys: Key
+ ) -> None:
+ encrypted_data = {}
+ for key in keys:
+ encrypted_data[key.id] = key.encrypt(event_type, data)
+ await self.client.set_account_data(
+ event_type,
+ EncryptedAccountDataEventContent(encrypted=encrypted_data),
+ )
+
+ async def generate_and_upload_key(self, passphrase: str | None = None) -> Key:
+ key = Key.generate(passphrase)
+ await self.set_key_data(key.id, key.metadata)
+ return key
diff --git a/mautrix/crypto/ssss/types.py b/mautrix/crypto/ssss/types.py
new file mode 100644
index 00000000..4a47f743
--- /dev/null
+++ b/mautrix/crypto/ssss/types.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from typing import TYPE_CHECKING
+
+from attr import dataclass
+
+from mautrix.types import EventType, SerializableAttrs, SerializableEnum
+from mautrix.types.event.account_data import account_data_event_content_map
+
+if TYPE_CHECKING:
+ from .key import Key
+
+
+class Algorithm(SerializableEnum):
+ AES_HMAC_SHA2 = "m.secret_storage.v1.aes-hmac-sha2"
+ CURVE25519_AES_SHA2 = "m.secret_storage.v1.curve25519-aes-sha2"
+
+
+class PassphraseAlgorithm(SerializableEnum):
+ PBKDF2 = "m.pbkdf2"
+
+
+@dataclass
+class EncryptedKeyData(SerializableAttrs):
+ ciphertext: str
+ iv: str
+ mac: str
+
+
+@dataclass
+class EncryptedAccountDataEventContent(SerializableAttrs):
+ encrypted: dict[str, EncryptedKeyData]
+
+ def decrypt(self, event_type: str | EventType, key: "Key") -> bytes:
+ try:
+ encrypted_data = self.encrypted[key.id]
+ except KeyError as e:
+ raise ValueError(f"Event not encrypted for provided key") from e
+ return key.decrypt(event_type, encrypted_data)
+
+
+for encrypted_account_data_type in (
+ EventType.CROSS_SIGNING_MASTER,
+ EventType.CROSS_SIGNING_USER_SIGNING,
+ EventType.CROSS_SIGNING_SELF_SIGNING,
+ EventType.MEGOLM_BACKUP_V1,
+):
+ account_data_event_content_map[encrypted_account_data_type] = EncryptedAccountDataEventContent
diff --git a/mautrix/crypto/ssss/util.py b/mautrix/crypto/ssss/util.py
new file mode 100644
index 00000000..b58c941a
--- /dev/null
+++ b/mautrix/crypto/ssss/util.py
@@ -0,0 +1,78 @@
+# Copyright (c) 2025 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+import hashlib
+import hmac
+
+import base58
+import unpaddedbase64
+
+from mautrix.types import EventType
+
+try:
+ from Crypto import Random
+ from Crypto.Cipher import AES
+ from Crypto.Hash import SHA256
+ from Crypto.Protocol.KDF import HKDF
+ from Crypto.Util import Counter
+except ImportError:
+ from Cryptodome import Random
+ from Cryptodome.Cipher import AES
+ from Cryptodome.Hash import SHA256
+ from Cryptodome.Protocol.KDF import HKDF
+ from Cryptodome.Util import Counter
+
+cryptorand = Random.new()
+
+
+def decode_base58_recovery_key(key: str) -> bytes | None:
+ key_bytes = base58.b58decode(key.replace(" ", ""))
+ if len(key_bytes) != 35 or key_bytes[0] != 0x8B or key_bytes[1] != 1:
+ return None
+ parity = 0
+ for byte in key_bytes[:34]:
+ parity ^= byte
+ return key_bytes[2:34] if parity == key_bytes[34] else None
+
+
+def encode_base58_recovery_key(key: bytes) -> str:
+ key_bytes = bytearray(35)
+ key_bytes[0] = 0x8B
+ key_bytes[1] = 1
+ key_bytes[2:34] = key
+ parity = 0
+ for byte in key_bytes:
+ parity ^= byte
+ key_bytes[34] = parity
+ encoded_key = base58.b58encode(key_bytes).decode("utf-8")
+ return " ".join(encoded_key[i : i + 4] for i in range(0, len(encoded_key), 4))
+
+
+def derive_keys(key: bytes, name: str | EventType = "") -> tuple[bytes, bytes]:
+ aes_key, hmac_key = HKDF(
+ master=key,
+ key_len=32,
+ salt=b"\x00" * 32,
+ hashmod=SHA256,
+ num_keys=2,
+ context=str(name).encode("utf-8"),
+ )
+ return aes_key, hmac_key
+
+
+def prepare_aes(key: bytes, iv: str | bytes) -> AES:
+ if isinstance(iv, str):
+ iv = unpaddedbase64.decode_base64(iv)
+ # initial_value = struct.unpack(">Q", iv[8:])[0]
+ # counter = Counter.new(64, prefix=iv[:8], initial_value=initial_value)
+ counter = Counter.new(128, initial_value=int.from_bytes(iv, byteorder="big"))
+ return AES.new(key=key, mode=AES.MODE_CTR, counter=counter)
+
+
+def calculate_hash(key: bytes, iv: str | bytes) -> str:
+ aes_key, hmac_key = derive_keys(key)
+ cipher = prepare_aes(aes_key, iv).decrypt(b"\x00" * 32)
+ digest = hmac.digest(hmac_key, cipher, hashlib.sha256)
+ return unpaddedbase64.encode_base64(digest)
diff --git a/mautrix/crypto/store/abstract.py b/mautrix/crypto/store/abstract.py
index 7828524d..0916a2d7 100644
--- a/mautrix/crypto/store/abstract.py
+++ b/mautrix/crypto/store/abstract.py
@@ -5,8 +5,9 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import NamedTuple
+from typing import AsyncContextManager, NamedTuple
from abc import ABC, abstractmethod
+from contextlib import asynccontextmanager
from mautrix.types import (
CrossSigner,
@@ -87,6 +88,11 @@ async def close(self) -> None:
async def flush(self) -> None:
"""Flush the store. If all the methods persist data immediately, this can be a no-op."""
+ @asynccontextmanager
+ async def transaction(self) -> None:
+ """Run a database transaction. If the store doesn't support transactions, this can be a no-op."""
+ yield
+
@abstractmethod
async def delete(self) -> None:
"""Delete the data in the store."""
@@ -197,6 +203,56 @@ async def get_group_session(
The :class:`InboundGroupSession` object, or ``None`` if not found.
"""
+ @abstractmethod
+ async def redact_group_session(
+ self, room_id: RoomID, session_id: SessionID, reason: str
+ ) -> None:
+ """
+ Remove the keys for a specific Megolm group session.
+
+ Args:
+ room_id: The room where the session is.
+ session_id: The session ID to remove.
+ reason: The reason the session is being removed.
+ """
+
+ @abstractmethod
+ async def redact_group_sessions(
+ self, room_id: RoomID | None, sender_key: IdentityKey | None, reason: str
+ ) -> list[SessionID]:
+ """
+ Remove the keys for multiple Megolm group sessions,
+ based on the room ID and/or sender device.
+
+ Args:
+ room_id: The room ID to delete keys from.
+ sender_key: The Olm identity key of the device to delete keys from.
+ reason: The reason why the keys are being deleted.
+
+ Returns:
+ The list of session IDs that were deleted.
+ """
+
+ @abstractmethod
+ async def redact_expired_group_sessions(self) -> list[SessionID]:
+ """
+ Remove all Megolm group sessions where at least twice the maximum age has passed since
+ receiving the keys.
+
+ Returns:
+ The list of session IDs that were deleted.
+ """
+
+ @abstractmethod
+ async def redact_outdated_group_sessions(self) -> list[SessionID]:
+ """
+ Remove all Megolm group sessions which lack the metadata to determine when they should
+ expire.
+
+ Returns:
+ The list of session IDs that were deleted.
+ """
+
@abstractmethod
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
"""
diff --git a/mautrix/crypto/store/asyncpg/store.py b/mautrix/crypto/store/asyncpg/store.py
index 8609e12d..bdc37ddd 100644
--- a/mautrix/crypto/store/asyncpg/store.py
+++ b/mautrix/crypto/store/asyncpg/store.py
@@ -6,12 +6,14 @@
from __future__ import annotations
from collections import defaultdict
+from contextlib import asynccontextmanager
from datetime import timedelta
from asyncpg import UniqueViolationError
from mautrix.client.state_store import SyncStore
from mautrix.client.state_store.asyncpg import PgStateStore
+from mautrix.errors import GroupSessionWithheldError
from mautrix.types import (
CrossSigner,
CrossSigningUsage,
@@ -20,6 +22,7 @@
EventID,
IdentityKey,
RoomID,
+ RoomKeyWithheldCode,
SessionID,
SigningKey,
SyncToken,
@@ -30,7 +33,7 @@
from mautrix.util.async_db import Database, Scheme
from mautrix.util.logging import TraceLogger
-from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, Session
+from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, RatchetSafety, Session
from ..abstract import CryptoStore, StateStore
from .upgrade import upgrade_table
@@ -77,6 +80,11 @@ def __init__(self, account_id: str, pickle_key: str, db: Database) -> None:
self._account = None
self._olm_cache = defaultdict(lambda: {})
+ @asynccontextmanager
+ async def transaction(self) -> None:
+ async with self.db.acquire() as conn, conn.transaction():
+ yield
+
async def delete(self) -> None:
tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session")
async with self.db.acquire() as conn, conn.transaction():
@@ -117,7 +125,7 @@ async def put_account(self, account: OlmAccount) -> None:
await self.db.execute(
q,
self.account_id,
- self._device_id,
+ self._device_id or "",
account.shared,
self._sync_token or "",
pickle,
@@ -234,8 +242,15 @@ async def put_group_session(
forwarding_chains = ",".join(session.forwarding_chain)
q = """
INSERT INTO crypto_megolm_inbound_session (
- session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id
- ) VALUES ($1, $2, $3, $4, $5, $6, $7)
+ session_id, sender_key, signing_key, room_id, session, forwarding_chains,
+ ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
+ ON CONFLICT (session_id, account_id) DO UPDATE
+ SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key,
+ signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session,
+ forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety,
+ received_at=excluded.received_at, max_age=excluded.max_age,
+ max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
"""
try:
await self.db.execute(
@@ -246,6 +261,11 @@ async def put_group_session(
room_id,
pickle,
forwarding_chains,
+ session.ratchet_safety.json(),
+ session.received_at,
+ int(session.max_age.total_seconds() * 1000) if session.max_age else None,
+ session.max_messages,
+ session.is_scheduled,
self.account_id,
)
except (IntegrityError, UniqueViolationError):
@@ -255,13 +275,17 @@ async def get_group_session(
self, room_id: RoomID, session_id: SessionID
) -> InboundGroupSession | None:
q = """
- SELECT sender_key, signing_key, session, forwarding_chains
+ SELECT
+ sender_key, signing_key, session, forwarding_chains, withheld_code,
+ ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND session_id=$2 AND account_id=$3
"""
row = await self.db.fetchrow(q, room_id, session_id, self.account_id)
if row is None:
return None
+ if row["withheld_code"] is not None:
+ raise GroupSessionWithheldError(session_id, row["withheld_code"])
forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else []
return InboundGroupSession.from_pickle(
row["session"],
@@ -270,21 +294,106 @@ async def get_group_session(
sender_key=row["sender_key"],
room_id=room_id,
forwarding_chain=forwarding_chain,
+ ratchet_safety=RatchetSafety.parse_json(row["ratchet_safety"] or "{}"),
+ received_at=row["received_at"],
+ max_age=timedelta(milliseconds=row["max_age"]) if row["max_age"] else None,
+ max_messages=row["max_messages"],
+ is_scheduled=row["is_scheduled"],
)
+ async def redact_group_session(
+ self, room_id: RoomID, session_id: SessionID, reason: str
+ ) -> None:
+ q = """
+ UPDATE crypto_megolm_inbound_session
+ SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
+ WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
+ """
+ await self.db.execute(
+ q,
+ RoomKeyWithheldCode.BEEPER_REDACTED.value,
+ f"Session redacted: {reason}",
+ session_id,
+ self.account_id,
+ )
+
+ async def redact_group_sessions(
+ self, room_id: RoomID, sender_key: IdentityKey, reason: str
+ ) -> list[SessionID]:
+ if not room_id and not sender_key:
+ raise ValueError("Either room_id or sender_key must be provided")
+ q = """
+ UPDATE crypto_megolm_inbound_session
+ SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
+ WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
+ AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
+ RETURNING session_id
+ """
+ rows = await self.db.fetch(
+ q,
+ RoomKeyWithheldCode.BEEPER_REDACTED.value,
+ f"Session redacted: {reason}",
+ room_id,
+ sender_key,
+ self.account_id,
+ )
+ return [row["session_id"] for row in rows]
+
+ async def redact_expired_group_sessions(self) -> list[SessionID]:
+ if self.db.scheme == Scheme.SQLITE:
+ q = """
+ UPDATE crypto_megolm_inbound_session
+ SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
+ WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
+ AND received_at IS NOT NULL and max_age IS NOT NULL
+ AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
+ RETURNING session_id
+ """
+ elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
+ q = """
+ UPDATE crypto_megolm_inbound_session
+ SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
+ WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
+ AND received_at IS NOT NULL and max_age IS NOT NULL
+ AND received_at + 2 * (max_age * interval '1 millisecond') < now()
+ RETURNING session_id
+ """
+ else:
+ raise RuntimeError(f"Unsupported dialect {self.db.scheme}")
+ rows = await self.db.fetch(
+ q,
+ RoomKeyWithheldCode.BEEPER_REDACTED.value,
+ f"Session redacted: expired",
+ self.account_id,
+ )
+ return [row["session_id"] for row in rows]
+
+ async def redact_outdated_group_sessions(self) -> list[SessionID]:
+ q = """
+ UPDATE crypto_megolm_inbound_session
+ SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
+ WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
+ RETURNING session_id
+ """
+ rows = await self.db.fetch(
+ q,
+ RoomKeyWithheldCode.BEEPER_REDACTED.value,
+ f"Session redacted: outdated",
+ self.account_id,
+ )
+ return [row["session_id"] for row in rows]
+
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
q = """
SELECT COUNT(session) FROM crypto_megolm_inbound_session
- WHERE room_id=$1 AND session_id=$2 AND account_id=$3
+ WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL
"""
count = await self.db.fetchval(q, room_id, session_id, self.account_id)
return count > 0
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
pickle = session.pickle(self.pickle_key)
- max_age = session.max_age
- if self.db.scheme == Scheme.SQLITE:
- max_age = max_age.total_seconds()
+ max_age = int(session.max_age.total_seconds() * 1000)
q = """
INSERT INTO crypto_megolm_outbound_session (
room_id, session_id, session, shared, max_messages, message_count,
@@ -334,9 +443,6 @@ async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSess
row = await self.db.fetchrow(q, room_id, self.account_id)
if row is None:
return None
- max_age = row["max_age"]
- if self.db.scheme == Scheme.SQLITE:
- max_age = timedelta(seconds=max_age)
return OutboundGroupSession.from_pickle(
row["session"],
passphrase=self.pickle_key,
@@ -344,7 +450,7 @@ async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSess
shared=row["shared"],
max_messages=row["max_messages"],
message_count=row["message_count"],
- max_age=max_age,
+ max_age=timedelta(milliseconds=row["max_age"]),
use_time=row["last_used"],
creation_time=row["created_at"],
)
diff --git a/mautrix/crypto/store/asyncpg/upgrade.py b/mautrix/crypto/store/asyncpg/upgrade.py
index f7f50a74..8d413858 100644
--- a/mautrix/crypto/store/asyncpg/upgrade.py
+++ b/mautrix/crypto/store/asyncpg/upgrade.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2022 Tulir Asokan
+# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -16,34 +16,34 @@
)
-@upgrade_table.register(description="Latest revision", upgrades_to=5)
-async def upgrade_blank_to_v4(conn: Connection) -> None:
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_account (
- account_id TEXT PRIMARY KEY,
- device_id TEXT,
- shared BOOLEAN NOT NULL,
- sync_token TEXT NOT NULL,
- account bytea NOT NULL
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_message_index (
+@upgrade_table.register(description="Latest revision", upgrades_to=10)
+async def upgrade_blank_to_latest(conn: Connection) -> None:
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_account (
+ account_id TEXT PRIMARY KEY,
+ device_id TEXT NOT NULL,
+ shared BOOLEAN NOT NULL,
+ sync_token TEXT NOT NULL,
+ account bytea NOT NULL
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id TEXT NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_tracked_user (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_tracked_user (
user_id TEXT PRIMARY KEY
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_device (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_device (
user_id TEXT,
device_id TEXT,
identity_key CHAR(43) NOT NULL,
@@ -52,10 +52,10 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
deleted BOOLEAN NOT NULL,
name TEXT NOT NULL,
PRIMARY KEY (user_id, device_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_olm_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_olm_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
@@ -64,22 +64,29 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
last_decrypted timestamp NOT NULL,
last_encrypted timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
- account_id TEXT,
- session_id CHAR(43),
- sender_key CHAR(43) NOT NULL,
- signing_key CHAR(43) NOT NULL,
- room_id TEXT NOT NULL,
- session bytea NOT NULL,
- forwarding_chains TEXT NOT NULL,
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
+ account_id TEXT,
+ session_id CHAR(43),
+ sender_key CHAR(43) NOT NULL,
+ signing_key CHAR(43),
+ room_id TEXT NOT NULL,
+ session bytea,
+ forwarding_chains TEXT,
+ withheld_code TEXT,
+ withheld_reason TEXT,
+ ratchet_safety jsonb,
+ received_at timestamp,
+ max_age BIGINT,
+ max_messages INTEGER,
+ is_scheduled BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
account_id TEXT,
room_id TEXT,
session_id CHAR(43) NOT NULL UNIQUE,
@@ -87,31 +94,33 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
- max_age INTERVAL NOT NULL,
+ max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_cross_signing_keys (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
- key CHAR(43),
- first_seen_key CHAR(43),
+ key CHAR(43) NOT NULL,
+
+ first_seen_key CHAR(43) NOT NULL,
+
PRIMARY KEY (user_id, usage)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_cross_signing_signatures (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
- signature TEXT,
+ signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
- )"""
- )
+ )
+ """)
@upgrade_table.register(description="Add account_id primary key column")
@@ -121,17 +130,17 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
await conn.execute("DROP TABLE crypto_olm_session")
await conn.execute("DROP TABLE crypto_megolm_inbound_session")
await conn.execute("DROP TABLE crypto_megolm_outbound_session")
- await conn.execute(
- """CREATE TABLE crypto_account (
+ await conn.execute("""
+ CREATE TABLE crypto_account (
account_id VARCHAR(255) PRIMARY KEY,
device_id VARCHAR(255) NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_olm_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_olm_session (
account_id VARCHAR(255),
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
@@ -139,10 +148,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_megolm_inbound_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_megolm_inbound_session (
account_id VARCHAR(255),
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
@@ -151,10 +160,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
session bytea NOT NULL,
forwarding_chains TEXT NOT NULL,
PRIMARY KEY (account_id, session_id)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_megolm_outbound_session (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_megolm_outbound_session (
account_id VARCHAR(255),
room_id VARCHAR(255),
session_id CHAR(43) NOT NULL UNIQUE,
@@ -162,12 +171,12 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
- max_age INTERVAL NOT NULL,
+ max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
- )"""
- )
+ )
+ """)
else:
async def add_account_id_column(table: str, pkey_columns: list[str]) -> None:
@@ -224,25 +233,25 @@ async def upgrade_v4(conn: Connection, scheme: Scheme) -> None:
@upgrade_table.register(description="Add cross-signing key and signature caches")
async def upgrade_v5(conn: Connection) -> None:
- await conn.execute(
- """CREATE TABLE crypto_cross_signing_keys (
+ await conn.execute("""
+ CREATE TABLE crypto_cross_signing_keys (
user_id TEXT,
usage TEXT,
key CHAR(43),
first_seen_key CHAR(43),
PRIMARY KEY (user_id, usage)
- )"""
- )
- await conn.execute(
- """CREATE TABLE crypto_cross_signing_signatures (
+ )
+ """)
+ await conn.execute("""
+ CREATE TABLE crypto_cross_signing_signatures (
signed_user_id TEXT,
signed_key TEXT,
signer_user_id TEXT,
signer_key TEXT,
signature TEXT,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
- )"""
- )
+ )
+ """)
@upgrade_table.register(description="Update trust state values")
@@ -250,3 +259,177 @@ async def upgrade_v6(conn: Connection) -> None:
await conn.execute("UPDATE crypto_device SET trust=300 WHERE trust=1") # verified
await conn.execute("UPDATE crypto_device SET trust=-100 WHERE trust=2") # blacklisted
await conn.execute("UPDATE crypto_device SET trust=0 WHERE trust=3") # ignored -> unset
+
+
+@upgrade_table.register(
+ description="Synchronize schema with mautrix-go", upgrades_to=9, transaction=False
+)
+async def upgrade_v9(conn: Connection, scheme: Scheme) -> None:
+ if scheme == Scheme.POSTGRES:
+ async with conn.transaction():
+ await upgrade_v9_postgres(conn)
+ else:
+ await upgrade_v9_sqlite(conn)
+
+
+# These two are never used because the previous one jumps from 6 to 9.
+@upgrade_table.register
+async def upgrade_noop_7_to_8(_: Connection) -> None:
+ pass
+
+
+@upgrade_table.register
+async def upgrade_noop_8_to_9(_: Connection) -> None:
+ pass
+
+
+async def upgrade_v9_postgres(conn: Connection) -> None:
+ await conn.execute("UPDATE crypto_account SET device_id='' WHERE device_id IS NULL")
+ await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id SET NOT NULL")
+
+ await conn.execute(
+ "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN signing_key DROP NOT NULL"
+ )
+ await conn.execute(
+ "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN session DROP NOT NULL"
+ )
+ await conn.execute(
+ "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN forwarding_chains DROP NOT NULL"
+ )
+ await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_code TEXT")
+ await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_reason TEXT")
+
+ await conn.execute("DELETE FROM crypto_cross_signing_keys WHERE key IS NULL")
+ await conn.execute(
+ "UPDATE crypto_cross_signing_keys SET first_seen_key=key WHERE first_seen_key IS NULL"
+ )
+ await conn.execute("ALTER TABLE crypto_cross_signing_keys ALTER COLUMN key SET NOT NULL")
+ await conn.execute(
+ "ALTER TABLE crypto_cross_signing_keys ALTER COLUMN first_seen_key SET NOT NULL"
+ )
+
+ await conn.execute("DELETE FROM crypto_cross_signing_signatures WHERE signature IS NULL")
+ await conn.execute(
+ "ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature SET NOT NULL"
+ )
+
+ await conn.execute(
+ "ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN max_age TYPE BIGINT "
+ "USING (EXTRACT(EPOCH from max_age)*1000)::bigint"
+ )
+
+
+async def upgrade_v9_sqlite(conn: Connection) -> None:
+ await conn.execute("PRAGMA foreign_keys = OFF")
+ async with conn.transaction():
+ await conn.execute("""
+ CREATE TABLE new_crypto_account (
+ account_id TEXT PRIMARY KEY,
+ device_id TEXT NOT NULL,
+ shared BOOLEAN NOT NULL,
+ sync_token TEXT NOT NULL,
+ account bytea NOT NULL
+ )
+ """)
+ await conn.execute("""
+ INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account)
+ SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account
+ FROM crypto_account
+ """)
+ await conn.execute("DROP TABLE crypto_account")
+ await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account")
+
+ await conn.execute("""
+ CREATE TABLE new_crypto_megolm_inbound_session (
+ account_id TEXT,
+ session_id CHAR(43),
+ sender_key CHAR(43) NOT NULL,
+ signing_key CHAR(43),
+ room_id TEXT NOT NULL,
+ session bytea,
+ forwarding_chains TEXT,
+ withheld_code TEXT,
+ withheld_reason TEXT,
+ PRIMARY KEY (account_id, session_id)
+ )
+ """)
+ await conn.execute("""
+ INSERT INTO new_crypto_megolm_inbound_session (
+ account_id, session_id, sender_key, signing_key, room_id, session,
+ forwarding_chains
+ )
+ SELECT account_id, session_id, sender_key, signing_key, room_id, session,
+ forwarding_chains
+ FROM crypto_megolm_inbound_session
+ """)
+ await conn.execute("DROP TABLE crypto_megolm_inbound_session")
+ await conn.execute(
+ "ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session"
+ )
+
+ await conn.execute("UPDATE crypto_megolm_outbound_session SET max_age=max_age*1000")
+
+ await conn.execute("""
+ CREATE TABLE new_crypto_cross_signing_keys (
+ user_id TEXT,
+ usage TEXT,
+ key CHAR(43) NOT NULL,
+
+ first_seen_key CHAR(43) NOT NULL,
+
+ PRIMARY KEY (user_id, usage)
+ )
+ """)
+ await conn.execute("""
+ INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key)
+ SELECT user_id, usage, key, COALESCE(first_seen_key, key)
+ FROM crypto_cross_signing_keys
+ WHERE key IS NOT NULL
+ """)
+ await conn.execute("DROP TABLE crypto_cross_signing_keys")
+ await conn.execute(
+ "ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys"
+ )
+
+ await conn.execute("""
+ CREATE TABLE new_crypto_cross_signing_signatures (
+ signed_user_id TEXT,
+ signed_key TEXT,
+ signer_user_id TEXT,
+ signer_key TEXT,
+ signature CHAR(88) NOT NULL,
+ PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
+ )
+ """)
+ await conn.execute("""
+ INSERT INTO new_crypto_cross_signing_signatures (
+ signed_user_id, signed_key, signer_user_id, signer_key, signature
+ )
+ SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature
+ FROM crypto_cross_signing_signatures
+ WHERE signature IS NOT NULL
+ """)
+ await conn.execute("DROP TABLE crypto_cross_signing_signatures")
+ await conn.execute(
+ "ALTER TABLE new_crypto_cross_signing_signatures "
+ "RENAME TO crypto_cross_signing_signatures"
+ )
+
+ await conn.execute("PRAGMA foreign_key_check")
+ await conn.execute("PRAGMA foreign_keys = ON")
+
+
+@upgrade_table.register(
+ description="Add metadata for detecting when megolm sessions are safe to delete"
+)
+async def upgrade_v10(conn: Connection) -> None:
+ await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb")
+ await conn.execute(
+ "ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp"
+ )
+ await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT")
+ await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER")
+ await conn.execute(
+ "ALTER TABLE crypto_megolm_inbound_session "
+ "ADD COLUMN is_scheduled BOOLEAN NOT NULL DEFAULT false"
+ )
diff --git a/mautrix/crypto/store/memory.py b/mautrix/crypto/store/memory.py
index 35dc26b5..c26f86bc 100644
--- a/mautrix/crypto/store/memory.py
+++ b/mautrix/crypto/store/memory.py
@@ -110,6 +110,33 @@ async def get_group_session(
) -> InboundGroupSession:
return self._inbound_sessions.get((room_id, session_id))
+ async def redact_group_session(
+ self, room_id: RoomID, session_id: SessionID, reason: str
+ ) -> None:
+ self._inbound_sessions.pop((room_id, session_id), None)
+
+ async def redact_group_sessions(
+ self, room_id: RoomID, sender_key: IdentityKey, reason: str
+ ) -> list[SessionID]:
+ if not room_id and not sender_key:
+ raise ValueError("Either room_id or sender_key must be provided")
+ deleted = []
+ keys = list(self._inbound_sessions.keys())
+ for key in keys:
+ item = self._inbound_sessions[key]
+ if (not room_id or item.room_id == room_id) and (
+ not sender_key or item.sender_key == sender_key
+ ):
+ deleted.append(SessionID(item.id))
+ del self._inbound_sessions[key]
+ return deleted
+
+ async def redact_expired_group_sessions(self) -> list[SessionID]:
+ raise NotImplementedError()
+
+ async def redact_outdated_group_sessions(self) -> list[SessionID]:
+ raise NotImplementedError()
+
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
return (room_id, session_id) in self._inbound_sessions
diff --git a/mautrix/crypto/store/tests/store_test.py b/mautrix/crypto/store/tests/store_test.py
index 8d3fc851..949949b8 100644
--- a/mautrix/crypto/store/tests/store_test.py
+++ b/mautrix/crypto/store/tests/store_test.py
@@ -50,7 +50,7 @@ async def async_postgres_store() -> AsyncIterator[PgCryptoStore]:
@asynccontextmanager
async def async_sqlite_store() -> AsyncIterator[PgCryptoStore]:
db = Database.create(
- "sqlite:///:memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1}
+ "sqlite::memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1}
)
store = PgCryptoStore("", "test", db)
await db.start()
diff --git a/mautrix/errors/__init__.py b/mautrix/errors/__init__.py
index fdec6e3a..afe68dc9 100644
--- a/mautrix/errors/__init__.py
+++ b/mautrix/errors/__init__.py
@@ -6,6 +6,7 @@
DeviceValidationError,
DuplicateMessageIndex,
EncryptionError,
+ GroupSessionWithheldError,
MatchingSessionDecryptionError,
MismatchingRoomError,
SessionNotFound,
@@ -72,6 +73,7 @@
"DeviceValidationError",
"DuplicateMessageIndex",
"EncryptionError",
+ "GroupSessionWithheldError",
"MatchingSessionDecryptionError",
"MismatchingRoomError",
"SessionNotFound",
diff --git a/mautrix/errors/crypto.py b/mautrix/errors/crypto.py
index 97592b05..4a65048c 100644
--- a/mautrix/errors/crypto.py
+++ b/mautrix/errors/crypto.py
@@ -36,6 +36,12 @@ class MatchingSessionDecryptionError(DecryptionError):
pass
+class GroupSessionWithheldError(DecryptionError):
+ def __init__(self, session_id: SessionID, withheld_code: str) -> None:
+ super().__init__(f"Session ID {session_id} was withheld ({withheld_code})")
+ self.withheld_code = withheld_code
+
+
class SessionNotFound(DecryptionError):
def __init__(self, session_id: SessionID, sender_key: IdentityKey | None = None) -> None:
super().__init__(
diff --git a/mautrix/types/__init__.py b/mautrix/types/__init__.py
index 6d2ffa7c..ceceeaca 100644
--- a/mautrix/types/__init__.py
+++ b/mautrix/types/__init__.py
@@ -57,6 +57,7 @@
CallRejectEventContent,
CallSelectAnswerEventContent,
CanonicalAliasStateEventContent,
+ DirectAccountDataEventContent,
EncryptedEvent,
EncryptedEventContent,
EncryptedFile,
@@ -122,6 +123,7 @@
RoomTombstoneStateEventContent,
RoomTopicStateEventContent,
RoomType,
+ SecretStorageDefaultKeyEventContent,
SingleReceiptEventContent,
SpaceChildStateEventContent,
SpaceParentStateEventContent,
@@ -149,6 +151,7 @@
)
from .misc import (
BatchSendResponse,
+ BeeperBatchSendResponse,
DeviceLists,
DeviceOTKCount,
DirectoryPaginationToken,
@@ -258,6 +261,7 @@
"CallRejectEventContent",
"CallSelectAnswerEventContent",
"CanonicalAliasStateEventContent",
+ "DirectAccountDataEventContent",
"EncryptedEvent",
"EncryptedEventContent",
"EncryptedFile",
@@ -323,6 +327,7 @@
"RoomTombstoneStateEventContent",
"RoomTopicStateEventContent",
"RoomType",
+ "SecretStorageDefaultKeyEventContent",
"SingleReceiptEventContent",
"SpaceChildStateEventContent",
"SpaceParentStateEventContent",
@@ -353,6 +358,7 @@
"OpenGraphImage",
"OpenGraphVideo",
"BatchSendResponse",
+ "BeeperBatchSendResponse",
"DeviceLists",
"DeviceOTKCount",
"DirectoryPaginationToken",
diff --git a/mautrix/types/auth.py b/mautrix/types/auth.py
index 198d2f53..ad582118 100644
--- a/mautrix/types/auth.py
+++ b/mautrix/types/auth.py
@@ -26,6 +26,8 @@ class LoginType(ExtensibleEnum):
UNSTABLE_JWT: "LoginType" = "org.matrix.login.jwt"
+ DEVTURE_SHARED_SECRET: "LoginType" = "com.devture.shared_secret_auth"
+
@dataclass
class LoginFlow(SerializableAttrs):
diff --git a/mautrix/types/crypto.py b/mautrix/types/crypto.py
index 821599de..fe4ab742 100644
--- a/mautrix/types/crypto.py
+++ b/mautrix/types/crypto.py
@@ -47,9 +47,9 @@ def curve25519(self) -> Optional[IdentityKey]:
class CrossSigningUsage(ExtensibleEnum):
- MASTER = "master"
- SELF = "self_signing"
- USER = "user_signing"
+ MASTER: "CrossSigningUsage" = "master"
+ SELF: "CrossSigningUsage" = "self_signing"
+ USER: "CrossSigningUsage" = "user_signing"
@dataclass
@@ -71,6 +71,8 @@ def first_ed25519_key(self) -> Optional[SigningKey]:
return self.first_key_with_algorithm(EncryptionKeyAlgorithm.ED25519)
def first_key_with_algorithm(self, alg: EncryptionKeyAlgorithm) -> Optional[SigningKey]:
+ if not self.keys:
+ return None
try:
return next(key for key_id, key in self.keys.items() if key_id.algorithm == alg)
except StopIteration:
diff --git a/mautrix/types/event/__init__.py b/mautrix/types/event/__init__.py
index b391e912..db0658db 100644
--- a/mautrix/types/event/__init__.py
+++ b/mautrix/types/event/__init__.py
@@ -6,8 +6,10 @@
from .account_data import (
AccountDataEvent,
AccountDataEventContent,
+ DirectAccountDataEventContent,
RoomTagAccountDataEventContent,
RoomTagInfo,
+ SecretStorageDefaultKeyEventContent,
)
from .base import BaseEvent, BaseRoomEvent, BaseUnsigned, GenericEvent
from .batch import BatchSendEvent, BatchSendStateEvent
diff --git a/mautrix/types/event/account_data.py b/mautrix/types/event/account_data.py
index 3c144a36..fdfa0a30 100644
--- a/mautrix/types/event/account_data.py
+++ b/mautrix/types/event/account_data.py
@@ -3,7 +3,7 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Dict, List, Union
+from typing import TYPE_CHECKING, Dict, List, Union
from attr import dataclass
import attr
@@ -12,6 +12,9 @@
from ..util import Obj, SerializableAttrs, deserializer
from .base import BaseEvent, EventType
+if TYPE_CHECKING:
+ from mautrix.crypto.ssss import EncryptedAccountDataEventContent, KeyMetadata
+
@dataclass
class RoomTagInfo(SerializableAttrs):
@@ -23,11 +26,24 @@ class RoomTagAccountDataEventContent(SerializableAttrs):
tags: Dict[str, RoomTagInfo] = attr.ib(default=None, metadata={"json": "tags"})
+@dataclass
+class SecretStorageDefaultKeyEventContent(SerializableAttrs):
+ key: str
+
+
DirectAccountDataEventContent = Dict[UserID, List[RoomID]]
-AccountDataEventContent = Union[RoomTagAccountDataEventContent, DirectAccountDataEventContent, Obj]
+AccountDataEventContent = Union[
+ RoomTagAccountDataEventContent,
+ DirectAccountDataEventContent,
+ SecretStorageDefaultKeyEventContent,
+ "EncryptedAccountDataEventContent",
+ "KeyMetadata",
+ Obj,
+]
account_data_event_content_map = {
EventType.TAG: RoomTagAccountDataEventContent,
+ EventType.SECRET_STORAGE_DEFAULT_KEY: SecretStorageDefaultKeyEventContent,
# m.direct doesn't really need deserializing
# EventType.DIRECT: DirectAccountDataEventContent,
}
diff --git a/mautrix/types/event/beeper.py b/mautrix/types/event/beeper.py
index dc588800..0ec16479 100644
--- a/mautrix/types/event/beeper.py
+++ b/mautrix/types/event/beeper.py
@@ -7,7 +7,7 @@
from attr import dataclass
-from ..primitive import EventID
+from ..primitive import EventID, RoomID, SessionID
from ..util import SerializableAttrs, SerializableEnum, field
from .base import BaseRoomEvent
from .message import RelatesTo
@@ -49,20 +49,16 @@ class BeeperMessageStatusEventContent(SerializableAttrs):
error: Optional[str] = None
message: Optional[str] = None
- success: Optional[bool] = None
- still_working: Optional[bool] = None
- can_retry: Optional[bool] = None
- is_certain: Optional[bool] = None
-
last_retry: Optional[EventID] = None
- def fill_legacy_booleans(self) -> None:
- self.success = self.status == MessageStatus.SUCCESS
- if not self.success:
- self.still_working = self.status == MessageStatus.PENDING
- self.can_retry = self.status in (MessageStatus.PENDING, MessageStatus.RETRIABLE)
-
@dataclass
class BeeperMessageStatusEvent(BaseRoomEvent, SerializableAttrs):
content: BeeperMessageStatusEventContent
+
+
+@dataclass
+class BeeperRoomKeyAckEventContent(SerializableAttrs):
+ room_id: RoomID
+ session_id: SessionID
+ first_message_index: int
diff --git a/mautrix/types/event/encrypted.py b/mautrix/types/event/encrypted.py
index cd08fc94..735e481a 100644
--- a/mautrix/types/event/encrypted.py
+++ b/mautrix/types/event/encrypted.py
@@ -9,7 +9,7 @@
from attr import dataclass
-from ..primitive import JSON, DeviceID, IdentityKey, SessionID
+from ..primitive import JSON, DeviceID, IdentityKey, SessionID, SigningKey
from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field
from .base import BaseRoomEvent, BaseUnsigned
from .message import RelatesTo
@@ -43,6 +43,18 @@ def deserialize(cls, raw: JSON) -> "KeyID":
def __str__(self) -> str:
return f"{self.algorithm.value}:{self.key_id}"
+ @classmethod
+ def ed25519(cls, key_id: SigningKey | DeviceID) -> "KeyID":
+ return cls(EncryptionKeyAlgorithm.ED25519, key_id)
+
+ @classmethod
+ def curve25519(cls, key_id: IdentityKey) -> "KeyID":
+ return cls(EncryptionKeyAlgorithm.CURVE25519, key_id)
+
+ @classmethod
+ def signed_curve25519(cls, key_id: IdentityKey) -> "KeyID":
+ return cls(EncryptionKeyAlgorithm.SIGNED_CURVE25519, key_id)
+
class OlmMsgType(Serializable, IntEnum):
PREKEY = 0
diff --git a/mautrix/types/event/message.py b/mautrix/types/event/message.py
index eab4ba95..32033581 100644
--- a/mautrix/types/event/message.py
+++ b/mautrix/types/event/message.py
@@ -112,6 +112,12 @@ def set_thread_parent(
self.relates_to.event_id = (
thread_parent if isinstance(thread_parent, str) else thread_parent.event_id
)
+ if isinstance(thread_parent, MessageEvent) and isinstance(
+ thread_parent.content, BaseMessageEventContentFuncs
+ ):
+ self.relates_to.event_id = (
+ thread_parent.content.get_thread_parent() or self.relates_to.event_id
+ )
if not disable_reply_fallback:
self.set_reply(last_event_in_thread or thread_parent, **kwargs)
self.relates_to.is_falling_back = True
@@ -265,33 +271,6 @@ class LocationInfo(SerializableAttrs):
# region Event content
-@dataclass
-class MediaMessageEventContent(BaseMessageEventContent, SerializableAttrs):
- """The content of a media message event (m.image, m.audio, m.video, m.file)"""
-
- url: Optional[ContentURI] = None
- info: Optional[MediaInfo] = None
- file: Optional[EncryptedFile] = None
-
- @staticmethod
- @deserializer(MediaInfo)
- @deserializer(Optional[MediaInfo])
- def deserialize_info(data: JSON) -> MediaInfo:
- if not isinstance(data, dict):
- return Obj()
- msgtype = data.pop("__mautrix_msgtype", None)
- if msgtype == "m.image" or msgtype == "m.sticker":
- return ImageInfo.deserialize(data)
- elif msgtype == "m.video":
- return VideoInfo.deserialize(data)
- elif msgtype == "m.audio":
- return AudioInfo.deserialize(data)
- elif msgtype == "m.file":
- return FileInfo.deserialize(data)
- else:
- return Obj(**data)
-
-
@dataclass
class LocationMessageEventContent(BaseMessageEventContent, SerializableAttrs):
geo_uri: str = None
@@ -308,21 +287,6 @@ class TextMessageEventContent(BaseMessageEventContent, SerializableAttrs):
format: Format = None
formatted_body: str = None
- def set_reply(
- self, reply_to: Union["MessageEvent", EventID], *, displayname: Optional[str] = None
- ) -> None:
- super().set_reply(reply_to)
- if isinstance(reply_to, str):
- return
- if isinstance(reply_to, MessageEvent):
- self.ensure_has_html()
- if isinstance(reply_to.content, TextMessageEventContent):
- reply_to.content.trim_reply_fallback()
- self.formatted_body = (
- reply_to.make_reply_fallback_html(displayname) + self.formatted_body
- )
- self.body = reply_to.make_reply_fallback_text(displayname) + self.body
-
def ensure_has_html(self) -> None:
if not self.formatted_body or self.format != Format.HTML:
self.format = Format.HTML
@@ -340,20 +304,48 @@ def trim_reply_fallback(self) -> None:
setattr(self, "__reply_fallback_trimmed", True)
def _trim_reply_fallback_text(self) -> None:
- if not self.body.startswith("> ") or "\n" not in self.body:
+ if (
+ not self.body.startswith("> <") and not self.body.startswith("> * <")
+ ) or "\n" not in self.body:
return
lines = self.body.split("\n")
while len(lines) > 0 and lines[0].startswith("> "):
lines.pop(0)
- # Pop extra newline at end of fallback
- lines.pop(0)
- self.body = "\n".join(lines)
+ self.body = "\n".join(lines).strip()
def _trim_reply_fallback_html(self) -> None:
if self.formatted_body and self.format == Format.HTML:
self.formatted_body = html_reply_fallback_regex.sub("", self.formatted_body)
+@dataclass
+class MediaMessageEventContent(TextMessageEventContent, SerializableAttrs):
+ """The content of a media message event (m.image, m.audio, m.video, m.file)"""
+
+ url: Optional[ContentURI] = None
+ info: Optional[MediaInfo] = None
+ file: Optional[EncryptedFile] = None
+ filename: Optional[str] = None
+
+ @staticmethod
+ @deserializer(MediaInfo)
+ @deserializer(Optional[MediaInfo])
+ def deserialize_info(data: JSON) -> MediaInfo:
+ if not isinstance(data, dict):
+ return Obj()
+ msgtype = data.pop("__mautrix_msgtype", None)
+ if msgtype == "m.image" or msgtype == "m.sticker":
+ return ImageInfo.deserialize(data)
+ elif msgtype == "m.video":
+ return VideoInfo.deserialize(data)
+ elif msgtype == "m.audio":
+ return AudioInfo.deserialize(data)
+ elif msgtype == "m.file":
+ return FileInfo.deserialize(data)
+ else:
+ return Obj(**data)
+
+
MessageEventContent = Union[
TextMessageEventContent, MediaMessageEventContent, LocationMessageEventContent, Obj
]
@@ -413,36 +405,3 @@ def deserialize_content(data: JSON) -> MessageEventContent:
return LocationMessageEventContent.deserialize(data)
else:
return Obj(**data)
-
- def make_reply_fallback_html(self, displayname: Optional[str] = None) -> str:
- """Generate the HTML fallback for messages replying to this event."""
- if self.content.msgtype.is_text:
- body = self.content.formatted_body or escape(self.content.body).replace("\n", "
")
- else:
- sent_type = media_reply_fallback_body_map[self.content.msgtype] or "a message"
- body = f"sent {sent_type}"
- displayname = escape(displayname) if displayname else self.sender
- return html_reply_fallback_format.format(
- room_id=self.room_id,
- event_id=self.event_id,
- sender=self.sender,
- displayname=displayname,
- content=body,
- )
-
- def make_reply_fallback_text(self, displayname: Optional[str] = None) -> str:
- """Generate the plaintext fallback for messages replying to this event."""
- if self.content.msgtype.is_text:
- body = self.content.body
- else:
- try:
- body = media_reply_fallback_body_map[self.content.msgtype]
- except KeyError:
- body = "an unknown message type"
- lines = body.strip().split("\n")
- first_line, lines = lines[0], lines[1:]
- fallback_text = f"> <{displayname or self.sender}> {first_line}"
- for line in lines:
- fallback_text += f"\n> {line}"
- fallback_text += "\n\n"
- return fallback_text
diff --git a/mautrix/types/event/state.py b/mautrix/types/event/state.py
index c6b351c6..5ffc855f 100644
--- a/mautrix/types/event/state.py
+++ b/mautrix/types/event/state.py
@@ -9,7 +9,7 @@
import attr
from ..primitive import JSON, ContentURI, EventID, RoomAlias, RoomID, UserID
-from ..util import Obj, SerializableAttrs, SerializableEnum, deserializer, field
+from ..util import ExtensibleEnum, Obj, SerializableAttrs, SerializableEnum, deserializer, field
from .base import BaseRoomEvent, BaseUnsigned
from .encrypted import EncryptionAlgorithm
from .type import EventType, RoomType
@@ -41,7 +41,19 @@ class PowerLevelStateEventContent(SerializableAttrs):
ban: int = 50
redact: int = 50
- def get_user_level(self, user_id: UserID) -> int:
+ def get_user_level(
+ self,
+ user_id: UserID,
+ create: Optional["StateEvent"] = None,
+ ) -> int:
+ if (
+ create
+ and create.content.supports_creator_power
+ and (user_id == create.sender or user_id in (create.content.additional_creators or []))
+ ):
+ # This is really meant to be infinity, but involving floats would be annoying,
+ # so we use an integer larger than the maximum power level (2^53-1) instead.
+ return 2**60 - 1
return int(self.users.get(user_id, self.users_default))
def set_user_level(self, user_id: UserID, level: int) -> None:
@@ -50,7 +62,16 @@ def set_user_level(self, user_id: UserID, level: int) -> None:
else:
self.users[user_id] = level
- def ensure_user_level(self, user_id: UserID, level: int) -> bool:
+ def ensure_user_level(
+ self, user_id: UserID, level: int, create: Optional["StateEvent"] = None
+ ) -> bool:
+ if (
+ create
+ and create.content.supports_creator_power
+ and (user_id == create.sender or user_id in (create.content.additional_creators or []))
+ ):
+ # Don't try to set creator power levels
+ return False
if self.get_user_level(user_id) != level:
self.set_user_level(user_id, level)
return True
@@ -138,16 +159,15 @@ class RoomAvatarStateEventContent(SerializableAttrs):
url: Optional[ContentURI] = None
-class JoinRule(SerializableEnum):
+class JoinRule(ExtensibleEnum):
PUBLIC = "public"
KNOCK = "knock"
RESTRICTED = "restricted"
INVITE = "invite"
- PRIVATE = "private"
KNOCK_RESTRICTED = "knock_restricted"
-class JoinRestrictionType(SerializableEnum):
+class JoinRestrictionType(ExtensibleEnum):
ROOM_MEMBERSHIP = "m.room_membership"
@@ -193,6 +213,24 @@ class RoomCreateStateEventContent(SerializableAttrs):
federate: bool = field(json="m.federate", omit_default=True, default=True)
predecessor: Optional[RoomPredecessor] = None
type: Optional[RoomType] = None
+ additional_creators: Optional[List[UserID]] = None
+
+ @property
+ def supports_creator_power(self) -> bool:
+ return self.room_version not in (
+ "",
+ "1",
+ "2",
+ "3",
+ "4",
+ "5",
+ "6",
+ "7",
+ "8",
+ "9",
+ "10",
+ "11",
+ )
@dataclass
diff --git a/mautrix/types/event/to_device.py b/mautrix/types/event/to_device.py
index 6a17e36e..d6392177 100644
--- a/mautrix/types/event/to_device.py
+++ b/mautrix/types/event/to_device.py
@@ -9,8 +9,9 @@
import attr
from ..primitive import JSON, DeviceID, IdentityKey, RoomID, SessionID, SigningKey, UserID
-from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer
+from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field
from .base import BaseEvent, EventType
+from .beeper import BeeperRoomKeyAckEventContent
from .encrypted import EncryptedOlmEventContent, EncryptionAlgorithm
@@ -21,6 +22,8 @@ class RoomKeyWithheldCode(ExtensibleEnum):
UNAVAILABLE: "RoomKeyWithheldCode" = "m.unavailable"
NO_OLM_SESSION: "RoomKeyWithheldCode" = "m.no_olm"
+ BEEPER_REDACTED: "RoomKeyWithheldCode" = "com.beeper.redacted"
+
@dataclass
class RoomKeyWithheldEventContent(SerializableAttrs):
@@ -39,6 +42,10 @@ class RoomKeyEventContent(SerializableAttrs):
session_id: SessionID
session_key: str
+ beeper_max_age_ms: Optional[int] = field(json="com.beeper.max_age_ms", default=None)
+ beeper_max_messages: Optional[int] = field(json="com.beeper.max_messages", default=None)
+ beeper_is_scheduled: Optional[bool] = field(json="com.beeper.is_scheduled", default=False)
+
class KeyRequestAction(ExtensibleEnum):
REQUEST: "KeyRequestAction" = "request"
@@ -61,7 +68,7 @@ class RoomKeyRequestEventContent(SerializableAttrs):
body: Optional[RequestedKeyInfo] = None
-@dataclass
+@dataclass(kw_only=True)
class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs):
sender_key: IdentityKey
signing_key: SigningKey = attr.ib(metadata={"json": "sender_claimed_ed25519_key"})
@@ -75,6 +82,7 @@ class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs):
RoomKeyEventContent,
RoomKeyRequestEventContent,
ForwardedRoomKeyEventContent,
+ BeeperRoomKeyAckEventContent,
]
to_device_event_content_map = {
EventType.TO_DEVICE_ENCRYPTED: EncryptedOlmEventContent,
@@ -82,12 +90,10 @@ class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs):
EventType.ROOM_KEY_REQUEST: RoomKeyRequestEventContent,
EventType.ROOM_KEY: RoomKeyEventContent,
EventType.FORWARDED_ROOM_KEY: ForwardedRoomKeyEventContent,
+ EventType.BEEPER_ROOM_KEY_ACK: BeeperRoomKeyAckEventContent,
}
-# TODO remaining account data event types
-
-
@dataclass
class ToDeviceEvent(BaseEvent, SerializableAttrs):
sender: UserID
diff --git a/mautrix/types/event/type.py b/mautrix/types/event/type.py
index 609f40ff..5faf3785 100644
--- a/mautrix/types/event/type.py
+++ b/mautrix/types/event/type.py
@@ -207,6 +207,11 @@ def is_to_device(self) -> bool:
"m.push_rules": "PUSH_RULES",
"m.tag": "TAG",
"m.ignored_user_list": "IGNORED_USER_LIST",
+ "m.secret_storage.default_key": "SECRET_STORAGE_DEFAULT_KEY",
+ "m.cross_signing.master": "CROSS_SIGNING_MASTER",
+ "m.cross_signing.self_signing": "CROSS_SIGNING_SELF_SIGNING",
+ "m.cross_signing.user_signing": "CROSS_SIGNING_USER_SIGNING",
+ "m.megolm_backup.v1": "MEGOLM_BACKUP_V1",
},
EventType.Class.TO_DEVICE: {
"m.room.encrypted": "TO_DEVICE_ENCRYPTED",
@@ -216,6 +221,7 @@ def is_to_device(self) -> bool:
"m.room_key_request": "ROOM_KEY_REQUEST",
"m.forwarded_room_key": "FORWARDED_ROOM_KEY",
"m.dummy": "TO_DEVICE_DUMMY",
+ "com.beeper.room_key.ack": "BEEPER_ROOM_KEY_ACK",
},
EventType.Class.UNKNOWN: {
"__ALL__": "ALL", # This is not a real event type
diff --git a/mautrix/types/event/type.pyi b/mautrix/types/event/type.pyi
index 6c4c6e65..a2788d6f 100644
--- a/mautrix/types/event/type.pyi
+++ b/mautrix/types/event/type.pyi
@@ -18,6 +18,7 @@ class EventType(Serializable):
ACCOUNT_DATA = "account_data"
EPHEMERAL = "ephemeral"
TO_DEVICE = "to_device"
+
_by_event_type: ClassVar[dict[str, EventType]]
ROOM_CANONICAL_ALIAS: "EventType"
@@ -60,6 +61,11 @@ class EventType(Serializable):
PUSH_RULES: "EventType"
TAG: "EventType"
IGNORED_USER_LIST: "EventType"
+ SECRET_STORAGE_DEFAULT_KEY: "EventType"
+ CROSS_SIGNING_MASTER: "EventType"
+ CROSS_SIGNING_SELF_SIGNING: "EventType"
+ CROSS_SIGNING_USER_SIGNING: "EventType"
+ MEGOLM_BACKUP_V1: "EventType"
TO_DEVICE_ENCRYPTED: "EventType"
TO_DEVICE_DUMMY: "EventType"
@@ -68,6 +74,7 @@ class EventType(Serializable):
ORG_MATRIX_ROOM_KEY_WITHHELD: "EventType"
ROOM_KEY_REQUEST: "EventType"
FORWARDED_ROOM_KEY: "EventType"
+ BEEPER_ROOM_KEY_ACK: "EventType"
ALL: "EventType"
diff --git a/mautrix/types/media.py b/mautrix/types/media.py
index 72aa6c4f..1d0ee66a 100644
--- a/mautrix/types/media.py
+++ b/mautrix/types/media.py
@@ -20,7 +20,7 @@ class MediaRepoConfig(SerializableAttrs):
https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3config
"""
- upload_size: int = field(json="m.upload.size")
+ upload_size: int = field(default=50 * 1024 * 1024, json="m.upload.size")
@dataclass
@@ -71,4 +71,4 @@ class MediaCreateResponse(SerializableAttrs):
content_uri: ContentURI
unused_expired_at: Optional[int] = None
- upload_url: Optional[str] = None
+ unstable_upload_url: Optional[str] = field(default=None, json="com.beeper.msc3870.upload_url")
diff --git a/mautrix/types/misc.py b/mautrix/types/misc.py
index 7a978bd1..5a07699c 100644
--- a/mautrix/types/misc.py
+++ b/mautrix/types/misc.py
@@ -87,7 +87,7 @@ class PublicRoomInfo(SerializableAttrs):
num_joined_members: int
world_readable: bool
- guests_can_join: bool
+ guest_can_join: bool
name: str = None
topic: str = None
@@ -106,7 +106,7 @@ class RoomDirectoryResponse(SerializableAttrs):
PaginatedMessages = NamedTuple(
- "PaginatedMessages", start=SyncToken, end=SyncToken, events=List[Event]
+ "PaginatedMessages", start=SyncToken, end=Optional[SyncToken], events=List[Event]
)
@@ -129,3 +129,8 @@ class BatchSendResponse(SerializableAttrs):
batch_event_id: EventID
next_batch_id: BatchID
base_insertion_event_id: Optional[EventID] = None
+
+
+@dataclass
+class BeeperBatchSendResponse(SerializableAttrs):
+ event_ids: List[EventID]
diff --git a/mautrix/types/versions.py b/mautrix/types/versions.py
index 8b42de39..52a62f59 100644
--- a/mautrix/types/versions.py
+++ b/mautrix/types/versions.py
@@ -70,6 +70,14 @@ class SpecVersions:
V11 = Version.deserialize("v1.1")
V12 = Version.deserialize("v1.2")
V13 = Version.deserialize("v1.3")
+ V14 = Version.deserialize("v1.4")
+ V15 = Version.deserialize("v1.5")
+ V16 = Version.deserialize("v1.6")
+ V17 = Version.deserialize("v1.7")
+ V18 = Version.deserialize("v1.8")
+ V19 = Version.deserialize("v1.9")
+ V110 = Version.deserialize("v1.10")
+ V111 = Version.deserialize("v1.11")
@dataclass
diff --git a/mautrix/util/__init__.py b/mautrix/util/__init__.py
index 6a7827d1..fd349bef 100644
--- a/mautrix/util/__init__.py
+++ b/mautrix/util/__init__.py
@@ -1,22 +1,28 @@
__all__ = [
+ # Directory modules
+ "async_db",
+ "config",
+ "db",
"formatter",
"logging",
- "config",
- "signed_token",
- "simple_template",
- "manhole",
- "markdown",
- "simple_lock",
+ # File modules
+ "async_body",
+ "async_getter_lock",
+ "background_task",
+ "bridge_state",
+ "color_log",
+ "ffmpeg",
"file_store",
- "program",
- "async_db",
- "db",
- "opt_prometheus",
+ "format_duration",
"magic",
- "bridge_state",
+ "manhole",
+ "markdown",
"message_send_checkpoint",
- "variation_selector",
- "format_duration",
- "ffmpeg",
+ "opt_prometheus",
+ "program",
+ "signed_token",
+ "simple_lock",
+ "simple_template",
"utf16_surrogate",
+ "variation_selector",
]
diff --git a/mautrix/util/async_body.py b/mautrix/util/async_body.py
new file mode 100644
index 00000000..4db4d1e5
--- /dev/null
+++ b/mautrix/util/async_body.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2023 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from __future__ import annotations
+
+from typing import AsyncGenerator, Union
+import logging
+
+import aiohttp
+
+AsyncBody = AsyncGenerator[Union[bytes, bytearray, memoryview], None]
+
+
+async def async_iter_bytes(data: bytearray | bytes, chunk_size: int = 1024**2) -> AsyncBody:
+ """
+ Return memory views into a byte array in chunks. This is used to prevent aiohttp from copying
+ the entire request body.
+
+ Args:
+ data: The underlying data to iterate through.
+ chunk_size: How big each returned chunk should be.
+
+ Returns:
+ An async generator that yields the given data in chunks.
+ """
+ with memoryview(data) as mv:
+ for i in range(0, len(data), chunk_size):
+ yield mv[i : i + chunk_size]
+
+
+class FileTooLargeError(Exception):
+ def __init__(self, max_size: int) -> None:
+ super().__init__(f"File size larger than maximum ({max_size / 1024 / 1024} MiB)")
+
+
+_default_dl_log = logging.getLogger("mau.util.download")
+
+
+async def read_response_chunks(
+ resp: aiohttp.ClientResponse, max_size: int, log: logging.Logger = _default_dl_log
+) -> bytearray:
+ """
+ Read the body from an aiohttp response in chunks into a mutable bytearray.
+
+ Args:
+ resp: The aiohttp response object to read the body from.
+ max_size: The maximum size to read. FileTooLargeError will be raised if the Content-Length
+ is higher than this, or if the body exceeds this size during reading.
+ log: A logger for logging download status.
+
+ Returns:
+ The body data as a byte array.
+
+ Raises:
+ FileTooLargeError: if the body is larger than the provided max_size.
+ """
+ content_length = int(resp.headers.get("Content-Length", "0"))
+ if 0 < max_size < content_length:
+ raise FileTooLargeError(max_size)
+ size_str = "unknown length" if content_length == 0 else f"{content_length} bytes"
+ log.info(f"Reading file download response with {size_str} (max: {max_size})")
+ data = bytearray(content_length)
+ mv = memoryview(data) if content_length > 0 else None
+ read_size = 0
+ max_size += 1
+ while True:
+ block = await resp.content.readany()
+ if not block:
+ break
+ max_size -= len(block)
+ if max_size <= 0:
+ raise FileTooLargeError(max_size)
+ if len(data) >= read_size + len(block):
+ mv[read_size : read_size + len(block)] = block
+ elif len(data) > read_size:
+ log.warning("File being downloaded is bigger than expected")
+ mv[read_size:] = block[: len(data) - read_size]
+ mv.release()
+ mv = None
+ data.extend(block[len(data) - read_size :])
+ else:
+ if mv is not None:
+ mv.release()
+ mv = None
+ data.extend(block)
+ read_size += len(block)
+ if mv is not None:
+ mv.release()
+ log.info(f"Successfully read {read_size} bytes of file download response")
+ return data
+
+
+__all__ = ["AsyncBody", "FileTooLargeError", "async_iter_bytes", "async_read_bytes"]
diff --git a/mautrix/util/async_db/aiosqlite.py b/mautrix/util/async_db/aiosqlite.py
index 9b7fa16a..934379a8 100644
--- a/mautrix/util/async_db/aiosqlite.py
+++ b/mautrix/util/async_db/aiosqlite.py
@@ -5,10 +5,12 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Any
+from typing import Any, AsyncContextManager
from contextlib import asynccontextmanager
+from contextvars import ContextVar
import asyncio
import logging
+import os
import re
import sqlite3
@@ -23,6 +25,9 @@
POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)")
+in_transaction = ContextVar("in_transaction", default=False)
+
+
class TxnConnection(aiosqlite.Connection):
def __init__(self, path: str, **kwargs) -> None:
def connector() -> sqlite3.Connection:
@@ -34,7 +39,11 @@ def connector() -> sqlite3.Connection:
@asynccontextmanager
async def transaction(self) -> None:
+ if in_transaction.get():
+ yield
+ return
await self.execute("BEGIN TRANSACTION")
+ token = in_transaction.set(True)
try:
yield
except Exception:
@@ -42,6 +51,8 @@ async def transaction(self) -> None:
raise
else:
await self.commit()
+ finally:
+ in_transaction.reset(token)
def __execute(self, query: str, *args: Any):
query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query)
@@ -81,6 +92,7 @@ async def fetchval(
class SQLiteDatabase(Database):
scheme = Scheme.SQLITE
+ _parent: SQLiteDatabase | None
_pool: asyncio.Queue[TxnConnection]
_stopped: bool
_conns: int
@@ -103,9 +115,8 @@ def __init__(
owner_name=owner_name,
ignore_foreign_tables=ignore_foreign_tables,
)
+ self._parent = None
self._path = url.path
- if self._path.startswith("/"):
- self._path = self._path[1:]
self._pool = asyncio.Queue(self._db_args.pop("min_size", 1))
self._db_args.pop("max_size", None)
self._stopped = False
@@ -116,6 +127,7 @@ def __init__(
def _add_missing_pragmas(init_commands: list[str]) -> list[str]:
has_foreign_keys = False
has_journal_mode = False
+ has_synchronous = False
has_busy_timeout = False
for cmd in init_commands:
if "PRAGMA" not in cmd:
@@ -124,23 +136,39 @@ def _add_missing_pragmas(init_commands: list[str]) -> list[str]:
has_foreign_keys = True
elif "journal_mode" in cmd:
has_journal_mode = True
+ elif "synchronous" in cmd:
+ has_synchronous = True
elif "busy_timeout" in cmd:
has_busy_timeout = True
if not has_foreign_keys:
init_commands.append("PRAGMA foreign_keys = ON")
if not has_journal_mode:
init_commands.append("PRAGMA journal_mode = WAL")
+ if not has_synchronous and "PRAGMA journal_mode = WAL" in init_commands:
+ init_commands.append("PRAGMA synchronous = NORMAL")
if not has_busy_timeout:
init_commands.append("PRAGMA busy_timeout = 5000")
return init_commands
+ def override_pool(self, db: Database) -> None:
+ assert isinstance(db, SQLiteDatabase)
+ self._parent = db
+
async def start(self) -> None:
+ if self._parent:
+ await super().start()
+ return
if self._conns:
raise RuntimeError("database pool has already been started")
elif self._stopped:
raise RuntimeError("database pool can't be restarted")
self.log.debug(f"Connecting to {self.url}")
self.log.debug(f"Database connection init commands: {self._init_commands}")
+ if os.path.exists(self._path):
+ if not os.access(self._path, os.W_OK):
+ self.log.warning("Database file doesn't seem writable")
+ elif not os.access(os.path.dirname(os.path.abspath(self._path)), os.W_OK):
+ self.log.warning("Database file doesn't exist and directory doesn't seem writable")
for _ in range(self._pool.maxsize):
conn = await TxnConnection(self._path, **self._db_args)
if self._init_commands:
@@ -155,14 +183,21 @@ async def start(self) -> None:
await super().start()
async def stop(self) -> None:
+ if self._parent:
+ return
self._stopped = True
while self._conns > 0:
conn = await self._pool.get()
self._conns -= 1
await conn.close()
+ def acquire_direct(self) -> AsyncContextManager[LoggingConnection]:
+ if self._parent:
+ return self._parent.acquire()
+ return self._acquire()
+
@asynccontextmanager
- async def acquire(self) -> LoggingConnection:
+ async def _acquire(self) -> LoggingConnection:
if self._stopped:
raise RuntimeError("database pool has been stopped")
conn = await self._pool.get()
diff --git a/mautrix/util/async_db/asyncpg.py b/mautrix/util/async_db/asyncpg.py
index 07ef7ad7..97b49f6c 100644
--- a/mautrix/util/async_db/asyncpg.py
+++ b/mautrix/util/async_db/asyncpg.py
@@ -96,7 +96,7 @@ async def _handle_exception(self, err: Exception) -> None:
sys.exit(26)
@asynccontextmanager
- async def acquire(self) -> LoggingConnection:
+ async def acquire_direct(self) -> LoggingConnection:
async with self.pool.acquire() as conn:
yield LoggingConnection(
self.scheme, conn, self.log, handle_exception=self._handle_exception
diff --git a/mautrix/util/async_db/database.py b/mautrix/util/async_db/database.py
index 0f23b02d..b5128b74 100644
--- a/mautrix/util/async_db/database.py
+++ b/mautrix/util/async_db/database.py
@@ -7,6 +7,8 @@
from typing import Any, AsyncContextManager, Type
from abc import ABC, abstractmethod
+from contextlib import asynccontextmanager
+from contextvars import ContextVar
import logging
from yarl import URL
@@ -23,6 +25,8 @@
from aiosqlite import Cursor
from asyncpg import Record
+conn_var: ContextVar[LoggingConnection | None] = ContextVar("db_connection", default=None)
+
class Database(ABC):
schemes: dict[str, Type[Database]] = {}
@@ -111,15 +115,17 @@ async def _check_foreign_tables(self) -> None:
raise ForeignTablesFound("found roomserver_rooms likely belonging to Dendrite")
async def _check_owner(self) -> None:
- await self.execute(
- """CREATE TABLE IF NOT EXISTS database_owner (
+ await self.execute("""
+ CREATE TABLE IF NOT EXISTS database_owner (
key INTEGER PRIMARY KEY DEFAULT 0,
owner TEXT NOT NULL
- )"""
- )
+ )
+ """)
owner = await self.fetchval("SELECT owner FROM database_owner WHERE key=0")
if not owner:
- await self.execute("INSERT INTO database_owner (owner) VALUES ($1)", self.owner_name)
+ await self.execute(
+ "INSERT INTO database_owner (key, owner) VALUES (0, $1)", self.owner_name
+ )
elif owner != self.owner_name:
raise DatabaseNotOwned(owner)
@@ -128,9 +134,22 @@ async def stop(self) -> None:
pass
@abstractmethod
- def acquire(self) -> AsyncContextManager[LoggingConnection]:
+ def acquire_direct(self) -> AsyncContextManager[LoggingConnection]:
pass
+ @asynccontextmanager
+ async def acquire(self) -> LoggingConnection:
+ conn = conn_var.get(None)
+ if conn is not None:
+ yield conn
+ return
+ async with self.acquire_direct() as conn:
+ token = conn_var.set(conn)
+ try:
+ yield conn
+ finally:
+ conn_var.reset(token)
+
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor:
async with self.acquire() as conn:
return await conn.execute(query, *args, timeout=timeout)
diff --git a/mautrix/util/async_db/upgrade.py b/mautrix/util/async_db/upgrade.py
index d69dac56..c084d28b 100644
--- a/mautrix/util/async_db/upgrade.py
+++ b/mautrix/util/async_db/upgrade.py
@@ -21,7 +21,7 @@
UpgradeWithoutScheme = Callable[[LoggingConnection], Awaitable[Optional[int]]]
-async def noop_upgrade(_: LoggingConnection) -> None:
+async def noop_upgrade(_: LoggingConnection, _2: Scheme) -> None:
pass
@@ -97,11 +97,11 @@ async def _save_version(self, conn: LoggingConnection, version: int) -> None:
await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version)
async def upgrade(self, db: async_db.Database) -> None:
- await db.execute(
- f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} (
+ await db.execute(f"""
+ CREATE TABLE IF NOT EXISTS {self.version_table_name} (
version INTEGER PRIMARY KEY
- )"""
- )
+ )
+ """)
row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1")
version = row["version"] if row else 0
@@ -178,6 +178,6 @@ def _find_upgrade_table(fn: Upgrade) -> UpgradeTable:
def register_upgrade(index: int = -1, description: str = "") -> Callable[[Upgrade], Upgrade]:
def actually_register(fn: Upgrade) -> Upgrade:
- return _find_upgrade_table(fn).register(index, description, fn)
+ return _find_upgrade_table(fn).register(fn, index=index, description=description)
return actually_register
diff --git a/mautrix/util/background_task.py b/mautrix/util/background_task.py
new file mode 100644
index 00000000..e22e74f1
--- /dev/null
+++ b/mautrix/util/background_task.py
@@ -0,0 +1,53 @@
+# Copyright (c) 2023 Tulir Asokan
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+from __future__ import annotations
+
+from typing import Coroutine
+import asyncio
+import logging
+
+_tasks = set()
+log = logging.getLogger("mau.background_task")
+
+
+async def catch(coro: Coroutine, caller: str) -> None:
+ try:
+ await coro
+ except Exception:
+ log.exception(f"Uncaught error in background task (created in {caller})")
+
+
+# Logger.findCaller finds the 3rd stack frame, so add an intermediate function
+# to get the caller of create().
+def _find_caller() -> tuple[str, int, str, None]:
+ return log.findCaller()
+
+
+def create(coro: Coroutine, *, name: str | None = None, catch_errors: bool = True) -> asyncio.Task:
+ """
+ Create a background asyncio task safely, ensuring a reference is kept until the task completes.
+ It also catches and logs uncaught errors (unless disabled via the parameter).
+
+ Args:
+ coro: The coroutine to wrap in a task and execute.
+ name: An optional name for the created task.
+ catch_errors: Should the task be wrapped in a try-except block to log any uncaught errors?
+
+ Returns:
+ An asyncio Task object wrapping the given coroutine.
+ """
+ if catch_errors:
+ try:
+ file_name, line_number, function_name, _ = _find_caller()
+ caller = f"{function_name} at {file_name}:{line_number}"
+ except ValueError:
+ caller = "unknown function"
+ task = asyncio.create_task(catch(coro, caller), name=name)
+ else:
+ task = asyncio.create_task(coro, name=name)
+ _tasks.add(task)
+ task.add_done_callback(_tasks.discard)
+ return task
diff --git a/mautrix/util/bridge_state.py b/mautrix/util/bridge_state.py
index b7c346f0..d28448bf 100644
--- a/mautrix/util/bridge_state.py
+++ b/mautrix/util/bridge_state.py
@@ -62,8 +62,8 @@ class BridgeStateEvent(SerializableEnum):
class BridgeState(SerializableAttrs):
human_readable_errors: ClassVar[Dict[Optional[str], str]] = {}
default_source: ClassVar[str] = "bridge"
- default_error_ttl: ClassVar[int] = 60
- default_ok_ttl: ClassVar[int] = 240
+ default_error_ttl: ClassVar[int] = 3600
+ default_ok_ttl: ClassVar[int] = 21600
state_event: BridgeStateEvent
user_id: Optional[UserID] = None
@@ -106,8 +106,8 @@ def should_deduplicate(self, prev_state: Optional["BridgeState"]) -> bool:
):
# If there's no previous state or the state was different, send this one.
return False
- # If there's more than ⅘ of the previous pong's time-to-live left, drop this one
- return prev_state.timestamp + (prev_state.ttl / 5) > self.timestamp
+ # If the previous state is recent, drop this one
+ return prev_state.timestamp + prev_state.ttl > self.timestamp
async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool = True) -> bool:
if not url:
@@ -115,9 +115,10 @@ async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool =
self.send_attempts_ += 1
headers = {"Authorization": f"Bearer {token}", "User-Agent": HTTPAPI.default_ua}
try:
- async with aiohttp.ClientSession() as sess, sess.post(
- url, json=self.serialize(), headers=headers
- ) as resp:
+ async with (
+ aiohttp.ClientSession() as sess,
+ sess.post(url, json=self.serialize(), headers=headers) as resp,
+ ):
if not 200 <= resp.status < 300:
text = await resp.text()
text = text.replace("\n", "\\n")
diff --git a/mautrix/util/ffmpeg.py b/mautrix/util/ffmpeg.py
index f158ae52..41cebf32 100644
--- a/mautrix/util/ffmpeg.py
+++ b/mautrix/util/ffmpeg.py
@@ -5,9 +5,11 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
-from typing import Iterable
+from typing import Any, Iterable
from pathlib import Path
import asyncio
+import json
+import logging
import mimetypes
import os
import shutil
@@ -34,7 +36,89 @@ def __init__(self) -> None:
ffmpeg_path = _abswhich("ffmpeg")
-ffmpeg_default_params = ("-hide_banner", "-loglevel", "warning")
+ffmpeg_default_params = ("-hide_banner", "-loglevel", "warning", "-y")
+
+ffprobe_path = _abswhich("ffprobe")
+ffprobe_default_params = (
+ "-loglevel",
+ "quiet",
+ "-print_format",
+ "json",
+ "-show_optional_fields",
+ "1",
+ "-show_format",
+ "-show_streams",
+)
+
+
+async def probe_path(
+ input_file: os.PathLike[str] | str,
+ logger: logging.Logger | None = None,
+) -> Any:
+ """
+ Probes a media file on the disk using ffprobe.
+
+ Args:
+ input_file: The full path to the file.
+
+ Returns:
+ A Python object containing the parsed JSON response from ffprobe
+
+ Raises:
+ ConverterError: if ffprobe returns a non-zero exit code.
+ """
+ if ffprobe_path is None:
+ raise NotInstalledError()
+
+ input_file = Path(input_file)
+ proc = await asyncio.create_subprocess_exec(
+ ffprobe_path,
+ *ffprobe_default_params,
+ str(input_file),
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ stdin=asyncio.subprocess.PIPE,
+ )
+ stdout, stderr = await proc.communicate()
+ if proc.returncode != 0:
+ err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})"
+ raise ConverterError(f"ffprobe error: {err_text}")
+ elif stderr and logger:
+ logger.warning(f"ffprobe warning: {stderr.decode('utf-8')}")
+ return json.loads(stdout)
+
+
+async def probe_bytes(
+ data: bytes,
+ input_mime: str | None = None,
+ logger: logging.Logger | None = None,
+) -> Any:
+ """
+ Probe media file data using ffprobe.
+
+ Args:
+ data: The bytes of the file to probe.
+ input_mime: The mime type of the input data. If not specified, will be guessed using magic.
+
+ Returns:
+ A Python object containing the parsed JSON response from ffprobe
+
+ Raises:
+ ConverterError: if ffprobe returns a non-zero exit code.
+ """
+ if ffprobe_path is None:
+ raise NotInstalledError()
+
+ if input_mime is None:
+ if magic is None:
+ raise ValueError("input_mime was not specified and magic is not installed")
+ input_mime = magic.mimetype(data)
+ input_extension = mimetypes.guess_extension(input_mime)
+ with tempfile.TemporaryDirectory(prefix="mautrix_ffmpeg_") as tmpdir:
+ input_file = Path(tmpdir) / f"data{input_extension}"
+ with open(input_file, "wb") as file:
+ file.write(data)
+ return await probe_path(input_file=input_file, logger=logger)
async def convert_path(
@@ -44,6 +128,7 @@ async def convert_path(
output_args: Iterable[str] | None = None,
remove_input: bool = False,
output_path_override: os.PathLike[str] | str | None = None,
+ logger: logging.Logger | None = None,
) -> Path | bytes:
"""
Convert a media file on the disk using ffmpeg.
@@ -76,6 +161,10 @@ async def convert_path(
else:
input_file = Path(input_file)
output_file = input_file.parent / f"{input_file.stem}{output_extension}"
+ if input_file == output_file:
+ output_file = Path(output_file)
+ output_file = output_file.parent / f"{output_file.stem}-new{output_extension}"
+
proc = await asyncio.create_subprocess_exec(
ffmpeg_path,
*ffmpeg_default_params,
@@ -92,9 +181,8 @@ async def convert_path(
if proc.returncode != 0:
err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})"
raise ConverterError(f"ffmpeg error: {err_text}")
- elif stderr:
- # TODO log warnings?
- pass
+ elif stderr and logger:
+ logger.warning(f"ffmpeg warning: {stderr.decode('utf-8')}")
if remove_input and isinstance(input_file, Path):
input_file.unlink(missing_ok=True)
return stdout if output_file == "-" else output_file
@@ -106,6 +194,7 @@ async def convert_bytes(
input_args: Iterable[str] | None = None,
output_args: Iterable[str] | None = None,
input_mime: str | None = None,
+ logger: logging.Logger | None = None,
) -> bytes:
"""
Convert media file data using ffmpeg.
@@ -140,6 +229,7 @@ async def convert_bytes(
output_extension=output_extension,
input_args=input_args,
output_args=output_args,
+ logger=logger,
)
with open(output_file, "rb") as file:
return file.read()
@@ -152,4 +242,6 @@ async def convert_bytes(
"NotInstalledError",
"convert_bytes",
"convert_path",
+ "probe_bytes",
+ "probe_path",
]
diff --git a/mautrix/util/message_send_checkpoint.py b/mautrix/util/message_send_checkpoint.py
index 61eb691d..ee0c17f3 100644
--- a/mautrix/util/message_send_checkpoint.py
+++ b/mautrix/util/message_send_checkpoint.py
@@ -29,6 +29,7 @@ class MessageSendCheckpointStatus(SerializableEnum):
PERM_FAILURE = "PERM_FAILURE"
UNSUPPORTED = "UNSUPPORTED"
TIMEOUT = "TIMEOUT"
+ DELIVERY_FAILED = "DELIVERY_FAILED"
class MessageSendCheckpointReportedBy(SerializableEnum):
@@ -56,12 +57,15 @@ async def send(self, endpoint: str, as_token: str, log: logging.Logger) -> None:
return
try:
headers = {"Authorization": f"Bearer {as_token}", "User-Agent": HTTPAPI.default_ua}
- async with aiohttp.ClientSession() as sess, sess.post(
- endpoint,
- json={"checkpoints": [self.serialize()]},
- headers=headers,
- timeout=ClientTimeout(30),
- ) as resp:
+ async with (
+ aiohttp.ClientSession() as sess,
+ sess.post(
+ endpoint,
+ json={"checkpoints": [self.serialize()]},
+ headers=headers,
+ timeout=ClientTimeout(30),
+ ) as resp,
+ ):
if not 200 <= resp.status < 300:
text = await resp.text()
text = text.replace("\n", "\\n")
diff --git a/mautrix/util/program.py b/mautrix/util/program.py
index 80d2298b..74752ca5 100644
--- a/mautrix/util/program.py
+++ b/mautrix/util/program.py
@@ -98,7 +98,7 @@ def _prepare(self) -> None:
self.log.info(f"Initializing {self.name} {self.version}")
try:
- self.prepare()
+ self.loop.run_until_complete(self._async_prepare())
except Exception:
self.log.critical("Unexpected error in initialization", exc_info=True)
sys.exit(1)
@@ -117,9 +117,10 @@ def preinit(self) -> None:
self.prepare_config()
self.prepare_log()
self.check_config()
+ self.init_loop()
@property
- def _default_base_config(self) -> str:
+ def base_config_path(self) -> str:
return f"pkg://{self.module}/example-config.yaml"
def prepare_arg_parser(self) -> None:
@@ -133,21 +134,13 @@ def prepare_arg_parser(self) -> None:
metavar="",
help="the path to your config file",
)
- self.parser.add_argument(
- "-b",
- "--base-config",
- type=str,
- default=self._default_base_config,
- metavar="",
- help="the path to the example config (for automatic config updates)",
- )
self.parser.add_argument(
"-n", "--no-update", action="store_true", help="Don't save updated config to disk"
)
def prepare_config(self) -> None:
"""Pre-init lifecycle method. Extend this if you want to customize config loading."""
- self.config = self.config_class(self.args.config, self.args.base_config)
+ self.config = self.config_class(self.args.config, self.base_config_path)
self.load_and_update_config()
def load_and_update_config(self) -> None:
@@ -155,13 +148,10 @@ def load_and_update_config(self) -> None:
try:
self.config.update(save=not self.args.no_update)
except BaseMissingError:
- if self.args.base_config != self._default_base_config:
- print(f"Failed to read base config from {self.args.base_config}")
- else:
- print(
- "Failed to read base config from the default path "
- f"({self._default_base_config}). Maybe your installation is corrupted?"
- )
+ print(
+ "Failed to read base config from the default path "
+ f"({self.base_config_path}). Maybe your installation is corrupted?"
+ )
sys.exit(12)
def check_config(self) -> None:
@@ -179,14 +169,16 @@ def prepare_log(self) -> None:
logging.config.dictConfig(copy.deepcopy(self.config["logging"]))
self.log = cast(TraceLogger, logging.getLogger("mau.init"))
+ async def _async_prepare(self) -> None:
+ self.prepare()
+
def prepare(self) -> None:
"""
Lifecycle method where the primary program initialization happens.
Use this to fill startup_actions with async startup tasks.
"""
- self.prepare_loop()
- def prepare_loop(self) -> None:
+ def init_loop(self) -> None:
"""Init lifecycle method where the asyncio event loop is created."""
if uvloop is not None:
uvloop.install()
@@ -199,6 +191,7 @@ def start_prometheus(self) -> None:
try:
enabled = self.config["metrics.enabled"]
listen_port = self.config["metrics.listen_port"]
+ hostname = self.config.get("metrics.hostname", "0.0.0.0")
except KeyError:
return
if not enabled:
@@ -208,12 +201,13 @@ def start_prometheus(self) -> None:
"Metrics are enabled in config, but prometheus_client is not installed"
)
return
- prometheus.start_http_server(listen_port)
+ prometheus.start_http_server(listen_port, addr=hostname)
def _run(self) -> None:
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
+ self._stop_task = self.loop.create_future()
exit_code = 0
try:
self.log.debug("Running startup actions...")
@@ -224,7 +218,6 @@ def _run(self) -> None:
f"Startup actions complete in {round(end_ts - start_ts, 2)} seconds, "
"now running forever"
)
- self._stop_task = self.loop.create_future()
exit_code = self.loop.run_until_complete(self._stop_task)
self.log.debug("manual_stop() called, stopping...")
except KeyboardInterrupt:
diff --git a/mautrix/util/proxy.py b/mautrix/util/proxy.py
new file mode 100644
index 00000000..f36da73d
--- /dev/null
+++ b/mautrix/util/proxy.py
@@ -0,0 +1,129 @@
+from __future__ import annotations
+
+from typing import Awaitable, Callable, TypeVar
+import asyncio
+import json
+import logging
+import time
+import urllib.request
+
+from aiohttp import ClientConnectionError
+from yarl import URL
+
+from mautrix.util.logging import TraceLogger
+
+try:
+ from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError
+except ImportError:
+
+ class ProxyError(Exception):
+ pass
+
+ ProxyConnectionError = ProxyTimeoutError = ProxyError
+
+RETRYABLE_PROXY_EXCEPTIONS = (
+ ProxyError,
+ ProxyTimeoutError,
+ ProxyConnectionError,
+ ClientConnectionError,
+ ConnectionError,
+ asyncio.TimeoutError,
+)
+
+
+class ProxyHandler:
+ current_proxy_url: str | None = None
+ log = logging.getLogger("mau.proxy")
+
+ def __init__(self, api_url: str | None) -> None:
+ self.api_url = api_url
+
+ def get_proxy_url_from_api(self, reason: str | None = None) -> str | None:
+ assert self.api_url is not None
+
+ api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {}))
+
+ # NOTE: using urllib.request to intentionally block the whole bridge until the proxy change applied
+ request = urllib.request.Request(api_url, method="GET")
+ self.log.debug("Requesting proxy from: %s", api_url)
+
+ try:
+ with urllib.request.urlopen(request) as f:
+ response = json.loads(f.read().decode())
+ except Exception:
+ self.log.exception("Failed to retrieve proxy from API")
+ return self.current_proxy_url
+ else:
+ return response["proxy_url"]
+
+ def update_proxy_url(self, reason: str | None = None) -> bool:
+ old_proxy = self.current_proxy_url
+ new_proxy = None
+
+ if self.api_url is not None:
+ new_proxy = self.get_proxy_url_from_api(reason)
+ else:
+ new_proxy = urllib.request.getproxies().get("http")
+
+ if old_proxy != new_proxy:
+ self.log.debug("Set new proxy URL: %s", new_proxy)
+ self.current_proxy_url = new_proxy
+ return True
+
+ self.log.debug("Got same proxy URL: %s", new_proxy)
+ return False
+
+ def get_proxy_url(self) -> str | None:
+ if not self.current_proxy_url:
+ self.update_proxy_url()
+
+ return self.current_proxy_url
+
+
+T = TypeVar("T")
+
+
+async def proxy_with_retry(
+ name: str,
+ func: Callable[[], Awaitable[T]],
+ logger: TraceLogger,
+ proxy_handler: ProxyHandler,
+ on_proxy_change: Callable[[], Awaitable[None]],
+ max_retries: int = 10,
+ min_wait_seconds: int = 0,
+ max_wait_seconds: int = 60,
+ multiply_wait_seconds: int = 10,
+ retryable_exceptions: tuple[Exception] = RETRYABLE_PROXY_EXCEPTIONS,
+ reset_after_seconds: int | None = None,
+) -> T:
+ errors = 0
+ last_error = 0
+
+ while True:
+ try:
+ return await func()
+ except retryable_exceptions as e:
+ errors += 1
+ if errors > max_retries:
+ raise
+ wait = errors * multiply_wait_seconds
+ wait = max(wait, min_wait_seconds)
+ wait = min(wait, max_wait_seconds)
+ logger.warning(
+ "%s while trying to %s, retrying in %d seconds",
+ e.__class__.__name__,
+ name,
+ wait,
+ )
+ if errors > 1 and proxy_handler.update_proxy_url(
+ f"{e.__class__.__name__} while trying to {name}"
+ ):
+ await on_proxy_change()
+
+ # If sufficient time has passed since the previous error, reset the
+ # error count. Useful for long running tasks with rare failures.
+ if reset_after_seconds is not None:
+ now = time.time()
+ if last_error and now - last_error > reset_after_seconds:
+ errors = 0
+ last_error = now
diff --git a/mautrix/util/simple_lock.py b/mautrix/util/simple_lock.py
index ad979aec..c6bd08ce 100644
--- a/mautrix/util/simple_lock.py
+++ b/mautrix/util/simple_lock.py
@@ -47,7 +47,7 @@ def locked(self) -> bool:
return not self.noop_mode and not self._event.is_set()
async def wait(self, task: str | None = None) -> None:
- if not self.noop_mode and not self._event.is_set():
+ if self.locked:
if self.log and self.message:
self.log.debug(self.message, task)
await self._event.wait()
diff --git a/mautrix/util/utf16_surrogate.py b/mautrix/util/utf16_surrogate.py
index d202e15e..92dc49c2 100644
--- a/mautrix/util/utf16_surrogate.py
+++ b/mautrix/util/utf16_surrogate.py
@@ -16,9 +16,11 @@ def add(text: str) -> str:
The text with surrogate pairs.
"""
return "".join(
- "".join(chr(y) for y in struct.unpack(" dict[str, str]:
@@ -43,11 +43,12 @@ async def fetch_data() -> dict[str, str]:
if __name__ == "__main__":
import asyncio
+ import importlib.resources
+ import pathlib
import sys
- import pkg_resources
-
- path = pkg_resources.resource_filename("mautrix.util", "variation_selector.json")
+ path = importlib.resources.files("mautrix.util").joinpath("variation_selector.json")
+ assert isinstance(path, pathlib.Path)
emojis = asyncio.run(fetch_data())
with open(path, "w") as file:
json.dump(emojis, file, indent=" ", ensure_ascii=False)
@@ -59,11 +60,11 @@ async def fetch_data() -> dict[str, str]:
ADD_VARIATION_TRANSLATION = str.maketrans(
{ord(emoji): f"{emoji}{VARIATION_SELECTOR_16}" for emoji in read_data().values()}
)
-SKIN_TONE_MODIFIERS = ("\U0001F3FB", "\U0001F3FC", "\U0001F3FD", "\U0001F3FE", "\U0001F3FF")
+SKIN_TONE_MODIFIERS = ("\U0001f3fb", "\U0001f3fc", "\U0001f3fd", "\U0001f3fe", "\U0001f3ff")
SKIN_TONE_REPLACEMENTS = {f"{VARIATION_SELECTOR_16}{mod}": mod for mod in SKIN_TONE_MODIFIERS}
VARIATION_SELECTOR_REPLACEMENTS = {
**SKIN_TONE_REPLACEMENTS,
- "\U0001F408\ufe0f\u200d\u2b1b\ufe0f": "\U0001F408\u200d\u2b1b",
+ "\U0001f408\ufe0f\u200d\u2b1b\ufe0f": "\U0001f408\u200d\u2b1b",
}
diff --git a/optional-requirements.txt b/optional-requirements.txt
index 6cfdcece..a6e0227d 100644
--- a/optional-requirements.txt
+++ b/optional-requirements.txt
@@ -1,6 +1,6 @@
python-magic
ruamel.yaml
-SQLAlchemy
+SQLAlchemy<2
commonmark
lxml
asyncpg
@@ -11,3 +11,4 @@ uvloop
python-olm
unpaddedbase64
pycryptodome
+base58
diff --git a/pyproject.toml b/pyproject.toml
index c0e41313..bc17b4d7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,7 +7,8 @@ line_length = 99
[tool.black]
line-length = 99
-target-version = ["py38"]
+target-version = ["py310"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
+addopts = "--ignore mautrix/util/db/ --ignore mautrix/bridge/"
diff --git a/setup.py b/setup.py
index f7596154..23ac1cb1 100644
--- a/setup.py
+++ b/setup.py
@@ -2,8 +2,8 @@
from mautrix import __version__
-encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome"]
-test_dependencies = ["aiosqlite", "sqlalchemy", "asyncpg", *encryption_dependencies]
+encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome", "base58"]
+test_dependencies = ["aiosqlite", "asyncpg", "ruamel.yaml", *encryption_dependencies]
setuptools.setup(
name="mautrix",
@@ -28,12 +28,12 @@
],
extras_require={
"detect_mimetype": ["python-magic>=0.4.15,<0.5"],
- "lint": ["black==22.1.0", "isort"],
+ "lint": ["black~=25.1", "isort"],
"test": ["pytest", "pytest-asyncio", *test_dependencies],
"encryption": encryption_dependencies,
},
tests_require=test_dependencies,
- python_requires="~=3.8",
+ python_requires="~=3.10",
classifiers=[
"Development Status :: 4 - Beta",
@@ -42,9 +42,11 @@
"Framework :: AsyncIO",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
],
package_data={