diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c1ee7705..3a845672 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -4,22 +4,24 @@ on: [push, pull_request] jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Install libolm + run: sudo apt-get install libolm3 - name: Install dependencies run: | python -m pip install --upgrade pip + python -m pip install python-olm --extra-index-url https://gitlab.matrix.org/api/v4/projects/27/packages/pypi/simple python -m pip install .[test] - name: Test with pytest run: | @@ -44,17 +46,17 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.14" - uses: isort/isort-action@master with: sortPaths: "./mautrix" - uses: psf/black@stable with: src: "./mautrix" - version: "22.1.0" + version: "26.3.1" - name: pre-commit run: | pip install pre-commit diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8f3efe97..b0d7ab3f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,6 +1,6 @@ build docs builder: stage: build - image: docker:stable + image: docker:latest tags: - amd64 only: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d9f684a..66065033 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,20 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files - # TODO convert to use the upstream psf/black when - # https://github.com/psf/black/issues/2493 gets fixed - - repo: local + - repo: https://github.com/psf/black + rev: 26.3.1 hooks: - id: black - name: black - entry: black --check - language: system - files: ^mautrix/.*\.py$ + language_version: python3 + files: ^mautrix/.*\.pyi?$ - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 8.0.1 hooks: - id: isort - files: ^mautrix/.*$ + files: ^mautrix/.*\.pyi?$ diff --git a/CHANGELOG.md b/CHANGELOG.md index fe9f5eb1..c7179c86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,579 @@ -## v0.14.11 (unreleased) +## v0.21.0 (2025-11-17) + +* *(event)* Added support for creator power in room v12+. +* *(crypto)* Added support for generating and using recovery keys for verifying + the active device. +* *(bridge)* Added config option for self-signing bot device. +* *(bridge)* Removed check for login flows when using MSC4190 + (thanks to [@meson800] in [#178]). +* *(client)* Changed `set_displayname` and `set_avatar_url` to avoid setting + empty strings if the value is already unset (thanks to [@frebib] in [#171]). + +[@frebib]: https://github.com/frebib +[@meson800]: https://github.com/meson800 +[#171]: https://github.com/mautrix/python/pull/171 +[#178]: https://github.com/mautrix/python/pull/178 + +## v0.20.8 (2025-06-01) + +* *(bridge)* Added support for [MSC4190] (thanks to [@surakin] in [#175]). +* *(appservice)* Renamed `push_ephemeral` in generated registrations to + `receive_ephemeral` to match the accepted version of [MSC2409]. +* *(bridge)* Fixed compatibility with breaking change in aiohttp 3.12.6. -* Removed Python 3.7 support. +[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 +[@surakin]: https://github.com/surakin +[#175]: https://github.com/mautrix/python/pull/175 + +## v0.20.7 (2025-01-03) + +* *(types)* Removed support for generating reply fallbacks to implement + [MSC2781]. Stripping fallbacks is still supported. + +[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 + +## v0.20.6 (2024-07-12) + +* *(bridge)* Added `/register` call if `/versions` fails with `M_FORBIDDEN`. + +## v0.20.5 (2024-07-09) + +**Note:** The `bridge` module is deprecated as all bridges are being rewritten +in Go. See for more info. + +* *(client)* Added support for authenticated media downloads. +* *(bridge)* Stopped using cached homeserver URLs for double puppeting if one + is set in the config file. +* *(crypto)* Fixed error when checking OTK counts before uploading new keys. +* *(types)* Added MSC2530 (captions) fields to `MediaMessageEventContent`. + +## v0.20.4 (2024-01-09) + +* Dropped Python 3.9 support. +* *(client)* Changed media download methods to log requests and to raise + exceptions on non-successful status codes. + +## v0.20.3 (2023-11-10) + +* *(client)* Deprecated MSC2716 methods and added new Beeper-specific batch + send methods, as upstream MSC2716 support has been abandoned. +* *(util.async_db)* Added `PRAGMA synchronous = NORMAL;` to default pragmas. +* *(types)* Fixed `guest_can_join` field name in room directory response + (thanks to [@ashfame] in [#163]). + +[@ashfame]: https://github.com/ashfame +[#163]: https://github.com/mautrix/python/pull/163 + +## v0.20.2 (2023-09-09) + +* *(crypto)* Changed `OlmMachine.share_keys` to make the OTK count parameter + optional. When omitted, the count is fetched from the server. +* *(appservice)* Added option to run appservice transaction event handlers + synchronously. +* *(appservice)* Added `log` and `hs_token` parameters to `AppServiceServerMixin` + to allow using it as a standalone class without extending. +* *(api)* Added support for setting appservice `user_id` and `device_id` query + parameters manually without using `AppServiceAPI`. + +## v0.20.1 (2023-08-29) + +* *(util.program)* Removed `--base-config` flag in bridges, as there are no + valid use cases (package data should always work) and it's easy to cause + issues by pointing the flag at the wrong file. +* *(bridge)* Added support for the `com.devture.shared_secret_auth` login type + for automatic double puppeting. +* *(bridge)* Dropped support for syncing with double puppets. MSC2409 is now + the only way to receive ephemeral events. +* *(bridge)* Added support for double puppeting with arbitrary `as_token`s. + +## v0.20.0 (2023-06-25) + +* Dropped Python 3.8 support. +* **Breaking change *(.state_store)*** Removed legacy SQLAlchemy state store + implementations. +* **Mildly breaking change *(util.async_db)*** Changed `SQLiteDatabase` to not + remove prefix slashes from database paths. + * Library users should use `sqlite:path.db` instead of `sqlite:///path.db` + for relative paths, and `sqlite:/path.db` instead of `sqlite:////path.db` + for absolute paths. + * Bridge configs do this migration automatically. +* *(util.async_db)* Added warning log if using SQLite database path that isn't + writable. +* *(util.program)* Fixed `manual_stop` not working if it's called during startup. +* *(client)* Stabilized support for asynchronous uploads. + * `unstable_create_msc` was renamed to `create_mxc`, and the `max_stall_ms` + parameters for downloading were renamed to `timeout_ms`. +* *(crypto)* Added option to not rotate keys when devices change. +* *(crypto)* Added option to remove all keys that were received before the + automatic ratcheting was implemented (in v0.19.10). +* *(types)* Improved reply fallback removal to have a smaller chance of false + positives for messages that don't use reply fallbacks. + +## v0.19.16 (2023-05-26) + +* *(appservice)* Fixed Python 3.8 compatibility. + +## v0.19.15 (2023-05-24) + +* *(client)* Fixed dispatching room ephemeral events (i.e. typing notifications) in syncer. + +## v0.19.14 (2023-05-16) + +* *(bridge)* Implemented appservice pinging using MSC2659. +* *(bridge)* Started reusing aiosqlite connection pool for crypto db. + * This fixes the crypto pool getting stuck if the bridge exits unexpectedly + (the default pool is closed automatically at any type of exit). + +## v0.19.13 (2023-04-24) + +* *(crypto)* Fixed bug with redacting megolm sessions when device is deleted. + +## v0.19.12 (2023-04-18) + +* *(bridge)* Fixed backwards-compatibility with new key deletion config options. + +## v0.19.11 (2023-04-14) + +* *(crypto)* Fixed bug in previous release which caused errors if the `max_age` + of a megolm session was not known. +* *(crypto)* Changed key receiving handler to fetch encryption config from + server if it's not cached locally (to find `max_age` and `max_messages` more + reliably). + +## v0.19.10 (2023-04-13) + +* *(crypto, bridge)* Added options to automatically ratchet/delete megolm + sessions to minimize access to old messages. + +## v0.19.9 (2023-04-12) + +* *(crypto)* Fixed bug in crypto store migration when using outbound sessions + with max age higher than usual. + +## v0.19.8 (2023-04-06) + +* *(crypto)* Updated crypto store schema to match mautrix-go. +* *(types)* Fixed `set_thread_parent` adding reply fallbacks to the message body. + +## v0.19.7 (2023-03-22) + +* *(bridge, crypto)* Fixed key sharing trust checker not resolving cross-signing + signatures when minimum trust level is set to cross-signed. + +## v0.19.6 (2023-03-13) + +* *(crypto)* Added cache checks to prevent invalidating group session when the + server sends a duplicate member event in /sync. +* *(util.proxy)* Fixed `min_wait_seconds` behavior and added `max_wait_seconds` + and `multiply_wait_seconds` to `proxy_with_retry`. + +## v0.19.5 (2023-03-07) + +* *(util.proxy)* Added utility for dynamic proxies (from mautrix-instagram/facebook). +* *(types)* Added default value for `upload_size` in `MediaRepoConfig` as the + field is optional in the spec. +* *(bridge)* Changed ghost invite handling to only process one per room at a time + (thanks to [@maltee1] in [#132]). + +[#132]: https://github.com/mautrix/python/pull/132 + +## v0.19.4 (2023-02-12) + +* *(types)* Changed `set_thread_parent` to inherit the existing thread parent + if a `MessageEvent` is passed, as starting threads from a message in a thread + is not allowed. +* *(util.background_task)* Added new utility for creating background tasks + safely, by ensuring that the task is not garbage collected before finishing + and logging uncaught exceptions immediately. + +## v0.19.3 (2023-01-27) + +* *(bridge)* Bumped default timeouts for decrypting incoming messages. + +## v0.19.2 (2023-01-14) + +* *(util.async_body)* Added utility for reading aiohttp response into a bytearray + (so that the output is mutable, e.g. for decrypting or encrypting media). +* *(client.api)* Fixed retry loop for MSC3870 URL uploads not exiting properly + after too many errors. + +## v0.19.1 (2023-01-11) + +* Marked Python 3.11 as supported. Python 3.8 support will likely be dropped in + the coming months. +* *(client.api)* Added request payload memory optimization to MSC3870 URL uploads. + * aiohttp will duplicate the entire request body if it's raw bytes, which + wastes a lot of memory. The optimization is passing an iterator instead of + raw bytes, so aiohttp won't accidentally duplicate the whole thing. + * The main `HTTPAPI` has had the optimization for a while, but uploading to + URL calls aiohttp manually. + +## v0.19.0 (2023-01-10) + +* **Breaking change *(appservice)*** Removed typing status from state store. +* **Breaking change *(appservice)*** Removed `is_typing` parameter from + `IntentAPI.set_typing` to make the signature match `ClientAPI.set_typing`. + `timeout=0` is equivalent to the old `is_typing=False`. +* **Breaking change *(types)*** Removed legacy fields in Beeper MSS events. +* *(bridge)* Removed accidentally nested reply loop when accepting invites as + the bridge bot. +* *(bridge)* Fixed decoding JSON values in config override env vars. + +## v0.18.9 (2022-12-14) + +* *(util.async_db)* Changed aiosqlite connector to force-enable foreign keys, + WAL mode and busy_timeout. + * The values can be changed by manually specifying the same PRAGMAs in the + `init_commands` db arg, e.g. `- PRAGMA foreign_keys = OFF`. +* *(types)* Added workaround to `StateEvent.deserialize` to handle Conduit's + broken `unsigned` fields. +* *(client.state_store)* Fixed `set_power_level` to allow raw dicts the same + way as `set_encryption_info` does (thanks to [@bramenn] in [#127]). + +[@bramenn]: https://github.com/bramenn +[#127]: https://github.com/mautrix/python/pull/127 + +## v0.18.8 (2022-11-18) + +* *(crypto.store.asyncpg)* Fixed bug causing `put_group_session` to fail when + trying to log unique key errors. +* *(client)* Added wrapper for `create_room` to update the state store with + initial state and invites (applies to anything extending `StoreUpdatingAPI`, + such as the high-level `Client` and appservice `IntentAPI` classes). + +## v0.18.7 (2022-11-08) + +## v0.18.6 (2022-10-24) + +* *(util.formatter)* Added conversion method for `
` tag and defaulted to + converting back to `---`. + +## v0.18.5 (2022-10-20) + +* *(appservice)* Added try blocks around [MSC3202] handler functions to log + errors instead of failing the entire transaction. This matches the behavior + of errors in normal appservice event handlers. + +## v0.18.4 (2022-10-13) + +* *(client.api)* Added option to pass custom data to `/createRoom` to enable + using custom fields and testing MSCs without changing the library. +* *(client.api)* Updated [MSC3870] support to send file name in upload complete + call. +* *(types)* Changed `set_edit` to clear reply metadata as edits can't change + the reply status. +* *(util.formatter)* Fixed edge case causing negative entity lengths when + splitting entity strings. + +## v0.18.3 (2022-10-11) + +* *(util.async_db)* Fixed mistake in default no-op database error handler + causing the wrong exception to be raised. +* *(crypto.store.asyncpg)* Updated `put_group_session` to catch unique key + errors and log instead of raising. +* *(client.api)* Updated [MSC3870] support to catch and retry on all + connection errors instead of only non-200 status codes when uploading. + +## v0.18.2 (2022-09-24) + +* *(crypto)* Fixed handling key requests when using appservice-mode (MSC2409) + encryption. +* *(appservice)* Added workaround for dumb servers that send `"unsigned": null` + in events. + +## v0.18.1 (2022-09-15) + +* *(crypto)* Fixed error sharing megolm session if a single recipient device + has ran out of one-time keys. + +## v0.18.0 (2022-09-15) + +* **Breaking change *(util.async_db)*** Added checks to prevent calling + `.start()` on a database multiple times. +* *(appservice)* Fixed [MSC2409] support to read to-device events from the + correct field. +* *(appservice)* Added support for automatically calling functions when a + transaction contains [MSC2409] to-device events or [MSC3202] encryption data. +* *(bridge)* Added option to use [MSC2409] and [MSC3202] for end-to-bridge + encryption. However, this may not work with the Synapse implementation as it + hasn't been tested yet. +* *(bridge)* Replaced `homeserver` -> `asmux` flag with more generic `software` + field. +* *(bridge)* Added support for overriding parts of config with environment + variables. + * If the value starts with `json::`, it'll be parsed as JSON instead of using + as a raw string. +* *(client.api)* Added support for [MSC3870] for both uploading and downloading + media. +* *(types)* Added `knock_restricted` join rule to `JoinRule` enum. +* *(crypto)* Added warning logs if claiming one-time keys for other users fails. + +[MSC3870]: https://github.com/matrix-org/matrix-spec-proposals/pull/3870 + +## v0.17.8 (2022-08-22) + +* *(crypto)* Fixed parsing `/keys/claim` responses with no `failures` field. +* *(bridge)* Fixed parsing e2ee key sharing allow/minimum level config. + +## v0.17.7 (2022-08-22) + +* *(util.async_db)* Added `init_commands` to run commands on each SQLite + connection (e.g. to enable `PRAGMA`s). No-op on Postgres. +* *(bridge)* Added check to make sure e2ee keys are intact on server. + If they aren't, the crypto database will be wiped and the bridge will stop. + +## v0.17.6 (2022-08-17) + +* *(bridge)* Added hidden option to use appservice login for double puppeting. +* *(client)* Fixed sync handling throwing an error if event parsing failed. +* *(errors)* Added `M_UNKNOWN_ENDPOINT` error code from [MSC3743] +* *(appservice)* Updated [MSC3202] support to handle one time keys correctly. + +[MSC3743]: https://github.com/matrix-org/matrix-spec-proposals/pull/3743 + +## v0.17.5 (2022-08-15) + +* *(types)* Added `m.read.private` to receipt types. +* *(appservice)* Stopped `ensure_registered` and `invite_user` raising + `IntentError`s (now they raise the original Matrix error instead). + +## v0.17.4 (2022-07-28) + +* *(bridge)* Started rejecting reusing access tokens when enabling double + puppeting. Reuse is detected by presence of encryption keys on the device. +* *(client.api)* Added wrapper method for the `/context` API. +* *(api, errors)* Implemented new error codes from [MSC3848]. +* *(types)* Disabled deserializing `m.direct` content (it didn't work and it + wasn't really necessary). +* *(client.state_store)* Updated `set_encryption_info` to allow raw dicts. + This fixes the bug where sending a `m.room.encryption` event with a raw dict + as the content would throw an error from the state store. +* *(crypto)* Fixed error when fetching keys for user with no cross-signing keys + (thanks to [@maltee1] in [#109]). + +[MSC3848]: https://github.com/matrix-org/matrix-spec-proposals/pull/3848 +[#109]: https://github.com/mautrix/python/pull/109 + +## v0.17.3 (2022-07-12) + +* *(types)* Updated `BeeperMessageStatusEventContent` fields. + +## v0.17.2 (2022-07-06) + +* *(api)* Updated request logging to log full URL instead of only path. +* *(bridge)* Fixed migrating key sharing allow flag to new config format. +* *(appservice)* Added `beeper_new_messages` flag for `batch_send` method. + +## v0.17.1 (2022-07-05) + +* *(crypto)* Fixed Python 3.8/9 compatibility broken in v0.17.0. +* *(crypto)* Added some tests for attachments and store code. +* *(crypto)* Improved logging when device change validation fails. + +## v0.17.0 (2022-07-05) + +* **Breaking change *(bridge)*** Added options to check cross-signing status + for bridge users. This requires changes to the base config. + * New options include requiring cross-signed devices (with TOFU) for sending + and/or receiving messages, and an option to drop any unencrypted messages. +* **Breaking change *(crypto)*** Removed `sender_key` parameter from + CryptoStore's `has_group_session` and `put_group_session`, and also + OlmMachine's `wait_for_session`. +* **Breaking change *(crypto.store.memory)*** Updated the key of the + `_inbound_sessions` dict to be (room_id, session_id), removing the identity + key in the middle. This only affects custom stores based on the memory store. +* *(crypto)* Added basic cross-signing validation code. +* *(crypto)* Marked device_id and sender_key as deprecated in Megolm events + as per Matrix 1.3. +* *(api)* Bumped request logs to `DEBUG` level. + * Also added new `sensitive` parameter to the `request` method to prevent + logging content in sensitive requests. The `login` method was updated to + mark the content as sensitive if a password or token is provided. +* *(bridge.commands)* Switched the order of the user ID parameter in `set-pl`, + `set-avatar` and `set-displayname`. + +## v0.16.11 (2022-06-28) + +* *(appservice)* Fixed the `extra_content` parameter in membership methods + causing duplicate join events through the `ensure_joined` mechanism. + +## v0.16.10 (2022-06-24) + +* *(bridge)* Started requiring Matrix v1.1 support from homeservers. +* *(bridge)* Added hack to automatically send a read receipt for messages sent + to Matrix with double puppeting (to work around weird unread count issues). + +## v0.16.9 (2022-06-22) + +* *(client)* Added support for knocking on rooms (thanks to [@maltee1] in [#105]). +* *(bridge)* Added config option to set key rotation settings with e2be. + +[#105]: https://github.com/mautrix/python/pull/105 + +## v0.16.8 (2022-06-20) + +* *(bridge)* Updated e2be helper to stop bridge if syncing fails. +* *(util.async_db)* Updated asyncpg connector to stop program if an asyncpg + `InternalClientError` is thrown. These errors usually cause everything to + get stuck. + * The behavior can be disabled by passing `meow_exit_on_ice` = `false` in + the `db_args`. + +## v0.16.7 (2022-06-19) + +* *(util.formatter)* Added support for parsing `img` tags + * By default, the `alt` or `title` attribute will be used as plaintext. +* *(types)* Added `notifications` object to power level content class. +* *(bridge)* Added utility methods for handling incoming knocks in + `MatrixHandler` (thanks to [@maltee1] in [#103]). +* *(appservice)* Updated `IntentAPI` to add the `fi.mau.double_puppet_source` + to all state events sent with double puppeted intents (previously it was only + added to non-state events). + +[#103]: https://github.com/mautrix/python/pull/103 + +## v0.16.6 (2022-06-02) + +* *(bridge)* Fixed double puppeting `start` method not handling some errors + from /whoami correctly. +* *(types)* Added `com.beeper.message_send_status` event type for bridging + status. + +## v0.16.5 (2022-05-26) + +* *(bridge.commands)* Added `reason` field for `CommandEvent.redact`. +* *(client.api)* Added `reason` field for the `unban_user` method + (thanks to [@maltee1] in [#101]). +* *(bridge)* Changed automatic DM portal creation to only apply when the invite + event specifies `"is_direct": true` (thanks to [@maltee1] in [#102]). +* *(util.program)* Changed `Program` to use create and set an event loop + explicitly instead of using `get_event_loop`. +* *(util.program)* Added optional `exit_code` parameter to `manual_stop`. +* *(util.manhole)* Removed usage of loop parameters to fix Python 3.10 + compatibility. +* *(appservice.api)* Switched `IntentAPI.batch_send` method to use custom Event + classes instead of the default ones (since some normal event fields aren't + applicable when batch sending). + +[@maltee1]: https://github.com/maltee1 +[#101]: https://github.com/mautrix/python/pull/101 +[#102]: https://github.com/mautrix/python/pull/102 + +## v0.16.4 (2022-05-10) + +* *(types, bridge)* Dropped support for appservice login with unstable prefix. +* *(util.async_db)* Fixed some database start errors causing unnecessary noise + in logs. +* *(bridge.commands)* Added helper method to redact bridge commands. + +## v0.16.3 (2022-04-21) + +* *(types)* Changed `set_thread_parent` to have an explicit option for + disabling the thread-as-reply fallback. + +## v0.16.2 (2022-04-21) + +* *(types)* Added `get_thread_parent` and `set_thread_parent` helper methods + for `MessageEventContent`. +* *(bridge)* Increased timeout for `MessageSendCheckpoint.send`. + +## v0.16.1 (2022-04-17) + +* **Breaking change** Removed `r0` path support. + * The new `v3` paths are implemented since Synapse 1.48, Dendrite 0.6.5, + and Conduit 0.4.0. Servers older than these are no longer supported. + +## v0.16.0 (2022-04-11) + +* **Breaking change *(types)*** Removed custom `REPLY` relation type and + changed `RelatesTo` structure to match the actual event content. + * Applications using `content.get_reply_to()` and `content.set_reply()` will + keep working with no changes. +* *(types)* Added `THREAD` relation type and `is_falling_back` field to + `RelatesTo`. + +## v0.15.8 (2022-04-08) + +* *(client.api)* Added experimental prometheus metric for file upload speed. +* *(util.async_db)* Improved type hints for `UpgradeTable.register` +* *(util.async_db)* Changed connection string log to redact database password. + +## v0.15.7 (2022-04-05) + +* *(api)* Added `file_name` parameter to `HTTPAPI.get_download_url`. + +## v0.15.6 (2022-03-30) + +* *(types)* Fixed removing nested (i.e. malformed) reply fallbacks generated by + some clients. +* *(types)* Added automatic reply fallback trimming to `set_reply()` to prevent + accidentally creating nested reply fallbacks. + +## v0.15.5 (2022-03-28) + +* *(crypto)* Changed default behavior of OlmMachine to ignore instead of reject + key requests from other users. +* Fixed some type hints + +## v0.15.3 & v0.15.4 (2022-03-25) + +* *(client.api)* Fixed incorrect HTTP methods in async media uploads. + +## v0.15.2 (2022-03-25) + +* *(client.api)* Added support for async media uploads ([MSC2246]). +* Moved `async_getter_lock` decorator to `mautrix.util` (from `mautrix.bridge`). + * The old import path will keep working. + +[MSC2246]: https://github.com/matrix-org/matrix-spec-proposals/pull/2246 + +## v0.15.1 (2022-03-23) + +* *(types)* Added `ensure_has_html` method for `TextMessageEventContent` to + generate a HTML `formatted_body` from the plaintext `body` correctly (i.e. + escaping HTML and replacing newlines). + +## v0.15.0 (2022-03-16) + +* **Breaking change** Removed Python 3.7 support. +* **Breaking change *(api)*** Removed `r0` from default path builders in order + to update to `v3` and per-endpoint versioning. + * The client API modules have been updated to specify v3 in the paths, other + direct usage of `Path`, `ClientPath` and `MediaPath` will have to be + updated manually. `UnstableClientPath` no longer exists and should be + replaced with `Path.unstable`. + * There's a temporary hacky backwards-compatibility layer which replaces /v3 + with /r0 if the server doesn't advertise support for Matrix v1.1 or higher. + It can be activated by calling the `.versions()` method in `ClientAPI`. + The bridge module calls that method automatically. +* **Breaking change *(util.formatter)*** Removed lxml-based HTML parser. + * The parsed data format is still compatible with lxml, so it is possible to + use lxml with `MatrixParser` by setting `lxml.html.fromstring` as the + `read_html` method. +* **Breaking change *(crypto)*** Moved `TrustState`, `DeviceIdentity`, + `OlmEventKeys` and `DecryptedOlmEvent` dataclasses from `crypto.types` + into `types.crypto`. +* **Breaking change *(bridge)*** Made `User.get_puppet` abstract and added new + abstract `User.get_portal_with` and `Portal.get_dm_puppet` methods. +* Added a redundant `__all__` to various `__init__.py` files to appease pyright. +* *(api)* Reduced aiohttp memory usage when uploading large files by making + an in-memory async iterable instead of passing the bytes directly. * *(bridge)* Removed legacy community utilities. +* *(bridge)* Added support for creating DM portals with minimal bridge-specific code. * *(util.async_db)* Fixed counting number of db upgrades. * *(util.async_db)* Added support for schema migrations that jump versions. +* *(util.async_db)* Added system for preventing using the same database for + multiple programs. + * To enable it, provide an unique program name as the `owner_name` parameter + in `Database.create`. + * Additionally, if `ignore_foreign_tables` is set to `True`, it will check + for tables of some known software like Synapse and Dendrite. + * The `bridge` module enables both options by default. +* *(util.db)* Module deprecated. The async_db module is recommended. However, + the SQLAlchemy helpers will remain until maubot has switched to asyncpg. +* *(util.magic)* Allowed `bytearray` as an input type for the `mimetype` method. +* *(crypto.attachments)* Added method to encrypt a `bytearray` in-place to + avoid unnecessarily duplicating data in memory. ## v0.14.10 (2022-02-01) @@ -165,12 +735,14 @@ ## v0.12.5 (2021-11-30) -* Added wrapper for [MSC2716]'s `/batch_send` endpoint in `IntentAPI` -* Added some Matrix request metrics (thanks to @jaller94 in #68) +* Added wrapper for [MSC2716]'s `/batch_send` endpoint in `IntentAPI`. +* Added some Matrix request metrics (thanks to [@jaller94] in [#68]). * Added utility method for adding variation selector 16 to emoji strings the same way as Element does (using emojibase data). -[MSC2716]: https://github.com/matrix-org/matrix-doc/pull/2716 +[MSC2716]: https://github.com/matrix-org/matrix-spec-proposals/pull/2716 +[@jaller94]: https://github.com/jaller94 +[#68]: https://github.com/mautrix/python/pull/68 ## v0.12.4 (2021-11-25) @@ -340,8 +912,8 @@ * Fixed receiving appservice transactions with `Authorization` header (i.e. fixed [MSC2832] support). -[MSC3202]: https://github.com/matrix-org/matrix-doc/pull/3202 -[MSC2832]: https://github.com/matrix-org/matrix-doc/pull/2832 +[MSC3202]: https://github.com/matrix-org/matrix-spec-proposals/pull/3202 +[MSC2832]: https://github.com/matrix-org/matrix-spec-proposals/pull/2832 [@sumnerevans]: https://github.com/sumnerevans [#49]: https://github.com/mautrix/python/pull/49 @@ -574,8 +1146,8 @@ `EventType.Class.UNKNOWN` as the type class. * Fixed regex escaping in bridge registration generation. -[MSC2778]: https://github.com/matrix-org/matrix-doc/pull/2778 -[MSC2409]: https://github.com/matrix-org/matrix-doc/pull/2409 +[MSC2778]: https://github.com/matrix-org/matrix-spec-proposals/pull/2778 +[MSC2409]: https://github.com/matrix-org/matrix-spec-proposals/pull/2409 [@ShadowJonathan]: https://github.com/ShadowJonathan [@witchent]: https://github.com/witchent [#26]: https://github.com/mautrix/python/pull/26 diff --git a/README.rst b/README.rst index 75ce8f6a..f1342c19 100644 --- a/README.rst +++ b/README.rst @@ -3,7 +3,7 @@ mautrix-python |PyPI| |Python versions| |License| |Docs| |Code style| |Imports| -A Python 3.8+ asyncio Matrix framework. +A Python 3.10+ asyncio Matrix framework. Matrix room: `#maunium:maunium.net`_ @@ -49,7 +49,7 @@ Components .. _#maunium:maunium.net: https://matrix.to/#/#maunium:maunium.net .. _python-appservice-framework: https://github.com/Cadair/python-appservice-framework/ -.. _Client API: https://matrix.org/docs/spec/client_server/r0.6.1.html +.. _Client API: https://spec.matrix.org/latest/client-server-api/ .. _mautrix.api: https://docs.mau.fi/python/latest/api/mautrix.api.html .. _mautrix.client.api: https://docs.mau.fi/python/latest/api/mautrix.client.api.html diff --git a/dev-requirements.txt b/dev-requirements.txt index 232f7249..dff6ee9d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ pre-commit>=2.10.1,<3 -isort>=5.10.1,<6 -black==22.1.0 +isort>=8,<9 +black>=26,<27 diff --git a/docs/api/mautrix.client.state_store/index.rst b/docs/api/mautrix.client.state_store/index.rst index 91f2ba42..0a934308 100644 --- a/docs/api/mautrix.client.state_store/index.rst +++ b/docs/api/mautrix.client.state_store/index.rst @@ -13,5 +13,4 @@ Implementations In-memory Async database (asyncpg/aiosqlite) - Legacy database (SQLAlchemy) Flat file diff --git a/docs/api/mautrix.client.state_store/sqlalchemy.rst b/docs/api/mautrix.client.state_store/sqlalchemy.rst deleted file mode 100644 index 767e0595..00000000 --- a/docs/api/mautrix.client.state_store/sqlalchemy.rst +++ /dev/null @@ -1,5 +0,0 @@ -mautrix.client.state\_store.sqlalchemy -====================================== - -.. autoclass:: mautrix.client.state_store.sqlalchemy.SQLStateStore - :no-undoc-members: diff --git a/docs/api/mautrix.util/db.rst b/docs/api/mautrix.util/db.rst index 4607ad48..4aec7409 100644 --- a/docs/api/mautrix.util/db.rst +++ b/docs/api/mautrix.util/db.rst @@ -3,3 +3,6 @@ db .. automodule:: mautrix.util.db :imported-members: + + .. deprecated:: 0.15.0 + The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy. diff --git a/docs/requirements.txt b/docs/requirements.txt index 798e0f1b..7269d3a5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,7 +9,7 @@ yarl # that aren't used for anything that's in the docs python-magic ruamel.yaml -SQLAlchemy +SQLAlchemy<2 commonmark asyncpg aiosqlite @@ -17,3 +17,4 @@ prometheus_client python-olm unpaddedbase64 pycryptodome +base58 diff --git a/mautrix/__init__.py b/mautrix/__init__.py index f5e75340..c94995f9 100644 --- a/mautrix/__init__.py +++ b/mautrix/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.14.10" +__version__ = "0.21.0" __author__ = "Tulir Asokan " __all__ = [ "api", diff --git a/mautrix/api.py b/mautrix/api.py index 56023c20..1adde9ec 100644 --- a/mautrix/api.py +++ b/mautrix/api.py @@ -1,38 +1,34 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import ClassVar, Mapping +from typing import ClassVar, Literal, Mapping from enum import Enum from json.decoder import JSONDecodeError -from time import time from urllib.parse import quote as urllib_quote, urljoin as urllib_join import asyncio +import inspect import json import logging import platform -import sys +import time -from aiohttp import ClientSession, __version__ as aiohttp_version +from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version from aiohttp.client_exceptions import ClientError, ContentTypeError from yarl import URL from mautrix import __optional_imports__, __version__ as mautrix_version from mautrix.errors import MatrixConnectionError, MatrixRequestError, make_request_error +from mautrix.util.async_body import AsyncBody, async_iter_bytes from mautrix.util.logging import TraceLogger from mautrix.util.opt_prometheus import Counter -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - if __optional_imports__: # Safe to import, but it's not actually needed, so don't force-import the whole types module. - from mautrix.types import JSON + from mautrix.types import JSON, DeviceID, UserID API_CALLS = Counter( name="bridge_matrix_api_calls", @@ -52,9 +48,8 @@ class APIPath(Enum): These don't start with a slash so they can be used nicely with yarl. """ - CLIENT = "_matrix/client/r0" - CLIENT_UNSTABLE = "_matrix/client/unstable" - MEDIA = "_matrix/media/r0" + CLIENT = "_matrix/client" + MEDIA = "_matrix/media" SYNAPSE_ADMIN = "_synapse/admin" def __repr__(self): @@ -88,8 +83,8 @@ class PathBuilder: >>> from mautrix.api import Path >>> room_id = "!foo:example.com" >>> event_id = "$bar:example.com" - >>> str(Path.rooms[room_id].event[event_id]) - "_matrix/client/r0/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com" + >>> str(Path.v3.rooms[room_id].event[event_id]) + "_matrix/client/v3/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com" """ def __init__(self, path: str | APIPath = "") -> None: @@ -129,36 +124,34 @@ def __getitem__(self, append: str | int) -> PathBuilder: return self return PathBuilder(f"{self.path}/{self._quote(str(append))}") + def replace(self, find: str, replace: str) -> PathBuilder: + return PathBuilder(self.path.replace(find, replace)) + ClientPath = PathBuilder(APIPath.CLIENT) ClientPath.__doc__ = """ -A path builder with the standard client r0 prefix ( ``/_matrix/client/r0``, :attr:`APIPath.CLIENT`) +A path builder with the standard client prefix ( ``/_matrix/client``, :attr:`APIPath.CLIENT`). """ Path = PathBuilder(APIPath.CLIENT) Path.__doc__ = """A shorter alias for :attr:`ClientPath`""" -UnstableClientPath = PathBuilder(APIPath.CLIENT_UNSTABLE) -UnstableClientPath.__doc__ = """ -A path builder for client endpoints that haven't reached the spec yet -(``/_matrix/client/unstable``, :attr:`APIPath.CLIENT_UNSTABLE`) -""" MediaPath = PathBuilder(APIPath.MEDIA) MediaPath.__doc__ = """ -A path builder for standard media r0 paths (``/_matrix/media/r0``, :attr:`APIPath.MEDIA`) +A path builder with the standard media prefix (``/_matrix/media``, :attr:`APIPath.MEDIA`) Examples: >>> from mautrix.api import MediaPath - >>> str(MediaPath.config) - "_matrix/media/r0/config" + >>> str(MediaPath.v3.config) + "_matrix/media/v3/config" """ SynapseAdminPath = PathBuilder(APIPath.SYNAPSE_ADMIN) SynapseAdminPath.__doc__ = """ A path builder for synapse-specific admin API paths -(``/_synapse/admin/v1``, :attr:`APIPath.SYNAPSE_ADMIN`) +(``/_synapse/admin``, :attr:`APIPath.SYNAPSE_ADMIN`) Examples: >>> from mautrix.api import SynapseAdminPath >>> user_id = "@user:example.com" - >>> str(SynapseAdminPath.users[user_id]/login) + >>> str(SynapseAdminPath.v1.users[user_id]/login) "_synapse/admin/v1/users/%40user%3Aexample.com/login" """ @@ -200,6 +193,13 @@ class HTTPAPI: default_retry_count: int """The default retry count to use if a custom value is not passed to :meth:`request`""" + as_user_id: UserID | None + """An optional user ID to set as the user_id query parameter for appservice requests.""" + as_device_id: DeviceID | None + """ + An optional device ID to set as the user_id query parameter for appservice requests (MSC3202). + """ + def __init__( self, base_url: URL | str, @@ -210,6 +210,8 @@ def __init__( txn_id: int = 0, log: TraceLogger | None = None, loop: asyncio.AbstractEventLoop | None = None, + as_user_id: UserID | None = None, + as_device_id: UserID | None = None, ) -> None: """ Args: @@ -219,6 +221,10 @@ def __init__( txn_id: The outgoing transaction ID to start with. log: The :class:`logging.Logger` instance to log requests with. default_retry_count: Default number of retries to do when encountering network errors. + as_user_id: An optional user ID to set as the user_id query parameter for + appservice requests. + as_device_id: An optional device ID to set as the user_id query parameter for + appservice requests (MSC3202). """ self.base_url = URL(base_url) self.token = token @@ -226,6 +232,8 @@ def __init__( self.session = client_session or ClientSession( loop=loop, headers={"User-Agent": self.default_ua} ) + self.as_user_id = as_user_id + self.as_device_id = as_device_id if txn_id is not None: self.txn_id = txn_id if default_retry_count is not None: @@ -237,20 +245,21 @@ async def _send( self, method: Method, url: URL, - content: bytes | str, + content: bytes | bytearray | str | AsyncBody, query_params: dict[str, str], headers: dict[str, str], - ) -> JSON: + ) -> tuple[JSON, ClientResponse]: request = self.session.request( str(method), url, data=content, params=query_params, headers=headers ) async with request as response: if response.status < 200 or response.status >= 300: - errcode = message = None + errcode = unstable_errcode = message = None try: response_data = await response.json() errcode = response_data["errcode"] message = response_data["error"] + unstable_errcode = response_data.get("org.matrix.msc3848.unstable.errcode") except (JSONDecodeError, ContentTypeError, KeyError): pass raise make_request_error( @@ -258,39 +267,64 @@ async def _send( text=await response.text(), errcode=errcode, message=message, + unstable_errcode=unstable_errcode, ) - return await response.json() + return await response.json(), response def _log_request( self, method: Method, - path: PathBuilder, - content: str | bytes, + url: URL, + content: str | bytes | bytearray | AsyncBody | None, orig_content, query_params: dict[str, str], + headers: dict[str, str], req_id: int, + sensitive: bool, ) -> None: if not self.log: return - log_content = content if not isinstance(content, bytes) else f"<{len(content)} bytes>" + if isinstance(content, (bytes, bytearray)): + log_content = f"<{len(content)} bytes>" + elif inspect.isasyncgen(content): + size = headers.get("Content-Length", None) + log_content = f"<{size} async bytes>" if size else f"" + elif sensitive: + log_content = f"<{len(content)} sensitive bytes>" + else: + log_content = content as_user = query_params.get("user_id", None) - level = 1 if path == Path.sync else 5 + level = 5 if url.path.endswith("/v3/sync") else 10 self.log.log( level, - f"{method}#{req_id} /{path} {log_content}".strip(" "), + f"req #{req_id}: {method} {url} {log_content}".strip(" "), extra={ "matrix_http_request": { "req_id": req_id, "method": str(method), - "path": str(path), + "url": str(url), "content": ( - orig_content if isinstance(orig_content, (dict, list)) else log_content + orig_content + if isinstance(orig_content, (dict, list)) and not sensitive + else log_content ), "user": as_user, } }, ) + def _log_request_done( + self, path: PathBuilder | str, req_id: int, duration: float, status: int + ) -> None: + level = 5 if path == Path.v3.sync else 10 + duration_str = f"{duration * 1000:.1f}ms" if duration < 1 else f"{duration:.3f}s" + path_without_prefix = f"/{path}".replace("/_matrix/client", "") + self.log.log( + level, + f"req #{req_id} ({path_without_prefix}) completed in {duration_str} " + f"with status {status}", + ) + def _full_path(self, path: PathBuilder | str) -> str: path = str(path) if path and path[0] == "/": @@ -300,15 +334,27 @@ def _full_path(self, path: PathBuilder | str) -> str: base_path += "/" return urllib_join(base_path, path) + def log_download_request(self, url: URL, query_params: dict[str, str]) -> int: + req_id = _next_global_req_id() + self._log_request(Method.GET, url, None, None, query_params, {}, req_id, False) + return req_id + + def log_download_request_done( + self, url: URL, req_id: int, duration: float, status: int + ) -> None: + self._log_request_done(url.path.removeprefix("/_matrix/media/"), req_id, duration, status) + async def request( self, method: Method, path: PathBuilder | str, - content: dict | list | bytes | str | None = None, + content: dict | list | bytes | bytearray | str | AsyncBody | None = None, headers: dict[str, str] | None = None, query_params: Mapping[str, str] | None = None, retry_count: int | None = None, metrics_method: str = "", + min_iter_size: int = 25 * 1024 * 1024, + sensitive: bool = False, ) -> JSON: """ Make a raw Matrix API request. @@ -325,6 +371,10 @@ async def request( retry_count: Number of times to retry if the homeserver isn't reachable. Defaults to :attr:`default_retry_count`. metrics_method: Name of the method to include in Prometheus timing metrics. + min_iter_size: If the request body is larger than this value, it will be passed to + aiohttp as an async iterable to stop it from copying the whole thing + in memory. + sensitive: If True, the request content will not be logged. Returns: The parsed response JSON. @@ -335,6 +385,11 @@ async def request( query_params = query_params or {} if isinstance(query_params, dict): query_params = {k: v for k, v in query_params.items() if v is not None} + if self.as_user_id: + query_params["user_id"] = self.as_user_id + if self.as_device_id: + query_params["org.matrix.msc3202.device_id"] = self.as_device_id + query_params["device_id"] = self.as_device_id if method != Method.GET: content = content or {} @@ -351,12 +406,27 @@ async def request( if retry_count is None: retry_count = self.default_retry_count + if inspect.isasyncgen(content): + # Can't retry with non-static body + retry_count = 0 + do_fake_iter = content and hasattr(content, "__len__") and len(content) > min_iter_size + if do_fake_iter: + headers["Content-Length"] = str(len(content)) backoff = 4 + log_url = full_url.with_query(query_params) while True: - self._log_request(method, path, content, orig_content, query_params, req_id) + self._log_request( + method, log_url, content, orig_content, query_params, headers, req_id, sensitive + ) API_CALLS.labels(method=metrics_method).inc() + req_content = async_iter_bytes(content) if do_fake_iter else content + start = time.monotonic() try: - return await self._send(method, full_url, content, query_params, headers or {}) + resp_data, resp = await self._send( + method, full_url, req_content, query_params, headers or {} + ) + self._log_request_done(path, req_id, time.monotonic() - start, resp.status) + return resp_data except MatrixRequestError as e: API_CALLS_FAILED.labels(method=metrics_method).inc() if retry_count > 0 and e.http_status in (502, 503, 504): @@ -365,6 +435,7 @@ async def request( f"retrying in {backoff} seconds" ) else: + self._log_request_done(path, req_id, time.monotonic() - start, e.http_status) raise except ClientError as e: API_CALLS_FAILED.labels(method=metrics_method).inc() @@ -384,10 +455,14 @@ async def request( def get_txn_id(self) -> str: """Get a new unique transaction ID.""" self.txn_id += 1 - return f"mautrix-python_R{self.txn_id}@T{int(time() * 1000)}" + return f"mautrix-python_{time.time_ns()}_{self.txn_id}" def get_download_url( - self, mxc_uri: str, download_type: Literal["download", "thumbnail"] = "download" + self, + mxc_uri: str, + download_type: Literal["download", "thumbnail"] = "download", + file_name: str | None = None, + authenticated: bool = False, ) -> URL: """ Get the full HTTP URL to download a ``mxc://`` URI. @@ -395,6 +470,8 @@ def get_download_url( Args: mxc_uri: The MXC URI whose full URL to get. download_type: The type of download ("download" or "thumbnail"). + file_name: Optionally, a file name to include in the download URL. + authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11. Returns: The full HTTP URL. @@ -403,11 +480,38 @@ def get_download_url( ValueError: If `mxc_uri` doesn't begin with ``mxc://``. Examples: - >>> api = HTTPAPI(...) + >>> api = HTTPAPI(base_url="https://matrix-client.matrix.org", ...) >>> api.get_download_url("mxc://matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6") - "https://matrix.org/_matrix/media/r0/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6" + "https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6" + >>> api.get_download_url("mxc://matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6", file_name="hello.png") + "https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png" + """ + server_name, media_id = self.parse_mxc_uri(mxc_uri) + if authenticated: + url = self.base_url / str(APIPath.CLIENT) / "v1" / "media" + else: + url = self.base_url / str(APIPath.MEDIA) / "v3" + url = url / download_type / server_name / media_id + if file_name: + url /= file_name + return url + + @staticmethod + def parse_mxc_uri(mxc_uri: str) -> tuple[str, str]: + """ + Parse a ``mxc://`` URI. + + Args: + mxc_uri: The MXC URI to parse. + + Returns: + A tuple containing the server and media ID of the MXC URI. + + Raises: + ValueError: If `mxc_uri` doesn't begin with ``mxc://``. """ if mxc_uri.startswith("mxc://"): - return self.base_url / str(APIPath.MEDIA) / download_type / mxc_uri[6:] + server_name, media_id = mxc_uri[6:].split("/") + return server_name, media_id else: raise ValueError("MXC URI did not begin with `mxc://`") diff --git a/mautrix/appservice/__init__.py b/mautrix/appservice/__init__.py index a6028208..c4b2ae74 100644 --- a/mautrix/appservice/__init__.py +++ b/mautrix/appservice/__init__.py @@ -11,4 +11,5 @@ "ASStateStore", "AppServiceServerMixin", "DOUBLE_PUPPET_SOURCE_KEY", + "state_store", ] diff --git a/mautrix/appservice/api/appservice.py b/mautrix/appservice/api/appservice.py index afbcc689..6654ec20 100644 --- a/mautrix/appservice/api/appservice.py +++ b/mautrix/appservice/api/appservice.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -51,6 +51,7 @@ def __init__( client_session: ClientSession = None, child: bool = False, real_user: bool = False, + real_user_as_token: bool = False, bridge_name: str | None = None, default_retry_count: int = None, loop: asyncio.AbstractEventLoop | None = None, @@ -66,6 +67,7 @@ def __init__( client_session: The aiohttp ClientSession to use. child: Whether or not this is instance is a child of another AppServiceAPI. real_user: Whether or not this is a real (non-appservice-managed) user. + real_user_as_token: Whether this real user is actually using another ``as_token``. bridge_name: The name of the bridge to put in the ``fi.mau.double_puppet_source`` field in outgoing message events sent through real users. """ @@ -85,6 +87,7 @@ def __init__( self._bot_intent = None self.state_store = state_store self.is_real_user = real_user + self.is_real_user_as_token = real_user_as_token self.bridge_name = bridge_name if not child: @@ -113,7 +116,9 @@ def user(self, user: UserID) -> ChildAppServiceAPI: self.children[user] = child return child - def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> AppServiceAPI: + def real_user( + self, mxid: UserID, token: str, base_url: URL | None = None, as_token: bool = False + ) -> AppServiceAPI: """ Get the AppServiceAPI for a real (non-appservice-managed) Matrix user. @@ -122,6 +127,8 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap token: The access token for the user. base_url: The base URL of the homeserver client-server API to use. Defaults to the appservice homeserver URL. + as_token: Whether the token is actually an as_token + (meaning the ``user_id`` query parameter needs to be used). Returns: The AppServiceAPI object for the user. @@ -136,6 +143,7 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap child = self.real_users[mxid] child.base_url = base_url or child.base_url child.token = token or child.token + child.is_real_user_as_token = as_token except KeyError: child = type(self)( base_url=base_url or self.base_url, @@ -145,6 +153,7 @@ def real_user(self, mxid: UserID, token: str, base_url: URL | None = None) -> Ap state_store=self.state_store, client_session=self.session, real_user=True, + real_user_as_token=as_token, bridge_name=self.bridge_name, default_retry_count=self.default_retry_count, ) @@ -163,7 +172,11 @@ def bot_intent(self) -> as_api.IntentAPI: return self._bot_intent def intent( - self, user: UserID = None, token: str | None = None, base_url: str | None = None + self, + user: UserID = None, + token: str | None = None, + base_url: str | None = None, + real_user_as_token: bool = False, ) -> as_api.IntentAPI: """ Get the intent API of a child user. @@ -173,6 +186,8 @@ def intent( token: The access token to use. Only applicable for non-appservice-managed users. base_url: The base URL of the homeserver client-server API to use. Only applicable for non-appservice users. Defaults to the appservice homeserver URL. + real_user_as_token: When providing a token, whether it's actually another as_token + (meaning the ``user_id`` query parameter needs to be used). Returns: The IntentAPI object for the given user. @@ -184,7 +199,10 @@ def intent( raise ValueError("Can't get child intent of real user") if token: return as_api.IntentAPI( - user, self.real_user(user, token, base_url), self.bot_intent(), self.state_store + user, + self.real_user(user, token, base_url, as_token=real_user_as_token), + self.bot_intent(), + self.state_store, ) return as_api.IntentAPI(user, self.user(user), self.bot_intent(), self.state_store) @@ -198,31 +216,38 @@ def request( query_params: dict[str, Any] | None = None, retry_count: int | None = None, metrics_method: str | None = "", + min_iter_size: int = 25 * 1024 * 1024, ) -> Awaitable[dict]: """ - Make a raw HTTP request, with optional AppService timestamp massaging and external_url - setting. + Make a raw Matrix API request, acting as the appservice user assigned to this AppServiceAPI + instance and optionally including timestamp massaging. Args: method: The HTTP method to use. - path: The API endpoint to call. - Does not include the base path (e.g. /_matrix/client/r0). - content: The content to post as a dict (json) or bytes/str (raw). + path: The full API endpoint to call (including the _matrix/... prefix) + content: The content to post as a dict/list (will be serialized as JSON) + or bytes/str (will be sent as-is). timestamp: The timestamp query param used for timestamp massaging. - headers: The dict of HTTP headers to send. - query_params: The dict of query parameters to send. + headers: A dict of HTTP headers to send. If the headers don't contain ``Content-Type``, + it'll be set to ``application/json``. The ``Authorization`` header is always + overridden if :attr:`token` is set. + query_params: A dict of query parameters to send. retry_count: Number of times to retry if the homeserver isn't reachable. + Defaults to :attr:`default_retry_count`. metrics_method: Name of the method to include in Prometheus timing metrics. + min_iter_size: If the request body is larger than this value, it will be passed to + aiohttp as an async iterable to stop it from copying the whole thing + in memory. Returns: - The response as a dict. + The parsed response JSON. """ query_params = query_params or {} if timestamp is not None: if isinstance(timestamp, datetime): timestamp = int(timestamp.replace(tzinfo=timezone.utc).timestamp() * 1000) query_params["ts"] = timestamp - if not self.is_real_user: + if not self.is_real_user or self.is_real_user_as_token: query_params["user_id"] = self.identity or self.bot_mxid return super().request( diff --git a/mautrix/appservice/api/intent.py b/mautrix/appservice/api/intent.py index 29d41cd4..626b34f1 100644 --- a/mautrix/appservice/api/intent.py +++ b/mautrix/appservice/api/intent.py @@ -1,18 +1,18 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Awaitable, Iterable +from typing import Any, Awaitable, Iterable, TypeVar from urllib.parse import quote as urllib_quote -from mautrix import __optional_imports__ -from mautrix.api import Method, Path, UnstableClientPath +from mautrix.api import Method, Path from mautrix.client import ClientAPI, StoreUpdatingAPI from mautrix.errors import ( IntentError, + MAlreadyJoined, MatrixRequestError, MBadState, MForbidden, @@ -22,7 +22,10 @@ from mautrix.types import ( JSON, BatchID, + BatchSendEvent, BatchSendResponse, + BatchSendStateEvent, + BeeperBatchSendResponse, ContentURI, EventContent, EventID, @@ -31,7 +34,6 @@ JoinRulesStateEventContent, Member, Membership, - MessageEvent, PowerLevelStateEventContent, PresenceState, RoomAvatarStateEventContent, @@ -39,7 +41,6 @@ RoomNameStateEventContent, RoomPinnedEventsStateEventContent, RoomTopicStateEventContent, - StateEvent, StateEventContent, UserID, ) @@ -70,6 +71,8 @@ def quote(*args, **kwargs): ClientAPI.search_users, ClientAPI.set_displayname, ClientAPI.set_avatar_url, + ClientAPI.beeper_update_profile, + ClientAPI.create_mxc, ClientAPI.upload_media, ClientAPI.send_receipt, ClientAPI.set_fully_read_marker, @@ -91,6 +94,8 @@ def quote(*args, **kwargs): DOUBLE_PUPPET_SOURCE_KEY = "fi.mau.double_puppet_source" +T = TypeVar("T") + class IntentAPI(StoreUpdatingAPI): """ @@ -113,6 +118,8 @@ def __init__( ) -> None: super().__init__(mxid=mxid, api=api, state_store=state_store) self.bot = bot + if bot is not None: + self.versions_cache = bot.versions_cache self.log = api.base_log.getChild("intent") for method in ENSURE_REGISTERED_METHODS: @@ -131,13 +138,19 @@ async def wrapper(*args, __self=self, __method=method, **kwargs): room_id = kwargs.get("room_id", None) if not room_id: room_id = args[0] - await __self.ensure_joined(room_id) + ensure_joined = kwargs.pop("ensure_joined", True) + if ensure_joined: + await __self.ensure_joined(room_id) return await __method(*args, **kwargs) setattr(self, method.__name__, wrapper) def user( - self, user_id: UserID, token: str | None = None, base_url: str | None = None + self, + user_id: UserID, + token: str | None = None, + base_url: str | None = None, + as_token: bool = False, ) -> IntentAPI: """ Get the intent API for a specific user. @@ -150,15 +163,17 @@ def user( user_id: The Matrix ID of the user whose intent API to get. token: The access token to use for the Matrix ID. base_url: An optional URL to use for API requests. + as_token: Whether the provided token is actually another as_token + (meaning the ``user_id`` query parameter needs to be used). Returns: The IntentAPI for the given user. """ if not self.bot: - return self.api.intent(user_id, token, base_url) + return self.api.intent(user_id, token, base_url, real_user_as_token=as_token) else: self.log.warning("Called IntentAPI#user() of child intent object.") - return self.bot.api.intent(user_id, token, base_url) + return self.bot.api.intent(user_id, token, base_url, real_user_as_token=as_token) # region User actions @@ -176,7 +191,7 @@ async def set_presence( Args: presence: The online status of the user. status: The status message. - ignore_cache: Whether or not to set presence even if the cache says the presence is + ignore_cache: Whether to set presence even if the cache says the presence is already set to that value. """ await self.ensure_registered() @@ -188,6 +203,13 @@ async def set_presence( # endregion # region Room actions + def _add_source_key(self, content: T = None) -> T: + if self.api.is_real_user and self.api.bridge_name: + if not content: + content = {} + content[DOUBLE_PUPPET_SOURCE_KEY] = self.api.bridge_name + return content + async def invite_user( self, room_id: RoomID, @@ -226,37 +248,95 @@ async def invite_user( await self.state_store.get_membership(room_id, user_id) not in ok_states ) if do_invite: + extra_content = self._add_source_key(extra_content) await super().invite_user( room_id, user_id, reason=reason, extra_content=extra_content ) await self.state_store.invited(room_id, user_id) + except MAlreadyJoined as e: + await self.state_store.joined(room_id, user_id) except MatrixRequestError as e: - if e.errcode == "M_FORBIDDEN" and "is already in the room" in e.message: + # TODO remove this once MSC3848 is released and minimum spec version is bumped + if e.errcode == "M_FORBIDDEN" and ( + "already in the room" in e.message or "is already joined to room" in e.message + ): await self.state_store.joined(room_id, user_id) else: - raise IntentError(f"Failed to invite {user_id} to {room_id}", e) + raise + + async def kick_user( + self, + room_id: RoomID, + user_id: UserID, + reason: str = "", + extra_content: dict[str, JSON] | None = None, + ) -> None: + extra_content = self._add_source_key(extra_content) + await super().kick_user(room_id, user_id, reason=reason, extra_content=extra_content) + + async def ban_user( + self, + room_id: RoomID, + user_id: UserID, + reason: str = "", + extra_content: dict[str, JSON] | None = None, + ) -> None: + extra_content = self._add_source_key(extra_content) + await super().ban_user(room_id, user_id, reason=reason, extra_content=extra_content) + + async def unban_user( + self, + room_id: RoomID, + user_id: UserID, + reason: str = "", + extra_content: dict[str, JSON] | None = None, + ) -> None: + extra_content = self._add_source_key(extra_content) + await super().unban_user(room_id, user_id, reason=reason, extra_content=extra_content) + + async def join_room_by_id( + self, + room_id: RoomID, + third_party_signed: JSON = None, + extra_content: dict[str, JSON] | None = None, + ) -> RoomID: + extra_content = self._add_source_key(extra_content) + return await super().join_room_by_id( + room_id, third_party_signed=third_party_signed, extra_content=extra_content + ) + + async def leave_room( + self, + room_id: RoomID, + reason: str | None = None, + extra_content: dict[str, JSON] | None = None, + raise_not_in_room: bool = False, + ) -> None: + extra_content = self._add_source_key(extra_content) + await super().leave_room(room_id, reason, extra_content, raise_not_in_room) def set_room_avatar( self, room_id: RoomID, avatar_url: ContentURI | None, **kwargs ) -> Awaitable[EventID]: - return self.send_state_event( - room_id, EventType.ROOM_AVATAR, RoomAvatarStateEventContent(url=avatar_url), **kwargs - ) + content = RoomAvatarStateEventContent(url=avatar_url) + content = self._add_source_key(content) + return self.send_state_event(room_id, EventType.ROOM_AVATAR, content, **kwargs) def set_room_name(self, room_id: RoomID, name: str, **kwargs) -> Awaitable[EventID]: - return self.send_state_event( - room_id, EventType.ROOM_NAME, RoomNameStateEventContent(name=name), **kwargs - ) + content = RoomNameStateEventContent(name=name) + content = self._add_source_key(content) + return self.send_state_event(room_id, EventType.ROOM_NAME, content, **kwargs) def set_room_topic(self, room_id: RoomID, topic: str, **kwargs) -> Awaitable[EventID]: - return self.send_state_event( - room_id, EventType.ROOM_TOPIC, RoomTopicStateEventContent(topic=topic), **kwargs - ) + content = RoomTopicStateEventContent(topic=topic) + content = self._add_source_key(content) + return self.send_state_event(room_id, EventType.ROOM_TOPIC, content, **kwargs) async def get_power_levels( - self, room_id: RoomID, ignore_cache: bool = False + self, room_id: RoomID, ignore_cache: bool = False, ensure_joined: bool = True ) -> PowerLevelStateEventContent: - await self.ensure_joined(room_id) + if ensure_joined: + await self.ensure_joined(room_id) if not ignore_cache: levels = await self.state_store.get_power_levels(room_id) if levels: @@ -265,12 +345,17 @@ async def get_power_levels( levels = await self.get_state_event(room_id, EventType.ROOM_POWER_LEVELS) except MNotFound: levels = PowerLevelStateEventContent() + except MForbidden: + if not ensure_joined: + return PowerLevelStateEventContent() + raise await self.state_store.set_power_levels(room_id, levels) return levels async def set_power_levels( self, room_id: RoomID, content: PowerLevelStateEventContent, **kwargs ) -> EventID: + content = self._add_source_key(content) response = await self.send_state_event( room_id, EventType.ROOM_POWER_LEVELS, content, **kwargs ) @@ -289,12 +374,9 @@ async def get_pinned_messages(self, room_id: RoomID) -> list[EventID]: def set_pinned_messages( self, room_id: RoomID, events: list[EventID], **kwargs ) -> Awaitable[EventID]: - return self.send_state_event( - room_id, - EventType.ROOM_PINNED_EVENTS, - RoomPinnedEventsStateEventContent(pinned=events), - **kwargs, - ) + content = RoomPinnedEventsStateEventContent(pinned=events) + content = self._add_source_key(content) + return self.send_state_event(room_id, EventType.ROOM_PINNED_EVENTS, content, **kwargs) async def pin_message(self, room_id: RoomID, event_id: EventID) -> None: events = await self.get_pinned_messages(room_id) @@ -309,12 +391,9 @@ async def unpin_message(self, room_id: RoomID, event_id: EventID): await self.set_pinned_messages(room_id, events) async def set_join_rule(self, room_id: RoomID, join_rule: JoinRule, **kwargs): - await self.send_state_event( - room_id, - EventType.ROOM_JOIN_RULES, - JoinRulesStateEventContent(join_rule=join_rule), - **kwargs, - ) + content = JoinRulesStateEventContent(join_rule=join_rule) + content = self._add_source_key(content) + await self.send_state_event(room_id, EventType.ROOM_JOIN_RULES, content, **kwargs) async def get_room_displayname( self, room_id: RoomID, user_id: UserID, ignore_cache=False @@ -337,15 +416,10 @@ async def get_room_member_info( async def set_typing( self, room_id: RoomID, - is_typing: bool = True, - timeout: int = 5000, - ignore_cache: bool = False, + timeout: int = 0, ) -> None: await self.ensure_joined(room_id) - if not ignore_cache and is_typing == self.state_store.is_typing(room_id, self.mxid): - return - await super().set_typing(room_id, timeout if is_typing else 0) - self.state_store.set_typing(room_id, self.mxid, is_typing, timeout) + await super().set_typing(room_id, timeout) async def error_and_leave( self, room_id: RoomID, text: str | None = None, html: str | None = None @@ -358,10 +432,7 @@ async def send_message_event( self, room_id: RoomID, event_type: EventType, content: EventContent, **kwargs ) -> EventID: await self._ensure_has_power_level_for(room_id, event_type) - - if self.api.is_real_user and self.api.bridge_name is not None: - content[DOUBLE_PUPPET_SOURCE_KEY] = self.api.bridge_name - + content = self._add_source_key(content) return await super().send_message_event(room_id, event_type, content, **kwargs) async def redact( @@ -373,10 +444,7 @@ async def redact( **kwargs, ) -> EventID: await self._ensure_has_power_level_for(room_id, EventType.ROOM_REDACTION) - if self.api.is_real_user and self.api.bridge_name: - if not extra_content: - extra_content = {} - extra_content[DOUBLE_PUPPET_SOURCE_KEY] = self.api.bridge_name + extra_content = self._add_source_key(extra_content) return await super().redact( room_id, event_id, reason, extra_content=extra_content, **kwargs ) @@ -390,6 +458,7 @@ async def send_state_event( **kwargs, ) -> EventID: await self._ensure_has_power_level_for(room_id, event_type, state_key=state_key) + content = self._add_source_key(content) return await super().send_state_event(room_id, event_type, content, state_key, **kwargs) async def get_room_members( @@ -429,14 +498,24 @@ async def mark_read( ) self.state_store.set_read(room_id, self.mxid, event_id) + async def appservice_ping(self, appservice_id: str, txn_id: str | None = None) -> int: + resp = await self.api.request( + Method.POST, + Path.v1.appservice[appservice_id].ping, + content={"transaction_id": txn_id} if txn_id is not None else {}, + ) + return resp.get("duration_ms") or -1 + async def batch_send( self, room_id: RoomID, prev_event_id: EventID, *, batch_id: BatchID | None = None, - events: Iterable[MessageEvent], - state_events_at_start: Iterable[StateEvent] = None, + events: Iterable[BatchSendEvent], + state_events_at_start: Iterable[BatchSendStateEvent] = (), + beeper_new_messages: bool = False, + beeper_mark_read_by: UserID | None = None, ) -> BatchSendResponse: """ Send a batch of historical events into a room. See `MSC2716`_ for more info. @@ -445,6 +524,9 @@ async def batch_send( .. versionadded:: v0.12.5 + .. deprecated:: v0.20.3 + MSC2716 was abandoned by upstream and Beeper has forked the endpoint. + Args: room_id: The room ID to send the events to. prev_event_id: The anchor event. The batch will be inserted immediately after this event. @@ -454,14 +536,20 @@ async def batch_send( state_events_at_start: The state events to send at the start of the batch. These will be sent as outlier events, which means they won't be a part of the actual room state. + beeper_new_messages: Custom flag to tell the server that the messages can be sent to + the end of the room as normal messages instead of history. Returns: All the event IDs generated, plus a batch ID that can be passed back to this method. """ - path = UnstableClientPath["org.matrix.msc2716"].rooms[room_id].batch_send - query = {"prev_event_id": prev_event_id} + path = Path.unstable["org.matrix.msc2716"].rooms[room_id].batch_send + query: JSON = {"prev_event_id": prev_event_id} if batch_id: query["batch_id"] = batch_id + if beeper_new_messages: + query["com.beeper.new_messages"] = "true" + if beeper_mark_read_by: + query["com.beeper.mark_read_by"] = beeper_mark_read_by resp = await self.api.request( Method.POST, path, @@ -473,6 +561,58 @@ async def batch_send( ) return BatchSendResponse.deserialize(resp) + async def beeper_batch_send( + self, + room_id: RoomID, + events: Iterable[BatchSendEvent], + *, + forward: bool = False, + forward_if_no_messages: bool = False, + send_notification: bool = False, + mark_read_by: UserID | None = None, + ) -> BeeperBatchSendResponse: + """ + Send a batch of events into a room. Only for Beeper/hungryserv. + + .. versionadded:: v0.20.3 + + Args: + room_id: The room ID to send the events to. + events: The events to send. + forward: Send events to the end of the room instead of the beginning + forward_if_no_messages: Send events to the end of the room, but only if there are no + messages in the room. If there are messages, send the new messages to the beginning. + send_notification: Send a push notification for the new messages. + Only applies when sending to the end of the room. + mark_read_by: Send a read receipt from the given user ID atomically. + + Returns: + All the event IDs generated. + """ + body = { + "events": [evt.serialize() for evt in events], + } + if forward: + body["forward"] = forward + elif forward_if_no_messages: + body["forward_if_no_messages"] = forward_if_no_messages + if send_notification: + body["send_notification"] = send_notification + if mark_read_by: + body["mark_read_by"] = mark_read_by + resp = await self.api.request( + Method.POST, + Path.unstable["com.beeper.backfill"].rooms[room_id].batch_send, + content=body, + ) + return BeeperBatchSendResponse.deserialize(resp) + + async def beeper_delete_room(self, room_id: RoomID) -> None: + versions = await self.versions() + if not versions.supports("com.beeper.room_yeeting"): + raise RuntimeError("Homeserver does not support yeeting rooms") + await self.api.request(Method.POST, Path.unstable["com.beeper.yeet"].rooms[room_id].delete) + # endregion # region Ensure functions @@ -538,7 +678,7 @@ def _register(self) -> Awaitable[dict]: "inhibit_login": True, } query_params = {"kind": "user"} - return self.api.request(Method.POST, Path.register, content, query_params=query_params) + return self.api.request(Method.POST, Path.v3.register, content, query_params=query_params) async def ensure_registered(self) -> None: """ @@ -554,8 +694,6 @@ async def ensure_registered(self) -> None: await self._register() except MUserInUse: pass - except MatrixRequestError as e: - raise IntentError(f"Failed to register {self.mxid}", e) await self.state_store.registered(self.mxid) async def _ensure_has_power_level_for( @@ -571,7 +709,9 @@ async def _ensure_has_power_level_for( return if not await self.state_store.has_power_levels_cached(room_id): # TODO add option to not try to fetch power levels from server - await self.get_power_levels(room_id, ignore_cache=True) + await self.get_power_levels(room_id, ignore_cache=True, ensure_joined=False) + if not await self.state_store.has_create_cached(room_id): + await self.get_state_event(room_id, EventType.ROOM_CREATE, format="event") if not await self.state_store.has_power_level(room_id, self.mxid, event_type): # TODO implement something better raise IntentError( diff --git a/mautrix/appservice/appservice.py b/mautrix/appservice/appservice.py index 396e4610..65e202ae 100644 --- a/mautrix/appservice/appservice.py +++ b/mautrix/appservice/appservice.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -13,7 +13,7 @@ from aiohttp import web import aiohttp -from mautrix.types import JSON, RoomAlias, UserID +from mautrix.types import JSON, RoomAlias, UserID, VersionsResponse from mautrix.util.logging import TraceLogger from ..api import HTTPAPI @@ -36,8 +36,8 @@ class AppService(AppServiceServerMixin): domain: str id: str verify_ssl: bool - tls_cert: str - tls_key: str + tls_cert: str | None + tls_key: str | None as_token: str hs_token: str bot_mxid: UserID @@ -56,7 +56,10 @@ class AppService(AppServiceServerMixin): loop: asyncio.AbstractEventLoop log: TraceLogger app: web.Application - runner: web.AppRunner + runner: web.AppRunner | None + + _http_session: aiohttp.ClientSession | None + _intent: IntentAPI | None def __init__( self, @@ -77,11 +80,12 @@ def __init__( state_store: ASStateStore = None, aiohttp_params: dict = None, ephemeral_events: bool = False, + encryption_events: bool = False, default_ua: str = HTTPAPI.default_ua, default_http_retry_count: int = 0, connection_limit: int | None = None, ) -> None: - super().__init__(ephemeral_events=ephemeral_events) + super().__init__(ephemeral_events=ephemeral_events, encryption_events=encryption_events) self.server = server self.domain = domain self.id = id @@ -177,10 +181,12 @@ async def start(self, host: str = "127.0.0.1", port: int = 8080) -> None: async def stop(self) -> None: self.log.debug("Stopping appservice web server") - await self.runner.cleanup() + if self.runner: + await self.runner.cleanup() self._intent = None - await self._http_session.close() - self._http_session = None + if self._http_session: + await self._http_session.close() + self._http_session = None await self.state_store.close() async def _liveness_probe(self, _: web.Request) -> web.Response: @@ -188,3 +194,6 @@ async def _liveness_probe(self, _: web.Request) -> web.Response: async def _readiness_probe(self, _: web.Request) -> web.Response: return web.Response(status=200 if self.ready else 500, text="{}") + + async def ping_self(self, txn_id: str | None = None) -> int: + return await self.intent.appservice_ping(self.id, txn_id=txn_id) diff --git a/mautrix/appservice/as_handler.py b/mautrix/appservice/as_handler.py index c06de019..ec7e339f 100644 --- a/mautrix/appservice/as_handler.py +++ b/mautrix/appservice/as_handler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,7 +8,6 @@ from typing import Any, Awaitable, Callable from json import JSONDecodeError -import asyncio import json import logging @@ -16,35 +15,58 @@ from mautrix.types import ( JSON, + ASToDeviceEvent, + DeviceID, DeviceLists, DeviceOTKCount, EphemeralEvent, Event, + EventType, RoomAlias, SerializerError, UserID, ) +from mautrix.util import background_task HandlerFunc = Callable[[Event], Awaitable] class AppServiceServerMixin: - loop: asyncio.AbstractEventLoop log: logging.Logger hs_token: str ephemeral_events: bool + encryption_events: bool + synchronous_handlers: bool query_user: Callable[[UserID], JSON] query_alias: Callable[[RoomAlias], JSON] transactions: set[str] event_handlers: list[HandlerFunc] + to_device_handler: HandlerFunc | None + otk_handler: Callable[[dict[UserID, dict[DeviceID, DeviceOTKCount]]], Awaitable] | None + device_list_handler: Callable[[DeviceLists], Awaitable] | None - def __init__(self, ephemeral_events: bool = False) -> None: + def __init__( + self, + ephemeral_events: bool = False, + encryption_events: bool = False, + log: logging.Logger | None = None, + hs_token: str | None = None, + ) -> None: + if log is not None: + self.log = log + if hs_token is not None: + self.hs_token = hs_token self.transactions = set() self.event_handlers = [] + self.to_device_handler = None + self.otk_handler = None + self.device_list_handler = None self.ephemeral_events = ephemeral_events + self.encryption_events = encryption_events + self.synchronous_handlers = False async def default_query_handler(_): return None @@ -63,6 +85,7 @@ def register_routes(self, app: web.Application) -> None: ) app.router.add_route("GET", "/_matrix/app/v1/rooms/{alias}", self._http_query_alias) app.router.add_route("GET", "/_matrix/app/v1/users/{user_id}", self._http_query_user) + app.router.add_route("POST", "/_matrix/app/v1/ping", self._http_ping) def _check_token(self, request: web.Request) -> bool: try: @@ -70,10 +93,12 @@ def _check_token(self, request: web.Request) -> bool: except KeyError: try: token = request.headers["Authorization"].removeprefix("Bearer ") - except (KeyError, AttributeError): + except KeyError: + self.log.debug("No access_token nor Authorization header in request") return False if token != self.hs_token: + self.log.debug(f"Incorrect hs_token in request") return False return True @@ -116,6 +141,23 @@ async def _http_query_alias(self, request: web.Request) -> web.Response: return web.json_response({}, status=404) return web.json_response(response) + async def _http_ping(self, request: web.Request) -> web.Response: + if not self._check_token(request): + raise web.HTTPUnauthorized( + content_type="application/json", + text=json.dumps({"error": "Invalid auth token", "errcode": "M_UNKNOWN_TOKEN"}), + ) + try: + body = await request.json() + except JSONDecodeError: + raise web.HTTPBadRequest( + content_type="application/json", + text=json.dumps({"error": "Body is not JSON", "errcode": "M_NOT_JSON"}), + ) + txn_id = body.get("transaction_id") + self.log.info(f"Received ping from homeserver with transaction ID {txn_id}") + return web.json_response({}) + @staticmethod def _get_with_fallback( json: dict[str, Any], field: str, unstable_prefix: str, default: Any = None @@ -150,8 +192,11 @@ async def _read_transaction_header(self, request: web.Request) -> tuple[str, dic async def _http_handle_transaction(self, request: web.Request) -> web.Response: transaction_id, data = await self._read_transaction_header(request) + txn_content_log = [] try: events = data.pop("events") + if events: + txn_content_log.append(f"{len(events)} PDUs") except KeyError: raise web.HTTPBadRequest( content_type="application/json", @@ -160,20 +205,46 @@ async def _http_handle_transaction(self, request: web.Request) -> web.Response: ), ) - ephemeral = ( - self._get_with_fallback(data, "ephemeral", "de.sorunome.msc2409") - if self.ephemeral_events - else None - ) - device_lists = DeviceLists.deserialize( - self._get_with_fallback(data, "device_lists", "org.matrix.msc3202") - ) - otk_counts = { - user_id: DeviceOTKCount.deserialize(count) - for user_id, count in self._get_with_fallback( - data, "device_one_time_keys_count", "org.matrix.msc3202", default={} - ).items() - } + if self.ephemeral_events: + ephemeral = self._get_with_fallback(data, "ephemeral", "de.sorunome.msc2409") + if ephemeral: + txn_content_log.append(f"{len(ephemeral)} EDUs") + else: + ephemeral = None + if self.encryption_events: + to_device = self._get_with_fallback(data, "to_device", "de.sorunome.msc2409") + device_lists = DeviceLists.deserialize( + self._get_with_fallback(data, "device_lists", "org.matrix.msc3202") + ) + otk_counts = { + user_id: { + device_id: DeviceOTKCount.deserialize(count) + for device_id, count in devices.items() + } + for user_id, devices in self._get_with_fallback( + data, "device_one_time_keys_count", "org.matrix.msc3202", default={} + ).items() + } + if to_device: + txn_content_log.append(f"{len(to_device)} to-device events") + if device_lists.changed: + txn_content_log.append(f"{len(device_lists.changed)} device list changes") + if otk_counts: + txn_content_log.append( + f"{sum(len(vals) for vals in otk_counts.values())} OTK counts" + ) + else: + otk_counts = {} + device_lists = None + to_device = None + + if len(txn_content_log) > 2: + txn_content_log = [", ".join(txn_content_log[:-1]), txn_content_log[-1]] + if not txn_content_log: + txn_description = "nothing?" + else: + txn_description = " and ".join(txn_content_log) + self.log.debug(f"Handling transaction {transaction_id} with {txn_description}") try: output = await self.handle_transaction( @@ -181,12 +252,15 @@ async def _http_handle_transaction(self, request: web.Request) -> web.Response: events=events, extra_data=data, ephemeral=ephemeral, + to_device=to_device, device_lists=device_lists, - device_otk_count=otk_counts, + otk_counts=otk_counts, ) except Exception: self.log.exception("Exception in transaction handler") output = None + finally: + self.log.debug(f"Finished handling transaction {transaction_id}") self.transactions.add(transaction_id) @@ -194,6 +268,11 @@ async def _http_handle_transaction(self, request: web.Request) -> web.Response: @staticmethod def _fix_prev_content(raw_event: JSON) -> None: + try: + if raw_event["unsigned"] is None: + del raw_event["unsigned"] + except KeyError: + pass try: raw_event["unsigned"]["prev_content"] except KeyError: @@ -209,16 +288,37 @@ async def handle_transaction( events: list[JSON], extra_data: JSON, ephemeral: list[JSON] | None = None, - device_otk_count: dict[UserID, DeviceOTKCount] | None = None, + to_device: list[JSON] | None = None, + otk_counts: dict[UserID, dict[DeviceID, DeviceOTKCount]] | None = None, device_lists: DeviceLists | None = None, ) -> JSON: + for raw_td in to_device or []: + try: + td = ASToDeviceEvent.deserialize(raw_td) + except SerializerError: + self.log.exception("Failed to deserialize to-device event %s", raw_td) + else: + try: + await self.to_device_handler(td) + except Exception: + self.log.exception("Exception in Matrix to-device event handler") + if device_lists and self.device_list_handler: + try: + await self.device_list_handler(device_lists) + except Exception: + self.log.exception("Exception in Matrix device list change handler") + if otk_counts and self.otk_handler: + try: + await self.otk_handler(otk_counts) + except Exception: + self.log.exception("Exception in Matrix OTK count handler") for raw_edu in ephemeral or []: try: edu = EphemeralEvent.deserialize(raw_edu) except SerializerError: self.log.exception("Failed to deserialize ephemeral event %s", raw_edu) else: - self.handle_matrix_event(edu) + await self.handle_matrix_event(edu, ephemeral=True) for raw_event in events: try: self._fix_prev_content(raw_event) @@ -226,13 +326,16 @@ async def handle_transaction( except SerializerError: self.log.exception("Failed to deserialize event %s", raw_event) else: - self.handle_matrix_event(event) + await self.handle_matrix_event(event) return {} - def handle_matrix_event(self, event: Event) -> None: - if event.type.is_state and event.state_key is None: - self.log.debug(f"Not sending {event.event_id} to handlers: expected state_key.") - return + async def handle_matrix_event(self, event: Event, ephemeral: bool = False) -> None: + if ephemeral: + event.type = event.type.with_class(EventType.Class.EPHEMERAL) + elif getattr(event, "state_key", None) is not None: + event.type = event.type.with_class(EventType.Class.STATE) + else: + event.type = event.type.with_class(EventType.Class.MESSAGE) async def try_handle(handler_func: HandlerFunc): try: @@ -240,8 +343,12 @@ async def try_handle(handler_func: HandlerFunc): except Exception: self.log.exception("Exception in Matrix event handler") - for handler in self.event_handlers: - asyncio.create_task(try_handle(handler)) + if self.synchronous_handlers: + for handler in self.event_handlers: + await handler(event) + else: + for handler in self.event_handlers: + background_task.create(try_handle(handler)) def matrix_event_handler(self, func: HandlerFunc) -> HandlerFunc: self.event_handlers.append(func) diff --git a/mautrix/appservice/state_store/__init__.py b/mautrix/appservice/state_store/__init__.py index 40e4bead..771ac252 100644 --- a/mautrix/appservice/state_store/__init__.py +++ b/mautrix/appservice/state_store/__init__.py @@ -1,2 +1,4 @@ from .file import FileASStateStore from .memory import ASStateStore + +__all__ = ["FileASStateStore", "ASStateStore", "asyncpg"] diff --git a/mautrix/appservice/state_store/asyncpg.py b/mautrix/appservice/state_store/asyncpg.py index 33927279..f144269e 100644 --- a/mautrix/appservice/state_store/asyncpg.py +++ b/mautrix/appservice/state_store/asyncpg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/appservice/state_store/file.py b/mautrix/appservice/state_store/file.py index 9d4faded..fff99bbc 100644 --- a/mautrix/appservice/state_store/file.py +++ b/mautrix/appservice/state_store/file.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/appservice/state_store/memory.py b/mautrix/appservice/state_store/memory.py index d5fadf81..a9aba4ea 100644 --- a/mautrix/appservice/state_store/memory.py +++ b/mautrix/appservice/state_store/memory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -13,7 +13,6 @@ class ASStateStore(ClientStateStore, ABC): _presence: Dict[UserID, str] - _typing: Dict[Tuple[RoomID, UserID], int] _read: Dict[Tuple[RoomID, UserID], EventID] _registered: Dict[UserID, bool] @@ -21,7 +20,6 @@ def __init__(self) -> None: self._registered = {} # Non-persistent storage self._presence = {} - self._typing = {} self._read = {} async def is_registered(self, user_id: UserID) -> bool: @@ -69,22 +67,3 @@ def get_read(self, room_id: RoomID, user_id: UserID) -> Optional[EventID]: return self._read[(room_id, user_id)] except KeyError: return None - - def set_typing( - self, room_id: RoomID, user_id: UserID, is_typing: bool, timeout: int = 0 - ) -> None: - if is_typing: - ts = int(round(time.time() * 1000)) - self._typing[(room_id, user_id)] = ts + timeout - else: - try: - del self._typing[(room_id, user_id)] - except KeyError: - pass - - def is_typing(self, room_id: RoomID, user_id: UserID) -> bool: - ts = int(round(time.time() * 1000)) - try: - return self._typing[(room_id, user_id)] > ts - except KeyError: - return False diff --git a/mautrix/appservice/state_store/sqlalchemy.py b/mautrix/appservice/state_store/sqlalchemy.py deleted file mode 100644 index 6e876579..00000000 --- a/mautrix/appservice/state_store/sqlalchemy.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2021 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from mautrix.client.state_store.sqlalchemy import SQLStateStore as SQLClientStateStore - -from .memory import ASStateStore - - -class SQLASStateStore(SQLClientStateStore, ASStateStore): - def __init__(self) -> None: - SQLClientStateStore.__init__(self) - ASStateStore.__init__(self) diff --git a/mautrix/bridge/__init__.py b/mautrix/bridge/__init__.py index 83ad6473..ff6c0f7b 100644 --- a/mautrix/bridge/__init__.py +++ b/mautrix/bridge/__init__.py @@ -1,10 +1,11 @@ -from .async_getter_lock import async_getter_lock -from .bridge import Bridge +from ..util.async_getter_lock import async_getter_lock +from .bridge import Bridge, HomeserverSoftware from .config import BaseBridgeConfig from .custom_puppet import ( AutologinError, CustomPuppetError, CustomPuppetMixin, + EncryptionKeysFound, HomeserverURLNotFound, InvalidAccessToken, OnlyLoginSelf, @@ -13,6 +14,28 @@ from .disappearing_message import AbstractDisappearingMessage from .matrix import BaseMatrixHandler from .notification_disabler import NotificationDisabler -from .portal import BasePortal +from .portal import BasePortal, DMCreateError, IgnoreMatrixInvite, RejectMatrixInvite from .puppet import BasePuppet from .user import BaseUser + +__all__ = [ + "async_getter_lock", + "Bridge", + "HomeserverSoftware", + "BaseBridgeConfig", + "AutologinError", + "CustomPuppetError", + "CustomPuppetMixin", + "HomeserverURLNotFound", + "InvalidAccessToken", + "OnlyLoginSelf", + "OnlyLoginTrustedDomain", + "AbstractDisappearingMessage", + "BaseMatrixHandler", + "NotificationDisabler", + "BasePortal", + "BasePuppet", + "BaseUser", + "state_store", + "commands", +] diff --git a/mautrix/bridge/async_getter_lock.py b/mautrix/bridge/async_getter_lock.py deleted file mode 100644 index 78497578..00000000 --- a/mautrix/bridge/async_getter_lock.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) 2021 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import TYPE_CHECKING, Any -import functools - -if TYPE_CHECKING: - from typing import Awaitable, Callable, ParamSpec - - Param = ParamSpec("Param") - Func = Callable[Param, Awaitable[Any]] - - -def async_getter_lock(fn: "Func") -> "Func": - @functools.wraps(fn) - async def wrapper(cls, *args, **kwargs) -> Any: - async with cls._async_get_locks[args]: - return await fn(cls, *args, **kwargs) - - return wrapper diff --git a/mautrix/bridge/bridge.py b/mautrix/bridge/bridge.py index 1545f383..9ce9360e 100644 --- a/mautrix/bridge/bridge.py +++ b/mautrix/bridge/bridge.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,6 +7,8 @@ from typing import Any from abc import ABC, abstractmethod +from enum import Enum +import asyncio import sys from aiohttp import web @@ -17,7 +19,7 @@ from mautrix.client.state_store.asyncpg import PgStateStore as PgClientStateStore from mautrix.errors import MExclusive, MUnknownToken from mautrix.types import RoomID, UserID -from mautrix.util.async_db import Database, UpgradeTable +from mautrix.util.async_db import Database, DatabaseException, UpgradeTable from mautrix.util.bridge_state import BridgeState, BridgeStateEvent, GlobalBridgeState from mautrix.util.program import Program @@ -30,6 +32,20 @@ uvloop = None +class HomeserverSoftware(Enum): + STANDARD = "standard" + ASMUX = "asmux" + HUNGRY = "hungry" + + @property + def is_hungry(self) -> bool: + return self == self.HUNGRY + + @property + def is_asmux(self) -> bool: + return self == self.ASMUX + + class Bridge(Program, ABC): db: Database az: AppService @@ -42,7 +58,10 @@ class Bridge(Program, ABC): matrix: br.BaseMatrixHandler repo_url: str markdown_version: str - manhole: br.manhole.ManholeState | None + manhole: br.commands.manhole.ManholeState | None + homeserver_software: HomeserverSoftware + beeper_network_name: str | None = None + beeper_service_name: str | None = None def __init__( self, @@ -81,6 +100,16 @@ def prepare_arg_parser(self) -> None: "(not needed for running the bridge)" ), ) + self.parser.add_argument( + "--ignore-unsupported-database", + action="store_true", + help="Run even if the database schema is too new", + ) + self.parser.add_argument( + "--ignore-foreign-tables", + action="store_true", + help="Run even if the database contains tables from other programs (like Synapse)", + ) def preinit(self) -> None: super().preinit() @@ -89,14 +118,26 @@ def preinit(self) -> None: sys.exit(0) def prepare(self) -> None: + if self.config.env: + self.log.debug( + "Loaded config overrides from environment: %s", list(self.config.env.keys()) + ) super().prepare() + try: + self.homeserver_software = HomeserverSoftware(self.config["homeserver.software"]) + except Exception: + self.log.fatal("Invalid value for homeserver.software in config") + sys.exit(11) self.prepare_db() self.prepare_appservice() self.prepare_bridge() def prepare_config(self) -> None: self.config = self.config_class( - self.args.config, self.args.registration, self.args.base_config + self.args.config, + self.args.registration, + self.base_config_path, + env_prefix=self.module.upper(), ) if self.args.generate_registration: self.config._check_tokens = False @@ -135,6 +176,7 @@ def prepare_appservice(self) -> None: tls_key=self.config.get("appservice.tls_key", None), bot_localpart=self.config["appservice.bot_username"], ephemeral_events=self.config["appservice.ephemeral_events"], + encryption_events=self.config["bridge.encryption.appservice"], default_ua=HTTPAPI.default_ua, default_http_retry_count=default_http_retry_count, log="mau.as", @@ -152,19 +194,34 @@ def prepare_db(self) -> None: self.config["appservice.database"], upgrade_table=self.upgrade_table, db_args=self.config["appservice.database_opts"], + owner_name=self.name, + ignore_foreign_tables=self.args.ignore_foreign_tables, ) def prepare_bridge(self) -> None: self.matrix = self.matrix_class(bridge=self) + def _log_db_error(self, e: Exception) -> None: + self.log.critical("Failed to initialize database", exc_info=e) + if isinstance(e, DatabaseException) and e.explanation: + self.log.info(e.explanation) + sys.exit(25) + async def start_db(self) -> None: if hasattr(self, "db") and isinstance(self.db, Database): self.log.debug("Starting database...") - await self.db.start() - if isinstance(self.state_store, PgClientStateStore): - await self.state_store.upgrade_table.upgrade(self.db) - if self.matrix.e2ee: - self.matrix.e2ee.crypto_db.override_pool(self.db) + ignore_unsupported = self.args.ignore_unsupported_database + self.db.upgrade_table.allow_unsupported = ignore_unsupported + try: + await self.db.start() + if isinstance(self.state_store, PgClientStateStore): + self.state_store.upgrade_table.allow_unsupported = ignore_unsupported + await self.state_store.upgrade_table.upgrade(self.db) + if self.matrix.e2ee: + self.matrix.e2ee.crypto_db.allow_unsupported = ignore_unsupported + self.matrix.e2ee.crypto_db.override_pool(self.db) + except Exception as e: + self._log_db_error(e) async def stop_db(self) -> None: if hasattr(self, "db") and isinstance(self.db, Database): @@ -190,6 +247,9 @@ async def start(self) -> None: "correct, and do they match the values in the registration?" ) sys.exit(16) + except Exception: + self.log.critical("Failed to check connection to homeserver", exc_info=True) + sys.exit(16) await self.matrix.init_encryption() self.add_startup_actions(self.matrix.init_as_bot()) @@ -199,12 +259,16 @@ async def start(self) -> None: status_endpoint = self.config["homeserver.status_endpoint"] if status_endpoint and await self.count_logged_in_users() == 0: state = BridgeState(state_event=BridgeStateEvent.UNCONFIGURED).fill() - await state.send(status_endpoint, self.az.as_token, self.log) + while not await state.send(status_endpoint, self.az.as_token, self.log): + await asyncio.sleep(5) async def system_exit(self) -> None: if hasattr(self, "db") and isinstance(self.db, Database): - self.log.trace("Stopping database due to SystemExit") + self.log.debug("Stopping database due to SystemExit") await self.db.stop() + self.log.debug("Database stopped") + elif getattr(self, "db", None): + self.log.trace("Database not started at SystemExit") async def stop(self) -> None: if self.manhole: diff --git a/mautrix/bridge/commands/admin.py b/mautrix/bridge/commands/admin.py index b4508673..80ff15e8 100644 --- a/mautrix/bridge/commands/admin.py +++ b/mautrix/bridge/commands/admin.py @@ -1,13 +1,14 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Optional +from __future__ import annotations from mautrix.errors import IntentError, MatrixRequestError, MForbidden from mautrix.types import ContentURI, EventID, UserID +from ... import bridge as br from .handler import SECTION_ADMIN, CommandEvent, command_handler @@ -16,19 +17,28 @@ needs_auth=False, name="set-pl", help_section=SECTION_ADMIN, - help_args="<_level_> [_mxid_]", + help_args="[_mxid_] <_level_>", help_text="Set a temporary power level without affecting the remote platform.", ) async def set_power_level(evt: CommandEvent) -> EventID: + try: + user_id = UserID(evt.args[0]) + except IndexError: + return await evt.reply(f"**Usage:** `$cmdprefix+sp set-pl [mxid] `") + + if user_id.startswith("@"): + evt.args.pop(0) + else: + user_id = evt.sender.mxid + try: level = int(evt.args[0]) except (KeyError, IndexError): - return await evt.reply("**Usage:** `$cmdprefix+sp set-pl [mxid]`") + return await evt.reply("**Usage:** `$cmdprefix+sp set-pl [mxid] `") except ValueError: return await evt.reply("The level must be an integer.") - mxid = evt.args[1] if len(evt.args) > 1 else evt.sender.mxid levels = await evt.main_intent.get_power_levels(evt.room_id, ignore_cache=True) - levels.users[mxid] = level + levels.users[user_id] = level try: return await evt.main_intent.set_power_levels(evt.room_id, levels) except MForbidden as e: @@ -38,35 +48,50 @@ async def set_power_level(evt: CommandEvent) -> EventID: return await evt.reply("Failed to update power levels (see logs for more details)") +async def _get_mxid_param( + evt: CommandEvent, args: str +) -> tuple[br.BasePuppet | None, EventID | None]: + try: + user_id = UserID(evt.args[0]) + except IndexError: + return None, await evt.reply(f"**Usage:** `$cmdprefix+sp {evt.command} {args}`") + + if user_id.startswith("@") and ":" in user_id: + # TODO support parsing mention pills instead of requiring a plaintext mxid + puppet = await evt.bridge.get_puppet(user_id) + if not puppet: + return None, await evt.reply("The given user ID is not a valid ghost user.") + evt.args.pop(0) + return puppet, None + elif evt.is_portal and (puppet := await evt.portal.get_dm_puppet()): + return puppet, None + return None, await evt.reply( + "This is not a private chat portal, you must pass a user ID explicitly." + ) + + @command_handler( needs_admin=True, needs_auth=False, name="set-avatar", help_section=SECTION_ADMIN, - help_args="<_mxc:// uri_> [_mxid_]", + help_args="[_mxid_] <_mxc:// uri_>", help_text="Set an avatar for a ghost user.", ) -async def set_ghost_avatar(evt: CommandEvent) -> Optional[EventID]: +async def set_ghost_avatar(evt: CommandEvent) -> EventID | None: + puppet, err = await _get_mxid_param(evt, "[mxid] ") + if err: + return err + try: mxc_uri = ContentURI(evt.args[0]) - except (KeyError, IndexError): - return await evt.reply("**Usage:** `$cmdprefix+sp set-avatar [mxid]`") + except IndexError: + return await evt.reply("**Usage:** `$cmdprefix+sp set-avatar [mxid] `") if not mxc_uri.startswith("mxc://"): - return await evt.reply("The URI has to start with mxc://.") - if len(evt.args) > 1: - # TODO support parsing mention pills instead of requiring a plaintext mxid - puppet = await evt.processor.bridge.get_puppet(UserID(evt.args[1])) - if puppet is None: - return await evt.reply("The given mxid was not a valid ghost user.") - intent = puppet.intent - elif evt.is_portal: - intent = evt.portal.main_intent - if intent == evt.az.intent: - return await evt.reply("No mxid given and the main intent is not a ghost user.") - else: - return await evt.reply("No mxid given and not in a portal.") + return await evt.reply("The avatar URL must start with `mxc://`") + try: - return await intent.set_avatar_url(mxc_uri) + return await puppet.default_mxid_intent.set_avatar_url(mxc_uri) except (MatrixRequestError, IntentError): evt.log.exception("Failed to set avatar.") return await evt.reply("Failed to set avatar (see logs for more details).") @@ -80,20 +105,12 @@ async def set_ghost_avatar(evt: CommandEvent) -> Optional[EventID]: help_args="[_mxid_]", help_text="Remove the avatar for a ghost user.", ) -async def remove_ghost_avatar(evt: CommandEvent) -> Optional[EventID]: - if len(evt.args) > 0: - puppet = await evt.processor.bridge.get_puppet(UserID(evt.args[0])) - if puppet is None: - return await evt.reply("The given mxid was not a valid ghost user.") - intent = puppet.intent - elif evt.is_portal: - intent = evt.portal.main_intent - if intent == evt.az.intent: - return await evt.reply("No mxid given and the main intent is not a ghost user.") - else: - return await evt.reply("No mxid given and not in a portal.") +async def remove_ghost_avatar(evt: CommandEvent) -> EventID | None: + puppet, err = await _get_mxid_param(evt, "[mxid]") + if err: + return err try: - return await intent.set_avatar_url(ContentURI("")) + return await puppet.default_mxid_intent.set_avatar_url(ContentURI("")) except (MatrixRequestError, IntentError): evt.log.exception("Failed to remove avatar.") return await evt.reply("Failed to remove avatar (see logs for more details).") @@ -104,29 +121,15 @@ async def remove_ghost_avatar(evt: CommandEvent) -> Optional[EventID]: needs_auth=False, name="set-displayname", help_section=SECTION_ADMIN, - help_args="<_displayname_> [_mxid_]", + help_args="[_mxid_] <_displayname_>", help_text="Set the display name for a ghost user.", ) -async def set_ghost_display_name(evt: CommandEvent) -> Optional[EventID]: - if len(evt.args) > 1: - # This allows whitespaces in the name - puppet = await evt.processor.bridge.get_puppet(UserID(evt.args[len(evt.args) - 1])) - if puppet is None: - return await evt.reply( - "The given mxid was not a valid ghost user. " - "If the display name has whitespaces mxid is required" - ) - intent = puppet.intent - displayname = " ".join(evt.args[:-1]) - elif evt.is_portal: - intent = evt.portal.main_intent - if intent == evt.az.intent: - return await evt.reply("No mxid given and the main intent is not a ghost user.") - displayname = evt.args[0] - else: - return await evt.reply("No mxid given and not in a portal.") +async def set_ghost_display_name(evt: CommandEvent) -> EventID | None: + puppet, err = await _get_mxid_param(evt, "[mxid] ") + if err: + return err try: - return await intent.set_displayname(displayname) + return await puppet.default_mxid_intent.set_displayname(" ".join(evt.args)) except (MatrixRequestError, IntentError): evt.log.exception("Failed to set display name.") return await evt.reply("Failed to set display name (see logs for more details).") @@ -140,20 +143,12 @@ async def set_ghost_display_name(evt: CommandEvent) -> Optional[EventID]: help_args="[_mxid_]", help_text="Remove the display name for a ghost user.", ) -async def set_ghost_display_name(evt: CommandEvent) -> Optional[EventID]: - if len(evt.args) > 0: - puppet = await evt.processor.bridge.get_puppet(UserID(evt.args[0])) - if puppet is None: - return await evt.reply("The given mxid was not a valid ghost user.") - intent = puppet.intent - elif evt.is_portal: - intent = evt.portal.main_intent - if intent == evt.az.intent: - return await evt.reply("No mxid given and the main intent is not a ghost user.") - else: - return await evt.reply("No mxid given and not in a portal (see logs for more details).") +async def remove_ghost_display_name(evt: CommandEvent) -> EventID | None: + puppet, err = await _get_mxid_param(evt, "[mxid]") + if err: + return err try: - return await intent.set_displayname("") + return await puppet.default_mxid_intent.set_displayname("") except (MatrixRequestError, IntentError): evt.log.exception("Failed to remove display name.") return await evt.reply("Failed to remove display name (see logs for more details).") diff --git a/mautrix/bridge/commands/clean_rooms.py b/mautrix/bridge/commands/clean_rooms.py index ac508028..cdf72b35 100644 --- a/mautrix/bridge/commands/clean_rooms.py +++ b/mautrix/bridge/commands/clean_rooms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/bridge/commands/crypto.py b/mautrix/bridge/commands/crypto.py index b06d9cf7..b8cb280a 100644 --- a/mautrix/bridge/commands/crypto.py +++ b/mautrix/bridge/commands/crypto.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/bridge/commands/delete_portal.py b/mautrix/bridge/commands/delete_portal.py index b40aef12..a821ce77 100644 --- a/mautrix/bridge/commands/delete_portal.py +++ b/mautrix/bridge/commands/delete_portal.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/bridge/commands/handler.py b/mautrix/bridge/commands/handler.py index 3dda0c7a..4e6caa3d 100644 --- a/mautrix/bridge/commands/handler.py +++ b/mautrix/bridge/commands/handler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -12,6 +12,7 @@ import traceback from mautrix.appservice import AppService, IntentAPI +from mautrix.errors import MForbidden from mautrix.types import EventID, MessageEventContent, RoomID from mautrix.util import markdown from mautrix.util.logging import TraceLogger @@ -144,6 +145,22 @@ def print_error_traceback(self) -> bool: def main_intent(self) -> IntentAPI: return self.portal.main_intent if self.portal else self.az.intent + async def redact(self, reason: str | None = None) -> None: + """ + Try to redact the command. + + If the redaction fails with M_FORBIDDEN, the error will be logged and ignored. + """ + try: + if self.has_bridge_bot: + await self.az.intent.redact(self.room_id, self.event_id, reason=reason) + else: + await self.main_intent.redact(self.room_id, self.event_id, reason=reason) + except MForbidden as e: + self.log.warning(f"Failed to redact command {self.command}: {e}") + except Exception: + self.log.warning(f"Failed to redact command {self.command}", exc_info=True) + def reply( self, message: str, allow_html: bool = False, render_markdown: bool = True ) -> Awaitable[EventID]: @@ -287,9 +304,9 @@ async def get_permission_error(self, evt: CommandEvent) -> str | None: "you may only run it in management rooms." ) elif self.needs_admin and not evt.sender.is_admin: - return "This command requires administrator privileges." + return "That command is limited to bridge administrators." elif self.needs_auth and not await evt.sender.is_logged_in(): - return "This command requires you to be logged in." + return "That command requires you to be logged in." return None def has_permission(self, key: HelpCacheKey) -> bool: diff --git a/mautrix/bridge/commands/login_matrix.py b/mautrix/bridge/commands/login_matrix.py index e3c4a4e7..6cc6f7aa 100644 --- a/mautrix/bridge/commands/login_matrix.py +++ b/mautrix/bridge/commands/login_matrix.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -6,13 +6,7 @@ from mautrix.client import Client from mautrix.types import EventID -from ..custom_puppet import ( - AutologinError, - HomeserverURLNotFound, - InvalidAccessToken, - OnlyLoginSelf, - OnlyLoginTrustedDomain, -) +from ..custom_puppet import AutologinError, CustomPuppetError, InvalidAccessToken from .handler import SECTION_AUTH, CommandEvent, command_handler @@ -36,20 +30,10 @@ async def login_matrix(evt: CommandEvent) -> None: try: await puppet.switch_mxid(evt.args[0], evt.sender.mxid) await evt.reply("Successfully enabled double puppeting.") - except OnlyLoginSelf: - await evt.reply("You may only enable double puppeting with your own Matrix account.") - except OnlyLoginTrustedDomain: - await evt.reply(f"This bridge does not allow double puppeting from {homeserver}.") - except HomeserverURLNotFound: - await evt.reply( - f"Unable to find the base URL for {homeserver}. Please ensure a client" - " .well-known file is set up, or ask the bridge administrator to add the" - " homeserver URL to the bridge config." - ) except AutologinError as e: await evt.reply(f"Failed to create an access token: {e}") - except InvalidAccessToken: - await evt.reply("Invalid access token.") + except CustomPuppetError as e: + await evt.reply(str(e)) @command_handler( @@ -86,24 +70,3 @@ async def ping_matrix(evt: CommandEvent) -> EventID: except InvalidAccessToken: return await evt.reply("Your access token is invalid.") return await evt.reply("Your Matrix login is working.") - - -@command_handler( - needs_auth=True, - help_section=SECTION_AUTH, - help_text="Clear the Matrix sync token stored for your double puppet.", -) -async def clear_cache_matrix(evt: CommandEvent) -> EventID: - try: - puppet = await evt.sender.get_puppet() - except NotImplementedError: - return await evt.reply("This bridge has not implemented the clear-cache-matrix command") - if not puppet.is_real_user: - return await evt.reply("You are not logged in with your Matrix account.") - try: - puppet.stop() - puppet.next_batch = None - await puppet.start() - except InvalidAccessToken: - return await evt.reply("Your access token is invalid.") - return await evt.reply("Cleared cache successfully.") diff --git a/mautrix/bridge/commands/manhole.py b/mautrix/bridge/commands/manhole.py index 75ace8d2..0b5edd2a 100644 --- a/mautrix/bridge/commands/manhole.py +++ b/mautrix/bridge/commands/manhole.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/bridge/commands/meta.py b/mautrix/bridge/commands/meta.py index b7f513a8..d0f066b0 100644 --- a/mautrix/bridge/commands/meta.py +++ b/mautrix/bridge/commands/meta.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/bridge/commands/relay.py b/mautrix/bridge/commands/relay.py index ff131580..9b7ec622 100644 --- a/mautrix/bridge/commands/relay.py +++ b/mautrix/bridge/commands/relay.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/bridge/config.py b/mautrix/bridge/config.py index b669c8cf..defed222 100644 --- a/mautrix/bridge/config.py +++ b/mautrix/bridge/config.py @@ -1,12 +1,14 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any +from typing import Any, ClassVar from abc import ABC +import json +import os import re import secrets import time @@ -21,15 +23,40 @@ class BaseBridgeConfig(BaseFileConfig, BaseValidatableConfig, ABC): + env_prefix: str | None = None registration_path: str _registration: dict | None _check_tokens: bool + env: dict[str, Any] - def __init__(self, path: str, registration_path: str, base_path: str) -> None: + def __init__( + self, path: str, registration_path: str, base_path: str, env_prefix: str | None = None + ) -> None: super().__init__(path, base_path) self.registration_path = registration_path self._registration = None self._check_tokens = True + self.env = {} + if not self.env_prefix: + self.env_prefix = env_prefix + if self.env_prefix: + env_prefix = f"{self.env_prefix}_" + for key, value in os.environ.items(): + if not key.startswith(env_prefix): + continue + key = key.removeprefix(env_prefix) + if value.startswith("json::"): + value = json.loads(value.removeprefix("json::")) + self.env[key] = value + + def __getitem__(self, item: str) -> Any: + if self.env: + try: + sanitized_item = item.replace(".", "_").replace("[", "").replace("]", "").upper() + return self.env[sanitized_item] + except KeyError: + pass + return super().__getitem__(item) def save(self) -> None: super().save() @@ -45,6 +72,7 @@ def _new_token() -> str: def forbidden_defaults(self) -> list[ForbiddenDefault]: return [ ForbiddenDefault("homeserver.address", "https://example.com"), + ForbiddenDefault("homeserver.address", "https://matrix.example.com"), ForbiddenDefault("homeserver.domain", "example.com"), ] + ( [ @@ -73,6 +101,11 @@ def do_update(self, helper: ConfigUpdateHelper) -> None: copy("homeserver.connection_limit") copy("homeserver.status_endpoint") copy("homeserver.message_send_checkpoint_endpoint") + copy("homeserver.async_media") + if self.get("homeserver.asmux", False): + helper.base["homeserver.software"] = "asmux" + else: + copy("homeserver.software") copy("appservice.address") copy("appservice.hostname") @@ -82,7 +115,12 @@ def do_update(self, helper: ConfigUpdateHelper) -> None: copy("appservice.tls_cert") copy("appservice.tls_key") - copy("appservice.database") + if "appservice.database" in self and self["appservice.database"].startswith("sqlite:///"): + helper.base["appservice.database"] = self["appservice.database"].replace( + "sqlite:///", "sqlite:" + ) + else: + copy("appservice.database") copy("appservice.database_opts") copy("appservice.id") @@ -101,6 +139,38 @@ def do_update(self, helper: ConfigUpdateHelper) -> None: copy("bridge.management_room_text.additional_help") copy("bridge.management_room_multiple_messages") + copy("bridge.encryption.allow") + copy("bridge.encryption.default") + copy("bridge.encryption.require") + copy("bridge.encryption.appservice") + copy("bridge.encryption.msc4190") + copy("bridge.encryption.self_sign") + copy("bridge.encryption.delete_keys.delete_outbound_on_ack") + copy("bridge.encryption.delete_keys.dont_store_outbound") + copy("bridge.encryption.delete_keys.ratchet_on_decrypt") + copy("bridge.encryption.delete_keys.delete_fully_used_on_decrypt") + copy("bridge.encryption.delete_keys.delete_prev_on_new_session") + copy("bridge.encryption.delete_keys.delete_on_device_delete") + copy("bridge.encryption.delete_keys.periodically_delete_expired") + copy("bridge.encryption.delete_keys.delete_outdated_inbound") + copy("bridge.encryption.verification_levels.receive") + copy("bridge.encryption.verification_levels.send") + copy("bridge.encryption.verification_levels.share") + copy("bridge.encryption.allow_key_sharing") + if self.get("bridge.encryption.key_sharing.allow", False): + helper.base["bridge.encryption.allow_key_sharing"] = True + require_verif = self.get("bridge.encryption.key_sharing.require_verification", True) + require_cs = self.get("bridge.encryption.key_sharing.require_cross_signing", False) + if require_verif: + helper.base["bridge.encryption.verification_levels.share"] = "verified" + elif not require_cs: + helper.base["bridge.encryption.verification_levels.share"] = "unverified" + # else: default (cross-signed-tofu) + copy("bridge.encryption.rotation.enable_custom") + copy("bridge.encryption.rotation.milliseconds") + copy("bridge.encryption.rotation.messages") + copy("bridge.encryption.rotation.disable_device_change_key_rotation") + copy("bridge.relay.enabled") copy_dict("bridge.relay.message_formats", override_existing_map=False) @@ -132,14 +202,18 @@ def namespaces(self) -> dict[str, list[dict[str, Any]]]: "regex": re.escape(f"@{username_format}:{homeserver}").replace(regex_ph, ".*"), } ], - "aliases": [ - { - "exclusive": True, - "regex": re.escape(f"#{alias_format}:{homeserver}").replace(regex_ph, ".*"), - } - ] - if alias_format - else [], + "aliases": ( + [ + { + "exclusive": True, + "regex": re.escape(f"#{alias_format}:{homeserver}").replace( + regex_ph, ".*" + ), + } + ] + if alias_format + else [] + ), } def generate_registration(self) -> None: @@ -168,4 +242,7 @@ def generate_registration(self) -> None: if self["appservice.ephemeral_events"]: self._registration["de.sorunome.msc2409.push_ephemeral"] = True - self._registration["push_ephemeral"] = True + self._registration["receive_ephemeral"] = True + + if self["bridge.encryption.msc4190"]: + self._registration["io.element.msc4190"] = True diff --git a/mautrix/bridge/crypto_state_store.py b/mautrix/bridge/crypto_state_store.py index 84c05c4b..f08dec2c 100644 --- a/mautrix/bridge/crypto_state_store.py +++ b/mautrix/bridge/crypto_state_store.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -28,29 +28,6 @@ async def is_encrypted(self, room_id: RoomID) -> bool: return portal.encrypted if portal else False -try: - from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile - - class SQLCryptoStateStore(BaseCryptoStateStore): - @staticmethod - async def find_shared_rooms(user_id: UserID) -> list[RoomID]: - return [profile.room_id for profile in UserProfile.find_rooms_with_user(user_id)] - - @staticmethod - async def get_encryption_info(room_id: RoomID) -> RoomEncryptionStateEventContent | None: - state = RoomState.get(room_id) - if not state: - return None - return state.encryption - -except ImportError: - if __optional_imports__: - raise - UserProfile = None - RoomState = None - SQLCryptoStateStore = None - - class PgCryptoStateStore(BaseCryptoStateStore): db: Database diff --git a/mautrix/bridge/custom_puppet.py b/mautrix/bridge/custom_puppet.py index b3c15c19..f5befd5f 100644 --- a/mautrix/bridge/custom_puppet.py +++ b/mautrix/bridge/custom_puppet.py @@ -1,24 +1,20 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Awaitable, Iterator from abc import ABC, abstractmethod -from itertools import chain import asyncio import hashlib import hmac -import json import logging -from aiohttp import ClientConnectionError from yarl import URL -from mautrix.api import Path from mautrix.appservice import AppService, IntentAPI +from mautrix.client import ClientAPI from mautrix.errors import ( IntentError, MatrixError, @@ -26,21 +22,7 @@ MatrixRequestError, WellKnownError, ) -from mautrix.types import ( - Event, - EventFilter, - EventType, - Filter, - FilterID, - LoginType, - PresenceState, - RoomEventFilter, - RoomFilter, - RoomID, - StateFilter, - SyncToken, - UserID, -) +from mautrix.types import LoginType, MatrixUserIdentifier, RoomID, UserID from .. import bridge as br @@ -56,12 +38,24 @@ def __init__(self): class OnlyLoginSelf(CustomPuppetError): def __init__(self): - super().__init__("You may only replace your puppet with your own Matrix account.") + super().__init__("You may only enable double puppeting with your own Matrix account.") + + +class EncryptionKeysFound(CustomPuppetError): + def __init__(self): + super().__init__( + "The given access token is for a device that has encryption keys set up. " + "Please provide a fresh token, don't reuse one from another client." + ) class HomeserverURLNotFound(CustomPuppetError): def __init__(self, domain: str): - super().__init__(f"Could not discover a valid homeserver URL for {domain}") + super().__init__( + f"Could not discover a valid homeserver URL for {domain}." + " Please ensure a client .well-known file is set up, or ask the bridge administrator " + "to add the homeserver URL to the bridge config." + ) class OnlyLoginTrustedDomain(CustomPuppetError): @@ -102,7 +96,6 @@ class CustomPuppetMixin(ABC): intent: The primary IntentAPI. """ - sync_with_custom_puppets: bool = True allow_discover_url: bool = False homeserver_url_map: dict[str, URL] = {} only_handle_own_synced_events: bool = True @@ -121,12 +114,9 @@ class CustomPuppetMixin(ABC): custom_mxid: UserID | None access_token: str | None base_url: URL | None - next_batch: SyncToken | None intent: IntentAPI - _sync_task: asyncio.Task | None = None - @abstractmethod async def save(self) -> None: """Save the information of this puppet. Called from :meth:`switch_mxid`""" @@ -142,6 +132,25 @@ def is_real_user(self) -> bool: return bool(self.custom_mxid and self.access_token) def _fresh_intent(self) -> IntentAPI: + if self.custom_mxid: + _, server = self.az.intent.parse_user_id(self.custom_mxid) + try: + self.base_url = self.homeserver_url_map[server] + except KeyError: + if server == self.az.domain: + self.base_url = self.az.intent.api.base_url + if self.access_token == "appservice-config" and self.custom_mxid: + try: + secret = self.login_shared_secret_map[server] + except KeyError: + raise AutologinError(f"No shared secret configured for {server}") + self.log.debug(f"Using as_token for double puppeting {self.custom_mxid}") + return self.az.intent.user( + self.custom_mxid, + secret.decode("utf-8").removeprefix("as_token:"), + self.base_url, + as_token=True, + ) return ( self.az.intent.user(self.custom_mxid, self.access_token, self.base_url) if self.is_real_user @@ -162,6 +171,8 @@ async def _login_with_shared_secret(cls, mxid: UserID) -> str: secret = cls.login_shared_secret_map[server] except KeyError: raise AutologinError(f"No shared secret configured for {server}") + if secret.startswith(b"as_token:"): + return "appservice-config" try: base_url = cls.homeserver_url_map[server] except KeyError: @@ -169,30 +180,32 @@ async def _login_with_shared_secret(cls, mxid: UserID) -> str: base_url = cls.az.intent.api.base_url else: raise AutologinError(f"No homeserver URL configured for {server}") - password = hmac.new(secret, mxid.encode("utf-8"), hashlib.sha512).hexdigest() - url = base_url / str(Path.login) - resp = await cls.az.http_session.post( - url, - data=json.dumps( - { - "type": str(LoginType.PASSWORD), - "initial_device_display_name": cls.login_device_name, - "device_id": cls.login_device_name, - "identifier": { - "type": "m.id.user", - "user": mxid, - }, - "password": password, - } - ), - headers={"Content-Type": "application/json"}, + client = ClientAPI(base_url=base_url) + login_args = {} + if secret == b"appservice": + login_type = LoginType.APPSERVICE + client.api.token = cls.az.as_token + else: + flows = await client.get_login_flows() + flow = flows.get_first_of_type(LoginType.DEVTURE_SHARED_SECRET, LoginType.PASSWORD) + if not flow: + raise AutologinError("No supported shared secret auth login flows") + login_type = flow.type + token = hmac.new(secret, mxid.encode("utf-8"), hashlib.sha512).hexdigest() + if login_type == LoginType.DEVTURE_SHARED_SECRET: + login_args["token"] = token + elif login_type == LoginType.PASSWORD: + login_args["password"] = token + resp = await client.login( + identifier=MatrixUserIdentifier(user=mxid), + device_id=cls.login_device_name, + initial_device_display_name=cls.login_device_name, + login_type=login_type, + **login_args, + store_access_token=False, + update_hs_url=False, ) - data = await resp.json() - try: - return data["access_token"] - except KeyError: - error_msg = data.get("error", data.get("errcode", f"HTTP {resp.status}")) - raise AutologinError(f"Didn't get an access token: {error_msg}") from None + return resp.access_token async def switch_mxid( self, access_token: str | None, mxid: UserID | None, start_sync_task: bool = True @@ -205,11 +218,11 @@ async def switch_mxid( the appservice-owned ID. mxid: The expected Matrix user ID of the custom account, or ``None`` when ``access_token`` is None. - start_sync_task: Whether or not syncing should be started after logging in. """ if access_token == "auto": access_token = await self._login_with_shared_secret(mxid) - self.log.debug(f"Logged in for {mxid} using shared secret") + if access_token != "appservice-config": + self.log.debug(f"Logged in for {mxid} using shared secret") if mxid is not None: _, mxid_domain = self.az.intent.parse_user_id(mxid) @@ -235,7 +248,7 @@ async def switch_mxid( self.base_url = base_url self.intent = self._fresh_intent() - await self.start(start_sync_task=start_sync_task) + await self.start(check_e2ee_keys=True) try: del self.by_custom_mxid[prev_mxid] @@ -255,7 +268,20 @@ async def try_start(self, retry_auto_login: bool = True) -> None: except Exception: self.log.exception("Failed to initialize custom mxid") - async def start(self, retry_auto_login: bool = False, start_sync_task: bool = True) -> None: + async def _invalidate_double_puppet(self) -> None: + if self.custom_mxid and self.by_custom_mxid.get(self.custom_mxid) == self: + del self.by_custom_mxid[self.custom_mxid] + self.custom_mxid = None + self.access_token = None + await self.save() + self.intent = self._fresh_intent() + + async def start( + self, + retry_auto_login: bool = False, + start_sync_task: bool = True, + check_e2ee_keys: bool = False, + ) -> None: """Initialize the custom account this puppet uses. Should be called at startup to start the /sync task. Is called by :meth:`switch_mxid` automatically.""" if not self.is_real_user: @@ -266,34 +292,37 @@ async def start(self, retry_auto_login: bool = False, start_sync_task: bool = Tr except MatrixInvalidToken as e: if retry_auto_login and self.custom_mxid and self.can_auto_login(self.custom_mxid): self.log.debug(f"Got {e.errcode} while trying to initialize custom mxid") - await self.switch_mxid("auto", self.custom_mxid, start_sync_task=start_sync_task) + await self.switch_mxid("auto", self.custom_mxid) return self.log.warning(f"Got {e.errcode} while trying to initialize custom mxid") whoami = None if not whoami or whoami.user_id != self.custom_mxid: - if self.custom_mxid and self.by_custom_mxid.get(self.custom_mxid) == self: - del self.by_custom_mxid[self.custom_mxid] - self.custom_mxid = None - self.access_token = None - self.next_batch = None - await self.save() - self.intent = self._fresh_intent() - if whoami.user_id != self.custom_mxid: + prev_custom_mxid = self.custom_mxid + await self._invalidate_double_puppet() + if whoami and whoami.user_id != prev_custom_mxid: raise OnlyLoginSelf() raise InvalidAccessToken() - if self.sync_with_custom_puppets and start_sync_task: - if self._sync_task: - self._sync_task.cancel() - self.log.info(f"Initialized custom mxid: {whoami.user_id}. Starting sync task") - self._sync_task = asyncio.create_task(self._try_sync()) - else: - self.log.info(f"Initialized custom mxid: {whoami.user_id}. Not starting sync task") + if check_e2ee_keys: + try: + devices = await self.intent.query_keys({whoami.user_id: [whoami.device_id]}) + device_keys = devices.device_keys.get(whoami.user_id, {}).get(whoami.device_id) + except Exception: + self.log.warning( + "Failed to query keys to check if double puppeting token was reused", + exc_info=True, + ) + else: + if device_keys and len(device_keys.keys) > 0: + await self._invalidate_double_puppet() + raise EncryptionKeysFound() + self.log.info(f"Initialized custom mxid: {whoami.user_id}") def stop(self) -> None: - """Cancel the sync task.""" - if self._sync_task: - self._sync_task.cancel() - self._sync_task = None + """ + No-op + + .. deprecated:: 0.20.1 + """ async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool: """ @@ -316,112 +345,3 @@ async def _leave_rooms_with_default_user(self) -> None: await self.intent.ensure_joined(room_id) except (IntentError, MatrixRequestError): pass - - def _create_sync_filter(self) -> Awaitable[FilterID]: - all_events = EventType.find("*") - return self.intent.create_filter( - Filter( - account_data=EventFilter(types=[all_events]), - presence=EventFilter( - types=[EventType.PRESENCE], - senders=[self.custom_mxid] if self.only_handle_own_synced_events else None, - ), - room=RoomFilter( - include_leave=False, - state=StateFilter(not_types=[all_events]), - timeline=RoomEventFilter(not_types=[all_events]), - account_data=RoomEventFilter(not_types=[all_events]), - ephemeral=RoomEventFilter( - types=[ - EventType.TYPING, - EventType.RECEIPT, - ] - ), - ), - ) - ) - - def _filter_events(self, room_id: RoomID, events: list[dict]) -> Iterator[Event]: - for event in events: - event["room_id"] = room_id - if self.only_handle_own_synced_events: - # We only want events about the custom puppet user, but we can't use - # filters for typing and read receipt events. - evt_type = EventType.find(event.get("type", None)) - event.setdefault("content", {}) - if evt_type == EventType.TYPING: - is_typing = self.custom_mxid in event["content"].get("user_ids", []) - event["content"]["user_ids"] = [self.custom_mxid] if is_typing else [] - elif evt_type == EventType.RECEIPT: - try: - event_id, receipt = event["content"].popitem() - data = receipt["m.read"][self.custom_mxid] - event["content"] = {event_id: {"m.read": {self.custom_mxid: data}}} - except KeyError: - continue - yield event - - def _handle_sync(self, sync_resp: dict) -> None: - # Get events from rooms -> join -> [room_id] -> ephemeral -> events (array) - ephemeral_events = ( - event - for room_id, data in sync_resp.get("rooms", {}).get("join", {}).items() - for event in self._filter_events(room_id, data.get("ephemeral", {}).get("events", [])) - ) - - # Get events from presence -> events (array) - presence_events = sync_resp.get("presence", {}).get("events", []) - - # Deserialize and handle all events - for event in chain(ephemeral_events, presence_events): - asyncio.create_task(self.mx.try_handle_sync_event(Event.deserialize(event))) - - async def _try_sync(self) -> None: - try: - await self._sync() - except asyncio.CancelledError: - self.log.info(f"Syncing for {self.custom_mxid} cancelled") - except Exception: - self.log.critical(f"Fatal error syncing {self.custom_mxid}", exc_info=True) - - async def _sync(self) -> None: - if not self.is_real_user: - self.log.warning("Called sync() for non-custom puppet.") - return - custom_mxid: UserID = self.custom_mxid - access_token_at_start: str = self.access_token - errors: int = 0 - filter_id: FilterID = await self._create_sync_filter() - self.log.debug(f"Starting syncer for {custom_mxid} with sync filter {filter_id}.") - while access_token_at_start == self.access_token: - try: - cur_batch = self.next_batch - sync_resp = await self.intent.sync( - filter_id=filter_id, since=cur_batch, set_presence=PresenceState.OFFLINE - ) - try: - self.next_batch = sync_resp.get("next_batch", None) - except Exception: - self.log.warning("Failed to store next batch", exc_info=True) - errors = 0 - if cur_batch is not None: - self._handle_sync(sync_resp) - except MatrixInvalidToken: - # TODO when not using syncing, we should still check this occasionally and relogin - self.log.warning(f"Access token for {custom_mxid} got invalidated, restarting...") - await self.start(retry_auto_login=True, start_sync_task=False) - if self.is_real_user: - self.log.info("Successfully relogined custom puppet, continuing sync") - filter_id = await self._create_sync_filter() - access_token_at_start = self.access_token - else: - self.log.warning("Something went wrong during relogin") - raise - except (MatrixError, ClientConnectionError, asyncio.TimeoutError) as e: - errors += 1 - wait = min(errors, 11) ** 2 - self.log.warning( - f"Syncer for {custom_mxid} errored: {e}. Waiting for {wait} seconds..." - ) - await asyncio.sleep(wait) - self.log.debug(f"Syncer for custom puppet {custom_mxid} stopped.") diff --git a/mautrix/bridge/e2ee.py b/mautrix/bridge/e2ee.py index 633f98de..1525b388 100644 --- a/mautrix/bridge/e2ee.py +++ b/mautrix/bridge/e2ee.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -11,19 +11,12 @@ from mautrix import __optional_imports__ from mautrix.appservice import AppService -from mautrix.client import Client, SyncStore -from mautrix.crypto import ( - CryptoStore, - DeviceIdentity, - OlmMachine, - PgCryptoStore, - RejectKeyShare, - StateStore, - TrustState, -) -from mautrix.errors import EncryptionError, SessionNotFound +from mautrix.client import Client, InternalEventType, SyncStore +from mautrix.crypto import CryptoStore, OlmMachine, PgCryptoStore, RejectKeyShare, StateStore +from mautrix.errors import EncryptionError, MForbidden, MNotFound, SessionNotFound from mautrix.types import ( JSON, + DeviceIdentity, EncryptedEvent, EncryptedMegolmEventContent, EventFilter, @@ -39,26 +32,15 @@ Serializable, StateEvent, StateFilter, + TrustState, ) +from mautrix.util import background_task +from mautrix.util.async_db import Database from mautrix.util.logging import TraceLogger from .. import bridge as br from .crypto_state_store import PgCryptoStateStore -try: - from mautrix.client.state_store.sqlalchemy import UserProfile -except ImportError: - if __optional_imports__: - raise - UserProfile = None - -try: - from mautrix.util.async_db import Database -except ImportError: - if __optional_imports__: - raise - Database = None - class EncryptionManager: loop: asyncio.AbstractEventLoop @@ -70,12 +52,21 @@ class EncryptionManager: crypto_db: Database | None state_store: StateStore + min_send_trust: TrustState + key_sharing_enabled: bool + appservice_mode: bool + periodically_delete_expired_keys: bool + delete_outdated_inbound: bool + msc4190: bool + self_sign: bool + bridge: br.Bridge az: AppService _id_prefix: str _id_suffix: str _share_session_events: dict[RoomID, asyncio.Event] + _key_delete_task: asyncio.Task | None def __init__( self, @@ -84,7 +75,6 @@ def __init__( user_id_prefix: str, user_id_suffix: str, db_url: str, - key_sharing_config: dict[str, bool] = None, ) -> None: self.loop = bridge.loop or asyncio.get_event_loop() self.bridge = bridge @@ -93,7 +83,6 @@ def __init__( self._id_prefix = user_id_prefix self._id_suffix = user_id_suffix self._share_session_events = {} - self.key_sharing_config = key_sharing_config or {} pickle_key = "mautrix.bridge.e2ee" self.crypto_db = Database.create( url=db_url, @@ -110,14 +99,48 @@ def __init__( sync_store=self.crypto_store, log=self.log.getChild("client"), default_retry_count=default_http_retry_count, + state_store=self.bridge.state_store, ) self.crypto = OlmMachine(self.client, self.crypto_store, self.state_store) + self.client.add_event_handler(InternalEventType.SYNC_STOPPED, self._exit_on_sync_fail) self.crypto.allow_key_share = self.allow_key_share + verification_levels = bridge.config["bridge.encryption.verification_levels"] + self.min_send_trust = TrustState.parse(verification_levels["send"]) + self.crypto.share_keys_min_trust = TrustState.parse(verification_levels["share"]) + self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"]) + self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"] + self.appservice_mode = bridge.config["bridge.encryption.appservice"] + self.msc4190 = bridge.config["bridge.encryption.msc4190"] + self.self_sign = bridge.config["bridge.encryption.self_sign"] + if self.appservice_mode: + self.az.otk_handler = self.crypto.handle_as_otk_counts + self.az.device_list_handler = self.crypto.handle_as_device_lists + self.az.to_device_handler = self.crypto.handle_as_to_device_event + + self.periodically_delete_expired_keys = False + self.delete_outdated_inbound = False + self._key_delete_task = None + del_cfg = bridge.config["bridge.encryption.delete_keys"] + if del_cfg: + self.crypto.delete_outbound_keys_on_ack = del_cfg["delete_outbound_on_ack"] + self.crypto.dont_store_outbound_keys = del_cfg["dont_store_outbound"] + self.crypto.delete_previous_keys_on_receive = del_cfg["delete_prev_on_new_session"] + self.crypto.ratchet_keys_on_decrypt = del_cfg["ratchet_on_decrypt"] + self.crypto.delete_fully_used_keys_on_decrypt = del_cfg["delete_fully_used_on_decrypt"] + self.crypto.delete_keys_on_device_delete = del_cfg["delete_on_device_delete"] + self.periodically_delete_expired_keys = del_cfg["periodically_delete_expired"] + self.delete_outdated_inbound = del_cfg["delete_outdated_inbound"] + self.crypto.disable_device_change_key_rotation = bridge.config[ + "bridge.encryption.rotation.disable_device_change_key_rotation" + ] + + async def _exit_on_sync_fail(self, data) -> None: + if data["error"]: + self.log.critical("Exiting due to crypto sync error") + sys.exit(32) async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInfo) -> bool: - require_verification = self.key_sharing_config.get("require_verification", True) - allow = self.key_sharing_config.get("allow", False) - if not allow: + if not self.key_sharing_enabled: self.log.debug( f"Key sharing not enabled, ignoring key request from " f"{device.user_id}/{device.device_id}" @@ -128,9 +151,9 @@ async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInf f"Rejecting key request from blacklisted device " f"{device.user_id}/{device.device_id}", code=RoomKeyWithheldCode.BLACKLISTED, - reason="You have been blacklisted by this device", + reason="Your device has been blacklisted by the bridge", ) - elif device.trust == TrustState.VERIFIED or not require_verification: + elif await self.crypto.resolve_trust(device) >= self.crypto.share_keys_min_trust: portal = await self.bridge.get_portal(request.room_id) if portal is None: raise RejectKeyShare( @@ -157,7 +180,7 @@ async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInf f"Rejecting key request from unverified device " f"{device.user_id}/{device.device_id}", code=RoomKeyWithheldCode.UNVERIFIED, - reason="You have not been verified by this device", + reason="Your device is not trusted by the bridge", ) def _ignore_user(self, user_id: str) -> bool: @@ -212,7 +235,7 @@ async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> M f" waiting {wait_session_timeout} seconds..." ) got_keys = await self.crypto.wait_for_session( - evt.room_id, e.sender_key, e.session_id, timeout=wait_session_timeout + evt.room_id, e.session_id, timeout=wait_session_timeout ) if got_keys: self.log.debug( @@ -226,39 +249,126 @@ async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> M return decrypted async def start(self) -> None: - flows = await self.client.get_login_flows() - flow = flows.get_first_of_type(LoginType.APPSERVICE, LoginType.UNSTABLE_APPSERVICE) - if flow is None: - self.log.critical( - "Encryption enabled in config, but homeserver does not " - "advertise appservice login" - ) - sys.exit(30) - self.log.debug(f"Logging in with bridge bot user (using login type {flow.type.value})") + if not self.msc4190: + flows = await self.client.get_login_flows() + if not flows.supports_type(LoginType.APPSERVICE): + self.log.critical( + "Encryption enabled in config, but homeserver does not support appservice login" + ) + sys.exit(30) + self.log.debug("Logging in with bridge bot user") if self.crypto_db: - await self.crypto_db.start() + try: + await self.crypto_db.start() + except Exception as e: + self.bridge._log_db_error(e) await self.crypto_store.open() device_id = await self.crypto_store.get_device_id() if device_id: self.log.debug(f"Found device ID in database: {device_id}") - # We set the API token to the AS token here to authenticate the appservice login - # It'll get overridden after the login - self.client.api.token = self.az.as_token - await self.client.login( - login_type=flow.type, - device_name=self.device_name, - device_id=device_id, - store_access_token=True, - update_hs_url=False, - ) + + if self.msc4190: + if not device_id: + self.log.debug("Creating bot device with MSC4190") + self.client.api.token = self.az.as_token + await self.client.create_device_msc4190( + device_id=device_id, initial_display_name=self.device_name + ) + else: + # We set the API token to the AS token here to authenticate the appservice login + # It'll get overridden after the login + self.client.api.token = self.az.as_token + await self.client.login( + login_type=LoginType.APPSERVICE, + device_name=self.device_name, + device_id=device_id, + store_access_token=True, + update_hs_url=False, + ) + await self.crypto.load() if not device_id: await self.crypto_store.put_device_id(self.client.device_id) self.log.debug(f"Logged in with new device ID {self.client.device_id}") - _ = self.client.start(self._filter) - self.log.info("End-to-bridge encryption support is enabled") + await self.crypto.share_keys() + elif self.crypto.account.shared: + await self._verify_keys_are_on_server() + else: + await self.crypto.share_keys() + if self.self_sign: + trust_state = await self.crypto.resolve_trust(self.crypto.own_identity) + if trust_state < TrustState.CROSS_SIGNED_UNTRUSTED: + recovery_key = await self.crypto.generate_recovery_key() + self.log.info(f"Generated recovery key and signed own device: {recovery_key}") + else: + self.log.debug(f"Own device is already verified ({trust_state})") + if self.appservice_mode: + self.log.info("End-to-bridge encryption support is enabled (appservice mode)") + else: + _ = self.client.start(self._filter) + self.log.info("End-to-bridge encryption support is enabled (sync mode)") + if self.delete_outdated_inbound: + deleted = await self.crypto_store.redact_outdated_group_sessions() + if len(deleted) > 0: + self.log.debug( + f"Deleted {len(deleted)} inbound keys which lacked expiration metadata" + ) + if self.periodically_delete_expired_keys: + self._key_delete_task = background_task.create(self._periodically_delete_keys()) + background_task.create(self._resync_encryption_info()) + + async def _resync_encryption_info(self) -> None: + rows = await self.crypto_db.fetch( + """SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'""" + ) + room_ids = [row["room_id"] for row in rows] + if not room_ids: + return + self.log.debug(f"Resyncing encryption state event in rooms: {room_ids}") + for room_id in room_ids: + try: + evt = await self.client.get_state_event(room_id, EventType.ROOM_ENCRYPTION) + except (MNotFound, MForbidden) as e: + self.log.debug(f"Failed to get encryption state in {room_id}: {e}") + q = """ + UPDATE mx_room_state SET encryption=NULL + WHERE room_id=$1 AND encryption='{"resync":true}' + """ + await self.crypto_db.execute(q, room_id) + else: + self.log.debug(f"Resynced encryption state in {room_id}: {evt}") + q = """ + UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2 + WHERE room_id=$3 AND max_age IS NULL and max_messages IS NULL + """ + await self.crypto_db.execute( + q, evt.rotation_period_ms, evt.rotation_period_msgs, room_id + ) + + async def _verify_keys_are_on_server(self) -> None: + self.log.debug("Making sure keys are still on server") + try: + resp = await self.client.query_keys([self.client.mxid]) + except Exception: + self.log.critical( + "Failed to query own keys to make sure device still exists", exc_info=True + ) + sys.exit(33) + try: + own_keys = resp.device_keys[self.client.mxid][self.client.device_id] + if len(own_keys.keys) > 0: + return + except KeyError: + pass + self.log.critical("Existing device doesn't have keys on server, resetting crypto") + await self.crypto.crypto_store.delete() + await self.client.logout_all() + sys.exit(34) async def stop(self) -> None: + if self._key_delete_task: + self._key_delete_task.cancel() + self._key_delete_task = None self.client.stop() await self.crypto_store.close() if self.crypto_db: @@ -278,3 +388,12 @@ def _filter(self) -> Filter: ephemeral=RoomEventFilter(not_types=[all_events]), ), ) + + async def _periodically_delete_keys(self) -> None: + while True: + deleted = await self.crypto_store.redact_expired_group_sessions() + if deleted: + self.log.info(f"Deleted expired megolm sessions: {deleted}") + else: + self.log.debug("No expired megolm sessions found") + await asyncio.sleep(24 * 60 * 60) diff --git a/mautrix/bridge/matrix.py b/mautrix/bridge/matrix.py index 8a9c2228..e5399094 100644 --- a/mautrix/bridge/matrix.py +++ b/mautrix/bridge/matrix.py @@ -1,10 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +from collections import defaultdict import asyncio import logging import sys @@ -23,6 +24,7 @@ ) from mautrix.types import ( BaseRoomEvent, + BeeperMessageStatusEventContent, EncryptedEvent, Event, EventID, @@ -32,21 +34,30 @@ MemberStateEventContent, MessageEvent, MessageEventContent, + MessageStatus, + MessageStatusReason, MessageType, PresenceEvent, ReactionEvent, ReceiptEvent, ReceiptType, RedactionEvent, + RelatesTo, + RelationType, RoomID, + RoomType, SingleReceiptEventContent, + SpecVersions, StateEvent, StateUnsigned, TextMessageEventContent, + TrustState, TypingEvent, UserID, + Version, + VersionsResponse, ) -from mautrix.util import markdown +from mautrix.util import background_task, markdown from mautrix.util.logging import TraceLogger from mautrix.util.message_send_checkpoint import ( CHECKPOINT_TYPES, @@ -84,6 +95,44 @@ ) +class UnencryptedMessageError(DecryptionError): + def __init__(self) -> None: + super().__init__("unencrypted message") + + @property + def human_message(self) -> str: + return "the message is not encrypted" + + +class EncryptionUnsupportedError(DecryptionError): + def __init__(self) -> None: + super().__init__("encryption is not supported") + + @property + def human_message(self) -> str: + return "the bridge is not configured to support encryption" + + +class DeviceUntrustedError(DecryptionError): + def __init__(self, trust: TrustState) -> None: + explanation = { + TrustState.BLACKLISTED: "device is blacklisted", + TrustState.UNVERIFIED: "unverified", + TrustState.UNKNOWN_DEVICE: "device info not found", + TrustState.FORWARDED: "keys were forwarded from an unknown device", + TrustState.CROSS_SIGNED_UNTRUSTED: ( + "cross-signing keys changed after setting up the bridge" + ), + }.get(trust) + base = "your device is not trusted" + self.message = f"{base} ({explanation})" if explanation else base + super().__init__(self.message) + + @property + def human_message(self) -> str: + return self.message + + class BaseMatrixHandler: log: TraceLogger = logging.getLogger("mau.mx") az: AppService @@ -91,7 +140,11 @@ class BaseMatrixHandler: config: config.BaseBridgeConfig bridge: br.Bridge e2ee: EncryptionManager | None + require_e2ee: bool media_config: MediaRepoConfig + versions: VersionsResponse + minimum_spec_version: Version = SpecVersions.V11 + room_locks: dict[str, asyncio.Lock] user_id_prefix: str user_id_suffix: str @@ -106,9 +159,12 @@ def __init__( self.bridge = bridge self.commands = command_processor or cmd.CommandProcessor(bridge=bridge) self.media_config = MediaRepoConfig(upload_size=50 * 1024 * 1024) + self.versions = VersionsResponse.deserialize({"versions": ["v1.3"]}) self.az.matrix_event_handler(self.int_handle_event) + self.room_locks = defaultdict(asyncio.Lock) self.e2ee = None + self.require_e2ee = False if self.config["bridge.encryption.allow"]: if not EncryptionManager: self.log.fatal( @@ -129,8 +185,8 @@ def __init__( user_id_suffix=self.user_id_suffix, homeserver_address=self.config["homeserver.address"], db_url=self.config["appservice.database"], - key_sharing_config=self.config["bridge.encryption.key_sharing"], ) + self.require_e2ee = self.config["bridge.encryption.require"] self.management_room_text = self.config.get( "bridge.management_room_text", @@ -145,34 +201,66 @@ def __init__( False, ) + async def check_versions(self) -> None: + if not self.versions.supports_at_least(self.minimum_spec_version): + self.log.fatal( + "The homeserver is outdated " + "(server supports Matrix %s, but the bridge requires at least %s)", + self.versions.latest_version, + self.minimum_spec_version, + ) + sys.exit(18) + if self.bridge.homeserver_software.is_hungry and not self.versions.supports( + "com.beeper.hungry" + ): + self.log.fatal( + "The config claims the homeserver is hungryserv, " + "but the /versions response didn't confirm it" + ) + sys.exit(18) + async def wait_for_connection(self) -> None: self.log.info("Ensuring connectivity to homeserver") - errors = 0 - tried_to_register = False while True: try: - await self.az.intent.whoami() + self.versions = await self.az.intent.versions() break - except (MUnknownToken, MExclusive): - # These are probably not going to resolve themselves by waiting - raise except MForbidden: - if not tried_to_register: - self.log.debug( - "Whoami endpoint returned M_FORBIDDEN, " - "trying to register bridge bot before retrying..." - ) - await self.az.intent.ensure_registered() - tried_to_register = True - else: - raise + self.log.debug( + "/versions endpoint returned M_FORBIDDEN, " + "trying to register bridge bot before retrying..." + ) + await self.az.intent.ensure_registered() except Exception: - errors += 1 - if errors <= 6: - self.log.exception("Connection to homeserver failed, retrying in 10 seconds") - await asyncio.sleep(10) - else: - raise + self.log.exception("Connection to homeserver failed, retrying in 10 seconds") + await asyncio.sleep(10) + await self.check_versions() + try: + await self.az.intent.whoami() + except MForbidden: + self.log.debug( + "Whoami endpoint returned M_FORBIDDEN, " + "trying to register bridge bot before retrying..." + ) + await self.az.intent.ensure_registered() + await self.az.intent.whoami() + if self.versions.supports("fi.mau.msc2659.stable") or self.versions.supports_at_least( + SpecVersions.V17 + ): + try: + txn_id = self.az.intent.api.get_txn_id() + duration = await self.az.ping_self(txn_id) + self.log.debug( + "Homeserver->bridge connection works, " + f"roundtrip time is {duration} ms (txn ID: {txn_id})" + ) + except Exception: + self.log.exception("Error checking homeserver -> bridge connection") + sys.exit(16) + else: + self.log.debug( + "Homeserver does not support checking status of homeserver -> bridge connection" + ) try: self.media_config = await self.az.intent.get_media_repo_config() except Exception: @@ -196,6 +284,16 @@ async def init_as_bot(self) -> None: except Exception: self.log.exception("Failed to set bot avatar") + if self.bridge.homeserver_software.is_hungry and self.bridge.beeper_network_name: + self.log.debug("Setting contact info on the appservice bot") + await self.az.intent.beeper_update_profile( + { + "com.beeper.bridge.service": self.bridge.beeper_service_name, + "com.beeper.bridge.network": self.bridge.beeper_network_name, + "com.beeper.bridge.is_bridge_bot": True, + } + ) + async def init_encryption(self) -> None: if self.e2ee: await self.e2ee.start() @@ -213,6 +311,10 @@ async def allow_command(user: br.BaseUser) -> bool: async def allow_bridging_message(user: br.BaseUser, portal: br.BasePortal) -> bool: return await user.is_logged_in() or (user.relay_whitelisted and portal.has_relay) + @staticmethod + async def allow_puppet_invite(user: br.BaseUser, puppet: br.BasePuppet) -> bool: + return await user.is_logged_in() + async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None: pass @@ -234,6 +336,26 @@ async def handle_unban( async def handle_join(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None: pass + async def handle_knock( + self, room_id: RoomID, user_id: UserID, reason: str, event_id: EventID + ) -> None: + pass + + async def handle_retract_knock( + self, room_id: RoomID, user_id: UserID, reason: str, event_id: EventID + ) -> None: + pass + + async def handle_reject_knock( + self, room_id: RoomID, user_id: UserID, sender: UserID, reason: str, event_id: EventID + ) -> None: + pass + + async def handle_accept_knock( + self, room_id: RoomID, user_id: UserID, sender: UserID, reason: str, event_id: EventID + ) -> None: + pass + async def handle_member_info_change( self, room_id: RoomID, @@ -244,13 +366,82 @@ async def handle_member_info_change( ) -> None: pass + async def handle_puppet_group_invite( + self, + room_id: RoomID, + puppet: br.BasePuppet, + invited_by: br.BaseUser, + evt: StateEvent, + members: list[UserID], + ) -> None: + if self.az.bot_mxid not in members: + await puppet.default_mxid_intent.leave_room( + room_id, reason="This ghost does not join multi-user rooms without the bridge bot." + ) + + async def handle_puppet_dm_invite( + self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent + ) -> None: + portal = await invited_by.get_portal_with(puppet) + if portal: + await portal.accept_matrix_dm(room_id, invited_by, puppet) + else: + await puppet.default_mxid_intent.leave_room( + room_id, reason="This bridge does not support creating DMs." + ) + + async def handle_puppet_space_invite( + self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent + ) -> None: + await puppet.default_mxid_intent.leave_room( + room_id, reason="This ghost does not join spaces." + ) + + async def handle_puppet_nonportal_invite( + self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent + ) -> None: + intent = puppet.default_mxid_intent + await intent.join_room(room_id) + try: + create_evt = await intent.get_state_event(room_id, EventType.ROOM_CREATE) + members = await intent.get_room_members(room_id) + except MatrixError: + self.log.exception(f"Failed to get state after joining {room_id} as {intent.mxid}") + background_task.create(intent.leave_room(room_id, reason="Internal error")) + return + if create_evt.type == RoomType.SPACE: + await self.handle_puppet_space_invite(room_id, puppet, invited_by, evt) + elif len(members) > 2 or not evt.content.is_direct: + await self.handle_puppet_group_invite(room_id, puppet, invited_by, evt, members) + else: + await self.handle_puppet_dm_invite(room_id, puppet, invited_by, evt) + async def handle_puppet_invite( - self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, event_id: EventID + self, room_id: RoomID, puppet: br.BasePuppet, invited_by: br.BaseUser, evt: StateEvent ) -> None: - pass + intent = puppet.default_mxid_intent + if not await self.allow_puppet_invite(invited_by, puppet): + self.log.debug(f"Rejecting invite for {intent.mxid} to {room_id}: user can't invite") + await intent.leave_room(room_id, reason="You're not allowed to invite this ghost.") + return + + async with self.room_locks[room_id]: + portal = await self.bridge.get_portal(room_id) + if portal: + try: + await portal.handle_matrix_invite(invited_by, puppet) + except br.RejectMatrixInvite as e: + await intent.leave_room(room_id, reason=e.message) + except br.IgnoreMatrixInvite: + pass + else: + await intent.join_room(room_id) + return + else: + await self.handle_puppet_nonportal_invite(room_id, puppet, invited_by, evt) async def handle_invite( - self, room_id: RoomID, user_id: UserID, inviter: br.BaseUser, event_id: EventID + self, room_id: RoomID, user_id: UserID, invited_by: br.BaseUser, evt: StateEvent ) -> None: pass @@ -307,23 +498,11 @@ async def send_permission_error(self, room_id: RoomID) -> None: ) async def accept_bot_invite(self, room_id: RoomID, inviter: br.BaseUser) -> None: - tries = 0 - while tries < 5: - try: - await self.az.intent.join_room(room_id) - break - except (IntentError, MatrixError): - tries += 1 - wait_for_seconds = (tries + 1) * 10 - if tries < 5: - self.log.exception( - f"Failed to join room {room_id} with bridge bot, " - f"retrying in {wait_for_seconds} seconds..." - ) - await asyncio.sleep(wait_for_seconds) - else: - self.log.exception(f"Failed to join room {room_id}, giving up.") - return + try: + await self.az.intent.join_room(room_id) + except Exception: + self.log.exception(f"Failed to join room {room_id} as bridge bot") + return if not await self.allow_command(inviter): await self.send_permission_error(room_id) @@ -359,26 +538,22 @@ async def send_welcome_message(self, room_id: RoomID, inviter: br.BaseUser) -> N combined_html = "".join(map(markdown.render, welcome_messages)) await self.az.intent.send_notice(room_id, text=combined, html=combined_html) - async def int_handle_invite( - self, room_id: RoomID, user_id: UserID, invited_by: UserID, event_id: EventID - ) -> None: - self.log.debug(f"{invited_by} invited {user_id} to {room_id}") - inviter = await self.bridge.get_user(invited_by) + async def int_handle_invite(self, evt: StateEvent) -> None: + self.log.debug(f"{evt.sender} invited {evt.state_key} to {evt.room_id}") + inviter = await self.bridge.get_user(evt.sender) if inviter is None: - self.log.exception(f"Failed to find user with Matrix ID {invited_by}") - return - elif user_id == self.az.bot_mxid: - await self.accept_bot_invite(room_id, inviter) + self.log.exception(f"Failed to find user with Matrix ID {evt.sender}") return - elif not await self.allow_command(inviter): + elif evt.state_key == self.az.bot_mxid: + await self.accept_bot_invite(evt.room_id, inviter) return - puppet = await self.bridge.get_puppet(user_id) + puppet = await self.bridge.get_puppet(UserID(evt.state_key)) if puppet: - await self.handle_puppet_invite(room_id, puppet, inviter, event_id) + await self.handle_puppet_invite(evt.room_id, puppet, inviter, evt) return - await self.handle_invite(room_id, user_id, inviter, event_id) + await self.handle_invite(evt.room_id, UserID(evt.state_key), inviter, evt) def is_command(self, message: MessageEventContent) -> tuple[bool, str]: text = message.body @@ -388,33 +563,95 @@ def is_command(self, message: MessageEventContent) -> tuple[bool, str]: text = text[len(prefix) + 1 :].lstrip() return is_command, text - async def handle_message( - self, room_id: RoomID, user_id: UserID, message: MessageEventContent, event_id: EventID + async def _send_mss( + self, + evt: Event, + status: MessageStatus, + reason: MessageStatusReason | None = None, + error: str | None = None, + message: str | None = None, ) -> None: - async def bail(error_text: str, step=MessageSendCheckpointStep.REMOTE) -> None: - self.log.debug(error_text) - await MessageSendCheckpoint( - event_id=event_id, - room_id=room_id, - step=step, - timestamp=int(time.time() * 1000), - status=MessageSendCheckpointStatus.PERM_FAILURE, - reported_by=MessageSendCheckpointReportedBy.BRIDGE, - event_type=EventType.ROOM_MESSAGE, - message_type=message.msgtype, - info=error_text, - ).send( - self.bridge.config["homeserver.message_send_checkpoint_endpoint"], - self.az.as_token, - self.log, - ) + if not self.config.get("bridge.message_status_events", False): + return + status_content = BeeperMessageStatusEventContent( + network="", # TODO set network properly + relates_to=RelatesTo(rel_type=RelationType.REFERENCE, event_id=evt.event_id), + status=status, + reason=reason, + error=error, + message=message, + ) + await self.az.intent.send_message_event( + evt.room_id, EventType.BEEPER_MESSAGE_STATUS, status_content + ) + + async def _send_crypto_status_error( + self, + evt: Event, + err: DecryptionError | None = None, + retry_num: int = 0, + is_final: bool = True, + edit: EventID | None = None, + wait_for: int | None = None, + ) -> EventID | None: + msg = str(err) + if isinstance(err, (SessionNotFound, UnencryptedMessageError)): + msg = err.human_message + self._send_message_checkpoint( + evt, MessageSendCheckpointStep.DECRYPTED, msg, permanent=is_final, retry_num=retry_num + ) + + if wait_for: + msg += f". The bridge will retry for {wait_for} seconds" + full_msg = f"\u26a0 Your message was not bridged: {msg}." + if isinstance(err, EncryptionUnsupportedError): + full_msg = "🔒️ This bridge has not been configured to support encryption" + event_id = None + if self.config.get("bridge.delivery_error_reports", True): + try: + content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=full_msg) + if edit: + content.set_edit(edit) + event_id = await self.az.intent.send_message(evt.room_id, content) + except IntentError: + self.log.debug("IntentError while sending encryption error", exc_info=True) + self.log.error( + "Got IntentError while trying to send encryption error message. " + "This likely means the bridge bot is not in the room, which can " + "happen if you force-enable e2ee on the homeserver without enabling " + "it by default on the bridge (bridge -> encryption -> default)." + ) + + await self._send_mss( + evt, + status=MessageStatus.RETRIABLE if is_final else MessageStatus.PENDING, + reason=MessageStatusReason.UNDECRYPTABLE, + error=str(err), + message=err.human_message if err else None, + ) + + return event_id + + async def handle_message(self, evt: MessageEvent, was_encrypted: bool = False) -> None: + room_id = evt.room_id + user_id = evt.sender + event_id = evt.event_id + message = evt.content + + if not was_encrypted and self.require_e2ee: + self.log.warning(f"Dropping {event_id} from {user_id} as it's not encrypted!") + await self._send_crypto_status_error(evt, UnencryptedMessageError(), 0) + return sender = await self.bridge.get_user(user_id) if not sender or not await self.allow_message(sender): - await bail( + self.log.debug( f"Ignoring message {event_id} from {user_id} to {room_id}:" " user is not whitelisted." ) + self._send_message_checkpoint( + evt, MessageSendCheckpointStep.BRIDGE, "user is not whitelisted" + ) return self.log.debug(f"Received Matrix event {event_id} from {sender.mxid} in {room_id}") self.log.trace("Event %s content: %s", event_id, message) @@ -428,19 +665,31 @@ async def bail(error_text: str, step=MessageSendCheckpointStep.REMOTE) -> None: if await self.allow_bridging_message(sender, portal): await portal.handle_matrix_message(sender, message, event_id) else: - await bail( + self.log.debug( f"Ignoring event {event_id} from {sender.mxid}:" " not allowed to send to portal" ) + self._send_message_checkpoint( + evt, + MessageSendCheckpointStep.BRIDGE, + "user is not allowed to send to the portal", + ) return if message.msgtype != MessageType.TEXT: - await bail(f"Ignoring event {event_id}: not a portal room and not a m.text message") + self.log.debug( + f"Ignoring event {event_id}: not a portal room and not a m.text message" + ) + self._send_message_checkpoint( + evt, MessageSendCheckpointStep.BRIDGE, "not a portal room and not a m.text message" + ) return elif not await self.allow_command(sender): - await bail( - f"Ignoring command {event_id} from {sender.mxid}: not allowed to perform command", - step=MessageSendCheckpointStep.COMMAND, + self.log.debug( + f"Ignoring command {event_id} from {sender.mxid}: not allowed to run commands" + ) + self._send_message_checkpoint( + evt, MessageSendCheckpointStep.COMMAND, "not allowed to run commands" ) return @@ -469,7 +718,15 @@ async def bail(error_text: str, step=MessageSendCheckpointStep.REMOTE) -> None: bridge_bot_in_room, ) except Exception as e: - await bail(repr(e), step=MessageSendCheckpointStep.COMMAND) + self.log.debug(f"Error handling command {command} from {sender}: {e}") + self._send_message_checkpoint(evt, MessageSendCheckpointStep.COMMAND, e) + await self._send_mss( + evt, + status=MessageStatus.FAIL, + reason=MessageStatusReason.GENERIC_ERROR, + error="", + message="Command execution failed", + ) else: await MessageSendCheckpoint( event_id=event_id, @@ -485,11 +742,22 @@ async def bail(error_text: str, step=MessageSendCheckpointStep.REMOTE) -> None: self.az.as_token, self.log, ) + await self._send_mss(evt, status=MessageStatus.SUCCESS) else: - await bail( + self.log.debug( f"Ignoring event {event_id} from {sender.mxid}:" " not a command and not a portal room" ) + self._send_message_checkpoint( + evt, MessageSendCheckpointStep.COMMAND, "not a command and not a portal room" + ) + await self._send_mss( + evt, + status=MessageStatus.FAIL, + reason=MessageStatusReason.UNSUPPORTED, + error="Unknown room", + message="Unknown room", + ) async def _is_direct_chat(self, room_id: RoomID) -> tuple[bool, bool]: try: @@ -500,7 +768,7 @@ async def _is_direct_chat(self, room_id: RoomID) -> tuple[bool, bool]: async def handle_receipt(self, evt: ReceiptEvent) -> None: for event_id, receipts in evt.content.items(): - for user_id, data in receipts[ReceiptType.READ].items(): + for user_id, data in receipts.get(ReceiptType.READ, {}).items(): user = await self.bridge.get_user(user_id, create=False) if not user or not await user.is_logged_in(): continue @@ -537,41 +805,49 @@ async def try_handle_sync_event(self, evt: Event) -> None: except Exception: self.log.exception("Error handling manually received Matrix event") - async def send_encryption_error_notice( - self, evt: EncryptedEvent, error: DecryptionError + async def _post_decrypt( + self, evt: Event, retry_num: int = 0, error_event_id: EventID | None = None ) -> None: - await self.az.intent.send_notice( - evt.room_id, f"\u26a0 Your message was not bridged: {error}" + trust_state = evt["mautrix"]["trust_state"] + if trust_state < self.e2ee.min_send_trust: + self.log.warning( + f"Dropping {evt.event_id} from {evt.sender} due to insufficient verification level" + f" (event: {trust_state}, required: {self.e2ee.min_send_trust})" + ) + await self._send_crypto_status_error( + evt, + retry_num=retry_num, + err=DeviceUntrustedError(trust_state), + edit=error_event_id, + ) + return + + self._send_message_checkpoint( + evt, MessageSendCheckpointStep.DECRYPTED, retry_num=retry_num ) + if error_event_id: + await self.az.intent.redact(evt.room_id, error_event_id) + await self.int_handle_event(evt, was_encrypted=True) async def handle_encrypted(self, evt: EncryptedEvent) -> None: if not self.e2ee: - self.send_decrypted_checkpoint(evt, "Encryption unsupported", True) - await self.handle_encrypted_unsupported(evt) + self.log.debug( + "Got encrypted message %s from %s, but encryption is not enabled", + evt.event_id, + evt.sender, + ) + await self._send_crypto_status_error(evt, EncryptionUnsupportedError()) return try: - decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=5) + decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=3) except SessionNotFound as e: - self.send_decrypted_checkpoint(evt, e, False) - await self._handle_encrypted_wait(evt, e, wait=15) + await self._handle_encrypted_wait(evt, e, wait=22) except DecryptionError as e: self.log.warning(f"Failed to decrypt {evt.event_id}: {e}") self.log.trace("%s decryption traceback:", evt.event_id, exc_info=True) - self.send_decrypted_checkpoint(evt, e, True) - await self.send_encryption_error_notice(evt, e) + await self._send_crypto_status_error(evt, e) else: - self.send_decrypted_checkpoint(decrypted) - await self.int_handle_event(decrypted, send_bridge_checkpoint=False) - - async def handle_encrypted_unsupported(self, evt: EncryptedEvent) -> None: - self.log.debug( - "Got encrypted message %s from %s, but encryption is not enabled", - evt.event_id, - evt.sender, - ) - await self.az.intent.send_notice( - evt.room_id, "🔒️ This bridge has not been configured to support encryption" - ) + await self._post_decrypt(decrypted) async def _handle_encrypted_wait( self, evt: EncryptedEvent, err: SessionNotFound, wait: int @@ -580,12 +856,7 @@ async def _handle_encrypted_wait( f"Couldn't find session {err.session_id} trying to decrypt {evt.event_id}," " waiting even longer" ) - msg = ( - "\u26a0 Your message was not bridged: the bridge hasn't received the decryption " - f"keys. The bridge will retry for {wait} seconds. If this error keeps happening, " - "try restarting your client." - ) - asyncio.create_task( + background_task.create( self.e2ee.crypto.request_room_key( evt.room_id, evt.content.sender_key, @@ -593,19 +864,9 @@ async def _handle_encrypted_wait( from_devices={evt.sender: [evt.content.device_id]}, ) ) - try: - event_id = await self.az.intent.send_notice(evt.room_id, msg) - except IntentError: - self.log.debug("IntentError while sending encryption error", exc_info=True) - self.log.error( - "Got IntentError while trying to send encryption error message. " - "This likely means the bridge bot is not in the room, which can " - "happen if you force-enable e2ee on the homeserver without enabling " - "it by default on the bridge (bridge -> encryption -> default)." - ) - return + event_id = await self._send_crypto_status_error(evt, err, is_final=False, wait_for=wait) got_keys = await self.e2ee.crypto.wait_for_session( - evt.room_id, err.sender_key, err.session_id, timeout=wait + evt.room_id, err.session_id, timeout=wait ) if got_keys: self.log.debug( @@ -615,26 +876,17 @@ async def _handle_encrypted_wait( try: decrypted = await self.e2ee.decrypt(evt, wait_session_timeout=0) except DecryptionError as e: - self.send_decrypted_checkpoint(evt, e, True, retry_num=1) + await self._send_crypto_status_error(evt, e, retry_num=1, edit=event_id) self.log.warning(f"Failed to decrypt {evt.event_id}: {e}") self.log.trace("%s decryption traceback:", evt.event_id, exc_info=True) - msg = f"\u26a0 Your message was not bridged: {e}" else: - self.send_decrypted_checkpoint(decrypted, retry_num=1) - await self.az.intent.redact(evt.room_id, event_id) - await self.int_handle_event(decrypted, send_bridge_checkpoint=False) + await self._post_decrypt(decrypted, retry_num=1, error_event_id=event_id) return else: - error_message = f"Didn't get {err.session_id}, giving up on {evt.event_id}" - self.log.warning(error_message) - self.send_decrypted_checkpoint(evt, error_message, True, retry_num=1) - msg = ( - "\u26a0 Your message was not bridged: the bridge hasn't received the decryption" - " keys. If this error keeps happening, try restarting your client." + self.log.warning(f"Didn't get {err.session_id}, giving up on {evt.event_id}") + await self._send_crypto_status_error( + evt, SessionNotFound(err.session_id), retry_num=1, edit=event_id ) - content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=msg) - content.set_edit(event_id) - await self.az.intent.send_message(evt.room_id, content) async def handle_encryption(self, evt: StateEvent) -> None: await self.az.state_store.set_encryption_info(evt.room_id, evt.content) @@ -646,12 +898,12 @@ async def handle_encryption(self, evt: StateEvent) -> None: portal.log.debug("Received encryption event in direct portal: %s", evt.content) await portal.enable_dm_encryption() - def send_message_send_checkpoint( + def _send_message_checkpoint( self, evt: Event, step: MessageSendCheckpointStep, err: Exception | str | None = None, - permanent: bool = False, + permanent: bool = True, retry_num: int = 0, ) -> None: endpoint = self.bridge.config["homeserver.message_send_checkpoint_endpoint"] @@ -681,21 +933,7 @@ def send_message_send_checkpoint( info=str(err) if err else None, retry_num=retry_num, ) - asyncio.create_task(checkpoint.send(endpoint, self.az.as_token, self.log)) - - def send_bridge_checkpoint(self, evt: Event) -> None: - self.send_message_send_checkpoint(evt, MessageSendCheckpointStep.BRIDGE) - - def send_decrypted_checkpoint( - self, - evt: Event, - err: Exception | str | None = None, - permanent: bool = False, - retry_num: int = 0, - ) -> None: - self.send_message_send_checkpoint( - evt, MessageSendCheckpointStep.DECRYPTED, err, permanent, retry_num - ) + background_task.create(checkpoint.send(endpoint, self.az.as_token, self.log)) allowed_event_classes: tuple[type, ...] = ( MessageEvent, @@ -727,15 +965,15 @@ async def allow_matrix_event(self, evt: Event) -> bool: # For non-room events and non-bridge-originated room events, allow. return True - async def int_handle_event(self, evt: Event, send_bridge_checkpoint: bool = True) -> None: + async def int_handle_event(self, evt: Event, was_encrypted: bool = False) -> None: if isinstance(evt, StateEvent) and evt.type == EventType.ROOM_MEMBER and self.e2ee: await self.e2ee.handle_member_event(evt) if not await self.allow_matrix_event(evt): return self.log.trace("Received event: %s", evt) - if send_bridge_checkpoint: - self.send_bridge_checkpoint(evt) + if not was_encrypted: + self._send_message_checkpoint(evt, MessageSendCheckpointStep.BRIDGE) start_time = time.time() if evt.type == EventType.ROOM_MEMBER: @@ -744,9 +982,16 @@ async def int_handle_event(self, evt: Event, send_bridge_checkpoint: bool = True prev_content = unsigned.prev_content or MemberStateEventContent() prev_membership = prev_content.membership if prev_content else Membership.JOIN if evt.content.membership == Membership.INVITE: - await self.int_handle_invite( - evt.room_id, UserID(evt.state_key), evt.sender, evt.event_id - ) + if prev_membership == Membership.KNOCK: + await self.handle_accept_knock( + evt.room_id, + UserID(evt.state_key), + evt.sender, + evt.content.reason, + evt.event_id, + ) + else: + await self.int_handle_invite(evt) elif evt.content.membership == Membership.LEAVE: if prev_membership == Membership.BAN: await self.handle_unban( @@ -769,6 +1014,20 @@ async def int_handle_event(self, evt: Event, send_bridge_checkpoint: bool = True evt.content.reason, evt.event_id, ) + elif prev_membership == Membership.KNOCK: + if evt.sender == evt.state_key: + await self.handle_retract_knock( + evt.room_id, UserID(evt.state_key), evt.content.reason, evt.event_id + ) + else: + await self.handle_reject_knock( + evt.room_id, + UserID(evt.state_key), + evt.sender, + evt.content.reason, + evt.event_id, + ) + elif evt.sender == evt.state_key: await self.handle_leave(evt.room_id, UserID(evt.state_key), evt.event_id) else: @@ -794,11 +1053,18 @@ async def int_handle_event(self, evt: Event, send_bridge_checkpoint: bool = True await self.handle_member_info_change( evt.room_id, UserID(evt.state_key), evt.content, prev_content, evt.event_id ) + elif evt.content.membership == Membership.KNOCK: + await self.handle_knock( + evt.room_id, + UserID(evt.state_key), + evt.content.reason, + evt.event_id, + ) elif evt.type in (EventType.ROOM_MESSAGE, EventType.STICKER): evt: MessageEvent if evt.type != EventType.ROOM_MESSAGE: evt.content.msgtype = MessageType(str(evt.type)) - await self.handle_message(evt.room_id, evt.sender, evt.content, evt.event_id) + await self.handle_message(evt, was_encrypted=was_encrypted) elif evt.type == EventType.ROOM_ENCRYPTED: await self.handle_encrypted(evt) elif evt.type == EventType.ROOM_ENCRYPTION: diff --git a/mautrix/bridge/notification_disabler.py b/mautrix/bridge/notification_disabler.py index 51f52ddf..bd53a6ea 100644 --- a/mautrix/bridge/notification_disabler.py +++ b/mautrix/bridge/notification_disabler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -34,7 +34,7 @@ def __init__(self, room_id: RoomID, user: BaseUser) -> None: @property def _path(self) -> PathBuilder: - return Path.pushrules["global"].override[ + return Path.v3.pushrules["global"].override[ f"net.maunium.silence_while_backfilling:{self.room_id}" ] diff --git a/mautrix/bridge/portal.py b/mautrix/bridge/portal.py index 3a1e0f8b..05d67e3b 100644 --- a/mautrix/bridge/portal.py +++ b/mautrix/bridge/portal.py @@ -1,11 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, ClassVar, NamedTuple +from typing import Any, NamedTuple from abc import ABC, abstractmethod from collections import defaultdict from string import Template @@ -14,9 +14,10 @@ import logging import time -from mautrix.appservice import DOUBLE_PUPPET_SOURCE_KEY, AppService, IntentAPI -from mautrix.errors import MatrixError, MatrixRequestError, MNotFound +from mautrix.appservice import AppService, IntentAPI +from mautrix.errors import MatrixError, MatrixRequestError, MForbidden, MNotFound from mautrix.types import ( + JSON, EncryptionAlgorithm, EventID, EventType, @@ -25,8 +26,11 @@ MessageType, RoomEncryptionStateEventContent, RoomID, + RoomTombstoneStateEventContent, + TextMessageEventContent, UserID, ) +from mautrix.util import background_task from mautrix.util.logging import TraceLogger from mautrix.util.simple_lock import SimpleLock @@ -38,6 +42,24 @@ class RelaySender(NamedTuple): is_relay: bool +class RejectMatrixInvite(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class IgnoreMatrixInvite(Exception): + pass + + +class DMCreateError(RejectMatrixInvite): + """ + An error raised by :meth:`BasePortal.prepare_dm` if the DM can't be set up. + + The message in the exception will be sent to the user as a message before the ghost leaves. + """ + + class BasePortal(ABC): log: TraceLogger = logging.getLogger("mau.portal") _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) @@ -66,12 +88,151 @@ def __init__(self) -> None: async def save(self) -> None: pass + @abstractmethod + async def get_dm_puppet(self) -> br.BasePuppet | None: + """ + Get the ghost representing the other end of this direct chat. + + Returns: + A puppet entity, or ``None`` if this is not a 1:1 chat. + """ + @abstractmethod async def handle_matrix_message( self, sender: br.BaseUser, message: MessageEventContent, event_id: EventID ) -> None: pass + async def prepare_remote_dm( + self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet + ) -> str: + """ + Do whatever is needed on the remote platform to set up a direct chat between the user + and the ghost. By default, this does nothing (and lets :meth:`setup_matrix_dm` handle + everything). + + Args: + room_id: The room ID that will be used. + invited_by: The Matrix user who invited the ghost. + puppet: The ghost who was invited. + + Returns: + A simple message indicating what was done (will be sent as a notice to the room). + If empty, the message won't be sent. + + Raises: + DMCreateError: if the DM could not be created and the ghost should leave the room. + """ + return "Portal to private chat created." + + async def postprocess_matrix_dm(self, user: br.BaseUser, puppet: br.BasePuppet) -> None: + await self.update_bridge_info() + + async def reject_duplicate_dm( + self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet + ) -> None: + try: + await puppet.default_mxid_intent.send_notice( + room_id, + text=f"You already have a private chat with me: {self.mxid}", + html=( + "You already have a private chat with me: " + f"Link to room" + ), + ) + except Exception as e: + self.log.debug(f"Failed to send notice to duplicate private chat room: {e}") + + try: + await puppet.default_mxid_intent.send_state_event( + room_id, + event_type=EventType.ROOM_TOMBSTONE, + content=RoomTombstoneStateEventContent( + replacement_room=self.mxid, + body="You already have a private chat with me", + ), + ) + except Exception as e: + self.log.debug(f"Failed to send tombstone to duplicate private chat room: {e}") + + await puppet.default_mxid_intent.leave_room(room_id) + + async def accept_matrix_dm( + self, room_id: RoomID, invited_by: br.BaseUser, puppet: br.BasePuppet + ) -> None: + """ + Set up a room as a direct chat portal. + + The ghost has already accepted the invite at this point, so this method needs to make it + leave if the DM can't be created for some reason. + + By default, this checks if there's an existing portal and redirects the user there if it + does exist. If a portal doesn't exist, this will call :meth:`prepare_matrix_dm` and then + save the room ID, enable encryption and update bridge info. If the portal exists, but isn't + usable, the old room will be cleaned up and the function will continue. + + Args: + room_id: The room ID that will be used. + invited_by: The Matrix user who invited the ghost. + puppet: The ghost who was invited. + """ + if self.mxid: + try: + portal_members = await self.main_intent.get_room_members(self.mxid) + except (MForbidden, MNotFound): + portal_members = [] + if invited_by.mxid in portal_members: + await self.reject_duplicate_dm(room_id, invited_by, puppet) + return + self.log.debug( + f"{invited_by.mxid} isn't in old portal room {self.mxid}," + " cleaning up and accepting new room as the DM portal" + ) + await self.cleanup_portal( + message="User seems to have left DM portal", puppets_only=True + ) + try: + message = await self.prepare_remote_dm(room_id, invited_by, puppet) + except DMCreateError as e: + if e.message: + await puppet.default_mxid_intent.send_notice(room_id, text=e.message) + await puppet.default_mxid_intent.leave_room(room_id, reason="Failed to create DM") + return + self.mxid = room_id + e2be_ok = await self.check_dm_encryption() + await self.save() + if e2be_ok is False: + message += "\n\nWarning: Failed to enable end-to-bridge encryption." + if message: + await self._send_message( + puppet.default_mxid_intent, + TextMessageEventContent( + msgtype=MessageType.NOTICE, + body=message, + ), + ) + await self.postprocess_matrix_dm(invited_by, puppet) + + async def handle_matrix_invite(self, invited_by: br.BaseUser, puppet: br.BasePuppet) -> None: + """ + Called when a Matrix user invites a bridge ghost to a room to process the invite (and check + if it should be accepted). + + Args: + invited_by: The user who invited the ghost. + puppet: The ghost who was invited. + + Raises: + RejectMatrixInvite: if the invite should be rejected. + IgnoreMatrixInvite: if the invite should be ignored (e.g. if it was already accepted). + """ + if self.is_direct: + raise RejectMatrixInvite("You can't invite additional users to private chats.") + raise RejectMatrixInvite("This bridge does not implement inviting users to portals.") + + async def update_bridge_info(self) -> None: + """Resend the ``m.bridge`` event into the room.""" + @property def _relay_is_implemented(self) -> bool: return hasattr(self, "relay_user_id") and hasattr(self, "_relay_user") @@ -162,6 +323,13 @@ async def check_dm_encryption(self) -> bool | None: return await self.enable_dm_encryption() return None + def get_encryption_state_event_json(self) -> JSON: + evt = RoomEncryptionStateEventContent(EncryptionAlgorithm.MEGOLM_V1) + if self.bridge.config["bridge.encryption.rotation.enable_custom"]: + evt.rotation_period_ms = self.bridge.config["bridge.encryption.rotation.milliseconds"] + evt.rotation_period_msgs = self.bridge.config["bridge.encryption.rotation.messages"] + return evt.serialize() + async def enable_dm_encryption(self) -> bool: self.log.debug("Inviting bridge bot to room for end-to-bridge encryption") try: @@ -171,15 +339,29 @@ async def enable_dm_encryption(self) -> bool: await self.main_intent.send_state_event( self.mxid, EventType.ROOM_ENCRYPTION, - RoomEncryptionStateEventContent(EncryptionAlgorithm.MEGOLM_V1), + self.get_encryption_state_event_json(), ) except Exception: self.log.warning(f"Failed to enable end-to-bridge encryption", exc_info=True) return False self.encrypted = True + await self.update_info_from_puppet() return True + async def update_info_from_puppet(self, puppet: br.BasePuppet | None = None) -> None: + """ + Update the room metadata to match the ghost's name/avatar. + + This is called after enabling encryption, as the bridge bot needs to join for e2ee, + but that messes up the default name generation. If/when canonical DMs happen, + this might not be necessary anymore. + + Args: + puppet: The ghost that is the other participant in the room. + If ``None``, the entity should be fetched as necessary. + """ + @property def disappearing_enabled(self) -> bool: return bool(self.disappearing_msg_class) @@ -204,7 +386,7 @@ async def _disappear_event(self, msg: br.AbstractDisappearingMessage) -> None: await self._do_disappear(msg.event_id) self.log.debug(f"Expired event {msg.event_id} disappeared successfully") except Exception as e: - self.log.warning(f"Failed to make expired event {msg.event_id} disappear: {e}", e) + self.log.warning(f"Failed to make expired event {msg.event_id} disappear: {e}") async def _do_disappear(self, event_id: EventID) -> None: await self.main_intent.redact(self.mxid, event_id) @@ -221,7 +403,7 @@ async def restart_scheduled_disappearing(cls) -> None: for msg in msgs: portal = await cls.bridge.get_portal(msg.room_id) if portal and portal.mxid: - asyncio.create_task(portal._disappear_event(msg)) + background_task.create(portal._disappear_event(msg)) else: await msg.delete() @@ -237,7 +419,7 @@ async def schedule_disappearing(self) -> None: for msg in msgs: msg.start_timer() await msg.update() - asyncio.create_task(self._disappear_event(msg)) + background_task.create(self._disappear_event(msg)) async def _send_message( self, @@ -248,7 +430,10 @@ async def _send_message( ) -> EventID: if self.encrypted and self.matrix.e2ee: event_type, content = await self.matrix.e2ee.encrypt(self.mxid, event_type, content) - return await intent.send_message_event(self.mxid, event_type, content, **kwargs) + event_id = await intent.send_message_event(self.mxid, event_type, content, **kwargs) + if intent.api.is_real_user: + background_task.create(intent.mark_read(self.mxid, event_id)) + return event_id @property @abstractmethod @@ -274,6 +459,19 @@ async def cleanup_room( message: str = "Cleaning room", puppets_only: bool = False, ) -> None: + if not puppets_only and cls.bridge.homeserver_software.is_hungry: + try: + await intent.beeper_delete_room(room_id) + return + except MNotFound as err: + cls.log.debug(f"Hungryserv yeet returned {err}, assuming the room is already gone") + return + except Exception: + cls.log.warning( + f"Failed to delete {room_id} using hungryserv yeet endpoint, " + f"falling back to normal method", + exc_info=True, + ) try: members = await intent.get_room_members(room_id) except MatrixError: @@ -292,8 +490,7 @@ async def cleanup_room( left = False if custom_puppet: try: - extra_content = {DOUBLE_PUPPET_SOURCE_KEY: cls.bridge.name} - await custom_puppet.intent.leave_room(room_id, extra_content=extra_content) + await custom_puppet.intent.leave_room(room_id) await custom_puppet.intent.forget_room(room_id) except MatrixError: pass diff --git a/mautrix/bridge/puppet.py b/mautrix/bridge/puppet.py index 4f4e63d3..a09f3548 100644 --- a/mautrix/bridge/puppet.py +++ b/mautrix/bridge/puppet.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/bridge/state_store/__init__.py b/mautrix/bridge/state_store/__init__.py index e69de29b..b990ac86 100644 --- a/mautrix/bridge/state_store/__init__.py +++ b/mautrix/bridge/state_store/__init__.py @@ -0,0 +1 @@ +__all__ = ["asyncpg"] diff --git a/mautrix/bridge/state_store/asyncpg.py b/mautrix/bridge/state_store/asyncpg.py index 3456bd2a..d9c476ce 100644 --- a/mautrix/bridge/state_store/asyncpg.py +++ b/mautrix/bridge/state_store/asyncpg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/bridge/state_store/sqlalchemy.py b/mautrix/bridge/state_store/sqlalchemy.py deleted file mode 100644 index 26340021..00000000 --- a/mautrix/bridge/state_store/sqlalchemy.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) 2021 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from __future__ import annotations - -from typing import Awaitable, Callable, Union - -from mautrix.appservice.state_store.sqlalchemy import SQLASStateStore -from mautrix.types import UserID - -from ..puppet import BasePuppet - -GetPuppetFunc = Union[ - Callable[[UserID], Awaitable[BasePuppet]], Callable[[UserID, bool], Awaitable[BasePuppet]] -] - - -class SQLBridgeStateStore(SQLASStateStore): - def __init__(self, get_puppet: GetPuppetFunc, get_double_puppet: GetPuppetFunc) -> None: - super().__init__() - self.get_puppet = get_puppet - self.get_double_puppet = get_double_puppet - - async def is_registered(self, user_id: UserID) -> bool: - puppet = await self.get_puppet(user_id) - if puppet: - return puppet.is_registered - custom_puppet = await self.get_double_puppet(user_id) - if custom_puppet: - return True - return await super().is_registered(user_id) - - async def registered(self, user_id: UserID) -> None: - puppet = await self.get_puppet(user_id, True) - if puppet: - puppet.is_registered = True - await puppet.save() - else: - await super().registered(user_id) diff --git a/mautrix/bridge/user.py b/mautrix/bridge/user.py index 318ee82a..66860c04 100644 --- a/mautrix/bridge/user.py +++ b/mautrix/bridge/user.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -12,10 +12,11 @@ import logging import time -from mautrix.api import Method, UnstableClientPath +from mautrix.api import Method, Path from mautrix.appservice import AppService from mautrix.errors import MNotFound from mautrix.types import EventID, EventType, Membership, MessageType, RoomID, UserID +from mautrix.util import background_task from mautrix.util.bridge_state import BridgeState, BridgeStateEvent from mautrix.util.logging import TraceLogger from mautrix.util.message_send_checkpoint import ( @@ -28,7 +29,7 @@ from .. import bridge as br -AsmuxPath = UnstableClientPath["com.beeper.asmux"] +AsmuxPath = Path.unstable["com.beeper.asmux"] class WrappedTask(NamedTuple): @@ -68,9 +69,33 @@ def __init__(self) -> None: async def is_logged_in(self) -> bool: raise NotImplementedError() + @abstractmethod async def get_puppet(self) -> br.BasePuppet | None: + """ + Get the ghost that represents this Matrix user on the remote network. + + Returns: + The puppet entity, or ``None`` if the user is not logged in, + or it's otherwise not possible to find the remote ghost. + """ raise NotImplementedError() + @abstractmethod + async def get_portal_with( + self, puppet: br.BasePuppet, create: bool = True + ) -> br.BasePortal | None: + """ + Get a private chat portal between this user and the given ghost. + + Args: + puppet: The ghost who the portal should be with. + create: ``True`` if the portal entity should be created if it doesn't exist. + + Returns: + The portal entity, or ``None`` if it can't be found, + or doesn't exist and ``create`` is ``False``. + """ + async def needs_relay(self, portal: br.BasePortal) -> bool: return not await self.is_logged_in() @@ -104,7 +129,7 @@ async def update_direct_chats(self, dms: dict[UserID, list[RoomID]] | None = Non self.log.debug("Updating m.direct list on homeserver") replace = dms is None dms = dms or await self.get_direct_chats() - if self.bridge.config.get("homeserver.asmux", False): + if self.bridge.homeserver_software.is_asmux: # This uses a secret endpoint for atomically updating the DM list await puppet.intent.api.request( Method.PUT if replace else Method.PATCH, @@ -154,6 +179,8 @@ async def push_bridge_state( message: str | None = None, ttl: int | None = None, remote_id: str | None = None, + info: dict[str, Any] | None = None, + reason: str | None = None, ) -> None: if not self.bridge.config["homeserver.status_endpoint"]: return @@ -164,6 +191,8 @@ async def push_bridge_state( message=message, ttl=ttl, remote_id=remote_id, + info=info, + reason=reason, ) await self.fill_bridge_state(state) if state.should_deduplicate(self._prev_bridge_status): @@ -216,7 +245,7 @@ def send_remote_checkpoint( """ if not self.bridge.config["homeserver.message_send_checkpoint_endpoint"]: return WrappedTask(task=None) - task = asyncio.create_task( + task = background_task.create( MessageSendCheckpoint( event_id=event_id, room_id=room_id, diff --git a/mautrix/client/__init__.py b/mautrix/client/__init__.py index 2fe700ba..30c65988 100644 --- a/mautrix/client/__init__.py +++ b/mautrix/client/__init__.py @@ -5,3 +5,24 @@ from .state_store import FileStateStore, MemoryStateStore, MemorySyncStore, StateStore, SyncStore from .store_updater import StoreUpdatingAPI from .syncer import EventHandler, InternalEventType, Syncer, SyncStream + +__all__ = [ + "ClientAPI", + "Client", + "Dispatcher", + "MembershipEventDispatcher", + "SimpleDispatcher", + "DecryptionDispatcher", + "EncryptingAPI", + "FileStateStore", + "MemoryStateStore", + "MemorySyncStore", + "StateStore", + "SyncStore", + "StoreUpdatingAPI", + "EventHandler", + "InternalEventType", + "Syncer", + "SyncStream", + "state_store", +] diff --git a/mautrix/client/api/__init__.py b/mautrix/client/api/__init__.py index 783eb248..74081a4e 100644 --- a/mautrix/client/api/__init__.py +++ b/mautrix/client/api/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/client/api/authentication.py b/mautrix/client/api/authentication.py index e485a0b9..0f6249ea 100644 --- a/mautrix/client/api/authentication.py +++ b/mautrix/client/api/authentication.py @@ -1,13 +1,16 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +import secrets + from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( + DeviceID, LoginFlowList, LoginResponse, LoginType, @@ -40,7 +43,7 @@ async def get_login_flows(self) -> LoginFlowList: Returns: The list of login flows that the homeserver supports. """ - resp = await self.api.request(Method.GET, Path.login) + resp = await self.api.request(Method.GET, Path.v3.login) try: return LoginFlowList.deserialize(resp) except KeyError: @@ -93,12 +96,13 @@ async def login( kwargs["device_id"] = self.device_id resp = await self.api.request( Method.POST, - Path.login, + Path.v3.login, { "type": str(login_type), "identifier": identifier.serialize(), **kwargs, }, + sensitive="password" in kwargs or "token" in kwargs, ) resp_data = LoginResponse.deserialize(resp) if store_access_token: @@ -115,6 +119,19 @@ async def login( self.api.base_url = base_url.rstrip("/") return resp_data + async def create_device_msc4190(self, device_id: str, initial_display_name: str) -> None: + """ + Create a Device for a user of the homeserver using appservice interface defined in MSC4190 + """ + if len(device_id) == 0: + device_id = DeviceID(secrets.token_urlsafe(10)) + self.api.as_user_id = self.mxid + await self.api.request( + Method.PUT, Path.v3.devices[device_id], {"display_name": initial_display_name} + ) + self.api.as_device_id = device_id + self.device_id = device_id + async def logout(self, clear_access_token: bool = True) -> None: """ Invalidates an existing access token, so that it can no longer be used for authorization. @@ -127,10 +144,10 @@ async def logout(self, clear_access_token: bool = True) -> None: Args: clear_access_token: Whether or not mautrix-python should forget the stored access token. """ - await self.api.request(Method.POST, Path.logout) + await self.api.request(Method.POST, Path.v3.logout) if clear_access_token: self.api.token = "" - self.device_id = "" + self.device_id = DeviceID("") async def logout_all(self, clear_access_token: bool = True) -> None: """ @@ -151,10 +168,10 @@ async def logout_all(self, clear_access_token: bool = True) -> None: Args: clear_access_token: Whether or not mautrix-python should forget the stored access token. """ - await self.api.request(Method.POST, Path.logout.all) + await self.api.request(Method.POST, Path.v3.logout.all) if clear_access_token: self.api.token = "" - self.device_id = "" + self.device_id = DeviceID("") # endregion @@ -170,7 +187,7 @@ async def whoami(self) -> WhoamiResponse: Returns: The user ID and device ID of the current user. """ - resp = await self.api.request(Method.GET, Path.account.whoami) + resp = await self.api.request(Method.GET, Path.v3.account.whoami) return WhoamiResponse.deserialize(resp) # endregion diff --git a/mautrix/client/api/base.py b/mautrix/client/api/base.py index 743fd879..07659894 100644 --- a/mautrix/client/api/base.py +++ b/mautrix/client/api/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -10,7 +10,7 @@ from aiohttp import ClientError, ClientSession, ContentTypeError from yarl import URL -from mautrix.api import HTTPAPI, Method +from mautrix.api import HTTPAPI, Method, Path from mautrix.errors import ( WellKnownInvalidVersionsResponse, WellKnownMissingHomeserver, @@ -37,6 +37,7 @@ class BaseClientAPI: device_id: DeviceID api: HTTPAPI log: TraceLogger + versions_cache: VersionsResponse | None def __init__( self, mxid: UserID = "", device_id: DeviceID = "", api: HTTPAPI | None = None, **kwargs @@ -62,6 +63,7 @@ def __init__( self.localpart = None self.domain = None self.fill_member_event_callback = None + self.versions_cache = None self.device_id = device_id self.api = api or HTTPAPI(**kwargs) self.log = self.api.log @@ -101,10 +103,21 @@ def mxid(self, mxid: UserID) -> None: self.localpart, self.domain = self.parse_user_id(mxid) self._mxid = mxid - async def versions(self) -> VersionsResponse: - """Get client-server spec versions supported by the server.""" - resp = await self.api.request(Method.GET, "_matrix/client/versions") - return VersionsResponse.deserialize(resp) + async def versions(self, no_cache: bool = False) -> VersionsResponse: + """ + Get client-server spec versions supported by the server. + + Args: + no_cache: If true, the versions will always be fetched from the server + rather than using cached results when availab.e. + + Returns: + The supported Matrix spec versions and unstable features. + """ + if no_cache or not self.versions_cache: + resp = await self.api.request(Method.GET, Path.versions) + self.versions_cache = VersionsResponse.deserialize(resp) + return self.versions_cache @classmethod async def discover(cls, domain: str, session: ClientSession | None = None) -> URL | None: diff --git a/mautrix/client/api/client.py b/mautrix/client/api/client.py index c66ec7af..5c4c5338 100644 --- a/mautrix/client/api/client.py +++ b/mautrix/client/api/client.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -24,7 +24,7 @@ class ClientAPI( functions for accessing the client-server API. This class can be used directly, but generally you should use the higher-level wrappers that - inherit from this class, such as :class:`mautrix.client.ClientAPI` + inherit from this class, such as :class:`mautrix.client.Client` or :class:`mautrix.appservice.IntentAPI`. Examples: diff --git a/mautrix/client/api/events.py b/mautrix/client/api/events.py index 14c06079..bd415b9b 100644 --- a/mautrix/client/api/events.py +++ b/mautrix/client/api/events.py @@ -1,11 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Awaitable +from typing import Awaitable, Literal, overload import json from mautrix.api import Method, Path @@ -16,6 +16,7 @@ ContentURI, Event, EventContent, + EventContext, EventID, EventType, FilterID, @@ -43,7 +44,6 @@ TextMessageEventContent, UserID, ) -from mautrix.types.event.state import state_event_content_map from mautrix.util.formatter import parse_html from .base import BaseClientAPI @@ -97,7 +97,7 @@ def sync( if set_presence: request["set_presence"] = str(set_presence) return self.api.request( - Method.GET, Path.sync, query_params=request, retry_count=0, metrics_method="sync" + Method.GET, Path.v3.sync, query_params=request, retry_count=0, metrics_method="sync" ) # endregion @@ -119,19 +119,90 @@ async def get_event(self, room_id: RoomID, event_id: EventID) -> Event: The event. """ content = await self.api.request( - Method.GET, Path.rooms[room_id].event[event_id], metrics_method="getEvent" + Method.GET, Path.v3.rooms[room_id].event[event_id], metrics_method="getEvent" ) try: return Event.deserialize(content) except SerializerError as e: raise MatrixResponseError("Invalid event in response") from e + async def get_event_context( + self, + room_id: RoomID, + event_id: EventID, + limit: int | None = 10, + filter: RoomEventFilter | None = None, + ) -> EventContext: + """ + Get a number of events that happened just before and after the specified event. + This allows clients to get the context surrounding an event, as well as get the state at + an event and paginate in either direction. + + Args: + room_id: The room to get events from. + event_id: The event to get context around. + limit: The maximum number of events to return. The limit applies to the total number of + events before and after the requested event. A limit of 0 means no other events + are returned, while 2 means one event before and one after are returned. + filter: A JSON RoomEventFilter_ to filter returned events with. + + Returns: + The event itself, up to ``limit/2`` events before and after the event, the room state + at the event, and pagination tokens to scroll up and down. + + .. _RoomEventFilter: + https://spec.matrix.org/v1.1/client-server-api/#filtering + """ + query_params = {} + if limit is not None: + query_params["limit"] = str(limit) + if filter is not None: + query_params["filter"] = ( + filter.serialize() if isinstance(filter, Serializable) else filter + ) + resp = await self.api.request( + Method.GET, + Path.v3.rooms[room_id].context[event_id], + query_params=query_params, + metrics_method="get_event_context", + ) + return EventContext.deserialize(resp) + + @overload async def get_state_event( self, room_id: RoomID, event_type: EventType, state_key: str = "", - ) -> StateEventContent: + *, + format: Literal["content"] = "content", + ) -> StateEventContent: ... + @overload + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: Literal["event"], + ) -> StateEvent: ... + @overload + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: ... + async def get_state_event( + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: """ Looks up the contents of a state event in a room. If the user is joined to the room then the state is taken from the current state of the room. If the user has left the room then the @@ -143,19 +214,27 @@ async def get_state_event( room_id: The ID of the room to look up the state in. event_type: The type of state to look up. state_key: The key of the state to look up. Defaults to empty string. + format: The format of the state event to return. Defaults to "content", which only returns + the content of the state event. If set to "event", the full event is returned. + See https://github.com/matrix-org/matrix-spec/issues/1047 for more info. Returns: The state event. """ content = await self.api.request( Method.GET, - Path.rooms[room_id].state[event_type][state_key], + Path.v3.rooms[room_id].state[event_type][state_key], + query_params={"format": format} if format != "content" else None, metrics_method="getStateEvent", ) try: - return state_event_content_map[event_type].deserialize(content) - except KeyError: - return Obj(**content) + if format == "content": + content["__mautrix_event_type"] = event_type + return StateEvent.deserialize_content(content) + elif format == "event": + return StateEvent.deserialize(content) + else: + return content except SerializerError as e: raise MatrixResponseError("Invalid state event in response") from e @@ -172,7 +251,7 @@ async def get_state(self, room_id: RoomID) -> list[StateEvent]: A list of state events with the most recent of each event_type/state_key pair. """ content = await self.api.request( - Method.GET, Path.rooms[room_id].state, metrics_method="getState" + Method.GET, Path.v3.rooms[room_id].state, metrics_method="getState" ) try: return [StateEvent.deserialize(event) for event in content] @@ -215,7 +294,7 @@ async def get_members( query["not_membership"] = not_membership.value content = await self.api.request( Method.GET, - Path.rooms[room_id].members, + Path.v3.rooms[room_id].members, query_params=query, metrics_method="getMembers", ) @@ -245,7 +324,7 @@ async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]: https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3roomsroomidmembers """ content = await self.api.request( - Method.GET, Path.rooms[room_id].joined_members, metrics_method="getJoinedMembers" + Method.GET, Path.v3.rooms[room_id].joined_members, metrics_method="getJoinedMembers" ) try: return { @@ -265,7 +344,7 @@ async def get_messages( self, room_id: RoomID, direction: PaginationDirection, - from_token: SyncToken, + from_token: SyncToken | None = None, to_token: SyncToken | None = None, limit: int | None = None, filter_json: str | dict | RoomEventFilter | None = None, @@ -274,7 +353,7 @@ async def get_messages( Get a list of message and state events for a room. Pagination parameters are used to paginate history in the room. - See also: `API reference `__ + See also: `API reference `__ Args: room_id: The ID of the room to get events from. @@ -282,6 +361,9 @@ async def get_messages( from_token: The token to start returning events from. This token can be obtained from a ``prev_batch`` token returned for each room by the `sync endpoint`_, or from a ``start`` or ``end`` token returned by a previous request to this endpoint. + + Starting from Matrix v1.3, this field can be omitted to fetch events from the + beginning or end of the room. to_token: The token to stop returning events at. limit: The maximum number of events to return. Defaults to 10. filter_json: A JSON RoomEventFilter_ to filter returned events with. @@ -289,9 +371,9 @@ async def get_messages( Returns: .. _RoomEventFilter: - https://spec.matrix.org/v1.1/client-server-api/#filtering + https://spec.matrix.org/v1.3/client-server-api/#filtering .. _sync endpoint: - https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync + https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3sync """ if isinstance(filter_json, Serializable): filter_json = filter_json.json() @@ -306,22 +388,20 @@ async def get_messages( } content = await self.api.request( Method.GET, - Path.rooms[room_id].messages, + Path.v3.rooms[room_id].messages, query_params=query_params, metrics_method="getMessages", ) try: return PaginatedMessages( content["start"], - content["end"], + content.get("end"), [Event.deserialize(event) for event in content["chunk"]], ) except KeyError: if "start" not in content: raise MatrixResponseError("`start` not in response.") - elif "end" not in content: - raise MatrixResponseError("`start` not in response.") - raise MatrixResponseError("`content` not in response.") + raise MatrixResponseError("`chunk` not in response.") except SerializerError as e: raise MatrixResponseError("Invalid events in response") from e @@ -335,6 +415,7 @@ async def send_state_event( event_type: EventType, content: StateEventContent, state_key: str = "", + ensure_joined: bool = True, **kwargs, ) -> EventID: """ @@ -348,6 +429,8 @@ async def send_state_event( event_type: The type of state to send. content: The content to send. state_key: The key for the state to send. Defaults to empty string. + ensure_joined: Used by IntentAPI to determine if it should ensure the user is joined + before sending the event. **kwargs: Optional parameters to pass to the :meth:`HTTPAPI.request` method. Used by :class:`IntentAPI` to pass the timestamp massaging field to :meth:`AppServiceAPI.request`. @@ -358,7 +441,7 @@ async def send_state_event( content = content.serialize() if isinstance(content, Serializable) else content resp = await self.api.request( Method.PUT, - Path.rooms[room_id].state[event_type][state_key], + Path.v3.rooms[room_id].state[event_type][state_key], content, **kwargs, metrics_method="sendStateEvent", @@ -398,7 +481,7 @@ async def send_message_event( raise ValueError("Room ID not given") elif not event_type: raise ValueError("Event type not given") - url = Path.rooms[room_id].send[event_type][txn_id or self.api.get_txn_id()] + url = Path.v3.rooms[room_id].send[event_type][txn_id or self.api.get_txn_id()] content = content.serialize() if isinstance(content, Serializable) else content resp = await self.api.request( Method.PUT, url, content, **kwargs, metrics_method="sendMessageEvent" @@ -669,7 +752,7 @@ async def redact( Returns: The ID of the event that was sent to redact the other event. """ - url = Path.rooms[room_id].redact[event_id][self.api.get_txn_id()] + url = Path.v3.rooms[room_id].redact[event_id][self.api.get_txn_id()] content = extra_content or {} if reason: content["reason"] = reason diff --git a/mautrix/client/api/filtering.py b/mautrix/client/api/filtering.py index bcaa3361..b8df1f36 100644 --- a/mautrix/client/api/filtering.py +++ b/mautrix/client/api/filtering.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -32,7 +32,7 @@ async def get_filter(self, filter_id: FilterID) -> Filter: Returns: The filter data. """ - content = await self.api.request(Method.GET, Path.user[self.mxid].filter[filter_id]) + content = await self.api.request(Method.GET, Path.v3.user[self.mxid].filter[filter_id]) return Filter.deserialize(content) async def create_filter(self, filter_params: Filter) -> FilterID: @@ -49,10 +49,12 @@ async def create_filter(self, filter_params: Filter) -> FilterID: """ resp = await self.api.request( Method.POST, - Path.user[self.mxid].filter, - filter_params.serialize() - if isinstance(filter_params, Serializable) - else filter_params, + Path.v3.user[self.mxid].filter, + ( + filter_params.serialize() + if isinstance(filter_params, Serializable) + else filter_params + ), ) try: return resp["filter_id"] diff --git a/mautrix/client/api/modules/__init__.py b/mautrix/client/api/modules/__init__.py index d134a32b..e6c0449d 100644 --- a/mautrix/client/api/modules/__init__.py +++ b/mautrix/client/api/modules/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/client/api/modules/account_data.py b/mautrix/client/api/modules/account_data.py index ecf7764d..18531ad4 100644 --- a/mautrix/client/api/modules/account_data.py +++ b/mautrix/client/api/modules/account_data.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -33,7 +33,7 @@ async def get_account_data(self, type: EventType | str, room_id: RoomID | None = """ if isinstance(type, EventType) and not type.is_account_data: raise ValueError("Event type is not an account data event type") - base_path = Path.user[self.mxid] + base_path = Path.v3.user[self.mxid] if room_id: base_path = base_path.rooms[room_id] return await self.api.request(Method.GET, base_path.account_data[type]) @@ -56,7 +56,7 @@ async def set_account_data( """ if isinstance(type, EventType) and not type.is_account_data: raise ValueError("Event type is not an account data event type") - base_path = Path.user[self.mxid] + base_path = Path.v3.user[self.mxid] if room_id: base_path = base_path.rooms[room_id] await self.api.request( diff --git a/mautrix/client/api/modules/crypto.py b/mautrix/client/api/modules/crypto.py index bcec60d3..2d879a63 100644 --- a/mautrix/client/api/modules/crypto.py +++ b/mautrix/client/api/modules/crypto.py @@ -1,17 +1,21 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any +from typing import Any, Union from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError from mautrix.types import ( + JSON, ClaimKeysResponse, + CrossSigningKeys, + CrossSigningUsage, DeviceID, + DeviceKeys, EncryptionKeyAlgorithm, EventType, QueryKeysResponse, @@ -47,7 +51,7 @@ async def send_to_device( raise ValueError("Event type must be a to-device event type") await self.api.request( Method.PUT, - Path.sendToDevice[event_type][self.api.get_txn_id()], + Path.v3.sendToDevice[event_type][self.api.get_txn_id()], { "messages": { user_id: { @@ -82,7 +86,7 @@ async def send_to_one_device( async def upload_keys( self, one_time_keys: dict[str, Any] | None = None, - device_keys: dict[str, Any] | None = None, + device_keys: DeviceKeys | dict[str, Any] | None = None, ) -> dict[EncryptionKeyAlgorithm, int]: """ Publishes end-to-end encryption keys for the device. @@ -102,10 +106,14 @@ async def upload_keys( """ data = {} if device_keys: + if isinstance(device_keys, Serializable): + device_keys = device_keys.serialize() data["device_keys"] = device_keys if one_time_keys: + if isinstance(one_time_keys, Serializable): + one_time_keys = one_time_keys.serialize() data["one_time_keys"] = one_time_keys - resp = await self.api.request(Method.POST, Path.keys.upload, data) + resp = await self.api.request(Method.POST, Path.v3.keys.upload, data) try: return { EncryptionKeyAlgorithm.deserialize(alg): count @@ -116,6 +124,43 @@ async def upload_keys( except AttributeError as e: raise MatrixResponseError("Invalid `one_time_key_counts` field in response.") from e + async def upload_cross_signing_keys( + self, + keys: dict[CrossSigningUsage, CrossSigningKeys], + auth: dict[str, JSON] | None = None, + ) -> None: + await self.api.request( + Method.POST, + Path.v3.keys.device_signing.upload, + {f"{usage}_key": key.serialize() for usage, key in keys.items()} + | ({"auth": auth} if auth else {}), + ) + + async def upload_one_signature( + self, + user_id: UserID, + device_id: DeviceID, + keys: Union[DeviceKeys, CrossSigningKeys], + ) -> None: + await self.api.request( + Method.POST, Path.v3.keys.signatures.upload, {user_id: {device_id: keys.serialize()}} + ) + # TODO check failures + + async def upload_many_signatures( + self, + signatures: dict[UserID, dict[DeviceID, Union[DeviceKeys, CrossSigningKeys]]], + ) -> None: + await self.api.request( + Method.POST, + Path.v3.keys.signatures.upload, + { + user_id: {device_id: keys.serialize() for device_id, keys in devices.items()} + for user_id, devices in signatures.items() + }, + ) + # TODO check failures + async def query_keys( self, device_keys: list[UserID] | set[UserID] | dict[UserID, list[DeviceID]], @@ -147,7 +192,7 @@ async def query_keys( } if token: data["token"] = token - resp = await self.api.request(Method.POST, Path.keys.query, data) + resp = await self.api.request(Method.POST, Path.v3.keys.query, data) return QueryKeysResponse.deserialize(resp) async def claim_keys( @@ -171,7 +216,7 @@ async def claim_keys( """ resp = await self.api.request( Method.POST, - Path.keys.claim, + Path.v3.keys.claim, { "timeout": timeout, "one_time_keys": { diff --git a/mautrix/client/api/modules/media_repository.py b/mautrix/client/api/modules/media_repository.py index 0cc6a2b3..b1d90cc1 100644 --- a/mautrix/client/api/modules/media_repository.py +++ b/mautrix/client/api/modules/media_repository.py @@ -1,25 +1,34 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import AsyncIterable -import sys +from typing import Any, AsyncIterable, Literal +from contextlib import contextmanager +import asyncio +import time + +from yarl import URL from mautrix import __optional_imports__ from mautrix.api import MediaPath, Method -from mautrix.errors import MatrixResponseError -from mautrix.types import ContentURI, MediaRepoConfig, MXOpenGraph, SerializerError +from mautrix.errors import MatrixResponseError, make_request_error +from mautrix.types import ( + ContentURI, + MediaCreateResponse, + MediaRepoConfig, + MXOpenGraph, + SerializerError, + SpecVersions, +) +from mautrix.util import background_task +from mautrix.util.async_body import async_iter_bytes +from mautrix.util.opt_prometheus import Histogram from ..base import BaseClientAPI -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - try: from mautrix.util import magic except ImportError: @@ -27,6 +36,12 @@ raise magic = None # type: ignore +UPLOAD_TIME = Histogram( + "bridge_media_upload_time", + "Time spent uploading media (milliseconds per megabyte)", + buckets=[10, 25, 50, 100, 250, 500, 750, 1000, 2500, 5000, 10000], +) + class MediaRepositoryMethods(BaseClientAPI): """ @@ -34,31 +49,63 @@ class MediaRepositoryMethods(BaseClientAPI): downloading content from the media repository and for getting URL previews without leaking client IPs. - See also: `API reference `__""" + See also: `API reference `__ + """ + + async def create_mxc(self) -> MediaCreateResponse: + """ + Create a media ID for uploading media to the homeserver. + + See also: `API reference `__ + + Returns: + MediaCreateResponse Containing the MXC URI that can be used to upload a file to later + """ + resp = await self.api.request(Method.POST, MediaPath.v1.create) + return MediaCreateResponse.deserialize(resp) + + @contextmanager + def _observe_upload_time(self, size: int | None, mxc: ContentURI | None = None) -> None: + start = time.monotonic_ns() + yield + duration = time.monotonic_ns() - start + if mxc: + duration_sec = duration / 1000**3 + self.log.debug(f"Completed asynchronous upload of {mxc} in {duration_sec:.3f} seconds") + if size: + UPLOAD_TIME.observe(duration / size) async def upload_media( self, - data: bytes | AsyncIterable[bytes], + data: bytes | bytearray | AsyncIterable[bytes], mime_type: str | None = None, filename: str | None = None, size: int | None = None, + mxc: ContentURI | None = None, + async_upload: bool = False, ) -> ContentURI: """ Upload a file to the content repository. - See also: `API reference `__ + See also: `API reference `__ Args: data: The data to upload. mime_type: The MIME type to send with the upload request. filename: The filename to send with the upload request. size: The file size to send with the upload request. + mxc: An existing MXC URI which doesn't have content yet to upload into. + async_upload: Should the media be uploaded in the background? + If ``True``, this will create a MXC URI using :meth:`create_mxc`, start uploading + in the background, and then immediately return the created URI. This is mutually + exclusive with manually passing the ``mxc`` parameter. Returns: The MXC URI to the uploaded file. Raises: MatrixResponseError: If the response does not contain a ``content_uri`` field. + ValueError: if both ``async_upload`` and ``mxc`` are provided at the same time. """ if magic and isinstance(data, bytes): mime_type = mime_type or magic.mimetype(data) @@ -67,32 +114,93 @@ async def upload_media( headers["Content-Type"] = mime_type if size: headers["Content-Length"] = str(size) + elif isinstance(data, (bytes, bytearray)): + size = len(data) query = {} if filename: query["filename"] = filename - resp = await self.api.request( - Method.POST, MediaPath.upload, content=data, headers=headers, query_params=query - ) - try: - return resp["content_uri"] - except KeyError: - raise MatrixResponseError("`content_uri` not in response.") - async def download_media(self, url: ContentURI) -> bytes: + upload_url = None + + if async_upload: + if mxc: + raise ValueError("async_upload and mxc can't be provided simultaneously") + create_response = await self.create_mxc() + mxc = create_response.content_uri + upload_url = create_response.unstable_upload_url + + path = MediaPath.v3.upload + method = Method.POST + if mxc: + server_name, media_id = self.api.parse_mxc_uri(mxc) + if upload_url is None: + path = MediaPath.v3.upload[server_name][media_id] + method = Method.PUT + else: + path = ( + MediaPath.unstable["com.beeper.msc3870"].upload[server_name][media_id].complete + ) + + if upload_url is not None: + task = self._upload_to_url(upload_url, path, headers, data, post_upload_query=query) + else: + task = self.api.request( + method, path, content=data, headers=headers, query_params=query + ) + + if async_upload: + + async def _try_upload(): + try: + with self._observe_upload_time(size, mxc): + await task + except Exception as e: + self.log.error(f"Failed to upload {mxc}: {type(e).__name__}: {e}") + + background_task.create(_try_upload()) + return mxc + else: + with self._observe_upload_time(size): + resp = await task + try: + return resp["content_uri"] + except KeyError: + raise MatrixResponseError("`content_uri` not in response.") + + async def download_media(self, url: ContentURI, timeout_ms: int | None = None) -> bytes: """ Download a file from the content repository. - See also: `API reference `__ + See also: `API reference `__ Args: url: The MXC URI to download. + timeout_ms: The maximum number of milliseconds that the client is willing to wait to + start receiving data. Used for asynchronous uploads. Returns: The raw downloaded data. """ - url = self.api.get_download_url(url) - async with self.api.session.get(url) as response: - return await response.read() + authenticated = (await self.versions()).supports(SpecVersions.V111) + url = self.api.get_download_url(url, authenticated=authenticated) + query_params: dict[str, Any] = {"allow_redirect": "true"} + if timeout_ms is not None: + query_params["timeout_ms"] = timeout_ms + headers: dict[str, str] = {} + if authenticated: + headers["Authorization"] = f"Bearer {self.api.token}" + if self.api.as_user_id: + query_params["user_id"] = self.api.as_user_id + req_id = self.api.log_download_request(url, query_params) + start = time.monotonic() + async with self.api.session.get(url, params=query_params, headers=headers) as response: + try: + response.raise_for_status() + return await response.read() + finally: + self.api.log_download_request_done( + url, req_id, time.monotonic() - start, response.status + ) async def download_thumbnail( self, @@ -100,12 +208,13 @@ async def download_thumbnail( width: int | None = None, height: int | None = None, resize_method: Literal["crop", "scale"] = None, - allow_remote: bool = True, + allow_remote: bool | None = None, + timeout_ms: int | None = None, ): """ Download a thumbnail for a file in the content repository. - See also: `API reference `__ + See also: `API reference `__ Args: url: The MXC URI to download. @@ -117,12 +226,17 @@ async def download_thumbnail( allow_remote: Indicates to the server that it should not attempt to fetch the media if it is deemed remote. This is to prevent routing loops where the server contacts itself. + timeout_ms: The maximum number of milliseconds that the client is willing to wait to + start receiving data. Used for asynchronous Uploads. Returns: The raw downloaded data. """ - url = self.api.get_download_url(url, download_type="thumbnail") - query_params = {} + authenticated = (await self.versions()).supports(SpecVersions.V111) + url = self.api.get_download_url( + url, download_type="thumbnail", authenticated=authenticated + ) + query_params: dict[str, Any] = {"allow_redirect": "true"} if width is not None: query_params["width"] = width if height is not None: @@ -130,15 +244,30 @@ async def download_thumbnail( if resize_method is not None: query_params["method"] = resize_method if allow_remote is not None: - query_params["allow_remote"] = allow_remote - async with self.api.session.get(url, params=query_params) as response: - return await response.read() + query_params["allow_remote"] = str(allow_remote).lower() + if timeout_ms is not None: + query_params["timeout_ms"] = timeout_ms + headers: dict[str, str] = {} + if authenticated: + headers["Authorization"] = f"Bearer {self.api.token}" + if self.api.as_user_id: + query_params["user_id"] = self.api.as_user_id + req_id = self.api.log_download_request(url, query_params) + start = time.monotonic() + async with self.api.session.get(url, params=query_params, headers=headers) as response: + try: + response.raise_for_status() + return await response.read() + finally: + self.api.log_download_request_done( + url, req_id, time.monotonic() - start, response.status + ) async def get_url_preview(self, url: str, timestamp: int | None = None) -> MXOpenGraph: """ Get information about a URL for a client. - See also: `API reference `__ + See also: `API reference `__ Args: url: The URL to get a preview of. @@ -149,7 +278,7 @@ async def get_url_preview(self, url: str, timestamp: int | None = None) -> MXOpe if timestamp is not None: query_params["ts"] = timestamp content = await self.api.request( - Method.GET, MediaPath.preview_url, query_params=query_params + Method.GET, MediaPath.v3.preview_url, query_params=query_params ) try: return MXOpenGraph.deserialize(content) @@ -173,8 +302,51 @@ async def get_media_repo_config(self) -> MediaRepoConfig: Returns: The media repository config. """ - content = await self.api.request(Method.GET, MediaPath.config) + content = await self.api.request(Method.GET, MediaPath.v3.config) try: return MediaRepoConfig.deserialize(content) except SerializerError as e: raise MatrixResponseError("Invalid MediaRepoConfig in response") from e + + async def _upload_to_url( + self, + upload_url: str, + post_upload_path: str, + headers: dict[str, str], + data: bytes | bytearray | AsyncIterable[bytes], + post_upload_query: dict[str, str], + min_iter_size: int = 25 * 1024 * 1024, + ) -> None: + retry_count = self.api.default_retry_count + backoff = 2 + do_fake_iter = data and hasattr(data, "__len__") and len(data) > min_iter_size + if do_fake_iter: + headers["Content-Length"] = str(len(data)) + while True: + self.log.debug("Uploading media to external URL %s", upload_url) + upload_response = None + try: + req_data = async_iter_bytes(data) if do_fake_iter else data + upload_response = await self.api.session.put( + upload_url, data=req_data, headers=headers + ) + upload_response.raise_for_status() + except Exception as e: + if retry_count <= 0: + raise make_request_error( + http_status=upload_response.status if upload_response else -1, + text=(await upload_response.text()) if upload_response else "", + errcode="COM.BEEPER.EXTERNAL_UPLOAD_ERROR", + message=None, + ) + self.log.warning( + f"Uploading media to external URL {upload_url} failed: {e}, " + f"retrying in {backoff} seconds", + ) + await asyncio.sleep(backoff) + backoff *= 2 + retry_count -= 1 + else: + break + + await self.api.request(Method.POST, post_upload_path, query_params=post_upload_query) diff --git a/mautrix/client/api/modules/misc.py b/mautrix/client/api/modules/misc.py index 485f345b..8ff9c85a 100644 --- a/mautrix/client/api/modules/misc.py +++ b/mautrix/client/api/modules/misc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -50,13 +50,13 @@ async def set_typing(self, room_id: RoomID, timeout: int = 0) -> None: Args: room_id: The ID of the room in which the user is typing. - timeout: The length of time in seconds to mark this user as typing. + timeout: The length of time in milliseconds to mark this user as typing. """ if timeout > 0: content = {"typing": True, "timeout": timeout} else: content = {"typing": False} - await self.api.request(Method.PUT, Path.rooms[room_id].typing[self.mxid], content) + await self.api.request(Method.PUT, Path.v3.rooms[room_id].typing[self.mxid], content) # endregion # region 13.5 Receipts @@ -77,7 +77,7 @@ async def send_receipt( event_id: The last event ID to acknowledge. receipt_type: The type of receipt to send. Currently only ``m.read`` is supported. """ - await self.api.request(Method.POST, Path.rooms[room_id].receipt[receipt_type][event_id]) + await self.api.request(Method.POST, Path.v3.rooms[room_id].receipt[receipt_type][event_id]) # endregion # region 13.6 Fully read markers @@ -110,7 +110,7 @@ async def set_fully_read_marker( content["m.read"] = read_receipt if extra_content: content.update(extra_content) - await self.api.request(Method.POST, Path.rooms[room_id].read_markers, content) + await self.api.request(Method.POST, Path.v3.rooms[room_id].read_markers, content) # endregion # region 13.7 Presence @@ -134,7 +134,7 @@ async def set_presence( } if status: content["status_msg"] = status - await self.api.request(Method.PUT, Path.presence[self.mxid].status, content) + await self.api.request(Method.PUT, Path.v3.presence[self.mxid].status, content) async def get_presence(self, user_id: UserID) -> PresenceEventContent: """ @@ -148,7 +148,7 @@ async def get_presence(self, user_id: UserID) -> PresenceEventContent: Returns: The presence info of the given user. """ - content = await self.api.request(Method.GET, Path.presence[user_id].status) + content = await self.api.request(Method.GET, Path.v3.presence[user_id].status) try: return PresenceEventContent.deserialize(content) except SerializerError: diff --git a/mautrix/client/api/modules/push_rules.py b/mautrix/client/api/modules/push_rules.py index 31392b0c..38554ac3 100644 --- a/mautrix/client/api/modules/push_rules.py +++ b/mautrix/client/api/modules/push_rules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -41,7 +41,7 @@ async def get_push_rule( Returns: The push rule information. """ - resp = await self.api.request(Method.GET, Path.pushrules[scope][kind][rule_id]) + resp = await self.api.request(Method.GET, Path.v3.pushrules[scope][kind][rule_id]) return PushRule.deserialize(resp) async def set_push_rule( @@ -81,7 +81,10 @@ async def set_push_rule( if pattern: content["pattern"] = pattern await self.api.request( - Method.PUT, Path.pushrules[scope][kind][rule_id], query_params=query, content=content + Method.PUT, + Path.v3.pushrules[scope][kind][rule_id], + query_params=query, + content=content, ) async def remove_push_rule( @@ -97,4 +100,4 @@ async def remove_push_rule( kind: The kind of rule. rule_id: The identifier of the rule. """ - await self.api.request(Method.DELETE, Path.pushrules[scope][kind][rule_id]) + await self.api.request(Method.DELETE, Path.v3.pushrules[scope][kind][rule_id]) diff --git a/mautrix/client/api/modules/room_tag.py b/mautrix/client/api/modules/room_tag.py index fe0c6271..0dd9a1a8 100644 --- a/mautrix/client/api/modules/room_tag.py +++ b/mautrix/client/api/modules/room_tag.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -31,7 +31,7 @@ async def get_room_tags(self, room_id: RoomID) -> RoomTagAccountDataEventContent Returns: The m.tag account data event. """ - resp = await self.api.request(Method.GET, Path.user[self.mxid].rooms[room_id].tags) + resp = await self.api.request(Method.GET, Path.v3.user[self.mxid].rooms[room_id].tags) return RoomTagAccountDataEventContent.deserialize(resp) async def get_room_tag(self, room_id: RoomID, tag: str) -> RoomTagInfo | None: @@ -66,7 +66,7 @@ async def set_room_tag( """ await self.api.request( Method.PUT, - Path.user[self.mxid].rooms[room_id].tags[tag], + Path.v3.user[self.mxid].rooms[room_id].tags[tag], content=(info.serialize() if isinstance(info, Serializable) else (info or {})), ) @@ -80,4 +80,4 @@ async def remove_room_tag(self, room_id: RoomID, tag: str) -> None: room_id: The room ID to remove the tag from. tag: The tag to remove. """ - await self.api.request(Method.DELETE, Path.user[self.mxid].rooms[room_id].tags[tag]) + await self.api.request(Method.DELETE, Path.v3.user[self.mxid].rooms[room_id].tags[tag]) diff --git a/mautrix/client/api/rooms.py b/mautrix/client/api/rooms.py index 32d86960..a488fba6 100644 --- a/mautrix/client/api/rooms.py +++ b/mautrix/client/api/rooms.py @@ -1,17 +1,23 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Awaitable, Callable, Iterable +from typing import Any, Awaitable, Callable import asyncio from multidict import CIMultiDict from mautrix.api import Method, Path -from mautrix.errors import MatrixRequestError, MatrixResponseError, MNotFound, MRoomInUse +from mautrix.errors import ( + MatrixRequestError, + MatrixResponseError, + MNotFound, + MNotJoined, + MRoomInUse, +) from mautrix.types import ( JSON, DirectoryPaginationToken, @@ -58,10 +64,12 @@ async def create_room( topic: str | None = None, is_direct: bool = False, invitees: list[UserID] | None = None, - initial_state: Iterable[StateEvent | StrippedStateEvent | dict[str, JSON]] | None = None, + initial_state: list[StateEvent | StrippedStateEvent | dict[str, JSON]] | None = None, room_version: str = None, creation_content: RoomCreateStateEventContent | dict[str, JSON] | None = None, power_level_override: PowerLevelStateEventContent | dict[str, JSON] | None = None, + beeper_auto_join_invites: bool = False, + custom_request_fields: dict[str, Any] | None = None, ) -> RoomID: """ Create a new room with various configuration options. @@ -105,6 +113,10 @@ async def create_room( power_level_override: The power level content to override in the default power level event. This object is applied on top of the generated ``m.room.power_levels`` event content prior to it being sent to the room. Defaults to overriding nothing. + beeper_auto_join_invites: A Beeper-specific extension which auto-joins all members in + the invite array instead of sending invites. + custom_request_fields: Additional fields to put in the top-level /createRoom content. + Non-custom fields take precedence over fields here. Returns: The ID of the newly created room. @@ -118,6 +130,7 @@ async def create_room( .. _m.room.member: https://spec.matrix.org/v1.1/client-server-api/#mroommember """ content = { + **(custom_request_fields or {}), "visibility": visibility.value, "is_direct": is_direct, "preset": preset.value, @@ -126,6 +139,8 @@ async def create_room( content["room_alias_name"] = alias_localpart if invitees: content["invite"] = invitees + if beeper_auto_join_invites: + content["com.beeper.auto_join_invites"] = True if name: content["name"] = name if topic: @@ -152,7 +167,7 @@ async def create_room( else power_level_override ) - resp = await self.api.request(Method.POST, Path.createRoom, content) + resp = await self.api.request(Method.POST, Path.v3.createRoom, content) try: return resp["room_id"] except KeyError: @@ -179,12 +194,12 @@ async def add_room_alias( room_alias = f"#{alias_localpart}:{self.domain}" content = {"room_id": room_id} try: - await self.api.request(Method.PUT, Path.directory.room[room_alias], content) + await self.api.request(Method.PUT, Path.v3.directory.room[room_alias], content) except MatrixRequestError as e: if e.http_status == 409: if override: await self.remove_room_alias(alias_localpart) - await self.api.request(Method.PUT, Path.directory.room[room_alias], content) + await self.api.request(Method.PUT, Path.v3.directory.room[room_alias], content) else: raise MRoomInUse(e.http_status, e.message) from e else: @@ -205,7 +220,7 @@ async def remove_room_alias(self, alias_localpart: str, raise_404: bool = False) """ room_alias = f"#{alias_localpart}:{self.domain}" try: - await self.api.request(Method.DELETE, Path.directory.room[room_alias]) + await self.api.request(Method.DELETE, Path.v3.directory.room[room_alias]) except MNotFound: if raise_404: raise @@ -226,7 +241,7 @@ async def resolve_room_alias(self, room_alias: RoomAlias) -> RoomAliasInfo: Returns: The room ID and a list of servers that are aware of the room. """ - content = await self.api.request(Method.GET, Path.directory.room[room_alias]) + content = await self.api.request(Method.GET, Path.v3.directory.room[room_alias]) return RoomAliasInfo.deserialize(content) # endregion @@ -235,7 +250,7 @@ async def resolve_room_alias(self, room_alias: RoomAlias) -> RoomAliasInfo: async def get_joined_rooms(self) -> list[RoomID]: """Get the list of rooms the user is in.""" - content = await self.api.request(Method.GET, Path.joined_rooms) + content = await self.api.request(Method.GET, Path.v3.joined_rooms) try: return content["joined_rooms"] except KeyError: @@ -273,7 +288,7 @@ async def join_room_by_id( return room_id content = await self.api.request( Method.POST, - Path.rooms[room_id].join, + Path.v3.rooms[room_id].join, {"third_party_signed": third_party_signed} if third_party_signed is not None else None, ) try: @@ -319,7 +334,7 @@ async def join_room( try: content = await self.api.request( Method.POST, - Path.join[room_id_or_alias], + Path.v3.join[room_id_or_alias], content=content, query_params=query_params, ) @@ -339,9 +354,12 @@ async def join_room( except KeyError: raise MatrixResponseError("`room_id` not in response.") - fill_member_event_callback: Callable[ - [RoomID, UserID, MemberStateEventContent], Awaitable[MemberStateEventContent | None] - ] | None + fill_member_event_callback: ( + Callable[ + [RoomID, UserID, MemberStateEventContent], Awaitable[MemberStateEventContent | None] + ] + | None + ) async def fill_member_event( self, room_id: RoomID, user_id: UserID, content: MemberStateEventContent @@ -391,7 +409,7 @@ async def send_member_event( content[key] = value content = await self.fill_member_event(room_id, user_id, content) or content return await self.send_state_event( - room_id, EventType.ROOM_MEMBER, content=content, state_key=user_id + room_id, EventType.ROOM_MEMBER, content=content, state_key=user_id, ensure_joined=False ) async def invite_user( @@ -431,7 +449,7 @@ async def invite_user( data = {"user_id": user_id} if reason: data["reason"] = reason - await self.api.request(Method.POST, Path.rooms[room_id].invite, content=data) + await self.api.request(Method.POST, Path.v3.rooms[room_id].invite, content=data) # endregion # region 8.4.2 Leaving rooms @@ -477,11 +495,54 @@ async def leave_room( data = {} if reason: data["reason"] = reason - await self.api.request(Method.POST, Path.rooms[room_id].leave, content=data) + await self.api.request(Method.POST, Path.v3.rooms[room_id].leave, content=data) + except MNotJoined: + if raise_not_in_room: + raise except MatrixRequestError as e: + # TODO remove this once MSC3848 is released and minimum spec version is bumped if "not in room" not in e.message or raise_not_in_room: raise + async def knock_room( + self, + room_id_or_alias: RoomID | RoomAlias, + reason: str | None = None, + servers: list[str] | None = None, + ) -> RoomID: + """ + Knock on a room, i.e. request to join it by its ID or alias, with an optional list of + servers to ask about the ID from. + + See also: `API reference `__ + + Args: + room_id_or_alias: The ID of the room to knock on, or an alias pointing to the room. + reason: The reason for knocking on the room. This will be supplied as the ``reason`` on + the updated `m.room.member`_ event. + servers: A list of servers to ask about the room ID to knock. Not applicable for aliases, + as aliases already contain the necessary server information. + + Returns: + The ID of the room the user knocked on. + """ + data = {} + if reason: + data["reason"] = reason + query_params = CIMultiDict() + for server_name in servers or []: + query_params.add("server_name", server_name) + content = await self.api.request( + Method.POST, + Path.v3.knock[room_id_or_alias], + content=data, + query_params=query_params, + ) + try: + return content["room_id"] + except KeyError: + raise MatrixResponseError("`room_id` not in response.") + async def forget_room(self, room_id: RoomID) -> None: """ Stop remembering a particular room, i.e. forget it. @@ -498,7 +559,7 @@ async def forget_room(self, room_id: RoomID) -> None: Args: room_id: The ID of the room to forget. """ - await self.api.request(Method.POST, Path.rooms[room_id].forget) + await self.api.request(Method.POST, Path.v3.rooms[room_id].forget) async def kick_user( self, @@ -538,7 +599,7 @@ async def kick_user( ) return await self.api.request( - Method.POST, Path.rooms[room_id].kick, {"user_id": user_id, "reason": reason} + Method.POST, Path.v3.rooms[room_id].kick, {"user_id": user_id, "reason": reason} ) # endregion @@ -578,10 +639,16 @@ async def ban_user( ) return await self.api.request( - Method.POST, Path.rooms[room_id].ban, {"user_id": user_id, "reason": reason} + Method.POST, Path.v3.rooms[room_id].ban, {"user_id": user_id, "reason": reason} ) - async def unban_user(self, room_id: RoomID, user_id: UserID) -> None: + async def unban_user( + self, + room_id: RoomID, + user_id: UserID, + reason: str = "", + extra_content: dict[str, JSON] | None = None, + ) -> None: """ Unban a user from the room. This allows them to be invited to the room, and join if they would otherwise be allowed to join according to its join rules. The caller must have the @@ -592,8 +659,22 @@ async def unban_user(self, room_id: RoomID, user_id: UserID) -> None: Args: room_id: The ID of the room from which the user should be unbanned. user_id: The fully qualified user ID of the user being banned. + reason: The reason the user has been unbanned. This will be supplied as the ``reason`` on + the target's updated `m.room.member`_ event. + extra_content: Additional properties for the unban (leave) event content. + If a non-empty dict is passed, the unban will be created using + the ``PUT /state/m.room.member/...`` endpoint instead of ``POST /unban``. """ - await self.api.request(Method.POST, Path.rooms[room_id].unban, {"user_id": user_id}) + if extra_content: + if reason and "reason" not in extra_content: + extra_content["reason"] = reason + await self.send_member_event( + room_id, user_id, Membership.LEAVE, extra_content=extra_content + ) + return + await self.api.request( + Method.POST, Path.v3.rooms[room_id].unban, {"user_id": user_id, "reason": reason} + ) # endregion @@ -613,7 +694,7 @@ async def get_room_directory_visibility(self, room_id: RoomID) -> RoomDirectoryV Returns: The visibility of the room in the directory. """ - resp = await self.api.request(Method.GET, Path.directory.list.room[room_id]) + resp = await self.api.request(Method.GET, Path.v3.directory.list.room[room_id]) try: return RoomDirectoryVisibility(resp["visibility"]) except KeyError: @@ -640,7 +721,7 @@ async def set_room_directory_visibility( """ await self.api.request( Method.PUT, - Path.directory.list.room[room_id], + Path.v3.directory.list.room[room_id], { "visibility": visibility.value, }, @@ -700,7 +781,7 @@ async def get_room_directory( query_params = {"server": server} if server is not None else None content = await self.api.request( - method, Path.publicRooms, content, query_params=query_params + method, Path.v3.publicRooms, content, query_params=query_params ) return RoomDirectoryResponse.deserialize(content) diff --git a/mautrix/client/api/user_data.py b/mautrix/client/api/user_data.py index 68807be6..9c380335 100644 --- a/mautrix/client/api/user_data.py +++ b/mautrix/client/api/user_data.py @@ -1,10 +1,12 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +from typing import Any + from mautrix.api import Method, Path from mautrix.errors import MatrixResponseError, MNotFound from mautrix.types import ContentURI, Member, SerializerError, User, UserID, UserSearchResults @@ -46,7 +48,7 @@ async def search_users(self, search_query: str, limit: int | None = 10) -> UserS """ content = await self.api.request( Method.POST, - Path.user_directory.search, + Path.v3.user_directory.search, { "search_term": search_query, "limit": limit, @@ -69,7 +71,7 @@ async def search_users(self, search_query: str, limit: int | None = 10) -> UserS # region 10.2 Profiles # API reference: https://matrix.org/docs/spec/client_server/r0.4.0.html#profiles - async def set_displayname(self, displayname: str, check_current: bool = True) -> None: + async def set_displayname(self, displayname: str | None, check_current: bool = True) -> None: """ Set the display name of the current user. @@ -79,11 +81,13 @@ async def set_displayname(self, displayname: str, check_current: bool = True) -> displayname: The new display name for the user. check_current: Whether or not to check if the displayname is already set. """ - if check_current and await self.get_displayname(self.mxid) == displayname: + if check_current and str_or_none(await self.get_displayname(self.mxid)) == str_or_none( + displayname + ): return await self.api.request( Method.PUT, - Path.profile[self.mxid].displayname, + Path.v3.profile[self.mxid].displayname, { "displayname": displayname, }, @@ -102,7 +106,7 @@ async def get_displayname(self, user_id: UserID) -> str | None: The display name of the given user. """ try: - content = await self.api.request(Method.GET, Path.profile[user_id].displayname) + content = await self.api.request(Method.GET, Path.v3.profile[user_id].displayname) except MNotFound: return None try: @@ -110,7 +114,9 @@ async def get_displayname(self, user_id: UserID) -> str | None: except KeyError: return None - async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = True) -> None: + async def set_avatar_url( + self, avatar_url: ContentURI | None, check_current: bool = True + ) -> None: """ Set the avatar of the current user. @@ -120,11 +126,13 @@ async def set_avatar_url(self, avatar_url: ContentURI, check_current: bool = Tru avatar_url: The ``mxc://`` URI to the new avatar. check_current: Whether or not to check if the avatar is already set. """ - if check_current and await self.get_avatar_url(self.mxid) == avatar_url: + if check_current and str_or_none(await self.get_avatar_url(self.mxid)) == str_or_none( + avatar_url + ): return await self.api.request( Method.PUT, - Path.profile[self.mxid].avatar_url, + Path.v3.profile[self.mxid].avatar_url, { "avatar_url": avatar_url, }, @@ -143,7 +151,7 @@ async def get_avatar_url(self, user_id: UserID) -> ContentURI | None: The ``mxc://`` URI to the user's avatar. """ try: - content = await self.api.request(Method.GET, Path.profile[user_id].avatar_url) + content = await self.api.request(Method.GET, Path.v3.profile[user_id].avatar_url) except MNotFound: return None try: @@ -163,10 +171,30 @@ async def get_profile(self, user_id: UserID) -> Member: Returns: The profile information of the given user. """ - content = await self.api.request(Method.GET, Path.profile[user_id]) + content = await self.api.request(Method.GET, Path.v3.profile[user_id]) try: return Member.deserialize(content) except SerializerError as e: raise MatrixResponseError("Invalid member in response") from e # endregion + + # region Beeper Custom Fields API + + async def beeper_update_profile(self, custom_fields: dict[str, Any]) -> None: + """ + Set custom fields on the user's profile. Only works on Hungryserv. + + Args: + custom_fields: A dictionary of fields to set in the custom content of the profile. + """ + await self.api.request(Method.PATCH, Path.v3.profile[self.mxid], custom_fields) + + # endregion + + +def str_or_none(v: str | None) -> str | None: + """ + str_or_none empty string values to None + """ + return None if v == "" else v diff --git a/mautrix/client/client.py b/mautrix/client/client.py index 830bf70f..5d4e5682 100644 --- a/mautrix/client/client.py +++ b/mautrix/client/client.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/client/dispatcher.py b/mautrix/client/dispatcher.py index e5565a44..7eb51619 100644 --- a/mautrix/client/dispatcher.py +++ b/mautrix/client/dispatcher.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/client/encryption_manager.py b/mautrix/client/encryption_manager.py index 5a97a8ca..60a9b908 100644 --- a/mautrix/client/encryption_manager.py +++ b/mautrix/client/encryption_manager.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -40,8 +40,10 @@ class EncryptingAPI(store_updater.StoreUpdatingAPI): """The logger to use for crypto-related things.""" _share_session_events: dict[RoomID, asyncio.Event] - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, crypto_log: TraceLogger | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) + if crypto_log: + self.crypto_log = crypto_log self._crypto = None self._share_session_events = {} diff --git a/mautrix/client/state_store/__init__.py b/mautrix/client/state_store/__init__.py index 0c8deb58..98150ca4 100644 --- a/mautrix/client/state_store/__init__.py +++ b/mautrix/client/state_store/__init__.py @@ -2,3 +2,12 @@ from .file import FileStateStore from .memory import MemoryStateStore from .sync import MemorySyncStore, SyncStore + +__all__ = [ + "StateStore", + "FileStateStore", + "MemoryStateStore", + "MemorySyncStore", + "SyncStore", + "asyncpg", +] diff --git a/mautrix/client/state_store/abstract.py b/mautrix/client/state_store/abstract.py index 98f198de..d5b1f5f2 100644 --- a/mautrix/client/state_store/abstract.py +++ b/mautrix/client/state_store/abstract.py @@ -1,11 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Awaitable +from typing import Any, Awaitable from abc import ABC, abstractmethod from mautrix.types import ( @@ -121,6 +121,18 @@ async def set_power_levels( ) -> None: pass + @abstractmethod + async def has_create_cached(self, room_id: RoomID) -> bool: + pass + + @abstractmethod + async def get_create(self, room_id: RoomID) -> StateEvent | None: + pass + + @abstractmethod + async def set_create(self, event: StateEvent) -> None: + pass + @abstractmethod async def has_encryption_info_cached(self, room_id: RoomID) -> bool: pass @@ -135,7 +147,7 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent @abstractmethod async def set_encryption_info( - self, room_id: RoomID, content: RoomEncryptionStateEventContent + self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: pass @@ -143,9 +155,14 @@ async def update_state(self, evt: StateEvent) -> None: if evt.type == EventType.ROOM_POWER_LEVELS: await self.set_power_levels(evt.room_id, evt.content) elif evt.type == EventType.ROOM_MEMBER: + evt.unsigned["mautrix_prev_membership"] = await self.get_member( + evt.room_id, UserID(evt.state_key) + ) await self.set_member(evt.room_id, UserID(evt.state_key), evt.content) elif evt.type == EventType.ROOM_ENCRYPTION: await self.set_encryption_info(evt.room_id, evt.content) + elif evt.type == EventType.ROOM_CREATE and evt.sender: + await self.set_create(evt) async def get_membership(self, room_id: RoomID, user_id: UserID) -> Membership: member = await self.get_member(room_id, user_id) @@ -169,4 +186,7 @@ async def has_power_level( room_levels = await self.get_power_levels(room_id) if not room_levels: return None - return room_levels.get_user_level(user_id) >= room_levels.get_event_level(event_type) + create_event = await self.get_create(room_id) + return room_levels.get_user_level(user_id, create_event) >= room_levels.get_event_level( + event_type + ) diff --git a/mautrix/client/state_store/asyncpg/__init__.py b/mautrix/client/state_store/asyncpg/__init__.py index 7b97ca3f..a7fd8c63 100644 --- a/mautrix/client/state_store/asyncpg/__init__.py +++ b/mautrix/client/state_store/asyncpg/__init__.py @@ -1 +1,3 @@ from .store import PgStateStore + +__all__ = ["PgStateStore"] diff --git a/mautrix/client/state_store/asyncpg/store.py b/mautrix/client/state_store/asyncpg/store.py index 240d6e2c..f4f8436f 100644 --- a/mautrix/client/state_store/asyncpg/store.py +++ b/mautrix/client/state_store/asyncpg/store.py @@ -1,11 +1,12 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import NamedTuple +from typing import Any, NamedTuple +import json from mautrix.types import ( Member, @@ -14,6 +15,8 @@ PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, + Serializable, + StateEvent, UserID, ) from mautrix.util.async_db import Database, Scheme @@ -212,13 +215,36 @@ async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent return PowerLevelStateEventContent.parse_json(power_levels_json) async def set_power_levels( - self, room_id: RoomID, content: PowerLevelStateEventContent + self, room_id: RoomID, content: PowerLevelStateEventContent | dict[str, Any] ) -> None: await self.db.execute( "INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) " "ON CONFLICT (room_id) DO UPDATE SET power_levels=$2", room_id, - content.json(), + json.dumps(content.serialize() if isinstance(content, Serializable) else content), + ) + + async def has_create_cached(self, room_id: RoomID) -> bool: + return bool( + await self.db.fetchval( + "SELECT create_event IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id + ) + ) + + async def get_create(self, room_id: RoomID) -> StateEvent | None: + create_event_json = await self.db.fetchval( + "SELECT create_event FROM mx_room_state WHERE room_id=$1", room_id + ) + if create_event_json is None: + return None + return StateEvent.parse_json(create_event_json) + + async def set_create(self, event: StateEvent) -> None: + await self.db.execute( + "INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) " + "ON CONFLICT (room_id) DO UPDATE SET create_event=$2", + event.room_id, + json.dumps(event.serialize() if isinstance(event, Serializable) else event), ) async def has_encryption_info_cached(self, room_id: RoomID) -> bool: @@ -242,10 +268,14 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent return RoomEncryptionStateEventContent.parse_json(row["encryption"]) async def set_encryption_info( - self, room_id: RoomID, content: RoomEncryptionStateEventContent + self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: q = ( "INSERT INTO mx_room_state (room_id, is_encrypted, encryption) VALUES ($1, true, $2) " "ON CONFLICT (room_id) DO UPDATE SET is_encrypted=true, encryption=$2" ) - await self.db.execute(q, room_id, content.json()) + await self.db.execute( + q, + room_id, + json.dumps(content.serialize() if isinstance(content, Serializable) else content), + ) diff --git a/mautrix/client/state_store/asyncpg/upgrade.py b/mautrix/client/state_store/asyncpg/upgrade.py index bb611b21..20b2e5b2 100644 --- a/mautrix/client/state_store/asyncpg/upgrade.py +++ b/mautrix/client/state_store/asyncpg/upgrade.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -14,31 +14,35 @@ ) -@upgrade_table.register(description="Latest revision", upgrades_to=2) -async def upgrade_blank_to_v2(conn: Connection, scheme: Scheme) -> None: - await conn.execute( - """CREATE TABLE mx_room_state ( +@upgrade_table.register(description="Latest revision", upgrades_to=4) +async def upgrade_blank_to_v4(conn: Connection, scheme: Scheme) -> None: + await conn.execute(""" + CREATE TABLE mx_room_state ( room_id TEXT PRIMARY KEY, is_encrypted BOOLEAN, has_full_member_list BOOLEAN, encryption TEXT, - power_levels TEXT - )""" - ) + power_levels TEXT, + create_event TEXT + ) + """) + membership_check = "" if scheme != Scheme.SQLITE: await conn.execute( "CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')" ) - await conn.execute( - """CREATE TABLE mx_user_profile ( + else: + membership_check = "CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock'))" + await conn.execute(f""" + CREATE TABLE mx_user_profile ( room_id TEXT, user_id TEXT, - membership membership NOT NULL, + membership membership NOT NULL {membership_check}, displayname TEXT, avatar_url TEXT, PRIMARY KEY (room_id, user_id) - )""" - ) + ) + """) @upgrade_table.register(description="Stop using size-limited string fields") @@ -51,3 +55,21 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT") await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN displayname TYPE TEXT") await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT") + + +@upgrade_table.register(description="Mark rooms that need crypto state event resynced") +async def upgrade_v3(conn: Connection) -> None: + if await conn.table_exists("portal"): + await conn.execute(""" + INSERT INTO mx_room_state (room_id, encryption) + SELECT portal.mxid, '{"resync":true}' FROM portal + WHERE portal.encrypted=true AND portal.mxid IS NOT NULL + ON CONFLICT (room_id) DO UPDATE + SET encryption=excluded.encryption + WHERE mx_room_state.encryption IS NULL + """) + + +@upgrade_table.register(description="Add create event to room state cache") +async def upgrade_v4(conn: Connection) -> None: + await conn.execute("ALTER TABLE mx_room_state ADD COLUMN create_event TEXT") diff --git a/mautrix/client/state_store/file.py b/mautrix/client/state_store/file.py index 57c7a5ec..a5d53663 100644 --- a/mautrix/client/state_store/file.py +++ b/mautrix/client/state_store/file.py @@ -1,11 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import IO +from typing import IO, Any from pathlib import Path from mautrix.types import ( @@ -15,6 +15,7 @@ PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, + StateEvent, UserID, ) from mautrix.util.file_store import Filer, FileStore @@ -55,7 +56,7 @@ async def set_members( self._time_limited_flush() async def set_encryption_info( - self, room_id: RoomID, content: RoomEncryptionStateEventContent + self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: await super().set_encryption_info(room_id, content) self._time_limited_flush() @@ -65,3 +66,7 @@ async def set_power_levels( ) -> None: await super().set_power_levels(room_id, content) self._time_limited_flush() + + async def set_create(self, event: StateEvent) -> None: + await super().set_create(event) + self._time_limited_flush() diff --git a/mautrix/client/state_store/memory.py b/mautrix/client/state_store/memory.py index 6328c00d..2b010a75 100644 --- a/mautrix/client/state_store/memory.py +++ b/mautrix/client/state_store/memory.py @@ -1,12 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any -import sys +from typing import Any, TypedDict from mautrix.types import ( Member, @@ -15,22 +14,19 @@ PowerLevelStateEventContent, RoomEncryptionStateEventContent, RoomID, + StateEvent, UserID, ) from .abstract import StateStore -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - class SerializedStateStore(TypedDict): members: dict[RoomID, dict[UserID, Any]] full_member_list: dict[RoomID, bool] power_levels: dict[RoomID, Any] encryption: dict[RoomID, Any] + create: dict[RoomID, Any] class MemoryStateStore(StateStore): @@ -38,12 +34,14 @@ class MemoryStateStore(StateStore): full_member_list: dict[RoomID, bool] power_levels: dict[RoomID, PowerLevelStateEventContent] encryption: dict[RoomID, RoomEncryptionStateEventContent | None] + create: dict[RoomID, StateEvent] def __init__(self) -> None: self.members = {} self.full_member_list = {} self.power_levels = {} self.encryption = {} + self.create = {} def serialize(self) -> SerializedStateStore: """ @@ -64,6 +62,7 @@ def serialize(self) -> SerializedStateStore: room_id: (content.serialize() if content is not None else None) for room_id, content in self.encryption.items() }, + "create": {room_id: evt.serialize() for room_id, evt in self.create.items()}, } def deserialize(self, data: SerializedStateStore) -> None: @@ -90,6 +89,9 @@ def deserialize(self, data: SerializedStateStore) -> None: ) for room_id, content in data["encryption"].items() } + self.create = { + room_id: StateEvent.deserialize(evt) for room_id, evt in data["create"].items() + } async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None: try: @@ -176,10 +178,23 @@ async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent return self.power_levels.get(room_id) async def set_power_levels( - self, room_id: RoomID, content: PowerLevelStateEventContent + self, room_id: RoomID, content: PowerLevelStateEventContent | dict[str, Any] ) -> None: + if not isinstance(content, PowerLevelStateEventContent): + content = PowerLevelStateEventContent.deserialize(content) self.power_levels[room_id] = content + async def has_create_cached(self, room_id: RoomID) -> bool: + return room_id in self.create + + async def get_create(self, room_id: RoomID) -> StateEvent | None: + return self.create.get(room_id) + + async def set_create(self, event: StateEvent | dict[str, Any]) -> None: + if not isinstance(event, StateEvent): + event = StateEvent.deserialize(event) + self.create[event.room_id] = event + async def has_encryption_info_cached(self, room_id: RoomID) -> bool: return room_id in self.encryption @@ -193,6 +208,8 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent return self.encryption.get(room_id) async def set_encryption_info( - self, room_id: RoomID, content: RoomEncryptionStateEventContent + self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any] ) -> None: + if not isinstance(content, RoomEncryptionStateEventContent): + content = RoomEncryptionStateEventContent.deserialize(content) self.encryption[room_id] = content diff --git a/mautrix/client/state_store/sqlalchemy/__init__.py b/mautrix/client/state_store/sqlalchemy/__init__.py deleted file mode 100644 index 6ce78f46..00000000 --- a/mautrix/client/state_store/sqlalchemy/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mx_room_state import RoomState, SerializableType -from .mx_user_profile import UserProfile -from .sqlstatestore import SQLStateStore diff --git a/mautrix/client/state_store/sqlalchemy/mx_room_state.py b/mautrix/client/state_store/sqlalchemy/mx_room_state.py deleted file mode 100644 index 18adcbe0..00000000 --- a/mautrix/client/state_store/sqlalchemy/mx_room_state.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2021 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from __future__ import annotations - -from typing import Type -import json - -from sqlalchemy import Boolean, Column, Text, types - -from mautrix.types import ( - PowerLevelStateEventContent as PowerLevels, - RoomEncryptionStateEventContent as EncryptionInfo, - RoomID, - Serializable, -) -from mautrix.util.db import Base - - -class SerializableType(types.TypeDecorator): - impl = types.Text - - def __init__(self, python_type: Type[Serializable], *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._python_type = python_type - - @property - def python_type(self) -> Type[Serializable]: - return self._python_type - - def process_bind_param(self, value: Serializable, dialect) -> str | None: - if value is not None: - return json.dumps(value.serialize()) - return None - - def process_result_value(self, value: str, dialect) -> Serializable | None: - if value is not None: - return self.python_type.deserialize(json.loads(value)) - return None - - def process_literal_param(self, value, dialect): - return value - - -class RoomState(Base): - __tablename__ = "mx_room_state" - - room_id: RoomID = Column(Text, primary_key=True) - is_encrypted: bool = Column(Boolean, nullable=True) - has_full_member_list: bool = Column(Boolean, nullable=True) - encryption: EncryptionInfo = Column(SerializableType(EncryptionInfo), nullable=True) - power_levels: PowerLevels = Column(SerializableType(PowerLevels), nullable=True) - - @property - def has_power_levels(self) -> bool: - return bool(self.power_levels) - - @property - def has_encryption_info(self) -> bool: - return self.is_encrypted is not None - - @classmethod - def get(cls, room_id: RoomID) -> RoomState | None: - return cls._select_one_or_none(cls.c.room_id == room_id) diff --git a/mautrix/client/state_store/sqlalchemy/mx_user_profile.py b/mautrix/client/state_store/sqlalchemy/mx_user_profile.py deleted file mode 100644 index 731b6885..00000000 --- a/mautrix/client/state_store/sqlalchemy/mx_user_profile.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2021 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from __future__ import annotations - -from typing import Iterable - -from sqlalchemy import Column, Enum, Text - -from mautrix.types import ContentURI, Member, Membership, RoomID, UserID -from mautrix.util.db import Base - -from .mx_room_state import RoomState - - -class UserProfile(Base): - __tablename__ = "mx_user_profile" - - room_id: RoomID = Column(Text, primary_key=True) - user_id: UserID = Column(Text, primary_key=True) - membership: Membership = Column(Enum(Membership), nullable=False, default=Membership.LEAVE) - displayname: str = Column(Text, nullable=True) - avatar_url: ContentURI = Column(Text, nullable=True) - - def member(self) -> Member: - return Member( - membership=self.membership, displayname=self.displayname, avatar_url=self.avatar_url - ) - - @classmethod - def get(cls, room_id: RoomID, user_id: UserID) -> UserProfile | None: - return cls._select_one_or_none((cls.c.room_id == room_id) & (cls.c.user_id == user_id)) - - @classmethod - def all_in_room( - cls, - room_id: RoomID, - memberships: tuple[Membership, ...], - prefix: str = None, - suffix: str = None, - bot: str = None, - ) -> Iterable[UserProfile]: - clause = cls.c.membership == memberships[0] - for membership in memberships[1:]: - clause |= cls.c.membership == membership - clause &= cls.c.room_id == room_id - if bot: - clause &= cls.c.user_id != bot - if prefix: - clause &= ~cls.c.user_id.startswith(prefix, autoescape=True) - if suffix: - clause &= ~cls.c.user_id.startswith(suffix, autoescape=True) - return cls._select_all(clause) - - @classmethod - def find_rooms_with_user(cls, user_id: UserID) -> Iterable[UserProfile]: - return cls._select_all( - (cls.c.user_id == user_id) - & (cls.c.room_id == RoomState.c.room_id) - & (RoomState.c.is_encrypted == True) - ) - - @classmethod - def delete_all(cls, room_id: RoomID) -> None: - with cls.db.begin() as conn: - conn.execute(cls.t.delete().where(cls.c.room_id == room_id)) - - @classmethod - def bulk_replace( - cls, - room_id: RoomID, - members: dict[UserID, Member], - only_membership: Membership | None = None, - ) -> None: - with cls.db.begin() as conn: - delete_condition = cls.c.room_id == room_id - if only_membership is not None: - delete_condition &= cls.c.membership == only_membership - cls.db.execute(cls.t.delete().where(delete_condition)) - conn.execute( - cls.t.insert(), - [ - dict( - room_id=room_id, - user_id=user_id, - membership=member.membership, - displayname=member.displayname, - avatar_url=member.avatar_url, - ) - for user_id, member in members.items() - ], - ) diff --git a/mautrix/client/state_store/sqlalchemy/sqlstatestore.py b/mautrix/client/state_store/sqlalchemy/sqlstatestore.py deleted file mode 100644 index 159b4747..00000000 --- a/mautrix/client/state_store/sqlalchemy/sqlstatestore.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright (c) 2021 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from __future__ import annotations - -from mautrix.types import ( - Member, - Membership, - PowerLevelStateEventContent, - RoomEncryptionStateEventContent, - RoomID, - UserID, -) - -from ..abstract import StateStore -from .mx_room_state import RoomState -from .mx_user_profile import UserProfile - - -class SQLStateStore(StateStore): - _profile_cache: dict[RoomID, dict[UserID, UserProfile]] - _room_state_cache: dict[RoomID, RoomState] - - def __init__(self) -> None: - super().__init__() - self._profile_cache = {} - self._room_state_cache = {} - - def _get_user_profile( - self, room_id: RoomID, user_id: UserID, create: bool = False - ) -> UserProfile: - if not room_id: - raise ValueError("room_id is empty") - elif not user_id: - raise ValueError("user_id is empty") - try: - return self._profile_cache[room_id][user_id] - except KeyError: - pass - if room_id not in self._profile_cache: - self._profile_cache[room_id] = {} - - profile = UserProfile.get(room_id, user_id) - if profile: - self._profile_cache[room_id][user_id] = profile - elif create: - profile = UserProfile(room_id=room_id, user_id=user_id, membership=Membership.LEAVE) - profile.insert() - self._profile_cache[room_id][user_id] = profile - return profile - - async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None: - profile = self._get_user_profile(room_id, user_id) - if not profile: - return None - return profile.member() - - async def set_member(self, room_id: RoomID, user_id: UserID, member: Member) -> None: - if not member: - raise ValueError("member info is empty") - profile = self._get_user_profile(room_id, user_id, create=True) - profile.edit( - membership=member.membership, - displayname=member.displayname or profile.displayname, - avatar_url=member.avatar_url or profile.avatar_url, - ) - - async def set_membership( - self, room_id: RoomID, user_id: UserID, membership: Membership - ) -> None: - await self.set_member(room_id, user_id, Member(membership=membership)) - - async def get_member_profiles( - self, - room_id: RoomID, - memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), - ) -> dict[UserID, Member]: - self._profile_cache[room_id] = {} - for profile in UserProfile.all_in_room(room_id, memberships): - self._profile_cache[room_id][profile.user_id] = profile - return { - profile.user_id: profile.member() for profile in self._profile_cache[room_id].values() - } - - async def get_members_filtered( - self, - room_id: RoomID, - not_prefix: str, - not_suffix: str, - not_id: str, - memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE), - ) -> list[UserID]: - return [ - profile.user_id - for profile in UserProfile.all_in_room( - room_id, memberships, not_suffix, not_prefix, not_id - ) - ] - - async def set_members( - self, - room_id: RoomID, - members: dict[UserID, Member], - only_membership: Membership | None = None, - ) -> None: - UserProfile.bulk_replace(room_id, members, only_membership=only_membership) - self._get_room_state(room_id, create=True).edit(has_full_member_list=True) - try: - del self._profile_cache[room_id] - except KeyError: - pass - - async def has_full_member_list(self, room_id: RoomID) -> bool: - room = self._get_room_state(room_id) - if not room: - return False - return room.has_full_member_list - - async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]: - return [profile.room_id for profile in UserProfile.find_rooms_with_user(user_id)] - - def _get_room_state(self, room_id: RoomID, create: bool = False) -> RoomState: - if not room_id: - raise ValueError("room_id is empty") - try: - return self._room_state_cache[room_id] - except KeyError: - pass - - room = RoomState.get(room_id) - if room: - self._room_state_cache[room_id] = room - elif create: - room = RoomState(room_id=room_id) - room.insert() - self._room_state_cache[room_id] = room - return room - - async def has_power_levels_cached(self, room_id: RoomID) -> bool: - room = self._get_room_state(room_id) - if not room: - return False - return room.has_power_levels - - async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None: - room = self._get_room_state(room_id) - if not room: - return None - return room.power_levels - - async def set_power_levels( - self, room_id: RoomID, content: PowerLevelStateEventContent - ) -> None: - if not content: - raise ValueError("content is empty") - self._get_room_state(room_id, create=True).edit(power_levels=content) - - async def is_encrypted(self, room_id: RoomID) -> bool | None: - room = self._get_room_state(room_id) - if not room: - return None - return room.is_encrypted - - async def has_encryption_info_cached(self, room_id: RoomID) -> bool: - room = self._get_room_state(room_id) - return room and room.has_encryption_info - - async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None: - room = self._get_room_state(room_id) - if not room: - return None - return room.encryption - - async def set_encryption_info( - self, room_id: RoomID, content: RoomEncryptionStateEventContent - ) -> None: - if not content: - raise ValueError("content is empty") - self._get_room_state(room_id, create=True).edit(encryption=content, is_encrypted=True) diff --git a/mautrix/client/state_store/sync.py b/mautrix/client/state_store/sync.py index 0f5e087f..cc054495 100644 --- a/mautrix/client/state_store/sync.py +++ b/mautrix/client/state_store/sync.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/client/state_store/tests/store_test.py b/mautrix/client/state_store/tests/store_test.py index 3e9b6884..46eee750 100644 --- a/mautrix/client/state_store/tests/store_test.py +++ b/mautrix/client/state_store/tests/store_test.py @@ -1,9 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import AsyncContextManager, AsyncIterator, Callable, Dict, List +from __future__ import annotations + +from typing import AsyncContextManager, AsyncIterator, Callable from contextlib import asynccontextmanager import json import os @@ -14,15 +16,12 @@ import asyncpg import pytest -import sqlalchemy as sql from mautrix.types import EncryptionAlgorithm, Member, Membership, RoomID, StateEvent, UserID from mautrix.util.async_db import Database -from mautrix.util.db import Base from .. import MemoryStateStore, StateStore from ..asyncpg import PgStateStore -from ..sqlalchemy import RoomState, SQLStateStore, UserProfile @asynccontextmanager @@ -52,7 +51,7 @@ async def async_postgres_store() -> AsyncIterator[PgStateStore]: @asynccontextmanager async def async_sqlite_store() -> AsyncIterator[PgStateStore]: db = Database.create( - "sqlite:///:memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1} + "sqlite::memory:", upgrade_table=PgStateStore.upgrade_table, db_args={"min_size": 1} ) store = PgStateStore(db) await db.start() @@ -60,30 +59,19 @@ async def async_sqlite_store() -> AsyncIterator[PgStateStore]: await db.stop() -@asynccontextmanager -async def alchemy_store() -> AsyncIterator[SQLStateStore]: - db = sql.create_engine("sqlite:///:memory:") - Base.metadata.bind = db - for table in (RoomState, UserProfile): - table.bind(db) - Base.metadata.create_all() - yield SQLStateStore() - db.dispose() - - @asynccontextmanager async def memory_store() -> AsyncIterator[MemoryStateStore]: yield MemoryStateStore() -@pytest.fixture(params=[async_postgres_store, async_sqlite_store, alchemy_store, memory_store]) +@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store]) async def store(request) -> AsyncIterator[StateStore]: param: Callable[[], AsyncContextManager[StateStore]] = request.param async with param() as state_store: yield state_store -def read_state_file(request, file) -> Dict[RoomID, List[StateEvent]]: +def read_state_file(request, file) -> dict[RoomID, list[StateEvent]]: path = pathlib.Path(request.node.fspath).with_name(file) with path.open() as fp: content = json.load(fp) @@ -122,7 +110,6 @@ async def get_joined_members(request, store: StateStore) -> None: await store.set_members(room_id, parsed_members, only_membership=Membership.JOIN) -@pytest.mark.asyncio async def test_basic(store: StateStore) -> None: room_id = RoomID("!foo:example.com") user_id = UserID("@tulir:example.com") @@ -136,7 +123,6 @@ async def test_basic(store: StateStore) -> None: assert await store.is_encrypted(RoomID("!unknown-room:example.com")) is None -@pytest.mark.asyncio async def test_basic_updated(request, store: StateStore) -> None: await store_room_state(request, store) test_group = RoomID("!telegram-group:example.com") @@ -145,7 +131,6 @@ async def test_basic_updated(request, store: StateStore) -> None: assert not await store.is_encrypted(RoomID("!unencrypted-room:example.com")) -@pytest.mark.asyncio async def test_updates(request, store: StateStore) -> None: await store_room_state(request, store) room_id = RoomID("!telegram-group:example.com") diff --git a/mautrix/client/store_updater.py b/mautrix/client/store_updater.py index 27905044..35280324 100644 --- a/mautrix/client/store_updater.py +++ b/mautrix/client/store_updater.py @@ -1,10 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +from typing import Literal import asyncio from mautrix.errors import MForbidden, MNotFound @@ -77,6 +78,17 @@ async def leave_room( if not extra_content and self.state_store: await self.state_store.set_membership(room_id, self.mxid, Membership.LEAVE) + async def knock_room( + self, + room_id_or_alias: RoomID | RoomAlias, + reason: str | None = None, + servers: list[str] | None = None, + ) -> RoomID: + room_id = await super().knock_room(room_id_or_alias, reason, servers) + if room_id and self.state_store: + await self.state_store.set_membership(room_id, self.mxid, Membership.KNOCK) + return room_id + async def invite_user( self, room_id: RoomID, @@ -110,8 +122,14 @@ async def ban_user( if not extra_content and self.state_store: await self.state_store.set_membership(room_id, user_id, Membership.BAN) - async def unban_user(self, room_id: RoomID, user_id: UserID) -> None: - await super().unban_user(room_id, user_id) + async def unban_user( + self, + room_id: RoomID, + user_id: UserID, + reason: str = "", + extra_content: dict[str, JSON] | None = None, + ) -> None: + await super().unban_user(room_id, user_id, reason=reason, extra_content=extra_content) if self.state_store: await self.state_store.set_membership(room_id, user_id, Membership.LEAVE) @@ -132,6 +150,28 @@ async def get_state(self, room_id: RoomID) -> list[StateEvent]: ) return state + async def create_room(self, *args, **kwargs) -> RoomID: + room_id = await super().create_room(*args, **kwargs) + if self.state_store: + invitee_membership = Membership.INVITE + if kwargs.get("beeper_auto_join_invites"): + invitee_membership = Membership.JOIN + for user_id in kwargs.get("invitees", []): + await self.state_store.set_membership(room_id, user_id, invitee_membership) + for evt in kwargs.get("initial_state", []): + await self.state_store.update_state( + StateEvent( + type=EventType.find(evt["type"], t_class=EventType.Class.STATE), + room_id=room_id, + event_id=EventID("$fake-create-id"), + sender=self.mxid, + state_key=evt.get("state_key", ""), + timestamp=0, + content=evt["content"], + ) + ) + return room_id + async def send_state_event( self, room_id: RoomID, @@ -157,20 +197,28 @@ async def send_state_event( return event_id async def get_state_event( - self, room_id: RoomID, event_type: EventType, state_key: str = "" - ) -> StateEventContent: - event = await super().get_state_event(room_id, event_type, state_key) + self, + room_id: RoomID, + event_type: EventType, + state_key: str = "", + *, + format: str = "content", + ) -> StateEventContent | StateEvent: + event = await super().get_state_event(room_id, event_type, state_key, format=format) if self.state_store: - fake_event = StateEvent( - type=event_type, - room_id=room_id, - event_id=EventID(""), - sender=UserID(""), - state_key=state_key, - timestamp=0, - content=event, - ) - await self.state_store.update_state(fake_event) + if isinstance(event, StateEvent): + await self.state_store.update_state(event) + else: + fake_event = StateEvent( + type=event_type, + room_id=room_id, + event_id=EventID(""), + sender=UserID(""), + state_key=state_key, + timestamp=0, + content=event, + ) + await self.state_store.update_state(fake_event) return event async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]: diff --git a/mautrix/client/syncer.py b/mautrix/client/syncer.py index a7c7bf23..4b8661a2 100644 --- a/mautrix/client/syncer.py +++ b/mautrix/client/syncer.py @@ -1,21 +1,22 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Awaitable, Callable, Type +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Type, TypeVar from abc import ABC, abstractmethod -from contextlib import suppress from enum import Enum, Flag, auto import asyncio +import itertools import time from mautrix.errors import MUnknownToken from mautrix.types import ( JSON, AccountDataEvent, + BaseMessageEventContentFuncs, DeviceLists, DeviceOTKCount, EphemeralEvent, @@ -23,15 +24,16 @@ EventType, Filter, FilterID, - MessageEvent, + GenericEvent, PresenceState, + SerializerError, StateEvent, StrippedStateEvent, SyncToken, ToDeviceEvent, UserID, ) -from mautrix.types.event.message import BaseMessageEventContentFuncs +from mautrix.util import background_task from mautrix.util.logging import TraceLogger from . import dispatcher @@ -39,6 +41,8 @@ EventHandler = Callable[[Event], Awaitable[None]] +T = TypeVar("T", bound=Event) + class SyncStream(Flag): INTERNAL = auto() @@ -74,13 +78,18 @@ class InternalEventType(Enum): DEVICE_OTK_COUNT = auto() +class EventHandlerProps(NamedTuple): + wait_sync: bool + sync_stream: Optional[SyncStream] + + class Syncer(ABC): loop: asyncio.AbstractEventLoop log: TraceLogger mxid: UserID - global_event_handlers: list[tuple[EventHandler, bool]] - event_handlers: dict[EventType | InternalEventType, list[tuple[EventHandler, bool]]] + global_event_handlers: dict[EventHandler, EventHandlerProps] + event_handlers: dict[EventType | InternalEventType, dict[EventHandler, EventHandlerProps]] dispatchers: dict[Type[dispatcher.Dispatcher], dispatcher.Dispatcher] syncing_task: asyncio.Task | None ignore_initial_sync: bool @@ -90,7 +99,7 @@ class Syncer(ABC): sync_store: SyncStore def __init__(self, sync_store: SyncStore) -> None: - self.global_event_handlers = [] + self.global_event_handlers = {} self.event_handlers = {} self.dispatchers = {} self.syncing_task = None @@ -153,6 +162,7 @@ def add_event_handler( event_type: InternalEventType | EventType, handler: EventHandler, wait_sync: bool = False, + sync_stream: Optional[SyncStream] = None, ) -> None: """ Add a new event handler. @@ -162,13 +172,15 @@ def add_event_handler( event types. handler: The handler function to add. wait_sync: Whether or not the handler should be awaited before the next sync request. + sync_stream: The sync streams to listen to. Defaults to all. """ if not isinstance(event_type, (EventType, InternalEventType)): raise ValueError("Invalid event type") + props = EventHandlerProps(wait_sync=wait_sync, sync_stream=sync_stream) if event_type == EventType.ALL: - self.global_event_handlers.append((handler, wait_sync)) + self.global_event_handlers[handler] = props else: - self.event_handlers.setdefault(event_type, []).append((handler, wait_sync)) + self.event_handlers.setdefault(event_type, {})[handler] = props def remove_event_handler( self, event_type: EventType | InternalEventType, handler: EventHandler @@ -192,16 +204,12 @@ def remove_event_handler( # No handlers for this event type registered return - # FIXME this is a bit hacky - with suppress(ValueError): - handler_list.remove((handler, True)) - with suppress(ValueError): - handler_list.remove((handler, False)) + handler_list.pop(handler, None) if len(handler_list) == 0 and event_type != EventType.ALL: del self.event_handlers[event_type] - def dispatch_event(self, event: Event, source: SyncStream) -> list[asyncio.Task]: + def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asyncio.Task]: """ Send the given event to all applicable event handlers. @@ -209,6 +217,8 @@ def dispatch_event(self, event: Event, source: SyncStream) -> list[asyncio.Task] event: The event to send. source: The sync stream the event was received in. """ + if event is None: + return [] if isinstance(event.content, BaseMessageEventContentFuncs): event.content.trim_reply_fallback() if getattr(event, "state_key", None) is not None: @@ -222,7 +232,9 @@ def dispatch_event(self, event: Event, source: SyncStream) -> list[asyncio.Task] else: event.type = event.type.with_class(EventType.Class.MESSAGE) setattr(event, "source", source) - return self.dispatch_manual_event(event.type, event, include_global_handlers=True) + return self.dispatch_manual_event( + event.type, event, include_global_handlers=True, source=source + ) async def _catch_errors(self, handler: EventHandler, data: Any) -> None: try: @@ -236,15 +248,25 @@ def dispatch_manual_event( data: Any, include_global_handlers: bool = False, force_synchronous: bool = False, + source: Optional[SyncStream] = None, ) -> list[asyncio.Task]: - handlers = self.event_handlers.get(event_type, []) + handlers = self.event_handlers.get(event_type, {}).items() if include_global_handlers: - handlers = self.global_event_handlers + handlers + handlers = itertools.chain(self.global_event_handlers.items(), handlers) tasks = [] - for handler, wait_sync in handlers: - task = asyncio.create_task(self._catch_errors(handler, data)) - if force_synchronous or wait_sync: - tasks.append(task) + if source is None: + source = getattr(data, "source", None) + for handler, props in handlers: + if ( + props.sync_stream is not None + and source is not None + and not props.sync_stream & source + ): + continue + if force_synchronous or props.wait_sync: + tasks.append(asyncio.create_task(self._catch_errors(handler, data))) + else: + background_task.create(self._catch_errors(handler, data)) return tasks async def run_internal_event( @@ -252,7 +274,10 @@ async def run_internal_event( ) -> None: kwargs["source"] = SyncStream.INTERNAL tasks = self.dispatch_manual_event( - event_type, custom_type or kwargs, include_global_handlers=False + event_type, + custom_type if custom_type is not None else kwargs, + include_global_handlers=False, + source=SyncStream.INTERNAL, ) await asyncio.gather(*tasks) @@ -261,9 +286,23 @@ def dispatch_internal_event( ) -> list[asyncio.Task]: kwargs["source"] = SyncStream.INTERNAL return self.dispatch_manual_event( - event_type, custom_type or kwargs, include_global_handlers=False + event_type, + custom_type if custom_type is not None else kwargs, + include_global_handlers=False, + source=SyncStream.INTERNAL, ) + def _try_deserialize(self, type: Type[T], data: JSON) -> T | GenericEvent: + try: + return type.deserialize(data) + except SerializerError as e: + self.log.trace("Deserialization error traceback", exc_info=True) + self.log.warning(f"Failed to deserialize {data} into {type.__name__}: {e}") + try: + return GenericEvent.deserialize(data) + except SerializerError: + return None + def handle_sync(self, data: JSON) -> list[asyncio.Task]: """ Handle a /sync object. @@ -286,21 +325,22 @@ def handle_sync(self, data: JSON) -> list[asyncio.Task]: tasks += self.dispatch_internal_event( InternalEventType.DEVICE_LISTS, custom_type=DeviceLists( - changed=device_lists.get("changed", []), left=device_lists.get("left", []) + changed=device_lists.get("changed", []), + left=device_lists.get("left", []), ), ) for raw_event in data.get("account_data", {}).get("events", []): tasks += self.dispatch_event( - AccountDataEvent.deserialize(raw_event), source=SyncStream.ACCOUNT_DATA + self._try_deserialize(AccountDataEvent, raw_event), source=SyncStream.ACCOUNT_DATA ) for raw_event in data.get("ephemeral", {}).get("events", []): tasks += self.dispatch_event( - EphemeralEvent.deserialize(raw_event), source=SyncStream.EPHEMERAL + self._try_deserialize(EphemeralEvent, raw_event), source=SyncStream.EPHEMERAL ) for raw_event in data.get("to_device", {}).get("events", []): tasks += self.dispatch_event( - ToDeviceEvent.deserialize(raw_event), source=SyncStream.TO_DEVICE + self._try_deserialize(ToDeviceEvent, raw_event), source=SyncStream.TO_DEVICE ) rooms = data.get("rooms", {}) @@ -308,33 +348,46 @@ def handle_sync(self, data: JSON) -> list[asyncio.Task]: for raw_event in room_data.get("state", {}).get("events", []): raw_event["room_id"] = room_id tasks += self.dispatch_event( - StateEvent.deserialize(raw_event), + self._try_deserialize(StateEvent, raw_event), source=SyncStream.JOINED_ROOM | SyncStream.STATE, ) for raw_event in room_data.get("timeline", {}).get("events", []): raw_event["room_id"] = room_id tasks += self.dispatch_event( - Event.deserialize(raw_event), + self._try_deserialize(Event, raw_event), source=SyncStream.JOINED_ROOM | SyncStream.TIMELINE, ) + + for raw_event in room_data.get("ephemeral", {}).get("events", []): + raw_event["room_id"] = room_id + tasks += self.dispatch_event( + self._try_deserialize(EphemeralEvent, raw_event), + source=SyncStream.JOINED_ROOM | SyncStream.EPHEMERAL, + ) for room_id, room_data in rooms.get("invite", {}).items(): events: list[dict[str, JSON]] = room_data.get("invite_state", {}).get("events", []) for raw_event in events: raw_event["room_id"] = room_id - raw_invite = next( - raw_event - for raw_event in events - if raw_event.get("type", "") == "m.room.member" - and raw_event.get("state_key", "") == self.mxid - ) + try: + raw_invite = next( + raw_event + for raw_event in events + if raw_event.get("type", "") == "m.room.member" + and raw_event.get("state_key", "") == self.mxid + ) + except StopIteration: + self.log.warning( + f"Corrupted invite section in sync: no invite event present for {room_id}" + ) + continue # These aren't required by the spec, so make sure they're set raw_invite.setdefault("event_id", None) raw_invite.setdefault("origin_server_ts", int(time.time() * 1000)) - invite = StateEvent.deserialize(raw_invite) + invite = self._try_deserialize(StateEvent, raw_invite) invite.unsigned.invite_room_state = [ - StrippedStateEvent.deserialize(raw_event) + self._try_deserialize(StrippedStateEvent, raw_event) for raw_event in events if raw_event != raw_invite ] @@ -344,7 +397,7 @@ def handle_sync(self, data: JSON) -> list[asyncio.Task]: if "state_key" in raw_event: raw_event["room_id"] = room_id tasks += self.dispatch_event( - StateEvent.deserialize(raw_event), + self._try_deserialize(StateEvent, raw_event), source=SyncStream.LEFT_ROOM | SyncStream.TIMELINE, ) return tasks @@ -403,7 +456,8 @@ async def _start(self, filter_id: FilterID | None) -> None: raise except Exception as e: self.log.warning( - f"Sync request errored: {e}, waiting {fail_sleep} seconds before continuing" + f"Sync request errored: {type(e).__name__}: {e}, waiting {fail_sleep}" + " seconds before continuing" ) await self.run_internal_event( InternalEventType.SYNC_ERRORED, error=e, sleep_for=fail_sleep diff --git a/mautrix/crypto/__init__.py b/mautrix/crypto/__init__.py index d7c934b9..743fc2c6 100644 --- a/mautrix/crypto/__init__.py +++ b/mautrix/crypto/__init__.py @@ -1,7 +1,6 @@ from .account import OlmAccount from .key_share import RejectKeyShare -from .sessions import InboundGroupSession, OutboundGroupSession, Session -from .types import DecryptedOlmEvent, DeviceIdentity, TrustState +from .sessions import InboundGroupSession, OutboundGroupSession, RatchetSafety, Session # These have to be last from .store import ( # isort: skip @@ -13,3 +12,18 @@ ) from .machine import OlmMachine # isort: skip + +__all__ = [ + "OlmAccount", + "RejectKeyShare", + "InboundGroupSession", + "OutboundGroupSession", + "Session", + "CryptoStore", + "MemoryCryptoStore", + "PgCryptoStateStore", + "PgCryptoStore", + "StateStore", + "OlmMachine", + "attachments", +] diff --git a/mautrix/crypto/account.py b/mautrix/crypto/account.py index 48a433f3..a00ada71 100644 --- a/mautrix/crypto/account.py +++ b/mautrix/crypto/account.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -10,15 +10,17 @@ from mautrix.types import ( DeviceID, + DeviceKeys, EncryptionAlgorithm, EncryptionKeyAlgorithm, IdentityKey, + KeyID, SigningKey, UserID, ) -from . import base from .sessions import Session +from .signature import sign_olm class OlmAccount(olm.Account): @@ -74,19 +76,18 @@ def new_outbound_session(self, target_key: IdentityKey, one_time_key: IdentityKe session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now() ) - def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> Dict[str, Any]: - device_keys = { - "user_id": user_id, - "device_id": device_id, - "algorithms": [EncryptionAlgorithm.OLM_V1.value, EncryptionAlgorithm.MEGOLM_V1.value], - "keys": { - f"{algorithm}:{device_id}": key for algorithm, key in self.identity_keys.items() + def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> DeviceKeys: + device_keys = DeviceKeys( + user_id=user_id, + device_id=device_id, + algorithms=[EncryptionAlgorithm.OLM_V1, EncryptionAlgorithm.MEGOLM_V1], + keys={ + KeyID(algorithm=EncryptionKeyAlgorithm(algorithm), key_id=device_id): key + for algorithm, key in self.identity_keys.items() }, - } - signature = self.sign(base.canonical_json(device_keys)) - device_keys["signatures"] = { - user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature} - } + signatures={}, + ) + device_keys.signatures[user_id] = {KeyID.ed25519(device_id): sign_olm(device_keys, self)} return device_keys def get_one_time_keys( @@ -97,12 +98,12 @@ def get_one_time_keys( self.generate_one_time_keys(new_count) keys = {} for key_id, key in self.one_time_keys.get("curve25519", {}).items(): - signature = self.sign(base.canonical_json({"key": key})) - keys[f"{EncryptionKeyAlgorithm.SIGNED_CURVE25519}:{key_id}"] = { + keys[str(KeyID.signed_curve25519(IdentityKey(key_id)))] = { "key": key, "signatures": { - user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature} + user_id: { + str(KeyID.ed25519(device_id)): sign_olm({"key": key}, self), + } }, } - self.mark_keys_as_published() return keys diff --git a/mautrix/crypto/attachments/__init__.py b/mautrix/crypto/attachments/__init__.py index ecf9d839..41ee2c34 100644 --- a/mautrix/crypto/attachments/__init__.py +++ b/mautrix/crypto/attachments/__init__.py @@ -1,2 +1,21 @@ -from .async_attachments import async_encrypt_attachment, async_generator_from_data -from .attachments import decrypt_attachment, encrypt_attachment, encrypted_attachment_generator +from .async_attachments import ( + async_encrypt_attachment, + async_generator_from_data, + async_inplace_encrypt_attachment, +) +from .attachments import ( + decrypt_attachment, + encrypt_attachment, + encrypted_attachment_generator, + inplace_encrypt_attachment, +) + +__all__ = [ + "async_encrypt_attachment", + "async_generator_from_data", + "async_inplace_encrypt_attachment", + "decrypt_attachment", + "encrypt_attachment", + "encrypted_attachment_generator", + "inplace_encrypt_attachment", +] diff --git a/mautrix/crypto/attachments/async_attachments.py b/mautrix/crypto/attachments/async_attachments.py index 7d1d78a7..57b6fc6f 100644 --- a/mautrix/crypto/attachments/async_attachments.py +++ b/mautrix/crypto/attachments/async_attachments.py @@ -5,7 +5,7 @@ # any purpose with or without fee is hereby granted, provided that the # above copyright notice and this permission notice appear in all copies. # -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,7 +19,7 @@ from mautrix.types import EncryptedFile -from .attachments import AES, SHA256, Counter, Random, _get_decryption_info +from .attachments import _get_decryption_info, _prepare_encryption, inplace_encrypt_attachment async def async_encrypt_attachment( @@ -48,16 +48,9 @@ async def async_encrypt_attachment( | hashes.sha256: Base64 encoded SHA-256 hash of the ciphertext. """ - key = Random.new().read(32) - # 8 bytes IV - iv = Random.new().read(8) - # 8 bytes counter, prefixed by the IV - ctr = Counter.new(64, prefix=iv, initial_value=0) + key, iv, cipher, sha256 = _prepare_encryption() - cipher = AES.new(key, AES.MODE_CTR, counter=ctr) - sha256 = SHA256.new() - - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() async for chunk in async_generator_from_data(data): update_crypt = partial(cipher.encrypt, chunk) @@ -71,6 +64,11 @@ async def async_encrypt_attachment( yield _get_decryption_info(key, iv, sha256) +async def async_inplace_encrypt_attachment(data: bytearray) -> EncryptedFile: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, partial(inplace_encrypt_attachment, data)) + + async def async_generator_from_data( data: bytes | Iterable[bytes] | AsyncIterable[bytes] | io.BufferedIOBase, chunk_size: int = 4 * 1024, diff --git a/mautrix/crypto/attachments/async_attachments_test.py b/mautrix/crypto/attachments/async_attachments_test.py new file mode 100644 index 00000000..571f8ead --- /dev/null +++ b/mautrix/crypto/attachments/async_attachments_test.py @@ -0,0 +1,46 @@ +# Copyright © 2019 Damir Jelić (under the Apache 2.0 license) +# Copyright © 2019 miruka (under the Apache 2.0 license) +# Copyright (c) 2022 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from mautrix.types import EncryptedFile + +from .async_attachments import async_encrypt_attachment, async_inplace_encrypt_attachment +from .attachments import decrypt_attachment + +try: + from Crypto import Random +except ImportError: + from Cryptodome import Random + + +async def _get_data_cypher_keys(data: bytes) -> tuple[bytes, EncryptedFile]: + *chunks, keys = [i async for i in async_encrypt_attachment(data)] + return b"".join(chunks), keys + + +async def test_async_encrypt(): + data = b"Test bytes" + + cyphertext, keys = await _get_data_cypher_keys(data) + + plaintext = decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], keys.iv) + + assert data == plaintext + + +async def test_async_inplace_encrypt(): + orig_data = b"Test bytes" + data = bytearray(orig_data) + + keys = await async_inplace_encrypt_attachment(data) + + assert data != orig_data + + decrypt_attachment(data, keys.key.key, keys.hashes["sha256"], keys.iv, inplace=True) + + assert data == orig_data diff --git a/mautrix/crypto/attachments/attachments.py b/mautrix/crypto/attachments/attachments.py index c0054aaf..80b62212 100644 --- a/mautrix/crypto/attachments/attachments.py +++ b/mautrix/crypto/attachments/attachments.py @@ -1,6 +1,6 @@ # Copyright 2018 Zil0 (under the Apache 2.0 license) # Copyright © 2019 Damir Jelić (under the Apache 2.0 license) -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -29,7 +29,9 @@ from Cryptodome.Util import Counter -def decrypt_attachment(ciphertext: bytes, key: str, hash: str, iv: str) -> bytes: +def decrypt_attachment( + ciphertext: bytes | bytearray | memoryview, key: str, hash: str, iv: str, inplace: bool = False +) -> bytes: """Decrypt an encrypted attachment. Args: @@ -37,12 +39,12 @@ def decrypt_attachment(ciphertext: bytes, key: str, hash: str, iv: str) -> bytes key: AES_CTR JWK key object. hash: Base64 encoded SHA-256 hash of the ciphertext. iv: Base64 encoded 16 byte AES-CTR IV. + inplace: Should the decryption be performed in-place? + The input must be a bytearray or writable memoryview to use this. Returns: The plaintext bytes. Raises: EncryptionError: if the integrity check fails. - - """ expected_hash = unpaddedbase64.decode_base64(hash) @@ -50,21 +52,23 @@ def decrypt_attachment(ciphertext: bytes, key: str, hash: str, iv: str) -> bytes h.update(ciphertext) if h.digest() != expected_hash: - raise DecryptionError("Mismatched SHA-256 digest.") + raise DecryptionError("Mismatched SHA-256 digest") try: byte_key: bytes = unpaddedbase64.decode_base64(key) except (binascii.Error, TypeError): - raise DecryptionError("Error decoding key.") + raise DecryptionError("Error decoding key") try: byte_iv: bytes = unpaddedbase64.decode_base64(iv) + if len(byte_iv) != 16: + raise DecryptionError("Invalid IV length") prefix = byte_iv[:8] # A non-zero IV counter is not spec-compliant, but some clients still do it, # so decode the counter part too. initial_value = struct.unpack(">Q", byte_iv[8:])[0] except (binascii.Error, TypeError, IndexError, struct.error): - raise DecryptionError("Error decoding initial values.") + raise DecryptionError("Error decoding IV") ctr = Counter.new(64, prefix=prefix, initial_value=initial_value) @@ -73,7 +77,11 @@ def decrypt_attachment(ciphertext: bytes, key: str, hash: str, iv: str) -> bytes except ValueError as e: raise DecryptionError("Failed to create AES cipher") from e - return cipher.decrypt(ciphertext) + if inplace: + cipher.decrypt(ciphertext, ciphertext) + return ciphertext + else: + return cipher.decrypt(ciphertext) def encrypt_attachment(plaintext: bytes) -> tuple[bytes, EncryptedFile]: @@ -90,6 +98,28 @@ def encrypt_attachment(plaintext: bytes) -> tuple[bytes, EncryptedFile]: return b"".join(values[:-1]), values[-1] +def _prepare_encryption() -> tuple[bytes, bytes, AES, SHA256.SHA256Hash]: + key = Random.new().read(32) + # 8 bytes IV + iv = Random.new().read(8) + # 8 bytes counter, prefixed by the IV + ctr = Counter.new(64, prefix=iv, initial_value=0) + + cipher = AES.new(key, AES.MODE_CTR, counter=ctr) + sha256 = SHA256.new() + + return key, iv, cipher, sha256 + + +def inplace_encrypt_attachment(data: bytearray | memoryview) -> EncryptedFile: + key, iv, cipher, sha256 = _prepare_encryption() + + cipher.encrypt(plaintext=data, output=data) + sha256.update(data) + + return _get_decryption_info(key, iv, sha256) + + def encrypted_attachment_generator( data: bytes | Iterable[bytes], ) -> Generator[bytes | EncryptedFile, None, None]: @@ -105,21 +135,10 @@ def encrypted_attachment_generator( Yields: The encrypted bytes for each chunk of data. - The last yielded value will be a dict containing the info needed to - decrypt data. The keys are: - | key: AES-CTR JWK key object. - | iv: Base64 encoded 16 byte AES-CTR IV. - | hashes.sha256: Base64 encoded SHA-256 hash of the ciphertext. + The last yielded value will be a dict containing the info needed to decrypt data. """ - key = Random.new().read(32) - # 8 bytes IV - iv = Random.new().read(8) - # 8 bytes counter, prefixed by the IV - ctr = Counter.new(64, prefix=iv, initial_value=0) - - cipher = AES.new(key, AES.MODE_CTR, counter=ctr) - sha256 = SHA256.new() + key, iv, cipher, sha256 = _prepare_encryption() if isinstance(data, bytes): data = [data] diff --git a/mautrix/crypto/attachments/attachments_test.py b/mautrix/crypto/attachments/attachments_test.py new file mode 100644 index 00000000..a8cb83b6 --- /dev/null +++ b/mautrix/crypto/attachments/attachments_test.py @@ -0,0 +1,111 @@ +# Copyright © 2019 Damir Jelić (under the Apache 2.0 license) +# Copyright (c) 2022 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import pytest +import unpaddedbase64 + +from mautrix.errors import DecryptionError + +from .attachments import decrypt_attachment, encrypt_attachment, inplace_encrypt_attachment + +try: + from Crypto import Random +except ImportError: + from Cryptodome import Random + + +def test_encrypt(): + data = b"Test bytes" + + cyphertext, keys = encrypt_attachment(data) + + plaintext = decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], keys.iv) + + assert data == plaintext + + +def test_inplace_encrypt(): + orig_data = b"Test bytes" + data = bytearray(orig_data) + + keys = inplace_encrypt_attachment(data) + + assert data != orig_data + + decrypt_attachment(data, keys.key.key, keys.hashes["sha256"], keys.iv, inplace=True) + + assert data == orig_data + + +def test_hash_verification(): + data = b"Test bytes" + + cyphertext, keys = encrypt_attachment(data) + + with pytest.raises(DecryptionError): + decrypt_attachment(cyphertext, keys.key.key, "Fake hash", keys.iv) + + +def test_invalid_key(): + data = b"Test bytes" + + cyphertext, keys = encrypt_attachment(data) + + with pytest.raises(DecryptionError): + decrypt_attachment(cyphertext, "Fake key", keys.hashes["sha256"], keys.iv) + + +def test_invalid_iv(): + data = b"Test bytes" + + cyphertext, keys = encrypt_attachment(data) + + with pytest.raises(DecryptionError): + decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], "Fake iv") + + +def test_short_key(): + data = b"Test bytes" + + cyphertext, keys = encrypt_attachment(data) + + with pytest.raises(DecryptionError): + decrypt_attachment( + cyphertext, + unpaddedbase64.encode_base64(b"Fake key", urlsafe=True), + keys["hashes"]["sha256"], + keys["iv"], + ) + + +def test_short_iv(): + data = b"Test bytes" + + cyphertext, keys = encrypt_attachment(data) + + with pytest.raises(DecryptionError): + decrypt_attachment( + cyphertext, + keys.key.key, + keys.hashes["sha256"], + unpaddedbase64.encode_base64(b"F" + b"\x00" * 8), + ) + + +def test_fake_key(): + data = b"Test bytes" + + cyphertext, keys = encrypt_attachment(data) + + fake_key = Random.new().read(32) + + plaintext = decrypt_attachment( + cyphertext, + unpaddedbase64.encode_base64(fake_key, urlsafe=True), + keys["hashes"]["sha256"], + keys["iv"], + ) + assert plaintext != data diff --git a/mautrix/crypto/base.py b/mautrix/crypto/base.py index 4b27b4c9..f7015ca1 100644 --- a/mautrix/crypto/base.py +++ b/mautrix/crypto/base.py @@ -1,64 +1,66 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Awaitable, Callable, Dict +from typing import Awaitable, Callable import asyncio -import functools -import json -import sys - -import olm +from mautrix.errors import MForbidden, MNotFound from mautrix.types import ( - DeviceID, - EncryptionKeyAlgorithm, + EventType, IdentityKey, RequestedKeyInfo, + RoomEncryptionStateEventContent, RoomID, + RoomKeyEventContent, SessionID, - SigningKey, + TrustState, UserID, ) from mautrix.util.logging import TraceLogger from .. import client as cli, crypto - -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -class SignedObject(TypedDict): - signatures: Dict[UserID, Dict[str, str]] - unsigned: Any +from .ssss import Machine as SSSSMachine class BaseOlmMachine: client: cli.Client + ssss: SSSSMachine log: TraceLogger crypto_store: crypto.CryptoStore state_store: crypto.StateStore account: account.OlmAccount - allow_unverified_devices: bool - share_to_unverified_devices: bool + send_keys_min_trust: TrustState + share_keys_min_trust: TrustState allow_key_share: Callable[[crypto.DeviceIdentity, RequestedKeyInfo], Awaitable[bool]] + delete_outbound_keys_on_ack: bool + dont_store_outbound_keys: bool + delete_previous_keys_on_receive: bool + ratchet_keys_on_decrypt: bool + delete_fully_used_keys_on_decrypt: bool + delete_keys_on_device_delete: bool + disable_device_change_key_rotation: bool + # Futures that wait for responses to a key request - _key_request_waiters: Dict[SessionID, asyncio.Future] + _key_request_waiters: dict[SessionID, asyncio.Future] # Futures that wait for a session to be received (either normally or through a key request) - _inbound_session_waiters: Dict[SessionID, asyncio.Future] + _inbound_session_waiters: dict[SessionID, asyncio.Future] - _prev_unwedge: Dict[IdentityKey, float] + _prev_unwedge: dict[IdentityKey, float] + _fetch_keys_lock: asyncio.Lock + _megolm_decrypt_lock: asyncio.Lock + _share_keys_lock: asyncio.Lock + _last_key_share: float + _cs_fetch_attempted: set[UserID] async def wait_for_session( - self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID, timeout: float = 3 + self, room_id: RoomID, session_id: SessionID, timeout: float = 3 ) -> bool: try: fut = self._inbound_session_waiters[session_id] @@ -68,7 +70,7 @@ async def wait_for_session( try: return await asyncio.wait_for(asyncio.shield(fut), timeout) except asyncio.TimeoutError: - return await self.crypto_store.has_group_session(room_id, sender_key, session_id) + return await self.crypto_store.has_group_session(room_id, session_id) def _mark_session_received(self, session_id: SessionID) -> None: try: @@ -76,22 +78,30 @@ def _mark_session_received(self, session_id: SessionID) -> None: except KeyError: return - -canonical_json = functools.partial( - json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True -) - - -def verify_signature_json( - data: "SignedObject", user_id: UserID, device_id: DeviceID, key: SigningKey -) -> bool: - data_copy = {**data} - data_copy.pop("unsigned", None) - signatures = data_copy.pop("signatures") - signature = signatures[user_id][f"{EncryptionKeyAlgorithm.ED25519}:{device_id}"] - signed_data = canonical_json(data_copy) - try: - olm.ed25519_verify(key, signed_data, signature) - return True - except olm.OlmVerifyError: - return False + async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None: + encryption_info = await self.state_store.get_encryption_info(evt.room_id) + if not encryption_info: + self.log.warning( + f"Encryption info for {evt.room_id} not found in state store, fetching from server" + ) + try: + encryption_info = await self.client.get_state_event( + evt.room_id, EventType.ROOM_ENCRYPTION + ) + except (MNotFound, MForbidden) as e: + self.log.warning( + f"Failed to get encryption info for {evt.room_id} from server: {e}," + " using defaults" + ) + encryption_info = RoomEncryptionStateEventContent() + if not encryption_info: + self.log.warning( + f"Didn't find encryption info for {evt.room_id} on server either," + " using defaults" + ) + encryption_info = RoomEncryptionStateEventContent() + + if not evt.beeper_max_age_ms: + evt.beeper_max_age_ms = encryption_info.rotation_period_ms + if not evt.beeper_max_messages: + evt.beeper_max_messages = encryption_info.rotation_period_msgs diff --git a/mautrix/crypto/cross_signing.py b/mautrix/crypto/cross_signing.py new file mode 100644 index 00000000..7577cd5d --- /dev/null +++ b/mautrix/crypto/cross_signing.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from ..types import ( + JSON, + CrossSigner, + CrossSigningKeys, + CrossSigningUsage, + DeviceIdentity, + EventType, + KeyID, + UserID, +) +from .cross_signing_key import CrossSigningPrivateKeys, CrossSigningPublicKeys, CrossSigningSeeds +from .device_lists import DeviceListMachine +from .signature import sign_olm +from .ssss import Key as SSSSKey + + +class CrossSigningMachine(DeviceListMachine): + _cross_signing_public_keys: CrossSigningPublicKeys | None + _cross_signing_public_keys_fetched: bool + _cross_signing_private_keys: CrossSigningPrivateKeys | None + + async def verify_with_recovery_key(self, recovery_key: str) -> None: + if not self.account.shared: + raise ValueError("Device keys must be shared before verifying with recovery key") + key_id, key_data = await self.ssss.get_default_key_data() + ssss_key = key_data.verify_recovery_key(key_id, recovery_key) + seeds = await self._fetch_cross_signing_keys_from_ssss(ssss_key) + self._import_cross_signing_keys(seeds) + await self.sign_own_device(self.own_identity) + + def _import_cross_signing_keys(self, seeds: CrossSigningSeeds) -> None: + self._cross_signing_private_keys = seeds.to_keys() + self._cross_signing_public_keys = self._cross_signing_private_keys.public_keys + + async def generate_recovery_key( + self, passphrase: str | None = None, seeds: CrossSigningSeeds | None = None + ) -> str: + if not self.account.shared: + raise ValueError("Device keys must be shared before generating recovery key") + seeds = seeds or CrossSigningSeeds.generate() + ssss_key = await self.ssss.generate_and_upload_key(passphrase) + await self._upload_cross_signing_keys_to_ssss(ssss_key, seeds) + await self._publish_cross_signing_keys(seeds.to_keys()) + await self.ssss.set_default_key_id(ssss_key.id) + await self.sign_own_device(self.own_identity) + return ssss_key.recovery_key + + async def _fetch_cross_signing_keys_from_ssss(self, key: SSSSKey) -> CrossSigningSeeds: + return CrossSigningSeeds( + master_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_MASTER, key + ), + user_signing_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_USER_SIGNING, key + ), + self_signing_key=await self.ssss.get_decrypted_account_data( + EventType.CROSS_SIGNING_SELF_SIGNING, key + ), + ) + + async def _upload_cross_signing_keys_to_ssss( + self, key: SSSSKey, seeds: CrossSigningSeeds + ) -> None: + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_MASTER, seeds.master_key, key + ) + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_USER_SIGNING, seeds.user_signing_key, key + ) + await self.ssss.set_encrypted_account_data( + EventType.CROSS_SIGNING_SELF_SIGNING, seeds.self_signing_key, key + ) + + async def get_own_cross_signing_public_keys(self) -> CrossSigningPublicKeys | None: + if self._cross_signing_public_keys or self._cross_signing_public_keys_fetched: + return self._cross_signing_public_keys + keys = await self.get_cross_signing_public_keys(self.client.mxid) + self._cross_signing_public_keys_fetched = True + if keys: + self._cross_signing_public_keys = keys + return keys + + async def get_cross_signing_public_keys( + self, user_id: UserID + ) -> CrossSigningPublicKeys | None: + db_keys = await self.crypto_store.get_cross_signing_keys(user_id) + if CrossSigningUsage.MASTER not in db_keys and user_id not in self._cs_fetch_attempted: + self.log.debug(f"Didn't find any cross-signing keys for {user_id}, fetching...") + async with self._fetch_keys_lock: + if user_id not in self._cs_fetch_attempted: + self._cs_fetch_attempted.add(user_id) + await self._fetch_keys([user_id], include_untracked=True) + db_keys = await self.crypto_store.get_cross_signing_keys(user_id) + if CrossSigningUsage.MASTER not in db_keys: + return None + return CrossSigningPublicKeys( + master_key=db_keys[CrossSigningUsage.MASTER].key, + self_signing_key=( + db_keys[CrossSigningUsage.SELF].key if CrossSigningUsage.SELF in db_keys else None + ), + user_signing_key=( + db_keys[CrossSigningUsage.USER].key if CrossSigningUsage.USER in db_keys else None + ), + ) + + async def sign_own_device(self, device: DeviceIdentity) -> None: + full_keys = await self._get_full_device_keys(device) + ssk = self._cross_signing_private_keys.self_signing_key + signature = sign_olm(full_keys, ssk) + full_keys.signatures = {self.client.mxid: {KeyID.ed25519(ssk.public_key): signature}} + await self.client.upload_one_signature(device.user_id, device.device_id, full_keys) + await self.crypto_store.put_signature( + CrossSigner(device.user_id, device.signing_key), + CrossSigner(self.client.mxid, ssk.public_key), + signature, + ) + + async def _publish_cross_signing_keys( + self, + keys: CrossSigningPrivateKeys, + auth: dict[str, JSON] | None = None, + ) -> None: + public = keys.public_keys + master_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.MASTER], + keys={KeyID.ed25519(public.master_key): public.master_key}, + ) + master_key.signatures = { + self.client.mxid: { + KeyID.ed25519(self.client.device_id): sign_olm(master_key, self.account), + } + } + self_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.SELF], + keys={KeyID.ed25519(public.self_signing_key): public.self_signing_key}, + ) + self_key.signatures = { + self.client.mxid: { + KeyID.ed25519(public.master_key): sign_olm(self_key, keys.master_key), + } + } + user_key = CrossSigningKeys( + user_id=self.client.mxid, + usage=[CrossSigningUsage.USER], + keys={KeyID.ed25519(public.user_signing_key): public.user_signing_key}, + ) + user_key.signatures = { + self.client.mxid: { + KeyID.ed25519(public.master_key): sign_olm(user_key, keys.master_key), + } + } + await self.client.upload_cross_signing_keys( + keys={ + CrossSigningUsage.MASTER: master_key, + CrossSigningUsage.SELF: self_key, + CrossSigningUsage.USER: user_key, + }, + auth=auth, + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.MASTER, public.master_key + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.SELF, public.self_signing_key + ) + await self.crypto_store.put_cross_signing_key( + self.client.mxid, CrossSigningUsage.USER, public.user_signing_key + ) + self._cross_signing_private_keys = keys + self._cross_signing_public_keys = public diff --git a/mautrix/crypto/cross_signing_key.py b/mautrix/crypto/cross_signing_key.py new file mode 100644 index 00000000..f4e3c1c1 --- /dev/null +++ b/mautrix/crypto/cross_signing_key.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import NamedTuple + +import olm + +from mautrix.crypto.ssss.util import cryptorand +from mautrix.types import SigningKey + + +class CrossSigningPublicKeys(NamedTuple): + master_key: SigningKey + self_signing_key: SigningKey + user_signing_key: SigningKey + + +class CrossSigningPrivateKeys(NamedTuple): + master_key: olm.PkSigning + self_signing_key: olm.PkSigning + user_signing_key: olm.PkSigning + + @property + def public_keys(self) -> CrossSigningPublicKeys: + return CrossSigningPublicKeys( + master_key=self.master_key.public_key, + self_signing_key=self.self_signing_key.public_key, + user_signing_key=self.user_signing_key.public_key, + ) + + +class CrossSigningSeeds(NamedTuple): + master_key: bytes + self_signing_key: bytes + user_signing_key: bytes + + def to_keys(self) -> CrossSigningPrivateKeys: + return CrossSigningPrivateKeys( + master_key=olm.PkSigning(self.master_key), + self_signing_key=olm.PkSigning(self.self_signing_key), + user_signing_key=olm.PkSigning(self.user_signing_key), + ) + + @classmethod + def generate(cls) -> "CrossSigningSeeds": + return cls( + master_key=cryptorand.read(32), + self_signing_key=cryptorand.read(32), + user_signing_key=cryptorand.read(32), + ) diff --git a/mautrix/crypto/decrypt_megolm.py b/mautrix/crypto/decrypt_megolm.py index fd654808..fd7a7eef 100644 --- a/mautrix/crypto/decrypt_megolm.py +++ b/mautrix/crypto/decrypt_megolm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -15,13 +15,20 @@ SessionNotFound, VerificationError, ) -from mautrix.types import EncryptedEvent, EncryptedMegolmEventContent, EncryptionAlgorithm, Event +from mautrix.types import ( + EncryptedEvent, + EncryptedMegolmEventContent, + EncryptionAlgorithm, + Event, + SessionID, + TrustState, +) -from .base import BaseOlmMachine -from .types import TrustState +from .device_lists import DeviceListMachine +from .sessions import InboundGroupSession -class MegolmDecryptionMachine(BaseOlmMachine): +class MegolmDecryptionMachine(DeviceListMachine): async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event: """ Decrypt an event that was encrypted using Megolm. @@ -39,38 +46,64 @@ async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event: raise DecryptionError("Unsupported event content class") elif evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1: raise DecryptionError("Unsupported event encryption algorithm") - session = await self.crypto_store.get_group_session( - evt.room_id, evt.content.sender_key, evt.content.session_id - ) - if session is None: - # TODO check if olm session is wedged - raise SessionNotFound(evt.content.session_id, evt.content.sender_key) - try: - plaintext, index = session.decrypt(evt.content.ciphertext) - except olm.OlmGroupSessionError as e: - raise DecryptionError("Failed to decrypt megolm event") from e - if not await self.crypto_store.validate_message_index( - evt.content.sender_key, evt.content.session_id, evt.event_id, index, evt.timestamp - ): - raise DuplicateMessageIndex() + async with self._megolm_decrypt_lock: + session = await self.crypto_store.get_group_session( + evt.room_id, evt.content.session_id + ) + if session is None: + # TODO check if olm session is wedged + raise SessionNotFound(evt.content.session_id, evt.content.sender_key) + try: + plaintext, index = session.decrypt(evt.content.ciphertext) + except olm.OlmGroupSessionError as e: + raise DecryptionError("Failed to decrypt megolm event") from e + if not await self.crypto_store.validate_message_index( + session.sender_key, SessionID(session.id), evt.event_id, index, evt.timestamp + ): + raise DuplicateMessageIndex() + await self._ratchet_session(session, index) - verified = False + forwarded_keys = False if ( evt.content.device_id == self.client.device_id and session.signing_key == self.account.signing_key - and evt.content.sender_key == self.account.identity_key + and session.sender_key == self.account.identity_key + and not session.forwarding_chain ): - verified = True + trust_level = TrustState.VERIFIED else: - device = await self.crypto_store.get_device(evt.sender, evt.content.device_id) - if device and device.trust == TrustState.VERIFIED and not session.forwarding_chain: - if ( + device = await self.get_or_fetch_device_by_key(evt.sender, session.sender_key) + if not session.forwarding_chain or ( + len(session.forwarding_chain) == 1 + and session.forwarding_chain[0] == session.sender_key + ): + if not device: + self.log.debug( + f"Couldn't resolve trust level of session {session.id}: " + f"sent by unknown device {evt.sender}/{session.sender_key}" + ) + trust_level = TrustState.UNKNOWN_DEVICE + elif ( device.signing_key != session.signing_key - or device.identity_key != evt.content.sender_key + or device.identity_key != session.sender_key ): raise VerificationError() - verified = True - # else: TODO query device keys? + else: + trust_level = await self.resolve_trust(device) + else: + forwarded_keys = True + last_chain_item = session.forwarding_chain[-1] + received_from = await self.crypto_store.find_device_by_key( + evt.sender, last_chain_item + ) + if received_from: + trust_level = await self.resolve_trust(received_from) + else: + self.log.debug( + f"Couldn't resolve trust level of session {session.id}: " + f"forwarding chain ends with unknown device {last_chain_item}" + ) + trust_level = TrustState.FORWARDED try: data = json.loads(plaintext) @@ -100,6 +133,60 @@ async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event: result.unsigned = evt.unsigned result.type = result.type.with_class(evt.type.t_class) result["mautrix"] = { - "verified": verified, + "trust_state": trust_level, + "forwarded_keys": forwarded_keys, + "was_encrypted": True, } return result + + async def _ratchet_session(self, sess: InboundGroupSession, index: int) -> None: + expected_message_index = sess.ratchet_safety.next_index + did_modify = True + if index > expected_message_index: + sess.ratchet_safety.missed_indices += list(range(expected_message_index, index)) + sess.ratchet_safety.next_index = index + 1 + elif index == expected_message_index: + sess.ratchet_safety.next_index = index + 1 + else: + try: + sess.ratchet_safety.missed_indices.remove(index) + except ValueError: + did_modify = False + # Use presence of received_at as a sign that this is a recent megolm session, + # and therefore it's safe to drop missed indices entirely. + if ( + sess.received_at + and sess.ratchet_safety.missed_indices + and sess.ratchet_safety.missed_indices[0] < expected_message_index - 10 + ): + i = 0 + for i, lost_index in enumerate(sess.ratchet_safety.missed_indices): + if lost_index < expected_message_index - 10: + sess.ratchet_safety.lost_indices.append(lost_index) + else: + break + sess.ratchet_safety.missed_indices = sess.ratchet_safety.missed_indices[i + 1 :] + ratchet_target_index = sess.ratchet_safety.next_index + if len(sess.ratchet_safety.missed_indices) > 0: + ratchet_target_index = min(sess.ratchet_safety.missed_indices) + self.log.debug( + f"Ratchet safety info for {sess.id}: {sess.ratchet_safety}, {ratchet_target_index=}" + ) + sess_id = SessionID(sess.id) + if ( + sess.max_messages + and ratchet_target_index >= sess.max_messages + and not sess.ratchet_safety.missed_indices + and self.delete_fully_used_keys_on_decrypt + ): + self.log.info(f"Deleting fully used session {sess.id}") + await self.crypto_store.redact_group_session( + sess.room_id, sess_id, reason="maximum messages reached" + ) + return + elif sess.first_known_index < ratchet_target_index and self.ratchet_keys_on_decrypt: + self.log.info(f"Ratcheting session {sess.id} to {ratchet_target_index}") + sess = sess.ratchet_to(ratchet_target_index) + elif not did_modify: + return + await self.crypto_store.put_group_session(sess.room_id, sess.sender_key, sess_id, sess) diff --git a/mautrix/crypto/decrypt_olm.py b/mautrix/crypto/decrypt_olm.py index 4cc777c9..6eef76f5 100644 --- a/mautrix/crypto/decrypt_olm.py +++ b/mautrix/crypto/decrypt_olm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -10,6 +10,7 @@ from mautrix.errors import DecryptionError, MatchingSessionDecryptionError from mautrix.types import ( + DecryptedOlmEvent, EncryptedOlmEventContent, EncryptionAlgorithm, IdentityKey, @@ -18,10 +19,10 @@ ToDeviceEvent, UserID, ) +from mautrix.util import background_task from .base import BaseOlmMachine from .sessions import Session -from .types import DecryptedOlmEvent class OlmDecryptionMachine(BaseOlmMachine): @@ -74,19 +75,19 @@ async def _decrypt_olm_ciphertext( f"Found matching session yet decryption failed for sender {sender}" f" with key {sender_key}" ) - asyncio.create_task(self._unwedge_session(sender, sender_key)) + background_task.create(self._unwedge_session(sender, sender_key)) raise if not plaintext: if message.type != OlmMsgType.PREKEY: - asyncio.create_task(self._unwedge_session(sender, sender_key)) + background_task.create(self._unwedge_session(sender, sender_key)) raise DecryptionError("Decryption failed for normal message") self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}") try: session = await self._create_inbound_session(sender_key, message.body) except olm.OlmSessionError as e: - asyncio.create_task(self._unwedge_session(sender, sender_key)) + background_task.create(self._unwedge_session(sender, sender_key)) raise DecryptionError("Failed to create new session from prekey message") from e self.log.debug( f"Created inbound session {session.id} for {sender} (sender key: {sender_key})" diff --git a/mautrix/crypto/device_lists.py b/mautrix/crypto/device_lists.py index 9fe01ea7..b00b104c 100644 --- a/mautrix/crypto/device_lists.py +++ b/mautrix/crypto/device_lists.py @@ -1,21 +1,48 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Dict, List, Optional +from __future__ import annotations from mautrix.errors import DeviceValidationError -from mautrix.types import DeviceID, DeviceKeys, IdentityKey, SyncToken, UserID +from mautrix.types import ( + CrossSigner, + CrossSigningKeys, + CrossSigningUsage, + DeviceID, + DeviceIdentity, + DeviceKeys, + EncryptionKeyAlgorithm, + IdentityKey, + KeyID, + QueryKeysResponse, + SigningKey, + SyncToken, + TrustState, + UserID, +) -from .base import BaseOlmMachine, verify_signature_json -from .types import DeviceIdentity, TrustState +from .base import BaseOlmMachine +from .signature import verify_signature_json class DeviceListMachine(BaseOlmMachine): + @property + def own_identity(self) -> DeviceIdentity: + return DeviceIdentity( + user_id=self.client.mxid, + device_id=self.client.device_id, + identity_key=self.account.identity_key, + signing_key=self.account.signing_key, + trust=TrustState.VERIFIED, + deleted=False, + name="", + ) + async def _fetch_keys( - self, users: List[UserID], since: SyncToken = "", include_untracked: bool = False - ) -> Dict[UserID, Dict[DeviceID, DeviceIdentity]]: + self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False + ) -> dict[UserID, dict[DeviceID, DeviceIdentity]]: if not include_untracked: users = await self.crypto_store.filter_tracked_users(users) if len(users) == 0: @@ -23,54 +50,203 @@ async def _fetch_keys( users = set(users) self.log.trace(f"Querying keys for {users}") - keys = await self.client.query_keys(users, token=since) + resp = await self.client.query_keys(users, token=since) + missing_users = users.copy() - for server, err in keys.failures.items(): + for server, err in resp.failures.items(): self.log.warning(f"Query keys failure for {server}: {err}") data = {} - for user_id, devices in keys.device_keys.items(): - users.remove(user_id) + for user_id, devices in resp.device_keys.items(): + missing_users.remove(user_id) + async with self.crypto_store.transaction(): + data[user_id] = await self._process_fetched_keys(user_id, devices, resp) + + for user_id in missing_users: + self.log.warning(f"Didn't get any devices for user {user_id}") + + return data + + async def _process_fetched_keys( + self, + user_id: UserID, + devices: dict[DeviceID, DeviceKeys], + resp: QueryKeysResponse, + ) -> dict[DeviceID, DeviceIdentity]: + new_devices = {} + existing_devices = (await self.crypto_store.get_devices(user_id)) or {} - new_devices = {} - existing_devices = (await self.crypto_store.get_devices(user_id)) or {} + self.log.trace( + f"Updating devices for {user_id}, got {len(devices)}, " + f"have {len(existing_devices)} in store" + ) + changed = False + ssks = resp.self_signing_keys.get(user_id) + ssk = ssks.first_ed25519_key if ssks else None + for device_id, device_keys in devices.items(): + try: + existing = existing_devices[device_id] + except KeyError: + existing = None + changed = True + self.log.trace(f"Validating device {device_keys} of {user_id}") + try: + new_device = await self._validate_device(user_id, device_id, device_keys, existing) + except DeviceValidationError as e: + self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}") + else: + if new_device: + new_devices[device_id] = new_device + await self._store_device_self_signatures(device_keys, ssk) + self.log.debug( + f"Storing new device list for {user_id} containing {len(new_devices)} devices" + ) + await self.crypto_store.put_devices(user_id, new_devices) - self.log.trace( - f"Updating devices for {user_id}, got {len(devices)}, " - f"have {len(existing_devices)} in store" + if changed or len(new_devices) != len(existing_devices): + if self.delete_keys_on_device_delete: + for device_id in existing_devices.keys() - new_devices.keys(): + device = existing_devices[device_id] + removed_ids = await self.crypto_store.redact_group_sessions( + room_id=None, sender_key=device.identity_key, reason="device removed" + ) + self.log.info( + "Redacted megolm sessions sent by removed device " + f"{device.user_id}/{device.device_id}: {removed_ids}" + ) + await self.on_devices_changed(user_id) + + await self._store_cross_signing_keys(resp, user_id) + + return new_devices + + async def _store_device_self_signatures( + self, device_keys: DeviceKeys, self_signing_key: SigningKey | None + ) -> None: + device_desc = f"Device {device_keys.user_id}/{device_keys.device_id}" + try: + self_signatures = device_keys.signatures[device_keys.user_id].copy() + except KeyError: + self.log.warning(f"{device_desc} doesn't have any signatures from the user") + return + if len(device_keys.signatures) > 1: + self.log.debug( + f"{device_desc} has signatures from other users (%s)", + set(device_keys.signatures.keys()) - {device_keys.user_id}, + ) + + device_self_sig = self_signatures.pop( + KeyID(EncryptionKeyAlgorithm.ED25519, device_keys.device_id) + ) + target = CrossSigner(device_keys.user_id, device_keys.ed25519) + # This one is already validated by _validate_device + await self.crypto_store.put_signature(target, target, device_self_sig) + + try: + cs_self_sig = self_signatures.pop( + KeyID(EncryptionKeyAlgorithm.ED25519, self_signing_key) ) - changed = False - for device_id, keys in devices.items(): - try: - existing = existing_devices[device_id] - except KeyError: - existing = None - changed = True - self.log.trace(f"Validating device {keys} of {user_id}") - try: - new_device = await self._validate_device(user_id, device_id, keys, existing) - except DeviceValidationError as e: - self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}") - else: - if new_device: - new_devices[device_id] = new_device + except KeyError: + self.log.warning(f"{device_desc} isn't cross-signed") + else: + is_valid_self_sig = verify_signature_json( + device_keys.serialize(), device_keys.user_id, self_signing_key, self_signing_key + ) + if is_valid_self_sig: + signer = CrossSigner(device_keys.user_id, self_signing_key) + await self.crypto_store.put_signature(target, signer, cs_self_sig) + else: + self.log.warning(f"{device_desc} doesn't have a valid cross-signing signature") + + if len(self_signatures) > 0: self.log.debug( - f"Storing new device list for {user_id} containing {len(new_devices)} devices" + f"{device_desc} has signatures from unexpected keys (%s)", + set(self_signatures.keys()), ) - await self.crypto_store.put_devices(user_id, new_devices) - data[user_id] = new_devices - if changed or len(new_devices) != len(existing_devices): - await self.on_devices_changed(user_id) + async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: UserID) -> None: + new_keys: dict[CrossSigningUsage, CrossSigningKeys] = {} + try: + master = new_keys[CrossSigningUsage.MASTER] = resp.master_keys[user_id] + except KeyError: + self.log.debug(f"Didn't get a cross-signing master key for {user_id}") + return + try: + new_keys[CrossSigningUsage.SELF] = resp.self_signing_keys[user_id] + except KeyError: + self.log.debug(f"Didn't get a cross-signing self-signing key for {user_id}") + return + try: + new_keys[CrossSigningUsage.USER] = resp.user_signing_keys[user_id] + except KeyError: + pass + current_keys = await self.crypto_store.get_cross_signing_keys(user_id) + for usage, key in current_keys.items(): + if usage in new_keys and key.key != new_keys[usage].first_ed25519_key: + num = await self.crypto_store.drop_signatures_by_key(CrossSigner(user_id, key.key)) + if num >= 0: + self.log.debug( + f"Dropped {num} signatures made by key {user_id}/{key.key} ({usage})" + " as it has been replaced" + ) + for usage, key in new_keys.items(): + actual_key = key.first_ed25519_key + self.log.debug(f"Storing cross-signing key for {user_id}: {actual_key} (type {usage})") + await self.crypto_store.put_cross_signing_key(user_id, usage, actual_key) - for user_id in users: - self.log.warning(f"Didn't get any keys for user {user_id}") + if usage != CrossSigningUsage.MASTER and ( + KeyID(EncryptionKeyAlgorithm.ED25519, master.first_ed25519_key) + not in key.signatures[user_id] + ): + self.log.warning( + f"Cross-signing key {user_id}/{actual_key}/{usage}" + " doesn't seem to have a signature from the master key" + ) - return data + for signer_user_id, signatures in key.signatures.items(): + for key_id, signature in signatures.items(): + signing_key = SigningKey(key_id.key_id) + if signer_user_id == user_id: + try: + device = resp.device_keys[signer_user_id][DeviceID(key_id.key_id)] + signing_key = device.ed25519 + except KeyError: + pass + if not signing_key or len(signing_key) != 43: + self.log.debug( + f"Cross-signing key {user_id}/{actual_key} has a signature from " + f"an unknown key {key_id}" + ) + continue + signing_key_log = signing_key + if signing_key != key_id.key_id: + signing_key_log = f"{signing_key} ({key_id})" + self.log.debug( + f"Verifying cross-signing key {user_id}/{actual_key} " + f"with key {signer_user_id}/{signing_key_log}" + ) + is_valid_sig = verify_signature_json( + key.serialize(), signer_user_id, key_id.key_id, signing_key + ) + if is_valid_sig: + self.log.debug(f"Signature from {signing_key_log} for {key_id} verified") + await self.crypto_store.put_signature( + target=CrossSigner(user_id, actual_key), + signer=CrossSigner(signer_user_id, signing_key), + signature=signature, + ) + else: + self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}") + + async def _get_full_device_keys(self, device: DeviceIdentity) -> DeviceKeys: + resp = await self.client.query_keys({device.user_id: [device.device_id]}) + keys = resp.device_keys[device.user_id][device.device_id] + await self._validate_device(device.user_id, device.device_id, keys, device) + return keys async def get_or_fetch_device( self, user_id: UserID, device_id: DeviceID - ) -> Optional[DeviceIdentity]: + ) -> DeviceIdentity | None: device = await self.crypto_store.get_device(user_id, device_id) if device is not None: return device @@ -82,7 +258,7 @@ async def get_or_fetch_device( async def get_or_fetch_device_by_key( self, user_id: UserID, identity_key: IdentityKey - ) -> Optional[DeviceIdentity]: + ) -> DeviceIdentity | None: device = await self.crypto_store.find_device_by_key(user_id, identity_key) if device is not None: return device @@ -93,6 +269,8 @@ async def get_or_fetch_device_by_key( return None async def on_devices_changed(self, user_id: UserID) -> None: + if self.disable_device_change_key_rotation: + return shared_rooms = await self.state_store.find_shared_rooms(user_id) self.log.debug( f"Devices of {user_id} changed, invalidating group session in {shared_rooms}" @@ -104,12 +282,16 @@ async def _validate_device( user_id: UserID, device_id: DeviceID, device_keys: DeviceKeys, - existing: Optional[DeviceIdentity] = None, + existing: DeviceIdentity | None = None, ) -> DeviceIdentity: if user_id != device_keys.user_id: - raise DeviceValidationError("mismatching user ID in parameter and keys object") + raise DeviceValidationError( + f"mismatching user ID (expected {user_id}, got {device_keys.user_id})" + ) elif device_id != device_keys.device_id: - raise DeviceValidationError("mismatching device ID in parameter and keys object") + raise DeviceValidationError( + f"mismatching device ID (expected {device_id}, got {device_keys.device_id})" + ) signing_key = device_keys.ed25519 if not signing_key: @@ -119,7 +301,10 @@ async def _validate_device( raise DeviceValidationError("didn't find curve25519 identity key") if existing and existing.signing_key != signing_key: - raise DeviceValidationError("received update for device with different signing key") + raise DeviceValidationError( + f"received update for device with different signing key " + f"(expected {existing.signing_key}, got {signing_key})" + ) if not verify_signature_json(device_keys.serialize(), user_id, device_id, signing_key): raise DeviceValidationError("invalid signature on device keys") @@ -131,7 +316,62 @@ async def _validate_device( device_id=device_id, identity_key=identity_key, signing_key=signing_key, - trust=TrustState.UNSET, + trust=TrustState.UNVERIFIED, name=name, deleted=False, ) + + async def resolve_trust(self, device: DeviceIdentity, allow_fetch: bool = True) -> TrustState: + try: + return await self._try_resolve_trust(device, allow_fetch) + except Exception: + self.log.exception(f"Failed to resolve trust of {device.user_id}/{device.device_id}") + return TrustState.UNVERIFIED + + async def _try_resolve_trust( + self, device: DeviceIdentity, allow_fetch: bool = True + ) -> TrustState: + if device.device_id != self.client.device_id and device.trust in ( + TrustState.VERIFIED, + TrustState.BLACKLISTED, + ): + return device.trust + their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id) + if len(their_keys) == 0 and allow_fetch and device.user_id not in self._cs_fetch_attempted: + self.log.debug(f"Didn't find any cross-signing keys for {device.user_id}, fetching...") + async with self._fetch_keys_lock: + if device.user_id not in self._cs_fetch_attempted: + self._cs_fetch_attempted.add(device.user_id) + await self._fetch_keys([device.user_id]) + their_keys = await self.crypto_store.get_cross_signing_keys(device.user_id) + try: + msk = their_keys[CrossSigningUsage.MASTER] + ssk = their_keys[CrossSigningUsage.SELF] + except KeyError as e: + if allow_fetch: + self.log.warning(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}") + return TrustState.UNVERIFIED + ssk_signed = await self.crypto_store.is_key_signed_by( + target=CrossSigner(device.user_id, ssk.key), + signer=CrossSigner(device.user_id, msk.key), + ) + if not ssk_signed: + self.log.warning( + f"Self-signing key of {device.user_id} is not signed by their master key" + ) + return TrustState.UNVERIFIED + device_signed = await self.crypto_store.is_key_signed_by( + target=CrossSigner(device.user_id, device.signing_key), + signer=CrossSigner(device.user_id, ssk.key), + ) + if device_signed: + if await self.is_user_trusted(device.user_id): + return TrustState.CROSS_SIGNED_TRUSTED + elif msk.key == msk.first: + return TrustState.CROSS_SIGNED_TOFU + return TrustState.CROSS_SIGNED_UNTRUSTED + return TrustState.UNVERIFIED + + async def is_user_trusted(self, user_id: UserID) -> bool: + # TODO implement once own cross-signing key stuff is ready + return False diff --git a/mautrix/crypto/encrypt_megolm.py b/mautrix/crypto/encrypt_megolm.py index 353e4f7e..459bbf1f 100644 --- a/mautrix/crypto/encrypt_megolm.py +++ b/mautrix/crypto/encrypt_megolm.py @@ -1,11 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Any, Dict, List, Tuple, Union from collections import defaultdict -from datetime import timedelta +from datetime import datetime, timedelta import asyncio import json import time @@ -13,6 +13,7 @@ from mautrix.errors import EncryptionError, SessionShareError from mautrix.types import ( DeviceID, + DeviceIdentity, EncryptedMegolmEventContent, EncryptionAlgorithm, EventType, @@ -24,13 +25,13 @@ Serializable, SessionID, SigningKey, + TrustState, UserID, ) from .device_lists import DeviceListMachine from .encrypt_olm import OlmEncryptionMachine from .sessions import InboundGroupSession, OutboundGroupSession, Session -from .types import DeviceIdentity, TrustState class Sentinel: @@ -94,9 +95,9 @@ async def _encrypt_megolm_event( { "room_id": room_id, "type": event_type.serialize(), - "content": content.serialize() - if isinstance(content, Serializable) - else content, + "content": ( + content.serialize() if isinstance(content, Serializable) else content + ), } ) ) @@ -172,21 +173,6 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No session = await self._new_outbound_group_session(room_id) self.log.debug(f"Sharing group session {session.id} for room {room_id} with {users}") - encryption_info = await self.state_store.get_encryption_info(room_id) - if encryption_info: - if encryption_info.algorithm != EncryptionAlgorithm.MEGOLM_V1: - raise SessionShareError("Room encryption algorithm is not supported") - session.max_messages = encryption_info.rotation_period_msgs or session.max_messages - session.max_age = ( - timedelta(milliseconds=encryption_info.rotation_period_ms) - if encryption_info.rotation_period_ms - else session.max_age - ) - self.log.debug( - "Got stored encryption state event and configured session to rotate " - f"after {session.max_messages} messages or {session.max_age}" - ) - olm_sessions: DeviceMap = defaultdict(lambda: {}) withhold_key_msgs = defaultdict(lambda: {}) missing_sessions: Dict[UserID, Dict[DeviceID, DeviceIdentity]] = defaultdict(lambda: {}) @@ -220,7 +206,10 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No if missing_sessions: self.log.debug(f"Creating missing outbound sessions {missing_sessions}") - await self._create_outbound_sessions(missing_sessions) + try: + await self._create_outbound_sessions(missing_sessions) + except Exception: + self.log.exception("Failed to create missing outbound sessions") for user_id, devices in missing_sessions.items(): for device_id, device in devices.items(): @@ -249,13 +238,33 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No async def _new_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession: session = OutboundGroupSession(room_id) - await self._create_group_session( - self.account.identity_key, - self.account.signing_key, - room_id, - SessionID(session.id), - session.session_key, - ) + + encryption_info = await self.state_store.get_encryption_info(room_id) + if encryption_info: + if encryption_info.algorithm != EncryptionAlgorithm.MEGOLM_V1: + raise SessionShareError("Room encryption algorithm is not supported") + session.max_messages = encryption_info.rotation_period_msgs or session.max_messages + session.max_age = ( + timedelta(milliseconds=encryption_info.rotation_period_ms) + if encryption_info.rotation_period_ms + else session.max_age + ) + self.log.debug( + "Got stored encryption state event and configured session to rotate " + f"after {session.max_messages} messages or {session.max_age}" + ) + + if not self.dont_store_outbound_keys: + await self._create_group_session( + self.account.identity_key, + self.account.signing_key, + room_id, + SessionID(session.id), + session.session_key, + max_messages=session.max_messages, + max_age=session.max_age, + is_scheduled=False, + ) return session async def _encrypt_and_share_group_session( @@ -282,6 +291,9 @@ async def _create_group_session( room_id: RoomID, session_id: SessionID, session_key: str, + max_age: Union[timedelta, int], + max_messages: int, + is_scheduled: bool = False, ) -> None: start = time.monotonic() session = InboundGroupSession( @@ -289,6 +301,10 @@ async def _create_group_session( signing_key=signing_key, sender_key=sender_key, room_id=room_id, + received_at=datetime.utcnow(), + max_age=max_age, + max_messages=max_messages, + is_scheduled=is_scheduled, ) olm_duration = time.monotonic() - start if olm_duration > 5: @@ -298,7 +314,10 @@ async def _create_group_session( session_id = session.id await self.crypto_store.put_group_session(room_id, sender_key, session_id, session) self._mark_session_received(session_id) - self.log.debug(f"Created inbound group session {room_id}/{sender_key}/{session_id}") + self.log.debug( + f"Created inbound group session {room_id}/{sender_key}/{session_id} " + f"(max {max_age} / {max_messages} messages, {is_scheduled=})" + ) async def _find_olm_sessions( self, @@ -314,7 +333,8 @@ async def _find_olm_sessions( session.users_ignored.add(key) return already_shared - if device.trust == TrustState.BLACKLISTED: + trust = await self.resolve_trust(device) + if trust == TrustState.BLACKLISTED: self.log.debug( f"Not encrypting group session {session.id} for {device_id} " f"of {user_id}: device is blacklisted" @@ -328,10 +348,11 @@ async def _find_olm_sessions( code=RoomKeyWithheldCode.BLACKLISTED, reason="Device is blacklisted", ) - elif not self.allow_unverified_devices and device.trust == TrustState.UNSET: + elif self.send_keys_min_trust > trust: self.log.debug( f"Not encrypting group session {session.id} for {device_id} " - f"of {user_id}: device is not verified" + f"of {user_id}: device is not trusted " + f"(min: {self.send_keys_min_trust}, device: {trust})" ) session.users_ignored.add(key) return RoomKeyWithheldEventContent( diff --git a/mautrix/crypto/encrypt_olm.py b/mautrix/crypto/encrypt_olm.py index de41c486..029ad1e5 100644 --- a/mautrix/crypto/encrypt_olm.py +++ b/mautrix/crypto/encrypt_olm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,17 +7,20 @@ import asyncio from mautrix.types import ( + DecryptedOlmEvent, DeviceID, + DeviceIdentity, EncryptedOlmEventContent, EncryptionKeyAlgorithm, EventType, + OlmEventKeys, ToDeviceEventContent, UserID, ) -from .base import BaseOlmMachine, verify_signature_json +from .base import BaseOlmMachine from .sessions import Session -from .types import DecryptedOlmEvent, DeviceIdentity, OlmEventKeys +from .signature import verify_signature_json ClaimKeysList = Dict[UserID, Dict[DeviceID, DeviceIdentity]] @@ -58,6 +61,7 @@ async def _create_outbound_sessions_locked( self, users: ClaimKeysList, _force_recreate_session: bool = False ) -> None: request: Dict[UserID, Dict[DeviceID, EncryptionKeyAlgorithm]] = {} + expected_devices = set() for user_id, devices in users.items(): request[user_id] = {} for device_id, identity in devices.items(): @@ -65,13 +69,18 @@ async def _create_outbound_sessions_locked( identity.identity_key ): request[user_id][device_id] = EncryptionKeyAlgorithm.SIGNED_CURVE25519 + expected_devices.add((user_id, device_id)) if not request[user_id]: del request[user_id] if not request: return + request_device_count = len(expected_devices) keys = await self.client.claim_keys(request) + for server, info in (keys.failures or {}).items(): + self.log.warning(f"Key claim failure for {server}: {info}") for user_id, devices in keys.one_time_keys.items(): for device_id, one_time_keys in devices.items(): + expected_devices.discard((user_id, device_id)) key_id, one_time_key_data = one_time_keys.popitem() one_time_key = one_time_key_data["key"] identity = users[user_id][device_id] @@ -88,6 +97,19 @@ async def _create_outbound_sessions_locked( f"Created new Olm session with {user_id}/{device_id} " f"(OTK ID: {key_id})" ) + if expected_devices: + if request_device_count == 1: + raise Exception( + "Key claim response didn't contain key " + f"for queried device {expected_devices.pop()}" + ) + else: + self.log.warning( + "Key claim response didn't contain keys for %d out of %d expected devices: %s", + len(expected_devices), + request_device_count, + expected_devices, + ) async def send_encrypted_to_device( self, diff --git a/mautrix/crypto/key_request.py b/mautrix/crypto/key_request.py index d27263d5..ecaf994d 100644 --- a/mautrix/crypto/key_request.py +++ b/mautrix/crypto/key_request.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ import uuid from mautrix.types import ( + DecryptedOlmEvent, DeviceID, EncryptionAlgorithm, EventType, @@ -23,7 +24,6 @@ from .base import BaseOlmMachine from .sessions import InboundGroupSession -from .types import DecryptedOlmEvent class KeyRequestingMachine(BaseOlmMachine): @@ -114,15 +114,24 @@ async def request_room_key( async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None: key: ForwardedRoomKeyEventContent = evt.content - if await self.crypto_store.has_group_session(key.room_id, key.sender_key, key.session_id): + if await self.crypto_store.has_group_session(key.room_id, key.session_id): self.log.debug( f"Ignoring received session {key.session_id} from {evt.sender}/" f"{evt.sender_device}, as crypto store says we have it already" ) return + if not key.beeper_max_messages or not key.beeper_max_age_ms: + await self._fill_encryption_info(key) key.forwarding_key_chain.append(evt.sender_key) sess = InboundGroupSession.import_session( - key.session_key, key.signing_key, key.sender_key, key.room_id, key.forwarding_key_chain + key.session_key, + key.signing_key, + key.sender_key, + key.room_id, + key.forwarding_key_chain, + max_age=key.beeper_max_age_ms, + max_messages=key.beeper_max_messages, + is_scheduled=key.beeper_is_scheduled, ) if key.session_id != sess.id: self.log.warning( diff --git a/mautrix/crypto/key_share.py b/mautrix/crypto/key_share.py index c8c134d3..8d6b6e73 100644 --- a/mautrix/crypto/key_share.py +++ b/mautrix/crypto/key_share.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,6 +7,7 @@ from mautrix.errors import MatrixConnectionError, MatrixError, MatrixRequestError from mautrix.types import ( + DeviceIdentity, EncryptionAlgorithm, EventType, ForwardedRoomKeyEventContent, @@ -15,12 +16,13 @@ RoomKeyRequestEventContent, RoomKeyWithheldCode, RoomKeyWithheldEventContent, + SessionID, ToDeviceEvent, + TrustState, ) from .device_lists import DeviceListMachine from .encrypt_olm import OlmEncryptionMachine -from .types import DeviceIdentity, TrustState class RejectKeyShare(MatrixError): @@ -64,9 +66,7 @@ async def default_allow_key_share( """ if device.user_id != self.client.mxid: raise RejectKeyShare( - f"Rejecting key request from a different user ({device.user_id})", - code=RoomKeyWithheldCode.UNAUTHORIZED, - reason="This device does not share keys to other users", + f"Ignoring key request from a different user ({device.user_id})", code=None ) elif device.device_id == self.client.device_id: raise RejectKeyShare("Ignoring key request from ourselves", code=None) @@ -76,18 +76,12 @@ async def default_allow_key_share( code=RoomKeyWithheldCode.BLACKLISTED, reason="You have been blacklisted by this device", ) - elif device.trust == TrustState.VERIFIED: - self.log.debug(f"Accepting key request from verified device {device.device_id}") - return True - elif self.share_to_unverified_devices: - self.log.debug( - f"Accepting key request from unverified device {device.device_id}, " - f"as share_to_unverified_devices is True" - ) + elif await self.resolve_trust(device) >= self.share_keys_min_trust: + self.log.debug(f"Accepting key request from trusted device {device.device_id}") return True else: raise RejectKeyShare( - f"Rejecting key request from unverified device {device.device_id}", + f"Rejecting key request from untrusted device {device.device_id}", code=RoomKeyWithheldCode.UNVERIFIED, reason="You have not been verified by this device", ) @@ -168,9 +162,7 @@ async def _handle_room_key_request( if not await self.allow_key_share(device, request): return - sess = await self.crypto_store.get_group_session( - request.room_id, request.sender_key, request.session_id - ) + sess = await self.crypto_store.get_group_session(request.room_id, request.session_id) if sess is None: raise RejectKeyShare( f"Didn't find group session {request.session_id} to forward to " @@ -183,7 +175,7 @@ async def _handle_room_key_request( forward_content = ForwardedRoomKeyEventContent( algorithm=EncryptionAlgorithm.MEGOLM_V1, room_id=sess.room_id, - session_id=sess.id, + session_id=SessionID(sess.id), session_key=exported_key, sender_key=sess.sender_key, forwarding_key_chain=sess.forwarding_chain, diff --git a/mautrix/crypto/machine.py b/mautrix/crypto/machine.py index 700aa0e9..60c65677 100644 --- a/mautrix/crypto/machine.py +++ b/mautrix/crypto/machine.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,26 +8,37 @@ from typing import Optional import asyncio import logging +import time from mautrix import client as cli +from mautrix.errors import GroupSessionWithheldError from mautrix.types import ( + ASToDeviceEvent, + DecryptedOlmEvent, + DeviceID, DeviceLists, DeviceOTKCount, EncryptionAlgorithm, + EncryptionKeyAlgorithm, EventType, + Member, Membership, StateEvent, ToDeviceEvent, + TrustState, + UserID, ) +from mautrix.util import background_task from mautrix.util.logging import TraceLogger from .account import OlmAccount +from .cross_signing import CrossSigningMachine from .decrypt_megolm import MegolmDecryptionMachine from .encrypt_megolm import MegolmEncryptionMachine from .key_request import KeyRequestingMachine from .key_share import KeySharingMachine +from .ssss import Machine as SSSSMachine from .store import CryptoStore, StateStore -from .types import DecryptedOlmEvent from .unwedge import OlmUnwedgingMachine @@ -37,6 +48,7 @@ class OlmMachine( OlmUnwedgingMachine, KeySharingMachine, KeyRequestingMachine, + CrossSigningMachine, ): """ OlmMachine is the main class for handling things related to Matrix end-to-end encryption with @@ -49,13 +61,10 @@ class OlmMachine( log: TraceLogger crypto_store: CryptoStore state_store: StateStore - - _fetch_keys_lock: asyncio.Lock + ssss: SSSSMachine account: Optional[OlmAccount] - allow_unverified_devices: bool - def __init__( self, client: cli.Client, @@ -65,19 +74,36 @@ def __init__( ) -> None: super().__init__() self.client = client + self.ssss = SSSSMachine(client) self.log = log or logging.getLogger("mau.crypto") self.crypto_store = crypto_store self.state_store = state_store self.account = None - self.allow_unverified_devices = True - self.share_to_unverified_devices = False + self.send_keys_min_trust = TrustState.UNVERIFIED + self.share_keys_min_trust = TrustState.CROSS_SIGNED_TOFU self.allow_key_share = self.default_allow_key_share + self.delete_outbound_keys_on_ack = False + self.dont_store_outbound_keys = False + self.delete_previous_keys_on_receive = False + self.ratchet_keys_on_decrypt = False + self.delete_fully_used_keys_on_decrypt = False + self.delete_keys_on_device_delete = False + self.disable_device_change_key_rotation = False + self._fetch_keys_lock = asyncio.Lock() + self._megolm_decrypt_lock = asyncio.Lock() + self._share_keys_lock = asyncio.Lock() + self._last_key_share = time.monotonic() - 60 self._key_request_waiters = {} self._inbound_session_waiters = {} self._prev_unwedge = {} + self._cs_fetch_attempted = set() + + self._cross_signing_public_keys = None + self._cross_signing_public_keys_fetched = False + self._cross_signing_private_keys = None self.client.add_event_handler( cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True @@ -85,6 +111,7 @@ def __init__( self.client.add_event_handler(cli.InternalEventType.DEVICE_LISTS, self.handle_device_lists) self.client.add_event_handler(EventType.TO_DEVICE_ENCRYPTED, self.handle_to_device_event) self.client.add_event_handler(EventType.ROOM_KEY_REQUEST, self.handle_room_key_request) + self.client.add_event_handler(EventType.BEEPER_ROOM_KEY_ACK, self.handle_beep_room_key_ack) # self.client.add_event_handler(EventType.ROOM_KEY_WITHHELD, self.handle_room_key_withheld) # self.client.add_event_handler(EventType.ORG_MATRIX_ROOM_KEY_WITHHELD, # self.handle_room_key_withheld) @@ -97,6 +124,34 @@ async def load(self) -> None: self.account = OlmAccount() await self.crypto_store.put_account(self.account) + async def handle_as_otk_counts( + self, otk_counts: dict[UserID, dict[DeviceID, DeviceOTKCount]] + ) -> None: + for user_id, devices in otk_counts.items(): + for device_id, count in devices.items(): + if user_id == self.client.mxid and device_id == self.client.device_id: + await self.handle_otk_count(count) + else: + self.log.warning(f"Got OTK count for unknown device {user_id}/{device_id}") + + async def handle_as_device_lists(self, device_lists: DeviceLists) -> None: + background_task.create(self.handle_device_lists(device_lists)) + + async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None: + if evt.to_user_id != self.client.mxid or evt.to_device_id != self.client.device_id: + self.log.warning( + f"Got to-device event for unknown device {evt.to_user_id}/{evt.to_device_id}" + ) + return + if evt.type == EventType.TO_DEVICE_ENCRYPTED: + await self.handle_to_device_event(evt) + elif evt.type == EventType.ROOM_KEY_REQUEST: + await self.handle_room_key_request(evt) + elif evt.type == EventType.BEEPER_ROOM_KEY_ACK: + await self.handle_beep_room_key_ack(evt) + else: + self.log.debug(f"Got unknown to-device event {evt.type} from {evt.sender}") + async def handle_otk_count(self, otk_count: DeviceOTKCount) -> None: """ Handle the ``device_one_time_keys_count`` data in a sync response. @@ -143,9 +198,19 @@ async def handle_member_event(self, evt: StateEvent) -> None: } if prev == cur or ignored_changes.get(prev) == cur: return + src = getattr(evt, "source", None) + prev_cache = evt.unsigned.get("mautrix_prev_membership") + if isinstance(prev_cache, Member) and prev_cache.membership == cur: + self.log.debug( + f"Got duplicate membership state event in {evt.room_id} changing {evt.state_key} " + f"from {prev} to {cur}, cached state was {prev_cache} (event ID: {evt.event_id}, " + f"sync source: {src})" + ) + return self.log.debug( f"Got membership state event in {evt.room_id} changing {evt.state_key} from " - f"{prev} to {cur}, invalidating group session" + f"{prev} to {cur} (event ID: {evt.event_id}, sync source: {src}, " + f"cached: {prev_cache.membership if prev_cache else None}), invalidating group session" ) await self.crypto_store.remove_outbound_group_session(evt.room_id) @@ -157,6 +222,11 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None: passed to the OlmMachine is syncing. You shouldn't need to call this yourself unless you do syncing in some manual way. """ + if isinstance(evt, DecryptedOlmEvent): + self.log.warning( + f"Dropping unexpected nested encrypted to-device event from {evt.sender}" + ) + return self.log.trace( f"Handling encrypted to-device event from {evt.content.sender_key} ({evt.sender})" ) @@ -165,28 +235,81 @@ async def handle_to_device_event(self, evt: ToDeviceEvent) -> None: await self._receive_room_key(decrypted_evt) elif decrypted_evt.type == EventType.FORWARDED_ROOM_KEY: await self._receive_forwarded_room_key(decrypted_evt) + else: + self.client.dispatch_event(decrypted_evt, source=evt.source) async def _receive_room_key(self, evt: DecryptedOlmEvent) -> None: # TODO nio had a comment saying "handle this better" # for the case where evt.Keys.Ed25519 is none? if evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1 or not evt.keys.ed25519: return + if not evt.content.beeper_max_messages or not evt.content.beeper_max_age_ms: + await self._fill_encryption_info(evt.content) + if self.delete_previous_keys_on_receive and not evt.content.beeper_is_scheduled: + removed_ids = await self.crypto_store.redact_group_sessions( + evt.content.room_id, evt.sender_key, reason="received new key from device" + ) + self.log.info(f"Redacted previous megolm sessions: {removed_ids}") await self._create_group_session( evt.sender_key, evt.keys.ed25519, evt.content.room_id, evt.content.session_id, evt.content.session_key, + max_age=evt.content.beeper_max_age_ms, + max_messages=evt.content.beeper_max_messages, + is_scheduled=evt.content.beeper_is_scheduled, ) - async def share_keys(self, current_otk_count: int) -> None: + async def handle_beep_room_key_ack(self, evt: ToDeviceEvent) -> None: + try: + sess = await self.crypto_store.get_group_session( + evt.content.room_id, evt.content.session_id + ) + except GroupSessionWithheldError: + self.log.debug( + f"Ignoring room key ack for session {evt.content.session_id}" + " that was already redacted" + ) + return + if not sess: + self.log.debug(f"Ignoring room key ack for unknown session {evt.content.session_id}") + return + if ( + sess.sender_key == self.account.identity_key + and self.delete_outbound_keys_on_ack + and evt.content.first_message_index == 0 + ): + self.log.debug("Redacting inbound copy of outbound group session after ack") + await self.crypto_store.redact_group_session( + evt.content.room_id, evt.content.session_id, reason="outbound session acked" + ) + else: + self.log.debug(f"Received room key ack for {sess.id}") + + async def share_keys(self, current_otk_count: int | None = None) -> None: """ Share any keys that need to be shared. This is automatically called from :meth:`handle_otk_count`, so you should not need to call this yourself. Args: current_otk_count: The current number of signed curve25519 keys present on the server. + If omitted, the count will be fetched from the server. """ + async with self._share_keys_lock: + await self._share_keys(current_otk_count) + + async def _share_keys(self, current_otk_count: int | None) -> None: + if current_otk_count is None or ( + # If the last key share was recent and the new count is very low, re-check the count + # from the server to avoid any race conditions. + self._last_key_share + 60 > time.monotonic() + and current_otk_count < 10 + ): + self.log.debug("Checking OTK count on server") + current_otk_count = (await self.client.upload_keys()).get( + EncryptionKeyAlgorithm.SIGNED_CURVE25519, 0 + ) device_keys = ( self.account.get_device_keys(self.client.mxid, self.client.device_id) if not self.account.shared @@ -201,7 +324,9 @@ async def share_keys(self, current_otk_count: int) -> None: if device_keys: self.log.debug("Going to upload initial account keys") self.log.debug(f"Uploading {len(one_time_keys)} one-time keys") - await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys) + resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys) self.account.shared = True + self.account.mark_keys_as_published() + self._last_key_share = time.monotonic() await self.crypto_store.put_account(self.account) - self.log.debug("Shared keys and saved account") + self.log.debug(f"Shared keys and saved account, new keys: {resp}") diff --git a/mautrix/crypto/sessions.py b/mautrix/crypto/sessions.py index cae175c5..b0b16a18 100644 --- a/mautrix/crypto/sessions.py +++ b/mautrix/crypto/sessions.py @@ -1,12 +1,13 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import List, Optional, Set, Tuple, cast +from typing import List, Optional, Set, Tuple, Union, cast from datetime import datetime, timedelta from _libolm import ffi, lib +from attr import dataclass import olm from mautrix.errors import EncryptionError @@ -18,8 +19,10 @@ OlmMsgType, RoomID, RoomKeyEventContent, + SerializableAttrs, SigningKey, UserID, + field, ) @@ -93,11 +96,24 @@ def describe(self) -> str: return "describe not supported" +@dataclass +class RatchetSafety(SerializableAttrs): + next_index: int = 0 + missed_indices: List[int] = field(factory=lambda: []) + lost_indices: List[int] = field(factory=lambda: []) + + class InboundGroupSession(olm.InboundGroupSession): room_id: RoomID signing_key: SigningKey sender_key: IdentityKey - forwarding_chain: List[str] + forwarding_chain: List[IdentityKey] + + ratchet_safety: RatchetSafety + received_at: datetime + max_age: timedelta + max_messages: int + is_scheduled: bool def __init__( self, @@ -105,12 +121,24 @@ def __init__( signing_key: SigningKey, sender_key: IdentityKey, room_id: RoomID, - forwarding_chain: Optional[List[str]] = None, + forwarding_chain: Optional[List[IdentityKey]] = None, + ratchet_safety: Optional[RatchetSafety] = None, + received_at: Optional[datetime] = None, + max_age: Union[timedelta, int, None] = None, + max_messages: Optional[int] = None, + is_scheduled: bool = False, ) -> None: self.signing_key = signing_key self.sender_key = sender_key self.room_id = room_id self.forwarding_chain = forwarding_chain or [] + self.ratchet_safety = ratchet_safety or RatchetSafety() + self.received_at = received_at or datetime.utcnow() + if isinstance(max_age, int): + max_age = timedelta(milliseconds=max_age) + self.max_age = max_age + self.max_messages = max_messages + self.is_scheduled = is_scheduled super().__init__(session_key) def __new__(cls, *args, **kwargs): @@ -124,13 +152,23 @@ def from_pickle( signing_key: SigningKey, sender_key: IdentityKey, room_id: RoomID, - forwarding_chain: Optional[List[str]] = None, + forwarding_chain: Optional[List[IdentityKey]] = None, + ratchet_safety: Optional[RatchetSafety] = None, + received_at: Optional[datetime] = None, + max_age: Optional[timedelta] = None, + max_messages: Optional[int] = None, + is_scheduled: bool = False, ) -> "InboundGroupSession": session = super().from_pickle(pickle, passphrase) session.signing_key = signing_key session.sender_key = sender_key session.room_id = room_id session.forwarding_chain = forwarding_chain or [] + session.ratchet_safety = ratchet_safety or RatchetSafety() + session.received_at = received_at + session.max_age = max_age + session.max_messages = max_messages + session.is_scheduled = is_scheduled return session @classmethod @@ -141,14 +179,41 @@ def import_session( sender_key: IdentityKey, room_id: RoomID, forwarding_chain: Optional[List[str]] = None, + ratchet_safety: Optional[RatchetSafety] = None, + received_at: Optional[datetime] = None, + max_age: Union[timedelta, int, None] = None, + max_messages: Optional[int] = None, + is_scheduled: bool = False, ) -> "InboundGroupSession": session = super().import_session(session_key) session.signing_key = signing_key session.sender_key = sender_key session.room_id = room_id session.forwarding_chain = forwarding_chain or [] + session.ratchet_safety = ratchet_safety or RatchetSafety() + session.received_at = received_at or datetime.utcnow() + if isinstance(max_age, int): + max_age = timedelta(milliseconds=max_age) + session.max_age = max_age + session.max_messages = max_messages + session.is_scheduled = is_scheduled return session + def ratchet_to(self, index: int) -> "InboundGroupSession": + exported = self.export_session(index) + return self.import_session( + exported, + signing_key=self.signing_key, + sender_key=self.sender_key, + room_id=self.room_id, + forwarding_chain=self.forwarding_chain, + ratchet_safety=self.ratchet_safety, + received_at=self.received_at, + max_age=self.max_age, + max_messages=self.max_messages, + is_scheduled=self.is_scheduled, + ) + class OutboundGroupSession(olm.OutboundGroupSession): """Outbound group session aware of the users it is shared with. diff --git a/mautrix/crypto/signature.py b/mautrix/crypto/signature.py new file mode 100644 index 00000000..6dc13e65 --- /dev/null +++ b/mautrix/crypto/signature.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Any, TypedDict +import functools +import json + +import olm +import unpaddedbase64 + +from mautrix.types import ( + JSON, + DeviceID, + EncryptionKeyAlgorithm, + KeyID, + Serializable, + Signature, + SigningKey, + UserID, +) + +try: + from Crypto.PublicKey import ECC + from Crypto.Signature import eddsa +except ImportError: + from Cryptodome.PublicKey import ECC + from Cryptodome.Signature import eddsa + +canonical_json = functools.partial( + json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True +) + + +class SignedObject(TypedDict): + signatures: dict[UserID, dict[str, str]] + unsigned: Any + + +def sign_olm(data: dict[str, JSON] | Serializable, key: olm.PkSigning | olm.Account) -> Signature: + if isinstance(data, Serializable): + data = data.serialize() + data.pop("signatures", None) + data.pop("unsigned", None) + return Signature(key.sign(canonical_json(data))) + + +def verify_signature_json( + data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey +) -> bool: + data_copy = {**data} + data_copy.pop("unsigned", None) + signatures = data_copy.pop("signatures") + key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name)) + try: + signature = signatures[user_id][key_id] + decoded_key = unpaddedbase64.decode_base64(key) + # pycryptodome doesn't accept raw keys, so wrap it in a DER structure + der_key = b"\x30\x2a\x30\x05\x06\x03\x2b\x65\x70\x03\x21\x00" + decoded_key + decoded_signature = unpaddedbase64.decode_base64(signature) + parsed_key = ECC.import_key(der_key) + verifier = eddsa.new(parsed_key, "rfc8032") + verifier.verify(canonical_json(data_copy).encode("utf-8"), decoded_signature) + return True + except (KeyError, ValueError): + return False diff --git a/mautrix/crypto/signature_test.py b/mautrix/crypto/signature_test.py new file mode 100644 index 00000000..115e4836 --- /dev/null +++ b/mautrix/crypto/signature_test.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from mautrix.types import SigningKey, UserID + +from .signature import verify_signature_json + + +def test_verify_signature_json() -> None: + assert verify_signature_json( + # This is actually a federation PDU rather than a device signature, + # but they're both 25519 curves so it doesn't make a difference. + { + "auth_events": [ + "$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw", + "$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM", + "$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo", + ], + "content": {}, + "depth": 3212, + "hashes": {"sha256": "K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"}, + "origin_server_ts": 1754242687127, + "prev_events": ["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"], + "room_id": "!offtopic-2:continuwuity.org", + "sender": "@tulir:maunium.net", + "type": "m.room.message", + "signatures": { + UserID("maunium.net"): { + "ed25519:a_xxeS": "SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw" + } + }, + "unsigned": {"age_ts": 1754242687146}, + }, + UserID("maunium.net"), + "a_xxeS", + SigningKey("lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg"), + ) diff --git a/mautrix/crypto/ssss/__init__.py b/mautrix/crypto/ssss/__init__.py new file mode 100644 index 00000000..9224418d --- /dev/null +++ b/mautrix/crypto/ssss/__init__.py @@ -0,0 +1,8 @@ +from .key import Key, KeyMetadata, PassphraseMetadata +from .machine import Machine +from .types import ( + Algorithm, + EncryptedAccountDataEventContent, + EncryptedKeyData, + PassphraseAlgorithm, +) diff --git a/mautrix/crypto/ssss/key.py b/mautrix/crypto/ssss/key.py new file mode 100644 index 00000000..691ded71 --- /dev/null +++ b/mautrix/crypto/ssss/key.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Optional +import base64 +import hashlib +import hmac + +from attr import dataclass +import unpaddedbase64 + +from mautrix.types import EventType, SerializableAttrs + +from .types import Algorithm, EncryptedKeyData, PassphraseAlgorithm +from .util import ( + calculate_hash, + cryptorand, + decode_base58_recovery_key, + derive_keys, + encode_base58_recovery_key, + prepare_aes, +) + +try: + from Crypto.Cipher import AES + from Crypto.Util import Counter +except ImportError: + from Cryptodome.Cipher import AES + from Cryptodome.Util import Counter + + +@dataclass +class PassphraseMetadata(SerializableAttrs): + algorithm: PassphraseAlgorithm + iterations: int + salt: str + bits: int = 256 + + def get_key(self, passphrase: str) -> bytes: + if self.algorithm != PassphraseAlgorithm.PBKDF2: + raise ValueError(f"Unsupported passphrase algorithm {self.algorithm}") + return hashlib.pbkdf2_hmac( + "sha512", + passphrase.encode("utf-8"), + self.salt.encode("utf-8"), + self.iterations, + self.bits // 8, + ) + + +@dataclass +class KeyMetadata(SerializableAttrs): + algorithm: Algorithm + + iv: str | None = None + mac: str | None = None + + name: str | None = None + passphrase: Optional[PassphraseMetadata] = None + + def verify_passphrase(self, key_id: str, phrase: str) -> "Key": + if not self.passphrase: + raise ValueError("Passphrase not set on this key") + return self.verify_raw_key(key_id, self.passphrase.get_key(phrase)) + + def verify_recovery_key(self, key_id: str, recovery_key: str) -> "Key": + decoded_key = decode_base58_recovery_key(recovery_key) + if not decoded_key: + raise ValueError("Invalid recovery key syntax") + return self.verify_raw_key(key_id, decoded_key) + + def verify_raw_key(self, key_id: str, key: bytes) -> "Key": + if self.mac.rstrip("=") != calculate_hash(key, self.iv): + raise ValueError("Key MAC does not match") + return Key(id=key_id, key=key, metadata=self) + + +@dataclass +class Key: + id: str + key: bytes + metadata: KeyMetadata + + @classmethod + def generate(cls, passphrase: str | None = None) -> "Key": + passphrase_meta = ( + PassphraseMetadata( + algorithm=PassphraseAlgorithm.PBKDF2, + iterations=500_000, + salt=base64.b64encode(cryptorand.read(24)).decode("utf-8"), + bits=256, + ) + if passphrase + else None + ) + key = passphrase_meta.get_key(passphrase) if passphrase else cryptorand.read(32) + iv = unpaddedbase64.encode_base64(cryptorand.read(16)) + metadata = KeyMetadata( + algorithm=Algorithm.AES_HMAC_SHA2, + passphrase=passphrase_meta, + mac=calculate_hash(key, iv), + iv=iv, + ) + key_id = unpaddedbase64.encode_base64(cryptorand.read(24)) + return cls(key=key, id=key_id, metadata=metadata) + + @property + def recovery_key(self) -> str: + return encode_base58_recovery_key(self.key) + + def encrypt(self, event_type: str | EventType, data: str | bytes) -> EncryptedKeyData: + if isinstance(data, str): + data = data.encode("utf-8") + data = base64.b64encode(data).rstrip(b"=") + + aes_key, hmac_key = derive_keys(self.key, event_type) + iv = bytearray(cryptorand.read(16)) + iv[8] &= 0x7F + ciphertext = prepare_aes(aes_key, iv).encrypt(data) + digest = hmac.digest(hmac_key, ciphertext, hashlib.sha256) + return EncryptedKeyData( + ciphertext=unpaddedbase64.encode_base64(ciphertext), + iv=unpaddedbase64.encode_base64(iv), + mac=unpaddedbase64.encode_base64(digest), + ) + + def decrypt(self, event_type: str | EventType, data: EncryptedKeyData) -> bytes: + aes_key, hmac_key = derive_keys(self.key, event_type) + ciphertext = unpaddedbase64.decode_base64(data.ciphertext) + mac = unpaddedbase64.decode_base64(data.mac) + + expected_mac = hmac.digest(hmac_key, ciphertext, hashlib.sha256) + if not hmac.compare_digest(mac, expected_mac): + raise ValueError("Invalid MAC") + + plaintext = prepare_aes(aes_key, data.iv).decrypt(ciphertext) + return unpaddedbase64.decode_base64(plaintext.decode("utf-8")) diff --git a/mautrix/crypto/ssss/key_test.py b/mautrix/crypto/ssss/key_test.py new file mode 100644 index 00000000..e60f1e3c --- /dev/null +++ b/mautrix/crypto/ssss/key_test.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import pytest + +from ...types.event.type import EventType +from .key import Key, KeyMetadata +from .types import EncryptedAccountDataEventContent + +KEY1_CROSS_SIGNING_MASTER_KEY = """{ + "encrypted": { + "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0": { + "iv": "BpKP9nQJTE9jrsAssoxPqQ==", + "ciphertext": "fNRiiiidezjerTgV+G6pUtmeF3izzj5re/mVvY0hO2kM6kYGrxLuIu2ej80=", + "mac": "/gWGDGMyOLmbJp+aoSLh5JxCs0AdS6nAhjzpe+9G2Q0=" + } + } +}""" + +KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED = bytes( + [ + 0x68, + 0xF9, + 0x7F, + 0xD1, + 0x92, + 0x2E, + 0xEC, + 0xF6, + 0xB8, + 0x2B, + 0xB8, + 0x90, + 0xD2, + 0x4D, + 0x06, + 0x52, + 0x98, + 0x4E, + 0x7A, + 0x1D, + 0x70, + 0x3B, + 0x9E, + 0x86, + 0x7B, + 0x7E, + 0xBA, + 0xF7, + 0xFE, + 0xB9, + 0x5B, + 0x6F, + ] +) + +KEY1_META = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "passphrase": { + "algorithm": "m.pbkdf2", + "iterations": 500000, + "salt": "y863BOoqOadgDp8S3FtHXikDJEalsQ7d" + }, + "iv": "xxkTK0L4UzxgAFkQ6XPwsw", + "mac": "MEhooO0ZhFJNxUhvRMSxBnJfL20wkLgle3ocY0ee/eA" +}""" +KEY1_ID = "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0" +KEY1_RECOVERY_KEY = "EsTE s92N EtaX s2h6 VQYF 9Kao tHYL mkyL GKMh isZb KJ4E tvoC" +KEY1_PASSPHRASE = "correct horse battery staple" + +KEY2_META = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +}""" +KEY2_ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" +KEY2_RECOVERY_KEY = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" + +KEY2_META_BROKEN_IV = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +}""" + +KEY2_META_BROKEN_MAC = """{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow" +}""" + + +def get_key_meta(meta: str) -> KeyMetadata: + return KeyMetadata.parse_json(meta) + + +def get_key1() -> Key: + return get_key_meta(KEY1_META).verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY) + + +def get_key2() -> Key: + return get_key_meta(KEY2_META).verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def get_encrypted_master_key() -> EncryptedAccountDataEventContent: + return EncryptedAccountDataEventContent.parse_json(KEY1_CROSS_SIGNING_MASTER_KEY) + + +def test_decrypt_success() -> None: + key = get_key1() + emk = get_encrypted_master_key() + assert ( + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) == KEY1_CROSS_SIGNING_MASTER_KEY_DECRYPTED + ) + + +def test_decrypt_fail_wrong_key() -> None: + key = get_key2() + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) + + +def test_decrypt_fail_fake_key() -> None: + key = get_key2() + key.id = KEY1_ID + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_MASTER, key) + + +def test_decrypt_fail_wrong_type() -> None: + key = get_key1() + emk = get_encrypted_master_key() + with pytest.raises(ValueError): + emk.decrypt(EventType.CROSS_SIGNING_SELF_SIGNING, key) + + +def test_encrypt_roundtrip() -> None: + key = get_key1() + data = bytes([0xDE, 0xAD, 0xBE, 0xEF]) + ciphertext = key.encrypt("net.maunium.data", data) + plaintext = key.decrypt("net.maunium.data", ciphertext) + assert plaintext == data + + +def test_verify_recovery_key_correct() -> None: + meta = get_key_meta(KEY1_META) + key = meta.verify_recovery_key(KEY1_ID, KEY1_RECOVERY_KEY) + assert key.recovery_key == KEY1_RECOVERY_KEY + + +def test_verify_recovery_key_correct2() -> None: + meta = get_key_meta(KEY2_META) + key = meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + assert key.recovery_key == KEY2_RECOVERY_KEY + + +def test_verify_recovery_key_invalid() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY1_ID, "foo") + + +def test_verify_recovery_key_incorrect() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_recovery_key_broken_iv() -> None: + meta = get_key_meta(KEY2_META_BROKEN_IV) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_recovery_key_broken_mac() -> None: + meta = get_key_meta(KEY2_META_BROKEN_MAC) + with pytest.raises(ValueError): + meta.verify_recovery_key(KEY2_ID, KEY2_RECOVERY_KEY) + + +def test_verify_passphrase_correct() -> None: + meta = get_key_meta(KEY1_META) + key = meta.verify_passphrase(KEY1_ID, KEY1_PASSPHRASE) + assert key.recovery_key == KEY1_RECOVERY_KEY + + +def test_verify_passphrase_incorrect() -> None: + meta = get_key_meta(KEY1_META) + with pytest.raises(ValueError): + meta.verify_passphrase(KEY1_ID, "incorrect horse battery staple") + + +def test_verify_passphrase_notset() -> None: + meta = get_key_meta(KEY2_META) + with pytest.raises(ValueError): + meta.verify_passphrase(KEY2_ID, "hmm") diff --git a/mautrix/crypto/ssss/machine.py b/mautrix/crypto/ssss/machine.py new file mode 100644 index 00000000..c43e25e3 --- /dev/null +++ b/mautrix/crypto/ssss/machine.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from mautrix import client as cli +from mautrix.errors import MNotFound +from mautrix.types import EventType, SecretStorageDefaultKeyEventContent + +from .key import Key, KeyMetadata +from .types import EncryptedAccountDataEventContent + + +class Machine: + client: cli.Client + + def __init__(self, client: cli.Client) -> None: + self.client = client + + async def get_default_key_id(self) -> str | None: + try: + data = await self.client.get_account_data(EventType.SECRET_STORAGE_DEFAULT_KEY) + return SecretStorageDefaultKeyEventContent.deserialize(data).key + except (MNotFound, ValueError): + return None + + async def set_default_key_id(self, key_id: str) -> None: + await self.client.set_account_data( + EventType.SECRET_STORAGE_DEFAULT_KEY, + SecretStorageDefaultKeyEventContent(key=key_id), + ) + + async def get_key_data(self, key_id: str) -> KeyMetadata: + data = await self.client.get_account_data(f"m.secret_storage.key.{key_id}") + return KeyMetadata.deserialize(data) + + async def set_key_data(self, key_id: str, data: KeyMetadata) -> None: + await self.client.set_account_data(f"m.secret_storage.key.{key_id}", data) + + async def get_default_key_data(self) -> tuple[str, KeyMetadata]: + key_id = await self.get_default_key_id() + if not key_id: + raise ValueError("No default key ID set") + return key_id, await self.get_key_data(key_id) + + async def get_decrypted_account_data(self, event_type: EventType | str, key: Key) -> bytes: + data = await self.client.get_account_data(event_type) + parsed = EncryptedAccountDataEventContent.deserialize(data) + return parsed.decrypt(event_type, key) + + async def set_encrypted_account_data( + self, event_type: EventType | str, data: bytes, *keys: Key + ) -> None: + encrypted_data = {} + for key in keys: + encrypted_data[key.id] = key.encrypt(event_type, data) + await self.client.set_account_data( + event_type, + EncryptedAccountDataEventContent(encrypted=encrypted_data), + ) + + async def generate_and_upload_key(self, passphrase: str | None = None) -> Key: + key = Key.generate(passphrase) + await self.set_key_data(key.id, key.metadata) + return key diff --git a/mautrix/crypto/ssss/types.py b/mautrix/crypto/ssss/types.py new file mode 100644 index 00000000..4a47f743 --- /dev/null +++ b/mautrix/crypto/ssss/types.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import TYPE_CHECKING + +from attr import dataclass + +from mautrix.types import EventType, SerializableAttrs, SerializableEnum +from mautrix.types.event.account_data import account_data_event_content_map + +if TYPE_CHECKING: + from .key import Key + + +class Algorithm(SerializableEnum): + AES_HMAC_SHA2 = "m.secret_storage.v1.aes-hmac-sha2" + CURVE25519_AES_SHA2 = "m.secret_storage.v1.curve25519-aes-sha2" + + +class PassphraseAlgorithm(SerializableEnum): + PBKDF2 = "m.pbkdf2" + + +@dataclass +class EncryptedKeyData(SerializableAttrs): + ciphertext: str + iv: str + mac: str + + +@dataclass +class EncryptedAccountDataEventContent(SerializableAttrs): + encrypted: dict[str, EncryptedKeyData] + + def decrypt(self, event_type: str | EventType, key: "Key") -> bytes: + try: + encrypted_data = self.encrypted[key.id] + except KeyError as e: + raise ValueError(f"Event not encrypted for provided key") from e + return key.decrypt(event_type, encrypted_data) + + +for encrypted_account_data_type in ( + EventType.CROSS_SIGNING_MASTER, + EventType.CROSS_SIGNING_USER_SIGNING, + EventType.CROSS_SIGNING_SELF_SIGNING, + EventType.MEGOLM_BACKUP_V1, +): + account_data_event_content_map[encrypted_account_data_type] = EncryptedAccountDataEventContent diff --git a/mautrix/crypto/ssss/util.py b/mautrix/crypto/ssss/util.py new file mode 100644 index 00000000..b58c941a --- /dev/null +++ b/mautrix/crypto/ssss/util.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import hashlib +import hmac + +import base58 +import unpaddedbase64 + +from mautrix.types import EventType + +try: + from Crypto import Random + from Crypto.Cipher import AES + from Crypto.Hash import SHA256 + from Crypto.Protocol.KDF import HKDF + from Crypto.Util import Counter +except ImportError: + from Cryptodome import Random + from Cryptodome.Cipher import AES + from Cryptodome.Hash import SHA256 + from Cryptodome.Protocol.KDF import HKDF + from Cryptodome.Util import Counter + +cryptorand = Random.new() + + +def decode_base58_recovery_key(key: str) -> bytes | None: + key_bytes = base58.b58decode(key.replace(" ", "")) + if len(key_bytes) != 35 or key_bytes[0] != 0x8B or key_bytes[1] != 1: + return None + parity = 0 + for byte in key_bytes[:34]: + parity ^= byte + return key_bytes[2:34] if parity == key_bytes[34] else None + + +def encode_base58_recovery_key(key: bytes) -> str: + key_bytes = bytearray(35) + key_bytes[0] = 0x8B + key_bytes[1] = 1 + key_bytes[2:34] = key + parity = 0 + for byte in key_bytes: + parity ^= byte + key_bytes[34] = parity + encoded_key = base58.b58encode(key_bytes).decode("utf-8") + return " ".join(encoded_key[i : i + 4] for i in range(0, len(encoded_key), 4)) + + +def derive_keys(key: bytes, name: str | EventType = "") -> tuple[bytes, bytes]: + aes_key, hmac_key = HKDF( + master=key, + key_len=32, + salt=b"\x00" * 32, + hashmod=SHA256, + num_keys=2, + context=str(name).encode("utf-8"), + ) + return aes_key, hmac_key + + +def prepare_aes(key: bytes, iv: str | bytes) -> AES: + if isinstance(iv, str): + iv = unpaddedbase64.decode_base64(iv) + # initial_value = struct.unpack(">Q", iv[8:])[0] + # counter = Counter.new(64, prefix=iv[:8], initial_value=initial_value) + counter = Counter.new(128, initial_value=int.from_bytes(iv, byteorder="big")) + return AES.new(key=key, mode=AES.MODE_CTR, counter=counter) + + +def calculate_hash(key: bytes, iv: str | bytes) -> str: + aes_key, hmac_key = derive_keys(key) + cipher = prepare_aes(aes_key, iv).decrypt(b"\x00" * 32) + digest = hmac.digest(hmac_key, cipher, hashlib.sha256) + return unpaddedbase64.encode_base64(digest) diff --git a/mautrix/crypto/store/abstract.py b/mautrix/crypto/store/abstract.py index 8adc9387..0916a2d7 100644 --- a/mautrix/crypto/store/abstract.py +++ b/mautrix/crypto/store/abstract.py @@ -1,23 +1,31 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations +from typing import AsyncContextManager, NamedTuple from abc import ABC, abstractmethod +from contextlib import asynccontextmanager from mautrix.types import ( + CrossSigner, + CrossSigningUsage, DeviceID, + DeviceIdentity, EventID, IdentityKey, RoomEncryptionStateEventContent, RoomID, SessionID, + SigningKey, + TOFUSigningKey, UserID, ) -from .. import DeviceIdentity, InboundGroupSession, OlmAccount, OutboundGroupSession, Session +from ..account import OlmAccount +from ..sessions import InboundGroupSession, OutboundGroupSession, Session class StateStore(ABC): @@ -80,6 +88,11 @@ async def close(self) -> None: async def flush(self) -> None: """Flush the store. If all the methods persist data immediately, this can be a no-op.""" + @asynccontextmanager + async def transaction(self) -> None: + """Run a database transaction. If the store doesn't support transactions, this can be a no-op.""" + yield + @abstractmethod async def delete(self) -> None: """Delete the data in the store.""" @@ -176,7 +189,7 @@ async def put_group_session( @abstractmethod async def get_group_session( - self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID + self, room_id: RoomID, session_id: SessionID ) -> InboundGroupSession | None: """ Get an inbound Megolm group session that was previously inserted with @@ -184,7 +197,6 @@ async def get_group_session( Args: room_id: The room ID for which the session was made. - sender_key: The curve25519 identity key of the user who made the session. session_id: The unique identifier of the session. Returns: @@ -192,16 +204,63 @@ async def get_group_session( """ @abstractmethod - async def has_group_session( - self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID - ) -> bool: + async def redact_group_session( + self, room_id: RoomID, session_id: SessionID, reason: str + ) -> None: + """ + Remove the keys for a specific Megolm group session. + + Args: + room_id: The room where the session is. + session_id: The session ID to remove. + reason: The reason the session is being removed. + """ + + @abstractmethod + async def redact_group_sessions( + self, room_id: RoomID | None, sender_key: IdentityKey | None, reason: str + ) -> list[SessionID]: + """ + Remove the keys for multiple Megolm group sessions, + based on the room ID and/or sender device. + + Args: + room_id: The room ID to delete keys from. + sender_key: The Olm identity key of the device to delete keys from. + reason: The reason why the keys are being deleted. + + Returns: + The list of session IDs that were deleted. + """ + + @abstractmethod + async def redact_expired_group_sessions(self) -> list[SessionID]: + """ + Remove all Megolm group sessions where at least twice the maximum age has passed since + receiving the keys. + + Returns: + The list of session IDs that were deleted. + """ + + @abstractmethod + async def redact_outdated_group_sessions(self) -> list[SessionID]: + """ + Remove all Megolm group sessions which lack the metadata to determine when they should + expire. + + Returns: + The list of session IDs that were deleted. + """ + + @abstractmethod + async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: """ Check whether or not a specific inbound Megolm session is in the store. This is used before importing forwarded keys. Args: room_id: The room ID for which the session was made. - sender_key: The curve25519 identity key of the user who made the session. session_id: The unique identifier of the session. Returns: @@ -362,3 +421,69 @@ async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]: A filtered version of the input list that only includes users who have had a previous call to :meth:`put_devices` (even if the call was with an empty dict). """ + + @abstractmethod + async def put_cross_signing_key( + self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey + ) -> None: + """ + Store a single cross-signing key. + + Args: + user_id: The user whose cross-signing key is being stored. + usage: The type of key being stored. + key: The key itself. + """ + + @abstractmethod + async def get_cross_signing_keys( + self, user_id: UserID + ) -> dict[CrossSigningUsage, TOFUSigningKey]: + """ + Retrieve stored cross-signing keys for a specific user. + + Args: + user_id: The user whose cross-signing keys to get. + + Returns: + A map from the type of key to a tuple containing the current key and the key that was + seen first. If the keys are different, it should be treated as a local TOFU violation. + """ + + @abstractmethod + async def put_signature( + self, target: CrossSigner, signer: CrossSigner, signature: str + ) -> None: + """ + Store a signature for a given key from a given key. + + Args: + target: The user ID and key being signed. + signer: The user ID and key who are doing the signing. + signature: The signature. + """ + + @abstractmethod + async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool: + """ + Check if a given key is signed by the given signer. + + Args: + target: The key to check. + signer: The signer who is expected to have signed the key. + + Returns: + ``True`` if the database contains a signature for the key, ``False`` otherwise. + """ + + @abstractmethod + async def drop_signatures_by_key(self, signer: CrossSigner) -> int: + """ + Delete signatures made by the given key. + + Args: + signer: The key whose signatures to delete. + + Returns: + The number of signatures deleted. + """ diff --git a/mautrix/crypto/store/asyncpg/__init__.py b/mautrix/crypto/store/asyncpg/__init__.py index 253ccb85..fe2645da 100644 --- a/mautrix/crypto/store/asyncpg/__init__.py +++ b/mautrix/crypto/store/asyncpg/__init__.py @@ -1 +1,3 @@ from .store import PgCryptoStateStore, PgCryptoStore + +__all__ = ["PgCryptoStore", "PgCryptoStateStore"] diff --git a/mautrix/crypto/store/asyncpg/store.py b/mautrix/crypto/store/asyncpg/store.py index 204afa74..bdc37ddd 100644 --- a/mautrix/crypto/store/asyncpg/store.py +++ b/mautrix/crypto/store/asyncpg/store.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -6,25 +6,48 @@ from __future__ import annotations from collections import defaultdict +from contextlib import asynccontextmanager from datetime import timedelta +from asyncpg import UniqueViolationError + from mautrix.client.state_store import SyncStore from mautrix.client.state_store.asyncpg import PgStateStore -from mautrix.types import DeviceID, EventID, IdentityKey, RoomID, SessionID, SyncToken, UserID -from mautrix.util.async_db import Database, Scheme -from mautrix.util.logging import TraceLogger - -from ... import ( +from mautrix.errors import GroupSessionWithheldError +from mautrix.types import ( + CrossSigner, + CrossSigningUsage, + DeviceID, DeviceIdentity, - InboundGroupSession, - OlmAccount, - OutboundGroupSession, - Session, + EventID, + IdentityKey, + RoomID, + RoomKeyWithheldCode, + SessionID, + SigningKey, + SyncToken, + TOFUSigningKey, TrustState, + UserID, ) +from mautrix.util.async_db import Database, Scheme +from mautrix.util.logging import TraceLogger + +from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, RatchetSafety, Session from ..abstract import CryptoStore, StateStore from .upgrade import upgrade_table +try: + from sqlite3 import IntegrityError, sqlite_version_info as sqlite_version + + from aiosqlite import Cursor +except ImportError: + Cursor = None + sqlite_version = (0, 0, 0) + + class IntegrityError(Exception): + pass + class PgCryptoStateStore(PgStateStore, StateStore): """ @@ -50,12 +73,18 @@ def __init__(self, account_id: str, pickle_key: str, db: Database) -> None: self.db = db self.account_id = account_id self.pickle_key = pickle_key + self.log = db.log self._sync_token = None self._device_id = DeviceID("") self._account = None self._olm_cache = defaultdict(lambda: {}) + @asynccontextmanager + async def transaction(self) -> None: + async with self.db.acquire() as conn, conn.transaction(): + yield + async def delete(self) -> None: tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session") async with self.db.acquire() as conn, conn.transaction(): @@ -63,45 +92,40 @@ async def delete(self) -> None: await conn.execute(f"DELETE FROM {table} WHERE account_id=$1", self.account_id) async def get_device_id(self) -> DeviceID | None: - device_id = await self.db.fetchval( - "SELECT device_id FROM crypto_account WHERE account_id=$1", self.account_id - ) + q = "SELECT device_id FROM crypto_account WHERE account_id=$1" + device_id = await self.db.fetchval(q, self.account_id) self._device_id = device_id or self._device_id return self._device_id async def put_device_id(self, device_id: DeviceID) -> None: - await self.db.fetchval( - "UPDATE crypto_account SET device_id=$1 WHERE account_id=$2", - device_id, - self.account_id, - ) + q = "UPDATE crypto_account SET device_id=$1 WHERE account_id=$2" + await self.db.fetchval(q, device_id, self.account_id) self._device_id = device_id async def put_next_batch(self, next_batch: SyncToken) -> None: self._sync_token = next_batch - await self.db.execute( - "UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2", - self._sync_token, - self.account_id, - ) + q = "UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2" + await self.db.execute(q, self._sync_token, self.account_id) async def get_next_batch(self) -> SyncToken: if self._sync_token is None: - self._sync_token = await self.db.fetchval( - "SELECT sync_token FROM crypto_account WHERE account_id=$1", self.account_id - ) + q = "SELECT sync_token FROM crypto_account WHERE account_id=$1" + self._sync_token = await self.db.fetchval(q, self.account_id) return self._sync_token async def put_account(self, account: OlmAccount) -> None: self._account = account pickle = account.pickle(self.pickle_key) + q = """ + INSERT INTO crypto_account (account_id, device_id, shared, sync_token, account) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (account_id) DO UPDATE + SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account + """ await self.db.execute( - "INSERT INTO crypto_account (account_id, device_id, shared, " - "sync_token, account) VALUES($1, $2, $3, $4, $5) " - "ON CONFLICT (account_id) DO UPDATE SET shared=$3, sync_token=$4," - " account=$5", + q, self.account_id, - self._device_id, + self._device_id or "", account.shared, self._sync_token or "", pickle, @@ -109,10 +133,8 @@ async def put_account(self, account: OlmAccount) -> None: async def get_account(self) -> OlmAccount: if self._account is None: - row = await self.db.fetchrow( - "SELECT shared, account, device_id FROM crypto_account WHERE account_id=$1", - self.account_id, - ) + q = "SELECT shared, account, device_id FROM crypto_account WHERE account_id=$1" + row = await self.db.fetchrow(q, self.account_id) if row is not None: self._account = OlmAccount.from_pickle( row["account"], passphrase=self.pickle_key, shared=row["shared"] @@ -122,19 +144,16 @@ async def get_account(self) -> OlmAccount: async def has_session(self, key: IdentityKey) -> bool: if len(self._olm_cache[key]) > 0: return True - val = await self.db.fetchval( - "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2", - key, - self.account_id, - ) + q = "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2" + val = await self.db.fetchval(q, key, self.account_id) return val is not None async def get_sessions(self, key: IdentityKey) -> list[Session]: - q = ( - "SELECT session_id, session, created_at, last_encrypted, last_decrypted " - "FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 " - "ORDER BY last_decrypted DESC" - ) + q = """ + SELECT session_id, session, created_at, last_encrypted, last_decrypted + FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 + ORDER BY last_decrypted DESC + """ rows = await self.db.fetch(q, key, self.account_id) sessions = [] for row in rows: @@ -153,11 +172,11 @@ async def get_sessions(self, key: IdentityKey) -> list[Session]: return sessions async def get_latest_session(self, key: IdentityKey) -> Session | None: - q = ( - "SELECT session_id, session, created_at, last_encrypted, last_decrypted " - "FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 " - "ORDER BY last_decrypted DESC LIMIT 1" - ) + q = """ + SELECT session_id, session, created_at, last_encrypted, last_decrypted + FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 + ORDER BY last_decrypted DESC LIMIT 1 + """ row = await self.db.fetchrow(q, key, self.account_id) if row is None: return None @@ -180,9 +199,9 @@ async def add_session(self, key: IdentityKey, session: Session) -> None: self._olm_cache[key][SessionID(session.id)] = session pickle = session.pickle(self.pickle_key) q = """ - INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, - last_encrypted, last_decrypted, account_id) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO crypto_olm_session ( + session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7) """ await self.db.execute( q, @@ -204,10 +223,10 @@ async def update_session(self, key: IdentityKey, session: Session) -> None: f"isn't equal to the one being saved to the database ({e})" ) pickle = session.pickle(self.pickle_key) - q = ( - "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 " - "WHERE session_id=$4 AND account_id=$5" - ) + q = """ + UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 + WHERE session_id=$4 AND account_id=$5 + """ await self.db.execute( q, pickle, session.last_encrypted, session.last_decrypted, session.id, self.account_id ) @@ -221,65 +240,172 @@ async def put_group_session( ) -> None: pickle = session.pickle(self.pickle_key) forwarding_chains = ",".join(session.forwarding_chain) - await self.db.execute( - "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, " - "signing_key, room_id, session, forwarding_chains, account_id) " - "VALUES ($1, $2, $3, $4, $5, $6, $7)", - session_id, - sender_key, - session.signing_key, - room_id, - pickle, - forwarding_chains, - self.account_id, - ) + q = """ + INSERT INTO crypto_megolm_inbound_session ( + session_id, sender_key, signing_key, room_id, session, forwarding_chains, + ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT (session_id, account_id) DO UPDATE + SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, + signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, + forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, + received_at=excluded.received_at, max_age=excluded.max_age, + max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled + """ + try: + await self.db.execute( + q, + session_id, + sender_key, + session.signing_key, + room_id, + pickle, + forwarding_chains, + session.ratchet_safety.json(), + session.received_at, + int(session.max_age.total_seconds() * 1000) if session.max_age else None, + session.max_messages, + session.is_scheduled, + self.account_id, + ) + except (IntegrityError, UniqueViolationError): + self.log.exception(f"Failed to insert megolm session {session_id}") async def get_group_session( - self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID + self, room_id: RoomID, session_id: SessionID ) -> InboundGroupSession | None: - row = await self.db.fetchrow( - "SELECT signing_key, session, forwarding_chains FROM crypto_megolm_inbound_session " - "WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4", - room_id, - sender_key, - session_id, - self.account_id, - ) + q = """ + SELECT + sender_key, signing_key, session, forwarding_chains, withheld_code, + ratchet_safety, received_at, max_age, max_messages, is_scheduled + FROM crypto_megolm_inbound_session + WHERE room_id=$1 AND session_id=$2 AND account_id=$3 + """ + row = await self.db.fetchrow(q, room_id, session_id, self.account_id) if row is None: return None + if row["withheld_code"] is not None: + raise GroupSessionWithheldError(session_id, row["withheld_code"]) + forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else [] return InboundGroupSession.from_pickle( row["session"], passphrase=self.pickle_key, signing_key=row["signing_key"], - sender_key=sender_key, + sender_key=row["sender_key"], room_id=room_id, - forwarding_chain=row["forwarding_chains"].split(","), + forwarding_chain=forwarding_chain, + ratchet_safety=RatchetSafety.parse_json(row["ratchet_safety"] or "{}"), + received_at=row["received_at"], + max_age=timedelta(milliseconds=row["max_age"]) if row["max_age"] else None, + max_messages=row["max_messages"], + is_scheduled=row["is_scheduled"], ) - async def has_group_session( - self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID - ) -> bool: - count = await self.db.fetchval( - "SELECT COUNT(session) FROM crypto_megolm_inbound_session " - "WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4", + async def redact_group_session( + self, room_id: RoomID, session_id: SessionID, reason: str + ) -> None: + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL + """ + await self.db.execute( + q, + RoomKeyWithheldCode.BEEPER_REDACTED.value, + f"Session redacted: {reason}", + session_id, + self.account_id, + ) + + async def redact_group_sessions( + self, room_id: RoomID, sender_key: IdentityKey, reason: str + ) -> list[SessionID]: + if not room_id and not sender_key: + raise ValueError("Either room_id or sender_key must be provided") + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5 + AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL + RETURNING session_id + """ + rows = await self.db.fetch( + q, + RoomKeyWithheldCode.BEEPER_REDACTED.value, + f"Session redacted: {reason}", room_id, sender_key, - session_id, self.account_id, ) + return [row["session_id"] for row in rows] + + async def redact_expired_group_sessions(self) -> list[SessionID]: + if self.db.scheme == Scheme.SQLITE: + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false + AND received_at IS NOT NULL and max_age IS NOT NULL + AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now')) + RETURNING session_id + """ + elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false + AND received_at IS NOT NULL and max_age IS NOT NULL + AND received_at + 2 * (max_age * interval '1 millisecond') < now() + RETURNING session_id + """ + else: + raise RuntimeError(f"Unsupported dialect {self.db.scheme}") + rows = await self.db.fetch( + q, + RoomKeyWithheldCode.BEEPER_REDACTED.value, + f"Session redacted: expired", + self.account_id, + ) + return [row["session_id"] for row in rows] + + async def redact_outdated_group_sessions(self) -> list[SessionID]: + q = """ + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL + WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL + RETURNING session_id + """ + rows = await self.db.fetch( + q, + RoomKeyWithheldCode.BEEPER_REDACTED.value, + f"Session redacted: outdated", + self.account_id, + ) + return [row["session_id"] for row in rows] + + async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: + q = """ + SELECT COUNT(session) FROM crypto_megolm_inbound_session + WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL + """ + count = await self.db.fetchval(q, room_id, session_id, self.account_id) return count > 0 async def add_outbound_group_session(self, session: OutboundGroupSession) -> None: pickle = session.pickle(self.pickle_key) - max_age = session.max_age - if self.db.scheme == Scheme.SQLITE: - max_age = max_age.total_seconds() + max_age = int(session.max_age.total_seconds() * 1000) + q = """ + INSERT INTO crypto_megolm_outbound_session ( + room_id, session_id, session, shared, max_messages, message_count, + max_age, created_at, last_used, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (account_id, room_id) DO UPDATE + SET session_id=excluded.session_id, session=excluded.session, shared=excluded.shared, + max_messages=excluded.max_messages, message_count=excluded.message_count, + max_age=excluded.max_age, created_at=excluded.created_at, last_used=excluded.last_used + """ await self.db.execute( - "INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, " - "max_messages, message_count, max_age, created_at, last_used, account_id) " - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" - "ON CONFLICT (account_id, room_id) DO UPDATE SET session_id=$2, session=$3, shared=$4," - " max_messages=$5, message_count=$6, max_age=$7, created_at=$8, last_used=$9", + q, session.room_id, session.id, pickle, @@ -294,9 +420,12 @@ async def add_outbound_group_session(self, session: OutboundGroupSession) -> Non async def update_outbound_group_session(self, session: OutboundGroupSession) -> None: pickle = session.pickle(self.pickle_key) + q = """ + UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 + WHERE room_id=$4 AND session_id=$5 AND account_id=$6 + """ await self.db.execute( - "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 " - "WHERE room_id=$4 AND session_id=$5 AND account_id=$6", + q, pickle, session.message_count, session.use_time, @@ -306,18 +435,14 @@ async def update_outbound_group_session(self, session: OutboundGroupSession) -> ) async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None: - row = await self.db.fetchrow( - "SELECT room_id, session_id, session, shared, max_messages, message_count, max_age, " - " created_at, last_used " - "FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2", - room_id, - self.account_id, - ) + q = """ + SELECT room_id, session_id, session, shared, max_messages, message_count, max_age, + created_at, last_used + FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2 + """ + row = await self.db.fetchrow(q, room_id, self.account_id) if row is None: return None - max_age = row["max_age"] - if self.db.scheme == Scheme.SQLITE: - max_age = timedelta(seconds=max_age) return OutboundGroupSession.from_pickle( row["session"], passphrase=self.pickle_key, @@ -325,45 +450,35 @@ async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSess shared=row["shared"], max_messages=row["max_messages"], message_count=row["message_count"], - max_age=max_age, + max_age=timedelta(milliseconds=row["max_age"]), use_time=row["last_used"], creation_time=row["created_at"], ) async def remove_outbound_group_session(self, room_id: RoomID) -> None: - await self.db.execute( - "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2", - room_id, - self.account_id, - ) + q = "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2" + await self.db.execute(q, room_id, self.account_id) async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None: if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): - await self.db.execute( - "DELETE FROM crypto_megolm_outbound_session " - "WHERE account_id=$1 AND room_id=ANY($2)", - self.account_id, - rooms, - ) + q = """ + DELETE FROM crypto_megolm_outbound_session WHERE account_id=$1 AND room_id=ANY($2) + """ + await self.db.execute(q, self.account_id, rooms) else: params = ",".join(["?"] * len(rooms)) - await self.db.execute( - "DELETE FROM crypto_megolm_outbound_session " - f"WHERE account_id=? AND room_id IN ({params})", - self.account_id, - *rooms, - ) - - _validate_message_index_query = ( - "WITH existing AS (" - " INSERT INTO crypto_message_index(sender_key, session_id, index, event_id, timestamp)" - " VALUES ($1, $2, $3, $4, $5)" - # have to update something so that RETURNING * always returns the row - " ON CONFLICT (sender_key, session_id, index) DO UPDATE SET sender_key=$1" - " RETURNING *" - ")" - "SELECT * FROM existing" - ) + q = f""" + DELETE FROM crypto_megolm_outbound_session WHERE account_id=? AND room_id IN ({params}) + """ + await self.db.execute(q, self.account_id, *rooms) + + _validate_message_index_query = """ + INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) + VALUES ($1, $2, $3, $4, $5) + -- have to update something so that RETURNING * always returns the row + ON CONFLICT (sender_key, session_id, "index") DO UPDATE SET sender_key=excluded.sender_key + RETURNING * + """ async def validate_message_index( self, @@ -373,7 +488,11 @@ async def validate_message_index( index: int, timestamp: int, ) -> bool: - if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): + if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH) or ( + # RETURNING was added in SQLite 3.35.0 https://www.sqlite.org/lang_returning.html + self.db.scheme == Scheme.SQLITE + and sqlite_version >= (3, 35) + ): row = await self.db.fetchrow( self._validate_message_index_query, sender_key, @@ -406,16 +525,15 @@ async def validate_message_index( return True async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None: - tracked_user_id = await self.db.fetchval( - "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", user_id - ) + q = "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1" + tracked_user_id = await self.db.fetchval(q, user_id) if tracked_user_id is None: return None - rows = await self.db.fetch( - "SELECT device_id, identity_key, signing_key, trust, deleted, " - "name FROM crypto_device WHERE user_id=$1", - user_id, - ) + q = """ + SELECT device_id, identity_key, signing_key, trust, deleted, name + FROM crypto_device WHERE user_id=$1 + """ + rows = await self.db.fetch(q, user_id) result = {} for row in rows: result[row["device_id"]] = DeviceIdentity( @@ -430,12 +548,11 @@ async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | return result async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None: - row = await self.db.fetchrow( - "SELECT identity_key, signing_key, trust, deleted, name " - "FROM crypto_device WHERE user_id=$1 AND device_id=$2", - user_id, - device_id, - ) + q = """ + SELECT identity_key, signing_key, trust, deleted, name FROM crypto_device + WHERE user_id=$1 AND device_id=$2 + """ + row = await self.db.fetchrow(q, user_id, device_id) if row is None: return None return DeviceIdentity( @@ -451,9 +568,12 @@ async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdenti async def find_device_by_key( self, user_id: UserID, identity_key: IdentityKey ) -> DeviceIdentity | None: + q = """ + SELECT device_id, signing_key, trust, deleted, name FROM crypto_device + WHERE user_id=$1 AND identity_key=$2 + """ row = await self.db.fetchrow( - "SELECT device_id, signing_key, trust, deleted, name " - "FROM crypto_device WHERE user_id=$1 AND identity_key=$2", + q, user_id, identity_key, ) @@ -492,30 +612,105 @@ async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdent "name", ] async with self.db.acquire() as conn, conn.transaction(): - await conn.execute( - "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) " - "ON CONFLICT (user_id) DO NOTHING", - user_id, - ) + q = """ + INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING + """ + await conn.execute(q, user_id) await conn.execute("DELETE FROM crypto_device WHERE user_id=$1", user_id) if self.db.scheme == Scheme.POSTGRES: await conn.copy_records_to_table("crypto_device", records=data, columns=columns) else: - await conn.executemany( - "INSERT INTO crypto_device (user_id, device_id, " - "identity_key, signing_key, trust, deleted, name) " - "VALUES ($1, $2, $3, $4, $5, $6, $7)", - data, - ) + q = """ + INSERT INTO crypto_device ( + user_id, device_id, identity_key, signing_key, trust, deleted, name + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + """ + await conn.executemany(q, data) async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]: if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): - rows = await self.db.fetch( - "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", users - ) + q = "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)" + rows = await self.db.fetch(q, users) else: params = ",".join(["?"] * len(users)) - rows = await self.db.fetch( - f"SELECT user_id FROM crypto_tracked_user WHERE user_id IN ({params})", *users - ) + q = f"SELECT user_id FROM crypto_tracked_user WHERE user_id IN ({params})" + rows = await self.db.fetch(q, *users) return [row["user_id"] for row in rows] + + async def put_cross_signing_key( + self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey + ) -> None: + q = """ + INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key + """ + try: + await self.db.execute(q, user_id, usage.value, key, key) + except Exception: + self.log.exception(f"Failed to store cross-signing key {user_id}/{key}/{usage}") + + async def get_cross_signing_keys( + self, user_id: UserID + ) -> dict[CrossSigningUsage, TOFUSigningKey]: + q = "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1" + return { + CrossSigningUsage(row["usage"]): TOFUSigningKey( + key=SigningKey(row["key"]), + first=SigningKey(row["first_seen_key"]), + ) + for row in await self.db.fetch(q, user_id) + } + + async def put_signature( + self, target: CrossSigner, signer: CrossSigner, signature: str + ) -> None: + q = """ + INSERT INTO crypto_cross_signing_signatures ( + signed_user_id, signed_key, signer_user_id, signer_key, signature + ) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) + DO UPDATE SET signature=excluded.signature + """ + signed_user_id, signed_key = target + signer_user_id, signer_key = signer + try: + await self.db.execute( + q, signed_user_id, signed_key, signer_user_id, signer_key, signature + ) + except Exception: + self.log.exception( + f"Failed to store signature from {signer_user_id}/{signer_key} " + f"for {signed_user_id}/{signed_key}" + ) + + async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool: + q = """ + SELECT EXISTS( + SELECT 1 FROM crypto_cross_signing_signatures + WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4 + ) + """ + signed_user_id, signed_key = target + signer_user_id, signer_key = signer + return await self.db.fetchval(q, signed_user_id, signed_key, signer_user_id, signer_key) + + async def drop_signatures_by_key(self, signer: CrossSigner) -> int: + signer_user_id, signer_key = signer + q = "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2" + try: + res = await self.db.execute(q, signer_user_id, signer_key) + except Exception: + self.log.exception( + f"Failed to drop old signatures made by replaced key {signer_user_id}/{signer_key}" + ) + return -1 + if Cursor is not None and isinstance(res, Cursor): + return res.rowcount + elif ( + isinstance(res, str) + and res.startswith("DELETE ") + and (intPart := res[len("DELETE ") :]).isdecimal() + ): + return int(intPart) + return -1 diff --git a/mautrix/crypto/store/asyncpg/upgrade.py b/mautrix/crypto/store/asyncpg/upgrade.py index 12b1ec54..8d413858 100644 --- a/mautrix/crypto/store/asyncpg/upgrade.py +++ b/mautrix/crypto/store/asyncpg/upgrade.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 Tulir Asokan +# Copyright (c) 2023 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -16,34 +16,34 @@ ) -@upgrade_table.register(description="Latest revision", upgrades_to=4) -async def upgrade_blank_to_v4(conn: Connection) -> None: - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_account ( - account_id TEXT PRIMARY KEY, - device_id TEXT, - shared BOOLEAN NOT NULL, - sync_token TEXT NOT NULL, - account bytea NOT NULL - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_message_index ( +@upgrade_table.register(description="Latest revision", upgrades_to=10) +async def upgrade_blank_to_latest(conn: Connection) -> None: + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_account ( + account_id TEXT PRIMARY KEY, + device_id TEXT NOT NULL, + shared BOOLEAN NOT NULL, + sync_token TEXT NOT NULL, + account bytea NOT NULL + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_message_index ( sender_key CHAR(43), session_id CHAR(43), "index" INTEGER, event_id TEXT NOT NULL, timestamp BIGINT NOT NULL, PRIMARY KEY (sender_key, session_id, "index") - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_tracked_user ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_tracked_user ( user_id TEXT PRIMARY KEY - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_device ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_device ( user_id TEXT, device_id TEXT, identity_key CHAR(43) NOT NULL, @@ -52,10 +52,10 @@ async def upgrade_blank_to_v4(conn: Connection) -> None: deleted BOOLEAN NOT NULL, name TEXT NOT NULL, PRIMARY KEY (user_id, device_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_olm_session ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_olm_session ( account_id TEXT, session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -64,22 +64,29 @@ async def upgrade_blank_to_v4(conn: Connection) -> None: last_decrypted timestamp NOT NULL, last_encrypted timestamp NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( - account_id TEXT, - session_id CHAR(43), - sender_key CHAR(43) NOT NULL, - signing_key CHAR(43) NOT NULL, - room_id TEXT NOT NULL, - session bytea NOT NULL, - forwarding_chains TEXT NOT NULL, + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( + account_id TEXT, + session_id CHAR(43), + sender_key CHAR(43) NOT NULL, + signing_key CHAR(43), + room_id TEXT NOT NULL, + session bytea, + forwarding_chains TEXT, + withheld_code TEXT, + withheld_reason TEXT, + ratchet_safety jsonb, + received_at timestamp, + max_age BIGINT, + max_messages INTEGER, + is_scheduled BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, room_id TEXT, session_id CHAR(43) NOT NULL UNIQUE, @@ -87,12 +94,33 @@ async def upgrade_blank_to_v4(conn: Connection) -> None: shared BOOLEAN NOT NULL, max_messages INTEGER NOT NULL, message_count INTEGER NOT NULL, - max_age INTERVAL NOT NULL, + max_age BIGINT NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) - )""" - ) + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_keys ( + user_id TEXT, + usage TEXT, + key CHAR(43) NOT NULL, + + first_seen_key CHAR(43) NOT NULL, + + PRIMARY KEY (user_id, usage) + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_signatures ( + signed_user_id TEXT, + signed_key TEXT, + signer_user_id TEXT, + signer_key TEXT, + signature CHAR(88) NOT NULL, + PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) + ) + """) @upgrade_table.register(description="Add account_id primary key column") @@ -102,17 +130,17 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: await conn.execute("DROP TABLE crypto_olm_session") await conn.execute("DROP TABLE crypto_megolm_inbound_session") await conn.execute("DROP TABLE crypto_megolm_outbound_session") - await conn.execute( - """CREATE TABLE crypto_account ( + await conn.execute(""" + CREATE TABLE crypto_account ( account_id VARCHAR(255) PRIMARY KEY, device_id VARCHAR(255) NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL - )""" - ) - await conn.execute( - """CREATE TABLE crypto_olm_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_olm_session ( account_id VARCHAR(255), session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -120,10 +148,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_megolm_inbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_megolm_inbound_session ( account_id VARCHAR(255), session_id CHAR(43), sender_key CHAR(43) NOT NULL, @@ -132,10 +160,10 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: session bytea NOT NULL, forwarding_chains TEXT NOT NULL, PRIMARY KEY (account_id, session_id) - )""" - ) - await conn.execute( - """CREATE TABLE crypto_megolm_outbound_session ( + ) + """) + await conn.execute(""" + CREATE TABLE crypto_megolm_outbound_session ( account_id VARCHAR(255), room_id VARCHAR(255), session_id CHAR(43) NOT NULL UNIQUE, @@ -143,12 +171,12 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: shared BOOLEAN NOT NULL, max_messages INTEGER NOT NULL, message_count INTEGER NOT NULL, - max_age INTERVAL NOT NULL, + max_age BIGINT NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) - )""" - ) + ) + """) else: async def add_account_id_column(table: str, pkey_columns: list[str]) -> None: @@ -201,3 +229,207 @@ async def upgrade_v4(conn: Connection, scheme: Scheme) -> None: await conn.execute( "ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL" ) + + +@upgrade_table.register(description="Add cross-signing key and signature caches") +async def upgrade_v5(conn: Connection) -> None: + await conn.execute(""" + CREATE TABLE crypto_cross_signing_keys ( + user_id TEXT, + usage TEXT, + key CHAR(43), + first_seen_key CHAR(43), + PRIMARY KEY (user_id, usage) + ) + """) + await conn.execute(""" + CREATE TABLE crypto_cross_signing_signatures ( + signed_user_id TEXT, + signed_key TEXT, + signer_user_id TEXT, + signer_key TEXT, + signature TEXT, + PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) + ) + """) + + +@upgrade_table.register(description="Update trust state values") +async def upgrade_v6(conn: Connection) -> None: + await conn.execute("UPDATE crypto_device SET trust=300 WHERE trust=1") # verified + await conn.execute("UPDATE crypto_device SET trust=-100 WHERE trust=2") # blacklisted + await conn.execute("UPDATE crypto_device SET trust=0 WHERE trust=3") # ignored -> unset + + +@upgrade_table.register( + description="Synchronize schema with mautrix-go", upgrades_to=9, transaction=False +) +async def upgrade_v9(conn: Connection, scheme: Scheme) -> None: + if scheme == Scheme.POSTGRES: + async with conn.transaction(): + await upgrade_v9_postgres(conn) + else: + await upgrade_v9_sqlite(conn) + + +# These two are never used because the previous one jumps from 6 to 9. +@upgrade_table.register +async def upgrade_noop_7_to_8(_: Connection) -> None: + pass + + +@upgrade_table.register +async def upgrade_noop_8_to_9(_: Connection) -> None: + pass + + +async def upgrade_v9_postgres(conn: Connection) -> None: + await conn.execute("UPDATE crypto_account SET device_id='' WHERE device_id IS NULL") + await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id SET NOT NULL") + + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN signing_key DROP NOT NULL" + ) + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN session DROP NOT NULL" + ) + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN forwarding_chains DROP NOT NULL" + ) + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_code TEXT") + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_reason TEXT") + + await conn.execute("DELETE FROM crypto_cross_signing_keys WHERE key IS NULL") + await conn.execute( + "UPDATE crypto_cross_signing_keys SET first_seen_key=key WHERE first_seen_key IS NULL" + ) + await conn.execute("ALTER TABLE crypto_cross_signing_keys ALTER COLUMN key SET NOT NULL") + await conn.execute( + "ALTER TABLE crypto_cross_signing_keys ALTER COLUMN first_seen_key SET NOT NULL" + ) + + await conn.execute("DELETE FROM crypto_cross_signing_signatures WHERE signature IS NULL") + await conn.execute( + "ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature SET NOT NULL" + ) + + await conn.execute( + "ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN max_age TYPE BIGINT " + "USING (EXTRACT(EPOCH from max_age)*1000)::bigint" + ) + + +async def upgrade_v9_sqlite(conn: Connection) -> None: + await conn.execute("PRAGMA foreign_keys = OFF") + async with conn.transaction(): + await conn.execute(""" + CREATE TABLE new_crypto_account ( + account_id TEXT PRIMARY KEY, + device_id TEXT NOT NULL, + shared BOOLEAN NOT NULL, + sync_token TEXT NOT NULL, + account bytea NOT NULL + ) + """) + await conn.execute(""" + INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account) + SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account + FROM crypto_account + """) + await conn.execute("DROP TABLE crypto_account") + await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account") + + await conn.execute(""" + CREATE TABLE new_crypto_megolm_inbound_session ( + account_id TEXT, + session_id CHAR(43), + sender_key CHAR(43) NOT NULL, + signing_key CHAR(43), + room_id TEXT NOT NULL, + session bytea, + forwarding_chains TEXT, + withheld_code TEXT, + withheld_reason TEXT, + PRIMARY KEY (account_id, session_id) + ) + """) + await conn.execute(""" + INSERT INTO new_crypto_megolm_inbound_session ( + account_id, session_id, sender_key, signing_key, room_id, session, + forwarding_chains + ) + SELECT account_id, session_id, sender_key, signing_key, room_id, session, + forwarding_chains + FROM crypto_megolm_inbound_session + """) + await conn.execute("DROP TABLE crypto_megolm_inbound_session") + await conn.execute( + "ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session" + ) + + await conn.execute("UPDATE crypto_megolm_outbound_session SET max_age=max_age*1000") + + await conn.execute(""" + CREATE TABLE new_crypto_cross_signing_keys ( + user_id TEXT, + usage TEXT, + key CHAR(43) NOT NULL, + + first_seen_key CHAR(43) NOT NULL, + + PRIMARY KEY (user_id, usage) + ) + """) + await conn.execute(""" + INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key) + SELECT user_id, usage, key, COALESCE(first_seen_key, key) + FROM crypto_cross_signing_keys + WHERE key IS NOT NULL + """) + await conn.execute("DROP TABLE crypto_cross_signing_keys") + await conn.execute( + "ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys" + ) + + await conn.execute(""" + CREATE TABLE new_crypto_cross_signing_signatures ( + signed_user_id TEXT, + signed_key TEXT, + signer_user_id TEXT, + signer_key TEXT, + signature CHAR(88) NOT NULL, + PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) + ) + """) + await conn.execute(""" + INSERT INTO new_crypto_cross_signing_signatures ( + signed_user_id, signed_key, signer_user_id, signer_key, signature + ) + SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature + FROM crypto_cross_signing_signatures + WHERE signature IS NOT NULL + """) + await conn.execute("DROP TABLE crypto_cross_signing_signatures") + await conn.execute( + "ALTER TABLE new_crypto_cross_signing_signatures " + "RENAME TO crypto_cross_signing_signatures" + ) + + await conn.execute("PRAGMA foreign_key_check") + await conn.execute("PRAGMA foreign_keys = ON") + + +@upgrade_table.register( + description="Add metadata for detecting when megolm sessions are safe to delete" +) +async def upgrade_v10(conn: Connection) -> None: + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb") + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp" + ) + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT") + await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER") + await conn.execute( + "ALTER TABLE crypto_megolm_inbound_session " + "ADD COLUMN is_scheduled BOOLEAN NOT NULL DEFAULT false" + ) diff --git a/mautrix/crypto/store/memory.py b/mautrix/crypto/store/memory.py index 507680a9..c26f86bc 100644 --- a/mautrix/crypto/store/memory.py +++ b/mautrix/crypto/store/memory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -6,9 +6,23 @@ from __future__ import annotations from mautrix.client.state_store import SyncStore -from mautrix.types import DeviceID, EventID, IdentityKey, RoomID, SessionID, SyncToken, UserID - -from .. import DeviceIdentity, InboundGroupSession, OlmAccount, OutboundGroupSession, Session +from mautrix.types import ( + CrossSigner, + CrossSigningUsage, + DeviceID, + DeviceIdentity, + EventID, + IdentityKey, + RoomID, + SessionID, + SigningKey, + SyncToken, + TOFUSigningKey, + UserID, +) + +from ..account import OlmAccount +from ..sessions import InboundGroupSession, OutboundGroupSession, Session from .abstract import CryptoStore @@ -19,8 +33,10 @@ class MemoryCryptoStore(CryptoStore, SyncStore): _message_indices: dict[tuple[IdentityKey, SessionID, int], tuple[EventID, int]] _devices: dict[UserID, dict[DeviceID, DeviceIdentity]] _olm_sessions: dict[IdentityKey, list[Session]] - _inbound_sessions: dict[tuple[RoomID, IdentityKey, SessionID], InboundGroupSession] + _inbound_sessions: dict[tuple[RoomID, SessionID], InboundGroupSession] _outbound_sessions: dict[RoomID, OutboundGroupSession] + _signatures: dict[CrossSigner, dict[CrossSigner, str]] + _cross_signing_keys: dict[UserID, dict[CrossSigningUsage, TOFUSigningKey]] def __init__(self, account_id: str, pickle_key: str) -> None: self.account_id = account_id @@ -34,6 +50,8 @@ def __init__(self, account_id: str, pickle_key: str) -> None: self._olm_sessions = {} self._inbound_sessions = {} self._outbound_sessions = {} + self._signatures = {} + self._cross_signing_keys = {} async def get_device_id(self) -> DeviceID | None: return self._device_id @@ -85,17 +103,42 @@ async def put_group_session( session_id: SessionID, session: InboundGroupSession, ) -> None: - self._inbound_sessions[(room_id, sender_key, session_id)] = session + self._inbound_sessions[(room_id, session_id)] = session async def get_group_session( - self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID + self, room_id: RoomID, session_id: SessionID ) -> InboundGroupSession: - return self._inbound_sessions.get((room_id, sender_key, session_id)) + return self._inbound_sessions.get((room_id, session_id)) - async def has_group_session( - self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID - ) -> bool: - return (room_id, sender_key, session_id) in self._inbound_sessions + async def redact_group_session( + self, room_id: RoomID, session_id: SessionID, reason: str + ) -> None: + self._inbound_sessions.pop((room_id, session_id), None) + + async def redact_group_sessions( + self, room_id: RoomID, sender_key: IdentityKey, reason: str + ) -> list[SessionID]: + if not room_id and not sender_key: + raise ValueError("Either room_id or sender_key must be provided") + deleted = [] + keys = list(self._inbound_sessions.keys()) + for key in keys: + item = self._inbound_sessions[key] + if (not room_id or item.room_id == room_id) and ( + not sender_key or item.sender_key == sender_key + ): + deleted.append(SessionID(item.id)) + del self._inbound_sessions[key] + return deleted + + async def redact_expired_group_sessions(self) -> list[SessionID]: + raise NotImplementedError() + + async def redact_outdated_group_sessions(self) -> list[SessionID]: + raise NotImplementedError() + + async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: + return (room_id, session_id) in self._inbound_sessions async def add_outbound_group_session(self, session: OutboundGroupSession) -> None: self._outbound_sessions[session.room_id] = session @@ -147,3 +190,32 @@ async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdent async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]: return [user_id for user_id in users if user_id in self._devices] + + async def put_cross_signing_key( + self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey + ) -> None: + try: + current = self._cross_signing_keys[user_id][usage] + except KeyError: + self._cross_signing_keys.setdefault(user_id, {})[usage] = TOFUSigningKey( + key=key, first=key + ) + else: + current.key = key + + async def get_cross_signing_keys( + self, user_id: UserID + ) -> dict[CrossSigningUsage, TOFUSigningKey]: + return self._cross_signing_keys.get(user_id, {}) + + async def put_signature( + self, target: CrossSigner, signer: CrossSigner, signature: str + ) -> None: + self._signatures.setdefault(signer, {})[target] = signature + + async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool: + return target in self._signatures.get(signer, {}) + + async def drop_signatures_by_key(self, signer: CrossSigner) -> int: + deleted = self._signatures.pop(signer, None) + return len(deleted) diff --git a/mautrix/crypto/store/tests/__init__.py b/mautrix/crypto/store/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mautrix/crypto/store/tests/store_test.py b/mautrix/crypto/store/tests/store_test.py new file mode 100644 index 00000000..949949b8 --- /dev/null +++ b/mautrix/crypto/store/tests/store_test.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from typing import AsyncContextManager, AsyncIterator, Callable +from contextlib import asynccontextmanager +import os +import random +import string +import time + +import asyncpg +import pytest + +from mautrix.client.state_store import SyncStore +from mautrix.crypto import InboundGroupSession, OlmAccount, OutboundGroupSession +from mautrix.types import DeviceID, EventID, RoomID, SessionID, SyncToken +from mautrix.util.async_db import Database + +from .. import CryptoStore, MemoryCryptoStore, PgCryptoStore + + +@asynccontextmanager +async def async_postgres_store() -> AsyncIterator[PgCryptoStore]: + try: + pg_url = os.environ["MEOW_TEST_PG_URL"] + except KeyError: + pytest.skip("Skipped Postgres tests (MEOW_TEST_PG_URL not specified)") + return + conn: asyncpg.Connection = await asyncpg.connect(pg_url) + schema_name = "".join(random.choices(string.ascii_lowercase, k=8)) + schema_name = f"test_schema_{schema_name}_{int(time.time())}" + await conn.execute(f"CREATE SCHEMA {schema_name}") + db = Database.create( + pg_url, + upgrade_table=PgCryptoStore.upgrade_table, + db_args={"min_size": 1, "max_size": 3, "server_settings": {"search_path": schema_name}}, + ) + store = PgCryptoStore("", "test", db) + await db.start() + yield store + await db.stop() + await conn.execute(f"DROP SCHEMA {schema_name} CASCADE") + await conn.close() + + +@asynccontextmanager +async def async_sqlite_store() -> AsyncIterator[PgCryptoStore]: + db = Database.create( + "sqlite::memory:", upgrade_table=PgCryptoStore.upgrade_table, db_args={"min_size": 1} + ) + store = PgCryptoStore("", "test", db) + await db.start() + yield store + await db.stop() + + +@asynccontextmanager +async def memory_store() -> AsyncIterator[MemoryCryptoStore]: + yield MemoryCryptoStore("", "test") + + +@pytest.fixture(params=[async_postgres_store, async_sqlite_store, memory_store]) +async def crypto_store(request) -> AsyncIterator[CryptoStore]: + param: Callable[[], AsyncContextManager[CryptoStore]] = request.param + async with param() as state_store: + yield state_store + + +async def test_basic(crypto_store: CryptoStore) -> None: + acc = OlmAccount() + keys = acc.identity_keys + await crypto_store.put_account(acc) + await crypto_store.put_device_id(DeviceID("TEST")) + if isinstance(crypto_store, SyncStore): + await crypto_store.put_next_batch(SyncToken("TEST")) + + assert await crypto_store.get_device_id() == "TEST" + assert (await crypto_store.get_account()).identity_keys == keys + if isinstance(crypto_store, SyncStore): + assert await crypto_store.get_next_batch() == "TEST" + + +def _make_group_sess( + acc: OlmAccount, room_id: RoomID +) -> tuple[InboundGroupSession, OutboundGroupSession]: + outbound = OutboundGroupSession(room_id) + inbound = InboundGroupSession( + session_key=outbound.session_key, + signing_key=acc.signing_key, + sender_key=acc.identity_key, + room_id=room_id, + ) + return inbound, outbound + + +async def test_validate_message_index(crypto_store: CryptoStore) -> None: + acc = OlmAccount() + + inbound, outbound = _make_group_sess(acc, RoomID("!foo:bar.com")) + outbound.shared = True + orig_plaintext = "hello world" + ciphertext = outbound.encrypt(orig_plaintext) + ts = int(time.time() * 1000) + plaintext, index = inbound.decrypt(ciphertext) + assert plaintext == orig_plaintext + + assert await crypto_store.validate_message_index( + acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + ), "Initial validation returns True" + assert await crypto_store.validate_message_index( + acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + ), "Validating the same details again returns True" + assert not await crypto_store.validate_message_index( + acc.identity_key, SessionID(inbound.id), EventID("$bar"), index, ts + ), "Different event ID causes validation to fail" + assert not await crypto_store.validate_message_index( + acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1 + ), "Different timestamp causes validation to fail" + assert not await crypto_store.validate_message_index( + acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + 1 + ), "Validating incorrect details twice fails" + assert await crypto_store.validate_message_index( + acc.identity_key, SessionID(inbound.id), EventID("$foo"), index, ts + ), "Validating the same details after fails still returns True" + + +# TODO tests for device identity storage, group session storage +# and cross-signing key/signature storage diff --git a/mautrix/crypto/types.py b/mautrix/crypto/types.py deleted file mode 100644 index e210afa5..00000000 --- a/mautrix/crypto/types.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2021 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Optional -from enum import IntEnum - -from attr import dataclass -import attr - -from mautrix.types import ( - DeviceID, - IdentityKey, - SerializableAttrs, - SigningKey, - ToDeviceEvent, - UserID, -) - - -class TrustState(IntEnum): - UNSET = 0 - VERIFIED = 1 - BLACKLISTED = 2 - IGNORED = 3 - - -@dataclass -class DeviceIdentity: - user_id: UserID - device_id: DeviceID - identity_key: IdentityKey - signing_key: SigningKey - - trust: TrustState - deleted: bool - name: str - - -@dataclass -class OlmEventKeys(SerializableAttrs): - ed25519: SigningKey - - -@dataclass -class DecryptedOlmEvent(ToDeviceEvent, SerializableAttrs): - keys: OlmEventKeys - recipient: UserID - recipient_keys: OlmEventKeys - sender_device: Optional[DeviceID] = None - sender_key: IdentityKey = attr.ib(metadata={"hidden": True}, default=None) diff --git a/mautrix/crypto/unwedge.py b/mautrix/crypto/unwedge.py index 3f07d528..c3e952f6 100644 --- a/mautrix/crypto/unwedge.py +++ b/mautrix/crypto/unwedge.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/errors/__init__.py b/mautrix/errors/__init__.py index 0adb3e2c..afe68dc9 100644 --- a/mautrix/errors/__init__.py +++ b/mautrix/errors/__init__.py @@ -6,6 +6,7 @@ DeviceValidationError, DuplicateMessageIndex, EncryptionError, + GroupSessionWithheldError, MatchingSessionDecryptionError, MismatchingRoomError, SessionNotFound, @@ -13,6 +14,7 @@ VerificationError, ) from .request import ( + MAlreadyJoined, MatrixBadContent, MatrixBadRequest, MatrixInvalidToken, @@ -27,6 +29,7 @@ MForbidden, MGuestAccessForbidden, MIncompatibleRoomVersion, + MInsufficientPower, MInvalidParam, MInvalidRoomState, MInvalidUsername, @@ -34,11 +37,13 @@ MMissingParam, MMissingToken, MNotFound, + MNotJoined, MNotJSON, MRoomInUse, MTooLarge, MUnauthorized, MUnknown, + MUnknownEndpoint, MUnknownToken, MUnrecognized, MUnsupportedRoomVersion, @@ -56,3 +61,66 @@ WellKnownUnexpectedStatus, WellKnownUnsupportedScheme, ) + +__all__ = [ + "IntentError", + "MatrixConnectionError", + "MatrixError", + "MatrixResponseError", + "CryptoError", + "DecryptedPayloadError", + "DecryptionError", + "DeviceValidationError", + "DuplicateMessageIndex", + "EncryptionError", + "GroupSessionWithheldError", + "MatchingSessionDecryptionError", + "MismatchingRoomError", + "SessionNotFound", + "SessionShareError", + "VerificationError", + "MAlreadyJoined", + "MatrixBadContent", + "MatrixBadRequest", + "MatrixInvalidToken", + "MatrixRequestError", + "MatrixStandardRequestError", + "MatrixUnknownRequestError", + "MBadJSON", + "MBadState", + "MCaptchaInvalid", + "MCaptchaNeeded", + "MExclusive", + "MForbidden", + "MGuestAccessForbidden", + "MIncompatibleRoomVersion", + "MInsufficientPower", + "MInvalidParam", + "MInvalidRoomState", + "MInvalidUsername", + "MLimitExceeded", + "MMissingParam", + "MMissingToken", + "MNotFound", + "MNotJoined", + "MNotJSON", + "MRoomInUse", + "MTooLarge", + "MUnauthorized", + "MUnknown", + "MUnknownEndpoint", + "MUnknownToken", + "MUnrecognized", + "MUnsupportedRoomVersion", + "MUserDeactivated", + "MUserInUse", + "make_request_error", + "standard_error", + "WellKnownError", + "WellKnownInvalidVersionsResponse", + "WellKnownMissingHomeserver", + "WellKnownNotJSON", + "WellKnownNotURL", + "WellKnownUnexpectedStatus", + "WellKnownUnsupportedScheme", +] diff --git a/mautrix/errors/base.py b/mautrix/errors/base.py index 086cc32f..8eb4cc1e 100644 --- a/mautrix/errors/base.py +++ b/mautrix/errors/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/errors/crypto.py b/mautrix/errors/crypto.py index 70e07c52..4a65048c 100644 --- a/mautrix/errors/crypto.py +++ b/mautrix/errors/crypto.py @@ -1,8 +1,12 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + +import warnings + from mautrix.types import IdentityKey, SessionID from .base import MatrixError @@ -23,20 +27,44 @@ class SessionShareError(CryptoError): class DecryptionError(CryptoError): - pass + @property + def human_message(self) -> str: + return "the bridge failed to decrypt the message" class MatchingSessionDecryptionError(DecryptionError): pass +class GroupSessionWithheldError(DecryptionError): + def __init__(self, session_id: SessionID, withheld_code: str) -> None: + super().__init__(f"Session ID {session_id} was withheld ({withheld_code})") + self.withheld_code = withheld_code + + class SessionNotFound(DecryptionError): - def __init__(self, session_id: SessionID, sender_key: IdentityKey) -> None: + def __init__(self, session_id: SessionID, sender_key: IdentityKey | None = None) -> None: super().__init__( f"Failed to decrypt megolm event: no session with given ID {session_id} found" ) self.session_id = session_id - self.sender_key = sender_key + self._sender_key = sender_key + + @property + def human_message(self) -> str: + return "the bridge hasn't received the decryption keys" + + @property + def sender_key(self) -> IdentityKey | None: + """ + .. deprecated:: 0.17.0 + Matrix v1.3 deprecated the device_id and sender_key fields in megolm events. + """ + warnings.warn( + "The sender_key field in Megolm events was deprecated in Matrix 1.3", + DeprecationWarning, + ) + return self._sender_key class DuplicateMessageIndex(DecryptionError): @@ -46,7 +74,7 @@ def __init__(self) -> None: class VerificationError(DecryptionError): def __init__(self) -> None: - super().__init__("Device keys in event and verified device info do not match") + super().__init__("Device keys in session and cached device info do not match") class DecryptedPayloadError(DecryptionError): diff --git a/mautrix/errors/request.py b/mautrix/errors/request.py index d0facbe0..ebff4d76 100644 --- a/mautrix/errors/request.py +++ b/mautrix/errors/request.py @@ -1,9 +1,11 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Callable, Dict, Optional, Type +from __future__ import annotations + +from typing import Callable, Type from .base import MatrixError @@ -12,25 +14,30 @@ class MatrixRequestError(MatrixError): """An error that was returned by the homeserver.""" http_status: int - message: Optional[str] + message: str | None errcode: str class MatrixUnknownRequestError(MatrixRequestError): """An unknown error type returned by the homeserver.""" + http_status: int + text: str + errcode: str | None + message: str | None + def __init__( self, http_status: int = 0, text: str = "", - errcode: Optional[str] = None, - message: Optional[str] = None, + errcode: str | None = None, + message: str | None = None, ) -> None: super().__init__(f"{http_status}: {text}") - self.http_status: int = http_status - self.text: str = text - self.errcode: Optional[str] = errcode - self.message: Optional[str] = message + self.http_status = http_status + self.text = text + self.errcode = errcode + self.message = message class MatrixStandardRequestError(MatrixRequestError): @@ -45,20 +52,28 @@ def __init__(self, http_status: int, message: str = "") -> None: MxSRE = Type[MatrixStandardRequestError] -ec_map: Dict[str, MxSRE] = {} +ec_map: dict[str, MxSRE] = {} +uec_map: dict[str, MxSRE] = {} -def standard_error(code: str) -> Callable[[MxSRE], MxSRE]: +def standard_error(code: str, unstable: str | None = None) -> Callable[[MxSRE], MxSRE]: def decorator(cls: MxSRE) -> MxSRE: cls.errcode = code ec_map[code] = cls + if unstable: + cls.unstable_errcode = unstable + uec_map[unstable] = cls return cls return decorator def make_request_error( - http_status: int, text: str, errcode: str, message: str + http_status: int, + text: str, + errcode: str | None, + message: str | None, + unstable_errcode: str | None = None, ) -> MatrixRequestError: """ Determine the correct exception class for the error code and create an instance of that class @@ -69,7 +84,14 @@ def make_request_error( text: The raw response text. errcode: The errcode field in the response JSON. message: The error field in the response JSON. + unstable_errcode: The MSC3848 error code field in the response JSON. """ + if unstable_errcode: + try: + ec_class = uec_map[unstable_errcode] + return ec_class(http_status, message) + except KeyError: + pass try: ec_class = ec_map[errcode] return ec_class(http_status, message) @@ -77,7 +99,7 @@ def make_request_error( return MatrixUnknownRequestError(http_status, text, errcode, message) -# Standard error codes from https://matrix.org/docs/spec/client_server/r0.4.0.html#api-standards +# Standard error codes from https://spec.matrix.org/v1.3/client-server-api/#api-standards # Additionally some combining superclasses for some of the error codes @@ -86,6 +108,26 @@ class MForbidden(MatrixStandardRequestError): pass +@standard_error("M_ALREADY_JOINED", unstable="ORG.MATRIX.MSC3848.ALREADY_JOINED") +class MAlreadyJoined(MForbidden): + pass + + +@standard_error("M_NOT_JOINED", unstable="ORG.MATRIX.MSC3848.NOT_JOINED") +class MNotJoined(MForbidden): + pass + + +@standard_error("M_INSUFFICIENT_POWER", unstable="ORG.MATRIX.MSC3848.INSUFFICIENT_POWER") +class MInsufficientPower(MForbidden): + pass + + +@standard_error("M_UNKNOWN_ENDPOINT") +class MUnknownEndpoint(MatrixStandardRequestError): + pass + + @standard_error("M_USER_DEACTIVATED") class MUserDeactivated(MForbidden): pass diff --git a/mautrix/errors/well_known.py b/mautrix/errors/well_known.py index 8ff98d72..d1c3f035 100644 --- a/mautrix/errors/well_known.py +++ b/mautrix/errors/well_known.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/genall.py b/mautrix/genall.py new file mode 100644 index 00000000..14fda230 --- /dev/null +++ b/mautrix/genall.py @@ -0,0 +1,41 @@ +# This script generates the __all__ arrays for types/__init__.py and errors/__init__.py +# to avoid having to manually add both the import and the __all__ entry. +# See https://github.com/mautrix/python/issues/90 for why __all__ is needed at all. +from pathlib import Path +import ast + +import black + +root_module = Path(__file__).parent + +black_cfg = black.parse_pyproject_toml(str(root_module.parent / "pyproject.toml")) +black_mode = black.Mode( + target_versions={black.TargetVersion[ver.upper()] for ver in black_cfg["target_version"]}, + line_length=black_cfg["line_length"], +) + + +def add_imports_to_all(dir: str) -> None: + init_file = root_module / dir / "__init__.py" + with open(init_file) as f: + init_ast = ast.parse(f.read(), filename=f"mautrix/{dir}/__init__.py") + + imports: list[str] = [] + all_node: ast.List | None = None + + for node in ast.iter_child_nodes(init_ast): + if isinstance(node, (ast.Import, ast.ImportFrom)): + imports += (name.name for name in node.names) + elif isinstance(node, ast.Assign) and isinstance(node.value, ast.List): + target = node.targets[0] + if len(node.targets) == 1 and isinstance(target, ast.Name) and target.id == "__all__": + all_node = node.value + + all_node.elts = [ast.Constant(name) for name in imports] + + with open(init_file, "w") as f: + f.write(black.format_str(ast.unparse(init_ast), mode=black_mode)) + + +add_imports_to_all("types") +add_imports_to_all("errors") diff --git a/mautrix/types/__init__.py b/mautrix/types/__init__.py index 6e8da575..ceceeaca 100644 --- a/mautrix/types/__init__.py +++ b/mautrix/types/__init__.py @@ -14,16 +14,35 @@ UserIdentifierType, WhoamiResponse, ) -from .crypto import ClaimKeysResponse, DeviceKeys, QueryKeysResponse, UnsignedDeviceInfo +from .crypto import ( + ClaimKeysResponse, + CrossSigner, + CrossSigningKeys, + CrossSigningUsage, + DecryptedOlmEvent, + DeviceIdentity, + DeviceKeys, + OlmEventKeys, + QueryKeysResponse, + TOFUSigningKey, + TrustState, + UnsignedDeviceInfo, +) from .event import ( AccountDataEvent, AccountDataEventContent, + ASToDeviceEvent, AudioInfo, BaseEvent, BaseFileInfo, BaseMessageEventContent, + BaseMessageEventContentFuncs, BaseRoomEvent, BaseUnsigned, + BatchSendEvent, + BatchSendStateEvent, + BeeperMessageStatusEvent, + BeeperMessageStatusEventContent, CallAnswerEventContent, CallCandidate, CallCandidatesEventContent, @@ -38,6 +57,7 @@ CallRejectEventContent, CallSelectAnswerEventContent, CanonicalAliasStateEventContent, + DirectAccountDataEventContent, EncryptedEvent, EncryptedEventContent, EncryptedFile, @@ -54,9 +74,11 @@ ForwardedRoomKeyEventContent, GenericEvent, ImageInfo, + InReplyTo, JoinRule, JoinRulesStateEventContent, JSONWebKey, + KeyID, KeyRequestAction, LocationInfo, LocationMessageEventContent, @@ -66,6 +88,8 @@ MemberStateEventContent, MessageEvent, MessageEventContent, + MessageStatus, + MessageStatusReason, MessageType, MessageUnsigned, OlmCiphertext, @@ -98,6 +122,8 @@ RoomTagInfo, RoomTombstoneStateEventContent, RoomTopicStateEventContent, + RoomType, + SecretStorageDefaultKeyEventContent, SingleReceiptEventContent, SpaceChildStateEventContent, SpaceParentStateEventContent, @@ -115,19 +141,27 @@ ) from .filter import EventFilter, Filter, RoomEventFilter, RoomFilter, StateFilter from .matrixuri import IdentifierType, MatrixURI, MatrixURIError, URIAction -from .media import MediaRepoConfig, MXOpenGraph, OpenGraphAudio, OpenGraphImage, OpenGraphVideo +from .media import ( + MediaCreateResponse, + MediaRepoConfig, + MXOpenGraph, + OpenGraphAudio, + OpenGraphImage, + OpenGraphVideo, +) from .misc import ( BatchSendResponse, + BeeperBatchSendResponse, DeviceLists, DeviceOTKCount, DirectoryPaginationToken, + EventContext, PaginatedMessages, PaginationDirection, RoomAliasInfo, RoomCreatePreset, RoomDirectoryResponse, RoomDirectoryVisibility, - VersionsResponse, ) from .primitive import ( JSON, @@ -140,6 +174,7 @@ RoomAlias, RoomID, SessionID, + Signature, SigningKey, SyncToken, UserID, @@ -169,3 +204,210 @@ field, serializer, ) +from .versions import SpecVersions, Version, VersionFormat, VersionsResponse + +__all__ = [ + "DiscoveryInformation", + "DiscoveryIntegrations", + "DiscoveryIntegrationServer", + "DiscoveryServer", + "LoginFlow", + "LoginFlowList", + "LoginResponse", + "LoginType", + "MatrixUserIdentifier", + "PhoneIdentifier", + "ThirdPartyIdentifier", + "UserIdentifier", + "UserIdentifierType", + "WhoamiResponse", + "ClaimKeysResponse", + "CrossSigner", + "CrossSigningKeys", + "CrossSigningUsage", + "DecryptedOlmEvent", + "DeviceIdentity", + "DeviceKeys", + "OlmEventKeys", + "QueryKeysResponse", + "TOFUSigningKey", + "TrustState", + "UnsignedDeviceInfo", + "AccountDataEvent", + "AccountDataEventContent", + "ASToDeviceEvent", + "AudioInfo", + "BaseEvent", + "BaseFileInfo", + "BaseMessageEventContent", + "BaseMessageEventContentFuncs", + "BaseRoomEvent", + "BaseUnsigned", + "BatchSendEvent", + "BatchSendStateEvent", + "BeeperMessageStatusEvent", + "BeeperMessageStatusEventContent", + "CallAnswerEventContent", + "CallCandidate", + "CallCandidatesEventContent", + "CallData", + "CallDataType", + "CallEvent", + "CallEventContent", + "CallHangupEventContent", + "CallHangupReason", + "CallInviteEventContent", + "CallNegotiateEventContent", + "CallRejectEventContent", + "CallSelectAnswerEventContent", + "CanonicalAliasStateEventContent", + "DirectAccountDataEventContent", + "EncryptedEvent", + "EncryptedEventContent", + "EncryptedFile", + "EncryptedMegolmEventContent", + "EncryptedOlmEventContent", + "EncryptionAlgorithm", + "EncryptionKeyAlgorithm", + "EphemeralEvent", + "Event", + "EventContent", + "EventType", + "FileInfo", + "Format", + "ForwardedRoomKeyEventContent", + "GenericEvent", + "ImageInfo", + "InReplyTo", + "JoinRule", + "JoinRulesStateEventContent", + "JSONWebKey", + "KeyID", + "KeyRequestAction", + "LocationInfo", + "LocationMessageEventContent", + "MediaInfo", + "MediaMessageEventContent", + "Membership", + "MemberStateEventContent", + "MessageEvent", + "MessageEventContent", + "MessageStatus", + "MessageStatusReason", + "MessageType", + "MessageUnsigned", + "OlmCiphertext", + "OlmMsgType", + "PowerLevelStateEventContent", + "PresenceEvent", + "PresenceEventContent", + "PresenceState", + "ReactionEvent", + "ReactionEventContent", + "ReceiptEvent", + "ReceiptEventContent", + "ReceiptType", + "RedactionEvent", + "RedactionEventContent", + "RelatesTo", + "RelationType", + "RequestedKeyInfo", + "RoomAvatarStateEventContent", + "RoomCreateStateEventContent", + "RoomEncryptionStateEventContent", + "RoomKeyEventContent", + "RoomKeyRequestEventContent", + "RoomKeyWithheldCode", + "RoomKeyWithheldEventContent", + "RoomNameStateEventContent", + "RoomPinnedEventsStateEventContent", + "RoomPredecessor", + "RoomTagAccountDataEventContent", + "RoomTagInfo", + "RoomTombstoneStateEventContent", + "RoomTopicStateEventContent", + "RoomType", + "SecretStorageDefaultKeyEventContent", + "SingleReceiptEventContent", + "SpaceChildStateEventContent", + "SpaceParentStateEventContent", + "StateEvent", + "StateEventContent", + "StateUnsigned", + "StrippedStateEvent", + "TextMessageEventContent", + "ThumbnailInfo", + "ToDeviceEvent", + "ToDeviceEventContent", + "TypingEvent", + "TypingEventContent", + "VideoInfo", + "EventFilter", + "Filter", + "RoomEventFilter", + "RoomFilter", + "StateFilter", + "IdentifierType", + "MatrixURI", + "MatrixURIError", + "URIAction", + "MediaCreateResponse", + "MediaRepoConfig", + "MXOpenGraph", + "OpenGraphAudio", + "OpenGraphImage", + "OpenGraphVideo", + "BatchSendResponse", + "BeeperBatchSendResponse", + "DeviceLists", + "DeviceOTKCount", + "DirectoryPaginationToken", + "EventContext", + "PaginatedMessages", + "PaginationDirection", + "RoomAliasInfo", + "RoomCreatePreset", + "RoomDirectoryResponse", + "RoomDirectoryVisibility", + "JSON", + "BatchID", + "ContentURI", + "DeviceID", + "EventID", + "FilterID", + "IdentityKey", + "RoomAlias", + "RoomID", + "SessionID", + "Signature", + "SigningKey", + "SyncToken", + "UserID", + "PushAction", + "PushActionDict", + "PushActionType", + "PushCondition", + "PushConditionKind", + "PushOperator", + "PushRule", + "PushRuleID", + "PushRuleKind", + "PushRuleScope", + "Member", + "User", + "UserSearchResults", + "ExtensibleEnum", + "Lst", + "Obj", + "Serializable", + "SerializableAttrs", + "SerializableEnum", + "SerializerError", + "deserializer", + "field", + "serializer", + "SpecVersions", + "Version", + "VersionFormat", + "VersionsResponse", +] diff --git a/mautrix/types/auth.py b/mautrix/types/auth.py index 847615ab..ad582118 100644 --- a/mautrix/types/auth.py +++ b/mautrix/types/auth.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -6,10 +6,9 @@ from typing import List, NewType, Optional, Union from attr import dataclass -import attr from .primitive import JSON, DeviceID, UserID -from .util import ExtensibleEnum, Obj, SerializableAttrs, deserializer +from .util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field class LoginType(ExtensibleEnum): @@ -17,7 +16,7 @@ class LoginType(ExtensibleEnum): A login type, as specified in the `POST /login endpoint`_ .. _POST /login endpoint: - https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login """ PASSWORD: "LoginType" = "m.login.password" @@ -26,7 +25,8 @@ class LoginType(ExtensibleEnum): APPSERVICE: "LoginType" = "m.login.application_service" UNSTABLE_JWT: "LoginType" = "org.matrix.login.jwt" - UNSTABLE_APPSERVICE: "LoginType" = "uk.half-shot.msc2778.login.application_service" + + DEVTURE_SHARED_SECRET: "LoginType" = "com.devture.shared_secret_auth" @dataclass @@ -35,7 +35,7 @@ class LoginFlow(SerializableAttrs): A login flow, as specified in the `GET /login endpoint`_ .. _GET /login endpoint: - https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-login + https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login """ type: LoginType @@ -60,7 +60,7 @@ class UserIdentifierType(ExtensibleEnum): A user identifier type, as specified in the `Identifier types`_ section of the login spec. .. _Identifier types: - https://matrix.org/docs/spec/client_server/latest#identifier-types + https://spec.matrix.org/v1.2/client-server-api/#identifier-types """ MATRIX_USER: "UserIdentifierType" = "m.id.user" @@ -89,9 +89,9 @@ class ThirdPartyIdentifier(SerializableAttrs): Appendix for a list of Third-party ID media. .. _/account/3pid: - https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-account-3pid + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3account3pid .. _3PID Types: - https://matrix.org/docs/spec/appendices.html#pid-types + https://spec.matrix.org/v1.2/appendices/#3pid-types """ medium: str @@ -109,7 +109,7 @@ class PhoneIdentifier(SerializableAttrs): identifier type with a ``medium`` of ``msisdn`` instead. .. _/account/3pid: - https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-account-3pid + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3account3pid """ country: str @@ -154,19 +154,24 @@ class DiscoveryIntegrationServer(SerializableAttrs): @dataclass class DiscoveryIntegrations(SerializableAttrs): - managers: List[DiscoveryIntegrationServer] = attr.ib(factory=lambda: []) + managers: List[DiscoveryIntegrationServer] = field(factory=lambda: []) @dataclass class DiscoveryInformation(SerializableAttrs): - homeserver: Optional[DiscoveryServer] = attr.ib( - metadata={"json": "m.homeserver"}, factory=DiscoveryServer - ) - identity_server: Optional[DiscoveryServer] = attr.ib( - metadata={"json": "m.identity_server"}, factory=DiscoveryServer + """ + .well-known discovery information, as specified in the `GET /.well-known/matrix/client endpoint`_ + + .. _GET /.well-known/matrix/client endpoint: + https://spec.matrix.org/v1.2/client-server-api/#getwell-knownmatrixclient + """ + + homeserver: Optional[DiscoveryServer] = field(json="m.homeserver", factory=DiscoveryServer) + identity_server: Optional[DiscoveryServer] = field( + json="m.identity_server", factory=DiscoveryServer ) - integrations: Optional[DiscoveryServer] = attr.ib( - metadata={"json": "m.integrations"}, factory=DiscoveryIntegrations + integrations: Optional[DiscoveryServer] = field( + json="m.integrations", factory=DiscoveryIntegrations ) @@ -176,13 +181,13 @@ class LoginResponse(SerializableAttrs): The response for a login request, as specified in the `POST /login endpoint`_ .. _POST /login endpoint: - https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login """ user_id: UserID device_id: DeviceID access_token: str - well_known: DiscoveryInformation = attr.ib(factory=DiscoveryInformation) + well_known: DiscoveryInformation = field(factory=DiscoveryInformation) @dataclass @@ -191,8 +196,9 @@ class WhoamiResponse(SerializableAttrs): The response for a whoami request, as specified in the `GET /account/whoami endpoint`_ .. _GET /account/whoami endpoint: - https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3accountwhoami + https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami """ user_id: UserID device_id: Optional[DeviceID] = None + is_guest: bool = False diff --git a/mautrix/types/crypto.py b/mautrix/types/crypto.py index f5d41db9..fe4ab742 100644 --- a/mautrix/types/crypto.py +++ b/mautrix/types/crypto.py @@ -1,15 +1,16 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, NamedTuple, Optional +from enum import IntEnum from attr import dataclass -from .event.encrypted import EncryptionAlgorithm, EncryptionKeyAlgorithm -from .primitive import DeviceID, IdentityKey, SigningKey, UserID -from .util import SerializableAttrs +from .event import EncryptionAlgorithm, EncryptionKeyAlgorithm, KeyID, ToDeviceEvent +from .primitive import DeviceID, IdentityKey, Signature, SigningKey, UserID +from .util import ExtensibleEnum, SerializableAttrs, field @dataclass @@ -22,8 +23,8 @@ class DeviceKeys(SerializableAttrs): user_id: UserID device_id: DeviceID algorithms: List[EncryptionAlgorithm] - keys: Dict[str, str] - signatures: Dict[UserID, Dict[str, str]] + keys: Dict[KeyID, str] + signatures: Dict[UserID, Dict[KeyID, Signature]] unsigned: UnsignedDeviceInfo = None def __attrs_post_init__(self) -> None: @@ -33,25 +34,138 @@ def __attrs_post_init__(self) -> None: @property def ed25519(self) -> Optional[SigningKey]: try: - return SigningKey(self.keys[f"{EncryptionKeyAlgorithm.ED25519}:{self.device_id}"]) + return SigningKey(self.keys[KeyID(EncryptionKeyAlgorithm.ED25519, self.device_id)]) except KeyError: return None @property def curve25519(self) -> Optional[IdentityKey]: try: - return IdentityKey(self.keys[f"{EncryptionKeyAlgorithm.CURVE25519}:{self.device_id}"]) + return IdentityKey(self.keys[KeyID(EncryptionKeyAlgorithm.CURVE25519, self.device_id)]) except KeyError: return None +class CrossSigningUsage(ExtensibleEnum): + MASTER: "CrossSigningUsage" = "master" + SELF: "CrossSigningUsage" = "self_signing" + USER: "CrossSigningUsage" = "user_signing" + + +@dataclass +class CrossSigningKeys(SerializableAttrs): + user_id: UserID + usage: List[CrossSigningUsage] + keys: Dict[KeyID, SigningKey] + signatures: Dict[UserID, Dict[KeyID, Signature]] = field(factory=lambda: {}) + + @property + def first_key(self) -> Optional[SigningKey]: + try: + return next(iter(self.keys.values())) + except StopIteration: + return None + + @property + def first_ed25519_key(self) -> Optional[SigningKey]: + return self.first_key_with_algorithm(EncryptionKeyAlgorithm.ED25519) + + def first_key_with_algorithm(self, alg: EncryptionKeyAlgorithm) -> Optional[SigningKey]: + if not self.keys: + return None + try: + return next(key for key_id, key in self.keys.items() if key_id.algorithm == alg) + except StopIteration: + return None + + @dataclass class QueryKeysResponse(SerializableAttrs): - failures: Dict[str, Any] - device_keys: Dict[UserID, Dict[DeviceID, DeviceKeys]] + device_keys: Dict[UserID, Dict[DeviceID, DeviceKeys]] = field(factory=lambda: {}) + master_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {}) + self_signing_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {}) + user_signing_keys: Dict[UserID, CrossSigningKeys] = field(factory=lambda: {}) + failures: Dict[str, Any] = field(factory=lambda: {}) @dataclass class ClaimKeysResponse(SerializableAttrs): - failures: Dict[str, Any] - one_time_keys: Dict[UserID, Dict[DeviceID, Dict[str, Any]]] + one_time_keys: Dict[UserID, Dict[DeviceID, Dict[KeyID, Any]]] + failures: Dict[str, Any] = field(factory=lambda: {}) + + +class TrustState(IntEnum): + BLACKLISTED = -100 + UNVERIFIED = 0 + UNKNOWN_DEVICE = 10 + FORWARDED = 20 + CROSS_SIGNED_UNTRUSTED = 50 + CROSS_SIGNED_TOFU = 100 + CROSS_SIGNED_TRUSTED = 200 + VERIFIED = 300 + + def __str__(self) -> str: + return _trust_state_to_name[self] + + @classmethod + def parse(cls, val: str) -> "TrustState": + try: + return _name_to_trust_state[val] + except KeyError as e: + raise ValueError(f"Invalid trust state {val!r}") from e + + +_trust_state_to_name: Dict[TrustState, str] = { + val: val.name.lower().replace("_", "-") for val in TrustState +} +_name_to_trust_state: Dict[str, TrustState] = { + value: key for key, value in _trust_state_to_name.items() +} + + +@dataclass +class DeviceIdentity: + user_id: UserID + device_id: DeviceID + identity_key: IdentityKey + signing_key: SigningKey + + trust: TrustState + deleted: bool + name: str + + +@dataclass +class OlmEventKeys(SerializableAttrs): + ed25519: SigningKey + + +@dataclass +class DecryptedOlmEvent(ToDeviceEvent, SerializableAttrs): + keys: OlmEventKeys + recipient: UserID + recipient_keys: OlmEventKeys + sender_device: Optional[DeviceID] = None + sender_key: IdentityKey = field(hidden=True, default=None) + + +class TOFUSigningKey(NamedTuple): + """ + A tuple representing a single cross-signing key. The first value is the current key, and the + second value is the first seen key. If the values don't match, it means the key is not valid + for trust-on-first-use. + """ + + key: SigningKey + first: SigningKey + + +class CrossSigner(NamedTuple): + """ + A tuple containing a user ID and a signing key they own. + + The key can either be a device-owned signing key, or one of the user's cross-signing keys. + """ + + user_id: UserID + key: SigningKey diff --git a/mautrix/types/event/__init__.py b/mautrix/types/event/__init__.py index f79eeb0f..db0658db 100644 --- a/mautrix/types/event/__init__.py +++ b/mautrix/types/event/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -6,10 +6,19 @@ from .account_data import ( AccountDataEvent, AccountDataEventContent, + DirectAccountDataEventContent, RoomTagAccountDataEventContent, RoomTagInfo, + SecretStorageDefaultKeyEventContent, ) from .base import BaseEvent, BaseRoomEvent, BaseUnsigned, GenericEvent +from .batch import BatchSendEvent, BatchSendStateEvent +from .beeper import ( + BeeperMessageStatusEvent, + BeeperMessageStatusEventContent, + MessageStatus, + MessageStatusReason, +) from .encrypted import ( EncryptedEvent, EncryptedEventContent, @@ -17,6 +26,7 @@ EncryptedOlmEventContent, EncryptionAlgorithm, EncryptionKeyAlgorithm, + KeyID, OlmCiphertext, OlmMsgType, ) @@ -37,10 +47,12 @@ AudioInfo, BaseFileInfo, BaseMessageEventContent, + BaseMessageEventContentFuncs, EncryptedFile, FileInfo, Format, ImageInfo, + InReplyTo, JSONWebKey, LocationInfo, LocationMessageEventContent, @@ -73,6 +85,7 @@ RoomPredecessor, RoomTombstoneStateEventContent, RoomTopicStateEventContent, + RoomType, SpaceChildStateEventContent, SpaceParentStateEventContent, StateEvent, @@ -81,6 +94,7 @@ StrippedStateEvent, ) from .to_device import ( + ASToDeviceEvent, ForwardedRoomKeyEventContent, KeyRequestAction, RequestedKeyInfo, diff --git a/mautrix/types/event/account_data.py b/mautrix/types/event/account_data.py index c4d3c8d8..fdfa0a30 100644 --- a/mautrix/types/event/account_data.py +++ b/mautrix/types/event/account_data.py @@ -1,9 +1,9 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Union from attr import dataclass import attr @@ -12,6 +12,9 @@ from ..util import Obj, SerializableAttrs, deserializer from .base import BaseEvent, EventType +if TYPE_CHECKING: + from mautrix.crypto.ssss import EncryptedAccountDataEventContent, KeyMetadata + @dataclass class RoomTagInfo(SerializableAttrs): @@ -23,12 +26,26 @@ class RoomTagAccountDataEventContent(SerializableAttrs): tags: Dict[str, RoomTagInfo] = attr.ib(default=None, metadata={"json": "tags"}) +@dataclass +class SecretStorageDefaultKeyEventContent(SerializableAttrs): + key: str + + DirectAccountDataEventContent = Dict[UserID, List[RoomID]] -AccountDataEventContent = Union[RoomTagAccountDataEventContent, DirectAccountDataEventContent, Obj] +AccountDataEventContent = Union[ + RoomTagAccountDataEventContent, + DirectAccountDataEventContent, + SecretStorageDefaultKeyEventContent, + "EncryptedAccountDataEventContent", + "KeyMetadata", + Obj, +] account_data_event_content_map = { EventType.TAG: RoomTagAccountDataEventContent, - EventType.DIRECT: DirectAccountDataEventContent, + EventType.SECRET_STORAGE_DEFAULT_KEY: SecretStorageDefaultKeyEventContent, + # m.direct doesn't really need deserializing + # EventType.DIRECT: DirectAccountDataEventContent, } @@ -42,10 +59,13 @@ class AccountDataEvent(BaseEvent, SerializableAttrs): @classmethod def deserialize(cls, data: JSON) -> "AccountDataEvent": try: - data.get("content", {})["__mautrix_event_type"] = EventType.find(data.get("type")) + evt_type = EventType.find(data.get("type")) + data.get("content", {})["__mautrix_event_type"] = evt_type except ValueError: return Obj(**data) - return super().deserialize(data) + evt = super().deserialize(data) + evt.type = evt_type + return evt @staticmethod @deserializer(AccountDataEventContent) diff --git a/mautrix/types/event/base.py b/mautrix/types/event/base.py index 84e2b166..5e71430b 100644 --- a/mautrix/types/event/base.py +++ b/mautrix/types/event/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/event/batch.py b/mautrix/types/event/batch.py new file mode 100644 index 00000000..b08ad21e --- /dev/null +++ b/mautrix/types/event/batch.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 Tulir Asokan, Sumner Evans +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Any, Optional + +from attr import dataclass +import attr + +from ..primitive import EventID, UserID +from ..util import SerializableAttrs +from .base import BaseEvent + + +@dataclass(kw_only=True) +class BatchSendEvent(BaseEvent, SerializableAttrs): + """Base event class for events sent via a batch send request.""" + + sender: UserID + timestamp: int = attr.ib(metadata={"json": "origin_server_ts"}) + content: Any + # N.B. Overriding event IDs is not allowed in standard room versions + event_id: Optional[EventID] = None + + +@dataclass(kw_only=True) +class BatchSendStateEvent(BatchSendEvent, SerializableAttrs): + """ + State events to be used as initial state events on batch send events. These never need to be + deserialized. + """ + + state_key: str diff --git a/mautrix/types/event/beeper.py b/mautrix/types/event/beeper.py new file mode 100644 index 00000000..0ec16479 --- /dev/null +++ b/mautrix/types/event/beeper.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Optional + +from attr import dataclass + +from ..primitive import EventID, RoomID, SessionID +from ..util import SerializableAttrs, SerializableEnum, field +from .base import BaseRoomEvent +from .message import RelatesTo + + +class MessageStatusReason(SerializableEnum): + GENERIC_ERROR = "m.event_not_handled" + UNSUPPORTED = "com.beeper.unsupported_event" + UNDECRYPTABLE = "com.beeper.undecryptable_event" + TOO_OLD = "m.event_too_old" + NETWORK_ERROR = "m.foreign_network_error" + NO_PERMISSION = "m.no_permission" + + @property + def checkpoint_status(self): + from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus + + if self == MessageStatusReason.UNSUPPORTED: + return MessageSendCheckpointStatus.UNSUPPORTED + elif self == MessageStatusReason.TOO_OLD: + return MessageSendCheckpointStatus.TIMEOUT + return MessageSendCheckpointStatus.PERM_FAILURE + + +class MessageStatus(SerializableEnum): + SUCCESS = "SUCCESS" + PENDING = "PENDING" + RETRIABLE = "FAIL_RETRIABLE" + FAIL = "FAIL_PERMANENT" + + +@dataclass(kw_only=True) +class BeeperMessageStatusEventContent(SerializableAttrs): + relates_to: RelatesTo = field(json="m.relates_to") + network: str = "" + status: Optional[MessageStatus] = None + + reason: Optional[MessageStatusReason] = None + error: Optional[str] = None + message: Optional[str] = None + + last_retry: Optional[EventID] = None + + +@dataclass +class BeeperMessageStatusEvent(BaseRoomEvent, SerializableAttrs): + content: BeeperMessageStatusEventContent + + +@dataclass +class BeeperRoomKeyAckEventContent(SerializableAttrs): + room_id: RoomID + session_id: SessionID + first_message_index: int diff --git a/mautrix/types/event/encrypted.py b/mautrix/types/event/encrypted.py index ad1580ab..735e481a 100644 --- a/mautrix/types/event/encrypted.py +++ b/mautrix/types/event/encrypted.py @@ -1,16 +1,16 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Dict, NewType, Optional, Union from enum import IntEnum +import warnings from attr import dataclass -import attr -from ..primitive import JSON, DeviceID, IdentityKey, SessionID -from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer +from ..primitive import JSON, DeviceID, IdentityKey, SessionID, SigningKey +from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field from .base import BaseRoomEvent, BaseUnsigned from .message import RelatesTo @@ -26,6 +26,36 @@ class EncryptionKeyAlgorithm(ExtensibleEnum): SIGNED_CURVE25519: "EncryptionKeyAlgorithm" = "signed_curve25519" +@dataclass(frozen=True) +class KeyID(Serializable): + algorithm: EncryptionKeyAlgorithm + key_id: str + + def serialize(self) -> JSON: + return str(self) + + @classmethod + def deserialize(cls, raw: JSON) -> "KeyID": + assert isinstance(raw, str), "key IDs must be strings" + alg, key_id = raw.split(":", 1) + return cls(EncryptionKeyAlgorithm(alg), key_id) + + def __str__(self) -> str: + return f"{self.algorithm.value}:{self.key_id}" + + @classmethod + def ed25519(cls, key_id: SigningKey | DeviceID) -> "KeyID": + return cls(EncryptionKeyAlgorithm.ED25519, key_id) + + @classmethod + def curve25519(cls, key_id: IdentityKey) -> "KeyID": + return cls(EncryptionKeyAlgorithm.CURVE25519, key_id) + + @classmethod + def signed_curve25519(cls, key_id: IdentityKey) -> "KeyID": + return cls(EncryptionKeyAlgorithm.SIGNED_CURVE25519, key_id) + + class OlmMsgType(Serializable, IntEnum): PREKEY = 0 MESSAGE = 1 @@ -56,12 +86,37 @@ class EncryptedMegolmEventContent(SerializableAttrs): """The content of an m.room.encrypted event""" ciphertext: str - sender_key: IdentityKey - device_id: DeviceID session_id: SessionID - _relates_to: Optional[RelatesTo] = attr.ib(default=None, metadata={"json": "m.relates_to"}) algorithm: EncryptionAlgorithm = EncryptionAlgorithm.MEGOLM_V1 + _sender_key: Optional[IdentityKey] = field(default=None, json="sender_key") + _device_id: Optional[DeviceID] = field(default=None, json="device_id") + _relates_to: Optional[RelatesTo] = field(default=None, json="m.relates_to") + + @property + def sender_key(self) -> Optional[IdentityKey]: + """ + .. deprecated:: 0.17.0 + Matrix v1.3 deprecated the device_id and sender_key fields in megolm events. + """ + warnings.warn( + "The sender_key field in Megolm events was deprecated in Matrix 1.3", + DeprecationWarning, + ) + return self._sender_key + + @property + def device_id(self) -> Optional[DeviceID]: + """ + .. deprecated:: 0.17.0 + Matrix v1.3 deprecated the device_id and sender_key fields in megolm events. + """ + warnings.warn( + "The sender_key field in Megolm events was deprecated in Matrix 1.3", + DeprecationWarning, + ) + return self._device_id + @property def relates_to(self) -> RelatesTo: if self._relates_to is None: @@ -96,7 +151,7 @@ class EncryptedEvent(BaseRoomEvent, SerializableAttrs): """A m.room.encrypted event""" content: EncryptedEventContent - _unsigned: Optional[BaseUnsigned] = attr.ib(default=None, metadata={"json": "unsigned"}) + _unsigned: Optional[BaseUnsigned] = field(default=None, json="unsigned") @property def unsigned(self) -> BaseUnsigned: diff --git a/mautrix/types/event/ephemeral.py b/mautrix/types/event/ephemeral.py index 4ab43988..a1483a23 100644 --- a/mautrix/types/event/ephemeral.py +++ b/mautrix/types/event/ephemeral.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,7 +8,7 @@ from attr import dataclass from ..primitive import JSON, EventID, RoomID, UserID -from ..util import SerializableAttrs, SerializableEnum, deserializer +from ..util import ExtensibleEnum, SerializableAttrs, SerializableEnum, deserializer from .base import BaseEvent, GenericEvent from .type import EventType @@ -49,8 +49,9 @@ class SingleReceiptEventContent(SerializableAttrs): ts: int -class ReceiptType(SerializableEnum): +class ReceiptType(ExtensibleEnum): READ = "m.read" + READ_PRIVATE = "m.read.private" ReceiptEventContent = Dict[EventID, Dict[ReceiptType, Dict[UserID, SingleReceiptEventContent]]] @@ -69,13 +70,15 @@ class ReceiptEvent(BaseEvent, SerializableAttrs): def deserialize_ephemeral_event(data: JSON) -> EphemeralEvent: event_type = EventType.find(data.get("type", None)) if event_type == EventType.RECEIPT: - return ReceiptEvent.deserialize(data) + evt = ReceiptEvent.deserialize(data) elif event_type == EventType.TYPING: - return TypingEvent.deserialize(data) + evt = TypingEvent.deserialize(data) elif event_type == EventType.PRESENCE: - return PresenceEvent.deserialize(data) + evt = PresenceEvent.deserialize(data) else: - return GenericEvent.deserialize(data) + evt = GenericEvent.deserialize(data) + evt.type = event_type + return evt setattr(EphemeralEvent, "deserialize", deserialize_ephemeral_event) diff --git a/mautrix/types/event/generic.py b/mautrix/types/event/generic.py index d9a19df6..155aef90 100644 --- a/mautrix/types/event/generic.py +++ b/mautrix/types/event/generic.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -9,6 +9,7 @@ from ..util import Obj, deserializer from .account_data import AccountDataEvent, AccountDataEventContent from .base import EventType, GenericEvent +from .beeper import BeeperMessageStatusEvent, BeeperMessageStatusEventContent from .encrypted import EncryptedEvent, EncryptedEventContent from .ephemeral import ( EphemeralEvent, @@ -22,7 +23,7 @@ from .reaction import ReactionEvent, ReactionEventContent from .redaction import RedactionEvent, RedactionEventContent from .state import StateEvent, StateEventContent -from .to_device import ToDeviceEvent, ToDeviceEventContent +from .to_device import ASToDeviceEvent, ToDeviceEvent, ToDeviceEventContent from .voip import CallEvent, CallEventContent, type_to_class as voip_types Event = NewType( @@ -37,7 +38,9 @@ PresenceEvent, EncryptedEvent, ToDeviceEvent, + ASToDeviceEvent, CallEvent, + BeeperMessageStatusEvent, GenericEvent, ], ) @@ -53,6 +56,7 @@ EncryptedEventContent, ToDeviceEventContent, CallEventContent, + BeeperMessageStatusEventContent, Obj, ] @@ -81,6 +85,8 @@ def deserialize_event(data: JSON) -> Event: return AccountDataEvent.deserialize(data) elif event_type.is_ephemeral: return EphemeralEvent.deserialize(data) + elif event_type == EventType.BEEPER_MESSAGE_STATUS: + return BeeperMessageStatusEvent.deserialize(data) else: return GenericEvent.deserialize(data) diff --git a/mautrix/types/event/message.py b/mautrix/types/event/message.py index b0deaead..32033581 100644 --- a/mautrix/types/event/message.py +++ b/mautrix/types/event/message.py @@ -1,9 +1,9 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Any, Dict, List, Optional, Pattern, Union +from typing import Dict, List, Optional, Pattern, Union from html import escape import re @@ -11,7 +11,7 @@ import attr from ..primitive import JSON, ContentURI, EventID -from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field +from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field from .base import BaseRoomEvent, BaseUnsigned # region Message types @@ -54,90 +54,35 @@ def is_media(self) -> bool: # region Relations -class InReplyTo: - def __init__( - self, event_id: Optional[EventID] = None, proxy_target: Optional["RelatesTo"] = None - ) -> None: - self._event_id = event_id - self._proxy_target = proxy_target - - @property - def event_id(self) -> EventID: - if self._proxy_target: - return self._proxy_target.event_id - return self._event_id - - @event_id.setter - def event_id(self, event_id: EventID) -> None: - if self._proxy_target: - self._proxy_target.rel_type = RelationType.REPLY - self._proxy_target.event_id = event_id - else: - self._event_id = event_id +@dataclass +class InReplyTo(SerializableAttrs): + event_id: EventID class RelationType(ExtensibleEnum): ANNOTATION: "RelationType" = "m.annotation" REFERENCE: "RelationType" = "m.reference" REPLACE: "RelationType" = "m.replace" - REPLY: "RelationType" = "net.maunium.reply" + THREAD: "RelationType" = "m.thread" @dataclass -class RelatesTo(Serializable): +class RelatesTo(SerializableAttrs): """Message relations. Used for reactions, edits and replies.""" rel_type: RelationType = None event_id: Optional[EventID] = None key: Optional[str] = None - _extra: Dict[str, Any] = attr.ib(factory=lambda: {}) + is_falling_back: Optional[bool] = None + in_reply_to: Optional[InReplyTo] = field(default=None, json="m.in_reply_to") def __bool__(self) -> bool: - return bool(self.rel_type) and bool(self.event_id) - - @classmethod - def deserialize(cls, data: JSON) -> Optional["RelatesTo"]: - if not data: - return None - try: - return cls( - rel_type=RelationType.deserialize(data.pop("rel_type")), - event_id=data.pop("event_id", None), - key=data.pop("key", None), - extra=data, - ) - except KeyError: - pass - try: - return cls(rel_type=RelationType.REPLY, event_id=data["m.in_reply_to"]["event_id"]) - except KeyError: - pass - return None + return (bool(self.rel_type) and bool(self.event_id)) or bool(self.in_reply_to) def serialize(self) -> JSON: if not self: return attr.NOTHING - data = { - **self._extra, - "rel_type": self.rel_type.serialize(), - } - if self.rel_type == RelationType.REPLY: - data["m.in_reply_to"] = {"event_id": self.event_id} - if self.event_id: - data["event_id"] = self.event_id - if self.key: - data["key"] = self.key - return data - - def __setitem__(self, key: str, value: Any) -> None: - if key in ("rel_type", "event_id", "key"): - return setattr(self, key, value) - self._extra[key] = value - - def __getitem__(self, item: str) -> None: - if item in ("rel_type", "event_id", "key"): - return getattr(self, item) - return self._extra[item] + return super().serialize() # endregion @@ -152,12 +97,40 @@ class BaseMessageEventContentFuncs: _relates_to: Optional[RelatesTo] def set_reply(self, reply_to: Union[EventID, "MessageEvent"], **kwargs) -> None: - self.relates_to.rel_type = RelationType.REPLY - self.relates_to.event_id = reply_to if isinstance(reply_to, str) else reply_to.event_id + self.relates_to.in_reply_to = InReplyTo( + event_id=reply_to if isinstance(reply_to, str) else reply_to.event_id + ) + + def set_thread_parent( + self, + thread_parent: Union[EventID, "MessageEvent"], + last_event_in_thread: Union[EventID, "MessageEvent", None] = None, + disable_reply_fallback: bool = False, + **kwargs, + ) -> None: + self.relates_to.rel_type = RelationType.THREAD + self.relates_to.event_id = ( + thread_parent if isinstance(thread_parent, str) else thread_parent.event_id + ) + if isinstance(thread_parent, MessageEvent) and isinstance( + thread_parent.content, BaseMessageEventContentFuncs + ): + self.relates_to.event_id = ( + thread_parent.content.get_thread_parent() or self.relates_to.event_id + ) + if not disable_reply_fallback: + self.set_reply(last_event_in_thread or thread_parent, **kwargs) + self.relates_to.is_falling_back = True def set_edit(self, edits: Union[EventID, "MessageEvent"]) -> None: self.relates_to.rel_type = RelationType.REPLACE self.relates_to.event_id = edits if isinstance(edits, str) else edits.event_id + # Library consumers may create message content by setting a reply first, + # then later marking it as an edit. As edits can't change the reply, just remove + # the reply metadata when marking as a reply. + if self.relates_to.in_reply_to: + self.relates_to.in_reply_to = None + self.relates_to.is_falling_back = None def serialize(self) -> JSON: data = SerializableAttrs.serialize(self) @@ -183,8 +156,8 @@ def relates_to(self, relates_to: RelatesTo) -> None: self._relates_to = relates_to def get_reply_to(self) -> Optional[EventID]: - if self._relates_to and self._relates_to.rel_type == RelationType.REPLY: - return self._relates_to.event_id + if self._relates_to and self._relates_to.in_reply_to: + return self._relates_to.in_reply_to.event_id return None def get_edit(self) -> Optional[EventID]: @@ -192,6 +165,11 @@ def get_edit(self) -> Optional[EventID]: return self._relates_to.event_id return None + def get_thread_parent(self) -> Optional[EventID]: + if self._relates_to and self._relates_to.rel_type == RelationType.THREAD: + return self._relates_to.event_id + return None + def trim_reply_fallback(self) -> None: pass @@ -293,40 +271,13 @@ class LocationInfo(SerializableAttrs): # region Event content -@dataclass -class MediaMessageEventContent(BaseMessageEventContent, SerializableAttrs): - """The content of a media message event (m.image, m.audio, m.video, m.file)""" - - url: Optional[ContentURI] = None - info: Optional[MediaInfo] = None - file: Optional[EncryptedFile] = None - - @staticmethod - @deserializer(MediaInfo) - @deserializer(Optional[MediaInfo]) - def deserialize_info(data: JSON) -> MediaInfo: - if not isinstance(data, dict): - return Obj() - msgtype = data.pop("__mautrix_msgtype", None) - if msgtype == "m.image" or msgtype == "m.sticker": - return ImageInfo.deserialize(data) - elif msgtype == "m.video": - return VideoInfo.deserialize(data) - elif msgtype == "m.audio": - return AudioInfo.deserialize(data) - elif msgtype == "m.file": - return FileInfo.deserialize(data) - else: - return Obj(**data) - - @dataclass class LocationMessageEventContent(BaseMessageEventContent, SerializableAttrs): geo_uri: str = None info: LocationInfo = None -html_reply_fallback_regex: Pattern = re.compile("^" r"[\s\S]+?") +html_reply_fallback_regex: Pattern = re.compile(r"^[\s\S]+") @dataclass @@ -336,20 +287,10 @@ class TextMessageEventContent(BaseMessageEventContent, SerializableAttrs): format: Format = None formatted_body: str = None - def set_reply( - self, reply_to: Union["MessageEvent", EventID], *, displayname: Optional[str] = None - ) -> None: - super().set_reply(reply_to) - if isinstance(reply_to, str): - return - if not self.formatted_body or len(self.formatted_body) == 0 or self.format != Format.HTML: + def ensure_has_html(self) -> None: + if not self.formatted_body or self.format != Format.HTML: self.format = Format.HTML self.formatted_body = escape(self.body).replace("\n", "
") - if isinstance(reply_to, MessageEvent): - self.formatted_body = ( - reply_to.make_reply_fallback_html(displayname) + self.formatted_body - ) - self.body = reply_to.make_reply_fallback_text(displayname) + self.body def formatted(self, format: Format) -> Optional[str]: if self.format == format: @@ -363,20 +304,48 @@ def trim_reply_fallback(self) -> None: setattr(self, "__reply_fallback_trimmed", True) def _trim_reply_fallback_text(self) -> None: - if not self.body.startswith("> ") or "\n" not in self.body: + if ( + not self.body.startswith("> <") and not self.body.startswith("> * <") + ) or "\n" not in self.body: return lines = self.body.split("\n") while len(lines) > 0 and lines[0].startswith("> "): lines.pop(0) - # Pop extra newline at end of fallback - lines.pop(0) - self.body = "\n".join(lines) + self.body = "\n".join(lines).strip() def _trim_reply_fallback_html(self) -> None: if self.formatted_body and self.format == Format.HTML: self.formatted_body = html_reply_fallback_regex.sub("", self.formatted_body) +@dataclass +class MediaMessageEventContent(TextMessageEventContent, SerializableAttrs): + """The content of a media message event (m.image, m.audio, m.video, m.file)""" + + url: Optional[ContentURI] = None + info: Optional[MediaInfo] = None + file: Optional[EncryptedFile] = None + filename: Optional[str] = None + + @staticmethod + @deserializer(MediaInfo) + @deserializer(Optional[MediaInfo]) + def deserialize_info(data: JSON) -> MediaInfo: + if not isinstance(data, dict): + return Obj() + msgtype = data.pop("__mautrix_msgtype", None) + if msgtype == "m.image" or msgtype == "m.sticker": + return ImageInfo.deserialize(data) + elif msgtype == "m.video": + return VideoInfo.deserialize(data) + elif msgtype == "m.audio": + return AudioInfo.deserialize(data) + elif msgtype == "m.file": + return FileInfo.deserialize(data) + else: + return Obj(**data) + + MessageEventContent = Union[ TextMessageEventContent, MediaMessageEventContent, LocationMessageEventContent, Obj ] @@ -436,36 +405,3 @@ def deserialize_content(data: JSON) -> MessageEventContent: return LocationMessageEventContent.deserialize(data) else: return Obj(**data) - - def make_reply_fallback_html(self, displayname: Optional[str] = None) -> str: - """Generate the HTML fallback for messages replying to this event.""" - if self.content.msgtype.is_text: - body = self.content.formatted_body or escape(self.content.body).replace("\n", "
") - else: - sent_type = media_reply_fallback_body_map[self.content.msgtype] or "a message" - body = f"sent {sent_type}" - displayname = escape(displayname) if displayname else self.sender - return html_reply_fallback_format.format( - room_id=self.room_id, - event_id=self.event_id, - sender=self.sender, - displayname=displayname, - content=body, - ) - - def make_reply_fallback_text(self, displayname: Optional[str] = None) -> str: - """Generate the plaintext fallback for messages replying to this event.""" - if self.content.msgtype.is_text: - body = self.content.body - else: - try: - body = media_reply_fallback_body_map[self.content.msgtype] - except KeyError: - body = "an unknown message type" - lines = body.strip().split("\n") - first_line, lines = lines[0], lines[1:] - fallback_text = f"> <{displayname or self.sender}> {first_line}" - for line in lines: - fallback_text += f"\n> {line}" - fallback_text += "\n\n" - return fallback_text diff --git a/mautrix/types/event/reaction.py b/mautrix/types/event/reaction.py index abebd579..437d4b31 100644 --- a/mautrix/types/event/reaction.py +++ b/mautrix/types/event/reaction.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/event/redaction.py b/mautrix/types/event/redaction.py index 5076e9e3..7714d8ba 100644 --- a/mautrix/types/event/redaction.py +++ b/mautrix/types/event/redaction.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/event/state.py b/mautrix/types/event/state.py index 5a66741f..5ffc855f 100644 --- a/mautrix/types/event/state.py +++ b/mautrix/types/event/state.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -9,12 +9,17 @@ import attr from ..primitive import JSON, ContentURI, EventID, RoomAlias, RoomID, UserID -from ..util import Obj, SerializableAttrs, SerializableEnum, deserializer, field +from ..util import ExtensibleEnum, Obj, SerializableAttrs, SerializableEnum, deserializer, field from .base import BaseRoomEvent, BaseUnsigned from .encrypted import EncryptionAlgorithm from .type import EventType, RoomType +@dataclass +class NotificationPowerLevels(SerializableAttrs): + room: int = 50 + + @dataclass class PowerLevelStateEventContent(SerializableAttrs): """The content of a power level event.""" @@ -27,6 +32,8 @@ class PowerLevelStateEventContent(SerializableAttrs): ) events_default: int = 0 + notifications: NotificationPowerLevels = attr.ib(factory=lambda: NotificationPowerLevels()) + state_default: int = 50 invite: int = 50 @@ -34,7 +41,19 @@ class PowerLevelStateEventContent(SerializableAttrs): ban: int = 50 redact: int = 50 - def get_user_level(self, user_id: UserID) -> int: + def get_user_level( + self, + user_id: UserID, + create: Optional["StateEvent"] = None, + ) -> int: + if ( + create + and create.content.supports_creator_power + and (user_id == create.sender or user_id in (create.content.additional_creators or [])) + ): + # This is really meant to be infinity, but involving floats would be annoying, + # so we use an integer larger than the maximum power level (2^53-1) instead. + return 2**60 - 1 return int(self.users.get(user_id, self.users_default)) def set_user_level(self, user_id: UserID, level: int) -> None: @@ -43,7 +62,16 @@ def set_user_level(self, user_id: UserID, level: int) -> None: else: self.users[user_id] = level - def ensure_user_level(self, user_id: UserID, level: int) -> bool: + def ensure_user_level( + self, user_id: UserID, level: int, create: Optional["StateEvent"] = None + ) -> bool: + if ( + create + and create.content.supports_creator_power + and (user_id == create.sender or user_id in (create.content.additional_creators or [])) + ): + # Don't try to set creator power levels + return False if self.get_user_level(user_id) != level: self.set_user_level(user_id, level) return True @@ -74,7 +102,7 @@ class Membership(SerializableEnum): The membership state of a user in a room as specified in section `8.4 Room membership`_ of the spec. - .. _8.4 Room membership: https://spec.matrix.org/v1.1/client-server-api/#room-membership + .. _8.4 Room membership: https://spec.matrix.org/v1.2/client-server-api/#room-membership """ JOIN = "join" @@ -88,7 +116,7 @@ class Membership(SerializableEnum): class MemberStateEventContent(SerializableAttrs): """The content of a membership event. `Spec link`_ - .. _Spec link: https://spec.matrix.org/v1.1/client-server-api/#mroommember""" + .. _Spec link: https://spec.matrix.org/v1.2/client-server-api/#mroommember""" membership: Membership = Membership.LEAVE avatar_url: ContentURI = None @@ -109,7 +137,7 @@ class CanonicalAliasStateEventContent(SerializableAttrs): See also: `m.room.canonical_alias in the spec`_ - .. _m.room.canonical_alias in the spec: https://spec.matrix.org/v1.1/client-server-api/#mroomcanonical_alias + .. _m.room.canonical_alias in the spec: https://spec.matrix.org/v1.2/client-server-api/#mroomcanonical_alias """ canonical_alias: RoomAlias = attr.ib(default=None, metadata={"json": "alias"}) @@ -131,15 +159,15 @@ class RoomAvatarStateEventContent(SerializableAttrs): url: Optional[ContentURI] = None -class JoinRule(SerializableEnum): +class JoinRule(ExtensibleEnum): PUBLIC = "public" KNOCK = "knock" RESTRICTED = "restricted" INVITE = "invite" - PRIVATE = "private" + KNOCK_RESTRICTED = "knock_restricted" -class JoinRestrictionType(SerializableEnum): +class JoinRestrictionType(ExtensibleEnum): ROOM_MEMBERSHIP = "m.room_membership" @@ -185,6 +213,24 @@ class RoomCreateStateEventContent(SerializableAttrs): federate: bool = field(json="m.federate", omit_default=True, default=True) predecessor: Optional[RoomPredecessor] = None type: Optional[RoomType] = None + additional_creators: Optional[List[UserID]] = None + + @property + def supports_creator_power(self) -> bool: + return self.room_version not in ( + "", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + ) @dataclass @@ -250,7 +296,9 @@ def deserialize(cls, data: JSON) -> "StrippedStateEvent": try: event_type = EventType.find(data.get("type", None)) data.get("content", {})["__mautrix_event_type"] = event_type - data.get("unsigned", {}).get("prev_content", {})["__mautrix_event_type"] = event_type + (data.get("unsigned") or {}).get("prev_content", {})[ + "__mautrix_event_type" + ] = event_type except ValueError: pass return super().deserialize(data) @@ -297,12 +345,17 @@ def deserialize(cls, data: JSON) -> "StateEvent": try: event_type = EventType.find(data.get("type"), t_class=EventType.Class.STATE) data.get("content", {})["__mautrix_event_type"] = event_type - if "prev_content" in data and "prev_content" not in data.get("unsigned", {}): + if "prev_content" in data and "prev_content" not in (data.get("unsigned") or {}): + # This if is a workaround for Conduit being extremely dumb + if data.get("unsigned", {}) is None: + data["unsigned"] = {} data.setdefault("unsigned", {})["prev_content"] = data["prev_content"] data.get("unsigned", {}).get("prev_content", {})["__mautrix_event_type"] = event_type except ValueError: return Obj(**data) - return super().deserialize(data) + evt = super().deserialize(data) + evt.type = event_type + return evt @staticmethod @deserializer(StateEventContent) diff --git a/mautrix/types/event/to_device.py b/mautrix/types/event/to_device.py index f2f189d1..d6392177 100644 --- a/mautrix/types/event/to_device.py +++ b/mautrix/types/event/to_device.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -9,18 +9,21 @@ import attr from ..primitive import JSON, DeviceID, IdentityKey, RoomID, SessionID, SigningKey, UserID -from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer +from ..util import ExtensibleEnum, Obj, SerializableAttrs, deserializer, field from .base import BaseEvent, EventType +from .beeper import BeeperRoomKeyAckEventContent from .encrypted import EncryptedOlmEventContent, EncryptionAlgorithm class RoomKeyWithheldCode(ExtensibleEnum): BLACKLISTED: "RoomKeyWithheldCode" = "m.blacklisted" UNVERIFIED: "RoomKeyWithheldCode" = "m.unverified" - UNAUTHORIZED: "RoomKeyWithheldCode" = "m.unauthorized" + UNAUTHORIZED: "RoomKeyWithheldCode" = "m.unauthorised" UNAVAILABLE: "RoomKeyWithheldCode" = "m.unavailable" NO_OLM_SESSION: "RoomKeyWithheldCode" = "m.no_olm" + BEEPER_REDACTED: "RoomKeyWithheldCode" = "com.beeper.redacted" + @dataclass class RoomKeyWithheldEventContent(SerializableAttrs): @@ -39,6 +42,10 @@ class RoomKeyEventContent(SerializableAttrs): session_id: SessionID session_key: str + beeper_max_age_ms: Optional[int] = field(json="com.beeper.max_age_ms", default=None) + beeper_max_messages: Optional[int] = field(json="com.beeper.max_messages", default=None) + beeper_is_scheduled: Optional[bool] = field(json="com.beeper.is_scheduled", default=False) + class KeyRequestAction(ExtensibleEnum): REQUEST: "KeyRequestAction" = "request" @@ -61,7 +68,7 @@ class RoomKeyRequestEventContent(SerializableAttrs): body: Optional[RequestedKeyInfo] = None -@dataclass +@dataclass(kw_only=True) class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs): sender_key: IdentityKey signing_key: SigningKey = attr.ib(metadata={"json": "sender_claimed_ed25519_key"}) @@ -75,6 +82,7 @@ class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs): RoomKeyEventContent, RoomKeyRequestEventContent, ForwardedRoomKeyEventContent, + BeeperRoomKeyAckEventContent, ] to_device_event_content_map = { EventType.TO_DEVICE_ENCRYPTED: EncryptedOlmEventContent, @@ -82,12 +90,10 @@ class ForwardedRoomKeyEventContent(RoomKeyEventContent, SerializableAttrs): EventType.ROOM_KEY_REQUEST: RoomKeyRequestEventContent, EventType.ROOM_KEY: RoomKeyEventContent, EventType.FORWARDED_ROOM_KEY: ForwardedRoomKeyEventContent, + EventType.BEEPER_ROOM_KEY_ACK: BeeperRoomKeyAckEventContent, } -# TODO remaining account data event types - - @dataclass class ToDeviceEvent(BaseEvent, SerializableAttrs): sender: UserID @@ -100,7 +106,9 @@ def deserialize(cls, data: JSON) -> "ToDeviceEvent": data.setdefault("content", {})["__mautrix_event_type"] = evt_type except ValueError: return Obj(**data) - return super().deserialize(data) + evt = super().deserialize(data) + evt.type = evt_type + return evt @staticmethod @deserializer(ToDeviceEventContent) @@ -110,3 +118,9 @@ def deserialize_content(data: JSON) -> ToDeviceEventContent: if not content_type: return Obj(**data) return content_type.deserialize(data) + + +@dataclass +class ASToDeviceEvent(ToDeviceEvent, SerializableAttrs): + to_user_id: UserID + to_device_id: DeviceID diff --git a/mautrix/types/event/type.py b/mautrix/types/event/type.py index 7ca8fd57..5faf3785 100644 --- a/mautrix/types/event/type.py +++ b/mautrix/types/event/type.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -195,6 +195,7 @@ def is_to_device(self) -> bool: "m.call.hangup": "CALL_HANGUP", "m.call.reject": "CALL_REJECT", "m.call.negotiate": "CALL_NEGOTIATE", + "com.beeper.message_send_status": "BEEPER_MESSAGE_STATUS", }, EventType.Class.EPHEMERAL: { "m.receipt": "RECEIPT", @@ -206,6 +207,11 @@ def is_to_device(self) -> bool: "m.push_rules": "PUSH_RULES", "m.tag": "TAG", "m.ignored_user_list": "IGNORED_USER_LIST", + "m.secret_storage.default_key": "SECRET_STORAGE_DEFAULT_KEY", + "m.cross_signing.master": "CROSS_SIGNING_MASTER", + "m.cross_signing.self_signing": "CROSS_SIGNING_SELF_SIGNING", + "m.cross_signing.user_signing": "CROSS_SIGNING_USER_SIGNING", + "m.megolm_backup.v1": "MEGOLM_BACKUP_V1", }, EventType.Class.TO_DEVICE: { "m.room.encrypted": "TO_DEVICE_ENCRYPTED", @@ -215,6 +221,7 @@ def is_to_device(self) -> bool: "m.room_key_request": "ROOM_KEY_REQUEST", "m.forwarded_room_key": "FORWARDED_ROOM_KEY", "m.dummy": "TO_DEVICE_DUMMY", + "com.beeper.room_key.ack": "BEEPER_ROOM_KEY_ACK", }, EventType.Class.UNKNOWN: { "__ALL__": "ALL", # This is not a real event type diff --git a/mautrix/types/event/type.pyi b/mautrix/types/event/type.pyi index 567ec5f1..a2788d6f 100644 --- a/mautrix/types/event/type.pyi +++ b/mautrix/types/event/type.pyi @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -18,6 +18,7 @@ class EventType(Serializable): ACCOUNT_DATA = "account_data" EPHEMERAL = "ephemeral" TO_DEVICE = "to_device" + _by_event_type: ClassVar[dict[str, EventType]] ROOM_CANONICAL_ALIAS: "EventType" @@ -50,6 +51,8 @@ class EventType(Serializable): CALL_REJECT: "EventType" CALL_NEGOTIATE: "EventType" + BEEPER_MESSAGE_STATUS: "EventType" + RECEIPT: "EventType" TYPING: "EventType" PRESENCE: "EventType" @@ -58,6 +61,11 @@ class EventType(Serializable): PUSH_RULES: "EventType" TAG: "EventType" IGNORED_USER_LIST: "EventType" + SECRET_STORAGE_DEFAULT_KEY: "EventType" + CROSS_SIGNING_MASTER: "EventType" + CROSS_SIGNING_SELF_SIGNING: "EventType" + CROSS_SIGNING_USER_SIGNING: "EventType" + MEGOLM_BACKUP_V1: "EventType" TO_DEVICE_ENCRYPTED: "EventType" TO_DEVICE_DUMMY: "EventType" @@ -66,6 +74,7 @@ class EventType(Serializable): ORG_MATRIX_ROOM_KEY_WITHHELD: "EventType" ROOM_KEY_REQUEST: "EventType" FORWARDED_ROOM_KEY: "EventType" + BEEPER_ROOM_KEY_ACK: "EventType" ALL: "EventType" diff --git a/mautrix/types/event/voip.py b/mautrix/types/event/voip.py index ce034d06..579f93a6 100644 --- a/mautrix/types/event/voip.py +++ b/mautrix/types/event/voip.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/filter.py b/mautrix/types/filter.py index c76f67de..6aa7e78d 100644 --- a/mautrix/types/filter.py +++ b/mautrix/types/filter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -17,7 +17,7 @@ class EventFormat(SerializableEnum): Federation event format enum, as specified in the `create filter endpoint`_. .. _create filter endpoint: - https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-user-userid-filter + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ CLIENT = "client" @@ -30,7 +30,7 @@ class EventFilter(SerializableAttrs): Event filter object, as specified in the `create filter endpoint`_. .. _create filter endpoint: - https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-user-userid-filter + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ limit: int = None @@ -59,7 +59,7 @@ class RoomEventFilter(EventFilter, SerializableAttrs): Room event filter object, as specified in the `create filter endpoint`_. .. _create filter endpoint: - https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-user-userid-filter + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ lazy_load_members: bool = False @@ -95,7 +95,7 @@ class StateFilter(RoomEventFilter, SerializableAttrs): same as :class:`RoomEventFilter`. .. _create filter endpoint: - https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-user-userid-filter + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ pass @@ -107,7 +107,7 @@ class RoomFilter(SerializableAttrs): Room filter object, as specified in the `create filter endpoint`_. .. _create filter endpoint: - https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-user-userid-filter + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ not_rooms: List[RoomID] = None @@ -144,7 +144,7 @@ class Filter(SerializableAttrs): Base filter object, as specified in the `create filter endpoint`_. .. _create filter endpoint: - https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-user-userid-filter + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter """ event_fields: List[str] = None diff --git a/mautrix/types/matrixuri.py b/mautrix/types/matrixuri.py index 3fb35eb7..a8c62da7 100644 --- a/mautrix/types/matrixuri.py +++ b/mautrix/types/matrixuri.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/matrixuri_test.py b/mautrix/types/matrixuri_test.py index 5762c48a..611ead0e 100644 --- a/mautrix/types/matrixuri_test.py +++ b/mautrix/types/matrixuri_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/media.py b/mautrix/types/media.py index a7661064..1d0ee66a 100644 --- a/mautrix/types/media.py +++ b/mautrix/types/media.py @@ -1,62 +1,74 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Optional + from attr import dataclass -import attr from .primitive import ContentURI -from .util import SerializableAttrs +from .util import SerializableAttrs, field @dataclass class MediaRepoConfig(SerializableAttrs): """ - Matrix media repo config. See `GET /_matrix/media/r0/config`_. + Matrix media repo config. See `GET /_matrix/media/v3/config`_. - .. _GET /_matrix/media/r0/config: - https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-media-r0-config + .. _GET /_matrix/media/v3/config: + https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3config """ - upload_size: int = attr.ib(metadata={"json": "m.upload.size"}) + upload_size: int = field(default=50 * 1024 * 1024, json="m.upload.size") @dataclass class OpenGraphImage(SerializableAttrs): - url: ContentURI = attr.ib(default=None, metadata={"json": "og:image"}) - mimetype: str = attr.ib(default=None, metadata={"json": "og:image:type"}) - height: int = attr.ib(default=None, metadata={"json": "og:image:width"}) - width: int = attr.ib(default=None, metadata={"json": "og:image:height"}) - size: int = attr.ib(default=None, metadata={"json": "matrix:image:size"}) + url: ContentURI = field(default=None, json="og:image") + mimetype: str = field(default=None, json="og:image:type") + height: int = field(default=None, json="og:image:width") + width: int = field(default=None, json="og:image:height") + size: int = field(default=None, json="matrix:image:size") @dataclass class OpenGraphVideo(SerializableAttrs): - url: ContentURI = attr.ib(default=None, metadata={"json": "og:video"}) - mimetype: str = attr.ib(default=None, metadata={"json": "og:video:type"}) - height: int = attr.ib(default=None, metadata={"json": "og:video:width"}) - width: int = attr.ib(default=None, metadata={"json": "og:video:height"}) - size: int = attr.ib(default=None, metadata={"json": "matrix:video:size"}) + url: ContentURI = field(default=None, json="og:video") + mimetype: str = field(default=None, json="og:video:type") + height: int = field(default=None, json="og:video:width") + width: int = field(default=None, json="og:video:height") + size: int = field(default=None, json="matrix:video:size") @dataclass class OpenGraphAudio(SerializableAttrs): - url: ContentURI = attr.ib(default=None, metadata={"json": "og:audio"}) - mimetype: str = attr.ib(default=None, metadata={"json": "og:audio:type"}) + url: ContentURI = field(default=None, json="og:audio") + mimetype: str = field(default=None, json="og:audio:type") @dataclass class MXOpenGraph(SerializableAttrs): """ - Matrix URL preview response. See `GET /_matrix/media/r0/preview_url`_. + Matrix URL preview response. See `GET /_matrix/media/v3/preview_url`_. + + .. _GET /_matrix/media/v3/preview_url: + https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url + """ + + title: str = field(default=None, json="og:title") + description: str = field(default=None, json="og:description") + image: OpenGraphImage = field(default=None, flatten=True) + video: OpenGraphVideo = field(default=None, flatten=True) + audio: OpenGraphAudio = field(default=None, flatten=True) - .. _GET /_matrix/media/r0/preview_url: - https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-media-r0-preview-url + +@dataclass +class MediaCreateResponse(SerializableAttrs): + """ + Matrix media create response including MSC3870 """ - title: str = attr.ib(default=None, metadata={"json": "og:title"}) - description: str = attr.ib(default=None, metadata={"json": "og:description"}) - image: OpenGraphImage = attr.ib(default=None, metadata={"flatten": True}) - video: OpenGraphVideo = attr.ib(default=None, metadata={"flatten": True}) - audio: OpenGraphAudio = attr.ib(default=None, metadata={"flatten": True}) + content_uri: ContentURI + unused_expired_at: Optional[int] = None + unstable_upload_url: Optional[str] = field(default=None, json="com.beeper.msc3870.upload_url") diff --git a/mautrix/types/misc.py b/mautrix/types/misc.py index 61dc3a54..5a07699c 100644 --- a/mautrix/types/misc.py +++ b/mautrix/types/misc.py @@ -1,15 +1,15 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Dict, List, NamedTuple, NewType +from typing import List, NamedTuple, NewType, Optional from enum import Enum from attr import dataclass import attr -from .event import Event +from .event import Event, StateEvent from .primitive import BatchID, ContentURI, EventID, RoomAlias, RoomID, SyncToken, UserID from .util import SerializableAttrs @@ -19,11 +19,14 @@ class DeviceLists(SerializableAttrs): changed: List[UserID] = attr.ib(factory=lambda: []) left: List[UserID] = attr.ib(factory=lambda: []) + def __bool__(self) -> bool: + return bool(self.changed or self.left) + @dataclass class DeviceOTKCount(SerializableAttrs): - curve25519: int - signed_curve25519: int + signed_curve25519: int = 0 + curve25519: int = 0 class RoomCreatePreset(Enum): @@ -31,7 +34,7 @@ class RoomCreatePreset(Enum): Room creation preset, as specified in the `createRoom endpoint`_ .. _createRoom endpoint: - https://spec.matrix.org/v1.1/client-server-api/#post_matrixclientv3createroom + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom """ PRIVATE = "private_chat" @@ -44,7 +47,7 @@ class RoomDirectoryVisibility(Enum): Room directory visibility, as specified in the `createRoom endpoint`_ .. _createRoom endpoint: - https://spec.matrix.org/v1.1/client-server-api/#post_matrixclientv3createroom + https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom """ PRIVATE = "private" @@ -64,7 +67,7 @@ class RoomAliasInfo(SerializableAttrs): Room alias query result, as specified in the `alias resolve endpoint`_ .. _alias resolve endpoint: - https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3directoryroomroomalias + https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3directoryroomroomalias """ room_id: RoomID = None @@ -84,7 +87,7 @@ class PublicRoomInfo(SerializableAttrs): num_joined_members: int world_readable: bool - guests_can_join: bool + guest_can_join: bool name: str = None topic: str = None @@ -103,14 +106,18 @@ class RoomDirectoryResponse(SerializableAttrs): PaginatedMessages = NamedTuple( - "PaginatedMessages", start=SyncToken, end=SyncToken, events=List[Event] + "PaginatedMessages", start=SyncToken, end=Optional[SyncToken], events=List[Event] ) @dataclass -class VersionsResponse(SerializableAttrs): - versions: List[str] - unstable_features: Dict[str, bool] = attr.ib(factory=lambda: {}) +class EventContext(SerializableAttrs): + end: SyncToken + start: SyncToken + event: Event + events_after: List[Event] + events_before: List[Event] + state: List[StateEvent] @dataclass @@ -120,6 +127,10 @@ class BatchSendResponse(SerializableAttrs): insertion_event_id: EventID batch_event_id: EventID - base_insertion_event_id: EventID - next_batch_id: BatchID + base_insertion_event_id: Optional[EventID] = None + + +@dataclass +class BeeperBatchSendResponse(SerializableAttrs): + event_ids: List[EventID] diff --git a/mautrix/types/primitive.py b/mautrix/types/primitive.py index 7c92b0ee..dec7dc87 100644 --- a/mautrix/types/primitive.py +++ b/mautrix/types/primitive.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -32,7 +32,7 @@ A Matrix `content URI`_, used by the content repository. .. _content URI: - https://spec.matrix.org/v1.1/client-server-api/#matrix-content-mxc-uris + https://spec.matrix.org/v1.2/client-server-api/#matrix-content-mxc-uris """ SyncToken = NewType("SyncToken", str) @@ -53,3 +53,5 @@ SigningKey.__doc__ = "A ed25519 public key as unpadded base64" IdentityKey = NewType("IdentityKey", str) IdentityKey.__doc__ = "A curve25519 public key as unpadded base64" +Signature = NewType("Signature", str) +Signature.__doc__ = "An ed25519 signature as unpadded base64" diff --git a/mautrix/types/push_rules.py b/mautrix/types/push_rules.py index a36c9663..12875770 100644 --- a/mautrix/types/push_rules.py +++ b/mautrix/types/push_rules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/users.py b/mautrix/types/users.py index 69176101..aacd612e 100644 --- a/mautrix/types/users.py +++ b/mautrix/types/users.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/util/enum.py b/mautrix/types/util/enum.py index 072dd927..2ee7cab4 100644 --- a/mautrix/types/util/enum.py +++ b/mautrix/types/util/enum.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/util/enum_test.py b/mautrix/types/util/enum_test.py index 68495970..38917cf9 100644 --- a/mautrix/types/util/enum_test.py +++ b/mautrix/types/util/enum_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/util/serializable.py b/mautrix/types/util/serializable.py index 760df3cb..9bdd7d6b 100644 --- a/mautrix/types/util/serializable.py +++ b/mautrix/types/util/serializable.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/util/serializable_attrs.py b/mautrix/types/util/serializable_attrs.py index 49f6f619..1172133c 100644 --- a/mautrix/types/util/serializable_attrs.py +++ b/mautrix/types/util/serializable_attrs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/util/serializable_attrs_test.py b/mautrix/types/util/serializable_attrs_test.py index 25fef77d..d311cd5d 100644 --- a/mautrix/types/util/serializable_attrs_test.py +++ b/mautrix/types/util/serializable_attrs_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/types/versions.py b/mautrix/types/versions.py new file mode 100644 index 00000000..52a62f59 --- /dev/null +++ b/mautrix/types/versions.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from typing import Dict, List, NamedTuple, Optional, Union +from enum import IntEnum +import re + +from attr import dataclass +import attr + +from . import JSON +from .util import Serializable, SerializableAttrs + + +class VersionFormat(IntEnum): + UNKNOWN = -1 + LEGACY = 0 + MODERN = 1 + + def __repr__(self) -> str: + return f"VersionFormat.{self.name}" + + +legacy_version_regex = re.compile(r"^r(\d+)\.(\d+)\.(\d+)$") +modern_version_regex = re.compile(r"^v(\d+)\.(\d+)$") + + +@attr.dataclass(frozen=True) +class Version(Serializable): + format: VersionFormat + major: int + minor: int + patch: int + raw: str + + def __str__(self) -> str: + if self.format == VersionFormat.MODERN: + return f"v{self.major}.{self.minor}" + elif self.format == VersionFormat.LEGACY: + return f"r{self.major}.{self.minor}.{self.patch}" + else: + return self.raw + + def serialize(self) -> JSON: + return str(self) + + @classmethod + def deserialize(cls, raw: JSON) -> "Version": + assert isinstance(raw, str), "versions must be strings" + if modern := modern_version_regex.fullmatch(raw): + major, minor = modern.groups() + return Version(VersionFormat.MODERN, int(major), int(minor), 0, raw) + elif legacy := legacy_version_regex.fullmatch(raw): + major, minor, patch = legacy.groups() + return Version(VersionFormat.LEGACY, int(major), int(minor), int(patch), raw) + else: + return Version(VersionFormat.UNKNOWN, 0, 0, 0, raw) + + +class SpecVersions: + R010 = Version.deserialize("r0.1.0") + R020 = Version.deserialize("r0.2.0") + R030 = Version.deserialize("r0.3.0") + R040 = Version.deserialize("r0.4.0") + R050 = Version.deserialize("r0.5.0") + R060 = Version.deserialize("r0.6.0") + R061 = Version.deserialize("r0.6.1") + V11 = Version.deserialize("v1.1") + V12 = Version.deserialize("v1.2") + V13 = Version.deserialize("v1.3") + V14 = Version.deserialize("v1.4") + V15 = Version.deserialize("v1.5") + V16 = Version.deserialize("v1.6") + V17 = Version.deserialize("v1.7") + V18 = Version.deserialize("v1.8") + V19 = Version.deserialize("v1.9") + V110 = Version.deserialize("v1.10") + V111 = Version.deserialize("v1.11") + + +@dataclass +class VersionsResponse(SerializableAttrs): + versions: List[Version] + unstable_features: Dict[str, bool] = attr.ib(factory=lambda: {}) + + def supports(self, thing: Union[Version, str]) -> Optional[bool]: + """ + Check if the versions response contains the given spec version or unstable feature. + + Args: + thing: The spec version (as a :class:`Version` or string) + or unstable feature name (as a string) to check. + + Returns: + ``True`` if the exact version or unstable feature is supported, + ``False`` if it's not supported, + ``None`` for unstable features which are not included in the response at all. + """ + if isinstance(thing, Version): + return thing in self.versions + elif (parsed_version := Version.deserialize(thing)).format != VersionFormat.UNKNOWN: + return parsed_version in self.versions + return self.unstable_features.get(thing) + + def supports_at_least(self, version: Union[Version, str]) -> bool: + """ + Check if the versions response contains the given spec version or any higher version. + + Args: + version: The spec version as a :class:`Version` or a string. + + Returns: + ``True`` if a version equal to or higher than the given version is found, + ``False`` otherwise. + """ + if isinstance(version, str): + version = Version.deserialize(version) + return any(v for v in self.versions if v > version) + + @property + def latest_version(self) -> Version: + return max(self.versions) + + @property + def has_legacy_versions(self) -> bool: + """ + Check if the response contains any legacy (r0.x.y) versions. + + .. deprecated:: 0.16.10 + :meth:`supports_at_least` and :meth:`supports` methods are now preferred. + """ + return any(v for v in self.versions if v.format == VersionFormat.LEGACY) + + @property + def has_modern_versions(self) -> bool: + """ + Check if the response contains any modern (v1.1 or higher) versions. + + .. deprecated:: 0.16.10 + :meth:`supports_at_least` and :meth:`supports` methods are now preferred. + """ + return self.supports_at_least(SpecVersions.V11) diff --git a/mautrix/util/__init__.py b/mautrix/util/__init__.py index 6a7827d1..fd349bef 100644 --- a/mautrix/util/__init__.py +++ b/mautrix/util/__init__.py @@ -1,22 +1,28 @@ __all__ = [ + # Directory modules + "async_db", + "config", + "db", "formatter", "logging", - "config", - "signed_token", - "simple_template", - "manhole", - "markdown", - "simple_lock", + # File modules + "async_body", + "async_getter_lock", + "background_task", + "bridge_state", + "color_log", + "ffmpeg", "file_store", - "program", - "async_db", - "db", - "opt_prometheus", + "format_duration", "magic", - "bridge_state", + "manhole", + "markdown", "message_send_checkpoint", - "variation_selector", - "format_duration", - "ffmpeg", + "opt_prometheus", + "program", + "signed_token", + "simple_lock", + "simple_template", "utf16_surrogate", + "variation_selector", ] diff --git a/mautrix/util/async_body.py b/mautrix/util/async_body.py new file mode 100644 index 00000000..4db4d1e5 --- /dev/null +++ b/mautrix/util/async_body.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from typing import AsyncGenerator, Union +import logging + +import aiohttp + +AsyncBody = AsyncGenerator[Union[bytes, bytearray, memoryview], None] + + +async def async_iter_bytes(data: bytearray | bytes, chunk_size: int = 1024**2) -> AsyncBody: + """ + Return memory views into a byte array in chunks. This is used to prevent aiohttp from copying + the entire request body. + + Args: + data: The underlying data to iterate through. + chunk_size: How big each returned chunk should be. + + Returns: + An async generator that yields the given data in chunks. + """ + with memoryview(data) as mv: + for i in range(0, len(data), chunk_size): + yield mv[i : i + chunk_size] + + +class FileTooLargeError(Exception): + def __init__(self, max_size: int) -> None: + super().__init__(f"File size larger than maximum ({max_size / 1024 / 1024} MiB)") + + +_default_dl_log = logging.getLogger("mau.util.download") + + +async def read_response_chunks( + resp: aiohttp.ClientResponse, max_size: int, log: logging.Logger = _default_dl_log +) -> bytearray: + """ + Read the body from an aiohttp response in chunks into a mutable bytearray. + + Args: + resp: The aiohttp response object to read the body from. + max_size: The maximum size to read. FileTooLargeError will be raised if the Content-Length + is higher than this, or if the body exceeds this size during reading. + log: A logger for logging download status. + + Returns: + The body data as a byte array. + + Raises: + FileTooLargeError: if the body is larger than the provided max_size. + """ + content_length = int(resp.headers.get("Content-Length", "0")) + if 0 < max_size < content_length: + raise FileTooLargeError(max_size) + size_str = "unknown length" if content_length == 0 else f"{content_length} bytes" + log.info(f"Reading file download response with {size_str} (max: {max_size})") + data = bytearray(content_length) + mv = memoryview(data) if content_length > 0 else None + read_size = 0 + max_size += 1 + while True: + block = await resp.content.readany() + if not block: + break + max_size -= len(block) + if max_size <= 0: + raise FileTooLargeError(max_size) + if len(data) >= read_size + len(block): + mv[read_size : read_size + len(block)] = block + elif len(data) > read_size: + log.warning("File being downloaded is bigger than expected") + mv[read_size:] = block[: len(data) - read_size] + mv.release() + mv = None + data.extend(block[len(data) - read_size :]) + else: + if mv is not None: + mv.release() + mv = None + data.extend(block) + read_size += len(block) + if mv is not None: + mv.release() + log.info(f"Successfully read {read_size} bytes of file download response") + return data + + +__all__ = ["AsyncBody", "FileTooLargeError", "async_iter_bytes", "async_read_bytes"] diff --git a/mautrix/util/async_db/__init__.py b/mautrix/util/async_db/__init__.py index 0fa484bb..500ffe0e 100644 --- a/mautrix/util/async_db/__init__.py +++ b/mautrix/util/async_db/__init__.py @@ -2,6 +2,12 @@ from .connection import LoggingConnection as Connection from .database import Database +from .errors import ( + DatabaseException, + DatabaseNotOwned, + ForeignTablesFound, + UnsupportedDatabaseVersion, +) from .scheme import Scheme from .upgrade import UpgradeTable, register_upgrade @@ -13,11 +19,14 @@ PostgresDatabase = None try: + from aiosqlite import Cursor as SQLiteCursor + from .aiosqlite import SQLiteDatabase except ImportError: if __optional_imports__: raise SQLiteDatabase = None + SQLiteCursor = None __all__ = [ "Database", @@ -25,6 +34,11 @@ "register_upgrade", "PostgresDatabase", "SQLiteDatabase", + "SQLiteCursor", "Connection", "Scheme", + "DatabaseException", + "DatabaseNotOwned", + "UnsupportedDatabaseVersion", + "ForeignTablesFound", ] diff --git a/mautrix/util/async_db/aiosqlite.py b/mautrix/util/async_db/aiosqlite.py index c0e5a5ad..934379a8 100644 --- a/mautrix/util/async_db/aiosqlite.py +++ b/mautrix/util/async_db/aiosqlite.py @@ -1,14 +1,16 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any +from typing import Any, AsyncContextManager from contextlib import asynccontextmanager +from contextvars import ContextVar import asyncio import logging +import os import re import sqlite3 @@ -23,6 +25,9 @@ POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)") +in_transaction = ContextVar("in_transaction", default=False) + + class TxnConnection(aiosqlite.Connection): def __init__(self, path: str, **kwargs) -> None: def connector() -> sqlite3.Connection: @@ -34,7 +39,11 @@ def connector() -> sqlite3.Connection: @asynccontextmanager async def transaction(self) -> None: + if in_transaction.get(): + yield + return await self.execute("BEGIN TRANSACTION") + token = in_transaction.set(True) try: yield except Exception: @@ -42,17 +51,23 @@ async def transaction(self) -> None: raise else: await self.commit() + finally: + in_transaction.reset(token) def __execute(self, query: str, *args: Any): query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query) return super().execute(query, args) - async def execute(self, query: str, *args: Any, timeout: float | None = None) -> None: - await self.__execute(query, *args) + async def execute( + self, query: str, *args: Any, timeout: float | None = None + ) -> aiosqlite.Cursor: + return await self.__execute(query, *args) - async def executemany(self, query: str, *args: Any, timeout: float | None = None) -> None: + async def executemany( + self, query: str, *args: Any, timeout: float | None = None + ) -> aiosqlite.Cursor: query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query) - await super().executemany(query, *args) + return await super().executemany(query, *args) async def fetch( self, query: str, *args: Any, timeout: float | None = None @@ -60,7 +75,9 @@ async def fetch( async with self.__execute(query, *args) as cursor: return list(await cursor.fetchall()) - async def fetchrow(self, query: str, *args: Any, timeout: float | None = None) -> sqlite3.Row: + async def fetchrow( + self, query: str, *args: Any, timeout: float | None = None + ) -> sqlite3.Row | None: async with self.__execute(query, *args) as cursor: return await cursor.fetchone() @@ -75,9 +92,11 @@ async def fetchval( class SQLiteDatabase(Database): scheme = Scheme.SQLITE + _parent: SQLiteDatabase | None _pool: asyncio.Queue[TxnConnection] _stopped: bool _conns: int + _init_commands: list[str] def __init__( self, @@ -85,34 +104,100 @@ def __init__( upgrade_table: UpgradeTable, db_args: dict[str, Any] | None = None, log: logging.Logger | None = None, + owner_name: str | None = None, + ignore_foreign_tables: bool = True, ) -> None: - super().__init__(url, db_args=db_args, upgrade_table=upgrade_table, log=log) + super().__init__( + url, + db_args=db_args, + upgrade_table=upgrade_table, + log=log, + owner_name=owner_name, + ignore_foreign_tables=ignore_foreign_tables, + ) + self._parent = None self._path = url.path - if self._path.startswith("/"): - self._path = self._path[1:] self._pool = asyncio.Queue(self._db_args.pop("min_size", 1)) self._db_args.pop("max_size", None) self._stopped = False self._conns = 0 + self._init_commands = self._add_missing_pragmas(self._db_args.pop("init_commands", [])) + + @staticmethod + def _add_missing_pragmas(init_commands: list[str]) -> list[str]: + has_foreign_keys = False + has_journal_mode = False + has_synchronous = False + has_busy_timeout = False + for cmd in init_commands: + if "PRAGMA" not in cmd: + continue + if "foreign_keys" in cmd: + has_foreign_keys = True + elif "journal_mode" in cmd: + has_journal_mode = True + elif "synchronous" in cmd: + has_synchronous = True + elif "busy_timeout" in cmd: + has_busy_timeout = True + if not has_foreign_keys: + init_commands.append("PRAGMA foreign_keys = ON") + if not has_journal_mode: + init_commands.append("PRAGMA journal_mode = WAL") + if not has_synchronous and "PRAGMA journal_mode = WAL" in init_commands: + init_commands.append("PRAGMA synchronous = NORMAL") + if not has_busy_timeout: + init_commands.append("PRAGMA busy_timeout = 5000") + return init_commands + + def override_pool(self, db: Database) -> None: + assert isinstance(db, SQLiteDatabase) + self._parent = db async def start(self) -> None: + if self._parent: + await super().start() + return + if self._conns: + raise RuntimeError("database pool has already been started") + elif self._stopped: + raise RuntimeError("database pool can't be restarted") self.log.debug(f"Connecting to {self.url}") + self.log.debug(f"Database connection init commands: {self._init_commands}") + if os.path.exists(self._path): + if not os.access(self._path, os.W_OK): + self.log.warning("Database file doesn't seem writable") + elif not os.access(os.path.dirname(os.path.abspath(self._path)), os.W_OK): + self.log.warning("Database file doesn't exist and directory doesn't seem writable") for _ in range(self._pool.maxsize): conn = await TxnConnection(self._path, **self._db_args) + if self._init_commands: + cur = await conn.cursor() + for command in self._init_commands: + self.log.trace("Executing init command: %s", command) + await cur.execute(command) + await conn.commit() conn.row_factory = sqlite3.Row self._pool.put_nowait(conn) self._conns += 1 await super().start() async def stop(self) -> None: + if self._parent: + return self._stopped = True while self._conns > 0: conn = await self._pool.get() self._conns -= 1 await conn.close() + def acquire_direct(self) -> AsyncContextManager[LoggingConnection]: + if self._parent: + return self._parent.acquire() + return self._acquire() + @asynccontextmanager - async def acquire(self) -> LoggingConnection: + async def _acquire(self) -> LoggingConnection: if self._stopped: raise RuntimeError("database pool has been stopped") conn = await self._pool.get() diff --git a/mautrix/util/async_db/asyncpg.py b/mautrix/util/async_db/asyncpg.py index aad9ef45..97b49f6c 100644 --- a/mautrix/util/async_db/asyncpg.py +++ b/mautrix/util/async_db/asyncpg.py @@ -9,6 +9,8 @@ from contextlib import asynccontextmanager import asyncio import logging +import sys +import traceback from yarl import URL import asyncpg @@ -23,6 +25,7 @@ class PostgresDatabase(Database): scheme = Scheme.POSTGRES _pool: asyncpg.pool.Pool | None _pool_override: bool + _exit_on_ice: bool def __init__( self, @@ -30,12 +33,25 @@ def __init__( upgrade_table: UpgradeTable, db_args: dict[str, Any] = None, log: logging.Logger | None = None, + owner_name: str | None = None, + ignore_foreign_tables: bool = True, ) -> None: if url.scheme in ("cockroach", "cockroachdb"): self.scheme = Scheme.COCKROACH # Send postgres scheme to asyncpg url = url.with_scheme("postgres") - super().__init__(url, db_args=db_args, upgrade_table=upgrade_table, log=log) + self._exit_on_ice = True + if db_args: + self._exit_on_ice = db_args.pop("meow_exit_on_ice", True) + db_args.pop("init_commands", None) + super().__init__( + url, + db_args=db_args, + upgrade_table=upgrade_table, + log=log, + owner_name=owner_name, + ignore_foreign_tables=ignore_foreign_tables, + ) self._pool = None self._pool_override = False @@ -45,8 +61,13 @@ def override_pool(self, db: PostgresDatabase) -> None: async def start(self) -> None: if not self._pool_override: + if self._pool: + raise RuntimeError("Database has already been started") self._db_args["loop"] = asyncio.get_running_loop() - self.log.debug(f"Connecting to {self.url}") + log_url = self.url + if log_url.password: + log_url = log_url.with_password("password-redacted") + self.log.debug(f"Connecting to {log_url}") self._pool = await asyncpg.create_pool(str(self.url), **self._db_args) await super().start() @@ -57,13 +78,29 @@ def pool(self) -> asyncpg.pool.Pool: return self._pool async def stop(self) -> None: - if not self._pool_override: - await self.pool.close() + if not self._pool_override and self._pool is not None: + await self._pool.close() + + async def _handle_exception(self, err: Exception) -> None: + if self._exit_on_ice and isinstance(err, asyncpg.InternalClientError): + pre_stack = traceback.format_stack()[:-2] + post_stack = traceback.format_exception(err) + header = post_stack[0] + post_stack = post_stack[1:] + self.log.critical( + "Got asyncpg internal client error, exiting...\n%s%s%s", + header, + "".join(pre_stack), + "".join(post_stack), + ) + sys.exit(26) @asynccontextmanager - async def acquire(self) -> LoggingConnection: + async def acquire_direct(self) -> LoggingConnection: async with self.pool.acquire() as conn: - yield LoggingConnection(self.scheme, conn, self.log) + yield LoggingConnection( + self.scheme, conn, self.log, handle_exception=self._handle_exception + ) Database.schemes["postgres"] = PostgresDatabase diff --git a/mautrix/util/async_db/connection.py b/mautrix/util/async_db/connection.py index 6d726e5d..1d27c9f0 100644 --- a/mautrix/util/async_db/connection.py +++ b/mautrix/util/async_db/connection.py @@ -5,7 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Any, Callable, TypeVar +from typing import Any, Awaitable, Callable, TypeVar from contextlib import asynccontextmanager from logging import WARNING import functools @@ -19,6 +19,7 @@ if __optional_imports__: from sqlite3 import Row + from aiosqlite import Cursor from asyncpg import Record import asyncpg @@ -42,20 +43,22 @@ async def wrapper(self: LoggingConnection, arg: str, *args: Any, **kwargs: str) return wrapper -class LoggingConnection: - scheme: Scheme - wrapped: aiosqlite.TxnConnection | asyncpg.Connection - log: TraceLogger +async def handle_exception_noop(_: Exception) -> None: + pass + +class LoggingConnection: def __init__( self, scheme: Scheme, wrapped: aiosqlite.TxnConnection | asyncpg.Connection, log: TraceLogger, + handle_exception: Callable[[Exception], Awaitable[None]] = handle_exception_noop, ) -> None: self.scheme = scheme self.wrapped = wrapped self.log = log + self._handle_exception = handle_exception self._inited = True def __setattr__(self, key: str, value: Any) -> None: @@ -65,32 +68,72 @@ def __setattr__(self, key: str, value: Any) -> None: @asynccontextmanager async def transaction(self) -> None: - async with self.wrapped.transaction(): - yield + try: + async with self.wrapped.transaction(): + yield + except Exception as e: + await self._handle_exception(e) + raise @log_duration - async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: - return await self.wrapped.execute(query, *args, timeout=timeout) + async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor: + try: + return await self.wrapped.execute(query, *args, timeout=timeout) + except Exception as e: + await self._handle_exception(e) + raise @log_duration - async def executemany(self, query: str, *args: Any, timeout: float | None = None) -> str: - return await self.wrapped.executemany(query, *args, timeout=timeout) + async def executemany( + self, query: str, *args: Any, timeout: float | None = None + ) -> str | Cursor: + try: + return await self.wrapped.executemany(query, *args, timeout=timeout) + except Exception as e: + await self._handle_exception(e) + raise @log_duration async def fetch( self, query: str, *args: Any, timeout: float | None = None ) -> list[Row | Record]: - return await self.wrapped.fetch(query, *args, timeout=timeout) + try: + return await self.wrapped.fetch(query, *args, timeout=timeout) + except Exception as e: + await self._handle_exception(e) + raise @log_duration async def fetchval( self, query: str, *args: Any, column: int = 0, timeout: float | None = None ) -> Any: - return await self.wrapped.fetchval(query, *args, column=column, timeout=timeout) + try: + return await self.wrapped.fetchval(query, *args, column=column, timeout=timeout) + except Exception as e: + await self._handle_exception(e) + raise @log_duration - async def fetchrow(self, query: str, *args: Any, timeout: float | None = None) -> Row | Record: - return await self.wrapped.fetchrow(query, *args, timeout=timeout) + async def fetchrow( + self, query: str, *args: Any, timeout: float | None = None + ) -> Row | Record | None: + try: + return await self.wrapped.fetchrow(query, *args, timeout=timeout) + except Exception as e: + await self._handle_exception(e) + raise + + async def table_exists(self, name: str) -> bool: + if self.scheme == Scheme.SQLITE: + return await self.fetchval( + "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name=?1)", name + ) + elif self.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): + return await self.fetchval( + "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)", name + ) + else: + raise RuntimeError(f"Unknown scheme {self.scheme}") @log_duration async def copy_records_to_table( @@ -104,6 +147,14 @@ async def copy_records_to_table( ) -> None: if self.scheme != Scheme.POSTGRES: raise RuntimeError("copy_records_to_table is only supported on Postgres") - return await self.wrapped.copy_records_to_table( - table_name, records=records, columns=columns, schema_name=schema_name, timeout=timeout - ) + try: + return await self.wrapped.copy_records_to_table( + table_name, + records=records, + columns=columns, + schema_name=schema_name, + timeout=timeout, + ) + except Exception as e: + await self._handle_exception(e) + raise diff --git a/mautrix/util/async_db/connection.pyi b/mautrix/util/async_db/connection.pyi index 2d967cef..27b86be7 100644 --- a/mautrix/util/async_db/connection.pyi +++ b/mautrix/util/async_db/connection.pyi @@ -3,7 +3,7 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import Any, AsyncContextManager +from typing import Any, AsyncContextManager, Awaitable, Callable from sqlite3 import Row from asyncpg import Record @@ -17,12 +17,14 @@ from .scheme import Scheme class LoggingConnection: scheme: Scheme wrapped: aiosqlite.TxnConnection | asyncpg.Connection + _handle_exception: Callable[[Exception], Awaitable[None]] log: TraceLogger def __init__( self, scheme: Scheme, wrapped: aiosqlite.TxnConnection | asyncpg.Connection, log: TraceLogger, + handle_exception: Callable[[Exception], Awaitable[None]] = None, ) -> None: ... async def transaction(self) -> AsyncContextManager[None]: ... async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: ... @@ -35,7 +37,8 @@ class LoggingConnection: ) -> Any: ... async def fetchrow( self, query: str, *args: Any, timeout: float | None = None - ) -> Row | Record: ... + ) -> Row | Record | None: ... + async def table_exists(self, name: str) -> bool: ... async def copy_records_to_table( self, table_name: str, diff --git a/mautrix/util/async_db/database.py b/mautrix/util/async_db/database.py index 420f093d..b5128b74 100644 --- a/mautrix/util/async_db/database.py +++ b/mautrix/util/async_db/database.py @@ -7,8 +7,9 @@ from typing import Any, AsyncContextManager, Type from abc import ABC, abstractmethod +from contextlib import asynccontextmanager +from contextvars import ContextVar import logging -import sys from yarl import URL @@ -16,12 +17,16 @@ from mautrix.util.logging import TraceLogger from .connection import LoggingConnection +from .errors import DatabaseNotOwned, ForeignTablesFound from .scheme import Scheme from .upgrade import UpgradeTable, upgrade_tables if __optional_imports__: + from aiosqlite import Cursor from asyncpg import Record +conn_var: ContextVar[LoggingConnection | None] = ContextVar("db_connection", default=None) + class Database(ABC): schemes: dict[str, Type[Database]] = {} @@ -30,18 +35,24 @@ class Database(ABC): scheme: Scheme url: URL _db_args: dict[str, Any] - upgrade_table: UpgradeTable + upgrade_table: UpgradeTable | None + owner_name: str | None + ignore_foreign_tables: bool def __init__( self, url: URL, - upgrade_table: UpgradeTable, + upgrade_table: UpgradeTable | None, db_args: dict[str, Any] | None = None, log: TraceLogger | None = None, + owner_name: str | None = None, + ignore_foreign_tables: bool = True, ) -> None: self.url = url self._db_args = {**db_args} if db_args else {} self.upgrade_table = upgrade_table + self.owner_name = owner_name + self.ignore_foreign_tables = ignore_foreign_tables self.log = log or logging.getLogger("mau.db") assert isinstance(self.log, TraceLogger) @@ -53,6 +64,8 @@ def create( db_args: dict[str, Any] | None = None, upgrade_table: UpgradeTable | str | None = None, log: logging.Logger | TraceLogger | None = None, + owner_name: str | None = None, + ignore_foreign_tables: bool = True, ) -> Database: url = URL(url) try: @@ -75,31 +88,75 @@ def create( upgrade_table = UpgradeTable() elif not isinstance(upgrade_table, UpgradeTable): raise ValueError(f"Can't use {type(upgrade_table)} as the upgrade table") - return impl(url, db_args=db_args, upgrade_table=upgrade_table, log=log) + return impl( + url, + db_args=db_args, + upgrade_table=upgrade_table, + log=log, + owner_name=owner_name, + ignore_foreign_tables=ignore_foreign_tables, + ) def override_pool(self, db: Database) -> None: pass async def start(self) -> None: - try: + if not self.ignore_foreign_tables: + await self._check_foreign_tables() + if self.owner_name: + await self._check_owner() + if self.upgrade_table and len(self.upgrade_table.upgrades) > 0: await self.upgrade_table.upgrade(self) - except Exception: - self.log.critical("Failed to upgrade database", exc_info=True) - sys.exit(25) + + async def _check_foreign_tables(self) -> None: + if await self.table_exists("state_groups_state"): + raise ForeignTablesFound("found state_groups_state likely belonging to Synapse") + elif await self.table_exists("roomserver_rooms"): + raise ForeignTablesFound("found roomserver_rooms likely belonging to Dendrite") + + async def _check_owner(self) -> None: + await self.execute(""" + CREATE TABLE IF NOT EXISTS database_owner ( + key INTEGER PRIMARY KEY DEFAULT 0, + owner TEXT NOT NULL + ) + """) + owner = await self.fetchval("SELECT owner FROM database_owner WHERE key=0") + if not owner: + await self.execute( + "INSERT INTO database_owner (key, owner) VALUES (0, $1)", self.owner_name + ) + elif owner != self.owner_name: + raise DatabaseNotOwned(owner) @abstractmethod async def stop(self) -> None: pass @abstractmethod - def acquire(self) -> AsyncContextManager[LoggingConnection]: + def acquire_direct(self) -> AsyncContextManager[LoggingConnection]: pass - async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: + @asynccontextmanager + async def acquire(self) -> LoggingConnection: + conn = conn_var.get(None) + if conn is not None: + yield conn + return + async with self.acquire_direct() as conn: + token = conn_var.set(conn) + try: + yield conn + finally: + conn_var.reset(token) + + async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor: async with self.acquire() as conn: return await conn.execute(query, *args, timeout=timeout) - async def executemany(self, query: str, *args: Any, timeout: float | None = None) -> str: + async def executemany( + self, query: str, *args: Any, timeout: float | None = None + ) -> str | Cursor: async with self.acquire() as conn: return await conn.executemany(query, *args, timeout=timeout) @@ -113,6 +170,12 @@ async def fetchval( async with self.acquire() as conn: return await conn.fetchval(query, *args, column=column, timeout=timeout) - async def fetchrow(self, query: str, *args: Any, timeout: float | None = None) -> Record: + async def fetchrow( + self, query: str, *args: Any, timeout: float | None = None + ) -> Record | None: async with self.acquire() as conn: return await conn.fetchrow(query, *args, timeout=timeout) + + async def table_exists(self, name: str) -> bool: + async with self.acquire() as conn: + return await conn.table_exists(name) diff --git a/mautrix/util/async_db/errors.py b/mautrix/util/async_db/errors.py new file mode 100644 index 00000000..082d3b8e --- /dev/null +++ b/mautrix/util/async_db/errors.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + + +class DatabaseException(RuntimeError): + pass + + @property + def explanation(self) -> str | None: + return None + + +class UnsupportedDatabaseVersion(DatabaseException): + def __init__(self, name: str, version: int, latest: int) -> None: + super().__init__( + f"Unsupported {name} schema version v{version} (latest known is v{latest})" + ) + + @property + def explanation(self) -> str: + return "Downgrading is not supported" + + +class ForeignTablesFound(DatabaseException): + def __init__(self, explanation: str) -> None: + super().__init__(f"The database contains foreign tables ({explanation})") + + @property + def explanation(self) -> str: + return "You can use --ignore-foreign-tables to ignore this error" + + +class DatabaseNotOwned(DatabaseException): + def __init__(self, owner: str) -> None: + super().__init__(f"The database is owned by {owner}") + + @property + def explanation(self) -> str: + return "Sharing the same database with different programs is not supported" diff --git a/mautrix/util/async_db/upgrade.py b/mautrix/util/async_db/upgrade.py index 0b1dd438..c084d28b 100644 --- a/mautrix/util/async_db/upgrade.py +++ b/mautrix/util/async_db/upgrade.py @@ -14,17 +14,14 @@ from .. import async_db from .connection import LoggingConnection +from .errors import UnsupportedDatabaseVersion from .scheme import Scheme Upgrade = Callable[[LoggingConnection, Scheme], Awaitable[Optional[int]]] UpgradeWithoutScheme = Callable[[LoggingConnection], Awaitable[Optional[int]]] -class UnsupportedDatabaseVersion(Exception): - pass - - -async def noop_upgrade(_: LoggingConnection) -> None: +async def noop_upgrade(_: LoggingConnection, _2: Scheme) -> None: pass @@ -64,17 +61,18 @@ def __init__( def register( self, + _outer_fn: Upgrade | UpgradeWithoutScheme | None = None, + *, index: int = -1, description: str = "", - _outer_fn: Upgrade | None = None, transaction: bool = True, upgrades_to: int | Upgrade | None = None, - ) -> Upgrade | Callable[[Upgrade], Upgrade] | None: + ) -> Upgrade | Callable[[Upgrade | UpgradeWithoutScheme], Upgrade]: if isinstance(index, str): description = index index = -1 - def actually_register(fn: Upgrade) -> Upgrade: + def actually_register(fn: Upgrade | UpgradeWithoutScheme) -> Upgrade: fn = _wrap_upgrade(fn) fn.__mau_db_upgrade_description__ = description fn.__mau_db_upgrade_transaction__ = transaction @@ -99,23 +97,22 @@ async def _save_version(self, conn: LoggingConnection, version: int) -> None: await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version) async def upgrade(self, db: async_db.Database) -> None: - await db.execute( - f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} ( + await db.execute(f""" + CREATE TABLE IF NOT EXISTS {self.version_table_name} ( version INTEGER PRIMARY KEY - )""" - ) + ) + """) row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1") version = row["version"] if row else 0 if len(self.upgrades) < version: - error = ( - f"Unsupported database version v{version} " - f"(latest known is v{len(self.upgrades) - 1})" + unsupported_version_error = UnsupportedDatabaseVersion( + self.database_name, version, len(self.upgrades) ) if not self.allow_unsupported: - raise UnsupportedDatabaseVersion(error) + raise unsupported_version_error else: - self.log.warning(error) + self.log.warning(str(unsupported_version_error)) return elif len(self.upgrades) == version: self.log.debug(f"Database at v{version}, not upgrading") @@ -125,7 +122,9 @@ async def upgrade(self, db: async_db.Database) -> None: while version < len(self.upgrades): old_version = version upgrade = self.upgrades[version] - new_version = getattr(upgrade, "__mau_db_upgrade_destination__", version + 1) + new_version = ( + getattr(upgrade, "__mau_db_upgrade_destination__", None) or version + 1 + ) if callable(new_version): new_version = await new_version(conn, db.scheme) desc = getattr(upgrade, "__mau_db_upgrade_description__", None) @@ -179,6 +178,6 @@ def _find_upgrade_table(fn: Upgrade) -> UpgradeTable: def register_upgrade(index: int = -1, description: str = "") -> Callable[[Upgrade], Upgrade]: def actually_register(fn: Upgrade) -> Upgrade: - return _find_upgrade_table(fn).register(index, description, fn) + return _find_upgrade_table(fn).register(fn, index=index, description=description) return actually_register diff --git a/mautrix/util/async_getter_lock.py b/mautrix/util/async_getter_lock.py new file mode 100644 index 00000000..31766976 --- /dev/null +++ b/mautrix/util/async_getter_lock.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from typing import Any +import functools + +from mautrix import __optional_imports__ + +if __optional_imports__: + from typing import Awaitable, Callable, ParamSpec + + Param = ParamSpec("Param") + Func = Callable[Param, Awaitable[Any]] + + +def async_getter_lock(fn: Func) -> Func: + """ + A utility decorator for locking async getters that have caches + (preventing race conditions between cache check and e.g. async database actions). + + The class must have an ```_async_get_locks`` defaultdict that contains :class:`asyncio.Lock`s + (see example for exact definition). Non-cache-affecting arguments should be only passed as + keyword args. + + Args: + fn: The function to decorate. + + Returns: + The decorated function. + + Examples: + >>> import asyncio + >>> from collections import defaultdict + >>> class User: + ... _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) + ... db: Any + ... cache: dict[str, User] + ... @classmethod + ... @async_getter_lock + ... async def get(cls, id: str, *, create: bool = False) -> User | None: + ... try: + ... return cls.cache[id] + ... except KeyError: + ... pass + ... user = await cls.db.fetch_user(id) + ... if user: + ... return user + ... elif create: + ... return await cls.db.create_user(id) + ... return None + """ + + @functools.wraps(fn) + async def wrapper(cls, *args, **kwargs) -> Any: + async with cls._async_get_locks[args]: + return await fn(cls, *args, **kwargs) + + return wrapper diff --git a/mautrix/util/background_task.py b/mautrix/util/background_task.py new file mode 100644 index 00000000..e22e74f1 --- /dev/null +++ b/mautrix/util/background_task.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 Tulir Asokan +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from typing import Coroutine +import asyncio +import logging + +_tasks = set() +log = logging.getLogger("mau.background_task") + + +async def catch(coro: Coroutine, caller: str) -> None: + try: + await coro + except Exception: + log.exception(f"Uncaught error in background task (created in {caller})") + + +# Logger.findCaller finds the 3rd stack frame, so add an intermediate function +# to get the caller of create(). +def _find_caller() -> tuple[str, int, str, None]: + return log.findCaller() + + +def create(coro: Coroutine, *, name: str | None = None, catch_errors: bool = True) -> asyncio.Task: + """ + Create a background asyncio task safely, ensuring a reference is kept until the task completes. + It also catches and logs uncaught errors (unless disabled via the parameter). + + Args: + coro: The coroutine to wrap in a task and execute. + name: An optional name for the created task. + catch_errors: Should the task be wrapped in a try-except block to log any uncaught errors? + + Returns: + An asyncio Task object wrapping the given coroutine. + """ + if catch_errors: + try: + file_name, line_number, function_name, _ = _find_caller() + caller = f"{function_name} at {file_name}:{line_number}" + except ValueError: + caller = "unknown function" + task = asyncio.create_task(catch(coro, caller), name=name) + else: + task = asyncio.create_task(coro, name=name) + _tasks.add(task) + task.add_done_callback(_tasks.discard) + return task diff --git a/mautrix/util/bridge_state.py b/mautrix/util/bridge_state.py index 3d25b305..d28448bf 100644 --- a/mautrix/util/bridge_state.py +++ b/mautrix/util/bridge_state.py @@ -1,9 +1,9 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from typing import ClassVar, Dict, Optional +from typing import Any, ClassVar, Dict, Optional import logging import time @@ -62,8 +62,8 @@ class BridgeStateEvent(SerializableEnum): class BridgeState(SerializableAttrs): human_readable_errors: ClassVar[Dict[Optional[str], str]] = {} default_source: ClassVar[str] = "bridge" - default_error_ttl: ClassVar[int] = 60 - default_ok_ttl: ClassVar[int] = 240 + default_error_ttl: ClassVar[int] = 3600 + default_ok_ttl: ClassVar[int] = 21600 state_event: BridgeStateEvent user_id: Optional[UserID] = None @@ -74,6 +74,8 @@ class BridgeState(SerializableAttrs): source: Optional[str] = None error: Optional[str] = None message: Optional[str] = None + info: Optional[Dict[str, Any]] = None + reason: Optional[str] = None send_attempts_: int = field(default=0, hidden=True) @@ -100,11 +102,12 @@ def should_deduplicate(self, prev_state: Optional["BridgeState"]) -> bool: not prev_state or prev_state.state_event != self.state_event or prev_state.error != self.error + or prev_state.info != self.info ): # If there's no previous state or the state was different, send this one. return False - # If there's more than ⅘ of the previous pong's time-to-live left, drop this one - return prev_state.timestamp + (prev_state.ttl / 5) > self.timestamp + # If the previous state is recent, drop this one + return prev_state.timestamp + prev_state.ttl > self.timestamp async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool = True) -> bool: if not url: @@ -112,9 +115,10 @@ async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool = self.send_attempts_ += 1 headers = {"Authorization": f"Bearer {token}", "User-Agent": HTTPAPI.default_ua} try: - async with aiohttp.ClientSession() as sess, sess.post( - url, json=self.serialize(), headers=headers - ) as resp: + async with ( + aiohttp.ClientSession() as sess, + sess.post(url, json=self.serialize(), headers=headers) as resp, + ): if not 200 <= resp.status < 300: text = await resp.text() text = text.replace("\n", "\\n") @@ -133,5 +137,5 @@ async def send(self, url: str, token: str, log: logging.Logger, log_sent: bool = @dataclass(kw_only=True) class GlobalBridgeState(SerializableAttrs): - remote_states: Optional[Dict[str, BridgeState]] = field(json="remoteState") + remote_states: Optional[Dict[str, BridgeState]] = field(json="remoteState", default=None) bridge_state: BridgeState = field(json="bridgeState") diff --git a/mautrix/util/config/__init__.py b/mautrix/util/config/__init__.py index 87ee3262..89322556 100644 --- a/mautrix/util/config/__init__.py +++ b/mautrix/util/config/__init__.py @@ -4,3 +4,18 @@ from .recursive_dict import RecursiveDict from .string import BaseStringConfig from .validation import BaseValidatableConfig, ConfigValueError, ForbiddenDefault, ForbiddenKey + +__all__ = [ + "BaseConfig", + "BaseMissingError", + "ConfigUpdateHelper", + "BaseFileConfig", + "yaml", + "BaseProxyConfig", + "RecursiveDict", + "BaseStringConfig", + "BaseValidatableConfig", + "ConfigValueError", + "ForbiddenDefault", + "ForbiddenKey", +] diff --git a/mautrix/util/config/base.py b/mautrix/util/config/base.py index 7d14ae0e..46bb3678 100644 --- a/mautrix/util/config/base.py +++ b/mautrix/util/config/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/config/file.py b/mautrix/util/config/file.py index ce8c5c66..5911af3c 100644 --- a/mautrix/util/config/file.py +++ b/mautrix/util/config/file.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/config/proxy.py b/mautrix/util/config/proxy.py index 06b6acdc..5b9324b0 100644 --- a/mautrix/util/config/proxy.py +++ b/mautrix/util/config/proxy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -30,7 +30,7 @@ def __init__( def load(self) -> None: self._data = self._load_proxy() or CommentedMap() - def load_base(self) -> Optional[RecursiveDict[CommentedMap]]: + def load_base(self) -> RecursiveDict[CommentedMap] | None: return self._load_base_proxy() def save(self) -> None: diff --git a/mautrix/util/config/recursive_dict.py b/mautrix/util/config/recursive_dict.py index cf6d0330..2d3c2b69 100644 --- a/mautrix/util/config/recursive_dict.py +++ b/mautrix/util/config/recursive_dict.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/config/string.py b/mautrix/util/config/string.py index 1967700f..45a2ce77 100644 --- a/mautrix/util/config/string.py +++ b/mautrix/util/config/string.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/config/validation.py b/mautrix/util/config/validation.py index 6ab36432..6b863b27 100644 --- a/mautrix/util/config/validation.py +++ b/mautrix/util/config/validation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/db/__init__.py b/mautrix/util/db/__init__.py index 15ade9f4..b13bcfd6 100644 --- a/mautrix/util/db/__init__.py +++ b/mautrix/util/db/__init__.py @@ -1 +1,3 @@ from .base import Base, BaseClass + +__all__ = ["Base", "BaseClass"] diff --git a/mautrix/util/db/base.py b/mautrix/util/db/base.py index 4acf7906..b9baeb21 100644 --- a/mautrix/util/db/base.py +++ b/mautrix/util/db/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -23,6 +23,9 @@ class BaseClass: """ Base class for SQLAlchemy models. Provides SQLAlchemy declarative base features and some additional utilities. + + .. deprecated:: 0.15.0 + The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy. """ __tablename__: str @@ -237,4 +240,9 @@ def __iter__(self): @as_declarative() class Base(BaseClass): + """ + .. deprecated:: 0.15.0 + The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy. + """ + pass diff --git a/mautrix/util/ffmpeg.py b/mautrix/util/ffmpeg.py index 2a415390..41cebf32 100644 --- a/mautrix/util/ffmpeg.py +++ b/mautrix/util/ffmpeg.py @@ -1,13 +1,15 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Iterable +from typing import Any, Iterable from pathlib import Path import asyncio +import json +import logging import mimetypes import os import shutil @@ -34,7 +36,89 @@ def __init__(self) -> None: ffmpeg_path = _abswhich("ffmpeg") -ffmpeg_default_params = ("-hide_banner", "-loglevel", "warning") +ffmpeg_default_params = ("-hide_banner", "-loglevel", "warning", "-y") + +ffprobe_path = _abswhich("ffprobe") +ffprobe_default_params = ( + "-loglevel", + "quiet", + "-print_format", + "json", + "-show_optional_fields", + "1", + "-show_format", + "-show_streams", +) + + +async def probe_path( + input_file: os.PathLike[str] | str, + logger: logging.Logger | None = None, +) -> Any: + """ + Probes a media file on the disk using ffprobe. + + Args: + input_file: The full path to the file. + + Returns: + A Python object containing the parsed JSON response from ffprobe + + Raises: + ConverterError: if ffprobe returns a non-zero exit code. + """ + if ffprobe_path is None: + raise NotInstalledError() + + input_file = Path(input_file) + proc = await asyncio.create_subprocess_exec( + ffprobe_path, + *ffprobe_default_params, + str(input_file), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})" + raise ConverterError(f"ffprobe error: {err_text}") + elif stderr and logger: + logger.warning(f"ffprobe warning: {stderr.decode('utf-8')}") + return json.loads(stdout) + + +async def probe_bytes( + data: bytes, + input_mime: str | None = None, + logger: logging.Logger | None = None, +) -> Any: + """ + Probe media file data using ffprobe. + + Args: + data: The bytes of the file to probe. + input_mime: The mime type of the input data. If not specified, will be guessed using magic. + + Returns: + A Python object containing the parsed JSON response from ffprobe + + Raises: + ConverterError: if ffprobe returns a non-zero exit code. + """ + if ffprobe_path is None: + raise NotInstalledError() + + if input_mime is None: + if magic is None: + raise ValueError("input_mime was not specified and magic is not installed") + input_mime = magic.mimetype(data) + input_extension = mimetypes.guess_extension(input_mime) + with tempfile.TemporaryDirectory(prefix="mautrix_ffmpeg_") as tmpdir: + input_file = Path(tmpdir) / f"data{input_extension}" + with open(input_file, "wb") as file: + file.write(data) + return await probe_path(input_file=input_file, logger=logger) async def convert_path( @@ -44,6 +128,7 @@ async def convert_path( output_args: Iterable[str] | None = None, remove_input: bool = False, output_path_override: os.PathLike[str] | str | None = None, + logger: logging.Logger | None = None, ) -> Path | bytes: """ Convert a media file on the disk using ffmpeg. @@ -76,6 +161,10 @@ async def convert_path( else: input_file = Path(input_file) output_file = input_file.parent / f"{input_file.stem}{output_extension}" + if input_file == output_file: + output_file = Path(output_file) + output_file = output_file.parent / f"{output_file.stem}-new{output_extension}" + proc = await asyncio.create_subprocess_exec( ffmpeg_path, *ffmpeg_default_params, @@ -92,9 +181,8 @@ async def convert_path( if proc.returncode != 0: err_text = stderr.decode("utf-8") if stderr else f"unknown ({proc.returncode})" raise ConverterError(f"ffmpeg error: {err_text}") - elif stderr: - # TODO log warnings? - pass + elif stderr and logger: + logger.warning(f"ffmpeg warning: {stderr.decode('utf-8')}") if remove_input and isinstance(input_file, Path): input_file.unlink(missing_ok=True) return stdout if output_file == "-" else output_file @@ -106,6 +194,7 @@ async def convert_bytes( input_args: Iterable[str] | None = None, output_args: Iterable[str] | None = None, input_mime: str | None = None, + logger: logging.Logger | None = None, ) -> bytes: """ Convert media file data using ffmpeg. @@ -140,6 +229,7 @@ async def convert_bytes( output_extension=output_extension, input_args=input_args, output_args=output_args, + logger=logger, ) with open(output_file, "rb") as file: return file.read() @@ -152,4 +242,6 @@ async def convert_bytes( "NotInstalledError", "convert_bytes", "convert_path", + "probe_bytes", + "probe_path", ] diff --git a/mautrix/util/file_store.py b/mautrix/util/file_store.py index 1ba88c5b..a0e43c56 100644 --- a/mautrix/util/file_store.py +++ b/mautrix/util/file_store.py @@ -1,23 +1,17 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import IO, Any +from typing import IO, Any, Protocol from abc import ABC, abstractmethod from pathlib import Path import json import pickle -import sys import time -if sys.version_info >= (3, 8): - from typing import Protocol -else: - from typing_extensions import Protocol - class Filer(Protocol): def dump(self, obj: Any, file: IO) -> None: diff --git a/mautrix/util/format_duration.py b/mautrix/util/format_duration.py index af2977ce..7b1ca1a1 100644 --- a/mautrix/util/format_duration.py +++ b/mautrix/util/format_duration.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/format_duration_test.py b/mautrix/util/format_duration_test.py index 7405845c..9693651c 100644 --- a/mautrix/util/format_duration_test.py +++ b/mautrix/util/format_duration_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/formatter/__init__.py b/mautrix/util/formatter/__init__.py index aa99777a..f922ff57 100644 --- a/mautrix/util/formatter/__init__.py +++ b/mautrix/util/formatter/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -12,3 +12,19 @@ async def parse_html(input_html: str) -> str: return (await MatrixParser().parse(input_html)).text + + +__all__ = [ + "AbstractEntity", + "EntityString", + "SemiAbstractEntity", + "SimpleEntity", + "EntityType", + "FormattedString", + "HTMLNode", + "read_html", + "MarkdownString", + "MatrixParser", + "RecursionContext", + "parse_html", +] diff --git a/mautrix/util/formatter/entity_string.py b/mautrix/util/formatter/entity_string.py index 53977e2f..f40beb0b 100644 --- a/mautrix/util/formatter/entity_string.py +++ b/mautrix/util/formatter/entity_string.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -39,6 +39,8 @@ def adjust_offset(self, offset: int, max_length: int = -1) -> SemiAbstractEntity entity.offset += offset if entity.offset < 0: entity.length += entity.offset + if entity.length < 0: + return None entity.offset = 0 elif entity.offset > max_length > -1: return None diff --git a/mautrix/util/formatter/formatted_string.py b/mautrix/util/formatter/formatted_string.py index 3414f4e3..8fbcdb89 100644 --- a/mautrix/util/formatter/formatted_string.py +++ b/mautrix/util/formatter/formatted_string.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/formatter/html_reader.py b/mautrix/util/formatter/html_reader.py index c7b582f3..29697180 100644 --- a/mautrix/util/formatter/html_reader.py +++ b/mautrix/util/formatter/html_reader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/formatter/html_reader.pyi b/mautrix/util/formatter/html_reader.pyi index d70e1bae..63b5b5c3 100644 --- a/mautrix/util/formatter/html_reader.pyi +++ b/mautrix/util/formatter/html_reader.pyi @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/formatter/html_reader_htmlparser.py b/mautrix/util/formatter/html_reader_htmlparser.py deleted file mode 100644 index 6461a80d..00000000 --- a/mautrix/util/formatter/html_reader_htmlparser.py +++ /dev/null @@ -1,2 +0,0 @@ -# TODO: remove this file in v0.15 -from .html_reader import HTMLNode, read_html diff --git a/mautrix/util/formatter/html_reader_lxml.py b/mautrix/util/formatter/html_reader_lxml.py deleted file mode 100644 index 4ca1a8a0..00000000 --- a/mautrix/util/formatter/html_reader_lxml.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2021 Tulir Asokan -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -from lxml import html - -HTMLNode = html.HtmlElement - - -def read_html(data: str) -> HTMLNode: - return html.fromstring(data) diff --git a/mautrix/util/formatter/markdown_string.py b/mautrix/util/formatter/markdown_string.py index 50fd8087..7eaca9c9 100644 --- a/mautrix/util/formatter/markdown_string.py +++ b/mautrix/util/formatter/markdown_string.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/formatter/parser.py b/mautrix/util/formatter/parser.py index 83f0e57f..dcff1eb8 100644 --- a/mautrix/util/formatter/parser.py +++ b/mautrix/util/formatter/parser.py @@ -106,6 +106,9 @@ async def blockquote_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> msg = await self.tag_aware_parse_node(node, ctx) return msg.format(self.e.BLOCKQUOTE) + async def hr_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: + return self.fs("---") + async def header_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: children = await self.node_to_fstrings(node, ctx) length = int(node.tag[1]) @@ -174,6 +177,9 @@ async def event_link_to_fstring( ) -> T | None: return None + async def img_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: + return self.fs(node.attrib.get("alt") or node.attrib.get("title") or "") + async def custom_node_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T | None: return None @@ -191,6 +197,8 @@ async def node_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: return self.fs("") elif node.tag == "blockquote": return await self.blockquote_to_fstring(node, ctx) + elif node.tag == "hr": + return await self.hr_to_fstring(node, ctx) elif node.tag == "ol": return await self.list_to_fstring(node, ctx) elif node.tag == "ul": @@ -203,6 +211,8 @@ async def node_to_fstring(self, node: HTMLNode, ctx: RecursionContext) -> T: return await self.basic_format_to_fstring(node, ctx) elif node.tag == "a": return await self.link_to_fstring(node, ctx) + elif node.tag == "img": + return await self.img_to_fstring(node, ctx) elif node.tag == "p": return (await self.tag_aware_parse_node(node, ctx)).append("\n") elif node.tag in ("font", "span"): diff --git a/mautrix/util/formatter/parser_test.py b/mautrix/util/formatter/parser_test.py index 730b62b7..ccfc2724 100644 --- a/mautrix/util/formatter/parser_test.py +++ b/mautrix/util/formatter/parser_test.py @@ -8,7 +8,6 @@ from . import parse_html -@pytest.mark.asyncio async def test_basic_markdown() -> None: tests = { "test": "**test**", @@ -27,7 +26,6 @@ async def test_basic_markdown() -> None: assert await parse_html(html) == markdown_ish -@pytest.mark.asyncio async def test_nested_markdown() -> None: input_html = """

Hello, World!

diff --git a/mautrix/util/logging/__init__.py b/mautrix/util/logging/__init__.py index fedf9dc6..e0792a17 100644 --- a/mautrix/util/logging/__init__.py +++ b/mautrix/util/logging/__init__.py @@ -1,2 +1,4 @@ from .color import ColorFormatter from .trace import SILLY, TRACE, TraceLogger + +__all__ = ["ColorFormatter", "TraceLogger", "SILLY", "TRACE"] diff --git a/mautrix/util/logging/color.py b/mautrix/util/logging/color.py index 86572d35..62701119 100644 --- a/mautrix/util/logging/color.py +++ b/mautrix/util/logging/color.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/logging/trace.py b/mautrix/util/logging/trace.py index 6b7ac943..f2c893aa 100644 --- a/mautrix/util/logging/trace.py +++ b/mautrix/util/logging/trace.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/magic.py b/mautrix/util/magic.py index a3b9ffc7..5c061993 100644 --- a/mautrix/util/magic.py +++ b/mautrix/util/magic.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -17,7 +17,7 @@ _from_filename = lambda file: magic.detect_from_filename(file).mime_type -def mimetype(data: bytes | str) -> str: +def mimetype(data: bytes | bytearray | str) -> str: """ Uses magic to determine the mimetype of a file on disk or in memory. @@ -33,6 +33,9 @@ def mimetype(data: bytes | str) -> str: return _from_filename(data) elif isinstance(data, bytes): return _from_buffer(data) + elif isinstance(data, bytearray): + # Magic doesn't like bytearrays directly, so just copy the first 1024 bytes for it. + return _from_buffer(bytes(data[:1024])) else: raise TypeError( f"mimetype() argument must be a string or bytes, not {type(data).__name__!r}" diff --git a/mautrix/util/manhole.py b/mautrix/util/manhole.py index 02b122da..c6505822 100644 --- a/mautrix/util/manhole.py +++ b/mautrix/util/manhole.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -76,7 +76,13 @@ class StatefulCommandCompiler(codeop.CommandCompiler): def __init__(self) -> None: super().__init__() self.compiler = functools.partial( - compile, optimize=1, flags=ast.PyCF_ONLY_AST | codeop.PyCF_DONT_IMPLY_DEDENT + compile, + optimize=1, + flags=( + ast.PyCF_ONLY_AST + | codeop.PyCF_DONT_IMPLY_DEDENT + | codeop.PyCF_ALLOW_INCOMPLETE_INPUT + ), ) self.buf = BytesIO() @@ -104,9 +110,7 @@ def reset(self) -> None: class Interpreter(ABC): @abstractmethod - def __init__( - self, namespace: Dict[str, Any], banner: Union[bytes, str], loop: asyncio.AbstractEventLoop - ) -> None: + def __init__(self, namespace: Dict[str, Any], banner: Union[bytes, str]) -> None: pass @abstractmethod @@ -126,17 +130,13 @@ class AsyncInterpreter(Interpreter): namespace: Dict[str, Any] banner: bytes compiler: StatefulCommandCompiler - loop: asyncio.AbstractEventLoop running: bool - def __init__( - self, namespace: Dict[str, Any], banner: Union[bytes, str], loop: asyncio.AbstractEventLoop - ) -> None: - super().__init__(namespace, banner, loop) + def __init__(self, namespace: Dict[str, Any], banner: Union[bytes, str]) -> None: + super().__init__(namespace, banner) self.namespace = namespace self.banner = banner if isinstance(banner, bytes) else str(banner).encode("utf-8") self.compiler = StatefulCommandCompiler() - self.loop = loop async def send_exception(self) -> None: """When an exception has occurred, write the traceback to the user.""" @@ -261,7 +261,6 @@ async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWri class InterpreterFactory: namespace: Dict[str, Any] banner: bytes - loop: asyncio.AbstractEventLoop interpreter_class: Type[Interpreter] clients: List[Interpreter] whitelist: Set[int] @@ -272,12 +271,10 @@ def __init__( namespace: Dict[str, Any], banner: Union[bytes, str], interpreter_class: Type[Interpreter], - loop: asyncio.AbstractEventLoop, whitelist: Set[int], ) -> None: self.namespace = namespace or {} self.banner = banner - self.loop = loop self.interpreter_class = interpreter_class self.clients = [] self.whitelist = whitelist @@ -304,9 +301,7 @@ async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWri return namespace = {**self.namespace} - interpreter = self.interpreter_class( - namespace=namespace, banner=self.banner, loop=self.loop - ) + interpreter = self.interpreter_class(namespace=namespace, banner=self.banner) namespace["exit"] = interpreter.close self.clients.append(interpreter) conn_id = self.conn_id @@ -336,15 +331,13 @@ async def start_manhole( """ if not SO_PEERCRED: raise ValueError("SO_PEERCRED is not supported on this platform") - loop = loop or asyncio.get_event_loop() factory = InterpreterFactory( namespace=namespace, banner=banner, interpreter_class=AsyncInterpreter, - loop=loop, whitelist=whitelist, ) - server = await asyncio.start_unix_server(factory, path=path, loop=loop) + server = await asyncio.start_unix_server(factory, path=path) os.chmod(path, 0o666) def stop(): diff --git a/mautrix/util/markdown.py b/mautrix/util/markdown.py index 6d4b275f..1f24ab4e 100644 --- a/mautrix/util/markdown.py +++ b/mautrix/util/markdown.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/message_send_checkpoint.py b/mautrix/util/message_send_checkpoint.py index 276f7b7a..ee0c17f3 100644 --- a/mautrix/util/message_send_checkpoint.py +++ b/mautrix/util/message_send_checkpoint.py @@ -1,3 +1,8 @@ +# Copyright (c) 2022 Sumner Evans +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. from typing import Optional import logging @@ -24,6 +29,7 @@ class MessageSendCheckpointStatus(SerializableEnum): PERM_FAILURE = "PERM_FAILURE" UNSUPPORTED = "UNSUPPORTED" TIMEOUT = "TIMEOUT" + DELIVERY_FAILED = "DELIVERY_FAILED" class MessageSendCheckpointReportedBy(SerializableEnum): @@ -51,26 +57,32 @@ async def send(self, endpoint: str, as_token: str, log: logging.Logger) -> None: return try: headers = {"Authorization": f"Bearer {as_token}", "User-Agent": HTTPAPI.default_ua} - async with aiohttp.ClientSession() as sess, sess.post( - endpoint, - json={"checkpoints": [self.serialize()]}, - headers=headers, - timeout=ClientTimeout(5), - ) as resp: + async with ( + aiohttp.ClientSession() as sess, + sess.post( + endpoint, + json={"checkpoints": [self.serialize()]}, + headers=headers, + timeout=ClientTimeout(30), + ) as resp, + ): if not 200 <= resp.status < 300: text = await resp.text() text = text.replace("\n", "\\n") log.warning( - f"Unexpected status code {resp.status} sending message send checkpoints " - f"for {self.event_id}: {text}" + f"Unexpected status code {resp.status} sending checkpoint " + f"for {self.event_id} ({self.step}/{self.status}): {text}" ) else: log.info( - f"Successfully sent message send checkpoints for {self.event_id} " - f"(step: {self.step})" + f"Successfully sent checkpoint for {self.event_id} " + f"({self.step}/{self.status})" ) except Exception as e: - log.warning(f"Failed to send message send checkpoints for {self.event_id}: {e}") + log.warning( + f"Failed to send checkpoint for {self.event_id} ({self.step}/{self.status}): " + f"{type(e).__name__}: {e}" + ) CHECKPOINT_TYPES = { @@ -78,6 +90,9 @@ async def send(self, endpoint: str, as_token: str, log: logging.Logger) -> None: EventType.ROOM_MESSAGE, EventType.ROOM_ENCRYPTED, EventType.ROOM_MEMBER, + EventType.ROOM_NAME, + EventType.ROOM_AVATAR, + EventType.ROOM_TOPIC, EventType.STICKER, EventType.REACTION, EventType.CALL_INVITE, diff --git a/mautrix/util/opt_prometheus.py b/mautrix/util/opt_prometheus.py index 91a1c90a..1afc4a54 100644 --- a/mautrix/util/opt_prometheus.py +++ b/mautrix/util/opt_prometheus.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/opt_prometheus.pyi b/mautrix/util/opt_prometheus.pyi index e3d943ed..10c763a5 100644 --- a/mautrix/util/opt_prometheus.pyi +++ b/mautrix/util/opt_prometheus.pyi @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Tulir Asokan +# Copyright (c) 2022 Tulir Asokan # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/mautrix/util/program.py b/mautrix/util/program.py index cc4d7845..74752ca5 100644 --- a/mautrix/util/program.py +++ b/mautrix/util/program.py @@ -98,7 +98,7 @@ def _prepare(self) -> None: self.log.info(f"Initializing {self.name} {self.version}") try: - self.prepare() + self.loop.run_until_complete(self._async_prepare()) except Exception: self.log.critical("Unexpected error in initialization", exc_info=True) sys.exit(1) @@ -117,9 +117,10 @@ def preinit(self) -> None: self.prepare_config() self.prepare_log() self.check_config() + self.init_loop() @property - def _default_base_config(self) -> str: + def base_config_path(self) -> str: return f"pkg://{self.module}/example-config.yaml" def prepare_arg_parser(self) -> None: @@ -133,21 +134,13 @@ def prepare_arg_parser(self) -> None: metavar="", help="the path to your config file", ) - self.parser.add_argument( - "-b", - "--base-config", - type=str, - default=self._default_base_config, - metavar="", - help="the path to the example config (for automatic config updates)", - ) self.parser.add_argument( "-n", "--no-update", action="store_true", help="Don't save updated config to disk" ) def prepare_config(self) -> None: """Pre-init lifecycle method. Extend this if you want to customize config loading.""" - self.config = self.config_class(self.args.config, self.args.base_config) + self.config = self.config_class(self.args.config, self.base_config_path) self.load_and_update_config() def load_and_update_config(self) -> None: @@ -155,13 +148,10 @@ def load_and_update_config(self) -> None: try: self.config.update(save=not self.args.no_update) except BaseMissingError: - if self.args.base_config != self._default_base_config: - print(f"Failed to read base config from {self.args.base_config}") - else: - print( - "Failed to read base config from the default path " - f"({self._default_base_config}). Maybe your installation is corrupted?" - ) + print( + "Failed to read base config from the default path " + f"({self.base_config_path}). Maybe your installation is corrupted?" + ) sys.exit(12) def check_config(self) -> None: @@ -179,25 +169,29 @@ def prepare_log(self) -> None: logging.config.dictConfig(copy.deepcopy(self.config["logging"])) self.log = cast(TraceLogger, logging.getLogger("mau.init")) + async def _async_prepare(self) -> None: + self.prepare() + def prepare(self) -> None: """ Lifecycle method where the primary program initialization happens. Use this to fill startup_actions with async startup tasks. """ - self.prepare_loop() - def prepare_loop(self) -> None: + def init_loop(self) -> None: """Init lifecycle method where the asyncio event loop is created.""" if uvloop is not None: uvloop.install() self.log.debug("Using uvloop for asyncio") - self.loop = asyncio.get_event_loop() + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) def start_prometheus(self) -> None: try: enabled = self.config["metrics.enabled"] listen_port = self.config["metrics.listen_port"] + hostname = self.config.get("metrics.hostname", "0.0.0.0") except KeyError: return if not enabled: @@ -207,12 +201,14 @@ def start_prometheus(self) -> None: "Metrics are enabled in config, but prometheus_client is not installed" ) return - prometheus.start_http_server(listen_port) + prometheus.start_http_server(listen_port, addr=hostname) def _run(self) -> None: signal.signal(signal.SIGINT, signal.default_int_handler) signal.signal(signal.SIGTERM, signal.default_int_handler) + self._stop_task = self.loop.create_future() + exit_code = 0 try: self.log.debug("Running startup actions...") start_ts = time() @@ -222,13 +218,13 @@ def _run(self) -> None: f"Startup actions complete in {round(end_ts - start_ts, 2)} seconds, " "now running forever" ) - self._stop_task = self.loop.create_future() - self.loop.run_until_complete(self._stop_task) + exit_code = self.loop.run_until_complete(self._stop_task) self.log.debug("manual_stop() called, stopping...") except KeyboardInterrupt: self.log.debug("Interrupt received, stopping...") except Exception: self.log.critical("Unexpected error in main event loop", exc_info=True) + self.loop.run_until_complete(self.system_exit()) sys.exit(2) except SystemExit: self.loop.run_until_complete(self.system_exit()) @@ -236,8 +232,10 @@ def _run(self) -> None: self.prepare_stop() self.loop.run_until_complete(self.stop()) self.prepare_shutdown() + self.loop.close() + asyncio.set_event_loop(None) self.log.info("Everything stopped, shutting down") - sys.exit(0) + sys.exit(exit_code) async def system_exit(self) -> None: """Lifecycle method that is called if the main event loop exits using ``sys.exit()``.""" @@ -267,9 +265,9 @@ async def stop(self) -> None: def prepare_shutdown(self) -> None: """Lifecycle method that is called right before ``sys.exit(0)``.""" - def manual_stop(self) -> None: + def manual_stop(self, exit_code: int = 0) -> None: """Tell the event loop to cleanly stop and run the stop lifecycle steps.""" - self._stop_task.set_result(None) + self._stop_task.set_result(exit_code) def add_startup_actions(self, *actions: NewTask) -> None: self.startup_actions = self._add_actions(self.startup_actions, actions) diff --git a/mautrix/util/proxy.py b/mautrix/util/proxy.py new file mode 100644 index 00000000..f36da73d --- /dev/null +++ b/mautrix/util/proxy.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Awaitable, Callable, TypeVar +import asyncio +import json +import logging +import time +import urllib.request + +from aiohttp import ClientConnectionError +from yarl import URL + +from mautrix.util.logging import TraceLogger + +try: + from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError +except ImportError: + + class ProxyError(Exception): + pass + + ProxyConnectionError = ProxyTimeoutError = ProxyError + +RETRYABLE_PROXY_EXCEPTIONS = ( + ProxyError, + ProxyTimeoutError, + ProxyConnectionError, + ClientConnectionError, + ConnectionError, + asyncio.TimeoutError, +) + + +class ProxyHandler: + current_proxy_url: str | None = None + log = logging.getLogger("mau.proxy") + + def __init__(self, api_url: str | None) -> None: + self.api_url = api_url + + def get_proxy_url_from_api(self, reason: str | None = None) -> str | None: + assert self.api_url is not None + + api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {})) + + # NOTE: using urllib.request to intentionally block the whole bridge until the proxy change applied + request = urllib.request.Request(api_url, method="GET") + self.log.debug("Requesting proxy from: %s", api_url) + + try: + with urllib.request.urlopen(request) as f: + response = json.loads(f.read().decode()) + except Exception: + self.log.exception("Failed to retrieve proxy from API") + return self.current_proxy_url + else: + return response["proxy_url"] + + def update_proxy_url(self, reason: str | None = None) -> bool: + old_proxy = self.current_proxy_url + new_proxy = None + + if self.api_url is not None: + new_proxy = self.get_proxy_url_from_api(reason) + else: + new_proxy = urllib.request.getproxies().get("http") + + if old_proxy != new_proxy: + self.log.debug("Set new proxy URL: %s", new_proxy) + self.current_proxy_url = new_proxy + return True + + self.log.debug("Got same proxy URL: %s", new_proxy) + return False + + def get_proxy_url(self) -> str | None: + if not self.current_proxy_url: + self.update_proxy_url() + + return self.current_proxy_url + + +T = TypeVar("T") + + +async def proxy_with_retry( + name: str, + func: Callable[[], Awaitable[T]], + logger: TraceLogger, + proxy_handler: ProxyHandler, + on_proxy_change: Callable[[], Awaitable[None]], + max_retries: int = 10, + min_wait_seconds: int = 0, + max_wait_seconds: int = 60, + multiply_wait_seconds: int = 10, + retryable_exceptions: tuple[Exception] = RETRYABLE_PROXY_EXCEPTIONS, + reset_after_seconds: int | None = None, +) -> T: + errors = 0 + last_error = 0 + + while True: + try: + return await func() + except retryable_exceptions as e: + errors += 1 + if errors > max_retries: + raise + wait = errors * multiply_wait_seconds + wait = max(wait, min_wait_seconds) + wait = min(wait, max_wait_seconds) + logger.warning( + "%s while trying to %s, retrying in %d seconds", + e.__class__.__name__, + name, + wait, + ) + if errors > 1 and proxy_handler.update_proxy_url( + f"{e.__class__.__name__} while trying to {name}" + ): + await on_proxy_change() + + # If sufficient time has passed since the previous error, reset the + # error count. Useful for long running tasks with rare failures. + if reset_after_seconds is not None: + now = time.time() + if last_error and now - last_error > reset_after_seconds: + errors = 0 + last_error = now diff --git a/mautrix/util/simple_lock.py b/mautrix/util/simple_lock.py index 8cd71d34..c6bd08ce 100644 --- a/mautrix/util/simple_lock.py +++ b/mautrix/util/simple_lock.py @@ -13,31 +13,41 @@ class SimpleLock: _event: asyncio.Event log: logging.Logger | None message: str | None - - def __init__(self, message: str | None = None, log: logging.Logger | None = None) -> None: - self._event = asyncio.Event() - self._event.set() + noop_mode: bool + + def __init__( + self, + message: str | None = None, + log: logging.Logger | None = None, + noop_mode: bool = False, + ) -> None: + self.noop_mode = noop_mode + if not noop_mode: + self._event = asyncio.Event() + self._event.set() self.log = log self.message = message def __enter__(self) -> None: - self._event.clear() + if not self.noop_mode: + self._event.clear() async def __aenter__(self) -> None: self.__enter__() def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self._event.set() + if not self.noop_mode: + self._event.set() def __aexit__(self, exc_type, exc_val, exc_tb) -> None: self.__exit__(exc_type, exc_val, exc_tb) @property def locked(self) -> bool: - return not self._event.is_set() + return not self.noop_mode and not self._event.is_set() - async def wait(self, task: Optional[str] = None) -> None: - if not self._event.is_set(): + async def wait(self, task: str | None = None) -> None: + if self.locked: if self.log and self.message: self.log.debug(self.message, task) await self._event.wait() diff --git a/mautrix/util/utf16_surrogate.py b/mautrix/util/utf16_surrogate.py index d202e15e..92dc49c2 100644 --- a/mautrix/util/utf16_surrogate.py +++ b/mautrix/util/utf16_surrogate.py @@ -16,9 +16,11 @@ def add(text: str) -> str: The text with surrogate pairs. """ return "".join( - "".join(chr(y) for y in struct.unpack(" dict[str, str]: @@ -43,11 +43,12 @@ async def fetch_data() -> dict[str, str]: if __name__ == "__main__": import asyncio + import importlib.resources + import pathlib import sys - import pkg_resources - - path = pkg_resources.resource_filename("mautrix.util", "variation_selector.json") + path = importlib.resources.files("mautrix.util").joinpath("variation_selector.json") + assert isinstance(path, pathlib.Path) emojis = asyncio.run(fetch_data()) with open(path, "w") as file: json.dump(emojis, file, indent=" ", ensure_ascii=False) @@ -59,8 +60,12 @@ async def fetch_data() -> dict[str, str]: ADD_VARIATION_TRANSLATION = str.maketrans( {ord(emoji): f"{emoji}{VARIATION_SELECTOR_16}" for emoji in read_data().values()} ) -SKIN_TONE_MODIFIERS = ("\U0001F3FB", "\U0001F3FC", "\U0001F3FD", "\U0001F3FE", "\U0001F3FF") +SKIN_TONE_MODIFIERS = ("\U0001f3fb", "\U0001f3fc", "\U0001f3fd", "\U0001f3fe", "\U0001f3ff") SKIN_TONE_REPLACEMENTS = {f"{VARIATION_SELECTOR_16}{mod}": mod for mod in SKIN_TONE_MODIFIERS} +VARIATION_SELECTOR_REPLACEMENTS = { + **SKIN_TONE_REPLACEMENTS, + "\U0001f408\ufe0f\u200d\u2b1b\ufe0f": "\U0001f408\u200d\u2b1b", +} def add(val: str) -> str: @@ -88,7 +93,7 @@ def add(val: str) -> str: The string with variation selectors added. """ added = remove(val).translate(ADD_VARIATION_TRANSLATION) - for invalid_selector, replacement in SKIN_TONE_REPLACEMENTS.items(): + for invalid_selector, replacement in VARIATION_SELECTOR_REPLACEMENTS.items(): added = added.replace(invalid_selector, replacement) return added diff --git a/optional-requirements.txt b/optional-requirements.txt index 6cfdcece..a6e0227d 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -1,6 +1,6 @@ python-magic ruamel.yaml -SQLAlchemy +SQLAlchemy<2 commonmark lxml asyncpg @@ -11,3 +11,4 @@ uvloop python-olm unpaddedbase64 pycryptodome +base58 diff --git a/pyproject.toml b/pyproject.toml index 64b48c44..bc17b4d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,5 +7,8 @@ line_length = 99 [tool.black] line-length = 99 -target-version = ["py38"] -required-version = "22.1.0" +target-version = ["py310"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +addopts = "--ignore mautrix/util/db/ --ignore mautrix/bridge/" diff --git a/requirements.txt b/requirements.txt index 77e93ba7..e88b12f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ aiohttp attrs yarl -typing_extensions; python_version<"3.8" diff --git a/setup.py b/setup.py index 08a10cbb..23ac1cb1 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,8 @@ from mautrix import __version__ -test_dependencies = ["aiosqlite", "sqlalchemy", "asyncpg"] +encryption_dependencies = ["python-olm", "unpaddedbase64", "pycryptodome", "base58"] +test_dependencies = ["aiosqlite", "asyncpg", "ruamel.yaml", *encryption_dependencies] setuptools.setup( name="mautrix", @@ -27,11 +28,12 @@ ], extras_require={ "detect_mimetype": ["python-magic>=0.4.15,<0.5"], - "lint": ["black==22.1.0", "isort"], + "lint": ["black~=25.1", "isort"], "test": ["pytest", "pytest-asyncio", *test_dependencies], + "encryption": encryption_dependencies, }, tests_require=test_dependencies, - python_requires="~=3.8", + python_requires="~=3.10", classifiers=[ "Development Status :: 4 - Beta", @@ -40,9 +42,11 @@ "Framework :: AsyncIO", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ], package_data={